import operator
import threading
import time
from functools import reduce

import torch
import torch.distributed.rpc as rpc
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.distributions import Categorical


OBSERVER_NAME = "observer{}"


class Policy(nn.Module):
    def __init__(self, in_features, nlayers, out_features):
        r"""
        Inits policy class
        Args:
            in_features (int): Number of input features the model takes
            nlayers (int): Number of layers in the model
            out_features (int): Number of features the model outputs
        """
        super().__init__()

        self.model = nn.Sequential(
            nn.Flatten(1, -1),
            nn.Linear(in_features, out_features),
            *[nn.Linear(out_features, out_features) for _ in range(nlayers)],
        )
        self.dim = 0

    def forward(self, x):
        action_scores = self.model(x)
        return F.softmax(action_scores, dim=self.dim)


class AgentBase:
    def __init__(self):
        r"""
        Inits agent class
        """
        self.id = rpc.get_worker_info().id
        self.running_reward = 0
        self.eps = 1e-7

        self.rewards = {}

        self.future_actions = torch.futures.Future()
        self.lock = threading.Lock()

        self.agent_latency_start = None
        self.agent_latency_end = None
        self.agent_latency = []
        self.agent_throughput = []

    def reset_metrics(self):
        r"""
        Sets all benchmark metrics to their empty values
        """
        self.agent_latency_start = None
        self.agent_latency_end = None
        self.agent_latency = []
        self.agent_throughput = []

    def set_world(self, batch_size, state_size, nlayers, out_features, batch=True):
        r"""
        Further initializes agent to be aware of rpc environment
        Args:
            batch_size (int): size of batches of observer requests to process
            state_size (list): List of ints dictating the dimensions of the state
            nlayers (int): Number of layers in the model
            out_features (int): Number of out features in the model
            batch (bool): Whether to process and respond to observer requests as a batch or 1 at a time
        """
        self.batch = batch
        self.policy = Policy(reduce(operator.mul, state_size), nlayers, out_features)
        self.optimizer = optim.Adam(self.policy.parameters(), lr=1e-2)

        self.batch_size = batch_size
        for rank in range(batch_size):
            ob_info = rpc.get_worker_info(OBSERVER_NAME.format(rank + 2))

            self.rewards[ob_info.id] = []

        self.saved_log_probs = (
            [] if self.batch else {k: [] for k in range(self.batch_size)}
        )

        self.pending_states = self.batch_size
        self.state_size = state_size
        self.states = torch.zeros(self.batch_size, *state_size)

    @staticmethod
    @rpc.functions.async_execution
    def select_action_batch(agent_rref, observer_id, state):
        r"""
        Receives state from an observer to select action for.  Queues the observers's request
        for an action until queue size equals batch size named during Agent initiation, at which point
        actions are selected for all pending observer requests and communicated back to observers
        Args:
            agent_rref (RRef): RRFef of this agent
            observer_id (int): Observer id of observer calling this function
            state (Tensor): Tensor representing current state held by observer
        """
        self = agent_rref.local_value()
        observer_id -= 2

        self.states[observer_id].copy_(state)
        future_action = self.future_actions.then(
            lambda future_actions: future_actions.wait()[observer_id].item()
        )

        with self.lock:
            if self.pending_states == self.batch_size:
                self.agent_latency_start = time.time()
            self.pending_states -= 1
            if self.pending_states == 0:
                self.pending_states = self.batch_size
                probs = self.policy(self.states)
                m = Categorical(probs)
                actions = m.sample()
                self.saved_log_probs.append(m.log_prob(actions).t())
                future_actions = self.future_actions
                self.future_actions = torch.futures.Future()
                future_actions.set_result(actions)

                self.agent_latency_end = time.time()

                batch_latency = self.agent_latency_end - self.agent_latency_start
                self.agent_latency.append(batch_latency)
                self.agent_throughput.append(self.batch_size / batch_latency)

        return future_action

    @staticmethod
    def select_action_non_batch(agent_rref, observer_id, state):
        r"""
        Select actions based on observer state and communicates back to observer
        Args:
            agent_rref (RRef): RRef of this agent
            observer_id (int): Observer id of observer calling this function
            state (Tensor): Tensor representing current state held by observer
        """
        self = agent_rref.local_value()
        observer_id -= 2
        agent_latency_start = time.time()

        state = state.float().unsqueeze(0)
        probs = self.policy(state)
        m = Categorical(probs)
        action = m.sample()
        self.saved_log_probs[observer_id].append(m.log_prob(action))

        agent_latency_end = time.time()
        non_batch_latency = agent_latency_end - agent_latency_start
        self.agent_latency.append(non_batch_latency)
        self.agent_throughput.append(1 / non_batch_latency)

        return action.item()

    def finish_episode(self, rets):
        r"""
        Finishes the episode
        Args:
            rets (list): List containing rewards generated by selct action calls during
            episode run
        """
        return self.agent_latency, self.agent_throughput
