import functools
import time
from abc import ABC, abstractmethod

from metrics.MetricsLogger import MetricsLogger

import torch


class TrainerBase(ABC):
    BATCH_LEVEL_METRIC = "batch_level_metric"
    BATCH_ALL = "batch_all"
    FORWARD_METRIC = "forward_metric"
    FORWARD_PASS = "forward_pass"
    BACKWARD_METRIC = "backward_metric"
    BACKWARD = "backward"

    def __init__(self, rank):
        r"""
        Inits TrainerBase class.
        Args:
            rank (int): worker rank
        """
        self.__metrics_logger = MetricsLogger(rank)

    @abstractmethod
    def train(self):
        r"""
        A method to be implemented by child class that will train a neural network.
        """
        return

    def record_start(self, type, key, name, cuda=True):
        r"""
        A method that records the start event for a metric.
        Args:
            type (str): group id for metric
            key (str): unique id for metric within a group
            name (str): description of the metric
            cuda (bool): indicator to determine if this is a CUDA metric
        """
        self.__metrics_logger.record_start(type, key, name, cuda)

    def record_end(self, type, key):
        r"""
        A method that records the end event for a metric.
        Args:
            type (str): group id for metric
            key (str): unique id for metric within a group
        """
        self.__metrics_logger.record_end(type, key)

    def record_batch_start(self, key, cuda=True):
        r"""
        A helper method that records a batch metric for the
        given key. A user should call this at the start of an
        iteration step during training.
        Args:
            key (str): unique id for metric within a group
            cuda (bool): indicator to determine if this is a CUDA metric
        """
        self.__metrics_logger.record_start(
            self.BATCH_LEVEL_METRIC, key, self.BATCH_ALL, cuda
        )

    def record_batch_end(self, key):
        r"""
        A helper method that records a batch metric for the
        given key. A user should call this at the end of an
        iteration step during training.
        Args:
            key (str): unique id for metric within a group
        """
        self.__metrics_logger.record_end(self.BATCH_LEVEL_METRIC, key)

    def record_forward_start(self, key, cuda=True):
        r"""
        A helper method that records a forward metric
        for the given key. A user should call this before
        their neural network forward.
        Args:
            key (str): unique id for metric within a group
            cuda (bool): indicator to determine if this is a CUDA metric
        """
        self.__metrics_logger.record_start(
            self.FORWARD_METRIC, key, self.FORWARD_PASS, cuda
        )

    def record_forward_end(self, key):
        r"""
        A helper method that records a forward metric
        for the given key. A user should call this after their
        neural network forward.
        Args:
            key (str): unique id for metric within a group
        """
        self.__metrics_logger.record_end(self.FORWARD_METRIC, key)

    def record_backward_start(self, key, cuda=True):
        r"""
        A helper method that records a backward metric
        for the given key. A user should call this before
        their .backward() call.
        Args:
            key (str): unique id for metric within a group
            cuda (bool): indicator to determine if this is a CUDA metric
        """
        self.__metrics_logger.record_start(
            self.BACKWARD_METRIC, key, self.BACKWARD, cuda
        )

    def record_backward_end(self, key):
        r"""
        A helper method that records a backward metric
        for the given key. A user should call this after
        .backward().
        Args:
            key (str): unique id for metric within a group
        """
        self.__metrics_logger.record_end(self.BACKWARD_METRIC, key)

    @staticmethod
    def methodmetric(name, type="method_metric", cuda=True):
        r"""
        A decorator that records a metric for the decorated method.
        Args:
            name (str): description of the metric
            type (str): group id for metric
            cuda (bool): indicator to determine if this is a CUDA metric
        """

        def decorator(function):
            @functools.wraps(function)
            def wrapper(self, *args):
                key = time.time()
                self.__metrics_logger.record_start(type, key, name, cuda)
                result = function(self, *args)
                self.__metrics_logger.record_end(type, key)
                return result

            return wrapper

        return decorator

    def get_metrics(self):
        r"""
        A method that returns metrics captured by the __metrics_logger.
        """
        return self.__metrics_logger.get_processed_metrics()

    def clear_metrics(self):
        r"""
        A method that clears __metrics_logger recorded metrics.
        """
        return self.__metrics_logger.clear_metrics()


class DdpTrainer(TrainerBase):
    def __init__(
        self,
        process_group,
        use_cuda_rpc,
        server_rref,
        backend,
        epochs,
        preprocess_data,
        create_criterion,
        create_ddp_model,
        hook_state_class,
        hook,
        iteration_step,
    ):
        r"""
        A trainer that implements a DDP training algorithm using a simple hook that performs allreduce
        using the process_group implementation.
        Args:
            process_group (ProcessGroup): distributed process group
            use_cuda_rpc (bool): indicator for CUDA RPC
            server_rref (RRef): remote reference to the server
            backend (str): distributed communication backend
            epochs (int): epoch count for training
            preprocess_data (function): preprocesses data passed
                to the trainer before starting training
            create_criterion (function): creates a criterion to calculate loss
            create_ddp_model (function): creates a ddp model for the trainer
            hook_state_class (class): class that will be used to keep tracking of state
                during training.
            hook (function): ddp communication hook
            iteration_step (function): will perform 1 step of training
        """
        super().__init__(process_group.rank())
        self.process_group = process_group
        self.use_cuda_rpc = use_cuda_rpc
        self.server_rref = server_rref
        self.backend = backend
        self.epochs = epochs
        self.preprocess_data = preprocess_data
        self.create_criterion = create_criterion
        self.create_ddp_model = create_ddp_model
        self.hook_state_class = hook_state_class
        self.hook = hook
        self.iteration_step = iteration_step

        self.rank = process_group.rank()
        self.trainer_count = process_group.size()

    def epoch_key(self, epoch, index):
        r"""
        A method that returns an encoded key that represents the current epoch and
        iteration index.
        Args:
            epoch (int): epoch index
            index (int): iteration index
        """
        return f"{epoch},{index}"

    def train(self, model, data):
        r"""
        A method that implements the training algorithm.
        Args:
            model (nn.Module): neural network model
            data (list): training examples
        """
        model = model.cuda(self.rank)
        data = self.preprocess_data(self.rank, data)
        criterion = self.create_criterion(self.rank)
        ddp_model, hook_state = self.create_ddp_model(
            self, self.rank, model, self.process_group, self.hook_state_class, self.hook
        )
        optimizer = torch.optim.SGD(ddp_model.parameters(), 1e-4)

        for epoch in range(self.epochs):
            if epoch % 5 == 0 and self.rank == 0:
                print(f"train epoch={epoch}")
            for index, batch in enumerate(data):
                self.iteration_step(
                    self,
                    ddp_model,
                    criterion,
                    optimizer,
                    hook_state,
                    epoch,
                    index,
                    batch,
                )
        torch.cuda.synchronize(self.rank)
