#include <torch/csrc/jit/runtime/interpreter.h>

#include <ATen/Parallel.h>
#include <ATen/core/ivalue.h>
#include <ATen/record_function.h>
#include <c10/core/thread_pool.h>
#include <c10/macros/Macros.h>
#include <c10/util/Exception.h>
#include <c10/util/irange.h>
#include <torch/csrc/autograd/edge.h>
#include <torch/csrc/autograd/grad_mode.h>
#include <torch/csrc/autograd/profiler.h>
#include <torch/csrc/autograd/variable.h>
#include <torch/csrc/jit/api/compilation_unit.h>
#include <torch/csrc/jit/api/function_impl.h>
#include <torch/csrc/jit/ir/constants.h>
#include <torch/csrc/jit/ir/ir.h>
#include <torch/csrc/jit/jit_log.h>
#include <torch/csrc/jit/mobile/promoted_prim_ops.h>
#include <torch/csrc/jit/runtime/exception_message.h>
#include <torch/csrc/jit/runtime/graph_executor.h>
#include <torch/csrc/jit/runtime/instruction.h>
#include <torch/csrc/jit/runtime/interpreter/code_impl.h>
#include <torch/csrc/jit/runtime/interpreter/frame.h>
#include <torch/csrc/jit/runtime/jit_exception.h>
#include <torch/csrc/jit/runtime/operator.h>
#include <torch/csrc/jit/runtime/profiling_record.h>
#include <torch/csrc/jit/runtime/script_profile.h>
#include <torch/csrc/jit/runtime/vararg_functions.h>
#include <torch/csrc/utils/cpp_stacktraces.h>
#include <string>

#ifdef USE_RPC
#include <torch/csrc/distributed/autograd/context/container.h>
using torch::distributed::autograd::DistAutogradContainer;
#endif

#include <exception>
#include <memory>
#include <mutex>
#include <ostream>
#include <stdexcept>
#include <typeinfo>
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include <vector>

C10_DEFINE_bool(
    torch_jit_enable_rethrow_caught_exception,
    false,
    "enable rethrowing caught exception");

C10_DEFINE_bool(
    torch_jit_enable_expanded_stacks,
    false,
    "When true we will attemps to pre-expand node stacks and cache expanded stacks.");

namespace torch::jit {

using CodeImpl = interpreter::CodeImpl;

// Before we translate to interpreter instructions, we do
// some preprocessing of the graph to turn it into a form that is closer
// to what the instructions will look like.
// In particular we:
// *  Computes whether a input to a node is the last use, so we can issue MOVE
//    rather than LOAD instructions.
// *  Drop nodes are inserted for any node that is unused to create a dummy use
//    that will cause the interpreter to free the node.
//    A drop node just pops its input off the stack to  ensure the interpreter
//    releases references to nodes that are never used. Drop nodes are also
//    inserted when the last use of a node is in some conditionally run control
//    flow (e.g. one side of an If) and the interpreter must free the node only
//    after the control flow has reconverged
// Outputs are:
// * graph - the post processed copy of g
// * move_flags[n] - a list of booleans, one for each input,
//   indicating whether this is the last use of the value. The interpreter
//   should generate a move rather than a copy in this case.

TensorTypePtr tensorTypeInCurrentExecutionContext(const at::Tensor& t) {
  if (!t.defined()) {
    return TensorType::get()->withUndefined();
  }
  auto r = TensorType::create(t);
  if (!at::GradMode::is_enabled()) {
    return r->withRequiresGrad(false);
  }
  return r;
}

namespace {
inline int64_t getDistAutogradContextId() {
#ifdef USE_RPC
  return DistAutogradContainer::currentContextId();
#else
  return 0;
#endif
}
} // namespace

thread_local InterpreterStateImpl* tls_int_state_ptr_ = nullptr;
struct TLSCurrentInterpreterGuard {
  TLSCurrentInterpreterGuard(InterpreterStateImpl* state)
      : prev_state_(tls_int_state_ptr_) {
    tls_int_state_ptr_ = state;
  }

  ~TLSCurrentInterpreterGuard() {
    tls_int_state_ptr_ = prev_state_;
  }

 private:
  InterpreterStateImpl* prev_state_;
};

// InterpreterState state that and used to compute a Code
struct InterpreterStateImpl : c10::intrusive_ptr_target {
  InterpreterStateImpl(const Code& code, TaskLauncher taskLauncher)
      : taskLauncher_(std::move(taskLauncher)) {
    enterFrame(code, 0);
  }

 private:
  using Frame = torch::jit::interpreter::Frame;
  struct WarnedNodes {
   public:
    // Inserts idx into warned_nodes_, returns a boolean indicates whether
    // insertion actually happened (idx wasn't originally in the set).
    bool insert(int32_t idx) {
      std::unique_lock<std::mutex> lock(mutex_);
      return warned_nodes_.insert(idx).second;
    }

   private:
    std::mutex mutex_;
    std::unordered_set<int32_t> warned_nodes_;
  };

  WarnedNodes warned_nodes_;

  // if we need to suspend, where do we reset the stack?
  // answer: to where it was when we were called, not
  // including any inputs to this function
  int64_t stack_start_ = -1;
  c10::intrusive_ptr<Future> future_;
  TaskLauncher taskLauncher_;

  // this holds all the tensors for this interpreter run
  // we don't bother minimizing the size of this vector, since the extra
  // memory used by the pointers in this will be small
  // instead we are very aggressive about releasing tensors when they become
  // dead to make sure memory management happens efficiently. We optimize for
  // the case where derivatives are run with retain_graph=False in the case
  // where it is true, then the interpreter and this array get copied if this
  // every becomes a bottleneck then we _should_ consider minimizing the total
  // number or register
  std::vector<IValue> registers;

  // A stack of objects that have been __enter__'d.
  std::vector<IValue> entered_objects;

  std::vector<Frame> frames;

  c10::intrusive_ptr<InterpreterStateImpl> intrusive_from_this() {
    c10::raw::intrusive_ptr::incref(this);
    return c10::intrusive_ptr<InterpreterStateImpl>::reclaim(this);
  }

  void enterFrame(const Code& code, size_t base_pointer) {
    frames.emplace_back(Frame{code.pImpl, 0, base_pointer, std::nullopt});
    registers.resize(registers.size() + code.pImpl->register_size_);
  }

  void leaveFrame() {
    registers.resize(registers.size() - frames.back().function->register_size_);
    frames.pop_back();
  }

  void callFunction(
      Function& f,
      Stack& stack,
      std::optional<size_t> bailOut = std::nullopt,
      bool next = true) {
    bool newFrame = f.call(stack, bailOut, [&](const Code& code) {
      enterFrame(code, stack.size() - code.num_inputs());
      checkAndStartRecordFunction(frames.back(), stack);
    });
    if (next) {
      (frames.rbegin() + (newFrame ? 1 : 0))->pc++;
    }
  }

  // relative to the end of the register list so that when we call
  // functions we are referring to the registers of the currently executing
  // function.
  IValue& reg(size_t reg) {
    return *(registers.end() - reg);
  }

  void dump(std::ostream& out, const Stack& stack) const {
    out << "Stack:\n";
    for (const auto& val : stack) {
      out << val;
      out << "\n";
    }
  }

  class StackSizeDidntChangeGuard {
   public:
    StackSizeDidntChangeGuard(const StackSizeDidntChangeGuard&) = delete;
    StackSizeDidntChangeGuard(StackSizeDidntChangeGuard&&) = delete;
    StackSizeDidntChangeGuard& operator=(const StackSizeDidntChangeGuard&) =
        delete;
    StackSizeDidntChangeGuard& operator=(StackSizeDidntChangeGuard&&) = delete;

    StackSizeDidntChangeGuard(
        const Frame& frame,
        const torch::jit::Stack& stack,
        const Instruction& inst)
        : frame_(frame), stack_(stack), instX_(inst.X) {
      // portable maybe_unused attribute.
      (void)frame_;
      (void)stack_;
      (void)instX_;
      (void)initialSize_;
    }

    void callAssert() const {
#ifndef NDEBUG
      frame_.function->assert_stack_size(instX_, initialSize_, stack_.size());
#endif
    }

   private:
    const Frame& frame_;
    const torch::jit::Stack& stack_;
    std::uint32_t instX_;
    std::size_t initialSize_{stack_.size()};
  };

  struct C10_UNUSED DoNothing {};

#if defined(__GNUC__) || defined(__clang__)
#define JIT_USE_COMPUTED_GOTO
#endif
// Primitives for making interpreter internal state transitions.
// We maintain two local variables as the internal interpreter state:
// `frame` will be the current frame that the interpreter operators on.
// `inst` will the current instruction pointed to by program counter.
//
// Instruction blocks should be always declared through `INST` macro and
// the instruction body should always start with a `instGuard()` declaration.
// Also blocks should be ended properly with either `INST_NEXT` (for going
// to the next instruction), or `INST_DISPATCH` (for jumping to a computed
// position using `instFetch`).
#if defined(JIT_USE_COMPUTED_GOTO)
#define INST(NAME) \
  NAME:            \
  label_##NAME
#define INST_DISPATCH goto* dispatch_table[inst.op]
#else
#define INST(NAME) NAME
#define INST_DISPATCH break
#endif
#define INST_NEXT      \
  inst = instFetch(1); \
  INST_DISPATCH

  template <bool EnableProfiling>
  bool runTemplate(Stack& stack) {
    // if we have never run before, then we might have to return the
    // stack when we suspend, record where it starts so we return the right
    // stack
    if (stack_start_ == -1) {
      TORCH_INTERNAL_ASSERT(stack.size() >= frames.back().function->n_inputs);
      stack_start_ = stack.size() - frames.back().function->n_inputs;
    } else {
      // during restarts, all of the stack is always our own, so we leave
      // nothing
      stack_start_ = 0;
    }

    TLSCurrentInterpreterGuard g(this);
    if (frames.back().pc == 0 && stack_start_ == 0) {
      checkAndStartRecordFunction(frames.back(), stack);
    }

#if defined(JIT_USE_COMPUTED_GOTO)
    // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays)
    static void* dispatch_table[] = {
#define DISPATCH_TABLE_ENTRY(op, _) &&label_##op,
        FORALL_OPCODES(DISPATCH_TABLE_ENTRY)
#undef DISPATCH_TABLE_ENTRY
    };
#endif

    try {
      while (true) {
        Frame& frame = frames.back();

        auto instFetch = [&](auto x) {
          return frame.function->instructions_[frame.pc += x];
        };

        auto instGuard = [&] {
          if constexpr (!EnableProfiling) {
            return DoNothing{};
          } else {
            return profiling::InstructionSpan{
                *frame.function->instructions_source()[frame.pc]};
          }
        };

        Instruction inst = instFetch(0);

        auto stackSizeAssertGuard = [&] {
          return StackSizeDidntChangeGuard{frame, stack, inst};
        };

        switch (inst.op) {
          case INST(ENTER): {
            auto _ = instGuard();
            const auto& obj = peek(stack, 0, 1);
            TORCH_INTERNAL_ASSERT(obj.isObject());
            entered_objects.push_back(obj);
          }
            INST_NEXT;
          case INST(EXIT): {
            auto _ = instGuard();
            auto obj = entered_objects.back().toObject();
            auto& f = obj->type()->getMethod("__exit__");
            push(stack, std::move(obj));
            entered_objects.pop_back();
            push(stack, IValue());
            push(stack, IValue());
            push(stack, IValue());
            callFunction(f, stack);
            continue;
          }
          case INST(OP): {
            auto _ = instGuard();
            auto stackSizeGuard = stackSizeAssertGuard();
            frame.function->operator_table_[inst.X](stack);
            stackSizeGuard.callAssert();
          }
            INST_NEXT;
          case INST(OPN): {
            auto _ = instGuard();
            stack.emplace_back(inst.N);
            auto stackSizeGuard = stackSizeAssertGuard();
            frame.function->operator_table_[inst.X](stack);
            stackSizeGuard.callAssert();
          }
            INST_NEXT;
          case INST(LOAD): {
            auto _ = instGuard();
            stack.emplace_back(reg(inst.X));
          }
            INST_NEXT;
          case INST(MOVE): {
            auto _ = instGuard();
            stack.emplace_back(std::move(reg(inst.X)));
          }
            INST_NEXT;
          case INST(STORE): {
            auto _ = instGuard();
            reg(inst.X) = pop(stack);
          }
            INST_NEXT;
          case INST(STOREN): {
            auto _ = instGuard();
            TORCH_INTERNAL_ASSERT(stack.size() >= inst.N);
            for (size_t i = inst.N; i > 0; --i) {
              reg(inst.X + i - 1) = pop(stack);
            }
          }
            INST_NEXT;
          case INST(DROP): {
            auto _ = instGuard();
            stack.pop_back();
          }
            INST_NEXT;
          case INST(DROPR): {
            auto _ = instGuard();
            reg(inst.X) = IValue();
          }
            INST_NEXT;
          case INST(LOADC): {
            auto _ = instGuard();
            stack.emplace_back(frame.function->constant_table_[inst.X]);
          }
            INST_NEXT;
          case INST(GET_ATTR): {
            auto _ = instGuard();
            const auto& userObj = stack.back().toObjectRef();
            stack.back() = userObj.getSlot(inst.X);
          }
            INST_NEXT;
          case INST(SET_ATTR): {
            auto _ = instGuard();
            auto v = pop(stack);
            auto& userObj = stack.back().toObjectRef();
            userObj.setSlot(inst.X, std::move(v));
            stack.pop_back();
          }
            INST_NEXT;
          case INST(JF): {
            auto _ = instGuard();
            if (pop(stack).toBool()) {
              inst = instFetch(1);
            } else {
              inst = instFetch(inst.X);
            }
          }
            INST_DISPATCH;
          case INST(JMP): {
            auto _ = instGuard();
            inst = instFetch(inst.X);
          }
            INST_DISPATCH;
          case INST(LOOP): {
            auto _ = instGuard();
            // stack: iteration_count, max_iter, cond, loop_carried_deps...
            auto fr = stack.end() - (inst.N + 1);
            int64_t trip_count = fr[0].toInt();
            int64_t max_trip_count = fr[1].toInt();
            bool cond = fr[2].toBool();
            if (trip_count < max_trip_count && cond) {
              fr[2] = trip_count;
              fr[0] = trip_count + 1;
              inst = instFetch(1);
            } else {
              size_t n_loop_carried = inst.N - 2;
              for (const auto i : c10::irange(n_loop_carried)) {
                fr[i] = std::move(fr[i + 3]);
              }
              drop(stack, 3); // iteration_count, max_iter, cond
              inst = instFetch(inst.X);
            }
          }
            INST_DISPATCH;
          case INST(CALL): {
            auto _ = instGuard();
            Function* fn = frame.function->function_table_[inst.X];
            callFunction(*fn, stack);
            continue;
          }
          case INST(INTERFACE_CALL): {
            auto _ = instGuard();
            // note the hash table lookup to find the function
            // this can be more optimized if necessary, caching parts
            // of the hashing computation or storing the offset when
            // the object is turned into an interface

            // consider passing
            // `frames.back().function->remaining_bailout_depth_` into
            // `get_executor().getPlanFor()` to propagate caller's depth
            // restrictions onto children while this strategy has a potential to
            // reduce the number of compilations for too dynamic callers we
            // might miss opportunities where a caller is dynamic but a callee
            // gets stable arguments
            Function& function =
                peek(stack, 0, inst.N)
                    .toObject()
                    ->type()
                    ->getMethod(
                        frame.function->constant_table_[inst.X].toStringRef());
            callFunction(function, stack);
            continue;
          }
          case INST(RET): {
            if (frames.size() > 1) {
              leaveFrame();
              continue;
            }
            if (future_) {
              auto num_outputs = frames.back().function->n_outputs;
              if (num_outputs == 1) {
                future_->markCompleted(stack.back());
              } else {
                future_->markCompleted(
                    c10::ivalue::Tuple::create(jit::last(stack, num_outputs)));
              }
            }
            // destroy the last frame and call RecordFunction's end callbacks
            leaveFrame();
            return false;
          }
          case INST(WAIT): {
            auto _ = instGuard();
            auto future = stack.back().toFuture();
            if (!future->completed()) {
              getOrCreateFuture();

              // callback needs to be a struct rather than a lambda so that
              // we can move the stack to the other thread
              struct Callback {
                Callback(
                    c10::intrusive_ptr<InterpreterStateImpl> state,
                    Stack stack)
                    : stateImpl_(std::move(state)),
                      state_(stateImpl_),
                      stack_(std::move(stack)),
                      dist_autograd_context_id_(getDistAutogradContextId()) {
                  state_ = InterpreterState(stateImpl_);
                }
                void operator()(c10::ivalue::Future& /* unused */) {
                  stateImpl_->taskLauncher_(InterpreterContinuation(
                      state_,
                      std::move(stack_),
                      dist_autograd_context_id_,
                      std::move(tls_state_)));
                }

               private:
                c10::intrusive_ptr<InterpreterStateImpl> stateImpl_;
                InterpreterState state_;
                Stack stack_;
                int64_t dist_autograd_context_id_;
                // preserve the original ThreadLocalState
                at::ThreadLocalState tls_state_;
              };

              // we are suspending, so we need to reset the stack to where we
              // started if it started empty, except for the inputs we can avoid
              // a true copy by swapping, which leaves the original stack empty.
              Stack copied;
              if (stack_start_ == 0) {
                copied.swap(stack);
              } else {
                copied.insert(
                    copied.begin(),
                    std::make_move_iterator(stack.begin() + stack_start_),
                    std::make_move_iterator(stack.end()));
                stack.resize(stack_start_);
              }
              // save pc into the frame so we continue here when restored
              future->addCallback(
                  Callback(intrusive_from_this(), std::move(copied)));

              return true;
            }
            stack.pop_back();
            stack.emplace_back(future->value());
          }
            INST_NEXT;
          case INST(PROFILE_OP): {
            auto _ = instGuard();
            auto& frame_id_ref = frame.id;
            if (!frame_id_ref.has_value()) {
              frame_id_ref = Frame::genId();
            }
            const auto& callback =
                frame.function->profile_function_table_[inst.X];
            push(stack, c10::IValue{static_cast<int64_t>(*frame_id_ref)});
            callback(stack);
          }
            INST_NEXT;
          case INST(FAIL_GUARD): {
            auto _ = instGuard();
            // patch FAIL_GUARD back to GUARD
            GRAPH_DEBUG(
                "Bailout ", inst.X, " triggered via bailout_requests_!");
            frame.function->instructions_[frame.pc].op = GUARD;
            push(stack, false);
          }
            INST_NEXT;
          case INST(TYPECHECK): {
            auto _ = instGuard();
            unsigned num_inputs = inst.N, i = 0;
            TORCH_INTERNAL_ASSERT(stack.size() >= num_inputs && num_inputs > 0);
            // Check every input's shape against profiled (expected) shape.
            for (i = 0; i < num_inputs; i++) {
              auto& input = peek(stack, i, num_inputs);
              auto& t = input.toTensor();
              const TypePtr& expected = frame.function->type_table_[inst.X + i];
              auto* expected_type = expected->castRaw<TensorType>();
              if (t.defined() && !expected_type->matchTensor(t)) {
                push(stack, false);
                break;
              }
            }
            if (i == num_inputs) {
              push(stack, true);
            }
          }
            INST_NEXT;
          case INST(GUARD): {
            auto _ = instGuard();
            if (!stack.back().isTensor()) {
              // stack.back() is an Uninitialized IValue and this is a guard
              // on a block output. Uninitialized IValues are never used
              // so it's safe to pass this guard check
              push(stack, true);
            } else {
              auto& t = stack.back().toTensor();
              const TypePtr& expected = frame.function->type_table_[inst.X];
              auto* expected_type = expected->castRaw<TensorType>();
              if (t.defined() &&
                  !frames.back().symbols2dims.bindSymbolicShapes(
                      t.sizes(), expected_type->symbolic_sizes())) {
                push(stack, false);
              } else {
                push(stack, expected_type->matchTensor(t));
              }
            }
          }
            INST_NEXT;
          case INST(TAIL_CALL): {
            auto _ = instGuard();
            GRAPH_DEBUG("running TAIL_CALL for ", inst.X);
            frame.function->function_table_[inst.X]->ensure_defined();
            size_t remaining_bailout_depth =
                frame.function->remaining_bailout_depth_ > 0
                ? frame.function->remaining_bailout_depth_ - 1
                : 0;
            auto& f = *frame.function->function_table_[inst.X];
            size_t num_inputs = f.num_inputs();
            size_t base_pointer = frame.base_pointer;
            TORCH_INTERNAL_ASSERT(stack.size() >= num_inputs);
            size_t inputs_start = stack.size() - num_inputs;
            for (const auto i : c10::irange(num_inputs)) {
              stack.at(base_pointer + i) =
                  std::move(stack.at(inputs_start + i));
            }
            stack.resize(base_pointer + num_inputs);
            leaveFrame();

            callFunction(f, stack, remaining_bailout_depth, false);
            continue;
          }
          case INST(LIST_UNPACK): {
            auto _ = instGuard();
            listUnpack(stack, inst.X);
          }
            INST_NEXT;
          case INST(TUPLE_CONSTRUCT): {
            auto _ = instGuard();
            tupleConstruct(stack, inst.X);
          }
            INST_NEXT;
          case INST(TUPLE_SLICE): {
            auto _ = instGuard();
            tupleSlice(stack, inst.X, inst.X + inst.N);
          }
            INST_NEXT;
          case INST(NAMED_TUPLE_CONSTRUCT): {
            auto _ = instGuard();
            namedTupleConstruct(
                stack,
                frame.function->type_table_[inst.X]->expect<TupleType>(),
                inst.N);
          }
            INST_NEXT;
          case INST(LIST_CONSTRUCT): {
            auto _ = instGuard();
            const auto& type =
                frame.function->type_table_[inst.X]->expectRef<ListType>();
            listConstruct(stack, type, inst.N);
          }
            INST_NEXT;
          case INST(DICT_CONSTRUCT): {
            auto _ = instGuard();
            const auto& type =
                frame.function->type_table_[inst.X]->expectRef<DictType>();
            dictConstruct(stack, type, inst.N);
          }
            INST_NEXT;
          case INST(CREATE_OBJECT): {
            auto _ = instGuard();
            auto type =
                frame.function->type_table_[inst.X]->expect<ClassType>();
            createObject(stack, type);
          }
            INST_NEXT;
          case INST(ISINSTANCE): {
            auto _ = instGuard();
            at::ArrayRef<TypePtr> types(
                &frame.function->type_table_[inst.X],
                &frame.function->type_table_[inst.X] + inst.N);
            isinstance(stack, types);
          }
            INST_NEXT;
          case INST(TUPLE_INDEX): {
            auto _ = instGuard();
            tupleIndex(stack);
          }
            INST_NEXT;
          case INST(RAISE_EXCEPTION): {
            auto _ = instGuard();
            raiseExceptionWithMessage(stack);
          }
            INST_NEXT;
          case INST(UNCHECKED_CAST): {
            auto _ = instGuard();
            noop(stack);
          }
            INST_NEXT;
          case INST(__IS__): {
            auto _ = instGuard();
            is(stack);
          }
            INST_NEXT;
          case INST(UN_INITIALIZED): {
            auto _ = instGuard();
            unInitialized(stack);
          }
            INST_NEXT;
          case INST(__ISNOT__): {
            auto _ = instGuard();
            isNot(stack);
          }
            INST_NEXT;
          case INST(FORMAT): {
            auto _ = instGuard();
            format(stack, inst.X);
          }
            INST_NEXT;
          case INST(DEVICE): {
            auto _ = instGuard();
            device(stack);
          }
            INST_NEXT;
          case INST(DTYPE): {
            auto _ = instGuard();
            TORCH_INTERNAL_ASSERT(!stack.empty());
            dtype(stack);
          }
            INST_NEXT;
          case INST(DIM): {
            auto _ = instGuard();
            TORCH_INTERNAL_ASSERT(!stack.empty());
            dim(stack);
          }
            INST_NEXT;
          case INST(__NOT__): {
            auto _ = instGuard();
            _not(stack);
          }
            INST_NEXT;
          case INST(DICT_INDEX): {
            auto _ = instGuard();
            dictIndex(stack);
          }
            INST_NEXT;
          case INST(TO_LIST): {
            auto _ = instGuard();
            toList(stack);
          }
            INST_NEXT;
          case INST(NUM_TO_TENSOR): {
            auto _ = instGuard();
            numToTensorScalar(stack);
          }
            INST_NEXT;
          case INST(IS_CUDA): {
            auto _ = instGuard();
            isCuda(stack);
          }
            INST_NEXT;
          case INST(FORK): {
            auto _ = instGuard();
            // Move inputs to a separate stack
            auto& forked_fn =
                toGraphFunction(*frame.function->function_table_[inst.X]);
            InterpreterState forked_interpreter(
                forked_fn.get_executor().getPlanFor(stack).code, taskLauncher_);
            InterpreterContinuation continuation(
                forked_interpreter,
                Stack(stack.end() - inst.N, stack.end()),
                getDistAutogradContextId());
            drop(stack, inst.N);
            push(stack, forked_interpreter.getFuture());
            taskLauncher_(std::move(continuation));
          }
            INST_NEXT;
          case INST(AWAITABLE): {
            auto _ = instGuard();
            auto fn_ptr = frame.function->function_table_[inst.X];
            auto& fn = toGraphFunction(*fn_ptr);
            auto num_outputs = fn.graph()->outputs().size();
            TypePtr out_type;
            if (num_outputs == 1) {
              out_type = fn.graph()->outputs()[0]->type();
            } else {
              std::vector<TypePtr> out_types;
              for (const auto& o : fn.graph()->outputs()) {
                out_types.push_back(o->type());
              }
              out_type = TupleType::create(out_types);
            }
            auto args = std::vector<IValue>(stack.end() - inst.N, stack.end());
            auto aw = c10::make_intrusive<c10::ivalue::Await>(out_type);
            aw->setArgs(std::move(args));
            aw->setFn(
                [&args = aw->args(),
                 fn_ptr,
                 taskLauncher = taskLauncher_]() -> IValue {
                  auto& fn = toGraphFunction(*fn_ptr);
                  auto n_out = fn.graph()->outputs().size();
                  torch::jit::Stack s;
                  for (const auto& arg : args) {
                    s.push_back(arg);
                  }
                  InterpreterState await_interpreter(
                      fn.get_executor().getPlanFor(s).code, taskLauncher);
                  await_interpreter.run(s);
                  if (n_out == 1) {
                    return s.back();
                  }
                  return c10::ivalue::Tuple::create(jit::last(s, n_out));
                });
            drop(stack, inst.N);
            push(stack, std::move(aw));
          }
            INST_NEXT;
          case INST(WARN): {
            auto _ = instGuard();
            // Keeps track of which WARN instruction has been executed before,
            // we only want to execute each WARN once to match default Python
            // warning behavior.
            bool need_warn = true;
            if (inst.X != -1) {
              need_warn = warned_nodes_.insert(inst.X);
            }

            Node* node =
                frames.back().function->instructions_source_.at(frame.pc);
            auto range = node->sourceRange().source();
            if (range->filename()) {
              drop(stack, 1);
              const auto& msg = stack.back().toStringRef();
              if (need_warn) {
                auto line = range->starting_line_no() +
                    range->lineno_for_offset(node->sourceRange().start());
                c10::SourceLocation location{
                    "", range->filename()->c_str(), uint32_t(line)};
                // Sends the warning to the warning handler with the
                // "verbatim" flag. This flag ensures the warning handler
                // will print the exception as configured.
                c10::warn(c10::Warning(
                    c10::UserWarning(), location, msg, /*verbatim=*/true));
              }
              stack.pop_back();
            } else {
              if (need_warn) {
                TORCH_WARN(stack.back().toStringRef());
              }
              stack.pop_back();
            }
          }
            INST_NEXT;
        }
      }
    } catch (std::exception& e) {
      for (auto it = entered_objects.rbegin(), end = entered_objects.rend();
           it != end;
           ++it) {
        auto& f = it->toObject()->type()->getMethod("__exit__");
        Stack stack;
        push(stack, *it);
        push(stack, IValue());
        push(stack, IValue());
        push(stack, IValue());
        try {
          f.run(stack);
        } catch (std::exception& _) {
          // TODO(T98048876): Handle `_` correctly.
        }
      }
      if (FLAGS_torch_jit_enable_rethrow_caught_exception) {
        if (future_) {
          future_->setError(std::current_exception());
          return false;
        }
        throw;
      }
      auto* jit_exception = dynamic_cast<JITException*>(&e);
      // Janky af.  See https://github.com/pytorch/pytorch/issues/54612
      auto* not_implemented_error = dynamic_cast<c10::NotImplementedError*>(&e);

      std::optional<std::string> python_class_name;
      if (jit_exception) {
        python_class_name = jit_exception->getPythonClassName();
      }
      handleError(
          e, (bool)jit_exception, not_implemented_error, python_class_name);
      return false;
    }
  }

#undef INST_NEXT
#undef INST_DISPATCH
#undef INST
#undef JIT_USE_COMPUTED_GOTO

  bool runImpl(Stack& stack) {
    if (!profiling::isProfilingOngoing()) {
      return runTemplate</*EnableProfiling*/ false>(stack);
    } else {
      return runTemplate</*EnableProfiling*/ true>(stack);
    }
  }

  void formatStackTrace(std::ostream& out) {
    format_stack_trace(out, callstack());
  }

  void handleError(
      const std::exception& e,
      bool is_jit_exception,
      c10::NotImplementedError* not_implemented_error,
      std::optional<std::string> python_class_name) {
    ExceptionMessage msg(e);
    std::ostringstream ss;
    std::string class_name =
        python_class_name ? *python_class_name : "RuntimeError";
    ss << "The following operation failed in the TorchScript interpreter.\n";
    formatStackTrace(ss);
    ss << class_name << ": " << msg << "\n";
    if (future_) {
      future_->setError(std::make_exception_ptr(Future::FutureError(ss.str())));
    } else if (is_jit_exception) {
      // save the original exception's message when creating a new JITException
      throw JITException(ss.str(), python_class_name, e.what());
    } else if (not_implemented_error) {
      throw c10::NotImplementedError(
          ss.str(),
          not_implemented_error->backtrace(),
          not_implemented_error->caller());
    } else {
      if (get_cpp_stacktraces_enabled()) {
        ss << e.what() << "\n";
      }
      throw std::runtime_error(ss.str());
    }
  }

  static void checkAndStartRecordFunction(Frame& frame, Stack& stack) {
    if (!frame.record_function) {
      auto step_callbacks = at::getStepCallbacksUnlessEmpty(
          at::RecordScope::TORCHSCRIPT_FUNCTION);
      if (C10_UNLIKELY(step_callbacks.has_value())) {
        auto rec_fn =
            std::make_unique<at::RecordFunction>(std::move(*step_callbacks));
        TORCH_INTERNAL_ASSERT_DEBUG_ONLY(rec_fn->isActive());
        if (rec_fn->needsInputs()) {
          rec_fn->before(
              frame.function->function_name_,
              last(stack, frame.function->n_inputs));
        } else {
          rec_fn->before(frame.function->function_name_);
        }
        frame.record_function = std::move(rec_fn);
      }
    }
  }

 public:
  // One way to avoid overhead of forming string would be to return
  // a vector of frame.function, i.e. CodeImpl*
  // This is not exactly clean as it will expose, internal details of
  // interpreter. But this way we hold onto graph/node and Function and
  // we can create module hierarchy string for each event in autograd
  // profiler at the end, when consolidating events.
  // At the moment overhead does not seem exhorbitantly large.
  // Another option would be return vector of (string, InlinedCallstackPtrs)
  // string would contain function name and typename of self
  // Format of the returned vector of strings:
  // For each frame, the corresponding module name, type and function name
  // are in following format:
  // <module-instance-name>(module type)::<function-name>
  // Special keys for module-instance-name:
  //   - TOP: for top level module
  //   - SELF: When method/function of the frame is associated with
  //           previous frame's module instance
  //   - INSTANCE_NAME_UNKNOWN: instance name cannot be figured out
  //   - CALL_FUNCTION: call to free function
  std::vector<std::string> moduleHierarchy() const {
    std::vector<std::string> module_function_list;
    std::string module_hierarchy("TOP");
    for (size_t i = 0; i < frames.size(); ++i) {
      const Frame& frame = frames[i];
      std::string fn_name = frame.function->function_name_;
      // For each frame, type of the class with which the function is
      // associated, is queried here. And the type name is added to
      // module hierarchy.
      const auto& g = frame.function->graph_;
      std::string g_self_type;
      if (g && !g->inputs().empty()) {
        const auto& g_self_type_ptr =
            g->inputs()[0]->type()->cast<c10::ClassType>();
        if (g_self_type_ptr) {
          g_self_type = g_self_type_ptr->name()->qualifiedName();
          g_self_type = g_self_type.substr(g_self_type.find_last_of('.') + 1);
        }
      }
      module_hierarchy.append("(")
          .append(g_self_type)
          .append(")::")
          .append(fn_name);
      module_function_list.emplace_back(std::move(module_hierarchy));

      size_t pc = frame.pc;
      // CALL nodes have already advanced the pc, so
      // undo that to report the call node
      if (i + 1 < frames.size()) {
        --pc;
      }

      Node* node = frame.function->instructions_source_[pc];
      if (node->callstack()) {
        for (const auto& p : (*node->callstack())->vec()) {
          fn_name = std::get<0>(p)->name();
          const auto& opt_module_info = std::get<2>(p);
          if (opt_module_info.has_value()) {
            const auto& module_instance_info = opt_module_info.value();
            module_hierarchy = utils::get_module_info(module_instance_info);
            module_hierarchy.append("::").append(fn_name);
          } else {
            // This is likely a call to free function, not associated with
            // any class
            module_hierarchy = "::";
            module_hierarchy.append(fn_name);
          }
          module_function_list.emplace_back(std::move(module_hierarchy));
        }
      }

      module_hierarchy = std::string();
      // If this node is of type callMethod then the following frame
      // will contain the op being executed.
      // For such callMethod node, we add the object instance name
      // associated with it, since the following frame will not have it.
      if (node->kind() == prim::CallMethod) {
        std::string class_instance_name;
        if (node->input(0)->node()->kind() == prim::GetAttr) {
          class_instance_name = node->input(0)->node()->s(attr::name);
        } else if (
            !node->owningGraph()->inputs().empty() &&
            node->input(0) == node->owningGraph()->inputs()[0]) {
          class_instance_name = "SELF";
        } else {
          class_instance_name = "INSTANCE_NAME_UNKNOWN";
        }
        module_hierarchy = std::move(class_instance_name);
      } else if (node->kind() == prim::CallFunction) {
        auto function_constant = node->input(0)->node();
        auto fun_type =
            function_constant->output()->type()->expect<FunctionType>();
        auto fun_name = fun_type->function()->name();
        module_hierarchy = "CALL_FUNCTION::";
        module_hierarchy.append(fun_name);
      }
    }
    return module_function_list;
  }

  std::vector<StackEntry> callstack() const {
    std::vector<StackEntry> entries;
    for (const auto i : c10::irange(frames.size())) {
      const Frame& frame = frames[i];
      std::string previous_fn_name = frame.function->function_name_;
      size_t pc = frame.pc;
      // CALL nodes have already advanced the pc, so
      // undo that to report the call node
      if (i + 1 < frames.size()) {
        --pc;
      }

      Node* node = frame.function->instructions_source_[pc];
      if (node->callstack()) {
        for (const auto& p : (*node->callstack())->vec()) {
          entries.emplace_back(StackEntry{previous_fn_name, std::get<1>(p)});
          previous_fn_name = std::get<0>(p)->name();
        }
      }
      entries.emplace_back(StackEntry{previous_fn_name, node->sourceRange()});
    }
    return entries;
  }

  c10::intrusive_ptr<Future> getOrCreateFuture() {
    if (!future_) {
      future_ =
          c10::make_intrusive<Future>(frames.front().function->return_type_);
    }
    return future_;
  }

  c10::intrusive_ptr<Future> runAsync(Stack& stack) {
    getOrCreateFuture();
    runImpl(stack);
    return future_;
  }

  void run(Stack& stack) {
    // By the time the continuation completes the frame will be gone, so this
    // must be done before calling runImpl().
    TORCH_INTERNAL_ASSERT(!frames.empty());
    const auto num_outputs = frames.front().function->n_outputs;
    if (runImpl(stack)) {
      future_->wait();

      if (num_outputs == 1) {
        push(stack, future_->value());
      } else {
        auto tuple = future_->value().toTuple();
        for (const IValue& value : tuple->elements()) {
          push(stack, value);
        }
      }
    }
  }
};

std::vector<StackEntry> currentCallstack() {
  if (tls_int_state_ptr_) {
    auto cs = tls_int_state_ptr_->callstack();
    std::reverse(cs.begin(), cs.end());
    return cs;
  }
  return std::vector<StackEntry>();
}

std::vector<std::string> currentModuleHierarchy() {
  if (tls_int_state_ptr_) {
    return tls_int_state_ptr_->moduleHierarchy();
  }
  return std::vector<std::string>();
}

std::ostream& operator<<(std::ostream& out, const Code& code) {
  out << *code.pImpl->graph_ << "\n";
  code.pImpl->dump(out);
  return out;
}

Code::Code(
    const std::shared_ptr<Graph>& graph,
    std::string function_name,
    size_t remaining_bailout_depth)
    : pImpl(new CodeImpl(
          graph,
          std::move(function_name),
          remaining_bailout_depth)) {}

Code::Code(CodeImpl* codeImpl) : pImpl(codeImpl) {}

MobileCode::MobileCode(
    const std::shared_ptr<Graph>& graph,
    std::string function_name,
    bool emit_default_input_instructions,
    bool support_default_args_before_out,
    bool emit_promoted_ops,
    size_t remaining_bailout_depth)
    : Code(new interpreter::MobileCodeImpl(
          graph,
          std::move(function_name),
          emit_default_input_instructions,
          support_default_args_before_out,
          emit_promoted_ops,
          remaining_bailout_depth)) {}

const std::vector<GraphExecutor*>& Code::grad_executors() {
  return pImpl->grad_executors();
}

const std::vector<GraphExecutor*>& Code::diff_graph_op_executors() {
  return pImpl->diff_graph_op_executors();
}

size_t Code::num_bailouts() const {
  return pImpl->type_table_.size();
}

void Code::request_bailout(size_t index) {
  pImpl->request_bailout(index);
}

size_t Code::num_inputs() const {
  return pImpl->n_inputs;
}

size_t Code::num_outputs() const {
  return pImpl->n_outputs;
}

const std::vector<c10::IValue>& Code::constant_table() const {
  return pImpl->constant_table();
}

const std::vector<Instruction>& Code::instructions() const {
  return pImpl->instructions();
}

const std::unordered_map<std::string, size_t>& Code::op_to_num_specified_args()
    const {
  return pImpl->op_to_num_specified_args();
}

const std::vector<Node*>& Code::instructions_source() const {
  return pImpl->instructions_source();
}

const std::vector<TypePtr>& Code::type_table() const {
  return pImpl->type_table_;
}

size_t Code::register_size() const {
  return pImpl->register_size_;
}

std::shared_ptr<Graph> Code::graph() const {
  return pImpl->preprocess_.graph;
}

InterpreterState::InterpreterState(const Code& code, TaskLauncher taskLauncher)
    : pImpl(c10::make_intrusive<InterpreterStateImpl>(
          code,
          std::move(taskLauncher))) {}

void InterpreterState::run(Stack& stack) {
  static_cast<InterpreterStateImpl*>(pImpl.get())->run(stack);
}

c10::intrusive_ptr<Future> InterpreterState::runAsync(Stack& stack) {
  return static_cast<InterpreterStateImpl*>(pImpl.get())->runAsync(stack);
}

c10::intrusive_ptr<Future> InterpreterState::getFuture() {
  return static_cast<InterpreterStateImpl*>(pImpl.get())->getOrCreateFuture();
}

InterpreterState::InterpreterState(
    c10::intrusive_ptr<c10::intrusive_ptr_target> pImpl_)
    : pImpl(std::move(pImpl_)) {}

void InterpreterContinuation::operator()() {
#ifdef USE_RPC
  auto prev_dist_id = DistAutogradContainer::currentContextId();
  DistAutogradContainer::forceCurrentContextId(dist_autograd_context_id_);
#endif
  if (tls_state_ != std::nullopt) {
    at::ThreadLocalStateGuard g(*tls_state_);
    state.runAsync(stack);
  } else {
    state.runAsync(stack);
  }
#ifdef USE_RPC
  DistAutogradContainer::forceCurrentContextId(prev_dist_id);
#endif
}

} // namespace torch::jit
