#include <torch/csrc/profiler/standalone/nvtx_observer.h>

#include <torch/csrc/profiler/stubs/base.h>
#include <torch/csrc/profiler/util.h>

namespace torch::profiler::impl {

struct NVTXThreadLocalState : ProfilerStateBase {
  explicit NVTXThreadLocalState(const ProfilerConfig& config)
      : ProfilerStateBase(config) {
    // Only `report_input_shapes` makes sense in this context.
    TORCH_CHECK(!config.profile_memory);
    TORCH_CHECK(!config.with_stack);
    TORCH_CHECK(!config.with_flops);
    TORCH_CHECK(!config.with_modules);
  }
  ~NVTXThreadLocalState() override = default;

  ActiveProfilerType profilerType() override {
    return ActiveProfilerType::NVTX;
  }

  void reportMemoryUsage(void*, int64_t, size_t, size_t, c10::Device) override {
  }

  static NVTXThreadLocalState* getTLS() {
    auto tls = ProfilerStateBase::get(/*global=*/false);
    TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
        tls == nullptr || tls->profilerType() == ActiveProfilerType::NVTX);
    return static_cast<NVTXThreadLocalState*>(tls);
  }
  std::pair<at::RecordFunctionHandle, int> getOpIdFromInput(
      const at::Tensor& tensor);

  void setProducerTensorMap(
      at::TensorImpl* tensor,
      at::RecordFunctionHandle op_id,
      int output_nr) {
    producer_tensor_map_[(void*)tensor] =
        std::pair<at::RecordFunctionHandle, int>{op_id, output_nr};
  }

 protected:
  // Maps the address of an output Tensor to a unique op id and output
  // index of the tensor.
  // at::TensorImpl* is the actual type of the key, but using void*
  // to indicate the pointer is just being used as a key
  // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
  std::unordered_map<void*, std::pair<at::RecordFunctionHandle, int>>
      producer_tensor_map_;
};

std::pair<at::RecordFunctionHandle, int> NVTXThreadLocalState::getOpIdFromInput(
    const at::Tensor& tensor) {
  std::pair<at::RecordFunctionHandle, int> producer_op_pair(0, -1);
  if (tensor.defined()) {
    at::TensorImpl* ten_addr = tensor.unsafeGetTensorImpl();
    // See if Address is in the map already
    if (producer_tensor_map_.count((void*)ten_addr) > 0) {
      producer_op_pair = producer_tensor_map_[(void*)ten_addr];
    }
  }
  return producer_op_pair;
}

static std::list<std::pair<at::RecordFunctionHandle, int>> flattenOpIdList(
    const c10::List<c10::IValue>& list) {
  std::list<std::pair<at::RecordFunctionHandle, int>> input_op_id_list;
  auto state_ptr = NVTXThreadLocalState::getTLS();
  TORCH_INTERNAL_ASSERT(state_ptr, "Expected profiler state set");
  for (const c10::IValue& input : list) {
    if (input.isTensor()) {
      const at::Tensor& tensor = input.toTensor();
      auto producer_op_pair = state_ptr->getOpIdFromInput(tensor);
      input_op_id_list.push_back(producer_op_pair);
    }
  }
  return input_op_id_list;
}

static std::list<std::pair<at::RecordFunctionHandle, int>> getInputTensorOpIds(
    const at::RecordFunction& fn) {
  std::pair<at::RecordFunctionHandle, int> undefined_op_pair(0, -1);
  std::list<std::pair<at::RecordFunctionHandle, int>> input_producer_ops_;
  auto state_ptr = NVTXThreadLocalState::getTLS();
  TORCH_INTERNAL_ASSERT(state_ptr, "Expected profiler state set");
  for (const c10::IValue& input_item : fn.inputs()) {
    if (input_item.isTensor()) {
      const at::Tensor& tensor = input_item.toTensor();
      auto producer_pair = state_ptr->getOpIdFromInput(tensor);
      input_producer_ops_.push_back(producer_pair);
    } else {
      if (input_item.isList()) {
        std::list<std::pair<at::RecordFunctionHandle, int>> tmp_op_ids =
            flattenOpIdList(input_item.toList());
        // Extend the current sizes array by the array returned from input sizes
        if (!tmp_op_ids.empty()) {
          input_producer_ops_.splice(input_producer_ops_.end(), tmp_op_ids);
        } else {
          input_producer_ops_.emplace_back(undefined_op_pair);
        }
      } else {
        input_producer_ops_.emplace_back(undefined_op_pair);
      }
    }
  }
  return input_producer_ops_;
}

static void updateOutputTensorTracker(const at::RecordFunction& fn) {
  int output_nr = 0;
  auto state_ptr = NVTXThreadLocalState::getTLS();
  TORCH_INTERNAL_ASSERT(state_ptr, "Expected profiler state set");
  for (const c10::IValue& s_tensor : fn.outputs()) {
    if (s_tensor.isTensor()) {
      const at::Tensor& tensor = s_tensor.toTensor();
      if (tensor.defined()) {
        auto ten_addr = tensor.unsafeGetTensorImpl();
        state_ptr->setProducerTensorMap(ten_addr, fn.handle(), output_nr);
      }
    }
    output_nr++;
  }
}

template <bool report_input_shapes>
std::unique_ptr<at::ObserverContext> enterNVTX(const at::RecordFunction& fn) {
  if (NVTXThreadLocalState::getTLS() != nullptr) {
    auto input_op_ids = getInputTensorOpIds(fn);
    torch::profiler::impl::cudaStubs()->rangePush(
        torch::profiler::impl::getNvtxStr(
            fn.name(),
            fn.seqNr(),
            report_input_shapes ? torch::profiler::impl::inputSizes(fn, true)
                                : std::vector<std::vector<int64_t>>(),
            fn.handle(),
            report_input_shapes
                ? input_op_ids
                : std::list<std::pair<at::RecordFunctionHandle, int>>())
            .c_str());
  }
  return nullptr;
}

void pushNVTXCallbacks(
    const ProfilerConfig& config,
    const std::unordered_set<at::RecordScope>& scopes) {
  TORCH_CHECK(
      torch::profiler::impl::cudaStubs()->enabled(),
      "Can't use NVTX profiler - PyTorch was compiled without CUDA");

  c10::ThreadLocalDebugInfo::_push(
      c10::DebugInfoKind::PROFILER_STATE,
      std::make_shared<NVTXThreadLocalState>(config));

  auto state_ptr = NVTXThreadLocalState::getTLS();
  TORCH_INTERNAL_ASSERT(state_ptr, "Expected profiler state set");

  auto handle = at::addThreadLocalCallback(
      at::RecordFunctionCallback(
          state_ptr->config().report_input_shapes
              ? &enterNVTX</*report_input_shapes=*/true>
              : &enterNVTX</*report_input_shapes=*/false>,
          [](const at::RecordFunction& fn, at::ObserverContext* ctx) {
            torch::profiler::impl::cudaStubs()->rangePop();
            updateOutputTensorTracker(fn);
          })
          .needsInputs(config.report_input_shapes)
          .needsOutputs(config.report_input_shapes)
          .needsIds(true)
          .scopes(scopes));
  state_ptr->setCallbackHandle(handle);
}

} // namespace torch::profiler::impl
