#pragma once

#include <mutex>
#include <unordered_set>

#include <torch/csrc/autograd/engine.h>
#include <torch/csrc/autograd/function.h>
#include <torch/csrc/autograd/functions/basic_ops.h>
#include <torch/csrc/distributed/autograd/context/context.h>

namespace torch {
namespace distributed {
namespace autograd {

// Forward declaration.
class BackwardPassCleanupGuard;

// This is a singleton class responsible for running distributed backward
// passes. This engine relies heavily on the vanilla autograd engine and tries
// to re-use it as much as possible. This class is mostly responsible for the
// distributed aspects of autograd and tries to hook into the autograd engine
// where convenient.

// Unlike the vanilla autograd engine, the distributed autograd engine
// accumulates the gradients in the appropriate DistAutogradContext. This avoids
// multiple trainer nodes stomping on each others gradients.
class TORCH_API DistEngine {
 public:
  // Retrieve the singleton instance.
  static DistEngine& getInstance();

  // Given a list of root variables, start the distributed backwards pass from
  // these variables and accumulate all the gradients in the current autograd
  // context on each node. This method is used to kickoff distributed autograd
  // on a single node.
  void execute(
      int64_t context_id,
      const torch::autograd::variable_list& roots,
      bool retainGraph);

  // Given a send function to execute in the autograd engine, ensures we compute
  // dependencies once for this node and enqueues the send function for execute
  // in the engine.
  // This method is used to kick off the autograd computation on a node when it
  // receives gradients from the corresponding 'recv' method on another node.
  // The gradients are accumulated in the provided autograd context.
  c10::intrusive_ptr<c10::ivalue::Future> executeSendFunctionAsync(
      const ContextPtr& autogradContext,
      const std::shared_ptr<SendRpcBackward>& sendFunction,
      bool retainGraph);

  // Number of backward passes currently running for the Distributed Engine.
  size_t numBackwardPasses() const;

  // Returns key-value pairs consisting of useful debugging information related
  // to distributed autograd.
  std::unordered_map<std::string, int> getDebugInfo() const;

  DistEngine(const DistEngine&) = delete;
  DistEngine& operator=(const DistEngine&) = delete;
  DistEngine(DistEngine&&) = delete;
  DistEngine& operator=(DistEngine&&) = delete;

 private:
  // Make sure this is a singleton.
  DistEngine();
  ~DistEngine();

  // Validates the input roots for the backward computations and retrieves the
  // appropriate root edges and corresponding gradients. Populates root_edges
  // with the appropriate gradient edges and grads with the gradients for each
  // edge.
  void validateRootsAndRetrieveEdges(
      const torch::autograd::variable_list& roots,
      torch::autograd::edge_list& rootEdges,
      torch::autograd::variable_list& grads);

  // Given the autograd context, root edges and grads, we compute dependencies
  // for the local node and fill out the provided GraphTask and GraphRoot with
  // appropriate information for the local autograd engine.
  // We also determine all leaf nodes(functions) in the graph and accumulate
  // them in outputEdges.
  void computeDependencies(
      const ContextPtr& context,
      const torch::autograd::edge_list& rootEdges,
      const torch::autograd::variable_list& grads,
      const std::shared_ptr<torch::autograd::Node>& graphRoot,
      torch::autograd::edge_list& outputEdges,
      bool retainGraph);

  // Given a pre-populated GraphTask and a root node, compute the backward pass
  // for the autograd graph until the graph task ready queue is empty.
  //
  // This method assumes that the appropriate GraphTask has already been
  // initialized appropriately. It will construct a local ready queue to
  // traverse the GraphTask instead of using the GraphTask embedded
  // cpu_ready_queue, this is because dist engine might run the same GraphTask
  // from different SendFunctions concurrently in different threads. The method
  // will only mark the GraphTask as completed when it needs to, which means it
  // might not mark as completed for every call as dist engine would like to
  // keep the GraphTask alive when it not receives all gradients.
  //
  // When `incrementOutstandingTasks=false`, the function does not increment
  // 'outstanding_tasks_' in the appropriate GraphTask. It is assumed we've
  // already done this before hand for this task (to ensure we don't pre-mark
  // this graph_task as completed). This is useful in the distributed autograd
  // case where we need to increment 'outstanding_tasks_' first to indicate the
  // local autograd engine the graph task is not completed until it receives the
  // signals from other workers over the network.
  //
  // XXX: calling this function assumes that we will have NO GPU nodetasks be
  // executed for the graph_task, the caller of this function need to ensure
  // this otherwise there will be undefined behaviors. A correct way to fix this
  // is to re-design the autograd engine so that GPU worker thread to behave the
  // same as CPU caller thread, record the operation/thread for the device, and
  // reuse it in backward.
  // TODO: 1. Add assert in the dist engine to ensure no GPU NodeTasks during
  // backward
  //       2. properly setup the thread local ready queue to enable reentrant
  //       backwards
  void execute_graph_task_until_ready_queue_empty(
      torch::autograd::NodeTask&& node_task,
      bool incrementOutstandingTasks = true);

  // Run the local autograd engine using the provided graphTask and graphRoot
  // and accumulate the gradients part 'outputEdges' in the provided autograd
  // context.
  c10::intrusive_ptr<c10::ivalue::Future> runEngineAndAccumulateGradients(
      const ContextPtr& autogradContext,
      const std::shared_ptr<torch::autograd::Node>& graphRoot,
      const torch::autograd::edge_list& outputEdges,
      bool incrementOutStandingTasks = true);

  // Run after the backward pass is done to appropriately cleanup structures.
  void cleanupBackwardPass(const ContextPtr& autogradContext);

  // Global thread to execute CPU continuations.
  void globalCpuThread(
      const std::shared_ptr<torch::autograd::ReadyQueue>& ready_queue);

  // Set of autograd context_ids, which we have already initialized for
  // distributed autograd on this node (e.g.: already computed dependencies)
  std::unordered_set<int64_t> initializedContextIds_;

  mutable std::mutex initializedContextIdsLock_;

  // Reference to local autograd engine.
  torch::autograd::Engine& engine_;

  // Ready queue used by the CPU thread in distributed engine.
  // See Note [GPU to CPU continuations]
  std::shared_ptr<torch::autograd::ReadyQueue> global_cpu_ready_queue_;

  // See Note [GPU to CPU continuations]
  std::thread global_cpu_thread_;

  friend class BackwardPassCleanupGuard;
};

// Guard to clean up resources once the backward pass is done.
class BackwardPassCleanupGuard {
 public:
  explicit BackwardPassCleanupGuard(ContextPtr autogradContext)
      : autogradContext_(std::move(autogradContext)) {}

  ~BackwardPassCleanupGuard() {
    DistEngine::getInstance().cleanupBackwardPass(autogradContext_);
  }

 private:
  ContextPtr autogradContext_;
};

} // namespace autograd
} // namespace distributed
} // namespace torch
