#ifdef USE_C10D_UCC

#include <torch/csrc/distributed/c10d/UCCTracing.hpp>
#include <torch/csrc/distributed/c10d/UCCUtils.hpp>

#include <torch/csrc/distributed/c10d/ParamCommsUtils.hpp>

#include <sys/stat.h>
#include <cstdlib>
#include <ctime>
#include <fstream>

namespace c10d {

void ProcessGroupUCCLogger::initCommsTracer() {
  trace_generator = std::make_shared<CommTraceLogger>();
  initialized_CommTraceLogger = true;
}

void ProcessGroupUCCLogger::flushComms(int rank, int world_size) {
  if (!initialized_CommTraceLogger ||
      trace_generator->getCommsTrace().empty()) {
    return;
  }

  std::string dirname = c10::str("ProcessGroupUCC_trace_np", world_size);
  time_t now_ = time(0);
  std::tm* ltm = localtime(&now_);
  if (ltm) {
    dirname += c10::str(
        "_", (1 + ltm->tm_mon), "_", ltm->tm_mday, "_", (1900 + ltm->tm_year));
  }

  std::string fullpath = "/tmp/" + dirname;
  char* user_path = std::getenv("TORCH_UCC_COMMS_TRACE_OUTPUT_DIR");
  if (user_path) {
    fullpath = user_path;
  }
  std::string trace_filename = c10::str(fullpath, "/rank", rank, ".json");
  std::ofstream _outfile;
  if (!_outfile.is_open()) {
    if (!mkdir(fullpath.c_str(), 0777)) {
      LOG(INFO) << getLogPrefix() << "[INFO] failed to mkdir " << fullpath;
    } else if (errno != EEXIST) {
      return;
    }
    _outfile.open(trace_filename, std::ofstream::out | std::ofstream::trunc);
  }
  // flush the traced comms
  if (_outfile.is_open()) {
    _outfile << "[" << c10::Join(",", trace_generator->getCommsTrace())
             << "\n]";
    _outfile.flush();
    _outfile.close();
  }
}

/* unused */
void CommTraceLogger::setCurBlock(const std::string& name) {
  curBlocks_.push_back(
      c10::str("\"", name, "\"")); // add quote marks for JSON format
}

/* unused */
void CommTraceLogger::popBlock() {
  // TODO: remove specific name
  curBlocks_.pop_back();
}

void CommTraceLogger::recordOptionalInfo(int root) {
  curRoot_ = root;
}

void CommTraceLogger::recordOptionalInfo(
    const std::vector<int64_t>& outputSplitSizes,
    const std::vector<int64_t>& inputSplitSizes) {
  curOutSplitSizes_ = outputSplitSizes;
  curInSplitSizes_ = inputSplitSizes;
}

void CommTraceLogger::recordComms(
    const std::string& commName,
    const uintptr_t workReq,
    const int rank,
    const int world_size,
    const std::vector<at::Tensor>& inputTensors,
    const std::vector<at::Tensor>& outputTensors) {
  auto inNelems = (!inputTensors.empty()) ? inputTensors[0].numel() : 0;
  auto outNelems = (!outputTensors.empty()) ? outputTensors[0].numel() : 0;
  auto dtype =
      (!outputTensors.empty()) ? outputTensors[0].scalar_type() : at::kByte;
  auto devType = (!outputTensors.empty()) ? outputTensors[0].device().type()
                                          : c10::DeviceType::CPU;
  auto now = std::chrono::system_clock::now();
  static auto startTS = now;
  int64_t time_since_begin =
      std::chrono::duration_cast<std::chrono::nanoseconds>(now - startTS)
          .count();

  // TODO: get markers from torch profiler if enabled

  // common fields for all operations
  std::string cur_trace_ = c10::str(
      "\n\t\t\"markers\": [",
      curBlocks_,
      "]",
      ",\n\t\t\"startTime_ns\": ",
      time_since_begin,
      ",\n\t\t\"comms\": \"",
      commName,
      "\"",
      ",\n\t\t\"req\": ",
      workReq,
      ",\n\t\t\"seqnum\": ",
      seqnum,
      ",\n\t\t\"world_size\": ",
      world_size);

  if (inNelems > 0 || outNelems > 0) {
    // for most collectives - append msg sizes, data type, device type
    cur_trace_ = c10::str(
        cur_trace_,
        ",\n\t\t\"in_msg_size\": ",
        inNelems,
        ",\n\t\t\"out_msg_size\": ",
        outNelems,
        ",\n\t\t\"dtype\": \"",
        at::toString(dtype),
        "\",\n\t\t\"devType\": \"",
        c10::DeviceTypeName(devType),
        "\"");
  }
  if (curRoot_ != -1) {
    // append root rank if applicable, e.g., broadcast, gather, scatter
    cur_trace_ = c10::str(cur_trace_, ",\n\t\t\"root\": ", curRoot_);
  }
  if (!curInSplitSizes_.empty() || !curOutSplitSizes_.empty()) {
    // append input and output splits if applicable, e.g., ALLTOALL_BASE
    cur_trace_ = c10::str(
        cur_trace_,
        ",\n\t\t\"in_split\": [",
        c10::Join(",", curInSplitSizes_),
        "]"
        ",\n\t\t\"out_split\": [",
        c10::Join(",", curOutSplitSizes_),
        "]");
  }
  comms_trace_.push_back(c10::str("\n\t{", cur_trace_, "\n\t}"));

  // record the trace to kineto trace if applicable
  RECORD_PARAM_COMMS(
      static_cast<int64_t>(seqnum), // seq
      std::make_tuple("0", ""), // pg_name tuple
      rank,
      commName.c_str(),
      inNelems,
      outNelems,
      dtype,
      curInSplitSizes_,
      curOutSplitSizes_,
      -1,
      -1,
      world_size);

  ++seqnum;

  // reset optional field
  curRoot_ = -1;
  curInSplitSizes_ = {};
  curOutSplitSizes_ = {};
}

} // namespace c10d

#endif // USE_C10D_UCC
