#include <torch/csrc/autograd/profiler_python.h>

#include <atomic>
#include <cstdint>
#include <deque>
#include <limits>
#include <memory>
#include <queue>
#include <string>
#include <utility>
#include <vector>

#include <Python.h>
#include <frameobject.h>

#include <ATen/core/TensorBase.h>
#include <c10/macros/Macros.h>
#include <c10/util/ApproximateClock.h>
#include <c10/util/Exception.h>
#include <c10/util/Logging.h>
#include <c10/util/flat_hash_map.h>
#include <c10/util/irange.h>
#include <torch/csrc/autograd/python_variable.h>
#include <torch/csrc/profiler/collection.h>
#include <torch/csrc/profiler/containers.h>
#include <torch/csrc/profiler/orchestration/python_tracer.h>
#include <torch/csrc/profiler/util.h>
#include <torch/csrc/utils/pybind.h>
#include <torch/csrc/utils/python_compat.h>
#include <torch/csrc/utils/python_strings.h>
#include <optional>

namespace py = pybind11;

namespace torch::profiler::impl {
namespace {
enum CallType { PyCall = 0, PyModuleCall, PyCCall, PyOptimizerCall };
static constexpr size_t CallTypeSize = 4;
using no_ephemeral_t = std::tuple<>;
static constexpr uint64_t NoTID = std::numeric_limits<uint64_t>::max();

// ============================================================================
// == Miscellaneous structs and utils =========================================
// ============================================================================
struct CodeLocation {
  CodeLocation() = default;
  explicit CodeLocation(PyFrameObject* frame)
      : line_number_{PyFrame_GetLineNumber(frame)} {
    auto code = THPCodeObjectPtr(PyFrame_GetCode(frame));
    filename_ = THPUtils_unpackStringView(code->co_filename).data();
    name_ = THPUtils_unpackStringView(code->co_name).data();
  }

  bool operator==(const CodeLocation& other) const {
    return filename_ == other.filename_ && name_ == other.name_ &&
        line_number_ == other.line_number_;
  }

  const char* filename_{nullptr};
  const char* name_{nullptr};
  int line_number_{0};
};

template <CallType C>
PyCodeObject* getCode();

template <>
PyCodeObject* getCode<CallType::PyModuleCall>() {
  static auto module_call_code = []() {
    pybind11::gil_scoped_acquire gil;
    auto res = py::module::import("torch.nn")
                   .attr("Module")
                   .attr("__call__")
                   .attr("__code__")
                   .ptr();
    TORCH_INTERNAL_ASSERT(PyCode_Check(res));
    return (PyCodeObject*)res;
  }();
  return module_call_code;
};

template <>
PyCodeObject* getCode<CallType::PyOptimizerCall>() {
  static auto optimizer_step_code = []() {
    pybind11::gil_scoped_acquire gil;
    auto res = py::module::import("torch.optim")
                   .attr("Optimizer")
                   .attr("_optimizer_step_code")
                   .attr("__code__")
                   .ptr();
    TORCH_INTERNAL_ASSERT(PyCode_Check(res));
    return (PyCodeObject*)res;
  }();
  return optimizer_step_code;
};

} // namespace
} // namespace torch::profiler::impl

template <>
struct std::hash<torch::profiler::impl::CodeLocation> {
  size_t operator()(const torch::profiler::impl::CodeLocation& x) {
    return c10::get_hash(x.filename_, x.name_, x.line_number_);
  }
};

namespace torch::profiler::impl {
namespace {
// ============================================================================
// == CallTypeHelper: Tools for generic programming on specializations. =======
// ============================================================================
template <template <CallType> class ClassT>
class CallTypeHelper final {
 private:
  static_assert(
      CallType::PyCall == 0,
      "CallTypeHelper uses integer math which depends on a zero start.");
  static constexpr size_t End = CallTypeSize;

  template <size_t... I>
  static constexpr std::tuple<ClassT<(CallType)I>...> make_tuple_impl(
      std::index_sequence<I...>);

  template <size_t C, typename T, typename FunctorT, typename... Args>
  static void map(T& t, FunctorT& f, Args&&... args) {
    f(std::get<C>(t), args...);
    if constexpr (C + 1 < End) {
      map<C + 1>(t, f, std::forward<Args>(args)...);
    }
  }

 public:
  using tuple_type = decltype(make_tuple_impl(std::make_index_sequence<End>{}));

  template <typename FunctorT, typename... Args>
  static void map(tuple_type& t, FunctorT& f, Args&&... args) {
    map<0>(t, f, std::forward<Args>(args)...);
  }
};

// ============================================================================
// == Event type definitions. =================================================
// ============================================================================
// When we are tracing a Python program, the general procedure is to record
// every time we enter or exit a function and later replay these events during
// post processing. Thus, during the profiling phase we want to do the MINIMAL
// amount of work to capture all of the information that we need; otherwise we
// will distort the profile. (While we don't wish to be terribly inefficient
// during post processing, we are willing to do extra fixup work in post if it
// reduces overhead in the profiling phase.)
//
// When the tracer first enters a frame, it constructs a CallKey for that
// location. The contents of the key vary by context. For a python function
// the key is the (PyCodeObject*, int) pair that defines the bytecode of the
// function. For an `nn.Module` the key is a (non-owning) pointer to `self`.
// For a bound C function it is a (non-owning) pointer to the bound function.
// A CallKey should be small, inexpensive, and POD.
//
// We then collect a CallKey<CallType::PyCall> for the calling frame for better
// source tracking. This pair is a `Callsite`, and serves as a first level key
// during tracing. We lookup the Callsite in a thread local cache which maps
// Callsite to a unique integer `TraceKey`. On a cache hit, we simply store the
// TraceKey and return. On a cache miss, we use a global value cache to store
// whatever fields we need from the two CallKeys, generate a new TraceKey, and
// update the local cache.
//
// During post processing we:
//   1) Determine the type represented by a TraceKey by checking which
//      sub-cache it appears in in the thread local cache.
//   2) Look up the pair of CallKeys from the thread local cache.
//   3) Look up the expanded values of each CallKey from the global value cache.
//
// To add a new event type to the cache:
//   1) Add an entry to the `CallType` enum.
//   2) Add a specialization of Config which defined key_t, ephemeral_t and
//      cache_t.
//   3) Add a specialization of ValueCache::store and ValueCache::load.
//
// -------------------------
// -- Ephemeral arguments --
// -------------------------
// The value cache mechanism assumes that `key_t` is enough to specify the
// correct value. However it may not be possible to materialize a value using
// only an instance of `key_t`. As a result, the cache also accepts "ephemeral"
// inputs which can be used to populate the value cache. Ephemeral inputs come
// with two caveats:
//  1) They are NOT safe to save, and cannot be used after `ValueCache::store`.
//  2) They should be used to access data that is not expect to change from
//     call to call, such as the name of a function.

template <CallType>
struct Config;

template <>
struct Config<CallType::PyCall> {
  using key_t = CodeLocation;
  using ephemeral_t = no_ephemeral_t;
  using cache_t = ska::flat_hash_map<key_t, PyFrameState>;
  static constexpr EventType event_type = EventType::PyCall;
};

template <typename Key, typename Cls, typename ParameterInfo>
struct ExtendedPyCallConfig {
  using key_t = Key;
  using cls_t = Cls;
  using ephemeral_t = PyFrameObject*;

  struct ClsAndParameters {
    cls_t cls_;
    std::vector<ParameterInfo> parameters_;
  };

  struct Cache {
    // `nn.Module.forward` or `optim.Optimizer._optimizer_step_code`
    std::optional<CodeLocation> location_;
    ska::flat_hash_map<key_t, ClsAndParameters> cls_and_parameters_;
    ska::flat_hash_map<cls_t, at::StringView> cls_names_;
  };
  using cache_t = Cache;

  static constexpr EventType event_type = EventType::PyCall;
};

template <>
struct Config<CallType::PyModuleCall> : ExtendedPyCallConfig<
                                            PyModuleSelf,
                                            PyModuleCls,
                                            NNModuleInfo::ParameterInfo> {};

template <>
struct Config<CallType::PyOptimizerCall> : ExtendedPyCallConfig<
                                               PyOptimizerSelf,
                                               PyOptimizerCls,
                                               OptimizerInfo::ParameterInfo> {};

template <>
struct Config<CallType::PyCCall> {
  using key_t = PyMethod;
  using ephemeral_t = PyObject*;
  using cache_t = ska::flat_hash_map<key_t, at::StringView>;
  static constexpr EventType event_type = EventType::PyCCall;
};

// ============================================================================
// == Callsite & ValueCache: Storage during profiling =========================
// ============================================================================
template <CallType C>
class Callsite {
 public:
  static constexpr CallType call_type = C;
  using key_t = typename Config<C>::key_t;

  static_assert(
      std::is_trivially_copyable_v<key_t>,
      "Key should be trivial, as it is passed by value.");

  template <typename U>
  Callsite(U value, PyFrameObject* f_back) : value_(value), caller_(f_back) {}

  bool operator==(const Callsite<C>& other) const {
    return value_ == other.value_ && caller_ == other.caller_;
  }

  key_t value_;
  Config<CallType::PyCall>::key_t caller_;
};

// ============================================================================
// == Type specific store and load implementations. ===========================
// ============================================================================
using PyCallKey = Config<CallType::PyCall>::key_t;
using PyModuleCallKey = Config<CallType::PyModuleCall>::key_t;
using PyCCallKey = Config<CallType::PyCCall>::key_t;
using PyOptimizerCallKey = Config<CallType::PyOptimizerCall>::key_t;

class ValueCache {
 public:
  ValueCache() = default;
  ValueCache(const ValueCache&) = delete;

  template <CallType C>
  void store(const typename Config<C>::key_t&, typename Config<C>::ephemeral_t);

  template <CallType C>
  auto load(const Callsite<C>& callsite, size_t python_tid) const {
    auto caller = load<CallType::PyCall>(callsite.caller_);
    TORCH_INTERNAL_ASSERT(!caller.module_info_.has_value());
    return ExtraFields<Config<C>::event_type>{
        /*end_time_ns=*/std::numeric_limits<c10::time_t>::min(),
        python_tid,
        caller.frame_state_,
        load<C>(callsite.value_)};
  }

  std::optional<TensorMetadata> recordIfTensor(py::handle p);
  std::vector<std::pair<std::string, TensorMetadata>> unpackTensorMap(
      const py::dict& tensor_map);
  void trimPrefixes();

 private:
  template <CallType C>
  typename ExtraFields<Config<C>::event_type>::args_t load(
      const typename Config<C>::key_t&) const;

  template <CallType C>
  using State = typename Config<C>::cache_t;

  CallTypeHelper<State>::tuple_type state_;
};

template <CallType C>
typename Config<C>::cls_t set_class(
    ValueCache* value_cache,
    typename Config<C>::cache_t& cache,
    const typename Config<C>::key_t& key,
    const typename Config<C>::ephemeral_t& frame) {
  if (C10_UNLIKELY(!cache.location_.has_value())) {
    auto code = THPCodeObjectPtr(PyFrame_GetCode(frame));
    TORCH_INTERNAL_ASSERT(code.get() == getCode<C>());
    cache.location_ = PyCallKey(frame);
    value_cache->store<CallType::PyCall>(*cache.location_, no_ephemeral_t());
  }

  auto cls_handle = py::handle((PyObject*)key).attr("__class__");
  auto cls = typename Config<C>::cls_t(cls_handle.ptr());
  if (cache.cls_names_.find(cls) == cache.cls_names_.end()) {
    cache.cls_names_[cls] =
        at::StringView(py::str(cls_handle.attr("__name__")));
  }
  return cls;
}

TensorMetadata toTensorMetadata(PyObject* self) {
  TORCH_INTERNAL_ASSERT(THPVariable_CheckExact(self));
  const auto& t = THPVariable_Unpack(self);
  RawTensorMetadata m{t};
  return TensorMetadata{
      m,
      t.sizes().vec(),
      m.layout_ == at::kStrided ? t.strides().vec() : std::vector<int64_t>()};
}

std::optional<TensorMetadata> ValueCache::recordIfTensor(py::handle p) {
  return THPVariable_CheckExact(p.ptr())
      ? std::optional<TensorMetadata>{toTensorMetadata(p.ptr())}
      : std::nullopt;
}

std::vector<std::pair<std::string, TensorMetadata>> ValueCache::unpackTensorMap(
    const py::dict& tensor_map) {
  std::vector<std::pair<std::string, TensorMetadata>> out;
  for (auto& it : tensor_map) {
    auto* value = it.second.ptr();
    if (py::isinstance<py::str>(it.first) && THPVariable_CheckExact(value)) {
      out.emplace_back(
          py::cast<std::string>(it.first), toTensorMetadata(value));
    }
  }
  return out;
}

template <>
void ValueCache::store<CallType::PyCall>(const PyCallKey& key, no_ephemeral_t) {
  auto& locations = std::get<CallType::PyCall>(state_);
  if (C10_UNLIKELY(locations.find(key) == locations.end())) {
    locations[key] = {
        key.line_number_,
        at::StringView(key.filename_),
        at::StringView(key.name_)};
  }
}

template <>
ExtraFields<EventType::PyCall>::args_t ValueCache::load<CallType::PyCall>(
    const PyCallKey& key) const {
  return {std::get<CallType::PyCall>(state_).at(key), std::nullopt};
}

template <>
void ValueCache::store<CallType::PyModuleCall>(
    const PyModuleCallKey& key,
    Config<CallType::PyModuleCall>::ephemeral_t frame) {
  auto& cache = std::get<CallType::PyModuleCall>(state_);
  if (C10_UNLIKELY(
          cache.cls_and_parameters_.find(key) ==
          cache.cls_and_parameters_.end())) {
    auto cls = set_class<CallType::PyModuleCall>(this, cache, key, frame);

    py::dict params = py::handle((PyObject*)key).attr("_parameters");
    std::vector<NNModuleInfo::ParameterInfo> params_;
    for (auto& it : params) {
      auto* p = it.second.ptr();
      if (py::isinstance<py::str>(it.first) && THPVariable_CheckExact(p)) {
        params_.push_back(
            {it.first.cast<std::string>(),
             toTensorMetadata(p),
             recordIfTensor(py::getattr(it.second, "grad", py::none()))});
      }
    }
    cache.cls_and_parameters_[key] = {cls, std::move(params_)};
  }
}

template <>
ExtraFields<EventType::PyCall>::args_t ValueCache::load<CallType::PyModuleCall>(
    const PyModuleCallKey& key) const {
  auto& cache = std::get<CallType::PyModuleCall>(state_);
  TORCH_INTERNAL_ASSERT(cache.location_.has_value());
  const auto& cls_and_parameters = cache.cls_and_parameters_.at(key);
  const auto& cls = cls_and_parameters.cls_;
  NNModuleInfo info{
      key, cls, cache.cls_names_.at(cls), cls_and_parameters.parameters_};
  return {
      /*frame_state_=*/std::get<CallType::PyCall>(state_).at(*cache.location_),
      /*module_info_=*/std::move(info),
      /*optimizer_info_=*/std::nullopt};
}

template <>
void ValueCache::store<CallType::PyOptimizerCall>(
    const PyOptimizerCallKey& key,
    Config<CallType::PyOptimizerCall>::ephemeral_t frame) {
  auto& cache = std::get<CallType::PyOptimizerCall>(state_);
  if (C10_UNLIKELY(
          cache.cls_and_parameters_.find(key) ==
          cache.cls_and_parameters_.end())) {
    auto cls = set_class<CallType::PyOptimizerCall>(this, cache, key, frame);
    const py::handle self{(PyObject*)key};
    std::vector<OptimizerInfo::ParameterInfo> params;

    for (const auto& i : (py::list)self.attr("param_groups")) {
      for (auto& param : py::cast<py::dict>(i).attr("get")("params")) {
        if (THPVariable_CheckExact(param.ptr())) {
          // While `self.state` is permitted to store data in an arbitrary way,
          // all generic optimizers (SGD, Adam, etc) use param as the key since
          // the state in question is tied to particular parameters. We can
          // relax this assumption if the need arises.
          params.push_back(
              {toTensorMetadata(param.ptr()),
               recordIfTensor(py::getattr(param, "grad", py::none())),
               unpackTensorMap(py::cast<py::dict>(self.attr("state"))
                                   .attr("get")(param, py::dict()))});
        }
      }
    }

    cache.cls_and_parameters_[key] = {cls, std::move(params)};
  }
}

template <>
ExtraFields<EventType::PyCall>::args_t ValueCache::load<
    CallType::PyOptimizerCall>(const PyOptimizerCallKey& key) const {
  auto& cache = std::get<CallType::PyOptimizerCall>(state_);
  const auto& cls_and_parameters = cache.cls_and_parameters_.at(key);
  auto cls = cls_and_parameters.cls_;
  OptimizerInfo info{
      key, cls, cache.cls_names_.at(cls), cls_and_parameters.parameters_};
  return {
      // NOLINTNEXTLINE(bugprone-unchecked-optional-access)
      /*frame_state_=*/std::get<CallType::PyCall>(state_).at(*cache.location_),
      /*module_info_=*/std::nullopt,
      /*optimizer_info_=*/std::move(info)};
}

template <>
void ValueCache::store<CallType::PyCCall>(
    const PyCCallKey& key,
    Config<CallType::PyCCall>::ephemeral_t arg) {
  auto& names = std::get<CallType::PyCCall>(state_);
  if (C10_UNLIKELY(names.find(key) == names.end())) {
    names[key] = at::StringView(py::repr(arg));
  }
}

template <>
ExtraFields<EventType::PyCCall>::args_t ValueCache::load<CallType::PyCCall>(
    const PyCCallKey& key) const {
  return std::get<CallType::PyCCall>(state_).at(key);
}

// TODO: Use re2.
void ValueCache::trimPrefixes() {
  static const auto prefixes = []() {
    pybind11::gil_scoped_acquire gil;
    return py::module::import("torch.profiler.python_tracer")
        .attr("_prefix_regex")()
        .cast<std::vector<std::string>>();
  }();

  for (auto& it : std::get<CallType::PyCall>(state_)) {
    std::string filename = it.second.filename_.str();
    for (const auto& p : prefixes) {
      if (filename.compare(0, p.size(), p) == 0) {
        filename.erase(0, p.size());
        it.second.filename_ = at::StringView(filename);
        break;
      }
    }
  }
}

// ============================================================================
// == TraceKey cache ==========================================================
// ============================================================================
using python_tracer::TraceKey;

TraceKey nextKey() {
  static std::atomic<uint64_t> key{0};
  return TraceKey{++key};
}

template <CallType C>
struct TraceKeyCacheState {
  struct Hash {
    size_t operator()(const Callsite<C>& key) {
      return c10::get_hash(key.value_, key.caller_);
    }
  };

  TraceKey intern(
      Callsite<C> callsite,
      typename Config<C>::ephemeral_t ephemeral,
      ValueCache& value_cache) {
    auto it = state_.find(callsite);
    if (C10_UNLIKELY(it == state_.end())) {
      value_cache.store<C>(callsite.value_, ephemeral);
      value_cache.store<CallType::PyCall>(callsite.caller_, no_ephemeral_t());
      it = state_.insert({callsite, nextKey()}).first;
    }
    return it->second;
  }

  auto lookup(Callsite<C>& callsite, ValueCache& value_cache) const {
    return std::make_pair(
        value_cache.load<C>(callsite.value_),
        value_cache.load<CallType::PyCall>(callsite.caller_));
  }

  ska::flat_hash_map<Callsite<C>, TraceKey, Hash> state_;
};

// ============================================================================
// == Core CPython data types =================================================
// ============================================================================
// PyObject that allows different threads to record events without colliding.
// It is passed as the second argument when enabling tracing via
// `PyEval_SetProfile`.
struct ThreadLocalResults;
struct TraceContext {
  PyObject_HEAD;
  ThreadLocalResults* thread_local_results_;
};

// CPython boilerplate to define `TraceContext` as a proper python object.
static PyTypeObject TraceContextType = {
    PyVarObject_HEAD_INIT(nullptr, 0) "TraceContext", /* tp_name */
    sizeof(TraceContext), /* tp_basicsize */
    0, /* tp_itemsize */
    nullptr, /* tp_dealloc */
    0,
    /* tp_vectorcall_offset */
    nullptr, /* tp_getattr */
    nullptr, /* tp_setattr */
    nullptr, /* tp_reserved */
    nullptr, /* tp_repr */
    nullptr, /* tp_as_number */
    nullptr, /* tp_as_sequence */
    nullptr, /* tp_as_mapping */
    nullptr, /* tp_hash  */
    nullptr, /* tp_call */
    nullptr, /* tp_str */
    nullptr, /* tp_getattro */
    nullptr, /* tp_setattro */
    nullptr, /* tp_as_buffer */
    Py_TPFLAGS_DEFAULT, /* tp_flags */
    "Python tracer TLS", /* tp_doc */
    nullptr, /* tp_traverse */
    nullptr, /* tp_clear */
    nullptr, /* tp_richcompare */
    0, /* tp_weaklistoffset */
    nullptr, /* tp_iter */
    nullptr, /* tp_iternext */
    nullptr, /* tp_methods */
    nullptr, /* tp_members */
    nullptr, /* tp_getset */
    nullptr, /* tp_base */
    nullptr, /* tp_dict */
    nullptr, /* tp_descr_get */
    nullptr, /* tp_descr_set */
    0, /* tp_dictoffset */
    nullptr, /* tp_init */
    nullptr, /* tp_alloc */
    PyType_GenericNew, /* tp_new */
    nullptr /* tp_free */
};

class gil_and_restore_thread {
 public:
  gil_and_restore_thread()
      : gil_(), initial_thread_state_{PyThreadState_Get()} {}
  ~gil_and_restore_thread() {
    PyThreadState_Swap(initial_thread_state_);

    // `gil_scoped_acquire` is a bit fragile in on-demand mode:
    // https://github.com/pytorch/pytorch/pull/91684#issuecomment-1413154458
    if (!Py_IsInitialized()) {
      gil_.disarm();
    }
  }

  PyThreadState* initial_thread_state() const {
    return initial_thread_state_;
  }

 private:
  pybind11::gil_scoped_acquire gil_;
  PyThreadState* initial_thread_state_;
};

// ============================================================================
// == Thread local cache ======================================================
// ============================================================================
class PythonTracer;
struct ThreadLocalResults {
  ThreadLocalResults(
      PyThreadState* thread_state,
      ValueCache* value_cache,
      PythonTracer* active_tracer)
      : thread_state_{thread_state},
        ctx_{(TraceContext*)TraceContextType.tp_alloc(&TraceContextType, 0)},
        value_cache_{value_cache},
        active_tracer_{active_tracer} {
    ctx_->thread_local_results_ = this;
  }

  ThreadLocalResults() = delete;
  ThreadLocalResults(const ThreadLocalResults&) = delete;
  ThreadLocalResults(ThreadLocalResults&&) = delete;
  ThreadLocalResults& operator=(const ThreadLocalResults&) = delete;
  ThreadLocalResults& operator=(const ThreadLocalResults&&) = delete;

  ~ThreadLocalResults() {
    Py_DECREF((PyObject*)ctx_);
  }

  template <CallType C, EventType E, typename Ephemeral, typename... Args>
  TraceKey intern(Ephemeral ephemeral, Args... args) {
    static_assert(
        Config<C>::event_type == E,
        "ThreadLocalResults.intern called from the wrong typed context.");
    auto callsite = Callsite<C>(std::forward<Args>(args)...);
    return std::get<C>(trace_keys_).intern(callsite, ephemeral, *value_cache_);
  }

  static constexpr size_t BLOCK_SIZE = 1024;

  PyThreadState* thread_state_;
  TraceContext* ctx_;
  ValueCache* value_cache_;
  PythonTracer* active_tracer_;
  CallTypeHelper<TraceKeyCacheState>::tuple_type trace_keys_;
  AppendOnlyList<c10::approx_time_t, BLOCK_SIZE> exit_times_;
  AppendOnlyList<c10::approx_time_t, BLOCK_SIZE> c_exit_times_;
};

// ============================================================================
// == Tracing implementation ==================================================
// ============================================================================
class PythonTracer final : public python_tracer::PythonTracerBase {
 public:
  PythonTracer(torch::profiler::impl::RecordQueue* queue);
  // NOLINTNEXTLINE(bugprone-exception-escape)
  ~PythonTracer() override;

  static int pyProfileFn(
      PyObject* obj,
      PyFrameObject* frame,
      int what,
      PyObject* arg);

  void stop() override;
  void restart() override;
  std::vector<std::shared_ptr<Result>> getEvents(
      std::function<c10::time_t(c10::approx_time_t)> time_converter,
      std::vector<python_tracer::CompressedEvent>& enters,
      c10::time_t end_time_ns) override;

  struct StartFrame {
    TraceKey trace_key_;
    c10::approx_time_t start_time{};
  };

 private:
  void recordPyCall(
      ThreadLocalResults& tls,
      PyFrameObject* frame,
      bool is_startup_frame);

  void recordCCall(
      ThreadLocalResults& tls,
      PyFrameObject* frame,
      PyObject* arg);

  const std::vector<PyThreadState*> interpreterThreads() const;

  std::atomic<bool> active_lock_{false};
  bool active_{false};

  torch::profiler::impl::RecordQueue* queue_;
  PyInterpreterState* interpreter_{nullptr};
  PyCodeObject* module_call_code_;
  PyCodeObject* optimizer_hook_;

  std::vector<StartFrame> start_frames_;
  std::deque<ThreadLocalResults> thread_local_results_;
  ValueCache value_cache_;
};

const std::vector<PyThreadState*> PythonTracer::interpreterThreads() const {
  pybind11::gil_scoped_acquire gil;
  std::vector<PyThreadState*> out;
  if (SOFT_ASSERT(interpreter_)) {
    auto* thread_state = PyInterpreterState_ThreadHead(interpreter_);
    while (thread_state != nullptr) {
      out.push_back(thread_state);
      thread_state = PyThreadState_Next(thread_state);
    }
  }
  return out;
}

PythonTracer::PythonTracer(torch::profiler::impl::RecordQueue* queue)
    : queue_(queue),

      module_call_code_(getCode<CallType::PyModuleCall>()),
      optimizer_hook_(getCode<CallType::PyOptimizerCall>()) {
  TORCH_CHECK(queue_ != nullptr);

  bool expected{false};
  active_ = active_lock_.compare_exchange_strong(expected, true);
  if (!active_) {
    TORCH_WARN(
        "There is already an active Python tracer. "
        "Refusing to register profile functions.");
    return;
  }

  gil_and_restore_thread gil;
  interpreter_ = PyInterpreterState_Get();

  if (!gil.initial_thread_state()) {
    TORCH_WARN("PyThreadState_Get returned NULL");
    return;
  }

  // Register the tracer in each thread.
  for (const auto thread_state : interpreterThreads()) {
    PyThreadState_Swap(thread_state);

    thread_local_results_.emplace_back(thread_state, &value_cache_, this);
    auto* ctx = thread_local_results_.back().ctx_;

    // When we begin profiling there are already frames on the Python
    // interpreter stack. To ensure a complete trace, we must push calls
    // to all the prior frames onto our event stack. (We stop at depth=128)

    std::vector<THPFrameObjectPtr> current_stack;
    auto frame = PyEval_GetFrame();
    Py_XINCREF(frame);

    size_t depth = 0; // Make sure we can't infinite loop.
    while (frame != nullptr) {
      current_stack.emplace_back(frame);
      if (++depth == 128) {
        break;
      }

      // NB: `PyFrame_GetBack` returns a strong reference.
      frame = PyFrame_GetBack(frame);
    }

    for (auto it = current_stack.rbegin(); it != current_stack.rend(); it++) {
      recordPyCall(thread_local_results_.back(), it->get(), true);
      auto frame_refcount = Py_REFCNT(it->get());

      // We hold one reference in `current_stack`, and the interpreter holds
      // another.
      TORCH_INTERNAL_ASSERT(frame_refcount >= 2, frame_refcount);
    }

    // Note:
    //   This profile will not compose with other CPython profilers, and
    //   cannot be round tripped via `sys.settrace(sys.gettrace())`
    PyEval_SetProfile(PythonTracer::pyProfileFn, (PyObject*)ctx);
  }
};

void PythonTracer::stop() {
  gil_and_restore_thread gil;
  if (active_) {
    for (const auto thread_state : interpreterThreads()) {
      if (thread_state->c_profilefunc == &PythonTracer::pyProfileFn) {
        PyThreadState_Swap(thread_state);
        PyEval_SetProfile(nullptr, nullptr);
      }
    }

    auto lock_returned = active_lock_.compare_exchange_strong(active_, false);
    active_ = false;
    SOFT_ASSERT(lock_returned, "Failed to return python tracer lock.");
  }
}

void PythonTracer::restart() {
  gil_and_restore_thread gil;
  active_ = active_lock_.compare_exchange_strong(active_, true);
  if (!active_) {
    TORCH_WARN(
        "There is already an active Python tracer. "
        "Refusing to register profile functions.");
    return;
  }
  int cur_thread = 0;
  for (const auto thread_state : interpreterThreads()) {
    if (thread_state->c_profilefunc == nullptr) {
      auto* ctx = thread_local_results_[cur_thread].ctx_;
      PyThreadState_Swap(thread_state);
      PyEval_SetProfile(PythonTracer::pyProfileFn, (PyObject*)ctx);
    }
  }
}

// NOLINTNEXTLINE(bugprone-exception-escape)
PythonTracer::~PythonTracer() {
  if (active_) {
    TORCH_WARN("`PythonTracer::stop()` was not called.");
    stop();
  }
}

void PythonTracer::recordPyCall(
    ThreadLocalResults& tls,
    PyFrameObject* frame,
    bool is_startup_frame) {
  static constexpr auto E = EventType::PyCall;
  const auto key = [&]() -> TraceKey {
    auto code = THPCodeObjectPtr(PyFrame_GetCode(frame));
    if (code.get() == module_call_code_) {
      // By default, CPython stores locals in a "fast" format, with an array
      // of names and an array of values. Consequently, frame->f_locals is
      // NULL since the interpreter has no need to populate it.
      //
      // If these arrays were part of the public API then we could very
      // quickly access `self`. Unfortunately they are not, and moreover are
      // not stable across versions. As a result, we are forced to call
      // `PyFrame_FastToLocals` which forces the interpreter to materialize
      // the full dict of locals.
      auto locals = THPObjectPtr(PyFrame_GetLocals(frame));
      auto self = THPObjectPtr(PyDict_GetItemString(locals, "self"));
      Py_INCREF(self.get());
      auto back = THPFrameObjectPtr(PyFrame_GetBack(frame));
      TORCH_INTERNAL_ASSERT(back != nullptr);
      return tls.intern<CallType::PyModuleCall, E>(
          frame, self.get(), back.get());
    } else if (code.get() == optimizer_hook_) {
      auto locals = THPObjectPtr(PyFrame_GetLocals(frame));
      auto self = THPObjectPtr(PyDict_GetItemString(locals, "self"));
      Py_INCREF(self.get());
      auto back = THPFrameObjectPtr(PyFrame_GetBack(frame));
      TORCH_INTERNAL_ASSERT(back != nullptr);
      return tls.intern<CallType::PyOptimizerCall, E>(
          frame, self.get(), back.get());
    } else {
      auto back = THPFrameObjectPtr(PyFrame_GetBack(frame));
      auto f_back = (back.get() != nullptr) ? back.get() : frame;
      return tls.intern<CallType::PyCall, E>(no_ephemeral_t(), frame, f_back);
    }
  }();
  const auto time = c10::getApproximateTime();
  is_startup_frame ? start_frames_.push_back({key, time})
                   : queue_->getSubqueue()->emplace_py_call(key, time);
}

void PythonTracer::recordCCall(
    ThreadLocalResults& tls,
    PyFrameObject* frame,
    PyObject* arg) {
  TORCH_INTERNAL_ASSERT_DEBUG_ONLY(PyCFunction_Check(arg));
  auto fn = reinterpret_cast<PyCFunctionObject*>(arg);

  // NB: For C calls a new frame is not created, so we use `frame` rather than
  //     `frame->f_back`.
  auto key = tls.intern<CallType::PyCCall, EventType::PyCCall>(
      arg, (void*)(fn->m_ml), frame);
  queue_->getSubqueue()->emplace_py_call(key, c10::getApproximateTime());
}

// ============================================================================
// == Post processing =========================================================
// ============================================================================
struct Exit {
  bool operator>(const Exit& other) const {
    return t_ > other.t_;
  }

  c10::time_t t_;
  size_t python_tid_;
};

class PostProcess {
 public:
  PostProcess(
      std::function<c10::time_t(c10::approx_time_t)> time_converter,
      std::deque<ThreadLocalResults>& tls,
      const ValueCache& value_cache,
      c10::time_t end_time_ns)
      : end_time_{end_time_ns}, time_converter_{std::move(time_converter)} {
    for (size_t python_tid : c10::irange(tls.size())) {
      CallTypeHelper<TraceKeyCacheState>::map(
          tls[python_tid].trace_keys_, *this, value_cache, python_tid);

      addExits<EventType::PyCall>(tls[python_tid].exit_times_, python_tid);
      addExits<EventType::PyCCall>(tls[python_tid].c_exit_times_, python_tid);
    }
  }

  void set_start_frames(
      const std::vector<PythonTracer::StartFrame>& start_frames,
      std::vector<python_tracer::CompressedEvent>& enters) {
    for (const auto& frame : start_frames) {
      enters.push_back(
          {frame.trace_key_,
           NoTID, // Allows us to detect unhandled start frames
           {},
           time_converter_(frame.start_time)});
    }
  }

  template <CallType C>
  void operator()(
      const TraceKeyCacheState<C>& trace_cache,
      const ValueCache& value_cache,
      size_t python_tid) {
    for (const auto& it : trace_cache.state_) {
      const auto inserted = get_state<Config<C>::event_type>().fields_.insert(
          {it.second, value_cache.load(it.first, python_tid)});
      TORCH_INTERNAL_ASSERT(inserted.second, "Duplicate key: ", it.second);
    }
  }

  template <EventType E, size_t N>
  void addExits(
      AppendOnlyList<c10::approx_time_t, N>& exits,
      size_t python_tid) {
    for (const auto i : exits) {
      get_state<E>().exits_.push({time_converter_(i), python_tid});
    }
  }

  std::vector<std::shared_ptr<Result>> run(
      std::vector<python_tracer::CompressedEvent>& enters) {
    std::stable_sort(
        enters.begin(), enters.end(), [](const auto a, const auto b) {
          return a.enter_t_ < b.enter_t_;
        });
    std::vector<std::shared_ptr<Result>> out;
    populate<EventType::PyCall>(enters, out);
    populate<EventType::PyCCall>(enters, out);
    return out;
  }

 private:
  template <EventType E>
  void populate(
      std::vector<python_tracer::CompressedEvent>& enters,
      std::vector<std::shared_ptr<Result>>& out) {
    using stack_t = std::vector<std::shared_ptr<Result>>;
    const auto initial_size = out.size();
    auto pop = [](stack_t& stack, c10::time_t t) {
      TORCH_INTERNAL_ASSERT(!stack.empty(), "Python replay stack is empty.");
      std::get<ExtraFields<E>>(stack.back()->extra_fields_).end_time_ns_ = t;
      stack.pop_back();
    };

    ska::flat_hash_map<size_t, stack_t> stacks;
    auto& state = get_state<E>();
    for (const auto& enter : enters) {
      auto fields_it = state.fields_.find(enter.key_);
      if (fields_it != state.fields_.end()) {
        while (!state.exits_.empty() &&
               state.exits_.top().t_ < enter.enter_t_) {
          auto& exit = state.exits_.top();
          pop(stacks[exit.python_tid_], exit.t_);
          state.exits_.pop();
        }
        out.push_back(Result::create(
            enter.enter_t_,
            enter.system_tid_,
            enter.kineto_info_,
            fields_it->second));

        stacks[fields_it->second.python_tid_].push_back(out.back());
      }
    }

    // Handle events which were still running when profiling ended.
    for (auto& i : stacks) {
      while (!i.second.empty()) {
        pop(i.second, end_time_);
      }
    }

    // Assign system TIDs to start events based on the system TID of the next
    // observed event with the same Python TID.
    ska::flat_hash_map<size_t, std::pair<size_t, kineto::DeviceAndResource>>
        tid_map;
    auto it = out.rbegin();
    for (C10_UNUSED auto _ : c10::irange(initial_size, out.size())) {
      const auto python_tid =
          std::get<ExtraFields<E>>((*it)->extra_fields_).python_tid_;
      if ((*it)->start_tid_ == NoTID && SOFT_ASSERT(E == EventType::PyCall)) {
        const auto& tid_info =
            tid_map.insert({python_tid, {NoTID, kineto::DeviceAndResource()}})
                .first->second;
        (*it)->start_tid_ = tid_info.first;
        (*it)->kineto_info_ = tid_info.second;
      }
      tid_map[python_tid] = {(*it)->start_tid_, (*it)->kineto_info_};
      ++it;
    }
  }

  template <EventType E>
  struct State {
    ska::flat_hash_map<TraceKey, ExtraFields<E>> fields_;
    std::priority_queue<Exit, std::vector<Exit>, std::greater<>> exits_;
  };

  template <EventType E>
  auto& get_state() {
    return std::get < E == EventType::PyCall ? 0 : 1 > (state_);
  }

  c10::time_t end_time_;
  std::function<c10::time_t(c10::approx_time_t)> time_converter_;
  std::tuple<State<EventType::PyCall>, State<EventType::PyCCall>> state_;
};

struct PythonIDVisitor {
  void operator()(ExtraFields<EventType::PyCall>& py_call) {
    py_call.id_ = ++current_python_id_;
    if (py_call.module_.has_value()) {
      auto& m = py_call.module_;
      auto& module_ids = module_ids_[m->cls_];
      m->id_ = module_ids.insert({m->self_, module_ids.size()}).first->second;
    }
  }

  void operator()(ExtraFields<EventType::PyCCall>& py_call) {
    py_call.id_ = ++current_python_id_;
  }

  template <typename T>
  void operator()(T&) {}

  size_t current_python_id_{0};
  ska::flat_hash_map<PyModuleCls, ska::flat_hash_map<PyModuleSelf, size_t>>
      module_ids_;
};

std::vector<std::shared_ptr<Result>> PythonTracer::getEvents(
    std::function<c10::time_t(c10::approx_time_t)> time_converter,
    std::vector<python_tracer::CompressedEvent>& enters,
    c10::time_t end_time_ns) {
  value_cache_.trimPrefixes();
  PostProcess post_process(
      std::move(time_converter),
      thread_local_results_,
      value_cache_,
      end_time_ns);
  post_process.set_start_frames(start_frames_, enters);
  auto out = post_process.run(enters);

  std::stable_sort(out.begin(), out.end(), [](const auto& a, const auto& b) {
    return a->start_time_ns_ < b->start_time_ns_;
  });

  PythonIDVisitor id_visitor;
  for (auto& i : out) {
    std::visit(id_visitor, i->extra_fields_);
  }

  return out;
}

// ============================================================================
// == API =====================================================================
// ============================================================================
int PythonTracer::pyProfileFn(
    PyObject* obj,
    PyFrameObject* frame,
    int what,
    PyObject* arg) {
  auto& local_results =
      *reinterpret_cast<TraceContext*>(obj)->thread_local_results_;
  switch (what) {
    case PyTrace_CALL:
      local_results.active_tracer_->recordPyCall(local_results, frame, false);
      break;

    case PyTrace_C_CALL:
      local_results.active_tracer_->recordCCall(local_results, frame, arg);
      break;

    case PyTrace_EXCEPTION:
    case PyTrace_RETURN:
      local_results.exit_times_.emplace_back(c10::getApproximateTime());
      break;

    case PyTrace_C_EXCEPTION:
    case PyTrace_C_RETURN:
      local_results.c_exit_times_.emplace_back(c10::getApproximateTime());
      break;
  }
  return 0;
}

std::unique_ptr<python_tracer::PythonTracerBase> getTracer(
    torch::profiler::impl::RecordQueue* queue) {
  return std::make_unique<PythonTracer>(queue);
}
} // namespace
} // namespace torch::profiler::impl

namespace torch::autograd::profiler::python_tracer {

void init() {
  pybind11::gil_scoped_acquire gil;
  TORCH_CHECK(PyType_Ready(&torch::profiler::impl::TraceContextType) == 0);
  torch::profiler::impl::python_tracer::registerTracer(
      &torch::profiler::impl::getTracer);
}
} // namespace torch::autograd::profiler::python_tracer
