#include <ATen/core/dispatch/Dispatcher.h>
#include <c10/util/intrusive_ptr.h>
#include <torch/csrc/distributed/c10d/ProcessGroup.hpp>
#include <torch/csrc/distributed/c10d/Types.hpp>
#include <torch/library.h>

namespace c10d {
namespace {

TORCH_LIBRARY(c10d, m) {
  // The following ProcessGroup, Work, and ReduceOp definitions are more like
  // declarations. They don't expose the details of the two classes into
  // TorchScript.
  m.class_<ProcessGroup>("ProcessGroup").def(torch::init<int64_t, int64_t>());
  m.class_<Work>("Work")
      .def(torch::init<>())
      .def("wait", [](const c10::intrusive_ptr<Work>& self) { self->wait(); });
  m.class_<ReduceOp>("ReduceOp").def(torch::init<>());
  m.def(
      "broadcast_(Tensor[] tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, int root_rank, int root_tensor, bool asyncOp, int timeout) -> (Tensor[], __torch__.torch.classes.c10d.Work)");
  m.def(
      "allreduce_(Tensor[] tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, __torch__.torch.classes.c10d.ReduceOp reduce_op, Tensor? sparse_indices, int timeout) -> (Tensor[], __torch__.torch.classes.c10d.Work)");
  m.def(
      "allreduce_coalesced_(Tensor[] tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, __torch__.torch.classes.c10d.ReduceOp reduce_op, int timeout) -> __torch__.torch.classes.c10d.Work");
  m.def(
      "allgather_(Tensor[][] output_tensors, Tensor[] input_tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, int timeout) -> (Tensor[][], __torch__.torch.classes.c10d.Work)");
  m.def(
      "_allgather_base_(Tensor output_tensor, Tensor input_tensor, __torch__.torch.classes.c10d.ProcessGroup process_group, bool asyncOp, int timeout) -> (Tensor, __torch__.torch.classes.c10d.Work)");
  m.def(
      "allgather_coalesced_(Tensor[][] output_lists, Tensor[] input_list, __torch__.torch.classes.c10d.ProcessGroup process_group) -> __torch__.torch.classes.c10d.Work");
  m.def(
      "allgather_into_tensor_coalesced_(Tensor[] outputs, Tensor[] inputs, __torch__.torch.classes.c10d.ProcessGroup process_group) -> __torch__.torch.classes.c10d.Work");
  m.def(
      "reduce_scatter_(Tensor[] output_tensors, Tensor[][] input_tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, __torch__.torch.classes.c10d.ReduceOp reduce_op, int timeout) -> (Tensor[], __torch__.torch.classes.c10d.Work)");
  m.def(
      "_reduce_scatter_base_(Tensor output_tensor, Tensor input_tensor, __torch__.torch.classes.c10d.ProcessGroup process_group, __torch__.torch.classes.c10d.ReduceOp reduce_op, bool asyncOp, int timeout) -> (Tensor, __torch__.torch.classes.c10d.Work)");
  m.def(
      "reduce_scatter_tensor_coalesced_(Tensor[] outputs, Tensor[] inputs, __torch__.torch.classes.c10d.ProcessGroup process_group, __torch__.torch.classes.c10d.ReduceOp reduce_op, int timeout) -> __torch__.torch.classes.c10d.Work");
  m.def(
      "reduce_(Tensor[] tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, __torch__.torch.classes.c10d.ReduceOp reduce_op, int root_rank, int root_tensor, int timeout) -> __torch__.torch.classes.c10d.Work");
  m.def(
      "gather_(Tensor[][] output_tensors, Tensor[] input_tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, int root_rank, int timeout) -> __torch__.torch.classes.c10d.Work");
  m.def(
      "scatter_(Tensor[] output_tensors, Tensor[][] input_tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, int root_rank, bool asyncOp, int timeout) -> (Tensor[], __torch__.torch.classes.c10d.Work)");
  m.def(
      "alltoall_(Tensor[] output_tensors, Tensor[] input_tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, int timeout) -> (Tensor[], __torch__.torch.classes.c10d.Work)");
  m.def(
      "alltoall_base_(Tensor output, Tensor input, __torch__.torch.classes.c10d.ProcessGroup process_group, int[] output_split_sizes, int[] input_split_sizes, int timeout) -> __torch__.torch.classes.c10d.Work");
  m.def(
      "barrier(Tensor tensor, __torch__.torch.classes.c10d.ProcessGroup process_group, int[] device_ids, int timeout) -> __torch__.torch.classes.c10d.Work");
  m.def(
      "monitored_barrier_(Tensor tensor, __torch__.torch.classes.c10d.ProcessGroup process_group, int[] device_ids, int timeout, bool wait_all_ranks) -> ()");
  m.def(
      "send(Tensor[] tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, int dst, int tag) -> __torch__.torch.classes.c10d.Work");
  m.def(
      "recv_(Tensor[] tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, int src, int tag) -> __torch__.torch.classes.c10d.Work");
  m.def(
      "recv_any_source_(Tensor[] tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, int tag) -> __torch__.torch.classes.c10d.Work");
}
} // namespace

namespace ops {

// Below are ProcessGroup's corresponding ops for each backend. Ops are but
// routed through the dispatcher to be dispatched to the appropriate backend.
// Currently a no-op as the process group does not have a list of backends.

namespace {

#define IMPL_SEND(DEV)                                                        \
  c10::intrusive_ptr<Work> send##DEV(                                         \
      at::TensorList tensors,                                                 \
      const c10::intrusive_ptr<ProcessGroup>& process_group,                  \
      int64_t dstRank,                                                        \
      int64_t tag) {                                                          \
    auto tensor_vec = tensors.vec();                                          \
    return process_group->getBackend(c10::DeviceType::DEV)                    \
        ->send(tensor_vec, static_cast<int>(dstRank), static_cast<int>(tag)); \
  }

IMPL_SEND(CPU)
IMPL_SEND(CUDA)
IMPL_SEND(PrivateUse1)

#define IMPL_RECV(DEV)                                                        \
  c10::intrusive_ptr<Work> recv_##DEV(                                        \
      at::TensorList tensors,                                                 \
      const c10::intrusive_ptr<ProcessGroup>& process_group,                  \
      int64_t srcRank,                                                        \
      int64_t tag) {                                                          \
    auto tensor_vec = tensors.vec();                                          \
    return process_group->getBackend(c10::DeviceType::DEV)                    \
        ->recv(tensor_vec, static_cast<int>(srcRank), static_cast<int>(tag)); \
  }

IMPL_RECV(CPU)
IMPL_RECV(CUDA)
IMPL_RECV(PrivateUse1)

#define IMPL_RECV_ANY_SOURCE(DEV)                            \
  c10::intrusive_ptr<Work> recv_any_source_##DEV(            \
      at::TensorList tensors,                                \
      const c10::intrusive_ptr<ProcessGroup>& process_group, \
      int64_t tag) {                                         \
    auto tensor_vec = tensors.vec();                         \
    return process_group->getBackend(c10::DeviceType::DEV)   \
        ->recvAnysource(tensor_vec, static_cast<int>(tag));  \
  }

IMPL_RECV_ANY_SOURCE(CPU)
IMPL_RECV_ANY_SOURCE(CUDA)
IMPL_RECV_ANY_SOURCE(PrivateUse1)

#define IMPL_REDUCE(DEV)                                     \
  c10::intrusive_ptr<Work> reduce_##DEV(                     \
      at::TensorList tensors,                                \
      const c10::intrusive_ptr<ProcessGroup>& process_group, \
      const c10::intrusive_ptr<ReduceOp>& reduce_op,         \
      int64_t root_rank,                                     \
      int64_t root_tensor,                                   \
      int64_t timeout) {                                     \
    auto tensor_vec = tensors.vec();                         \
    return process_group->getBackend(c10::DeviceType::DEV)   \
        ->reduce(                                            \
            tensor_vec,                                      \
            ReduceOptions{                                   \
                *reduce_op.get(),                            \
                root_rank,                                   \
                root_tensor,                                 \
                std::chrono::milliseconds(timeout)});        \
  }

IMPL_REDUCE(CPU)
IMPL_REDUCE(CUDA)
IMPL_REDUCE(PrivateUse1)

#define IMPL_BROADCAST(DEV)                                                   \
  std::tuple<std::vector<at::Tensor>, c10::intrusive_ptr<Work>>               \
      broadcast_##DEV(                                                        \
          at::TensorList tensors,                                             \
          const c10::intrusive_ptr<ProcessGroup>& process_group,              \
          int64_t root_rank,                                                  \
          int64_t root_tensor,                                                \
          bool asyncOp,                                                       \
          int64_t timeout) {                                                  \
    auto tensor_vec = tensors.vec();                                          \
    auto work = process_group->getBackend(c10::DeviceType::DEV) -> broadcast( \
        tensor_vec,                                                           \
        BroadcastOptions{                                                     \
            root_rank,                                                        \
            root_tensor,                                                      \
            std::chrono::milliseconds(timeout),                               \
            asyncOp});                                                        \
    return std::tuple<std::vector<at::Tensor>, c10::intrusive_ptr<Work>>(     \
        std::move(tensor_vec), work);                                         \
  }

IMPL_BROADCAST(CPU)
IMPL_BROADCAST(CUDA)
IMPL_BROADCAST(PrivateUse1)

// Return input tensors as output tensors to make inplace allreduce look like
// a functional API, so that make_fx can correctly build the dependencies in
// the graph later.
#define IMPL_ALLREDUCE(DEV)                                                   \
  std::tuple<std::vector<at::Tensor>, c10::intrusive_ptr<Work>>               \
      allreduce_##DEV(                                                        \
          at::TensorList tensors,                                             \
          const c10::intrusive_ptr<ProcessGroup>& process_group,              \
          const c10::intrusive_ptr<ReduceOp>& reduce_op,                      \
          const std::optional<at::Tensor>& sparse_indices,                    \
          int64_t timeout) {                                                  \
    auto tensor_vec = tensors.vec();                                          \
    auto work = process_group->getBackend(c10::DeviceType::DEV) -> allreduce( \
        tensor_vec,                                                           \
        AllreduceOptions{                                                     \
            *reduce_op.get(), std::chrono::milliseconds(timeout)});           \
    return std::tuple<std::vector<at::Tensor>, c10::intrusive_ptr<Work>>(     \
        std::move(tensor_vec), work);                                         \
  }

IMPL_ALLREDUCE(CPU)
IMPL_ALLREDUCE(CUDA)
IMPL_ALLREDUCE(PrivateUse1)

#define IMPL_ALLREDUCE_COALESCED(DEV)                             \
  c10::intrusive_ptr<Work> allreduce_coalesced_##DEV(             \
      at::TensorList tensors,                                     \
      const c10::intrusive_ptr<ProcessGroup>& process_group,      \
      const c10::intrusive_ptr<ReduceOp>& reduce_op,              \
      int64_t timeout) {                                          \
    auto tensor_vec = tensors.vec();                              \
    AllreduceCoalescedOptions opts = AllreduceCoalescedOptions{}; \
    opts.reduceOp = *reduce_op.get();                             \
    opts.timeout = std::chrono::milliseconds(timeout);            \
    return process_group->getBackend(c10::DeviceType::DEV)        \
        ->allreduce_coalesced(tensor_vec, opts);                  \
  }

IMPL_ALLREDUCE_COALESCED(CPU)
IMPL_ALLREDUCE_COALESCED(CUDA)
IMPL_ALLREDUCE_COALESCED(PrivateUse1)

// Copy output tensors (not storage) so that this can be used in a functional
// manner
#define IMPL_ALLGATHER(DEV)                                                    \
  std::tuple<std::vector<std::vector<at::Tensor>>, c10::intrusive_ptr<Work>>   \
      allgather_##DEV(                                                         \
          const std::vector<std::vector<at::Tensor>>& output_tensors,          \
          at::TensorList input_tensors,                                        \
          const c10::intrusive_ptr<ProcessGroup>& process_group,               \
          int64_t timeout) {                                                   \
    auto input_tensors_vec = input_tensors.vec();                              \
    auto work = process_group->getBackend(c10::DeviceType::DEV) -> allgather(  \
        const_cast<std::vector<std::vector<at::Tensor>>&>(output_tensors),     \
        input_tensors_vec,                                                     \
        AllgatherOptions{std::chrono::milliseconds(timeout)});                 \
    return std::                                                               \
        tuple<std::vector<std::vector<at::Tensor>>, c10::intrusive_ptr<Work>>( \
            output_tensors, work);                                             \
  }

// NOLINTBEGIN(cppcoreguidelines-pro-type-const-cast)
IMPL_ALLGATHER(CPU)
IMPL_ALLGATHER(CUDA)
IMPL_ALLGATHER(PrivateUse1)

#define IMPL__ALLGATHER_BASE(DEV)                                           \
  std::tuple<at::Tensor, c10::intrusive_ptr<Work>> _allgather_base_##DEV(   \
      at::Tensor& output_tensor,                                            \
      at::Tensor& input_tensor,                                             \
      const c10::intrusive_ptr<ProcessGroup>& process_group,                \
      bool asyncOp,                                                         \
      int64_t timeout) {                                                    \
    auto work =                                                             \
        process_group->getBackend(c10::DeviceType::DEV) -> _allgather_base( \
            output_tensor,                                                  \
            input_tensor,                                                   \
            AllgatherOptions{std::chrono::milliseconds(timeout), asyncOp}); \
    return std::tuple<at::Tensor, c10::intrusive_ptr<Work>>(                \
        output_tensor, work);                                               \
  }

IMPL__ALLGATHER_BASE(CPU)
IMPL__ALLGATHER_BASE(CUDA)
IMPL__ALLGATHER_BASE(PrivateUse1)

#define IMPL_ALLGATHER_COALESCED(DEV)                                        \
  c10::intrusive_ptr<Work> allgather_coalesced_##DEV(                        \
      const std::vector<std::vector<at::Tensor>>& output_lists,              \
      const at::TensorList& input_list,                                      \
      const c10::intrusive_ptr<ProcessGroup>& process_group) {               \
    auto input_list_vec = input_list.vec();                                  \
    return process_group->getBackend(c10::DeviceType::DEV)                   \
        ->allgather_coalesced(                                               \
            const_cast<std::vector<std::vector<at::Tensor>>&>(output_lists), \
            input_list_vec);                                                 \
  }

IMPL_ALLGATHER_COALESCED(CPU)
IMPL_ALLGATHER_COALESCED(CUDA)
IMPL_ALLGATHER_COALESCED(PrivateUse1)

#define IMPL_ALLGATHER_INTO_TENSOR_COALESCED(DEV)                       \
  c10::intrusive_ptr<c10d::Work> allgather_into_tensor_coalesced_##DEV( \
      at::TensorList outputs,                                           \
      at::TensorList inputs,                                            \
      const c10::intrusive_ptr<ProcessGroup>& process_group) {          \
    auto output_vec = outputs.vec();                                    \
    auto input_vec = inputs.vec();                                      \
    return process_group->getBackend(c10::DeviceType::DEV)              \
        ->allgather_into_tensor_coalesced(output_vec, input_vec);       \
  }

IMPL_ALLGATHER_INTO_TENSOR_COALESCED(CPU)
IMPL_ALLGATHER_INTO_TENSOR_COALESCED(CUDA)
IMPL_ALLGATHER_INTO_TENSOR_COALESCED(PrivateUse1)

#define IMPL_REDUCE_SCATTER(DEV)                                              \
  std::tuple<std::vector<at::Tensor>, c10::intrusive_ptr<Work>>               \
      reduce_scatter_##DEV(                                                   \
          const at::TensorList& output_tensors,                               \
          const std::vector<std::vector<at::Tensor>>& input_tensors,          \
          const c10::intrusive_ptr<ProcessGroup>& process_group,              \
          const c10::intrusive_ptr<ReduceOp>& reduce_op,                      \
          int64_t timeout) {                                                  \
    auto output_tensors_vec = output_tensors.vec();                           \
    auto work =                                                               \
        process_group->getBackend(c10::DeviceType::DEV) -> reduce_scatter(    \
            output_tensors_vec,                                               \
            const_cast<std::vector<std::vector<at::Tensor>>&>(input_tensors), \
            ReduceScatterOptions{                                             \
                *reduce_op.get(), std::chrono::milliseconds(timeout)});       \
    return std::tuple<std::vector<at::Tensor>, c10::intrusive_ptr<Work>>(     \
        output_tensors_vec, work);                                            \
  }

IMPL_REDUCE_SCATTER(CPU)
IMPL_REDUCE_SCATTER(CUDA)
IMPL_REDUCE_SCATTER(PrivateUse1)

#define IMPL__REDUCE_SCATTER_BASE(DEV)                                         \
  std::tuple<at::Tensor, c10::intrusive_ptr<Work>> _reduce_scatter_base_##DEV( \
      at::Tensor& output_tensor,                                               \
      at::Tensor& input_tensor,                                                \
      const c10::intrusive_ptr<ProcessGroup>& process_group,                   \
      const c10::intrusive_ptr<ReduceOp>& reduce_op,                           \
      bool asyncOp,                                                            \
      int64_t timeout) {                                                       \
    auto work = process_group->getBackend(c10::DeviceType::DEV)                \
                    -> _reduce_scatter_base(                                   \
                        output_tensor,                                         \
                        input_tensor,                                          \
                        ReduceScatterOptions{                                  \
                            *reduce_op.get(),                                  \
                            std::chrono::milliseconds(timeout),                \
                            asyncOp});                                         \
    return std::tuple<at::Tensor, c10::intrusive_ptr<Work>>(                   \
        output_tensor, work);                                                  \
  }

IMPL__REDUCE_SCATTER_BASE(CPU)
IMPL__REDUCE_SCATTER_BASE(CUDA)
IMPL__REDUCE_SCATTER_BASE(PrivateUse1)

#define IMPL_REDUCE_SCATTER_TENSOR_COALESCED(DEV)                       \
  c10::intrusive_ptr<c10d::Work> reduce_scatter_tensor_coalesced_##DEV( \
      at::TensorList outputs,                                           \
      at::TensorList inputs,                                            \
      const c10::intrusive_ptr<ProcessGroup>& process_group,            \
      const c10::intrusive_ptr<ReduceOp>& reduce_op,                    \
      int64_t timeout) {                                                \
    auto output_vec = outputs.vec();                                    \
    auto input_vec = inputs.vec();                                      \
    return process_group->getBackend(c10::DeviceType::DEV)              \
        ->reduce_scatter_tensor_coalesced(                              \
            output_vec,                                                 \
            input_vec,                                                  \
            ReduceScatterOptions{                                       \
                *reduce_op.get(), std::chrono::milliseconds(timeout)}); \
  }

IMPL_REDUCE_SCATTER_TENSOR_COALESCED(CPU)
IMPL_REDUCE_SCATTER_TENSOR_COALESCED(CUDA)
IMPL_REDUCE_SCATTER_TENSOR_COALESCED(PrivateUse1)

#define IMPL_GATHER(DEV)                                                       \
  c10::intrusive_ptr<Work> gather_##DEV(                                       \
      const std::vector<std::vector<at::Tensor>>& output_tensors,              \
      const at::TensorList& input_tensors,                                     \
      const c10::intrusive_ptr<ProcessGroup>& process_group,                   \
      int64_t root_rank,                                                       \
      int64_t timeout) {                                                       \
    auto input_tensors_vec = input_tensors.vec();                              \
    return process_group->getBackend(c10::DeviceType::DEV)                     \
        ->gather(                                                              \
            const_cast<std::vector<std::vector<at::Tensor>>&>(output_tensors), \
            input_tensors_vec,                                                 \
            GatherOptions{root_rank, std::chrono::milliseconds(timeout)});     \
  }

IMPL_GATHER(CPU)
IMPL_GATHER(CUDA)
IMPL_GATHER(PrivateUse1)

#define IMPL_SCATTER(DEV)                                                      \
  std::tuple<std::vector<at::Tensor>, c10::intrusive_ptr<Work>> scatter_##DEV( \
      const at::TensorList& output_tensors,                                    \
      const std::vector<std::vector<at::Tensor>>& input_tensors,               \
      const c10::intrusive_ptr<ProcessGroup>& process_group,                   \
      int64_t root_rank,                                                       \
      bool asyncOp,                                                            \
      int64_t timeout) {                                                       \
    auto output_tensors_vec = output_tensors.vec();                            \
    auto work = process_group->getBackend(c10::DeviceType::DEV) -> scatter(    \
        output_tensors_vec,                                                    \
        const_cast<std::vector<std::vector<at::Tensor>>&>(input_tensors),      \
        ScatterOptions{                                                        \
            root_rank, std::chrono::milliseconds(timeout), asyncOp});          \
    return std::tuple<std::vector<at::Tensor>, c10::intrusive_ptr<Work>>(      \
        std::move(output_tensors_vec), work);                                  \
  }

IMPL_SCATTER(CPU)
IMPL_SCATTER(CUDA)
IMPL_SCATTER(PrivateUse1)

#define IMPL_ALLTOALL(DEV)                                                   \
  std::tuple<std::vector<at::Tensor>, c10::intrusive_ptr<Work>>              \
      alltoall_##DEV(                                                        \
          const at::TensorList& output_tensors,                              \
          const at::TensorList& input_tensors,                               \
          const c10::intrusive_ptr<ProcessGroup>& process_group,             \
          int64_t timeout) {                                                 \
    auto output_tensors_vec = output_tensors.vec();                          \
    auto input_tensors_vec = input_tensors.vec();                            \
    auto work = process_group->getBackend(c10::DeviceType::DEV) -> alltoall( \
        output_tensors_vec,                                                  \
        input_tensors_vec,                                                   \
        AllToAllOptions{std::chrono::milliseconds(timeout)});                \
    return std::tuple<std::vector<at::Tensor>, c10::intrusive_ptr<Work>>(    \
        std::move(output_tensors_vec), work);                                \
  }

IMPL_ALLTOALL(CPU)
IMPL_ALLTOALL(CUDA)
IMPL_ALLTOALL(PrivateUse1)

#define IMPL_ALLTOALL_BASE(DEV)                                   \
  c10::intrusive_ptr<Work> alltoall_base_##DEV(                   \
      at::Tensor& output,                                         \
      at::Tensor& input,                                          \
      const c10::intrusive_ptr<ProcessGroup>& process_group,      \
      std::vector<int64_t> output_split_sizes,                    \
      std::vector<int64_t> input_split_sizes,                     \
      int64_t timeout) {                                          \
    return process_group->getBackend(c10::DeviceType::DEV)        \
        ->alltoall_base(                                          \
            output,                                               \
            input,                                                \
            output_split_sizes,                                   \
            input_split_sizes,                                    \
            AllToAllOptions{std::chrono::milliseconds(timeout)}); \
  }

IMPL_ALLTOALL_BASE(CPU)
IMPL_ALLTOALL_BASE(CUDA)
IMPL_ALLTOALL_BASE(PrivateUse1)

#define IMPL_BARRIER(DEV)                                                    \
  c10::intrusive_ptr<Work> barrier##DEV(                                     \
      at::Tensor /* unused */,                                               \
      const c10::intrusive_ptr<ProcessGroup>& process_group,                 \
      const std::vector<int64_t>& device_ids,                                \
      int64_t timeout) {                                                     \
    return process_group->getBackend(c10::DeviceType::DEV)                   \
        ->barrier(                                                           \
            BarrierOptions{device_ids, std::chrono::milliseconds(timeout)}); \
  }

IMPL_BARRIER(CPU)
IMPL_BARRIER(CUDA)
IMPL_BARRIER(PrivateUse1)
// NOLINTEND(cppcoreguidelines-pro-type-const-cast)

void monitored_barrier_CPU(
    at::Tensor /* unused */,
    const c10::intrusive_ptr<::c10d::ProcessGroup>& process_group,
    const std::vector<int64_t>& device_ids,
    int64_t timeout,
    bool wait_all_ranks) {
  process_group->getBackend(c10::DeviceType::CPU)
      ->monitoredBarrier(
          BarrierOptions{device_ids, std::chrono::milliseconds(timeout)},
          wait_all_ranks);
}

std::tuple<std::vector<at::Tensor>, c10::intrusive_ptr<Work>>
allreduce_sparse_cuda_(
    at::TensorList tensors,
    const c10::intrusive_ptr<ProcessGroup>& process_group,
    const c10::intrusive_ptr<ReduceOp>& reduce_op,
    const std::optional<at::Tensor>& sparse_indices,
    int64_t timeout) {
  auto tensor_vec = tensors.vec();
  auto work = process_group->getBackend(c10::DeviceType::CUDA)
                  ->allreduce_sparse(
                      tensor_vec,
                      AllreduceOptions{
                          *reduce_op,
                          std::chrono::milliseconds(timeout),
                          sparse_indices});

  return std::tuple<std::vector<at::Tensor>, c10::intrusive_ptr<Work>>(
      std::move(tensor_vec), work);
}
} // namespace

// register functions to dispatcher
namespace {

// 2nd level expansion
// FUNC: op name
// DEV: device
#define REGISTER_C10D_OP1(FUNC, DEV) \
  TORCH_LIBRARY_IMPL(c10d, DEV, m) { \
    m.impl(#FUNC, FUNC##DEV);        \
  }

// 1st level expansion
#define REGISTER_C10D_OP(FUNC)  \
  REGISTER_C10D_OP1(FUNC, CPU)  \
  REGISTER_C10D_OP1(FUNC, CUDA) \
  REGISTER_C10D_OP1(FUNC, PrivateUse1)

// Now we start to register ops with the three device keys

REGISTER_C10D_OP(send)
REGISTER_C10D_OP(recv_)
REGISTER_C10D_OP(recv_any_source_)
REGISTER_C10D_OP(reduce_)
REGISTER_C10D_OP(broadcast_)
REGISTER_C10D_OP(allreduce_)
REGISTER_C10D_OP(allreduce_coalesced_)
REGISTER_C10D_OP(allgather_)
REGISTER_C10D_OP(_allgather_base_)
REGISTER_C10D_OP(allgather_coalesced_)
REGISTER_C10D_OP(allgather_into_tensor_coalesced_)
REGISTER_C10D_OP(reduce_scatter_)
REGISTER_C10D_OP(_reduce_scatter_base_)
REGISTER_C10D_OP(reduce_scatter_tensor_coalesced_)
REGISTER_C10D_OP(gather_)
REGISTER_C10D_OP(scatter_)
REGISTER_C10D_OP(alltoall_)
REGISTER_C10D_OP(alltoall_base_)
REGISTER_C10D_OP(barrier)

// The following ops are specialized, register them separately

TORCH_LIBRARY_IMPL(c10d, CPU, m) {
  m.impl("monitored_barrier_", monitored_barrier_CPU);
}

// TODO: The SparseCPU/SparseCUDA dispatched methods are only used to support
// sparse all_reduce in the Gloo backend
TORCH_LIBRARY_IMPL(c10d, SparseCPU, m) {
  m.impl("allreduce_", allreduce_CPU);
}

TORCH_LIBRARY_IMPL(c10d, SparseCUDA, m) {
  m.impl("allreduce_", allreduce_sparse_cuda_);
}

} // namespace

} // namespace ops
} // namespace c10d
