#include <torch/csrc/jit/ir/ir.h>

#include <ATen/core/builtin_function.h>
#include <ATen/core/function.h>
#include <c10/util/Exception.h>
#include <c10/util/StringUtil.h>
#include <c10/util/irange.h>
#include <torch/csrc/jit/api/function_impl.h>
#include <torch/csrc/jit/frontend/error_report.h>
#include <torch/csrc/jit/frontend/schema_matching.h>
#include <torch/csrc/jit/ir/constants.h>
#include <torch/csrc/jit/runtime/operator.h>
#include <torch/csrc/jit/serialization/python_print.h>

#include <algorithm>
#include <iostream>
#include <locale>
#include <memory>
#include <set>
#include <sstream>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <utility>

namespace torch::jit {

namespace utils {
std::string getNodesModuleHierarchy(const Node& n) {
  if (!n.callstack().has_value()) {
    return std::string();
  }
  InlinedCallStackPtr callstack_ptr = n.callstack().value();
  std::string module_hierarchy;
  for (auto& entry : callstack_ptr->vec()) {
    const auto& opt_module_info = std::get<kModuleInstanceInfo>(entry);
    if (opt_module_info.has_value()) {
      const auto& module_instance_info = opt_module_info.value();
      if (!module_hierarchy.empty()) {
        module_hierarchy.append(".");
      }
      module_hierarchy.append(utils::get_module_info(module_instance_info));
    } else {
      module_hierarchy += ".UNKNOWN_INSTANCE(UNKNOWN_TYPE)";
    }
  }
  return module_hierarchy;
}
} // namespace utils

namespace {

// Constants relating to maintaining the topological index of nodes.
//
// Lower and upper bounds of the index. Inclusive range.
constexpr topo_position_t kLowerBound = INT64_MIN;
constexpr topo_position_t kUpperBound = INT64_MAX;
constexpr topo_position_t kMidPoint = 0;

// How far away to space nodes that are appended to the graph.
// should be 2^n, where:
//   - n is the maximum number of repeated insertions without a re-index
//   - 2^(64-n) is the maximum number of appends to the end without reindex
constexpr topo_position_t kAppendInterval = 1099511627776ULL /* 2^40 */;

void printValueRef(std::ostream& out, const Value* n) {
  out << "%" << n->debugName();
}

bool isNumber(c10::string_view str) {
  return str.find_first_not_of("0123456789") == std::string::npos;
}

std::string normalizeAttrName(c10::string_view field) {
  if (isNumber(field)) {
    return "_" + std::string{field};
  }
  return std::string{field};
}

void findAllNodes(
    Block& block,
    Symbol kind,
    bool recurse,
    std::vector<Node*>& ret) {
  for (Node* n : block.nodes()) {
    if (n->kind() == kind) {
      ret.push_back(n);
    }
    if (recurse) {
      for (auto b : n->blocks()) {
        findAllNodes(*b, kind, recurse, ret);
      }
    }
  }
}

} // namespace

// NB: This overload will become ambiguous with the one Caffe2 provides in its
// logging, if they ever intersect.
template <typename T>
std::ostream& operator<<(std::ostream& out, const std::vector<T>& nodes) {
  out << at::ArrayRef<T>{nodes};
  return out;
}

template <typename T>
static std::ostream& printValueRefs(
    std::ostream& out,
    const at::ArrayRef<T> nodes) {
  size_t i = 0;
  for (auto n : nodes) {
    if (i++ > 0) {
      out << ", ";
    }
    printValueRef(out, n);
  }
  return out;
}

// Can't make these two overloads directly a template, it'll be ambiguous with
// the global printer for operator<<.

static std::ostream& operator<<(
    std::ostream& out,
    const at::ArrayRef<const Value*> nodes) {
  return printValueRefs(out, nodes);
}

struct const_value_list_with_types {
  const ArrayRef<const Value*> values;
  std::string delim;
  const_value_list_with_types(
      ArrayRef<const Value*> values,
      std::string delim_ = ", ")
      : values(values), delim(std::move(delim_)) {}
};

static std::ostream& operator<<(
    std::ostream& out,
    const const_value_list_with_types& l) {
  size_t i = 0;
  for (auto n : l.values) {
    if (i++ > 0) {
      out << l.delim;
    }
    printValueRef(out, n);
    if (c10::type_verbosity() >= c10::TypeVerbosity::Type) {
      out << " : ";
      out << *n->type();
    }
  }
  return out;
}

static void printAttribute(std::ostream& out, const at::Tensor& tensor) {
  // 1-elem tensors are usually boxed scalars, so print them like it
  if (tensor.numel() == 1) {
    auto scalar_tensor = tensor.view(std::vector<int64_t>{}).item();
    out << "{";
    if (scalar_tensor.isFloatingPoint()) {
      out << scalar_tensor.toDouble();
    } else if (scalar_tensor.isComplex()) {
      out << scalar_tensor.toComplexDouble();
    } else {
      out << scalar_tensor.toLong();
    }
    out << "}";
  } else if (tensor.numel() <= max_tensor_display_size) {
    // TODO: This is awful code.  Also it doesn't work on Windows.
    std::ostringstream tensor_ss;
    tensor_ss << tensor;
    std::string tensor_s{tensor_ss.str()};
    // Remove newlines
    std::replace(tensor_s.begin(), tensor_s.end(), '\n', ' ');
    out << tensor_s;
  } else {
    out << "<Tensor>";
  }
}

static void printAttribute(std::ostream& out, const IValue& ival) {
  const auto customFormatter = [](std::ostream& ss, const IValue& input) {
    if (input.isTensor()) {
      printAttribute(ss, input.toTensor());
      return true;
    } else if (input.isTensorList()) {
      ss << "[<Tensors>]";
      return true;
    } else if (input.isObject() && !input.type()->is_module()) {
      ss << "object(" << &input.toObjectRef() << ")";
      return true;
    }
    return false;
  };
  ival.repr(out, customFormatter);
}

static void printTypeList(
    std::ostream& out,
    const std::vector<TypePtr>& items) {
  out << "[";
  int i = 0;
  for (auto& item : items) {
    if (i++ > 0)
      out << ", ";
    out << *item;
  }
  out << "]";
}

void Node::printAttrValue(std::ostream& out, const Symbol& name) const {
  switch (kindOf(name)) {
    case AttributeKind::c:
      printAttribute(out, c(name));
      break;
    case AttributeKind::cs:
      // TODO(@anjali411): fix this
      AT_ASSERT(false);
      break;
    case AttributeKind::f:
      printAttribute(out, f(name));
      break;
    case AttributeKind::fs:
      printAttribute(out, fs(name));
      break;
    case AttributeKind::i:
      printAttribute(out, i(name));
      break;
    case AttributeKind::is:
      printAttribute(out, is(name));
      break;
    case AttributeKind::s:
      printAttribute(out, s(name));
      break;
    case AttributeKind::ss:
      printAttribute(out, ss(name));
      break;
    case AttributeKind::t:
      printAttribute(out, t(name));
      break;
    case AttributeKind::ts:
      out << "[<Tensors>]";
      break;
    case AttributeKind::ival:
      printAttribute(out, ival(name));
      break;
    case AttributeKind::g:
      out << "<Graph>";
      break;
    case AttributeKind::gs:
      out << "[<Graphs>]";
      break;
    case AttributeKind::ty:
      out << *ty(name);
      break;
    case AttributeKind::tys:
      printTypeList(out, tys(name));
      break;
  }
}

void Node::printAttributes(std::ostream& out, bool ignore_subgraph = false)
    const {
  out << "[";
  auto names = attributeNames();
  int i = 0;
  for (auto name : names) {
    if (ignore_subgraph && name == attr::Subgraph) {
      continue;
    }
    if (i++ > 0) {
      out << ", ";
    }
    // TODO: debugging mode to see the qualifier.  We definitely
    // don't want to print the qualifier since it should always
    // be attribute, but you might be able to track down a weird
    // bug by printing it out.
    out << name.toUnqualString() << "=";

    printAttrValue(out, name);
  }
  out << "]";
}

SourceRange Node::sourceRange() const {
  if (source_range_) {
    return *source_range_;
  }
  return SourceRange();
}

static std::ostream& indent(std::ostream& out, size_t level) {
  for (const auto i : c10::irange(level)) {
    (void)i; // Suppress unused variable warning
    out << "  ";
  }
  return out;
}

std::ostream& Node::print(
    std::ostream& out,
    size_t level,
    std::vector<const Node*>* groups,
    bool print_source_locations,
    bool print_attributes,
    bool print_scopes,
    bool print_body) const {
  auto outs = outputs();
  indent(out, level) << const_value_list_with_types(outs);
  out << " = ";
  if (kind() == prim::PythonOp) {
    auto* pyOp = static_cast<const ::torch::jit::PythonOp*>(this);
    out << "^" << pyOp->name();
    printAttributes(out, /*ignore_subgraph=*/false);
    pyOp->writeScalars(out);
  } else if (hasAttribute(attr::Subgraph) && groups) {
    out << kind().toQualString() << "_" << groups->size();
    if (print_attributes && numAttributes() > 1 &&
        kind() != prim::DifferentiableGraph) {
      printAttributes(out, /*ignore_subgraph=*/true);
    }

    groups->push_back(this);
  } else {
    out << kind().toQualString();
    if (print_attributes && hasAttributes()) {
      printAttributes(out);
    }
  }
  out << "(" << inputs() << ")";

  if (print_scopes) {
    std::string scName = scopeName();
    if (!scName.empty()) {
      out << ", ";
      out << "scope: " << scName;
    }
  }

  // In debug print, append file:line:col as a comment after each node
  if (print_source_locations) {
    SourceRange r = sourceRange();
    if (sourceRange().source()) {
      if (auto orig = sourceRange().source()->findSourceRangeThatGenerated(r)) {
        r = *orig;
      }
    }
    if (auto file_line_col = r.file_line_col()) {
      auto [filename, line, col] = *file_line_col;
      out << " # " << filename << ":" << line << ":" << col;
    }
  }

  if (!print_body) {
    return out;
  }

  out << "\n";

  for (const auto i : c10::irange(blocks().size())) {
    auto b = blocks()[i];
    indent(out, level + 1) << "block" << i << "("
                           << const_value_list_with_types(b->inputs())
                           << "):\n";
    for (auto nested : b->nodes()) {
      nested->print(out, level + 2, groups);
    }
    indent(out, level + 2) << "-> (" << b->outputs() << ")\n";
  }

  return out;
}

std::ostream& operator<<(std::ostream& out, const Node& n) {
  return n.print(out, 0, nullptr);
}

std::ostream& Graph::print(std::ostream& out, bool print_source_locations)
    const {
  out << "graph(" << const_value_list_with_types(inputs(), ",\n      ")
      << "):\n";
  std::vector<const Node*> groups;
  for (auto n : nodes()) {
    n->print(out, 1, &groups, print_source_locations);
  }
  out << "  return (" << outputs() << ")\n";
  size_t i = 0;
  for (auto fg : groups) {
    out << "with " << fg->kind().toQualString() << "_" << i++ << " = "
        << *fg->g(attr::Subgraph);
  }
  out.flush();

  /*
  // Uncomment this to debug all_nodes issues
  {
    out << "\n";
    out << "all_nodes:\n";
    for (auto& n : all_nodes) {
      printNode(out, const_cast<Node*>(n), nullptr);
    }
  }
  */
  return out;
}

std::ostream& operator<<(std::ostream& out, const Graph& g) {
  return g.print(out, true);
}

static void checkSameDevice(const Node* node) {
  bool has_device = false;
  std::optional<at::Device> device = std::nullopt;
  auto checkValue = [&](const Value* v) {
    if (TensorTypePtr type = v->type()->cast<TensorType>()) {
      if (type->device() && !has_device) {
        has_device = true;
        device = *type->device();
      } else {
        AT_ASSERT(device == type->device());
      }
    }
  };
  for (auto input : node->inputs()) {
    checkValue(input);
  }
  for (auto output : node->outputs()) {
    checkValue(output);
  }
}

using node_set = std::set<const Node*>;
#define ALL_OF(container) container.begin(), container.end()

// These functions purposely operate on the internal members directly, to force
// you to think about how the invariants change if you change the data
// representation (even if the external API does not change.)

// NB: This assert is written to assume you don't have any unattached
// nodes.  Unattached nodes can occur while manipulations to the
// graph are occurring.
void Node::lint() const {
  // Node invariants
  // - if node should live in list, nodes_iter is consistent
  // - Inputs are all marked as a use by the nodes they refer to
  // - Owning graph is non-null and consistent
  // - The "Select" invariant, when the node is MultiReturn
  //
  // The handle invariant:
  //    If a node takes a handle as an input, it is always the
  //    LAST input of the node.  There is at most one handle input.

  {
    size_t i = 0;
    for (auto input : inputs_) {
      // WARNING: O(n^2)
      // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
      AT_ASSERT(
          std::find(ALL_OF(input->uses_), Use(const_cast<Node*>(this), i)) !=
          input->uses_.end());
      AT_ASSERT(graph_->all_nodes.count(this) == 1);
      i++;
    }
  }

  for (auto o : outputs()) {
    for (auto use : o->uses()) {
      // Use invariants
      // - Use is consistent with inputs
      // - Every user node is live (checked in Graph)
      AT_ASSERT(use.user->inputs_[use.offset] == o);
    }
  }

  // Node subclass invariants
  switch (kind()) {
    case prim::Constant:
      AT_ASSERT(inputs_.empty());
      break;
    case prim::Return:
      // Return uses is zero
      AT_ASSERT(outputs().empty());
      break;
    case prim::Param:
      // Param inputs is zero
      AT_ASSERT(inputs_.empty());
      break;
    case prim::PythonOp: {
      // Python operator cconv is correct
      auto* value = static_cast<const PythonOp*>(this);
      value->lint_python();
      break;
    }
    case prim::Eval:
      // TODO: add invariants
      // TODO: It's not good for these ops to be top-level, it makes cases
      // longer.
      break;
    case prim::FusionGroup:
    case prim::CudaFusionGroup:
    case prim::oneDNNFusionGroup:
      checkSameDevice(this);
      // TODO: Typecheck the parameters
      g(attr::Subgraph)->lint();
      break;
  }
}

// TODO: When lint fails, give better indication about which
// instruction triggered the failure.
void Graph::lint() const {
  // Graph invariants

  // Uncomment the following to see the graph
  // std::cout << *const_cast<Graph*>(this);

  // nodes
  // - nodes_ is a valid topological ordering for inputs
  // - No repeated nodes
  // - Params and return do NOT occur in nodes
  // - next_unique_ is greater than all uniques in graph
  // - uniques in all_nodes are unique
  // - every use will occur later in the toposort

  struct LintScope {
    LintScope() = default;
    LintScope(std::unique_ptr<LintScope> parent) : parent(std::move(parent)) {}
    bool contains(const Value* v) {
      return values.count(v) > 0 || (parent && parent->contains(v));
    }
    bool contains(const Node* n) {
      return nodes.count(n) > 0 || (parent && parent->contains(n));
    }
    void insert(const Value* v) {
      AT_ASSERT(!contains(v));
      values.insert(v);
    }
    void insert(const Node* n) {
      AT_ASSERT(!contains(n));
      nodes.insert(n);
    }
    // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
    std::unique_ptr<LintScope> parent;

   private:
    std::unordered_set<const Value*> values;
    std::unordered_set<const Node*> nodes;
  };
  // Struct enables mutual recursion in linting methods.
  // Putting it inside Graph::lint enables access to private Graph members
  struct LintImpl {
    LintImpl(const Graph& g)
        : g(g),
          scope(new LintScope()),
          all_nodes_set(ALL_OF(g.all_nodes)) {} // NB: all_nodes is *unordered*
    const Graph& g;
    std::unique_ptr<LintScope> scope;
    std::unordered_set<size_t> seen_uniques;
    std::unordered_map<const Node*, int64_t> anticipated_uses;
    node_set all_nodes_set;
    node_set sum_set;

    void check_value(const Value* v) {
      scope->insert(v);
      auto b2 = seen_uniques.insert(v->unique());
      AT_ASSERT(b2.second); // insertion took place
      AT_ASSERT(v->unique() < g.next_unique_);

      for (auto use : v->uses()) {
        AT_ASSERT(!scope->contains(use.user));
        AT_ASSERT(g.all_nodes.count(use.user) == 1);
        anticipated_uses[use.user]++; // int default constructs to 0
      }
    }
    void check_node(const Node* n) {
      for (auto input : n->inputs_) {
        if (!scope->contains(input)) {
          AT_ASSERTM(0, input->unique(), " not in scope");
        }
      }
      AT_ASSERT(anticipated_uses[n] == static_cast<int64_t>(n->inputs_.size()));
      anticipated_uses[n] = -1; // we saw the anticipated user!
      scope->insert(n);
      for (auto block : n->blocks()) {
        scope = std::make_unique<LintScope>(std::move(scope));
        check_block(block);
        scope = std::move(scope->parent);
      }
      size_t i = 0;
      for (auto o : n->outputs()) {
        AT_ASSERT(o->node() == n);
        AT_ASSERT(i++ == o->offset_);
        check_value(o);
      }
      n->lint();
    }
    void check_block(const Block* b) {
      // Check topological ordering
      AT_ASSERT(b->param_node()->isBefore(*b->nodes().begin()));
      auto curNode = *b->nodes().begin();
      while (curNode != b->return_node()) {
        AT_ASSERT(curNode->isBefore(curNode->next()));
        curNode = curNode->next();
      }

      for (auto input : b->inputs()) {
        check_value(input);
        AT_ASSERT(input->node()->kind_ == prim::Param);
      }

      for (auto n : b->nodes()) {
        AT_ASSERT(n->kind_ != prim::Param);
        AT_ASSERT(n->kind_ != prim::Return);
        check_node(n);
      }

      AT_ASSERT(b->output_->kind() == prim::Return);
      check_node(b->output_);

      // all_nodes
      // - inputs_, output_ and nodes_ are all included in all_nodes
      // - all_nodes does not contain dead nodes??? (likely to be temporarily
      // suspended).  Weaker: all_nodes contains all inputs and returns
      // - only one return node???

      node_set nodes_set(ALL_OF(b->nodes()));
      node_set inputs_set{b->input_};
      node_set output_set{b->output_};
      // TODO: Make a more type safe std::includes wrapper which disallows use
      // on non-ordered containers
      AT_ASSERT(std::includes(ALL_OF(all_nodes_set), ALL_OF(nodes_set)));
      AT_ASSERT(std::includes(ALL_OF(all_nodes_set), ALL_OF(inputs_set)));
      AT_ASSERT(std::includes(ALL_OF(all_nodes_set), ALL_OF(output_set)));

      sum_set.insert(ALL_OF(nodes_set));
      sum_set.insert(ALL_OF(inputs_set));
      sum_set.insert(ALL_OF(output_set));
    }
    void check_graph() {
      node_set all_nodes_set(
          ALL_OF(g.all_nodes)); // NB: all_nodes is *unordered*

      check_block(g.block_);
      for (auto kv : anticipated_uses) {
        AT_ASSERT(kv.second == -1);
      }
      AT_ASSERT(std::includes(ALL_OF(sum_set), ALL_OF(all_nodes_set)));
    }
  };
  LintImpl(*this).check_graph();
}

void Graph::dump() const {
  std::cout << *this << "\n";
}

void Graph::push_scope(const std::string& scope_name) {
  current_scope_ = current_scope_->push(Symbol::scope(scope_name));
  Node* block_node = insertNode(create(prim::TracedModuleForward, 0));
  block_node->s_(attr::scope, scope_name);
  Block* b = block_node->addBlock();
  setInsertPoint(b);
}
void Graph::pop_scope() {
  current_scope_ = current_scope_->parent();
  if (insertPoint()->owningBlock()->owningNode()->kind() ==
      prim::TracedModuleForward) {
    setInsertPoint(insertPoint()->owningBlock()->owningNode()->next());
  }
}

void LintGraph(const std::shared_ptr<Graph>& graph) {
  graph->lint();
}

Block::Block(Graph* graph_, Node* node_)
    : graph_(graph_),
      output_(graph_->create(prim::Return, 0)),
      input_(graph_->create(prim::Param, 0)),
      owning_node_(node_) {
  input_->next() = output_;
  input_->prev() = output_;
  output_->next() = input_;
  output_->prev() = input_;

  graph_->all_blocks.emplace(this);
  output_->owning_block_ = this;
  output_->topo_position_ = kUpperBound;
  input_->owning_block_ = this;
  input_->topo_position_ = kLowerBound;
}

void Block::reIndexTopology() {
  auto curPos = kLowerBound;
  for (auto node : nodes()) {
    AT_ASSERT(curPos <= (kUpperBound - kAppendInterval));
    curPos += kAppendInterval;
    node->topo_position_ = curPos;
  }
}

void Block::cloneFrom(Block* src, std::function<Value*(Value*)> value_map) {
  std::unordered_map<Value*, Value*> local_map;
  auto env = [&](Value* v) {
    auto it = local_map.find(v);
    if (it != local_map.end()) {
      return it->second;
    }
    return value_map(v);
  };

  auto graph = owningGraph();
  for (auto input : src->inputs()) {
    local_map[input] = this->addInput()->copyMetadata(input);
  }

  for (auto node : src->nodes()) {
    auto new_node = this->appendNode(graph->createClone(node, env));
    for (size_t i = 0; i < node->outputs().size(); ++i) {
      auto oo = node->outputs()[i];
      auto no = new_node->outputs()[i];
      local_map[oo] = no;
      no->copyMetadata(oo);
    }
  }
  for (auto output : src->outputs()) {
    this->registerOutput(env(output));
  }
}

void Block::destroy() {
  // we cannot destroy the output because it is used as the sentinel
  // for the nodes() list and has to remain valid for the loop
  output_->removeAllInputs();
  for (auto it = this->nodes().reverse().begin(),
            end = this->nodes().reverse().end();
       it != end;
       ++it) {
    it.destroyCurrent();
  }
  output_->destroy();
  input_->destroy();
  graph_->freeBlock(this);
}

void Graph::cloneFrom(Graph& src) {
  auto env = [](Value* v) -> Value* {
    AT_ERROR(
        "Graph::copy() encountered a use of a value " + v->debugName() +
        " not in scope. Run lint!");
  };
  block()->cloneFrom(src.block(), env);
}

std::shared_ptr<Graph> Graph::copy() {
  auto new_g = std::make_shared<Graph>();
  new_g->cloneFrom(*this);
  return new_g;
}

std::unique_ptr<Graph> Graph::copyUnique() {
  auto new_g = std::make_unique<Graph>();
  new_g->cloneFrom(*this);
  return new_g;
}

void Block::remapTypes(const std::function<TypePtr(TypePtr)>& type_map) {
  for (Value* input : inputs()) {
    input->setType(type_map(input->type()));
  }
  for (Node* node : nodes()) {
    for (Value* output : node->outputs()) {
      output->setType(type_map(output->type()));
    }
    for (Block* sub_block : node->blocks()) {
      sub_block->remapTypes(type_map);
    }
    for (Symbol name : node->attributeNames()) {
      if (node->kindOf(name) == AttributeKind::g) {
        node->g(name)->remapTypes(type_map);
      } else if (node->kindOf(name) == AttributeKind::gs) {
        for (const auto& g : node->gs(name)) {
          g->remapTypes(type_map);
        }
      }
    }
  }
}

void Graph::remapTypes(const std::function<TypePtr(TypePtr)>& type_map) {
  block()->remapTypes(type_map);
}

void Value::inferTypeFrom(const at::Tensor& output) {
  setType(TensorType::create(output));
}

void Value::inferTypeFrom(
    const c10::intrusive_ptr<c10::ivalue::Object>& output) {
  setType(output->type());
}

bool Value::mustBeNone() const {
  return type()->cast<NoneType>() || node_->mustBeNone();
}
bool Value::mustNotBeNone() const {
  return node_->kind() != prim::AutogradAdd && type() != NoneType::get() &&
      !type()->cast<OptionalType>() &&
      !(type()->cast<UnionType>() &&
        type()->expect<UnionType>()->canHoldType(*NoneType::get()));
}

std::string Value::debugNameBase() const {
  std::string name = debugName();
  std::string name_base = name;
  auto last_dot_pos = name.find_last_of('.');
  if (last_dot_pos != std::string::npos && last_dot_pos + 1 != name.size()) {
    if (name.find_first_not_of("0123456789", last_dot_pos + 1) ==
        std::string::npos) {
      name_base = name.substr(0, last_dot_pos);
    }
  }
  return name_base;
}

bool Value::isValidName(const std::string& name) {
  // Empty strings are legal
  if (name.empty()) {
    return true;
  }

  // Numbers are not legal
  if (isNumber(name)) {
    return false;
  }

  return true;
}

Value* Value::setDebugName(const std::string& name) {
  if (!isValidName(name)) {
    throw std::runtime_error("Invalid name: '" + name + "'");
  }

  auto& names = node()->owningGraph()->unique_names_;

  // clear any old name from the map
  if (hasDebugName()) {
    names.erase(unique_name_);
    unique_name_ = "";
  }

  // allow "" to clear the uniquename
  if (name.empty()) {
    return this;
  }

  // if someone else has this name, then rename the other value
  auto old_owner_of_name = names.find(name);
  if (old_owner_of_name != names.end()) {
    size_t suffix = 1;
    std::string name_base = name;
    auto last_dot_pos = name.find_last_of('.');
    if (last_dot_pos != std::string::npos && last_dot_pos + 1 != name.size()) {
      if (name.find_first_not_of("0123456789", last_dot_pos + 1) ==
          std::string::npos) {
        suffix = std::stoll(name.substr(last_dot_pos + 1));
        name_base = name.substr(0, last_dot_pos);
      }
    }

    auto& names_suffixes = node()->owningGraph()->name_base_suffix_;
    auto it = names_suffixes.find(name_base);
    if (it != names_suffixes.end()) {
      suffix = std::max(suffix, it->second + 1);
    }

    // Verify that new name is not used and find next usable name in case
    // suffix is used.
    std::string replacement_name;
    do {
      std::stringstream ss;
#ifndef _WIN32
      // Protect 12345 integer from becoming "1,2345" if some other process sets
      // global locale For more details see
      // https://github.com/pytorch/pytorch/issues/79583#issuecomment-1161260061
      static std::locale c_locale("C");
      ss.imbue(c_locale);
#endif
      ss << name_base << "." << suffix++;
      replacement_name = ss.str();
    } while (names.count(replacement_name) > 0);

    names_suffixes[name_base] = suffix;

    old_owner_of_name->second->setDebugName(replacement_name);
  }

  names[name] = this;
  unique_name_ = name;
  return this;
}

Value* Value::copyMetadata(Value* from) {
  setType(from->type());
  if (from->hasDebugName()) {
    setDebugName(from->debugName());
  }
  return this;
}

void Value::replaceFirstUseWith(Value* newValue) {
  AT_ASSERT(owningGraph() == newValue->owningGraph());
  auto u = uses()[0];
  u.user->inputs_[u.offset] = newValue;
  newValue->uses_.push_back(u);
  uses_.erase(uses_.begin());
}

void Value::replaceAllUsesWith(Value* newValue) {
  while (!uses().empty()) {
    replaceFirstUseWith(newValue);
  }
}

void Value::replaceAllUsesAfterNodeWith(const Node* node, Value* newValue) {
  std::for_each(uses_.begin(), uses_.end(), [&node, newValue](Use& u) {
    if (u.user->isAfter(node)) {
      u.user->inputs_[u.offset] = newValue;
      newValue->uses_.push_back(u);
    }
  });

  uses_.erase(
      std::remove_if(
          uses_.begin(),
          uses_.end(),
          [&node](const Use& u) { return u.user->isAfter(node); }),
      uses_.end());
}

void Value::replaceAllUsesDominatedByNodeWith(
    const Node* node,
    Value* newValue) {
  std::for_each(uses_.begin(), uses_.end(), [&node, newValue](Use& u) {
    if (u.user->isDominatedBy(node)) {
      u.user->inputs_[u.offset] = newValue;
      newValue->uses_.push_back(u);
    }
  });

  uses_.erase(
      std::remove_if(
          uses_.begin(),
          uses_.end(),
          [&node](const Use& u) { return u.user->isDominatedBy(node); }),
      uses_.end());
}

static size_t findArgument(
    const FunctionSchema& the_schema,
    const std::string& unqualName) {
  for (const auto i : c10::irange(the_schema.arguments().size())) {
    const Argument* arg = &the_schema.arguments()[i];
    if (arg->name() == unqualName) {
      return i;
    }
  }
  throw std::runtime_error(
      std::string("Couldn't find an argument called ") + unqualName);
}

static size_t findArgument(const FunctionSchema& the_schema, Symbol name) {
  const auto unqualName = name.toUnqualString();
  return findArgument(the_schema, unqualName);
}

std::optional<IValue> Node::get(Symbol name) const {
  return toIValue(namedInput(name));
}

bool Node::hasNamedInput(const std::string& name) const {
  for (const auto& argument : schema().arguments()) {
    if (argument.name() == name) {
      return true;
    }
  }
  return false;
}

Value* Node::namedInput(const std::string& unqualName) const {
  return input(findArgument(schema(), unqualName));
}
Value* Node::namedInput(Symbol name) const {
  return input(findArgument(schema(), name));
}

bool Node::matches(const FunctionSchema& schema) const {
  if (isBlockListedSchema(schema)) {
    return false;
  }
  // wrong name
  if (kind().toQualString() != schema.name()) {
    return false;
  }
  at::ArrayRef<const Value*> actuals = inputs();
  const auto& formals = schema.arguments();

  // not enough inputs
  if (actuals.size() < formals.size()) {
    return false;
  }

  TypeEnv type_env;
  for (const auto i : c10::irange(formals.size())) {
    auto formal = formals[i].type();
    const MatchTypeReturn matched_type =
        matchTypeVariables(formal, actuals[i]->type(), type_env);
    if (!matched_type.success()) {
      return false;
    }

    TypePtr resolved = tryEvalTypeVariables(formal, type_env);
    if (resolved) {
      formal = resolved;
    }
    // note: it is possible at this point that type variable matching has
    // not resolved all type variables, e.g. if None was matched to Optional[T]
    // we will not succeed at matching T. However None <: Optional[T] so this
    // check can still succeed.

    if (!actuals[i]->type()->isSubtypeOf(*formal)) {
      return false;
    }
  }

  // too many inputs
  if (!schema.is_vararg() && actuals.size() != formals.size()) {
    return false;
  }

  return true;
}

bool Node::matches(
    const char* signature_literal,
    at::ArrayRef<Symbol> const_inputs) const {
  if (!matches(getOperatorForLiteral(signature_literal)->schema())) {
    return false;
  }
  for (Symbol s : const_inputs) {
    if (!is_constant(s)) {
      return false;
    }
  }
  return true;
}

bool Node::mustBeNone() const {
  // We can statically deduce this Node has returning None if:
  return
      // It's an AutogradZero node, or ...
      kind_ == prim::AutogradZero ||
      // It has only one output and that output is NoneType, or ...
      (outputs().size() == 1 && output()->type() == NoneType::get()) ||
      // It's a constant optional with no value in the attributes.
      (kind_ == prim::Constant && !this->hasAttributes() &&
       output()->type()->cast<OptionalType>());
}

void Node::dump() const {
  std::cout << *this << "\n";
}

const FunctionSchema& Node::schema() const {
  if (op_) {
    return op_->schema();
  }
  return getOperator().schema();
}

const FunctionSchema* Node::maybeSchema() const {
  if (auto op = maybeOperator()) {
    return &op->schema();
  }
  return nullptr;
}

const Operator* Node::maybeOperator() const {
  if (!op_) {
    const auto& candidates = getAllOperatorsFor(kind());
    for (const auto& candidate : candidates) {
      if (matches(candidate->schema())) {
        op_ = candidate.get();
        break;
      }
    }
  }
  return op_;
}

const Operator& Node::getOperator() const {
  const Operator* maybe = maybeOperator();
  if (maybe)
    return *maybe;

  auto er = ErrorReport(sourceRange());
  er << "Schema not found for node. File a bug report.\n";
  er << "Node: " << *this << "\n";
  er << "Input types:";
  for (const auto i : c10::irange(inputs().size())) {
    if (i > 0)
      er << ", ";
    er << *inputs()[i]->type();
  }
  const auto& candidates = getAllOperatorsFor(kind());
  if (!candidates.empty()) {
    er << "\ncandidates were:\n";
    for (auto& candidate : candidates) {
      er << "  " << candidate->schema() << "\n";
    }
  } else {
    er << "\nno candidates found\n";
  }
  er << "within the graph:\n";
  er << *owningGraph() << "\n";
  throw er;
}

Operation Node::getOperation() const {
  // note: some operators require the node to produce a runnable operation,
  // which is why 'this' is passed here. getOperator() ensures that 'this'
  // matches the schema of the returned operator.
  return getOperator().getOperation(this);
}

bool Node::isNondeterministic() const {
  const auto schema = maybeSchema();
  if (!kind().is_aten()) {
    return false;
  }
  // All aten ops are expecte to have a schema. However this is left as a
  // warning instead of an assert to ensure that previous use cases do not
  // break.
  if (!schema) {
    TORCH_WARN("aten Schema not found.");
    return false;
  }
  torch::utils::SchemaInfo schema_info(*schema);
  if (hasNamedInput("train")) {
    auto value = constant_as<bool>(namedInput("train"));
    if (value.has_value()) {
      schema_info.addArgumentValue("train", *value);
    }
  }
  return schema_info.is_nondeterministic();
}

bool Node::hasSideEffects() const {
  switch (kind_) {
    case prim::PythonOp:
    case prim::IgnoredPythonOp:
    case prim::Print:
    case prim::RaiseException:
    case aten::warn:
    case aten::save:
    case aten::manual_seed:
    case prim::AddStatValue:
    case prim::TimePoint:
    case prim::CallFunction:
    case prim::CallMethod:
    case prim::BailoutTemplate:
    case prim::BailOut:
    case prim::rpc_async: // It represents RPC message sent.
    case prim::rpc_sync: // It represents RPC message sent.
    case prim::rpc_remote: // It represents RPC message sent.
    case aten::wait: // It can represent RPC message received.
#if !defined(USE_ROCM)
    case cuda::set_stream:
    case cuda::_set_device:
    case cuda::_current_device:
    case cuda::synchronize:
#endif
    case prim::Enter:
    case prim::Exit:
      return true;
  }

  auto op = maybeOperator();
  if (!op) {
    TORCH_INTERNAL_ASSERT(
        kind_.is_prim(),
        "Only prim ops are allowed to not have a registered operator but ",
        kind_.toDisplayString(),
        " doesn't have one either. We don't know if this op has side effects.");
    return false;
  }

  if (kind_.is_prim() || kind_.is_aten() || kind_.is_cuda()) {
    // TODO There is nothing in the system that relies on aten:: and prim::
    // ops using AliasAnalysisKind::FROM_SCHEMA,
    // AliasAnalysisKind::INTERNAL_SPECIAL_CASE, or
    // AliasAnalysisKind::CONSERVATIVE but this is the intended behavior for all
    // current ops and a good error check. We can consider lifting this
    // constraint later if we have a use case for it.
    TORCH_INTERNAL_ASSERT(
        op->aliasAnalysisKind() == AliasAnalysisKind::INTERNAL_SPECIAL_CASE ||
            op->aliasAnalysisKind() == AliasAnalysisKind::FROM_SCHEMA ||
            op->aliasAnalysisKind() == AliasAnalysisKind::CONSERVATIVE,
        "aten:: and prim:: ops should have AliasAnalysisKind::INTERNAL_SPECIAL_CASE"
        ", AliasAnalysisKind::FROM_SCHEMA or AliasAnalysisKind::CONSERVATIVE but ",
        kind_.toDisplayString(),
        " has ",
        toString(op->aliasAnalysisKind()));
  }

  switch (op->aliasAnalysisKind()) {
    case AliasAnalysisKind::PURE_FUNCTION:
    case AliasAnalysisKind::FROM_SCHEMA:
    case AliasAnalysisKind::INTERNAL_SPECIAL_CASE:
      return false;
    case AliasAnalysisKind::CONSERVATIVE:
      return true;
  }
  TORCH_INTERNAL_ASSERT(false, "Unhandled AliasAnalysisKind case");
  return false; // silence compiler warning
}

// Assign this node a topological position, to facilitate fast isBefore() and
// isAfter() queries. Must be called right after a node is inserted into the
// node list.
//
// The basic scheme is: assign every node a position (uint64_t).  The common
// case (appending to the end of the graph) is made more efficient by advancing
// a fixed interval past the previous node and placing `this` there. Otherwise,
// assign `this` a position at the midpoint between its prev() and next()
// nodes.
//
// If we ever run out of space (by, e.g. inserting too much in place), we
// reindex by spreading out all the nodes again.
void Node::assignTopoPosition() {
  bool is_first = prev() == owningBlock()->param_node();
  bool is_last = next() == owningBlock()->return_node();

  const auto prevPos = prev()->topo_position_;
  const auto nextPos = next()->topo_position_;

  // Append to the end of the graph
  if (is_last) {
    if (is_first) {
      // the node list is empty, assign the first position
      topo_position_ = kMidPoint;
      return;
    }

    if (prevPos >= (kUpperBound - kAppendInterval)) {
      // we're running off the edge
      owningBlock()->reIndexTopology();
      return;
    }

    topo_position_ = prevPos + kAppendInterval;

    // Prepend to the graph
  } else if (is_first) {
    // next() is the first element in the block list
    if (nextPos <= (kLowerBound + kAppendInterval)) {
      // we're running off the edge
      owningBlock()->reIndexTopology();
      return;
    }
    topo_position_ = nextPos - kAppendInterval;

    // insert between two existing nodes
  } else {
    int64_t remaining = nextPos - prevPos;
    AT_ASSERT(remaining > 0);
    if (remaining == 1) {
      // There was no room
      owningBlock()->reIndexTopology();
      return;
    }
    int64_t predicted_future_insertions = 0;
    if (next() == graph_->insertPoint()) {
      predicted_future_insertions = graph_->predicted_insert_count_++;
    }
    topo_position_ = prevPos +
        std::max(int64_t(1), remaining / (2 + predicted_future_insertions));
    AT_ASSERT(prevPos < topo_position_ && topo_position_ < nextPos);
  }
}

Node::Node(Graph* graph_, NodeKind kind_)
    : kind_(kind_),
      graph_(graph_),
      owning_block_(nullptr),
      scope_(graph_->current_scope_),
      callstack_(std::nullopt),
      op_(nullptr) {
  graph_->all_nodes.emplace(this);
}

void Node::eraseOutput(size_t i) {
  AT_ASSERT(i < outputs_.size());
  AT_ASSERT(outputs_[i]->uses().empty());
  op_ = nullptr;
  Value* n = outputs_[i];
  outputs_.erase(outputs_.begin() + i);
  owningGraph()->freeValue(n);
  for (const auto j : c10::irange(i, outputs_.size())) {
    outputs_[j]->offset_--;
  }
}

Block* Node::addBlock() {
  op_ = nullptr;
  blocks_.push_back(new Block(owningGraph(), this));
  return blocks_.back();
}

void Node::eraseBlock(size_t i) {
  AT_ASSERT(i < blocks_.size());
  op_ = nullptr;
  Block* n = blocks_[i];
  blocks_.erase(blocks_.begin() + i);
  n->destroy();
}

void Node::destroy() {
  while (!outputs().empty()) {
    eraseOutput(outputs().size() - 1);
  }
  while (!blocks().empty()) {
    eraseBlock(blocks().size() - 1);
  }
  removeAllInputs();
  if (inBlockList()) {
    removeFromList();
  }
  graph_->freeNode(this);
}

void Node::cloneFrom(Node* s) {
  source_range_ = s->source_range_;
  if (s->scope_ && !s->scope_->isBlank()) {
    scope_ = s->scope_;
  }
  copyAttributes(*s);
  callstack_ = s->callstack_;
}

void Node::replaceAllUsesWith(Node* n) {
  AT_ASSERT(outputs().size() == n->outputs().size());
  size_t nOutputs = outputs().size();
  for (const auto i : c10::irange(nOutputs)) {
    outputs()[i]->replaceAllUsesWith(n->outputs()[i]);
  }
}

Node* Node::replaceWithNewSymbol(Symbol new_symbol) {
  WithInsertPoint insert_guard{this};
  bool had_operator = maybeOperator() != nullptr;
  auto graph = owningGraph();
  auto replace_node = graph->insertNode(graph->create(new_symbol, 0));
  for (Value* v : inputs()) {
    replace_node->addInput(v);
  }
  for (Value* v : outputs()) {
    auto new_out = replace_node->addOutput()->copyMetadata(v);
    v->replaceAllUsesWith(new_out);
  }
  replace_node->copyMetadata(this);
  replace_node->copyAttributes(*this);
  TORCH_INTERNAL_ASSERT(
      (replace_node->maybeOperator() != nullptr) == had_operator,
      "invalid symbol replacement:",
      new_symbol,
      kind());
  return replace_node;
}

bool Node::isDominatedBy(const Node* dominator) const {
  const Node* node = this;
  while (node) {
    if (node->owningBlock() == dominator->owningBlock()) {
      return dominator->isBefore(node);
    }
    node = node->owningBlock()->owningNode();
  }
  return false;
}

Value* Node::insertInput(size_t i, Value* value) {
  AT_ASSERT(graph_ == value->owningGraph());
  op_ = nullptr;
  // First we update the offsets for all existing inputs that will reside
  // after the one we're inserting. Concretely, these are the inputs at
  // indices [i, # input). Since we're inserting one input before all of
  // these inputs, increment their use offsets for this value by 1
  for (const auto use_itr : c10::irange(i, inputs_.size())) {
    // See Note [User node does not uniquely identify use]
    auto use = findUseForInput(use_itr);
    use->offset += 1;
  }
  // Insert the actual input at the specified index
  inputs_.insert(inputs_.begin() + i, value);
  // Register the new use of the value we're inserted as an input.
  value->uses_.emplace_back(this, i);
  return value;
}

Value* Node::addInput(Value* value) {
  AT_ASSERT(graph_ == value->owningGraph());
  op_ = nullptr;
  value->uses_.emplace_back(this, inputs_.size());
  inputs_.push_back(value);
  return value;
}

Value* Node::replaceInput(size_t i, Value* newValue) {
  AT_ASSERT(newValue->owningGraph() == graph_);
  op_ = nullptr;
  Value* old = dropInput(i);
  inputs_[i] = newValue;
  newValue->uses_.emplace_back(this, i);
  return old;
}

void Node::replaceInputWith(Value* from, Value* to) {
  AT_ASSERT(from->owningGraph() == graph_);
  AT_ASSERT(to->owningGraph() == graph_);
  op_ = nullptr;
  size_t i = 0;
  for (auto input : inputs()) {
    if (input == from) {
      replaceInput(i, to);
    }
    i++;
  }
}

Value* Node::addOutput() {
  outputs_.push_back(new Value(this, outputs_.size()));
  op_ = nullptr;
  return outputs_.back();
}

Value* Node::insertOutput(size_t i) {
  op_ = nullptr;
  outputs_.insert(outputs_.begin() + i, new Value(this, i));
  for (size_t itr = i + 1; itr < outputs_.size(); ++itr) {
    outputs_[itr]->setOffset(outputs_[itr]->offset() + 1);
  }
  return outputs_.at(i);
}

bool Node::isBeforeOrAfter(const Node* n, MoveSide moveSide) const {
  if (this->owningBlock() == n->owningBlock()) {
    if (moveSide == MoveSide::BEFORE) {
      return this->topo_position_ < n->topo_position_;
    }

    if (moveSide == MoveSide::AFTER) {
      return this->topo_position_ > n->topo_position_;
    }

    AT_ASSERT(this == n);
    return false;
  }

  // These nodes don't share a common block. Traverse the blockchains upward
  // until we find the first common block.
  auto lhs = this;
  while (lhs) {
    AT_ASSERT(lhs->owningBlock());

    auto rhs = n;
    while (rhs) {
      if (!rhs->owningBlock()) {
        break;
      }

      if (lhs->owningBlock() == rhs->owningBlock()) {
        return lhs->isBeforeOrAfter(rhs, moveSide);
      }
      rhs = rhs->owningBlock()->owningNode();
    }

    lhs = lhs->owningBlock()->owningNode();
  }
  // should never reach here, since both nodes are ultimately in the same graph
  AT_ASSERT(false);
}

bool Node::isBefore(const Node* n) const {
  return isBeforeOrAfter(n, MoveSide::BEFORE);
}

bool Node::isAfter(const Node* n) const {
  return isBeforeOrAfter(n, MoveSide::AFTER);
}

Node* Node::insertBefore(Node* n) {
  AT_ASSERT(n->inBlockList());
  insertAfter(n->prev());
  return this;
}

Node* Node::insertAfter(Node* n) {
  AT_ASSERT(!inBlockList() && n->inBlockList());
  AT_ASSERT(n->owningBlock());
  AT_ASSERTM(
      n->kind() != prim::Return,
      "Attempting to insert a Node after the Return node or before the Param node. Tried to insert",
      *this,
      " after ",
      *n,
      ".");
  this->owning_block_ = n->owningBlock();
  Node* next = n->next();
  n->next() = this;
  this->prev() = n;
  this->next() = next;
  next->prev() = this;
  assignTopoPosition();
  return this;
}

void Node::moveAfter(Node* n) {
  removeFromList();
  insertAfter(n);
}

void Node::moveBefore(Node* n) {
  removeFromList();
  insertBefore(n);
}

void Node::removeInput(size_t i) {
  op_ = nullptr;
  dropInput(i);
  // everything after this input shifts left,
  // so we need to update their use offsets to match
  for (size_t j = i + 1; j < inputs_.size(); j++) {
    auto it = findUseForInput(j);
    it->offset--;
  }
  inputs_.erase(inputs_.begin() + i);
}

void Node::removeAllInputs() {
  op_ = nullptr;
  for (const auto i : c10::irange(inputs().size())) {
    dropInput(i);
  }
  inputs_.clear();
}

void Node::removeAllOutputs() {
  op_ = nullptr;
  size_t init_osize = outputs_.size();
  for (auto i : c10::irange(init_osize)) {
    eraseOutput(init_osize - i - 1);
  }
}

void Node::permuteInputs(const std::vector<size_t>& new_order) {
  op_ = nullptr;
  AT_ASSERT(new_order.size() == inputs_.size());
  std::vector<Value*> new_inputs;
  new_inputs.reserve(new_order.size());
  for (const auto i : c10::irange(new_order.size())) {
    AT_ASSERTM(inputs_.at(new_order[i]) != nullptr, "Repeated index");
    new_inputs.push_back(inputs_.at(new_order[i]));
    auto it = findUseForInput(new_order[i]);
    it->offset = i;
    inputs_.at(new_order[i]) = nullptr;
  }
  inputs_ = std::move(new_inputs);
}

void Node::permuteOutputs(const std::vector<size_t>& new_order) {
  op_ = nullptr;
  AT_ASSERT(new_order.size() == outputs_.size());
  std::vector<Value*> new_outputs;
  new_outputs.reserve(new_order.size());
  for (const auto i : c10::irange(new_order.size())) {
    AT_ASSERTM(outputs_.at(new_order[i]) != nullptr, "Repeated index");
    new_outputs.push_back(outputs_.at(new_order[i]));
    outputs_.at(new_order[i])->setOffset(i);
    outputs_.at(new_order[i]) = nullptr;
  }
  outputs_ = std::move(new_outputs);
}

use_list::iterator Node::findUseForInput(size_t i) {
  auto& input_uses = inputs_[i]->uses_;
  // O(N) on the use list, but unless we get nodes with +100 uses
  // vector traversal still is probably faster than linked list
  auto use_it = std::find(input_uses.begin(), input_uses.end(), Use(this, i));
  AT_ASSERT(use_it != input_uses.end());
  return use_it;
}

Value* Node::dropInput(size_t i) {
  AT_ASSERT(i < inputs_.size());
  auto input_node = inputs_[i];
  auto use_it = findUseForInput(i);
  input_node->uses_.erase(use_it);
  inputs_[i] = nullptr;
  return input_node;
}

void Node::removeFromList() {
  AT_ASSERT(inBlockList());
  this->owning_block_ = nullptr;
  Node* next = this->next();
  Node* prev = this->prev();
  prev->next() = next;
  next->prev() = prev;
  this->next() = nullptr;
  this->prev() = nullptr;
}

Block* Node::findCommonAncestorBlockWith(Node* n) {
  if (n->owningBlock() == owningBlock()) {
    return owningBlock();
  }

  Node* n1 = this;
  Node* n2 = n;

  size_t d_1 = n1->blocksFromGraphBlock();
  size_t d_2 = n2->blocksFromGraphBlock();

  for (; d_1 > d_2; --d_1) {
    n1 = n1->owningBlock()->owningNode();
    // n2 contains n1
  }

  for (; d_2 > d_1; --d_2) {
    n2 = n2->owningBlock()->owningNode();
  }

  // Now they are the same numer of blocks from the graph block,
  // recurse upwards, checking if they are on the same block
  while (true) {
    if (n1->owningBlock() == n2->owningBlock()) {
      return n1->owningBlock();
    }

    n1 = n1->owningBlock()->owningNode();
    n2 = n2->owningBlock()->owningNode();

    AT_ASSERT(n1 != nullptr);
    AT_ASSERT(n2 != nullptr);
  }
}

size_t Node::blocksFromGraphBlock() {
  Node* n = this;
  size_t dist = 0;
  while (n->owningBlock()->owningNode()) {
    n = n->owningBlock()->owningNode();
    ++dist;
  }
  return dist;
}

inline const SourceRange& fakeRange() {
  static SourceRange range(std::make_shared<Source>(std::string("")), 0, 1);
  return range;
}

Value* Graph::insert(
    Symbol opname,
    at::ArrayRef<NamedValue> args,
    at::ArrayRef<NamedValue> kwargs,
    const std::optional<SourceRange>& range) {
  return emitBuiltinCall(
      range.value_or(fakeRange()), *this, opname, args, kwargs);
}

Node* Graph::create(NodeKind kind, size_t num_outputs) {
  // NB: Node constructor adds node to all_nodes
  auto n = new Node(this, kind);
  for (const auto i : c10::irange(num_outputs)) {
    (void)i;
    n->addOutput();
  }
  return n;
}

Node* Graph::create(
    NodeKind kind,
    ArrayRef<Value*> inputs,
    size_t num_outputs) {
  auto n = create(kind, num_outputs);
  for (auto i : inputs) {
    n->addInput(i);
  }
  return n;
}

Node* Graph::createAutogradZero() {
  return create(prim::AutogradZero);
}

Node* Graph::createNone() {
  Node* n = create(prim::Constant);
  n->output()->setType(NoneType::get());
  return n;
}

Node* Graph::createUninitialized(TypePtr typ) {
  Node* n = create(prim::Uninitialized);
  n->output()->setType(std::move(typ));
  return n;
}

Node* Graph::createWithSubgraph(Symbol kind) {
  auto n = create(kind, 0);
  n->g_(attr::Subgraph, std::make_shared<Graph>(current_scope()));
  return n;
}

Node* Graph::createTuple(at::ArrayRef<Value*> values, TupleTypePtr tuple_type) {
  TORCH_INTERNAL_ASSERT(
      !tuple_type || tuple_type->schema(),
      "only pass tuple_type when creating a named tuple");
  if (!tuple_type) {
    auto types = fmap(values, [](Value* v) { return v->type(); });
    tuple_type = TupleType::create(std::move(types));
  }
  auto n = create(prim::TupleConstruct, values);

  n->output()->setType(tuple_type);
  return n;
}

Node* Graph::createTupleUnpack(Value* v) {
  TupleTypePtr tt = v->type()->expect<TupleType>();
  auto n = create(prim::TupleUnpack, {v}, 0);
  for (auto& element : tt->elements()) {
    n->addOutput()->setType(element);
  }
  return n;
}

Node* Graph::createTupleIndex(
    Value* tup,
    Value* idx,
    const TypePtr& output_type) {
  auto n = create(prim::TupleIndex, {tup, idx});
  n->output()->setType(output_type);
  return n;
}

Node* Graph::createTupleSlice(
    Value* tup,
    int64_t beg,
    int64_t step_size,
    int64_t num_values) {
  std::vector<Value*> new_vals;
  TupleTypePtr tt = tup->type()->expect<TupleType>();
  new_vals.reserve(num_values);

  int64_t i = beg;
  for (const auto j : c10::irange(num_values)) {
    (void)j; // Suppress unused variable warning
    auto idx = insertConstant(IValue(static_cast<int64_t>(i)));
    auto tupleIndex = insertNode(createTupleIndex(tup, idx, tt->elements()[i]));

    new_vals.push_back(tupleIndex->output());
    i += step_size;
  }

  auto n = createTuple(new_vals);
  return n;
}

Node* Graph::createEnumName(Value* e) {
  e->type()->expect<EnumType>();
  assert(e->type()->cast<EnumType>());
  auto n = create(prim::EnumName, {e});
  n->output()->setType(StringType::get());
  return n;
}

Node* Graph::createEnumValue(Value* e) {
  auto enum_type = e->type()->expect<EnumType>();
  auto n = create(prim::EnumValue, {e});
  n->output()->setType(enum_type->getValueType());
  return n;
}

Node* Graph::createList(
    const TypePtr& contained_type,
    at::ArrayRef<Value*> values) {
  auto n = create(prim::ListConstruct, values);
  for (const auto& v : values) {
    TORCH_CHECK(
        v->type()->isSubtypeOf(*contained_type),
        "Expected a list element that subtypes '",
        contained_type->repr_str(),
        "' but got an element of type '",
        v->type()->repr_str(),
        "'");
  }
  n->output()->setType(ListType::create(contained_type));
  return n;
}

Node* Graph::createListUnpack(Value* v, size_t size) {
  ListTypePtr list_type = v->type()->expect<ListType>();
  TypePtr elem_type = list_type->getElementType();
  auto n = create(prim::ListUnpack, {v}, 0);
  for (const auto i : c10::irange(size)) {
    (void)i; // Suppress unused variable warning
    n->addOutput()->setType(elem_type);
  }
  return n;
}

Node* Graph::createDict(
    const TypePtr& key_type,
    const TypePtr& value_type,
    at::ArrayRef<Value*> keys,
    at::ArrayRef<Value*> values) {
  AT_ASSERT(keys.size() == values.size());
  auto n = create(prim::DictConstruct, 1);
  for (const auto i : c10::irange(keys.size())) {
    AT_ASSERT(keys[i]->type()->isSubtypeOf(*key_type));
    AT_ASSERT(values[i]->type()->isSubtypeOf(*value_type));

    n->addInput(keys[i]);
    n->addInput(values[i]);
  }
  n->output()->setType(DictType::create(key_type, value_type));
  return n;
}

Node* Graph::createNumToTensor(Value* value) {
  Node* result = create(prim::NumToTensor, {value});
  result->output()->setType(TensorType::fromNumberType(*value->type()));
  return result;
}

Node* Graph::createObject(const ClassTypePtr& type) {
  auto result = create(prim::CreateObject);
  result->output()->setType(type);
  return result;
}

Node* Graph::createSetAttr(
    Value* obj,
    const std::string& field,
    Value* newValue) {
  auto n = create(prim::SetAttr, {obj, newValue}, /*num_outputs=*/0);
  n->s_(attr::name, field);
  return n;
}

Node* Graph::createGetAttr(Value* obj, const std::string& field) {
  const auto classType = obj->type()->expect<ClassType>();

  auto n = create(prim::GetAttr, {obj}, /*num_outputs=*/1);
  n->s_(attr::name, field);

  const auto outputType = classType->getAttribute(field);
  n->output()->setType(outputType);
  n->output()->setDebugName(normalizeAttrName(field));
  return n;
}

Node* Graph::createStore(const std::string& name, Value* v) {
  auto n = create(prim::Store, {v}, /*num_outputs*/ 0);
  n->s_(attr::name, name);
  return n;
}

Node* Graph::createLoad(const std::string& name, const TypePtr& type) {
  auto n = create(prim::Load, {}, /*num_outputs*/ 1);
  n->s_(attr::name, name);
  n->output()->setType(type);
  return n;
}

Node* Graph::createIsInstance(Value* v, at::ArrayRef<TypePtr> types) {
  auto n = create(prim::isinstance, {v}, /*num_outputs*/ 1);
  n->tys_(attr::types, types.vec());
  n->output()->setType(BoolType::get());
  return n;
}
Value* Graph::insertUncheckedCast(Value* v, TypePtr type) {
  Node* n = insertNode(create(prim::unchecked_cast, {v}));
  n->output()->setType(std::move(type));
  return n->output();
}

Value* Graph::insertToList(Value* v, TypePtr type) {
  int dim = 0;
  TypePtr ptr = type;

  // Unwrap the type to determine the number of dimensions.
  while (auto list_type = ptr->cast<ListType>()) {
    ptr = list_type->getElementType();
    ++dim;
  }

  // Encode the base element type as an integer.
  int elem_ty = 0;
  if (ptr == IntType::get()) {
    elem_ty = 0;
  } else if (ptr == FloatType::get()) {
    elem_ty = 1;
  } else if (ptr == BoolType::get()) {
    elem_ty = 2;
  } else if (ptr == ComplexType::get()) {
    elem_ty = 3;
  } else {
    TORCH_CHECK(
        false,
        ptr->repr_str(),
        " is not one of the supported element types for tolist: int, float, complex, bool");
  }

  // Pass in the number of dimensions and base element type as arguments
  // to the op.
  Value* dim_val = insertConstant(IValue(dim));
  Value* elem_ty_val = insertConstant(IValue(elem_ty));
  Node* n = insertNode(create(prim::tolist, {v, dim_val, elem_ty_val}));
  n->output()->setType(std::move(type));
  return n->output();
}

Value* Graph::insertFunctionCall(
    Function* callee,
    const MatchedSchema& matched) {
  std::string func_name = callee->name();
  Value* fn_constant = insertNode(create(prim::Constant))
                           ->s_(attr::name, func_name)
                           ->output()
                           ->setType(FunctionType::create(callee));
  std::vector<Value*> inputs = {fn_constant};
  inputs.insert(inputs.end(), matched.inputs.begin(), matched.inputs.end());
  Value* result = insertNode(create(prim::CallFunction, inputs))
                      ->output()
                      ->setType(matched.return_types.at(0));
  return result;
}

Value* Graph::insertMethodCall(
    std::string method_name,
    const MatchedSchema& matched) {
  Value* result = insertNode(create(prim::CallMethod, matched.inputs))
                      ->s_(attr::name, std::move(method_name))
                      ->output()
                      ->setType(matched.return_types.at(0));
  return result;
}

Node* Graph::createClone(
    Node* n,
    const std::function<Value*(Value*)>& value_map,
    bool copy_blocks) {
  // n can be from a different graph
  Node* r = n->allocNewInstance(this);
  for (auto o : n->outputs()) {
    r->addOutput()->copyMetadata(o);
  }
  r->cloneFrom(n);
  for (auto i : n->inputs()) {
    r->addInput(value_map(i));
  }
  if (copy_blocks) {
    for (auto b : n->blocks()) {
      r->addBlock()->cloneFrom(b, value_map);
    }
  }
  return r;
}

Value* Graph::insertConstant(
    const IValue& val,
    std::optional<SourceRange> loc,
    std::optional<ScopePtr> scope) {
  return jit::insertConstant(*this, val, std::move(loc), std::move(scope));
}

std::string Graph::toString(bool print_source_locations) const {
  std::ostringstream oss;
  print(oss, print_source_locations);
  return oss.str();
}

Graph::~Graph() {
  for (const Node* n : all_nodes) {
    delete n;
  }
  for (const Value* v : all_values) {
    delete v;
  }
  for (const Block* b : all_blocks) {
    delete b;
  }
}

void Graph::freeNode(Node* n) {
  auto it = all_nodes.find(n);
  AT_ASSERT(it != all_nodes.end());
  delete *it;
  all_nodes.erase(it);
}
void Graph::freeValue(Value* v) {
  v->setDebugName("");
  auto it = all_values.find(v);
  AT_ASSERT(it != all_values.end());
  delete *it;
  all_values.erase(it);
}
void Graph::freeBlock(Block* b) {
  auto it = all_blocks.find(b);
  AT_ASSERT(it != all_blocks.end());
  delete *it;
  all_blocks.erase(it);
}

at::ArrayRef<Value*> createTupleUnpack(Value* v) {
  // small peephole optimization to ensure IntArrayRef attributes can still turn
  // into constants e.g. in x.expand([3, 4])
  if (v->node()->kind() == prim::TupleConstruct) {
    return v->node()->inputs();
  }
  auto& g = *v->owningGraph();
  return g.insertNode(g.createTupleUnpack(v))->outputs();
}

void inlineCallStackOfNode(
    Node* n,
    std::unordered_map<InlinedCallStack*, InlinedCallStackPtr>& new_cs_entries,
    Function* callee,
    Node* to_replace,
    const std::optional<ModuleInstanceInfo>& m_info);

static void inlineCallStackOfBlock(
    Block* b,
    std::unordered_map<InlinedCallStack*, InlinedCallStackPtr>& new_cs_entries,
    Function* callee,
    Node* to_replace,
    const std::optional<ModuleInstanceInfo>& m_info) {
  for (auto n : b->nodes()) {
    inlineCallStackOfNode(n, new_cs_entries, callee, to_replace, m_info);
  }
}

void inlineCallStackOfNode(
    Node* new_node,
    std::unordered_map<InlinedCallStack*, InlinedCallStackPtr>& new_cs_entries,
    Function* callee,
    Node* to_replace,
    const std::optional<ModuleInstanceInfo>& m_info) {
  auto new_node_cs = new_node->callstack();

  InlinedCallStack* raw_callstack_ptr =
      new_node_cs ? new_node_cs->get() : nullptr;

  if (!new_cs_entries.count(raw_callstack_ptr)) {
    if (new_node_cs) {
      new_cs_entries[raw_callstack_ptr] = c10::make_intrusive<InlinedCallStack>(
          *new_node_cs, callee, to_replace->sourceRange(), m_info);
    } else {
      new_cs_entries[raw_callstack_ptr] = c10::make_intrusive<InlinedCallStack>(
          callee, to_replace->sourceRange(), m_info);
    }
  }
  new_node->setCallStack(new_cs_entries.at(raw_callstack_ptr));
  // We updated the inlined callstack of new_node.
  // Same must be done for the nodes of the blocks of new_node.
  // For example If node's block otherwise is not annotated appropriately.
  for (auto block : new_node->blocks()) {
    inlineCallStackOfBlock(block, new_cs_entries, callee, to_replace, m_info);
  }
}

std::vector<Value*> inlineCallTo(
    Node* to_replace,
    GraphFunction* callee,
    Graph* callee_graph) {
  WithInsertPoint guard(to_replace);
  std::unordered_map<Value*, Value*> value_map;
  std::vector<torch::jit::Value*> new_outputs = insertGraph(
      *to_replace->owningGraph(),
      *callee_graph,
      to_replace->inputs(),
      value_map);

  std::unordered_map<InlinedCallStack*, InlinedCallStackPtr>
      new_callstack_entries;

  std::optional<ModuleInstanceInfo> module_instance_info = std::nullopt;
  if (to_replace->kind() == prim::CallMethod) {
    auto class_type_ptr = to_replace->input(0)->type()->cast<c10::ClassType>();
    if (to_replace->input(0)->node()->kind() == prim::GetAttr) {
      module_instance_info = std::make_optional(ModuleInstanceInfo(
          class_type_ptr, to_replace->input(0)->node()->s(attr::name)));
    } else if (
        !to_replace->owningGraph()->inputs().empty() &&
        to_replace->input(0) == to_replace->owningGraph()->inputs()[0]) {
      // This CallMethod must correspond to method of the same object
      // to which this graph belongs.
      module_instance_info =
          std::make_optional(ModuleInstanceInfo(class_type_ptr, "SELF"));
    } else {
      // Not sure if it is possible to come here ever.
      // TODO: Remove this else. Or add assert
      module_instance_info = std::make_optional(
          ModuleInstanceInfo(class_type_ptr, "INSTANCE_NAME_UNKNOWN"));
    }
  }

  // TODO: We might need to use nodes_map instead of value_map. Otherwise, we
  // are missing nodes without outputs (e.g. prim::Print).
  std::unordered_set<Node*> updated_nodes;
  for (const auto& kv : value_map) {
    /* Skip the old value if it is the graph input.
     * The reason is that, value_map contains values not all for the nodes of
     * the graph but primary inputs as well, and it will create duplicates when
     * the first inlined graph is input to the next one. To avoid this issue,
     * skip the old value when it is one of the
     * callee->optimized_graph()->inputs() or callee->graph()->inputs(), depends
     * on if it is inlined_optimized_graph
     */
    auto is_graph_input = std::find(
        callee_graph->inputs().begin(), callee_graph->inputs().end(), kv.first);
    if (is_graph_input != callee_graph->inputs().end()) {
      continue;
    }

    Node* new_node = kv.second->node();
    if (!updated_nodes.insert(new_node).second) {
      continue;
    }

    inlineCallStackOfNode(
        new_node,
        new_callstack_entries,
        callee,
        to_replace,
        module_instance_info);
  }
  const auto& old_outputs = to_replace->outputs();

  AT_ASSERT(new_outputs.size() == old_outputs.size());
  for (const auto i : c10::irange(old_outputs.size())) {
    if (old_outputs[i]->hasDebugName()) {
      new_outputs[i]->setDebugName(old_outputs[i]->debugName());
    }
    old_outputs[i]->replaceAllUsesWith(new_outputs[i]);
  }
  to_replace->destroy();

  return new_outputs;
}

// inline_optimized_graph argument is used in substitute function call for
// ONNX conversion
std::vector<Value*> inlineCallTo(
    Node* to_replace,
    GraphFunction* callee,
    bool inline_optimized_graph /*=true*/) {
  auto graph =
      inline_optimized_graph ? callee->optimized_graph() : callee->graph();
  return inlineCallTo(to_replace, callee, graph.get());
}

std::vector<Value*> unpackOutputs(const std::vector<Value*>& outputs) {
  std::vector<Value*> new_outputs;
  if (outputs.size() != 1 || outputs.at(0)->type()->kind() != TupleType::Kind) {
    return outputs;
  }

  auto tup = outputs[0];
  for (Value* v : createTupleUnpack(tup)) {
    new_outputs.emplace_back(v);
  }
  // if this was a peephole tuple unpack we can just get rid of
  // the tuple construct here and prevent needing DCE
  if (tup->node()->kind() == prim::TupleConstruct && !tup->node()->hasUses()) {
    tup->node()->destroy();
  }
  return new_outputs;
}

std::vector<Node*> findAllNodes(
    at::ArrayRef<Block*> array,
    Symbol kind,
    bool recurse) {
  std::vector<Node*> ret;
  for (auto block : array) {
    findAllNodes(*block, kind, recurse, ret);
  }
  return ret;
}

std::vector<Node*> findAllNodes(Block& block, Symbol kind, bool recurse) {
  return findAllNodes({&block}, kind, recurse);
}

std::vector<Node*> findAllNodes(Graph& g, Symbol kind, bool recurse) {
  return findAllNodes(*g.block(), kind, recurse);
}

std::vector<Value*> insertGraph(
    Graph& g,
    Graph& callee,
    ArrayRef<Value*> inputs,
    std::unordered_map<Value*, Value*>& value_map) {
  auto value_map_func = [&](Value* v) { return value_map.at(v); };
  AT_ASSERT(callee.inputs().size() == inputs.size());
  for (const auto i : c10::irange(inputs.size())) {
    value_map[callee.inputs()[i]] = inputs[i];
  }
  for (auto* node : callee.nodes()) {
    auto* new_node = g.insertNode(g.createClone(node, value_map_func));
    for (size_t i = 0; i < node->outputs().size(); ++i) {
      value_map[node->outputs()[i]] = new_node->outputs()[i];
    }
  }

  std::vector<Value*> outputs;
  for (auto* output : callee.outputs()) {
    outputs.push_back(value_map_func(output));
  }

  return outputs;
}

std::vector<Value*> insertGraph(
    Graph& g,
    Graph& callee,
    ArrayRef<Value*> inputs) {
  std::unordered_map<Value*, Value*> value_map;
  return insertGraph(g, callee, inputs, value_map);
}

void ProfileOp::cloneFrom(Node* other_) {
  Node::cloneFrom(other_);
  auto other = other_->cast<ProfileOp>();
  this->callback_ = other->getCallback();
}

Node* ProfileOp::allocNewInstance(Graph* g) {
  return new ProfileOp(g, {nullptr});
}

void ProfileIValueOp::cloneFrom(Node* other_) {
  Node::cloneFrom(other_);
  auto other = other_->cast<ProfileIValueOp>();
  this->callback_ = other->getCallback();
}

Node* ProfileIValueOp::allocNewInstance(Graph* g) {
  return new ProfileIValueOp(g, {nullptr});
}

TypePtr NamedValue::type() const {
  if (value_) {
    return value_->type();
  } else {
    return ivalue_.type();
  }
}

const Symbol ProfileOp::Kind = ::c10::prim::profile;
const Symbol ProfileIValueOp::Kind = ::c10::prim::profile_ivalue;

OperatorSet::OperatorSet(std::initializer_list<const char*> sig_literals) {
  insert(sig_literals);
}

std::vector<std::shared_ptr<Operator>> OperatorSet::getOps() const {
  std::vector<std::shared_ptr<Operator>> result;
  for (const auto& kv : ops) {
    auto ops_for_symbol = kv.second;
    result.insert(result.end(), ops_for_symbol.begin(), ops_for_symbol.end());
  }
  return result;
}

void OperatorSet::insert(std::initializer_list<const char*> sig_literals) {
  for (const char* sig : sig_literals) {
    auto op = getOperatorForLiteral(sig);
    ops[Symbol::fromQualString(op->schema().name())].push_back(op);
  }
}

bool Node::isMemberOf(const OperatorSet& os) const {
  auto it = os.ops.find(kind());
  if (it == os.ops.end()) {
    return false;
  }
  for (auto& op : it->second) {
    if (matches(op->schema())) {
      return true;
    }
  }
  return false;
}

} // namespace torch::jit
