#ifdef USE_C10D_UCC

#include <ATen/cuda/nvrtc_stub/ATenNVRTC.h>
#include <torch/csrc/distributed/c10d/ProcessGroupUCC.hpp>
#include <torch/csrc/distributed/c10d/UCCTracing.hpp>
#include <torch/csrc/distributed/c10d/UCCUtils.hpp>
#include <list>
#include <memory>
#include <unordered_map>
#include <unordered_set>

namespace c10d {

namespace {

const std::map<c10::DeviceType, ucc_memory_type_t> ucc_mtype_map = {
    {c10::kCPU, UCC_MEMORY_TYPE_HOST},
    {c10::kCUDA, UCC_MEMORY_TYPE_CUDA},
};

ucc_memory_type_t to_ucc_memType(c10::DeviceType _c10_type) {
  if (ucc_mtype_map.find(_c10_type) != ucc_mtype_map.end())
    return ucc_mtype_map.at(_c10_type);
  else
    return UCC_MEMORY_TYPE_UNKNOWN;
}

const std::map<at::ScalarType, ucc_datatype_t> ucc_dtype_map = {
    {at::kByte, UCC_DT_UINT8},
    {at::kChar, UCC_DT_INT8},
    {at::kHalf, UCC_DT_FLOAT16},
    {at::kBFloat16, UCC_DT_BFLOAT16},
    {at::kDouble, UCC_DT_FLOAT64},
    {at::kFloat, UCC_DT_FLOAT32},
    {at::kInt, UCC_DT_INT32},
    {at::kLong, UCC_DT_INT64},
    {at::kBool, UCC_DT_UINT8},
};

ucc_datatype_t to_ucc_dType(at::Tensor _tensor) {
  if (_tensor.scalar_type() == at::kBool && _tensor.element_size() != 1) {
    TORCH_CHECK(
        false, "Size of Boolean type larger than 1 is not supported in UCC");
  }
  try {
    return ucc_dtype_map.at(_tensor.scalar_type());
  } catch (const std::out_of_range&) {
    TORCH_CHECK(false, "Not supported data type for UCC");
  }
}

const std::map<ReduceOp, ucc_reduction_op_t> ucc_op_map = {
    {ReduceOp::SUM, UCC_OP_SUM},
    {ReduceOp::PRODUCT, UCC_OP_PROD},
    {ReduceOp::MIN, UCC_OP_MIN},
    {ReduceOp::MAX, UCC_OP_MAX},
    {ReduceOp::BAND, UCC_OP_BAND},
    {ReduceOp::BOR, UCC_OP_BOR},
    {ReduceOp::BXOR, UCC_OP_BXOR},
    {ReduceOp::AVG, UCC_OP_AVG},
};

ucc_reduction_op_t to_ucc_reduceOp(
    const ReduceOp _op,
    const at::ScalarType _dt) {
  if (_dt == at::kBool) {
    if (_op == ReduceOp::SUM) {
      // bitwise or
      return UCC_OP_MAX;
    } else if (_op == ReduceOp::PRODUCT) {
      // bitwise and
      return UCC_OP_MIN;
    } else if (_op == ReduceOp::AVG) {
      TORCH_CHECK(false, "Cannot use ReduceOp.AVG with boolean inputs");
    }
  }

  try {
    return ucc_op_map.at(_op);
  } catch (const std::out_of_range&) {
    TORCH_CHECK(false, "Not supported ReduceOp for UCC");
  }
}

struct torch_ucc_config_t {
  c10::once_flag flag;
  std::array<bool, 32> blocking_wait;
  bool enable_comms_logger;
  bool use_future;
  // Sharing UCC communicator among multiple PGs to save resource.
  bool shared_comm;
  // Using allgatherv to achieve allgather, without flattening the list of
  // (potentially non-contiguous) tensors.
  bool use_allgatherv;
  bool enable_health_check;
} torch_ucc_config;

std::unordered_map<std::string, std::string> torch_ucc_envs_map = {
    // TORCH_UCC_BLOCKING_WAIT allowed syntax:
    // - TORCH_UCC_BLOCKING_WAIT=none --> blocking wait completely disabled
    // - TORCH_UCC_BLOCKING_WAIT=all --> blocking wait completely enabled
    // - TORCH_UCC_BLOCKING_WAIT=allreduce,send,recv --> blocking wait enabled
    //                                                   on selected operations
    // Supported operations:
    // [allgather,allgather_base,allreduce,alltoall,broadcast,
    //  gather,reduce,reduce_scatter,scatter,send,recv]
    {"TORCH_UCC_BLOCKING_WAIT", "none"},

    {"TORCH_UCC_USE_FUTURE", "1"},
    {"TORCH_UCC_PROFILING_ENABLE", "0"},
    {"TORCH_UCC_SHARED_COMM", "1"},
    {"TORCH_UCC_USE_ALLGATHERV", "0"},
    {"TORCH_UCC_ENABLE_HEALTH_CHECK", "0"},
    {"TORCH_UCC_ENABLE_COMMS_LOGGER", "0"},
};

std::vector<OpType> parse_blocking_wait(std::string op_list_string) {
  const static std::unordered_map<std::string, OpType> str2op = {
      {"allgather", OpType::ALLGATHER},
      {"allgather_base", OpType::_ALLGATHER_BASE},
      {"allreduce", OpType::ALLREDUCE},
      {"alltoall_base", OpType::ALLTOALL_BASE},
      {"broadcast", OpType::BROADCAST},
      {"gather", OpType::GATHER},
      {"reduce", OpType::REDUCE},
      {"reduce_scatter", OpType::REDUCE_SCATTER},
      {"scatter", OpType::SCATTER},
      {"send", OpType::SEND},
      {"recv", OpType::RECV},
  };
  auto op_list = parse_list(op_list_string);
  if (op_list == std::vector<std::string>{"none"}) {
    return {};
  }
  std::vector<OpType> result;
  if (op_list == std::vector<std::string>{"all"}) {
    for (auto entry : str2op) {
      result.push_back(entry.second);
    }
  } else {
    for (auto op_string : op_list) {
      result.push_back(str2op.at(op_string));
    }
  }
  return result;
}

} // namespace

void read_config() {
  // default configuration
  torch_ucc_config.blocking_wait.fill(false);
  torch_ucc_config.use_future = true;
  torch_ucc_config.shared_comm = false;
  torch_ucc_config.use_allgatherv = false;
  torch_ucc_config.enable_health_check = false;
  torch_ucc_config.enable_comms_logger = false;

  // read all torch_ucc env. variables and update the map
  char* env;
  for (auto& torch_ucc_env : torch_ucc_envs_map) {
    env = std::getenv(torch_ucc_env.first.c_str());
    if (env) {
      torch_ucc_envs_map[torch_ucc_env.first] = std::string(env);
    }
  }

  auto blocking_wait_str = torch_ucc_envs_map.at("TORCH_UCC_BLOCKING_WAIT");
  for (auto op : parse_blocking_wait(blocking_wait_str)) {
    torch_ucc_config.blocking_wait[(std::uint8_t)op] = true;
  }
  // barrier is always blocking
  torch_ucc_config.blocking_wait[(std::uint8_t)OpType::BARRIER] = true;

  torch_ucc_config.use_future =
      std::stoi(torch_ucc_envs_map.at("TORCH_UCC_USE_FUTURE"));
  torch_ucc_config.shared_comm =
      std::stoi(torch_ucc_envs_map.at("TORCH_UCC_SHARED_COMM"));
  torch_ucc_config.use_allgatherv =
      std::stoi(torch_ucc_envs_map.at("TORCH_UCC_USE_ALLGATHERV"));
  torch_ucc_config.enable_health_check =
      std::stoi(torch_ucc_envs_map.at("TORCH_UCC_ENABLE_HEALTH_CHECK"));
  torch_ucc_config.enable_comms_logger =
      std::stoi(torch_ucc_envs_map.at("TORCH_UCC_ENABLE_COMMS_LOGGER"));
}

void check_device(c10::Device dev1, c10::Device dev2) {
  if (dev1.is_cuda() && dev2.is_cuda() && dev1 != dev2) {
    throw std::invalid_argument("ProcessGroupUCC multidevice is not supported");
  }
}

void check_tensor(const std::vector<at::Tensor>& tensors) {
  if (tensors.size() != 1) {
    throw std::invalid_argument(
        "ProcessGroupUCC takes 1 tensor. Got " +
        std::to_string(tensors.size()) + ". ");
  }
  if (!tensors[0].is_contiguous()) {
    throw std::invalid_argument(
        "ProcessGroupUCC input tensor has to be contiguous");
  }
  if (tensors[0].is_sparse()) {
    throw std::invalid_argument("ProcessGroupUCC input tensor has to be dense");
  }
  // TODO: check cuda case
}

ProcessGroupUCC::WorkUCC::~WorkUCC() {
#ifdef USE_CUDA
  if (fence && ep) {
    std::lock_guard<std::mutex> lock(ep->event_pool_mutex);
    ep->event_pool.push(std::move(fence));
  }
#endif
}

void ProcessGroupUCC::WorkUCC::setException() {
  if (exception() || !entry_) {
    return;
  }
  exception_ = entry_->eptr_;
}

void ProcessGroupUCC::WorkUCC::setAndThrowException() {
  setException();
  if (exception()) {
    std::rethrow_exception(exception());
  }
}

bool ProcessGroupUCC::WorkUCC::isCompleted() {
  if (!entry_) {
    return true;
  }
  setException();
  // status_ <= 0 to avoid listing all possible status codes.  The main thread
  // needs to be unblocked when UCC (in progress thread) returns success (== 0)
  // or any error code (< 0).
  return exception() || entry_->status_ <= 0;
}

bool ProcessGroupUCC::WorkUCC::isSuccess() const {
  if (!entry_) {
    return true;
  }
  return !exception() && entry_->status_ == 0;
}

bool ProcessGroupUCC::WorkUCC::wait(std::chrono::milliseconds /* unused */) {
  if (torch_ucc_config.enable_comms_logger && logger_) {
    logger_->trace_generator->recordComms("wait", (uintptr_t)this, rank_);
  }
#ifdef USE_CUDA
  if (fence && !torch_ucc_config.blocking_wait[(int)opType_]) {
    // block user stream
    setAndThrowException();
    fence->block(at::cuda::getCurrentCUDAStream());
    return true;
  }
#endif
  // wait for complete.  For blocking case, the main thread will be blocked in
  // this loop until the progress thread changes the status of this request.
  // If timeout occurs, UCC will return UCC_ERR_TIMEOUT as the status.  The
  // main thread will throw out the exception then. There is no "abort"
  // function in UCC currently.
  while (!isCompleted())
    ;
  setAndThrowException();
  // manually call profiling end callbacks if they are set,
  // since progress thread does not own WorkUCC
  if (Work::recordFunctionEndCallback_) {
    Work::recordFunctionEndCallback_();
    Work::recordFunctionEndCallback_ = nullptr;
  }
  return true;
}

c10::intrusive_ptr<c10::ivalue::Future> ProcessGroupUCC::WorkUCC::getFuture() {
  return future_;
}

int ProcessGroupUCC::WorkUCC::sourceRank() const {
  if (opType_ != OpType::RECV && opType_ != OpType::RECVANYSOURCE) {
    // Throw an error
    return Work::sourceRank();
  }
  return sourceRank_;
}

std::vector<at::Tensor> ProcessGroupUCC::WorkUCC::result() {
  return *outputs_;
}

void ProcessGroupUCC::ProgressEntry::finalize(std::exception_ptr eptr) {
  ucc_status_t status = UCC_OK;

  if (request_ != nullptr) {
    status = request_->status;
    comm_->free_request(request_);
  }
  if (eptr) {
    eptr_ = eptr;
  } else {
    status_ = status;
  }
  if (future_) {
    if (eptr) {
      future_->setError(eptr);
    } else {
      future_->markCompleted(
          c10::IValue(data ? data->dst : std::vector<at::Tensor>()));
    }
  }
}

Comm::Comm(
    const c10::intrusive_ptr<ProcessGroupUCCLogger>& logger_,
    std::shared_ptr<torch_ucc_oob_coll_info_t> oob_,
    c10::Device dev,
    bool is_health_check)
    : logger(logger_),
      oob(oob_),
      ucc_comm(oob, logger),
      finalize_phase(
          is_health_check ? TORCH_UCC_HEALTH_CHECK : TORCH_UCC_FINALIZE),
      cuda_device_index(TORCH_UCC_DEVICE_NOT_SET) {
  if (dev.is_cuda()) {
    cuda_device_index = dev.index();
  }
  stop_progress_loop = false;
  collective_inprogress = false;
  progress_thread = std::thread(&Comm::progress_loop, this);
#ifdef _GNU_SOURCE
  pthread_setname_np(progress_thread.native_handle(), "ucc-progress");
#endif
}

Comm::~Comm() {
  std::unique_lock<std::mutex> lock(mutex);
  queue_consume_cv.wait(
      lock, [&] { return progress_queue.empty() && !collective_inprogress; });
  stop_progress_loop = true;
  lock.unlock();
  queue_produce_cv.notify_all();
  progress_thread.join();
}

std::shared_ptr<Comm> Comm::get_comm(
    uint32_t& id,
    c10::Device dev,
    std::shared_ptr<torch_ucc_oob_coll_info_t> oob,
    const c10::intrusive_ptr<ProcessGroupUCCLogger>& logger,
    bool is_health_check) {
  static std::mutex m;
  static std::weak_ptr<Comm> comm;
  static uint32_t comm_id;

  std::lock_guard<std::mutex> lock(m);
  id = comm_id;

  std::string group_id = "group_id";
  if (is_health_check) {
    group_id = c10::str(dev.type()) + "/" + group_id;
  }

  std::vector<uint8_t> remote_comm_id;
  oob->store->deleteKey(group_id + std::to_string(0));
  if (oob->rank != 0) {
    std::vector<uint8_t> val = std::vector<uint8_t>(
        reinterpret_cast<uint8_t*>(&id),
        reinterpret_cast<uint8_t*>(&id) + sizeof(id));
    oob->store->set(group_id + std::to_string(oob->rank), val);
  } else {
    for (int i = 1; i < oob->size; i++) {
      remote_comm_id = oob->store->get(group_id + std::to_string(i));
      oob->store->deleteKey(group_id + std::to_string(i));
      // Find the highest id.
      id = std::max(id, *(reinterpret_cast<uint32_t*>(remote_comm_id.data())));
    }
    std::vector<uint8_t> val = std::vector<uint8_t>(
        reinterpret_cast<uint8_t*>(&id),
        reinterpret_cast<uint8_t*>(&id) + sizeof(id));
    oob->store->set(group_id + std::to_string(oob->rank), val);
  }
  remote_comm_id = oob->store->get(group_id + std::to_string(0));
  oob->comm_id = *(reinterpret_cast<uint32_t*>(remote_comm_id.data()));
  // Prepare comm_id (static variable) to the next id.
  comm_id = oob->comm_id + 1;

  if (torch_ucc_config.shared_comm) {
    std::shared_ptr<Comm> shared_comm = comm.lock();
    if (!shared_comm) {
      shared_comm = std::make_shared<Comm>(logger, oob, dev, is_health_check);
      comm = shared_comm;
    } else {
      if (dev.is_cuda() && !is_health_check) {
        if ((shared_comm->cuda_device_index != TORCH_UCC_DEVICE_NOT_SET) &&
            (shared_comm->cuda_device_index != dev.index())) {
          TORCH_UCC_LOG_ERROR(
              is_health_check ? TORCH_UCC_HEALTH_CHECK : TORCH_UCC_INIT,
              "ucc communicator was initialized with different cuda device,"
              "multi device is not supported");
          throw std::invalid_argument(ucc_status_string(UCC_ERR_NOT_SUPPORTED));
        }
        shared_comm->cuda_device_index = dev.index();
      }
    }
    return shared_comm;
  } else {
    return std::make_shared<Comm>(logger, oob, dev, is_health_check);
  }
}

void Comm::ucc_create_team(
    ucc_team_h& team,
    std::shared_ptr<torch_ucc_oob_coll_info_t> oob) {
  ucc_status_t st;
  ucc_team_params_t team_params;
  team_params.mask = UCC_TEAM_PARAM_FIELD_EP | UCC_TEAM_PARAM_FIELD_EP_RANGE |
      UCC_TEAM_PARAM_FIELD_OOB;
  team_params.oob.allgather = oob_allgather;
  team_params.oob.req_test = oob_allgather_test;
  team_params.oob.req_free = oob_allgather_free;
  team_params.oob.coll_info = oob.get();
  team_params.oob.n_oob_eps = oob->size;
  team_params.oob.oob_ep = oob->rank;
  team_params.ep = oob->rank;
  team_params.ep_range = UCC_COLLECTIVE_EP_RANGE_CONTIG;
  TORCH_UCC_CHECK(
      ucc_team_create_post(&ucc_comm.context, 1, &team_params, &team),
      "failed to post team create");
  do {
    st = ucc_team_create_test(team);
    ucc_context_progress(ucc_comm.context);
  } while (st == UCC_INPROGRESS);
  TORCH_UCC_CHECK(st, "failed to create UCC team");
}

void Comm::ucc_destroy_team(ucc_team_h& team) {
  std::unique_lock<std::mutex> lock(mutex);
  queue_consume_cv.wait(
      lock, [&] { return progress_queue.empty() && !collective_inprogress; });

  ucc_status_t status;
  while (UCC_INPROGRESS == (status = ucc_team_destroy(team))) {
    if (UCC_OK != status) {
      TORCH_UCC_LOG_ERROR(
          finalize_phase,
          c10::str("ucc team destroy error: ", ucc_status_string(status)));
      break;
    }
  }

  lock.unlock();
}

void Comm::enqueue_collective(
    std::unique_ptr<ProcessGroupUCC::WorkData> data,
    c10::intrusive_ptr<ProcessGroupUCC::WorkUCC> work,
    ucc_coll_args_t& coll,
    ucc_team_h team) {
  ucc_coll_req_h request;
  TORCH_UCC_CHECK(
      ucc_collective_init(&coll, &request, team), "failed to init collective");
  TORCH_UCC_CHECK_REQUEST(
      request, ucc_collective_post(request), "failed to post collective");

  auto entry =
      std::make_shared<ProcessGroupUCC::ProgressEntry>(&ucc_comm, request);
  entry->data = std::move(data);
  entry->future_ = work->getFuture();
  work->entry_ = entry;
  std::unique_lock<std::mutex> lock(mutex);
  progress_queue.push_back(entry);
  lock.unlock();
  queue_produce_cv.notify_one();
}

#ifdef USE_CUDA
void Comm::enqueue_cuda_collective(
    std::unique_ptr<ProcessGroupUCC::WorkData> data,
    c10::intrusive_ptr<ProcessGroupUCC::WorkUCC> work,
    ucc_coll_args_t& coll,
    ucc_team_h team,
    ucc_ee_h ee) {
  ucc_coll_req_h request;
  TORCH_UCC_CHECK(
      ucc_collective_init(&coll, &request, team),
      "failed to init cuda collective");
  ucc_ev_t comp_ev, *post_ev;
  comp_ev.ev_type = UCC_EVENT_COMPUTE_COMPLETE;
  comp_ev.ev_context = nullptr;
  comp_ev.ev_context_size = 0;
  comp_ev.req = request;
  TORCH_UCC_CHECK_REQUEST(
      request,
      ucc_collective_triggered_post(ee, &comp_ev),
      "failed to post triggered collective");
  ucc_status_t st = ucc_ee_get_event(ee, &post_ev);
  TORCH_CHECK(st == UCC_OK && post_ev->ev_type == UCC_EVENT_COLLECTIVE_POST);
  ucc_ee_ack_event(ee, post_ev);
  auto entry =
      std::make_shared<ProcessGroupUCC::ProgressEntry>(&ucc_comm, request);
  entry->data = std::move(data);
  work->entry_ = entry;
  std::unique_lock<std::mutex> lock(mutex);
  progress_queue.push_back(entry);
  lock.unlock();
  queue_produce_cv.notify_one();
}
#endif

void Comm::progress_loop() {
  std::unique_lock<std::mutex> lock(mutex);
#ifdef USE_CUDA
  bool device_set = false;
#endif
  while (!stop_progress_loop) {
    if (progress_queue.empty()) {
      queue_produce_cv.wait(lock);
      continue;
    }
    collective_inprogress = true;
    auto work = progress_queue.front();
    progress_queue.pop_front();
    lock.unlock();
#ifdef USE_CUDA
    if ((!device_set) && (cuda_device_index != TORCH_UCC_DEVICE_NOT_SET)) {
      c10::cuda::set_device(cuda_device_index);
      CUcontext pctx = nullptr;
      at::globalContext().getNVRTC().cuCtxGetCurrent(&pctx);
      if (C10_UNLIKELY(!pctx)) {
        at::globalContext().getNVRTC().cuDevicePrimaryCtxRetain(
            &pctx, cuda_device_index);
        at::globalContext().getNVRTC().cuCtxSetCurrent(pctx);
      }
      device_set = true;
    }
#endif
    std::exception_ptr eptr;
    try {
      while (work->request_->status > 0) {
        ucc_comm.progress();
      }
      if (work->request_->status < 0) {
        eptr = std::make_exception_ptr(
            std::runtime_error(ucc_status_string(work->request_->status)));
        std::string err_log = c10::str(
            "Failed to progress communication", // TODO: report exact op type or
                                                // id?
            ucc_status_string(work->request_->status));
        TORCH_UCC_LOG_ERROR(TORCH_UCC_COLL_PROGRESS, err_log);
      }
    } catch (...) {
      eptr = std::current_exception();
    }
    work->finalize(eptr);
    work = nullptr;
    collective_inprogress = false;
    queue_consume_cv.notify_one();
    lock.lock();
  }
}

ProcessGroupUCC::ProcessGroupUCC(
    const c10::intrusive_ptr<Store>& store,
    int rank,
    int size,
    std::chrono::duration<float> timeout)
    : Backend(rank, size), timeout_(timeout) {
  c10::call_once(torch_ucc_config.flag, read_config);
  oob = std::make_shared<torch_ucc_oob_coll_info_t>();
  oob->rank = rank;
  oob->size = size;
  oob->store = store;
  comm = nullptr;
  cuda_ee = nullptr;
  static uint32_t id = 0;
  uint32_t pg_id = id++;

  logger = c10::make_intrusive<ProcessGroupUCCLogger>(
      c10::str("[Rank ", rank_, "]", "[ProcessGroupUCC-", pg_id, "]"),
      TORCH_UCC_INIT);
  TORCH_UCC_LOG_INFO(
      TORCH_UCC_INIT,
      c10::str(
          "Created ProcessGroupUCC with ",
          size,
          " ranks, with timeout ",
          timeout_.count(),
          " secs"));
  std::string envs = "";
  for (auto& torch_ucc_env : torch_ucc_envs_map) {
    envs += ("\n\t" + torch_ucc_env.first + "=" + torch_ucc_env.second);
  }
  TORCH_UCC_LOG_INFO(
      TORCH_UCC_INIT,
      c10::str(
          "Successfully read and set ProcessGroupUCC env. variables as followings",
          envs));

  if (torch_ucc_config.enable_health_check) {
    // Perform health check by initializing dummy communicators and destroying
    // them. This will help indicate any UCC/UCX-related issues prior to the
    // first collective. Run it in a separate thread and wait on CV to handle
    // timeouts so that if there are hangs, the main thread can still run
    // correctly.
    runHealthCheck();
  }
  if (torch_ucc_config.enable_comms_logger) {
    logger->initCommsTracer();
  }
}

ProcessGroupUCC::~ProcessGroupUCC() {
  if (torch_ucc_config.enable_comms_logger) {
    logger->flushComms(this->getRank(), this->getSize());
  }
  if (comm) {
    logger->setPhase(TORCH_UCC_FINALIZE);
    comm->ucc_destroy_team(team);
    TORCH_UCC_LOG_INFO(
        TORCH_UCC_FINALIZE, "Successfully destroyed UCC library");
    try {
      if (cuda_ee) {
        ucc_ee_destroy(cuda_ee);
        ucc_ee_destroy(cuda_ee_p2p[0]);
        ucc_ee_destroy(cuda_ee_p2p[1]);
      }
    } catch (std::exception& ex) {
      TORCH_UCC_LOG_INFO(
          TORCH_UCC_FINALIZE,
          c10::str(
              "(~ProcessGroupUCC) Caught error in Store Operation .. ",
              "[",
              ex.what(),
              "]"));
    }
    comm = nullptr;
  }
}

#ifdef USE_CUDA
// Return CUDA device with ordinal given by input rank.
c10::Device getCUDADeviceForRank(int rank) {
  TORCH_CHECK(rank >= 0, "Invalid rank ", rank);
  auto numGPUs = at::cuda::getNumGPUs();
  auto deviceIdx = static_cast<c10::DeviceIndex>(rank % numGPUs);
  return c10::Device(c10::DeviceType::CUDA, deviceIdx);
}
#endif

void ProcessGroupUCC::runHealthCheck() {
  // Run health check in a separate thread and wait on CV to handle timeouts.
  // This design allows us to handle hangs.

  // When size_ is 1, there is no need to do any communication at all.
  if (size_ == 1)
    return;

  struct HealthCheckData {
    std::mutex healthCheckMutex;
    std::condition_variable healthCheckCv;
    bool uccHealthCheckSuccess = false;
    std::exception_ptr healthCheckException;
  } healthCheckData;

  auto t = std::thread([&healthCheckData, this]() {
    std::list<c10::Device> devices{c10::kCPU};
#ifdef USE_CUDA
    c10::cuda::OptionalCUDAGuard gpuGuard;
    if (at::cuda::is_available()) {
      devices.emplace_front(getCUDADeviceForRank(rank_));
    }
#endif
    for (auto device : devices) {
      bool is_last_device = (device == devices.back());
      try {
        auto oob = std::make_shared<torch_ucc_oob_coll_info_t>();
        oob->rank = this->oob->rank;
        oob->size = this->oob->size;
        oob->store = this->oob->store;
        ucc_team_h team = nullptr;
        uint32_t comm_id;
#ifdef USE_CUDA
        if (device.is_cuda()) {
          gpuGuard.set_index(device.index());
        }
#endif
        auto comm = Comm::get_comm(comm_id, device, oob, logger, true);
        comm->ucc_create_team(team, oob);
        comm->ucc_destroy_team(team);
        TORCH_UCC_LOG_INFO(
            TORCH_UCC_HEALTH_CHECK,
            c10::str(
                "UCC library health check succeed for device ",
                c10::DeviceTypeName(device.type())));
        // Mark ucc health check as complete.
        if (is_last_device) {
          std::lock_guard<std::mutex> lk(healthCheckData.healthCheckMutex);
          healthCheckData.uccHealthCheckSuccess = true;
        }

        comm = nullptr;
        oob = nullptr;
        // Notify main thread the health check is complete.
        if (is_last_device) {
          healthCheckData.healthCheckCv.notify_one();
        }
      } catch (const std::exception&) {
        // Populate exception ptr.
        healthCheckData.healthCheckException = std::current_exception();
        // Unblock waiting main thread which will report exception.
        healthCheckData.healthCheckCv.notify_one();
      } // Unknown exceptions will just cause the program to terminate.
    }
  });
  // We don't need to join the thread, just need to verify health check via the
  // CV. Hence we detach the thread here.
  t.detach(); // NOLINT
  TORCH_UCC_LOG_INFO(
      TORCH_UCC_HEALTH_CHECK,
      c10::str(
          "will wait up to ",
          timeout_.count(),
          " msec for UCC health check to complete."));
  std::unique_lock<std::mutex> lock(healthCheckData.healthCheckMutex);
  healthCheckData.healthCheckCv.wait_for(lock, timeout_, [&healthCheckData]() {
    return healthCheckData.uccHealthCheckSuccess;
  });

  if (healthCheckData.healthCheckException) {
    std::rethrow_exception(healthCheckData.healthCheckException);
  }
  // If there is no exception, the likely culprit is a timeout/hang
  TORCH_CHECK(
      healthCheckData.uccHealthCheckSuccess,
      "ProcessGroupUCC: Health check failure: Failed to initialize UCC on rank ",
      rank_);
}

void ProcessGroupUCC::set_timeout(ucc_coll_args_t& args) {
  args.mask |= UCC_COLL_ARGS_FIELD_FLAGS;
  args.flags |= UCC_COLL_ARGS_FLAG_TIMEOUT;
  args.timeout = timeout_.count();
}

#ifdef USE_CUDA
std::unique_ptr<at::cuda::CUDAEvent> ProcessGroupUCC::getPooledEvent() {
  std::unique_ptr<at::cuda::CUDAEvent> ev;
  std::lock_guard<std::mutex> lock(ep.event_pool_mutex);
  if (ep.event_pool.empty()) {
    ev = std::make_unique<at::cuda::CUDAEvent>();
  } else {
    ev = std::move(ep.event_pool.front());
    ep.event_pool.pop();
  }
  return ev;
}
#endif

template <typename PreProcess, typename PostProcess>
c10::intrusive_ptr<Work> ProcessGroupUCC::collective_post(
    OpType opType,
    PreProcess preproc,
    PostProcess postproc,
    ucc_coll_args_t& coll,
    std::unique_ptr<ProcessGroupUCC::WorkData> data,
    c10::Device dev,
    std::vector<at::Tensor>& inputTensors,
    std::vector<at::Tensor>& outputTensors,
    const char* prof_title) {
  seq_++;
  set_timeout(coll);
  auto work = c10::make_intrusive<ProcessGroupUCC::WorkUCC>(
      opType, seq_, prof_title, inputTensors, logger);

  if (opType == OpType::RECV) {
    work->sourceRank_ = coll.root;
  }

  RECORD_COMMS_TRACE(
      logger->trace_generator,
      work,
      opType,
      this->getRank(),
      this->getSize(),
      inputTensors,
      outputTensors);

  // Store references to outputs to be used by result
  work->outputs_ = std::make_shared<std::vector<at::Tensor>>(outputTensors);
  switch (dev.type()) {
    case c10::DeviceType::CPU: {
      if (torch_ucc_config.use_future) {
        work->future_ = c10::make_intrusive<at::ivalue::Future>(
            c10::ListType::create(c10::TensorType::get()));
      }
      preproc();
      comm->enqueue_collective(std::move(data), work, coll, team);
      postproc();
      return work;
    }
#ifdef USE_CUDA
    case c10::DeviceType::CUDA: {
      auto cuda_ev = getPooledEvent();
      at::cuda::CUDAStream* op_stream;
      ucc_ee_h* op_ee;
      if (opType == OpType::SEND) {
        op_stream = stream_p2p[0].get();
        op_ee = &cuda_ee_p2p[0];
      } else if (opType == OpType::RECV) {
        op_stream = stream_p2p[1].get();
        op_ee = &cuda_ee_p2p[1];
      } else {
        op_stream = stream.get();
        op_ee = &cuda_ee;
      }

      cuda_ev->record(at::cuda::getCurrentCUDAStream(dev.index()));
      cuda_ev->block(*op_stream);
      at::cuda::CUDAStreamGuard guard(*op_stream);
      preproc();
      comm->enqueue_cuda_collective(std::move(data), work, coll, team, *op_ee);
      postproc();
      cuda_ev->record(*op_stream);
      work->fence = std::move(cuda_ev);
      work->ep = &ep;
      if (torch_ucc_config.use_future) {
        c10::cuda::CUDAMultiStreamGuard streamGuard(*op_stream);
        std::vector<c10::Device> devList{dev};
        work->future_ = c10::make_intrusive<at::ivalue::Future>(
            c10::ListType::create(c10::TensorType::get()), devList);
        // Add a callback that runs profiling end callbacks
        if (work->recordFunctionEndCallback_) {
          work->future_->addCallback([work](at::ivalue::Future& /* unused */) {
            work->recordFunctionEndCallback_();
          });
        }

        work->future_->markCompleted(c10::IValue(outputTensors));
      }
      return work;
    }
#endif // #ifdef USE_CUDA
    default: {
      TORCH_UCC_LOG_ERROR(
          TORCH_UCC_COLL_POST, c10::str("unsupported device type ", dev.str()));
      throw std::invalid_argument(ucc_status_string(UCC_ERR_NOT_SUPPORTED));
    }
  }
}

c10::intrusive_ptr<Work> ProcessGroupUCC::allgather(
    std::vector<std::vector<at::Tensor>>& outputTensors,
    std::vector<at::Tensor>& inputTensors,
    const AllgatherOptions& /* unused */) {
  auto& tensor = inputTensors[0];
  check_device(tensor.device(), outputTensors[0][0].device());
  initComm(tensor.device());

  if (tensor.device().is_cpu() || torch_ucc_config.use_allgatherv) {
    AllgathervWorkData* data = new AllgathervWorkData(size_);
    for (int i = 0; i < size_; i++) {
      data->recv_lengths[i] = tensor.element_size() * tensor.numel();
      data->recv_offsets[i] = (uint64_t)outputTensors[0][i].data_ptr();
    }
    ucc_coll_args_t coll;
    coll.mask = UCC_COLL_ARGS_FIELD_FLAGS;
    coll.flags =
        UCC_COLL_ARGS_FLAG_COUNT_64BIT | UCC_COLL_ARGS_FLAG_DISPLACEMENTS_64BIT;
    coll.coll_type = UCC_COLL_TYPE_ALLGATHERV;
    coll.src.info.buffer = tensor.data_ptr();
    coll.src.info.count = tensor.element_size() * tensor.numel();
    coll.src.info.datatype = UCC_DT_UINT8;
    coll.src.info.mem_type = to_ucc_memType(tensor.device().type());
    coll.dst.info_v.buffer = nullptr;
    coll.dst.info_v.counts = (ucc_count_t*)data->recv_lengths.data();
    coll.dst.info_v.displacements = (ucc_aint_t*)data->recv_offsets.data();
    coll.dst.info_v.datatype = UCC_DT_UINT8;
    coll.dst.info_v.mem_type =
        to_ucc_memType(outputTensors[0][0].device().type());
    SAVE_TENSORS(inputTensors, data->src);
    SAVE_TENSORS(outputTensors[0], data->dst);

    return collective_post(
        OpType::ALLGATHER,
        []() {},
        []() {},
        coll,
        std::unique_ptr<WorkData>(data),
        tensor.device(),
        inputTensors,
        outputTensors[0],
        "ucc:all_gather");
  } else {
    WorkData* data = new WorkData();
    std::vector<at::Tensor> flat_output(outputTensors.size());
    for (size_t i = 0; i < outputTensors.size(); i++) {
      TORCH_CHECK(
          outputTensors[i].size() == outputTensors.size() * size_,
          "Tensor output list is not valid for the number of participants");
      flat_output[i] = c10d::newLikeFlat(outputTensors, i);
    }
    SAVE_TENSORS(flat_output, data->flat);
    ucc_coll_args_t coll;
    coll.mask = 0;
    coll.flags = 0;
    coll.coll_type = UCC_COLL_TYPE_ALLGATHER;
    coll.src.info.buffer = tensor.data_ptr();
    coll.src.info.count = tensor.numel();
    coll.src.info.datatype = to_ucc_dType(tensor);
    coll.src.info.mem_type = to_ucc_memType(tensor.device().type());
    coll.dst.info.buffer = flat_output[0].data_ptr();
    coll.dst.info.count = flat_output[0].numel();
    coll.dst.info.datatype = to_ucc_dType(flat_output[0]);
    coll.dst.info.mem_type =
        to_ucc_memType(outputTensors[0][0].device().type());

    auto copy_from_flat = [&] {
      bool asyncCopy = false;
#ifdef USE_CUDA
      bool isCuda = outputTensors[0][0].device().is_cuda();
      ;
#endif
      for (size_t i = 0; i < outputTensors.size(); i++) {
        auto inumel = inputTensors[i].numel();
        for (size_t j = 0; j < outputTensors[i].size(); j++) {
          TORCH_CHECK(
              (outputTensors[i][j].numel() == inumel),
              "Tensor operand counts must be same");
#ifdef USE_CUDA
          if (isCuda) {
            c10::cuda::CUDACachingAllocator::recordStream(
                outputTensors[i][j].storage().data_ptr(), (*stream));
            asyncCopy = true;
          }
#endif
          outputTensors[i][j].copy_(flat_output[i][j], asyncCopy);
        }
      }
    };
    return collective_post(
        OpType::ALLGATHER,
        []() {},
        copy_from_flat,
        coll,
        std::unique_ptr<WorkData>(data),
        tensor.device(),
        inputTensors,
        outputTensors[0],
        "ucc:all_gather");
  }
}

c10::intrusive_ptr<Work> ProcessGroupUCC::_allgather_base(
    at::Tensor& outputTensor,
    at::Tensor& inputTensor,
    const AllgatherOptions& opts) {
  check_tensor({outputTensor});
  check_tensor({inputTensor});
  initComm(outputTensor.device());

  WorkData* data = new WorkData();

  ucc_coll_args_t coll;
  coll.mask = 0;
  coll.flags = 0;
  coll.coll_type = UCC_COLL_TYPE_ALLGATHER;
  coll.src.info.buffer = inputTensor.data_ptr();
  coll.src.info.count = inputTensor.numel();
  coll.src.info.datatype = ucc_dtype_map.at(inputTensor.scalar_type());
  coll.src.info.mem_type = to_ucc_memType(inputTensor.device().type());
  coll.dst.info.buffer = outputTensor.data_ptr();
  coll.dst.info.count = outputTensor.numel();
  coll.dst.info.datatype = ucc_dtype_map.at(outputTensor.scalar_type());
  coll.dst.info.mem_type = to_ucc_memType(outputTensor.device().type());

  std::vector<at::Tensor> inputTensors = {inputTensor};
  std::vector<at::Tensor> outputTensors = {outputTensor};
  SAVE_TENSORS(inputTensors, data->src);
  SAVE_TENSORS(outputTensors, data->dst);

  return collective_post(
      OpType::_ALLGATHER_BASE,
      []() {},
      []() {},
      coll,
      std::unique_ptr<WorkData>(data),
      outputTensor.device(),
      inputTensors,
      outputTensors,
      "ucc:allgather_base");
}

c10::intrusive_ptr<Work> ProcessGroupUCC::allreduce(
    std::vector<at::Tensor>& tensors,
    const AllreduceOptions& opts) {
  check_tensor(tensors);
  auto& tensor = tensors[0];
  initComm(tensor.device());
  WorkData* data = new WorkData();

  ucc_coll_args_t coll;
  coll.mask = UCC_COLL_ARGS_FIELD_FLAGS;
  coll.flags = UCC_COLL_ARGS_FLAG_IN_PLACE;
  coll.coll_type = UCC_COLL_TYPE_ALLREDUCE;
  coll.op = to_ucc_reduceOp(opts.reduceOp, tensor.scalar_type());
  coll.src.info.buffer = nullptr;
  coll.src.info.count = tensor.numel();
  coll.src.info.datatype = to_ucc_dType(tensor);
  coll.src.info.mem_type = to_ucc_memType(tensor.device().type());
  coll.dst.info.buffer = tensor.data_ptr();
  coll.dst.info.count = tensor.numel();
  coll.dst.info.datatype = to_ucc_dType(tensor);
  coll.dst.info.mem_type = to_ucc_memType(tensor.device().type());
  SAVE_TENSORS(tensors, data->dst);
  return collective_post(
      OpType::ALLREDUCE,
      []() {},
      []() {},
      coll,
      std::unique_ptr<WorkData>(data),
      tensor.device(),
      tensors,
      tensors,
      "ucc:all_reduce");
}

c10::intrusive_ptr<Work> ProcessGroupUCC::allreduce_coalesced(
    std::vector<at::Tensor>& /* unused */,
    const AllreduceCoalescedOptions& /* unused */) {
  throw std::invalid_argument(
      "ProcessGroupUCC does not support allreduce_coalesced");
}

c10::intrusive_ptr<Work> ProcessGroupUCC::alltoall(
    std::vector<at::Tensor>& outputTensors,
    std::vector<at::Tensor>& inputTensors,
    const AllToAllOptions& /* unused */) {
  auto device = outputTensors[0].device();
  for (const auto r : c10::irange(outputTensors.size())) {
    TORCH_CHECK(
        device == outputTensors[r].device() &&
            device == inputTensors[r].device(),
        "Tensors must be on the same device")
  }

  initComm(device);
  ucc_coll_args_t coll;
  AlltoallWorkData* data;
  data = new AlltoallWorkData(size_);

  /* to avoid flatten the tensors, we use alltoallv to achieve Alltoall as
     follow.
      1. store addresses of each tensor directly in displacements, keep buffer
     to nullptr, i.e., 0
      2. convert datatype to UINT8, which is always 1 bytes, to avoid wrong size
     calculation in UCC layer
      3. post Alltoallv
  */
  for (const auto i : c10::irange(size_)) {
    data->send_lengths[i] =
        (uint64_t)(inputTensors[i].element_size() * inputTensors[i].numel());
    data->send_offsets[i] = (uint64_t)inputTensors[i].data_ptr();
    data->recv_lengths[i] =
        (uint64_t)(outputTensors[i].element_size() * outputTensors[i].numel());
    data->recv_offsets[i] = (uint64_t)outputTensors[i].data_ptr();
  }

  coll.mask = UCC_COLL_ARGS_FIELD_FLAGS;
  coll.flags =
      UCC_COLL_ARGS_FLAG_COUNT_64BIT | UCC_COLL_ARGS_FLAG_DISPLACEMENTS_64BIT;
  coll.coll_type = UCC_COLL_TYPE_ALLTOALLV;
  coll.src.info_v.buffer = 0;
  coll.src.info_v.counts = (ucc_count_t*)data->send_lengths.data();
  coll.src.info_v.displacements = (ucc_aint_t*)data->send_offsets.data();
  coll.src.info_v.datatype = UCC_DT_UINT8;
  coll.src.info_v.mem_type = to_ucc_memType(inputTensors[0].device().type());
  coll.dst.info_v.buffer = 0;
  coll.dst.info_v.counts = (ucc_count_t*)data->recv_lengths.data();
  coll.dst.info_v.displacements = (ucc_aint_t*)data->recv_offsets.data();
  coll.dst.info_v.datatype = UCC_DT_UINT8;
  coll.dst.info_v.mem_type = to_ucc_memType(outputTensors[0].device().type());

  SAVE_TENSORS(inputTensors, data->src);
  SAVE_TENSORS(outputTensors, data->dst);

  return collective_post(
      OpType::ALLTOALL,
      []() {},
      []() {},
      coll,
      std::unique_ptr<WorkData>(data),
      device,
      inputTensors,
      outputTensors,
      "ucc:alltoall");
}

c10::intrusive_ptr<Work> ProcessGroupUCC::alltoall_base(
    at::Tensor& outputTensor,
    at::Tensor& inputTensor,
    std::vector<int64_t>& outputSplitSizes,
    std::vector<int64_t>& inputSplitSizes,
    const AllToAllOptions& /* unused */) {
  check_device(inputTensor.device(), outputTensor.device());
  initComm(inputTensor.device());
  ucc_coll_args_t coll;
  AlltoallWorkData* data;

  if ((outputSplitSizes.size() == 0) && (inputSplitSizes.size() == 0)) {
    data = new AlltoallWorkData(0);
    TORCH_CHECK(
        (outputTensor.size(0) % size_ == 0) &&
            (inputTensor.size(0) % size_ == 0),
        "Tensor's dim 0 does not divide equally across group size");
    coll.mask = 0;
    coll.flags = 0;
    coll.coll_type = UCC_COLL_TYPE_ALLTOALL;
    coll.src.info.buffer = inputTensor.data_ptr();
    coll.src.info.count = inputTensor.element_size() * inputTensor.numel();
    coll.src.info.datatype = UCC_DT_UINT8;
    coll.src.info.mem_type = to_ucc_memType(inputTensor.device().type());
    coll.dst.info.buffer = outputTensor.data_ptr();
    coll.dst.info.count = outputTensor.element_size() * outputTensor.numel();
    coll.dst.info.datatype = UCC_DT_UINT8;
    coll.dst.info.mem_type = to_ucc_memType(outputTensor.device().type());
    coll.flags = 0;
  } else {
    data = new AlltoallWorkData(size_);
    c10d::checkSplitSizes(inputSplitSizes, inputTensor, size_);
    c10d::checkSplitSizes(outputSplitSizes, outputTensor, size_);
    computeLengthsAndOffsets(
        outputSplitSizes,
        outputTensor,
        &data->recv_lengths,
        &data->recv_offsets);
    computeLengthsAndOffsets(
        inputSplitSizes, inputTensor, &data->send_lengths, &data->send_offsets);
    coll.mask = UCC_COLL_ARGS_FIELD_FLAGS;
    coll.coll_type = UCC_COLL_TYPE_ALLTOALLV;
    coll.src.info_v.buffer = inputTensor.data_ptr();
    coll.src.info_v.counts = (ucc_count_t*)data->send_lengths.data();
    coll.src.info_v.displacements = (ucc_aint_t*)data->send_offsets.data();
    coll.src.info_v.datatype = to_ucc_dType(inputTensor);
    coll.src.info_v.mem_type = to_ucc_memType(inputTensor.device().type());
    coll.dst.info_v.buffer = outputTensor.data_ptr();
    coll.dst.info_v.counts = (ucc_count_t*)data->recv_lengths.data();
    coll.dst.info_v.displacements = (ucc_aint_t*)data->recv_offsets.data();
    coll.dst.info_v.datatype = to_ucc_dType(outputTensor);
    coll.dst.info_v.mem_type = to_ucc_memType(outputTensor.device().type());
    coll.flags = UCC_COLL_ARGS_FLAG_CONTIG_SRC_BUFFER |
        UCC_COLL_ARGS_FLAG_CONTIG_DST_BUFFER | UCC_COLL_ARGS_FLAG_COUNT_64BIT |
        UCC_COLL_ARGS_FLAG_DISPLACEMENTS_64BIT;

    if (torch_ucc_config.enable_comms_logger) {
      logger->trace_generator->recordOptionalInfo(
          outputSplitSizes, inputSplitSizes);
    }
  }
  std::vector<at::Tensor> inputTensors = {inputTensor};
  std::vector<at::Tensor> outputTensors = {outputTensor};
  SAVE_TENSORS(inputTensors, data->src);
  SAVE_TENSORS(outputTensors, data->dst);

  return collective_post(
      OpType::ALLTOALL_BASE,
      []() {},
      []() {},
      coll,
      std::unique_ptr<WorkData>(data),
      inputTensor.device(),
      inputTensors,
      outputTensors,
      "ucc:alltoall");
}

c10::intrusive_ptr<Work> ProcessGroupUCC::barrier(const BarrierOptions& opts) {
  c10::Device device = c10::Device(c10::DeviceType::CPU);
#ifdef USE_CUDA
  auto numGPUs = c10::cuda::device_count();
  if (!opts.device_ids.empty()) {
    device = c10::Device(c10::DeviceType::CUDA, opts.device_ids.front());
  } else if (comm && comm->cuda_device_index != TORCH_UCC_DEVICE_NOT_SET) {
    device = c10::Device(c10::DeviceType::CUDA, comm->cuda_device_index);
  } else if (numGPUs > 0) {
    int8_t deviceIdx = static_cast<int8_t>(c10::cuda::current_device());
    // if current device is 0, likely the device is not set, use the best guess
    if (0 == (int)deviceIdx) {
      deviceIdx = static_cast<int8_t>(this->getRank() % numGPUs);
    }
    TORCH_UCC_LOG_INFO(
        TORCH_UCC_COLL_POST,
        c10::str(
            "post barrier before specifying any GPU while there are ",
            numGPUs,
            " GPUs available. ",
            "Not clear if GPU barrier is required, using GPU ",
            (int)deviceIdx,
            " to perform barrier. ",
            "Specify device_ids option in barrier() to force ",
            "use of a particular device"));
    device = c10::Device(c10::DeviceType::CUDA, deviceIdx);
  }
#endif
  initComm(device);

  ucc_coll_args_t coll;
  coll.mask = 0;
  coll.flags = 0;
  coll.coll_type = UCC_COLL_TYPE_BARRIER;
  auto dummy_tensor = std::vector<at::Tensor>();
  return collective_post(
      OpType::BARRIER,
      []() {},
      []() {},
      coll,
      nullptr,
      device,
      dummy_tensor,
      dummy_tensor,
      "ucc:barrier");
}

c10::intrusive_ptr<Work> ProcessGroupUCC::broadcast(
    std::vector<at::Tensor>& tensors,
    const BroadcastOptions& opts) {
  check_tensor(tensors);
  auto& tensor = tensors[0];
  initComm(tensor.device());
  WorkData* data = new WorkData();

  ucc_coll_args_t coll;
  coll.mask = 0;
  coll.flags = 0;
  coll.coll_type = UCC_COLL_TYPE_BCAST;
  coll.src.info.buffer = tensor.data_ptr();
  coll.src.info.count = tensor.numel();
  coll.src.info.datatype = to_ucc_dType(tensor);
  coll.src.info.mem_type = to_ucc_memType(tensor.device().type());
  coll.root = opts.rootRank;
  SAVE_TENSORS(tensors, data->dst);

  if (torch_ucc_config.enable_comms_logger) {
    logger->trace_generator->recordOptionalInfo(opts.rootRank);
  }

  return collective_post(
      OpType::BROADCAST,
      []() {},
      []() {},
      coll,
      std::unique_ptr<WorkData>(data),
      tensor.device(),
      tensors,
      tensors,
      "ucc:broadcast");
}

c10::intrusive_ptr<Work> ProcessGroupUCC::gather(
    std::vector<std::vector<at::Tensor>>& outputTensors,
    std::vector<at::Tensor>& inputTensors,
    const GatherOptions& opts) {
  std::vector<at::Tensor> outputs;
  auto& input = inputTensors[0];
  initComm(input.device());

  AllgathervWorkData* data = new AllgathervWorkData(size_);
  ucc_coll_args_t coll;
  coll.root = opts.rootRank;
  coll.mask = UCC_COLL_ARGS_FIELD_FLAGS;
  coll.flags =
      UCC_COLL_ARGS_FLAG_COUNT_64BIT | UCC_COLL_ARGS_FLAG_DISPLACEMENTS_64BIT;
  coll.coll_type = UCC_COLL_TYPE_GATHERV;

  /* for non-root ranks, only src is valid */
  coll.src.info.buffer = input.data_ptr();
  coll.src.info.count = (uint64_t)(input.element_size() * input.numel());
  coll.src.info.datatype = UCC_DT_UINT8;
  coll.src.info.mem_type = to_ucc_memType(input.device().type());

  if (getRank() == opts.rootRank) {
    if (outputTensors.size() != 1) {
      TORCH_UCC_LOG_ERROR(
          TORCH_UCC_COLL_POST,
          c10::str(
              "gather requires a single-element output list containing a list with ",
              getSize(),
              " tensors."));
    } else if (outputTensors[0].size() != static_cast<size_t>(getSize())) {
      TORCH_UCC_LOG_ERROR(
          TORCH_UCC_COLL_POST,
          c10::str(
              "Incorrect output list size ",
              outputTensors[0].size(),
              ". Output list size should be ",
              getSize(),
              ", same as size of the process group."));
    }
    outputs = outputTensors[0];

    for (int i = 0; i < size_; i++) {
      data->recv_lengths[i] =
          (uint64_t)(outputs[i].element_size() * outputs[i].numel());
      data->recv_offsets[i] = (uint64_t)outputs[i].data_ptr();
    }
    /* use gatherv and store non-contiguous addresses in displacements to avoid
     * flatten outputTensors */
    coll.dst.info_v.buffer = nullptr;
    coll.dst.info_v.counts = (ucc_count_t*)data->recv_lengths.data();
    coll.dst.info_v.displacements = (ucc_aint_t*)data->recv_offsets.data();
    coll.dst.info_v.datatype = UCC_DT_UINT8;
    coll.dst.info_v.mem_type = to_ucc_memType(outputs[0].device().type());

    SAVE_TENSORS(outputs, data->dst);
  } else {
    // for non-root ranks, outputTensors should be an empty list
    if (outputTensors.size() != 0) {
      TORCH_UCC_LOG_ERROR(
          TORCH_UCC_COLL_POST, "requires empty output on non-root");
    }
    outputs = {};
    // append a empty tensor to the list to be used by future mark
    outputs.emplace_back();
  }

  SAVE_TENSORS(inputTensors, data->src);

  return collective_post(
      OpType::GATHER,
      []() {},
      []() {},
      coll,
      std::unique_ptr<WorkData>(data),
      input.device(),
      inputTensors,
      outputs,
      "ucc:gather");
}

c10::intrusive_ptr<Work> ProcessGroupUCC::reduce(
    std::vector<at::Tensor>& tensors,
    const ReduceOptions& opts) {
  check_tensor(tensors);
  auto& tensor = tensors[0];
  initComm(tensor.device());
  WorkData* data = new WorkData();

  ucc_coll_args_t coll;
  coll.mask = UCC_COLL_ARGS_FIELD_FLAGS;
  coll.flags = UCC_COLL_ARGS_FLAG_IN_PLACE;
  coll.coll_type = UCC_COLL_TYPE_REDUCE;
  coll.op = ucc_op_map.at(opts.reduceOp);
  coll.root = opts.rootRank;
  coll.src.info.buffer = tensor.data_ptr();
  coll.src.info.count = tensor.numel();
  coll.src.info.datatype = ucc_dtype_map.at(tensor.scalar_type());
  coll.src.info.mem_type = to_ucc_memType(tensor.device().type());
  coll.dst.info.buffer = tensor.data_ptr();
  coll.dst.info.count = tensor.numel();
  coll.dst.info.datatype = ucc_dtype_map.at(tensor.scalar_type());
  coll.dst.info.mem_type = to_ucc_memType(tensor.device().type());
  SAVE_TENSORS(tensors, data->dst);
  return collective_post(
      OpType::REDUCE,
      []() {},
      []() {},
      coll,
      std::unique_ptr<WorkData>(data),
      tensor.device(),
      tensors,
      tensors,
      "ucc:reduce");
}

c10::intrusive_ptr<Work> ProcessGroupUCC::reduce_scatter(
    std::vector<at::Tensor>& outputTensors,
    std::vector<std::vector<at::Tensor>>& inputTensors,
    const ReduceScatterOptions& opts) {
  TORCH_CHECK(
      (outputTensors.size() == inputTensors.size()),
      "Tensor input/output list for reduce_scatter must have same size");
  check_tensor(outputTensors);
  check_device(inputTensors[0][0].device(), outputTensors[0].device());
  initComm(inputTensors[0][0].device());
  auto data = std::make_unique<WorkData>();
  std::vector<at::Tensor> flat_input(inputTensors.size());
  for (size_t i = 0; i < inputTensors.size(); i++) {
    TORCH_CHECK(
        inputTensors[i].size() == inputTensors.size() * size_,
        "Tensor input list is not valid for the number of participants");
    flat_input[i] = c10d::newLikeFlat(inputTensors, i);
  }
  SAVE_TENSORS(flat_input, data->flat);
  check_tensor(flat_input);
  ucc_coll_args_t coll;
  coll.mask = 0;
  coll.flags = 0;
  coll.coll_type = UCC_COLL_TYPE_REDUCE_SCATTER;
  coll.op = to_ucc_reduceOp(opts.reduceOp, flat_input[0].scalar_type());

  coll.src.info.buffer = flat_input[0].data_ptr();
  coll.src.info.count = flat_input[0].numel();
  coll.src.info.datatype = to_ucc_dType(flat_input[0]);
  coll.src.info.mem_type = to_ucc_memType(flat_input[0].device().type());
  coll.dst.info.buffer = outputTensors[0].data_ptr();
  coll.dst.info.count = outputTensors[0].numel();
  coll.dst.info.datatype = to_ucc_dType(outputTensors[0]);
  coll.dst.info.mem_type = to_ucc_memType(outputTensors[0].device().type());

  SAVE_TENSORS(inputTensors[0], data->src);
  SAVE_TENSORS(outputTensors, data->dst);

  auto copy_to_flat = [&] {
    bool asyncCopy = false;
    auto isize = inputTensors.size();
#ifdef USE_CUDA
    bool isCuda = inputTensors[0][0].device().is_cuda();
#endif
    for (size_t i = 0; i < isize; i++) {
      auto onumel = outputTensors[i].numel();
      for (size_t j = 0; j < inputTensors[i].size(); j++) {
        TORCH_CHECK(
            (inputTensors[i][j].numel() == onumel),
            "Tensor operand counts must be same");
#ifdef USE_CUDA
        if (isCuda) {
          c10::cuda::CUDACachingAllocator::recordStream(
              inputTensors[i][j].storage().data_ptr(), (*stream));
          asyncCopy = true;
        }
#endif
        flat_input[i][j].copy_(inputTensors[i][j], asyncCopy);
      }
    }
  };

  return collective_post(
      OpType::REDUCE_SCATTER,
      copy_to_flat,
      []() {},
      coll,
      std::move(data),
      inputTensors[0][0].device(),
      inputTensors[0],
      outputTensors,
      "ucc:reduce_scatter");
}

c10::intrusive_ptr<Work> ProcessGroupUCC::scatter(
    std::vector<at::Tensor>& outputTensors,
    std::vector<std::vector<at::Tensor>>& inputTensors,
    const ScatterOptions& opts) {
  auto& tensor = outputTensors[0];
  initComm(tensor.device());

  ScattervWorkData* data = new ScattervWorkData(size_);
  ucc_coll_args_t coll;
  coll.root = opts.rootRank;
  coll.mask = UCC_COLL_ARGS_FIELD_FLAGS;
  coll.flags =
      UCC_COLL_ARGS_FLAG_COUNT_64BIT | UCC_COLL_ARGS_FLAG_DISPLACEMENTS_64BIT;
  coll.coll_type = UCC_COLL_TYPE_SCATTERV;

  if (getRank() == opts.rootRank) {
    /* src is only valid at non-root rank */
    if (inputTensors.size() != 1) {
      TORCH_UCC_LOG_ERROR(
          TORCH_UCC_COLL_POST,
          c10::str(
              "gather requires a single-element output list containing a list with ",
              getSize(),
              " tensors."));
    } else if (inputTensors[0].size() != static_cast<size_t>(getSize())) {
      TORCH_UCC_LOG_ERROR(
          TORCH_UCC_COLL_POST,
          c10::str(
              "Incorrect output list size ",
              inputTensors[0].size(),
              ". Output list size should be ",
              getSize(),
              ", same as size of the process group."));
    }

    for (int i = 0; i < size_; i++) {
      data->send_lengths[i] = (uint64_t)tensor.element_size() * tensor.numel();
      data->send_offsets[i] = (uint64_t)inputTensors[0][i].data_ptr();
    }
    /* use scatter and store non-contiguous addresses in displacements to avoid
     * flatten inputTensors */
    coll.src.info_v.buffer = nullptr;
    coll.src.info_v.counts = (ucc_count_t*)data->send_lengths.data();
    coll.src.info_v.displacements = (ucc_aint_t*)data->send_offsets.data();
    coll.src.info_v.datatype = UCC_DT_UINT8;
    coll.src.info_v.mem_type =
        to_ucc_memType(inputTensors[0][0].device().type());

    SAVE_TENSORS(inputTensors[0], data->src);
  } else {
    // for non-root ranks, inputTensors should be an empty list
    if (inputTensors.size() != 0) {
      TORCH_UCC_LOG_ERROR(
          TORCH_UCC_COLL_POST, "requires empty output on non-root");
    }
  }

  coll.dst.info.buffer = tensor.data_ptr();
  coll.dst.info.count = (uint64_t)tensor.element_size() * tensor.numel();
  coll.dst.info.datatype = UCC_DT_UINT8;
  coll.dst.info.mem_type = to_ucc_memType(tensor.device().type());
  SAVE_TENSORS(outputTensors, data->dst);

  return collective_post(
      OpType::SCATTER,
      []() {},
      []() {},
      coll,
      std::unique_ptr<WorkData>(data),
      tensor.device(),
      (getRank() == opts.rootRank) ? inputTensors[0] : outputTensors,
      outputTensors,
      "ucc:scatter");
}

c10::intrusive_ptr<Work> ProcessGroupUCC::send(
    std::vector<at::Tensor>& tensors,
    int dstRank,
    int tag) {
  check_tensor(tensors);
  auto& tensor = tensors[0];
  initComm(tensor.device());

  WorkData* data = new WorkData();
  ucc_coll_args_t coll;
  coll.tag = tag;
  coll.mask = UCC_COLL_ARGS_FIELD_ACTIVE_SET | UCC_COLL_ARGS_FIELD_TAG;
  coll.flags = 0;
  coll.coll_type = UCC_COLL_TYPE_BCAST;
  coll.src.info.buffer = tensor.data_ptr();
  coll.src.info.count = tensor.numel();
  coll.src.info.datatype = to_ucc_dType(tensor);
  coll.src.info.mem_type = to_ucc_memType(tensor.device().type());
  coll.root = getRank();

  coll.active_set.size = 2;
  coll.active_set.start = getRank();
  coll.active_set.stride = dstRank - getRank();
  SAVE_TENSORS(tensors, data->dst);

  return collective_post(
      OpType::SEND,
      []() {},
      []() {},
      coll,
      std::unique_ptr<WorkData>(data),
      tensor.device(),
      tensors,
      tensors,
      "ucc:send");
}

c10::intrusive_ptr<Work> ProcessGroupUCC::recv(
    std::vector<at::Tensor>& tensors,
    int srcRank,
    int tag) {
  check_tensor(tensors);
  auto& tensor = tensors[0];
  initComm(tensor.device());

  WorkData* data = new WorkData();
  ucc_coll_args_t coll;
  coll.tag = tag;
  coll.mask = UCC_COLL_ARGS_FIELD_ACTIVE_SET | UCC_COLL_ARGS_FIELD_TAG;
  coll.flags = 0;
  coll.coll_type = UCC_COLL_TYPE_BCAST;
  coll.src.info.buffer = tensor.data_ptr();
  coll.src.info.count = tensor.numel();
  coll.src.info.datatype = to_ucc_dType(tensor);
  coll.src.info.mem_type = to_ucc_memType(tensor.device().type());
  coll.root = srcRank;

  coll.active_set.size = 2;
  coll.active_set.start = srcRank;
  coll.active_set.stride = getRank() - srcRank;
  SAVE_TENSORS(tensors, data->dst);

  return collective_post(
      OpType::RECV,
      []() {},
      []() {},
      coll,
      std::unique_ptr<WorkData>(data),
      tensor.device(),
      tensors,
      tensors,
      "ucc:recv");
}

void ProcessGroupUCC::setSequenceNumberForGroup() {}

uint64_t ProcessGroupUCC::getSequenceNumberForGroup() {
  return seq_;
}

c10::intrusive_ptr<Backend> ProcessGroupUCC::createProcessGroupUCC(
    const c10::intrusive_ptr<::c10d::Store>& store,
    int rank,
    int size,
    const std::chrono::duration<float>& timeout) {
  return c10::make_intrusive<ProcessGroupUCC>(store, rank, size, timeout);
}

void ProcessGroupUCC::initComm(c10::Device dev) {
  if (!comm) {
#ifdef USE_CUDA
    if (dev.is_cuda()) {
      c10::cuda::set_device(dev.index());
    }
#endif
    comm = Comm::get_comm(comm_id, dev, oob, logger);
    TORCH_UCC_LOG_INFO(TORCH_UCC_INIT, "Successfully initialized UCX library");
    comm->ucc_create_team(team, oob);
    TORCH_UCC_LOG_INFO(TORCH_UCC_INIT, "Successfully initialized UCC library");
    logger->setPhase(TORCH_UCC_READY);
  } else {
    if (dev.is_cuda()) {
      if ((comm->cuda_device_index != TORCH_UCC_DEVICE_NOT_SET) &&
          (comm->cuda_device_index != dev.index())) {
        TORCH_UCC_LOG_ERROR(
            TORCH_UCC_INIT,
            "ucc communicator was initialized with different cuda device,"
            "multi device is not supported");
        throw std::invalid_argument(ucc_status_string(UCC_ERR_NOT_SUPPORTED));
      }
      comm->cuda_device_index = dev.index();
    }
  }
#ifdef USE_CUDA
  // Create UCC execution engine.
  if (!cuda_ee && dev.is_cuda()) {
    stream = std::make_unique<at::cuda::CUDAStream>(
        at::cuda::getStreamFromPool(true, dev.index()));
    ucc_ee_params_t params;
    params.ee_type = UCC_EE_CUDA_STREAM;
    params.ee_context = (void*)stream->stream();
    params.ee_context_size = sizeof(cudaStream_t);
    TORCH_UCC_CHECK(
        ucc_ee_create(team, &params, &cuda_ee),
        "failed to create UCC execution engine");
    for (int i = 0; i < 2; i++) {
      stream_p2p[i] = std::make_unique<at::cuda::CUDAStream>(
          at::cuda::getStreamFromPool(true, dev.index()));
      ucc_ee_params_t params;
      params.ee_type = UCC_EE_CUDA_STREAM;
      params.ee_context = (void*)stream_p2p[i]->stream();
      params.ee_context_size = sizeof(cudaStream_t);
      TORCH_UCC_CHECK(
          ucc_ee_create(team, &params, &cuda_ee_p2p[i]),
          "failed to create UCC P2P execution engine");
    }
  }
#endif
}

} // namespace c10d

#endif // USE_C10D_UCC
