#include <torch/csrc/autograd/profiler_legacy.h>

#include <torch/csrc/autograd/function.h>
#include <torch/csrc/jit/frontend/tracer.h>
#include <torch/csrc/jit/runtime/interpreter.h>
#include <torch/csrc/jit/runtime/operator.h>

#include <ATen/code_template.h>
#include <ATen/core/op_registration/op_registration.h>
#include <torch/library.h>

#include <fstream>
#include <mutex>
#include <string>
#include <vector>

#include <ATen/record_function.h>
#include <c10/core/Allocator.h>
#include <c10/util/ApproximateClock.h>
#include <c10/util/ThreadLocalDebugInfo.h>
#include <c10/util/irange.h>

#include <iostream>

namespace torch::autograd::profiler {

// We decompose the profiler logic into the following components:
//
// ThreadLocalDebugInfo:
//
// ThreadLocalDebugInfo is a thread local mapping from slots into
// the debug information structs.
// ThreadLocalDebugInfo is automatically propagated across thread
// boundaries, including the cases of:
//  - launching async jobs with at::launch
//  - executing JIT continuations
//  - moving from the forward threads into autograd (backward) threads
//
// Entries in ThreadLocalDebugInfo are managed by DebugInfoGuard
// which can be used to add or overwrite an entry in the thread local
// mapping. A corresponding entry is removed when the guard is destroyed,
// potentially revealing the previously set value for the same slot.
//
// For the async tasks, slots previously set in the main thread before
// launching of an async task are shared and visible in the async task.
//
// On the other hand, any adding or overwriting of the mapping by the
// async task is not visible to the main thread and any modification
// (including removal of the entries) in the main thread is not visible
// to the async task if it happens after launching the task.
//
// We use ThreadLocalDebugInfo (slot PROFILER_STATE) to store profiler config,
// as well as a list of events that happen during profiling.
// An instance of ThreadLocalDebugInfo is created each time we enter
// profiler (i.e. enter profiling context manager/call enableConfig) and
// uniquely identifies a profiling run.
//
// We automatically propagate ThreadLocalDebugInfo into async tasks,
// as well as across JIT continuations and autograd thread, so all
// the operations that happen between profiling start and end
// (not necessarily within the same thread) are recorded.
// Unless the profiling slot is overwritten as in the case of nested
// profiling ranges (in this case events for the subrange are handled
// by the nested profiler)
//
// When we exit a profiling range (either by exiting profiling context
// manager or by calling disableProfiler), we remove the previously set
// profiling entry for the given thread local mapping, and consolidate
// events in the profiling result
//
//
// ThreadLocalState:
//
// ThreadLocalState takes a 'snapshot' of thread local variables
// using provided getters. It is used together with ThreadLocalStateGuard
// to transfer the snapshot across thread boundary and set the thread local
// values as in the parent task.
//
// Profiler uses ThreadLocalState to propagate profiler's thread local state.
// ThreadLocalState also automatically propagates profiler callbacks.
//
//
// at::RecordFunction and observers
//
// Profiler uses observers mechanism to add a pair of thread local callbacks
// that are executed on a number of predetermined ranges, including:
//  - c10/ATen ops
//  - TorchScript functions/methods
//  - user defined named ranges (see `record_function` python context manager)
//
// Profiler setups a pair of callbacks that record profiling events and save
// them into the thread local profiler struct (ThreadLocalDebugInfo,
// PROFILER_STATE slot)
//
//
// Thus, the overall logic is:
//
// enableProfiler:
//  - checks that profiler is not enabled (otherwise throws)
//  - pushes new ThreadLocalDebugInfo (slot PROFILER_STATE) as the profiler
//    config for the current thread
//  - pushes profiling callbacks for the current thread
//
// disableProfiler:
//  - pops PROFILER_STATE slot from the current ThreadLocalDebugInfo and
//    consolidates events
//  - removes profiling callbacks
//
// ThreadLocalState:
//  - propagates ThreadLocalDebugInfo across threads
//  - propagates profiler callbacks across threads
//
// Profiler callbacks:
//  - get the current profiling state (PROFILER slot in ThreadLocalDebugInfo)
//  - save profiling events into the profiling state
//

namespace {
using torch::profiler::impl::ActiveProfilerType;
using torch::profiler::impl::ProfilerStateBase;

struct ProfilerLegacyThreadLocalState : public ProfilerStateBase {
  explicit ProfilerLegacyThreadLocalState(
      const torch::profiler::impl::ProfilerConfig& config)
      : ProfilerStateBase(config), remoteProfiledEvents_{std::nullopt} {}
  ~ProfilerLegacyThreadLocalState() override = default;

  static ProfilerLegacyThreadLocalState* getTLS() {
    auto tls = ProfilerStateBase::get(/*global=*/false);
    TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
        tls == nullptr || tls->profilerType() == ActiveProfilerType::LEGACY);
    return static_cast<ProfilerLegacyThreadLocalState*>(tls);
  }

  thread_event_lists consolidate();

  void mark(std::string name, bool include_cuda = true);

  void setOrAddRemoteProfiledEvents(
      std::vector<LegacyEvent>&& remoteProfiledEvents);

  void pushRange(
      const at::RecordFunction& fn,
      const bool record_cuda,
      std::vector<std::vector<int64_t>>&& shapes = {});

  void popRange(const at::RecordFunction& fn, const bool record_cuda);

  void reportMemoryUsage(
      void* /* unused */,
      int64_t alloc_size,
      size_t /* total_allocated, unused for legacy */,
      size_t /* total_reserved, unused for legacy */,
      c10::Device device) override;

  ActiveProfilerType profilerType() override {
    return ActiveProfilerType::LEGACY;
  }

  void leakHandle() {
    handle_ = 0;
  }

 protected:
  RangeEventList& getEventList(
      std::optional<uint64_t> thread_id = std::nullopt);

  std::mutex state_mutex_;
  std::unordered_map<uint64_t, std::shared_ptr<RangeEventList>>
      event_lists_map_;

  std::optional<std::vector<std::vector<LegacyEvent>>> remoteProfiledEvents_;
};

thread_event_lists ProfilerLegacyThreadLocalState::consolidate() {
  std::lock_guard<std::mutex> g(state_mutex_);
  thread_event_lists result;
  for (auto& kv : event_lists_map_) {
    auto& list = kv.second;
    result.emplace_back(list->consolidate());
  }
  // Consolidate remote events if applicable as well.
  if (remoteProfiledEvents_) {
    result.insert(
        result.end(),
        std::make_move_iterator(remoteProfiledEvents_->begin()),
        std::make_move_iterator(remoteProfiledEvents_->end()));
  }
  return result;
}

void ProfilerLegacyThreadLocalState::mark(std::string name, bool include_cuda) {
  if (config_.disabled()) {
    return;
  }
  if (config_.state == torch::profiler::impl::ProfilerState::NVTX) {
    torch::profiler::impl::cudaStubs()->mark(name.c_str());
  } else {
    LegacyEvent evt(
        EventKind::Mark,
        at::StringView(std::move(name)),
        at::RecordFunction::currentThreadId(),
        include_cuda &&
            config_.state == torch::profiler::impl::ProfilerState::CUDA);
    evt.setNodeId(at::RecordFunction::getDefaultNodeId());
    getEventList().record(std::move(evt));
  }
}

void ProfilerLegacyThreadLocalState::setOrAddRemoteProfiledEvents(
    std::vector<LegacyEvent>&& remoteProfiledEvents) {
  // Lock to serialize access from multiple callback threads.
  std::lock_guard<std::mutex> guard(state_mutex_);
  if (remoteProfiledEvents_) {
    (*remoteProfiledEvents_).emplace_back(remoteProfiledEvents);
  } else {
    remoteProfiledEvents_ = {std::move(remoteProfiledEvents)};
  }
}

void ProfilerLegacyThreadLocalState::pushRange(
    const at::RecordFunction& fn,
    const bool record_cuda,
    std::vector<std::vector<int64_t>>&& shapes) {
  if (config_.disabled()) {
    return;
  }
  if (config_.state == torch::profiler::impl::ProfilerState::NVTX) {
    torch::profiler::impl::cudaStubs()->rangePush(
        torch::profiler::impl::getNvtxStr(fn.name(), fn.seqNr(), shapes)
            .c_str());
  } else {
    LegacyEvent evt(
        EventKind::PushRange,
        at::StringView(std::string(fn.name())),
        at::RecordFunction::currentThreadId(),
        record_cuda,
        fn.handle(),
        std::move(shapes),
        at::RecordFunction::getDefaultNodeId(),
        fn.isAsync());
    evt.setSequenceNr(fn.seqNr());
    evt.setFwdThreadId(fn.forwardThreadId());
    evt.setScope((uint8_t)fn.scope());
    if (config_.with_flops) {
      evt.setExtraArgs(torch::profiler::impl::saveExtraArgs(fn));
      evt.setFlops(torch::profiler::impl::computeFlops(
          std::string(fn.name()), evt.extraArgs()));
    }

// TODO: will unify the two macros BUILD_LITE_INTERPRETER and C10_MOBILE soon.
#if !defined BUILD_LITE_INTERPRETER && !defined C10_MOBILE
    // backward nodes source range corresponds to the forward node
    // TODO: consider using C++ stack trace
    if (config_.with_stack &&
        fn.scope() != at::RecordScope::BACKWARD_FUNCTION) {
      auto cs =
          torch::profiler::impl::prepareCallstack(jit::currentCallstack());
      if (cs.empty()) {
        cs = torch::profiler::impl::prepareCallstack(
            jit::tracer::pythonCallstack());
      }
      evt.setStack(callstackStr(cs));
    }
#endif
    getEventList().record(std::move(evt));
  }
}

void ProfilerLegacyThreadLocalState::popRange(
    const at::RecordFunction& fn,
    const bool record_cuda) {
  if (config_.disabled()) {
    return;
  }
  if (config_.state == torch::profiler::impl::ProfilerState::NVTX) {
    torch::profiler::impl::cudaStubs()->rangePop();
  } else {
    // In some cases RecordFunction (and popRange) may be
    // called on a different thread than pushRange
    // As a convention, we put the async pop on the original
    // thread and save current thread id in pop event
    LegacyEvent evt(
        EventKind::PopRange,
        at::StringView(""),
        at::RecordFunction::currentThreadId(),
        record_cuda,
        fn.handle());
    evt.setNodeId(at::RecordFunction::getDefaultNodeId());
    getEventList(fn.threadId()).record(std::move(evt));
  }
}

void ProfilerLegacyThreadLocalState::reportMemoryUsage(
    void* /* unused */,
    int64_t alloc_size,
    size_t /* total_allocated, unused for legacy */,
    size_t /* total_reserved, unused for legacy */,
    c10::Device device) {
  if (config_.profile_memory && !config_.disabled()) {
    uint64_t thread_id = at::RecordFunction::currentThreadId();
    LegacyEvent evt(
        EventKind::MemoryAlloc,
        at::StringView(""),
        thread_id,
        config_.state == torch::profiler::impl::ProfilerState::CUDA);
    evt.updateMemoryStats(alloc_size, device);
    getEventList(thread_id).record(std::move(evt));
  }
}

RangeEventList& ProfilerLegacyThreadLocalState::getEventList(
    std::optional<uint64_t> thread_id) {
  if (!thread_id.has_value()) {
    thread_id = at::RecordFunction::currentThreadId();
  }
  RangeEventList* list_ptr = nullptr;
  std::lock_guard<std::mutex> guard(state_mutex_);
  auto it = event_lists_map_.find(thread_id.value());
  if (it != event_lists_map_.end()) {
    list_ptr = it->second.get();
  } else {
    auto event_list = std::make_shared<RangeEventList>();
    event_lists_map_[thread_id.value()] = event_list;
    list_ptr = event_list.get();
  }
  return *list_ptr;
}

enum EventIValueIdx {
  KIND = 0,
  NAME,
  THREAD_ID,
  HANDLE,
  NODE_ID,
  CPU_MEM_USAGE,
  CPU_NS,
  CUDA_RECORDED,
  CUDA_MEM_USAGE,
  CUDA_DEVICE,
  CUDA_US,
  SHAPES,
  NUM_EVENT_IVALUE_IDX // must be last in list
};

const std::unordered_set<std::string> disable_cuda_profiling = {
    "aten::view",
    "aten::t",
    "aten::transpose",
    "aten::stride",
    "aten::empty",
    "aten::empty_like",
    "aten::empty_strided",
    "aten::as_strided",
    "aten::expand",
    "aten::resize_",
    "aten::squeeze",
    "aten::unsqueeze",
    "aten::slice",
    "aten::_unsafe_view",
    "aten::size"};

void pushProfilingCallbacksLegacy() {
  auto registration_state_ptr = ProfilerLegacyThreadLocalState::getTLS();
  TORCH_INTERNAL_ASSERT(registration_state_ptr, "Expected profiler state set");
  auto handle = at::addThreadLocalCallback(
      at::RecordFunctionCallback(
          [](const at::RecordFunction& fn)
              -> std::unique_ptr<at::ObserverContext> {
            auto state_ptr = ProfilerLegacyThreadLocalState::getTLS();
            if (!state_ptr || state_ptr->config().disabled()) {
              return nullptr;
            }
            bool record_cuda = state_ptr->config().state ==
                torch::profiler::impl::ProfilerState::CUDA;
            if (record_cuda &&
                disable_cuda_profiling.find(fn.name()) !=
                    disable_cuda_profiling.end()) {
              record_cuda = false;
            }

            if (state_ptr->config().report_input_shapes) {
              auto sizes = torch::profiler::impl::inputSizes(fn);
              state_ptr->pushRange(fn, record_cuda, std::move(sizes));
            } else {
              state_ptr->pushRange(fn, record_cuda);
            }

            return nullptr;
          },
          [](const at::RecordFunction& fn, at::ObserverContext*) {
            auto state_ptr = ProfilerLegacyThreadLocalState::getTLS();
            if (!state_ptr || state_ptr->config().disabled()) {
              return;
            }
            bool record_cuda = state_ptr->config().state ==
                torch::profiler::impl::ProfilerState::CUDA;
            if (record_cuda &&
                disable_cuda_profiling.find(fn.name()) !=
                    disable_cuda_profiling.end()) {
              record_cuda = false;
            }
            state_ptr->popRange(fn, record_cuda);
          })
          .needsInputs(registration_state_ptr->config().report_input_shapes)
          .needsIds(true));
  registration_state_ptr->setCallbackHandle(handle);
}

} // namespace

void enableProfilerLegacy(
    const torch::profiler::impl::ProfilerConfig& new_config) {
  TORCH_CHECK(
      new_config.state != torch::profiler::impl::ProfilerState::NVTX ||
          torch::profiler::impl::cudaStubs()->enabled(),
      "Can't use NVTX profiler - PyTorch was compiled without CUDA");

  TORCH_CHECK(new_config.state != torch::profiler::impl::ProfilerState::KINETO);

  auto state_ptr = ProfilerLegacyThreadLocalState::getTLS();
  TORCH_CHECK(!state_ptr, "Profiler is already enabled on this thread");
  auto state = std::make_shared<ProfilerLegacyThreadLocalState>(new_config);
  c10::ThreadLocalDebugInfo::_push(c10::DebugInfoKind::PROFILER_STATE, state);

  pushProfilingCallbacksLegacy();

  state->mark("__start_profile", false);
}

thread_event_lists disableProfilerLegacy(
    std::optional<ProfilerDisableOptions> profilerDisableOptions) {
  auto cleanupTLSState =
      profilerDisableOptions ? profilerDisableOptions->cleanupTLSState : true;
  auto consolidate =
      profilerDisableOptions ? profilerDisableOptions->consolidate : true;
  // all the DebugInfoBase objects are scope based and supposed to use
  // DebugInfoGuard
  std::shared_ptr<c10::DebugInfoBase> state;
  if (cleanupTLSState) {
    state = c10::ThreadLocalDebugInfo::_pop(c10::DebugInfoKind::PROFILER_STATE);
  } else {
    state =
        c10::ThreadLocalDebugInfo::_peek(c10::DebugInfoKind::PROFILER_STATE);
  }

  auto state_ptr = static_cast<ProfilerLegacyThreadLocalState*>(state.get());
  TORCH_CHECK(
      state_ptr && !state_ptr->config().disabled(),
      "Can't disable profiler when it's not running");

  cleanupTLSState ? state_ptr->removeCallback() : state_ptr->leakHandle();
  if (!consolidate ||
      state_ptr->config().state == torch::profiler::impl::ProfilerState::NVTX) {
    return thread_event_lists();
  }

  state_ptr->mark("__stop_profile", false);
  // Note that this will erase the underlying events.
  return state_ptr->consolidate();
}

void addEventList(std::vector<LegacyEvent>&& profiledEvents) {
  auto state_ptr = ProfilerLegacyThreadLocalState::getTLS();
  TORCH_CHECK(state_ptr, "Profiler must be enabled.");
  state_ptr->setOrAddRemoteProfiledEvents(std::move(profiledEvents));
}

void LegacyEvent::record(bool record_cuda) {
  if (record_cuda) {
    torch::profiler::impl::cudaStubs()->record(&device_, &cuda_event, &cpu_ns_);
    return;
  }
  cpu_ns_ = c10::getTime();
}

/* static */ LegacyEvent LegacyEvent::fromIValue(
    const at::IValue& eventIValue) {
  TORCH_INTERNAL_ASSERT(
      eventIValue.isList(),
      "Expected IValue to contain type c10::impl::GenericList");
  auto ivalues = eventIValue.toList();
  TORCH_INTERNAL_ASSERT(
      ivalues.size() >= NUM_EVENT_IVALUE_IDX,
      "Expected at least ",
      NUM_EVENT_IVALUE_IDX,
      " elements to reconstruct LegacyEvent.");

  // Reconstruct input shapes from ivalues.
  const auto& shapeListIValue = ivalues.get(EventIValueIdx::SHAPES);
  TORCH_INTERNAL_ASSERT(
      shapeListIValue.isList(),
      "Expected profiler shapes IValue to contain type c10::impl::GenericList.");

  auto shapeList = shapeListIValue.toList();
  std::vector<std::vector<int64_t>> shapes;
  shapes.reserve(shapeList.size());
  for (const auto i : c10::irange(shapeList.size())) {
    std::vector<int64_t> s;
    const auto& shapeIValue = shapeList.get(i);
    TORCH_INTERNAL_ASSERT(
        shapeIValue.isList(),
        "Expected each profiler shape element to contain shapes of type c10::impl::GenericList.")
    auto curShapesList = shapeIValue.toList();
    s.reserve(curShapesList.size());
    for (const auto j : c10::irange(curShapesList.size())) {
      s.emplace_back(curShapesList.get(j).toInt());
    }
    shapes.emplace_back(s);
  }

  LegacyEvent evt(
      static_cast<EventKind>(
          ivalues.get(EventIValueIdx::KIND).toInt()), // EventKind
      at::StringView(ivalues.get(EventIValueIdx::NAME).toStringRef()), // name
      ivalues.get(EventIValueIdx::THREAD_ID).toInt(), // thread_id
      static_cast<at::RecordFunctionHandle>(
          ivalues.get(EventIValueIdx::HANDLE).toDouble()), // handle
      std::move(shapes), // input shapes
      ivalues.get(EventIValueIdx::NODE_ID).toInt(), // node id
      true, // is remote
      ivalues.get(EventIValueIdx::CPU_MEM_USAGE).toInt(), // cpu_mem_usage
      ivalues.get(EventIValueIdx::CPU_NS).toInt(), // cpu_ns
      ivalues.get(EventIValueIdx::CUDA_RECORDED).toBool(), // was cuda recorded
      ivalues.get(EventIValueIdx::CUDA_MEM_USAGE).toInt(), // cuda memory usage
      c10::DeviceIndex(
          ivalues.get(EventIValueIdx::CUDA_DEVICE).toInt()), // device
      static_cast<double>(
          ivalues.get(EventIValueIdx::CUDA_US).toInt()) // cuda_us
  );
  return evt;
}

at::IValue LegacyEvent::toIValue() const {
  c10::impl::GenericList eventIValueList(at::AnyType::get());
  eventIValueList.reserve(NUM_EVENT_IVALUE_IDX);
  eventIValueList.emplace_back(static_cast<int64_t>(kind_));
  eventIValueList.emplace_back(std::string(name_.str()));
  eventIValueList.emplace_back(static_cast<int64_t>(thread_id_));
  eventIValueList.emplace_back(static_cast<double>(handle_));
  eventIValueList.emplace_back(node_id_);
  eventIValueList.emplace_back(cpu_memory_usage_);
  eventIValueList.emplace_back(cpu_ns_);
  // CUDA event information
  bool cuda_profiling_enabled = hasCuda();
  eventIValueList.emplace_back(cuda_profiling_enabled);
  eventIValueList.emplace_back(static_cast<int64_t>(cuda_memory_usage_));
  eventIValueList.emplace_back(device_);
  eventIValueList.emplace_back(cuda_us_);
  // Shapes
  c10::impl::GenericList shapesList =
      c10::impl::GenericList(at::ListType::create(at::IntType::get()));
  shapesList.reserve(shapes_.size());
  for (const auto& shape : shapes_) {
    c10::impl::GenericList s = c10::impl::GenericList(at::IntType::get());
    s.reserve(shape.size());
    for (const auto& k : shape) {
      s.emplace_back(k);
    }
    shapesList.emplace_back(s);
  }
  eventIValueList.emplace_back(shapesList);
  return at::IValue(eventIValueList);
}

double LegacyEvent::cudaElapsedUs(const LegacyEvent& e) const {
  TORCH_CHECK(e.hasCuda() && hasCuda(), "Events were not recorded for CUDA");
  TORCH_CHECK(
      e.device() == device(),
      c10::str(
          "Events are not on the same device: ", e.device(), " vs ", device()));
  if (isRemote() && e.isRemote()) {
    // validate that cuda_us_ has been set properly.
    TORCH_INTERNAL_ASSERT(cuda_us_ >= 0 && e.cuda_us_ >= 0);
    return static_cast<double>(e.cuda_us_ - cuda_us_);
  }
  return torch::profiler::impl::cudaStubs()->elapsed(
      &cuda_event, &e.cuda_event);
}

static const at::jit::CodeTemplate event_template(R"(
{
  "name": "${name}",
  "ph": "X",
  "ts": ${ts},
  "dur": ${dur},
  "tid": ${tid},
  "pid": "CPU Functions",
  "args": {}
})");

void writeProfilerEventsToStream(
    std::ostream& out,
    const std::vector<LegacyEvent*>& events) {
  TORCH_CHECK(out, "Could not open file");
  LegacyEvent* profiler_start = nullptr;
  for (LegacyEvent* e : events) {
    if (0 == strcmp(e->name(), "__start_profile")) {
      profiler_start = e;
      break;
    }
  }
  TORCH_CHECK(profiler_start, "Could not find __start_profile mark");

  struct PairHash {
    size_t operator()(
        std::pair<at::RecordFunctionHandle, int> p) const noexcept {
      return std::hash<at::RecordFunctionHandle>()(p.first) ^
          std::hash<int64_t>()(p.second);
    }
  };
  std::unordered_map<
      std::pair<at::RecordFunctionHandle, int64_t>,
      LegacyEvent*,
      PairHash>
      events_map;
  out << "[\n";
  bool first = true;
  for (LegacyEvent* evt : events) {
    if (evt->kindStr() == "push") {
      events_map[std::make_pair(evt->handle(), evt->nodeId())] = evt;
    } else if (evt->kindStr() == "pop") {
      if (!first) {
        out << ",\n";
      }
      first = false;
      auto it = events_map.find(std::make_pair(evt->handle(), evt->nodeId()));
      TORCH_CHECK(it != events_map.end(), "Unmatched pop event");
      LegacyEvent* evt_start = it->second;
      events_map.erase(it);

      at::jit::TemplateEnv env;
      env.s("name", evt_start->name());
      env.d("ts", profiler_start->cpuElapsedUs(*evt_start));
      env.d("dur", evt_start->cpuElapsedUs(*evt));
      env.d("tid", evt_start->threadId());
      out << event_template.format(env);
    }
  }
  out << "]\n";
}

RecordProfile::RecordProfile(std::ostream& out) : out_(out) {
  init();
}

RecordProfile::RecordProfile(const std::string& filename)
    : file_(std::make_unique<std::ofstream>(filename)), out_(*file_) {
  init();
}

void RecordProfile::init() {
  enableProfilerLegacy(torch::profiler::impl::ProfilerConfig(
      torch::profiler::impl::ProfilerState::CPU));
}

RecordProfile::~RecordProfile() {
  try {
    thread_event_lists event_lists = disableProfilerLegacy();
    std::vector<LegacyEvent*> events;
    for (auto& l : event_lists) {
      for (auto& e : l) {
        events.push_back(&e);
      }
    }
    processEvents(events);
  } catch (const std::exception& e) {
    LOG(ERROR) << e.what() << '\n';
  } catch (...) {
    LOG(ERROR) << "Unknown error" << '\n';
  }
}

void RecordProfile::processEvents(const std::vector<LegacyEvent*>& events) {
  writeProfilerEventsToStream(out_, events);
}

} // namespace torch::autograd::profiler
