#include <gmock/gmock.h>
#include <gtest/gtest.h>

#include <ATen/ATen.h>
#include <ATen/Parallel.h>
#include <ATen/core/interned_strings.h>
#include <ATen/core/ivalue.h>
#include <ATen/core/jit_type_base.h>
#include <c10/macros/Macros.h>
#include <test/cpp/jit/test_utils.h>
#include <torch/csrc/jit/passes/remove_mutation.h>
#include <torch/csrc/jit/passes/tensorexpr_fuser.h>
#include <torch/csrc/jit/tensorexpr/kernel.h>

#include <torch/csrc/autograd/engine.h>
#include <torch/csrc/autograd/generated/variable_factories.h>
#include <torch/csrc/autograd/profiler.h>
#include <torch/csrc/autograd/variable.h>
#include <torch/csrc/jit/api/function_impl.h>
#include <torch/csrc/jit/api/module.h>
#include <torch/csrc/jit/codegen/fuser/interface.h>
#include <torch/csrc/jit/frontend/ir_emitter.h>
#include <torch/csrc/jit/frontend/tracer.h>
#include <torch/csrc/jit/ir/alias_analysis.h>
#include <torch/csrc/jit/ir/attributes.h>
#include <torch/csrc/jit/ir/irparser.h>
#include <torch/csrc/jit/ir/scope.h>
#include <torch/csrc/jit/ir/type_hashing.h>
#include <torch/csrc/jit/jit_log.h>
#include <torch/csrc/jit/passes/bailout_graph.h>
#include <torch/csrc/jit/passes/canonicalize.h>
#include <torch/csrc/jit/passes/common_subexpression_elimination.h>
#include <torch/csrc/jit/passes/constant_propagation.h>
#include <torch/csrc/jit/passes/create_autodiff_subgraphs.h>
#include <torch/csrc/jit/passes/dead_code_elimination.h>
#include <torch/csrc/jit/passes/graph_fuser.h>
#include <torch/csrc/jit/passes/guard_elimination.h>
#include <torch/csrc/jit/passes/inline_autodiff_subgraphs.h>
#include <torch/csrc/jit/passes/insert_guards.h>
#include <torch/csrc/jit/passes/liveness.h>
#include <torch/csrc/jit/passes/loop_unrolling.h>
#include <torch/csrc/jit/passes/lower_grad_of.h>
#include <torch/csrc/jit/passes/lower_tuples.h>
#include <torch/csrc/jit/passes/pass_manager.h>
#include <torch/csrc/jit/passes/requires_grad_analysis.h>
#include <torch/csrc/jit/passes/restore_mutation.h>
#include <torch/csrc/jit/passes/shape_analysis.h>
#include <torch/csrc/jit/passes/symbolic_shape_analysis.h>
#include <torch/csrc/jit/passes/utils/subgraph_utils.h>
#include <torch/csrc/jit/runtime/argument_spec.h>
#include <torch/csrc/jit/runtime/autodiff.h>
#include <torch/csrc/jit/runtime/custom_operator.h>
#include <torch/csrc/jit/runtime/decomposition_registry.h>
#include <torch/csrc/jit/runtime/graph_executor.h>
#include <torch/csrc/jit/runtime/interpreter.h>
#include <torch/csrc/jit/runtime/jit_trace.h>
#include <torch/csrc/jit/runtime/profiling_record.h>
#include <torch/csrc/jit/runtime/symbolic_script.h>
#include <torch/csrc/jit/runtime/symbolic_shape_registry.h>
#include <torch/csrc/jit/serialization/import.h>
#include <torch/csrc/jit/testing/file_check.h>
#include <torch/jit.h>
#include <torch/script.h>

#include <onnx/onnx_pb.h>

#include <c10/util/Exception.h>
#include <c10/util/ThreadLocalDebugInfo.h>

#include <torch/csrc/jit/passes/freeze_module.h>
#include <torch/csrc/jit/passes/frozen_graph_optimizations.h>
#include <algorithm>
#include <cstddef>
#include <functional>
#include <iostream>
#include <memory>
#include <set>
#include <stdexcept>
#include <string>
#include <tuple>
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include <vector>

namespace torch {
namespace jit {
inline c10::AliasAnalysisKind aliasAnalysisFromSchema() {
  return c10::AliasAnalysisKind::FROM_SCHEMA;
}

template <typename T>
std::ostream& operator<<(std::ostream& out, const std::vector<T>& list) {
  size_t i = 0;
  out << "{";
  for (auto&& e : list) {
    if (i++ > 0)
      out << ", ";
    out << e;
  }
  out << "}";
  return out;
}

TEST(InternedStringsTest, Basic) {
  ASSERT_EQ(prim::Param, Symbol::prim("Param"));
  ASSERT_EQ(prim::Return, Symbol::prim("Return"));
  ASSERT_EQ(prim::Return.toUnqualString(), std::string("Return"));
  ASSERT_EQ(prim::Return.toQualString(), std::string("prim::Return"));
  Symbol newsym = Symbol::aten("__NEW_SYMBOL");
  size_t symstart = newsym;
  ASSERT_EQ(newsym.toQualString(), std::string("aten::__NEW_SYMBOL"));
  // TODO: This test is a bit too close to the implementation details.
  ASSERT_EQ(Symbol::aten("What"), symstart + 1);
  ASSERT_EQ(Symbol::aten("What2"), symstart + 2);
  ASSERT_EQ(Symbol::aten("What"), symstart + 1);
  ASSERT_EQ(Symbol::aten("What2"), symstart + 2);
  ASSERT_EQ(Symbol(symstart + 2).toUnqualString(), std::string("What2"));
}

TEST(FromQualStringTest, Basic) {
  ASSERT_EQ(Symbol::fromQualString("prim::Param"), Symbol::prim("Param"));
  ASSERT_EQ(Symbol::fromQualString("aten::mm"), Symbol::aten("mm"));
  ASSERT_EQ(Symbol::fromQualString("onnx::LSTM"), Symbol::onnx("LSTM"));
  ASSERT_EQ(Symbol::fromQualString("attr::value"), Symbol::attr("value"));
  ASSERT_EQ(Symbol::fromQualString("scope::"), Symbol::scope(""));
  ASSERT_EQ(Symbol::fromQualString("::").toUnqualString(), std::string(""));
  ASSERT_EQ(
      Symbol::fromQualString("::").ns().toQualString(),
      std::string("namespaces::"));
  ASSERT_EQ(
      Symbol::fromQualString("new_ns::param").toUnqualString(),
      std::string("param"));
  ASSERT_EQ(
      Symbol::fromQualString("new_ns::param").ns().toUnqualString(),
      std::string("new_ns"));
  ASSERT_EQ(
      Symbol::fromQualString("new_ns::param").ns(),
      Symbol::fromQualString("namespaces::new_ns"));

  auto bad_inputs = {"scope", ":", ""};
  for (auto input : bad_inputs) {
    try {
      Symbol::fromQualString(input);
      ASSERT_TRUE(0);
    } catch (const std::exception& c) {
    }
  }
}

TEST(THNNConvTest, Basic) {
  std::vector<int64_t> input_size = {4, 3, 15, 17}; // B x C x H x W
  std::vector<int64_t> kernel_size = {3, 5};
  std::vector<int64_t> stride = {1, 2};
  std::vector<int64_t> padding = {2, 1};
  constexpr int out_channels = 5;

  // make inputs
  at::Tensor input = torch::randn(input_size);
  at::Tensor weight = torch::randn(
      {out_channels, input_size[1], kernel_size[0], kernel_size[1]});
  at::Tensor bias = torch::randn({out_channels});

  // run forward eagerly
  at::Tensor output = at::_slow_conv2d_forward(
      input, weight, kernel_size, bias, stride, padding);

  // make grad_outputs
  at::Tensor grad_output =
      torch::randn_like(output, at::MemoryFormat::Preserve);

  // run backward eagerly
  auto [grad_input, grad_weight, grad_bias] = at::_slow_conv2d_backward(
      grad_output,
      input,
      weight,
      kernel_size,
      stride,
      padding,
      {true, true, true});

  // make JIT graph
  auto graph = std::make_shared<Graph>();
  auto ksz_val = graph->insertConstant(kernel_size);
  auto kst_val = graph->insertConstant(stride);
  auto pad_val = graph->insertConstant(padding);

  auto inputg = graph->addInput("self");
  auto weightg = graph->addInput("weight");
  auto biasg = graph->addInput("bias");

  Value* conv = graph->insert(
      aten::_slow_conv2d_forward,
      {inputg, weightg, ksz_val, biasg, kst_val, pad_val});
  auto outputs = conv->node()->outputs();
  for (auto output : outputs) {
    graph->registerOutput(output);
  }
  LowerAllTuples(graph);
  graph->lint();

  // differentiate JIT graph
  EliminateDeadCode(graph); // Tracing of some ops depends on the DCE trick
  ConstantPropagation(graph);
  auto grad_spec = differentiate(graph);
  LowerGradOf(*grad_spec.df);

  // prepare JIT inputs / gradients
  tensor_list tensors_in;
  tensors_in.push_back(input);
  tensors_in.push_back(weight);
  tensors_in.push_back(bias);

  tensor_list tensor_grads_in;
  tensor_grads_in.push_back(grad_output);

  // Get outputs from the interpreter
  auto [tensors_out, tensor_grads_out] =
      runGradient(grad_spec, tensors_in, tensor_grads_in);

  // prepare expected structs
  tensor_list expected_tensors_out, expected_tensor_grads_out;
  expected_tensors_out.push_back(output);
  expected_tensor_grads_out.push_back(grad_input);
  expected_tensor_grads_out.push_back(grad_weight);
  expected_tensor_grads_out.push_back(grad_bias);

  // Compare results
  assertAllClose(tensors_out, expected_tensors_out);
  assertAllClose(tensor_grads_out, expected_tensor_grads_out);
}

TEST(ATenNativeBatchNormTest, Basic) {
  // aten::native_batch_norm(Tensor input, Tensor weight, Tensor bias, Tensor
  // running_mean, Tensor running_var, bool training, float momentum, float eps)
  // -> (Tensor, Tensor, Tensor)
  std::vector<int64_t> input_size = {4, 3, 15, 17}; // B x C x H x W
  bool training = true;
  float momentum = 0.9;
  float eps = 1e-5;

  // make inputs
  at::Tensor input = torch::randn(input_size);
  at::Tensor weight = torch::randn({input_size[1]});
  at::Tensor bias = torch::randn({input_size[1]});
  at::Tensor running_mean = torch::randn({input_size[1]});
  at::Tensor running_var = torch::randn({input_size[1]});

  // running_mean and running_var are changed in-place, so clone and send them
  at::Tensor running_mean_eager = running_mean.clone();
  at::Tensor running_var_eager = running_var.clone();
  at::Tensor running_mean_jit = running_mean.clone();
  at::Tensor running_var_jit = running_var.clone();

  // run forward eagerly
  auto [output, savemean, saveinvstd] = at::native_batch_norm(
      input,
      weight,
      bias,
      running_mean_eager,
      running_var_eager,
      training,
      momentum,
      eps);

  // make grad_outputs
  at::Tensor grad_output =
      torch::randn_like(output, at::MemoryFormat::Preserve);
  at::Tensor grad_savemean =
      torch::zeros_like(savemean, at::MemoryFormat::Preserve);
  at::Tensor grad_saveinvstd =
      torch::zeros_like(saveinvstd, at::MemoryFormat::Preserve);

  // run backward eagerly
  // aten::native_batch_norm_backward(Tensor grad_out, Tensor input, Tensor
  // weight, Tensor running_mean, Tensor running_var, Tensor save_mean, Tensor
  // save_invstd, bool train, float eps, bool[3] output_mask) -> (Tensor,
  // Tensor, Tensor)
  auto [grad_input, grad_weight, grad_bias] = at::native_batch_norm_backward(
      grad_output,
      input,
      weight,
      running_mean_eager,
      running_var_eager,
      savemean,
      saveinvstd,
      training,
      eps,
      {true, true, true});

  // make JIT graph
  auto graph = std::make_shared<Graph>();
  auto training_val = graph->insertConstant(IValue(training));
  auto momentum_val = graph->insertConstant(IValue(momentum));
  auto eps_val = graph->insertConstant(IValue(eps));

  auto inputg = graph->addInput("self");
  auto weightg = graph->addInput("weight");
  auto biasg = graph->addInput("bias");
  auto running_meang = graph->addInput("running_mean");
  auto running_varg = graph->addInput("running_var");

  Value* bn = graph->insert(
      aten::native_batch_norm,
      {inputg,
       weightg,
       biasg,
       running_meang,
       running_varg,
       training_val,
       momentum_val,
       eps_val});
  auto outputs = bn->node()->outputs();
  for (auto output : outputs) {
    graph->registerOutput(output);
  }
  LowerAllTuples(graph);
  graph->lint();

  // differentiate JIT graph
  EliminateDeadCode(graph); // Tracing of some ops depends on the DCE trick
  ConstantPropagation(graph);
  auto grad_spec = differentiate(graph);
  LowerGradOf(*grad_spec.df);

  // prepare JIT inputs / gradients
  tensor_list tensors_in;
  tensors_in.push_back(input);
  tensors_in.push_back(weight);
  tensors_in.push_back(bias);
  tensors_in.push_back(running_mean_jit);
  tensors_in.push_back(running_var_jit);

  tensor_list tensor_grads_in;
  tensor_grads_in.push_back(grad_output);
  tensor_grads_in.push_back(grad_savemean);
  tensor_grads_in.push_back(grad_saveinvstd);

  // Get outputs from the interpreter
  auto [tensors_out, tensor_grads_out] =
      runGradient(grad_spec, tensors_in, tensor_grads_in);

  // prepare expected structs
  tensor_list expected_tensors_out, expected_tensor_grads_out;
  expected_tensors_out.push_back(output);
  expected_tensors_out.push_back(savemean);
  expected_tensors_out.push_back(saveinvstd);
  expected_tensors_out.push_back(running_mean_eager);
  expected_tensors_out.push_back(running_var_eager);
  expected_tensor_grads_out.push_back(grad_input);
  expected_tensor_grads_out.push_back(grad_weight);
  expected_tensor_grads_out.push_back(grad_bias);

  tensors_out.push_back(running_mean_jit);
  tensors_out.push_back(running_var_jit);

  // Compare results
  assertAllClose(tensors_out, expected_tensors_out);
  assertAllClose(tensor_grads_out, expected_tensor_grads_out);
}

TEST(CustomFusionTest, Basic) {
#if defined(FBCODE_CAFFE2)
  return;
#endif

  auto graph_string = R"IR(
    graph(%0 : Float(2, 3, 4),
          %1 : Float(2, 3, 4)):
      %2 : Tensor = aten::mul(%0, %1)
      %3 : Tensor = aten::mul(%2, %0)
      return (%3))IR";
  auto g = std::make_shared<Graph>();
  torch::jit::parseIR(graph_string, g.get());

  torch::jit::overrideCanFuseOnCPU(true);
  CustomFuseGraph(
      g,
      [](Node* n) { return n->kind() != prim::Param; },
      Symbol::fromQualString("prim::FusionGroup"));
  torch::jit::overrideCanFuseOnCPU(false);

  const auto& nodes = g->nodes();
  auto fusion_group =
      std::find_if(nodes.begin(), nodes.end(), [](const Node* node) {
        return node->kind() == Symbol::fromQualString("prim::FusionGroup");
      });
  AT_ASSERT(fusion_group != nodes.end());

  auto subgraph = fusion_group->g(attr::Subgraph);
  auto hits = 0;
  // two multiplications
  for (const auto& n : subgraph->nodes()) {
    (void)n;
    hits++;
  }
  AT_ASSERT(hits == 2);
}

TEST(CustomFusionTest, NestedBlocks) {
#if defined(FBCODE_CAFFE2)
  return;
#endif

  auto graph_string = R"IR(
  graph(%0 : Float(2, 3, 4),
        %1 : Float(2, 3, 4),
        %2 : Float(2, 3, 4)):
    %3 : int = prim::Constant[value=1]()
    %4 : Tensor = prim::If(%2)
      block0():
        %5 : Tensor = aten::mul(%0, %2)
        %6 : Tensor = aten::mul(%5, %1)
        -> (%6)
      block1():
        %7 : Tensor = aten::add(%0, %2, %3)
        %8 : Tensor = aten::add(%7, %1, %3)
        -> (%8)
    %9 : Tensor = aten::add(%4, %2, %3)
    return (%4))IR";
  auto g = std::make_shared<Graph>();
  torch::jit::parseIR(graph_string, g.get());

  CustomFuseGraph(
      g,
      [](Node* n) { return n->kind() == aten::mul; },
      Symbol::fromQualString("prim::FusionGroup"));

  // Could be done in more efficient ways, but this is only a test.
  std::function<bool(const Block*, Symbol)> dfs = [&](const Block* b,
                                                      Symbol s) {
    for (auto node : b->nodes()) {
      if (node->kind() == s)
        return true;
      for (auto nested_b : node->blocks())
        if (dfs(nested_b, s))
          return true;
    }
    return false;
  };

  AT_ASSERT(dfs(g->block(), Symbol::fromQualString("prim::FusionGroup")));
}

static const auto cf_examples = R"JIT(
  def if_test(a, b):
      # FIXME: use 0 instead of a.
      # c = 0
      c = a
      if bool(a < b):
        c = b
      else:
        c = a
      return c
  def if_one(a, b):
    c = b
    if bool(a < b):
      c = a
    return c
  def while_test(a, i):
    while bool(i < 3):
      a *= a
      i += 1
    return a
)JIT";

TEST(ControlFlowTest, Basic) {
  auto cu = compile(cf_examples);

  auto run = [&](const std::string& name, std::vector<IValue> stack) {
    auto graph = toGraphFunction(cu->get_function(name)).graph();
    Code code(graph, "");
    InterpreterState interp(code);
    interp.run(stack);
    return stack;
  };

  auto L = [](int64_t l) { return IValue(scalar_to_tensor(at::Scalar(l))); };
  auto V = [](IValue t) { return std::move(t).toTensor().item<int64_t>(); };
  auto run_binary = [&](const std::string& name, int64_t a, int64_t b) {
    return V(run(name, {L(a), L(b)})[0]);
  };
  ASSERT_EQ(2, run_binary("if_test", 1, 2));
  ASSERT_EQ(3, run_binary("if_test", 3, 2));
  ASSERT_EQ(2, run_binary("if_one", 2, 3));
  ASSERT_EQ(2, run_binary("if_one", 3, 2));
  ASSERT_EQ(256, run_binary("while_test", 2, 0));
}

#if !(C10_ASAN_ENABLED || C10_UBSAN_ENABLED)
// This test fails vptr UBSAN checks

TEST(ProtoTest, Basic) {
  ::ONNX_NAMESPACE::ModelProto proto;
  proto.set_producer_name("foo");
}
#endif

// test a few features that are not directly used in schemas yet
TEST(SchemaParserTest, NestedArrays) {
  // nested arrays
  auto s = parseSchema("at::what(int[][4] foo) -> ()");
  ASSERT_TRUE(s.arguments().at(0).N() == 4);
  ASSERT_TRUE(IntType::get()->isSubtypeOf(*s.arguments()
                                               .at(0)
                                               .type()
                                               ->expectRef<ListType>()
                                               .getElementType()
                                               ->expectRef<ListType>()
                                               .getElementType()));
  auto s2 = parseSchema("at::what(int[][] foo) -> ()");
  ASSERT_TRUE(IntType::get()->isSubtypeOf(*s2.arguments()
                                               .at(0)
                                               .type()
                                               ->expectRef<ListType>()
                                               .getElementType()
                                               ->expectRef<ListType>()
                                               .getElementType()));
}

TEST(SchemaParserTest, OutVariant) {
  auto schema_with_out = parseSchema(
      "at::foo(Tensor self, *, Tensor(a!) f, Tensor(b!) l) -> (Tensor(a!) f, Tensor(b!) l)");
  ASSERT_TRUE(schema_with_out.arguments().at(1).is_out());
  ASSERT_TRUE(schema_with_out.arguments().at(2).is_out());

  auto schema_without_out =
      parseSchema("at::foo(Tensor self, *, int scalar) -> (int)");

  for (const auto& arg : schema_without_out.arguments()) {
    ASSERT_TRUE(!arg.is_out());
  }

  auto schema_with_is_write = parseSchema(
      "aten::ne_.Scalar(Tensor(a!) self, Scalar other) -> (Tensor(a!))");

  for (const auto& arg : schema_with_is_write.arguments()) {
    ASSERT_TRUE(!arg.is_out());
  }
}

// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
TEST(SchemaParserTest, NamedReturns) {
  // named returns
  parseSchema("at::what(Tensor! i_will_be_written_to) -> ()");
  auto s3 =
      parseSchema("at::what() -> (Tensor the_return, Tensor the_return2)");
  ASSERT_TRUE(s3.returns().at(0).name() == "the_return");
  ASSERT_TRUE(s3.returns().at(1).name() == "the_return2");
}

TEST(SchemaParserTest, Futures) {
  // futures
  auto s4 = parseSchema("at::what(Future(int) foo) -> ()");
  ASSERT_TRUE(IntType::get()->isSubtypeOf(
      *s4.arguments().at(0).type()->expectRef<FutureType>().getElementType()));
}

TEST(SchemaParserTest, AnnotatedAliasSets) {
  // test tensor with annotated alias sets
  parseSchema("at::what(Tensor(a) foo) -> (Tensor(a))");
}

TEST(SchemaParserTest, TensorListAnnotatedAliasSets) {
  const auto s = parseSchema(
      "at::foo(Tensor(a!) self, Tensor(b!)[] out)"
      " -> ()");
  const AliasInfo* selfAliasInfo = s.arguments().at(0).alias_info();
  const AliasInfo* outAliasInfo = s.arguments().at(1).alias_info();
  ASSERT_TRUE(
      selfAliasInfo->beforeSets() ==
      std::unordered_set<Symbol>{Symbol::fromQualString("alias::a")});
  ASSERT_TRUE(selfAliasInfo->isWrite());

  ASSERT_TRUE(outAliasInfo->isWrite());
  ASSERT_TRUE(outAliasInfo->beforeSets().empty());
  ASSERT_EQ(outAliasInfo->containedTypes().size(), 1);

  auto containedType = outAliasInfo->containedTypes()[0];

  ASSERT_TRUE(containedType.isWrite());
  ASSERT_TRUE(
      containedType.beforeSets() ==
      std::unordered_set<Symbol>{Symbol::fromQualString("alias::b")});
}

TEST(SchemaParserTest, AnnotatedAliasWithoutBeforeSet) {
  EXPECT_THAT(
      []() { parseSchema("at::foo(Tensor(!) self) -> Tensor"); },
      ::testing::Throws<std::runtime_error>(::testing::Property(
          &std::runtime_error::what,
          ::testing::HasSubstr("expected ident but found '!' here"))));
}

TEST(SchemaParserTest, BeforeAfterSets) {
  const auto s = parseSchema(
      "at::what(Tensor(b|c)[](a!) list, Tensor(c) element)"
      " -> (Tensor(b|c)[](a!))");

  // The list itself is annotated with `a`
  const AliasInfo* aliasInfo = s.arguments().at(0).alias_info();
  ASSERT_NE(aliasInfo, nullptr);
  ASSERT_TRUE(
      aliasInfo->beforeSets() ==
      std::unordered_set<Symbol>{Symbol::fromQualString("alias::a")});
  ASSERT_TRUE(aliasInfo->isWrite());

  // Check the contained types
  ASSERT_TRUE(!aliasInfo->containedTypes().empty());
  const auto& containedAliasInfo = aliasInfo->containedTypes()[0];
  const auto expected = std::unordered_set<Symbol>{
      Symbol::fromQualString("alias::b"),
      Symbol::fromQualString("alias::c"),
  };
  ASSERT_TRUE(containedAliasInfo.beforeSets() == expected);
  ASSERT_TRUE(containedAliasInfo.afterSets() == expected);
  ASSERT_FALSE(containedAliasInfo.isWrite());
}

TEST(SchemaParserTest, BeforeAfterSets2) {
  const auto s = parseSchema(
      "at::what(Tensor(b -> b|c)[](a!) list, Tensor(c) element)"
      " -> (Tensor(b|c)[](a!))");

  // The list itself is annotated with `a`
  const AliasInfo* aliasInfo = s.arguments().at(0).alias_info();
  ASSERT_NE(aliasInfo, nullptr);
  ASSERT_EQ(
      aliasInfo->beforeSets(),
      std::unordered_set<Symbol>{Symbol::fromQualString("alias::a")});
  ASSERT_EQ(
      aliasInfo->afterSets(),
      std::unordered_set<Symbol>{Symbol::fromQualString("alias::a")});
  ASSERT_TRUE(aliasInfo->isWrite());
  ASSERT_EQ(aliasInfo->containedTypes().size(), 1);

  // Check the contained types
  ASSERT_TRUE(!aliasInfo->containedTypes().empty());
  const auto& containedAliasInfo = aliasInfo->containedTypes()[0];
  const auto expectedBefore = std::unordered_set<Symbol>{
      Symbol::fromQualString("alias::b"),
  };
  const auto expectedAfter = std::unordered_set<Symbol>{
      Symbol::fromQualString("alias::b"), Symbol::fromQualString("alias::c")};
  ASSERT_TRUE(containedAliasInfo.beforeSets() == expectedBefore);
  ASSERT_TRUE(containedAliasInfo.afterSets() == expectedAfter);
  ASSERT_FALSE(containedAliasInfo.isWrite());
}

TEST(TopologicalIndexTest, Basic) {
  Graph graph;
  auto node1 = graph.create(prim::AutogradZero);
  auto node2 = graph.create(prim::AutogradZero);
  auto node3 = graph.create(prim::AutogradZero);
  auto node4 = graph.create(prim::AutogradZero);

  graph.appendNode(node4);
  graph.prependNode(node1);
  node2->insertAfter(node1);
  node3->insertBefore(node4);

  // nodes should be in numerical order
  ASSERT_TRUE(node1->isBefore(node2));
  ASSERT_TRUE(node1->isBefore(node3));
  ASSERT_TRUE(node1->isBefore(node4));
  ASSERT_TRUE(node2->isAfter(node1));
  ASSERT_TRUE(node2->isBefore(node3));
  ASSERT_TRUE(node2->isBefore(node4));
  ASSERT_FALSE(node3->isBefore(node1));
  ASSERT_FALSE(node3->isBefore(node2));
  ASSERT_FALSE(node3->isAfter(node4));

  // Built up a block structure
  //  node3
  //   /\        ...
  //  A  B     block1
  //      \      ...
  //      C    block2
  auto block1 = node3->addBlock();
  auto A = graph.create(prim::AutogradZero);
  block1->appendNode(A);
  auto B = graph.create(prim::AutogradZero);
  block1->appendNode(B);
  auto block2 = B->addBlock();
  auto C = graph.create(prim::AutogradZero);
  block2->appendNode(C);

  // Check isAfter on different block levels
  ASSERT_TRUE(node1->isBefore(A));
  ASSERT_TRUE(A->isBefore(B));
  ASSERT_TRUE(A->isBefore(C));

  // make sure things don't blow up on deletions
  node2->destroy();
  auto node2p = graph.create(prim::AutogradZero);
  node2p->insertAfter(node1);
  ASSERT_TRUE(node1->isBefore(node2p));
  ASSERT_TRUE(node2p->isBefore(node3));
}

TEST(TopologicalIndexTest, Reindex) {
  // Induce reindexing to test that path
  Graph graph;
  std::map<size_t, Node*> nodes;

  auto anchor = graph.create(prim::AutogradZero);
  graph.appendNode(anchor);
  // Inserting to the same place a lot will trigger reindexing
  for (auto i = 0; i < 100; ++i) {
    auto n = graph.create(prim::AutogradZero);
    n->insertAfter(anchor);
    nodes[i] = n;
  }

  // Nodes should be in reverse order
  for (auto i = 0; i < 100; ++i) {
    for (auto j = i + 1; j < 100; ++j) {
      ASSERT_TRUE(nodes[i]->isAfter(nodes[j]));
    }
  }
}

at::Tensor invokeTestRecordFunction(at::Tensor& t) {
  RECORD_FUNCTION("test", std::vector<c10::IValue>({t}));

  auto t2 = t.pow(2);
  return t2;
}

static const auto invokeTestRecordFunction_JIT = R"JIT(
  def foo(self, t):
    t2 = t.pow(2)
    return t2

  def forward(self, t):
    return self.foo(t)
)JIT";

at::Tensor invokeTestRecordFunctionJIT(at::Tensor& t) {
  RECORD_FUNCTION("test", std::vector<c10::IValue>({t}));

  auto module = std::make_shared<script::Module>(
      "RecordFunctionTestModule", std::make_shared<script::CompilationUnit>());
  module->define(invokeTestRecordFunction_JIT);
  return module->forward({t}).toTensor();
}

using TracedTestValues =
    std::vector<std::tuple<std::string, std::vector<std::vector<int64_t>>>>;

void checkTracedInputs(const TracedTestValues& inputs) {
  bool found_test = false;
  bool found_pow = false;
  bool found_mul = false;
  for (const auto& input : inputs) {
    const auto& fn = std::get<0>(input);
    const auto& sizes = std::get<1>(input);

    if (fn == "test") {
      found_test = true;
      TORCH_CHECK(sizes.size() == 1);
      TORCH_CHECK(sizes[0] == std::vector<int64_t>({1, 2, 3}));
    } else if (fn == "aten::pow") {
      found_pow = true;
      TORCH_CHECK(sizes.size() == 2);
      TORCH_CHECK(sizes[0] == std::vector<int64_t>({1, 2, 3}));
      TORCH_CHECK(sizes[1].empty());
    } else if (fn == "aten::mul") {
      found_mul = true;
      TORCH_CHECK(sizes.size() > 1);
      TORCH_CHECK(sizes[0] == std::vector<int64_t>({1, 2, 3}));
    }
  }
  TORCH_CHECK(found_test);
  TORCH_CHECK(found_pow);
  TORCH_CHECK(found_mul);
}

void checkTracedOutputs(const TracedTestValues& outputs) {
  bool found_test = false;
  bool found_pow = false;
  bool found_mul = false;
  for (const auto& output : outputs) {
    const auto& fn = std::get<0>(output);
    const auto& sizes = std::get<1>(output);

    if (fn == "test") {
      found_test = true;
      TORCH_CHECK(sizes.empty());
    } else if (fn == "aten::pow") {
      found_pow = true;
      TORCH_CHECK(sizes.size() == 1);
      TORCH_CHECK(sizes[0] == std::vector<int64_t>({1, 2, 3}));
    } else if (fn == "aten::mul") {
      found_mul = true;
      TORCH_CHECK(sizes.size() == 1);
      TORCH_CHECK(sizes[0] == std::vector<int64_t>({1, 2, 3}));
    }
  }
  TORCH_CHECK(found_test);
  TORCH_CHECK(found_pow);
  TORCH_CHECK(found_mul);
}

static bool bad_scope = false;
template <RecordScope scope, size_t* cnt>
std::unique_ptr<at::ObserverContext> checkScopeCallback(
    const at::RecordFunction& fn) {
  if (fn.scope() == scope) {
    ++(*cnt);
  } else {
    bad_scope = true;
  }
  return nullptr;
}

template <RecordScope scope, size_t* cnt>
void pushScopedCallback() {
  at::addGlobalCallback(
      at::RecordFunctionCallback(checkScopeCallback<scope, cnt>)
          .scopes({scope}));
}

// These cannot be function-local because that would prohibit them
// from being used as template arguments prior to C++17.
static size_t fun_cnt;
static size_t ts_fun_cnt;
static size_t user_scope_cnt;

void checkScopeCallbacks() {
  static bool found_function_scope;
  static bool found_method_scope;
  static bool found_user_scope;
  found_function_scope = false;
  found_method_scope = false;
  found_user_scope = false;
  at::addGlobalCallback(at::RecordFunctionCallback(
      [](const at::RecordFunction& fn) -> std::unique_ptr<at::ObserverContext> {
        if (fn.scope() == at::RecordScope::FUNCTION &&
            std::string(fn.name()) == "test_function") {
          found_function_scope = true;
        }
        if (fn.scope() == at::RecordScope::TORCHSCRIPT_FUNCTION &&
            std::string(fn.name()) == "test_method") {
          found_method_scope = true;
        }
        if (fn.scope() == at::RecordScope::USER_SCOPE &&
            std::string(fn.name()) == "test_user_scope") {
          found_user_scope = true;
        }
        return nullptr;
      }));

  bad_scope = false;
  fun_cnt = 0;
  pushScopedCallback<at::RecordScope::FUNCTION, &fun_cnt>();
  ts_fun_cnt = 0;
  pushScopedCallback<at::RecordScope::TORCHSCRIPT_FUNCTION, &ts_fun_cnt>();
  user_scope_cnt = 0;
  pushScopedCallback<at::RecordScope::USER_SCOPE, &user_scope_cnt>();

  TORCH_CHECK(at::hasCallbacks());

  {
    RECORD_TORCHSCRIPT_FUNCTION("test_method", {});
    { RECORD_FUNCTION("test_function", {}); }
    { RECORD_USER_SCOPE("test_user_scope"); }
  }

  TORCH_CHECK(!bad_scope);
  TORCH_CHECK(fun_cnt == 1);
  TORCH_CHECK(ts_fun_cnt == 1);
  TORCH_CHECK(user_scope_cnt == 1);

  TORCH_CHECK(found_function_scope);
  TORCH_CHECK(found_method_scope);
  TORCH_CHECK(found_user_scope);
}

static TracedTestValues traced_inputs;
static TracedTestValues traced_outputs;
static std::unordered_set<std::string> ts_input_names;
static std::unordered_set<std::string> ts_output_names;

std::unique_ptr<at::ObserverContext> tracedInputsCallback(
    const RecordFunction& fn) {
  if (fn.scope() == RecordScope::FUNCTION) {
    auto inputs = fn.inputs();
    std::vector<std::vector<int64_t>> sizes;
    for (const auto& input : inputs) {
      if (input.isTensor()) {
        sizes.push_back(input.toTensor().sizes().vec());
      } else if (input.isScalar()) {
        // NOLINTNEXTLINE(modernize-use-emplace)
        sizes.push_back(std::vector<int64_t>());
      }
    }
    traced_inputs.push_back(std::make_tuple(fn.name(), sizes));
  } else if (fn.scope() == RecordScope::TORCHSCRIPT_FUNCTION) {
    ts_input_names.insert(fn.name());
  }
  return nullptr;
}

void tracedOutputsCallback(const RecordFunction& fn, ObserverContext* ctx_ptr) {
  if (fn.scope() == RecordScope::FUNCTION) {
    auto outputs = fn.outputs();
    std::vector<std::vector<int64_t>> sizes;
    for (const auto& output : outputs) {
      if (output.isTensor()) {
        sizes.push_back(output.toTensor().sizes().vec());
      } else if (output.isScalar()) {
        sizes.emplace_back();
      }
    }
    traced_outputs.push_back(std::make_tuple(fn.name(), sizes));
  } else if (fn.scope() == RecordScope::TORCHSCRIPT_FUNCTION) {
    ts_output_names.insert(fn.name());
  }
}

TEST(RecordFunctionTest, TracedTestInputsOutputs) {
  // disabling the inlining of method calls
  GraphOptimizerEnabledGuard opt_guard(false);

  // [(fn, [[sizes], [sizes], ...]), ...]
  addGlobalCallback(
      RecordFunctionCallback(tracedInputsCallback, tracedOutputsCallback)
          .needsInputs(true)
          .needsOutputs(true));

  TracedTestValues eager_inputs, eager_outputs, jit_inputs, jit_outputs;
  {
    auto t = torch::randn({1, 2, 3}, at::kCPU);
    t.set_requires_grad(true);
    auto t2 = invokeTestRecordFunction(t);
    t2.backward(torch::ones_like(t2, at::MemoryFormat::Preserve));
    eager_inputs = traced_inputs;
    eager_outputs = traced_outputs;
    traced_inputs.clear();
    traced_outputs.clear();

    TORCH_CHECK(ts_input_names.empty());
    TORCH_CHECK(ts_output_names.empty());

    t = torch::randn({1, 2, 3}, at::kCPU);
    t.set_requires_grad(true);
    t2 = invokeTestRecordFunctionJIT(t);
    t2.backward(torch::ones_like(t2, at::MemoryFormat::Preserve));
    jit_inputs = traced_inputs;
    jit_outputs = traced_outputs;
    traced_inputs.clear();
    traced_outputs.clear();
  }

  TORCH_CHECK(ts_input_names.find("forward") != ts_input_names.end());
  TORCH_CHECK(ts_input_names.find("foo") != ts_input_names.end());
  TORCH_CHECK(ts_output_names.find("forward") != ts_output_names.end());
  TORCH_CHECK(ts_output_names.find("foo") != ts_output_names.end());

  checkTracedInputs(eager_inputs);
  checkTracedOutputs(eager_outputs);
  checkTracedInputs(jit_inputs);
  checkTracedOutputs(jit_outputs);
  at::clearCallbacks();
}

static int sampled_cb_ctr = 0;
std::unique_ptr<ObserverContext> sampledCallback(const RecordFunction& fn) {
  if (std::string(fn.name()) == "test") {
    ++sampled_cb_ctr;
  }
  return nullptr;
}

static int non_sampled_cb_ctr = 0;
std::unique_ptr<ObserverContext> nonSampledCallback(const RecordFunction& fn) {
  if (std::string(fn.name()) == "test") {
    ++non_sampled_cb_ctr;
  }
  return nullptr;
}

TEST(RecordFunctionTest, SampledCallbacks) {
  // disabling the inlining of method calls
  GraphOptimizerEnabledGuard opt_guard(false);

  // test sampled callbacks
  sampled_cb_ctr = 0;
  auto setup_sampled_callback = [](double sampling_prob) {
    return addGlobalCallback(
        RecordFunctionCallback(sampledCallback).samplingProb(sampling_prob));
  };

  addGlobalCallback(RecordFunctionCallback(nonSampledCallback));

  auto handle = setup_sampled_callback(0.5);

  auto run_test_function = []() {
    auto t = torch::randn({1, 2, 3}, at::kCPU);
    for (auto k = 0; k < 1000; k++) {
      invokeTestRecordFunction(t);
    }
  };

  run_test_function();
  TORCH_CHECK(non_sampled_cb_ctr == 1000);
  TORCH_CHECK(sampled_cb_ctr > 0 && sampled_cb_ctr < 1000);

  sampled_cb_ctr = 0;
  removeCallback(handle);
  handle = setup_sampled_callback(0.0);
  run_test_function();

  TORCH_CHECK(non_sampled_cb_ctr == 2000);
  TORCH_CHECK(sampled_cb_ctr == 0);

  sampled_cb_ctr = 0;
  removeCallback(handle);
  // NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores)
  handle = setup_sampled_callback(1.0);
  run_test_function();

  TORCH_CHECK(non_sampled_cb_ctr == 3000);
  TORCH_CHECK(sampled_cb_ctr == 1000);
  clearCallbacks();

  // test the scope of the callbacks
  checkScopeCallbacks();
  clearCallbacks();
}

TEST(RecordFunctionTest, RecordFunctionGuard) {
  // disabling the inlining of method calls
  GraphOptimizerEnabledGuard opt_guard(false);

  static std::vector<std::string> fn_names;
  static std::mutex guard_mtx;

  // check record function guard
  addGlobalCallback(RecordFunctionCallback(
      [](const RecordFunction& fn) -> std::unique_ptr<at::ObserverContext> {
        std::lock_guard<std::mutex> lock(guard_mtx);
        // NOLINTNEXTLINE(modernize-use-emplace)
        fn_names.push_back(fn.name());
        return nullptr;
      }));
  {
    RecordFunctionGuard g1(false);
    {
      RECORD_USER_SCOPE("A");
      {
        RecordFunctionGuard g2(true);
        RECORD_USER_SCOPE("B");
        {
          DisableRecordFunctionGuard g3;
          RECORD_USER_SCOPE("C");
        }
      }
      { RECORD_USER_SCOPE("D"); }
    }
  }
  TORCH_CHECK(fn_names.size() == 1);
  TORCH_CHECK(fn_names[0] == "B");
  clearCallbacks();
}

static std::vector<size_t> ids;

template <size_t id>
auto add_remove_test_add_cb() {
  return addGlobalCallback(RecordFunctionCallback(
      [](const RecordFunction& fn) -> std::unique_ptr<at::ObserverContext> {
        ids.push_back(id);
        return nullptr;
      }));
}

TEST(RecordFunctionTest, Callbacks) {
  // disabling the inlining of method calls
  GraphOptimizerEnabledGuard opt_guard(false);

  auto h1 = add_remove_test_add_cb<1>();
  add_remove_test_add_cb<2>();
  auto h3 = add_remove_test_add_cb<3>();

  { RECORD_USER_SCOPE("test"); }

  TORCH_CHECK(ids.size() == 3);
  TORCH_CHECK(std::find(ids.begin(), ids.end(), 1) != ids.end());
  TORCH_CHECK(std::find(ids.begin(), ids.end(), 2) != ids.end());
  TORCH_CHECK(std::find(ids.begin(), ids.end(), 3) != ids.end());

  ids.clear();
  removeCallback(h1);

  { RECORD_USER_SCOPE("test"); }

  TORCH_CHECK(ids.size() == 2);
  TORCH_CHECK(std::find(ids.begin(), ids.end(), 2) != ids.end());
  TORCH_CHECK(std::find(ids.begin(), ids.end(), 3) != ids.end());

  ids.clear();
  removeCallback(h3);

  { RECORD_USER_SCOPE("test"); }

  TORCH_CHECK(ids.size() == 1);
  TORCH_CHECK(std::find(ids.begin(), ids.end(), 2) != ids.end());

  clearCallbacks();

  // thread local / global callbacks

  ids.clear();
  add_remove_test_add_cb<1>();

  { RECORD_USER_SCOPE("test"); }

  TORCH_CHECK(ids.size() == 1);
  TORCH_CHECK(ids[0] == 1);
  ids.clear();

  auto th = std::thread([]() {
    addThreadLocalCallback(RecordFunctionCallback(
        [](const RecordFunction& fn) -> std::unique_ptr<at::ObserverContext> {
          ids.push_back(2);
          return nullptr;
        }));

    { RECORD_USER_SCOPE("test_thread"); }
  });
  th.join();
  TORCH_CHECK(ids.size() == 2);
  TORCH_CHECK(std::find(ids.begin(), ids.end(), 1) != ids.end());
  TORCH_CHECK(std::find(ids.begin(), ids.end(), 2) != ids.end());
  ids.clear();

  { RECORD_USER_SCOPE("test"); }

  TORCH_CHECK(ids.size() == 1);
  TORCH_CHECK(ids[0] == 1);
  ids.clear();

  clearCallbacks();

  // START: thread local / global context check callbacks
  struct TestContext : public ObserverContext {
    int a{0};
    std::string b;
  };
  ids.clear();
  { // START: global test
    addGlobalCallback(RecordFunctionCallback(
        [](const RecordFunction&
           /* unused */) -> std::unique_ptr<at::ObserverContext> {
          auto ctx = std::make_unique<TestContext>();
          ctx->a = 123;
          ctx->b = "test_str";
          ids.push_back(1);
          return ctx;
        },
        [](const RecordFunction& /* unused */, ObserverContext* ctx_ptr) {
          auto ctx = dynamic_cast<TestContext*>(ctx_ptr);
          TORCH_CHECK(ctx != nullptr);
          TORCH_CHECK(ctx->a == 123);
          TORCH_CHECK(ctx->b == "test_str");
        }));

    { RECORD_USER_SCOPE("test"); }

    TORCH_CHECK(ids.size() == 1);
    TORCH_CHECK(ids[0] == 1);
    ids.clear();
  } // END: global test
  { // START: thread local test
    auto ctx_th = std::thread([]() {
      const std::string test_str = "test thread str";
      addThreadLocalCallback(RecordFunctionCallback(
          [](const RecordFunction&
             /* unused */) -> std::unique_ptr<at::ObserverContext> {
            auto ctx = std::make_unique<TestContext>();
            ctx->a = 234;
            ctx->b = "test_thread_str";
            ids.push_back(2);
            return ctx;
          },
          [](const RecordFunction& /* unused */, ObserverContext* ctx_ptr) {
            auto ctx = dynamic_cast<TestContext*>(ctx_ptr);
            TORCH_CHECK(ctx_ptr != nullptr);
            TORCH_CHECK(ctx->a == 234);
            TORCH_CHECK(ctx->b == "test_thread_str");
          }));

      // Will call both global and thread local callbacks.
      { RECORD_USER_SCOPE("test_thread"); }
    });
    ctx_th.join();
    TORCH_CHECK(ids.size() == 2);
    TORCH_CHECK(std::find(ids.begin(), ids.end(), 1) != ids.end());
    TORCH_CHECK(std::find(ids.begin(), ids.end(), 2) != ids.end());
    ids.clear();
  } // END: thread local test

  clearCallbacks();
}

TEST(RecordFunctionTest, ShouldRun) {
  // disabling the inlining of method calls
  GraphOptimizerEnabledGuard opt_guard(false);

  static bool ran = false;
  auto handle = addGlobalCallback(RecordFunctionCallback(
      [](const RecordFunction& fn) -> std::unique_ptr<at::ObserverContext> {
        ran = true;
        return nullptr;
      }));

  { RECORD_USER_SCOPE("test"); }

  EXPECT_TRUE(ran) << "first run didn't happen";
  ran = false;

  disableCallback(handle);

  { RECORD_USER_SCOPE("test"); }

  EXPECT_FALSE(ran) << "second run happened but shouldn't have";
  ran = false;

  reenableCallback(handle);

  { RECORD_USER_SCOPE("test"); }

  EXPECT_TRUE(ran) << "run after re-enable didn't happen";
  ran = false;

  clearCallbacks();
}

TEST(RecordFunctionTest, Basic) {
  // disabling the inlining of method calls
  GraphOptimizerEnabledGuard opt_guard(false);

  static std::string recorded_op;
  static bool has_ids = false;

  // test propagation of TLS callbacks
  std::thread t([]() {
    RecordFunctionGuard enable_rec_fn;
    auto handle = addThreadLocalCallback(RecordFunctionCallback(
        [](const RecordFunction& fn) -> std::unique_ptr<at::ObserverContext> {
          recorded_op = fn.name();
          return nullptr;
        }));
    ThreadLocalState state;
    std::thread t_child([state]() {
      ThreadLocalStateGuard g_tls(state);
      RECORD_USER_SCOPE("test_in_thread");
    });
    t_child.join();
    EXPECT_EQ(recorded_op, "test_in_thread");
    removeCallback(handle);
  });
  t.join();
  clearCallbacks();

  // test set ids
  addGlobalCallback(
      RecordFunctionCallback(
          [](const RecordFunction& fn) -> std::unique_ptr<at::ObserverContext> {
            has_ids = fn.handle() > 0;
            return nullptr;
          })
          .needsIds(true));
  { RECORD_USER_SCOPE("test"); }
  TORCH_CHECK(has_ids);
  clearCallbacks();
  has_ids = false;
  addGlobalCallback(RecordFunctionCallback(
      [](const RecordFunction& fn) -> std::unique_ptr<at::ObserverContext> {
        has_ids = fn.handle() > 0;
        return nullptr;
      }));
  { RECORD_USER_SCOPE("test"); }
  TORCH_CHECK(!has_ids);
  clearCallbacks();
}

TEST(RecordFunctionTest, OperatorNameOverload) {
  static std::set<std::string> operator_names;
  at::addGlobalCallback(at::RecordFunctionCallback(
                            [](const at::RecordFunction& fn)
                                -> std::unique_ptr<at::ObserverContext> {
                              std::optional<c10::OperatorName> op_name =
                                  fn.operator_name();
                              if (op_name.has_value()) {
                                operator_names.insert(c10::toString(*op_name));
                              } else {
                                operator_names.insert("No Operator Name");
                              }
                              return nullptr;
                            })
                            .scopes({at::RecordScope::FUNCTION}));
  auto t = torch::randn({1, 2, 3}, at::kCPU);
  t.set_requires_grad(false);
  auto t2 = t.pow(2);

  at::clearCallbacks();
  EXPECT_TRUE(operator_names.count("No Operator Name") == 0)
      << "Expected that all traced operators had an associated OperatorName object";
  EXPECT_TRUE(operator_names.count("aten::randn") == 1)
      << "Expected aten::randn to have been called and recorded, but it was not";
  EXPECT_TRUE(operator_names.count("aten::pow.Tensor_Scalar") == 1)
      << "Expected aten::pow.Tensor_Scalar to have been called and recorded, but it was not";
}

class TestThreadLocalDebugInfo : public c10::DebugInfoBase {
 public:
  int getModelId() const {
    return model_id_;
  }

  void setModelId(int model_id) {
    model_id_ = model_id;
  }

  // NOLINTNEXTLINE(modernize-use-equals-default)
  virtual ~TestThreadLocalDebugInfo() override {}

 private:
  int model_id_ = 0;
};

void checkDebugInfo(c10::DebugInfoKind kind, int model_id) {
  auto* debug_info = c10::ThreadLocalDebugInfo::get(kind);
  TORCH_CHECK(debug_info != nullptr);
  auto* test_debug_info = dynamic_cast<TestThreadLocalDebugInfo*>(debug_info);
  TORCH_CHECK(test_debug_info != nullptr);
  TORCH_CHECK(test_debug_info->getModelId() == model_id);
}

TEST(ThreadLocalDebugInfoTest, Basic) {
  static std::atomic<bool> done{false};

  TORCH_CHECK(
      c10::ThreadLocalDebugInfo::get(c10::DebugInfoKind::TEST_INFO) == nullptr);
  auto debug_info = std::make_shared<TestThreadLocalDebugInfo>();
  debug_info->setModelId(42);
  {
    c10::DebugInfoGuard guard(c10::DebugInfoKind::TEST_INFO, debug_info);
    checkDebugInfo(c10::DebugInfoKind::TEST_INFO, 42);
  }

  // check that thread local debug info is propagated through fork calls
  TORCH_CHECK(
      c10::ThreadLocalDebugInfo::get(c10::DebugInfoKind::TEST_INFO) == nullptr);
  {
    c10::DebugInfoGuard guard(c10::DebugInfoKind::TEST_INFO, debug_info);
    at::launch([]() {
      checkDebugInfo(c10::DebugInfoKind::TEST_INFO, 42);
      done = true;
    });
  }
  while (!done) {
  }

  // check that thread local debug info is propagated through backward pass
  TORCH_CHECK(
      c10::ThreadLocalDebugInfo::get(c10::DebugInfoKind::TEST_INFO) == nullptr);
  done = false;
  auto handle = addGlobalCallback(RecordFunctionCallback(
      [](const RecordFunction&) -> std::unique_ptr<at::ObserverContext> {
        checkDebugInfo(c10::DebugInfoKind::TEST_INFO, 42);
        done = true;
        return nullptr;
      }));
  {
    c10::DebugInfoGuard guard(c10::DebugInfoKind::TEST_INFO, debug_info);
    auto t = torch::randn({1, 2, 3}, at::kCPU);
    t.set_requires_grad(true);
    auto t2 = t.pow(2);
    t2.backward(torch::ones_like(t2, at::MemoryFormat::Preserve));
  }
  removeCallback(handle);
  TORCH_CHECK(done);

  // check nested debug info
  TORCH_CHECK(
      c10::ThreadLocalDebugInfo::get(c10::DebugInfoKind::TEST_INFO) == nullptr);
  {
    c10::DebugInfoGuard guard(c10::DebugInfoKind::TEST_INFO, debug_info);
    {
      checkDebugInfo(c10::DebugInfoKind::TEST_INFO, 42);
      {
        auto debug_info = std::make_shared<TestThreadLocalDebugInfo>();
        debug_info->setModelId(314);
        c10::DebugInfoGuard guard(c10::DebugInfoKind::TEST_INFO_2, debug_info);
        {
          checkDebugInfo(c10::DebugInfoKind::TEST_INFO, 42);
          checkDebugInfo(c10::DebugInfoKind::TEST_INFO_2, 314);
          done = false;
          at::launch([]() {
            checkDebugInfo(c10::DebugInfoKind::TEST_INFO, 42);
            checkDebugInfo(c10::DebugInfoKind::TEST_INFO_2, 314);
            done = true;
          });
          while (!done) {
          }
        }
      }
    }
  }
}

TEST(TestSymIntArrayRef, BasicConversion) {
  const size_t X = 2, Y = 4, Z = 5;
  std::vector<int64_t> tgt_size_v{2, 4, 5};
  std::vector<c10::SymInt> tgt_size({SymInt(X), SymInt(Y), SymInt(Z)});
  auto a = at::randn({1, 4, 1}, at::kCPU);
  auto b = a.expand_symint(tgt_size);
  auto c = a.expand(tgt_size_v);
  ASSERT_TRUE(torch::allclose(b, c));
}

TEST(TestSymInt, NarrowCopyWithSymbolicInt) {
  static const size_t LENGTH = 5;
  auto a = at::randn({10}, at::kCPU);
  c10::SymInt si(LENGTH);
  auto b = a.narrow_copy_symint(0, 0, si);
  auto c = a.narrow(0, 0, LENGTH);
  ASSERT_TRUE(torch::allclose(b, c));
}

TEST(TestSymInt, NarrowCopy) {
  static const size_t LENGTH = 5;
  auto a = at::randn({10}, at::kCPU);
  auto b = a.narrow_copy(0, 0, LENGTH);
  auto c = a.narrow(0, 0, LENGTH);
  ASSERT_TRUE(torch::allclose(b, c));
}

TEST(TestSymInt, AddSymbolicInt) {
  c10::SymInt a(5);
  c10::SymInt b(3);
  ASSERT_TRUE((a + b).expect_int() == 8);
}

TEST(FallbackGraphsTest, Basic) {
  auto x = at::randn({1}, at::kCPU);
  auto y = at::randn({1}, at::kCPU);
  auto stack = createStack({x.clone(), y.clone()});

  auto graph_string = R"IR(
    graph(%0 : Float(1),
          %1 : Float(1)):
      %2 : Tensor = aten::mul(%0, %1)
      %3 : Tensor = aten::mul(%2, %0)
      return (%3))IR";
  auto graph = std::make_shared<Graph>();
  torch::jit::parseIR(graph_string, graph.get());

  {
    Code code(graph, "");
    InterpreterState interpreter{code};
    interpreter.run(stack);
  }
  at::Tensor et;
  pop(stack, et);
  float ef = et.item<float>();
  {
    EnableProfilingGuard epg;
    GraphFunction f("fallbackGraphs", graph, nullptr);
    for (size_t i = 0; i < getNumProfiledRuns() + 1; i++) {
      stack.emplace_back(x.clone());
      stack.emplace_back(y.clone());
      if (i == getNumProfiledRuns()) {
        // we will be modifying a profiled graph
        // before ProfilingGraphExecutor
        // will optimize it in the next iteration
        auto opt_graph = lastExecutedOptimizedGraph();
        // this is safe to do since we are done profiling
        ProfilingRecord::removeProfileCounter(opt_graph->block());
        replaceBlockWithFallbackGraph(opt_graph->block(), opt_graph->inputs());
        auto it = opt_graph->block()->nodes().begin();
        ASSERT_EQ(it->kind(), prim::FallbackGraph);
        auto fallback = *it++;
        ASSERT_EQ(it, opt_graph->block()->nodes().end());
        ASSERT_TRUE(fallback->hasAttribute(attr::Subgraph));
        testing::FileCheck()
            .check("Tensor = aten::mul")
            ->check("Tensor = aten::mul")
            ->run(*fallback->g(attr::Subgraph));
      }
      f.run(stack);
      at::Tensor at;
      pop(stack, at);
      float af = at.item<float>();
      ASSERT_EQ(af, ef);
    }

    auto opt_graph = lastExecutedOptimizedGraph();
    testing::FileCheck()
        .check("(Tensor) = prim::CallFunction")
        ->run(*opt_graph);
  }
}

// TODO this test wasn't running and is broken.
// TEST(AutogradProfilerTest, Basic) {
//   constexpr int batch_size = 4;
//   constexpr int input_size = 256;
//   constexpr int seq_len = 32;

//   int hidden_size = 2 * input_size;
//   auto input = torch::randn({seq_len, batch_size, input_size}, at::kCPU);
//   auto hx = torch::randn({batch_size, hidden_size}, at::kCPU);
//   auto cx = torch::randn({batch_size, hidden_size}, at::kCPU);
//   auto w_ih = t_def(torch::randn({4 * hidden_size, input_size}, at::kCPU));
//   auto w_hh = t_def(torch::randn({4 * hidden_size, hidden_size}, at::kCPU));

//   std::stringstream ss;
//   {
//     RecordProfile guard(ss);
//     for (size_t i = 0; i < 100; ++i) {
//       std::tie(hx, cx) = lstm(input[0], hx, cx, w_ih, w_hh);
//     }
//   }

//   std::string result = ss.str();
//   size_t count = 0;
//   for (size_t pos = 0; (pos = result.find("tanh", pos)) != std::string::npos;
//        count++, pos++) {
//   }
//   ASSERT_EQ((count, 200);
// }

TEST(NoneSchemaMatchTest, Basic) {
  RegisterOperators reg({
      Operator(
          "prim::test_none() -> int?",
          [](Stack& stack) { push(stack, IValue()); },
          aliasAnalysisFromSchema()),
      Operator(
          "prim::is_none(int? a) -> bool",
          [](Stack& stack) {
            IValue a = pop(stack);
            if (a.isNone()) {
              push(stack, true);
            } else {
              push(stack, false);
            }
          },
          aliasAnalysisFromSchema()),
  });

  // Constant propagation will run test_none and produce a None,
  // testing that its type is set appropriately and schema matching  doesn't
  // fail when running is_none

  auto r = std::make_shared<Graph>();
  auto& g = *r;
  auto opt_int = g.insert(Symbol::fromQualString("prim::test_none"), {});
  auto out_bool = g.insert(Symbol::fromQualString("prim::is_none"), {opt_int});
  g.registerOutput(out_bool);
  ConstantPropagation(r);

  auto nodes = r->block()->nodes();
  // checking that constant propagation ran wo/failure
  AT_ASSERT(std::distance(nodes.begin(), nodes.end()) == 1);
}

static int testPassValue = 0;
void fakePass(std::shared_ptr<Graph>& g) {
  testPassValue++;
  return;
}

RegisterPass p(fakePass);

TEST(PassManagementTest, Basic) {
  std::shared_ptr<Graph> graph = std::make_shared<Graph>();
  parseIR(
      R"IR(
graph(%a):
  return (%a))IR",
      &*graph);

  std::vector<IValue> stack = {IValue(torch::randn({22}, at::kCPU))};
  auto run = [&](std::shared_ptr<Graph>& graph, std::vector<IValue> stack) {
    GraphExecutor executor(graph, "");
    executor.run(stack);
    return stack;
  };
  run(graph, stack);
  // we will not run fusion in simple mode
  if (!getExecutorMode()) {
    AT_ASSERT(testPassValue);
  }
}

static void checkShape(TypePtr typ, std::vector<int64_t> expected) {
  auto ptp = typ->expect<TensorType>();
  ASSERT_EQ(ptp->sizes().concrete_sizes().value(), expected);
}

static void checkShape(
    Node* n,
    std::vector<int64_t> expected,
    bool prev = true) {
  auto profile = (prev) ? n->inputs().at(0)->node() : n;
  checkShape(profile->output()->type(), expected);
}

void count_(
    Block* block,
    const std::function<bool(Node* n)>& pred,
    size_t& count) {
  for (Node* n : block->nodes()) {
    if (pred(n)) {
      count++;
    }

    for (Block* ib : n->blocks()) {
      count_(ib, pred, count);
    }
  }
}

size_t countNodes(
    const std::shared_ptr<Graph>& graph,
    const std::function<bool(Node* n)>& pred) {
  size_t count = 0;
  count_(graph->block(), pred, count);
  return count;
}

bool true_pred(Node* n) {
  return true;
};

bool is_loop(Node* n) {
  return n->kind() == prim::Loop;
};

TEST(LoopPeelerTest, NoInductionVariableUse) {
  // do not use an induction variable explicitly
  static const auto str_func_def = R"JIT(
    def test_peel_n_times():
      sum = 0
      for i in range(10):
        sum += 2
      return sum
    )JIT";

  auto cu = compile(str_func_def);
  auto& f = toGraphFunction(cu->get_function("test_peel_n_times"));
  auto stack = createStack({});
  // peeling loop once
  {
    LoopsPeeler peeler(true_pred, 1);
    auto copy = f.graph()->copy();
    peeler.run(copy);
    int num_loops =
        std::count_if(copy->nodes().begin(), copy->nodes().end(), is_loop);
    ASSERT_EQ(num_loops, 2);
    Code code(copy, "");
    InterpreterState interpreter{code};
    interpreter.run(stack);
    ASSERT_EQ(stack.back().toInt(), 20);
  }

  // test peeling more than one iteration
  {
    LoopsPeeler peeler(true_pred, 3);
    auto copy = f.graph()->copy();
    peeler.run(copy);
    int num_loops =
        std::count_if(copy->nodes().begin(), copy->nodes().end(), is_loop);
    ASSERT_EQ(num_loops, 2);
    Code code(copy, "");
    InterpreterState interpreter{code};
    interpreter.run(stack);
    ASSERT_EQ(stack.back().toInt(), 20);
  }
}

TEST(LoopPeelerTest, YesInductionVariableUse) {
  // uses the induction variable
  static const auto str_func_def = R"JIT(
    def test_peel_n_times():
      sum = 0
      for i in range(10):
        sum += i
      return sum
    )JIT";

  auto cu = compile(str_func_def);
  auto& f = toGraphFunction(cu->get_function("test_peel_n_times"));
  auto stack = createStack({});
  // peeling loop once
  {
    LoopsPeeler peeler(true_pred, 1);
    auto copy = f.graph()->copy();
    peeler.run(copy);
    int num_loops =
        std::count_if(copy->nodes().begin(), copy->nodes().end(), is_loop);
    ASSERT_EQ(num_loops, 2);
    Code code(copy, "");
    InterpreterState interpreter{code};
    interpreter.run(stack);
    ASSERT_EQ(stack.back().toInt(), 45);
  }

  // test peeling more than one iteration
  {
    LoopsPeeler peeler(true_pred, 3);
    auto copy = f.graph()->copy();
    peeler.run(copy);
    int num_loops =
        std::count_if(copy->nodes().begin(), copy->nodes().end(), is_loop);
    ASSERT_EQ(num_loops, 2);
    Code code(copy, "");
    InterpreterState interpreter{code};
    interpreter.run(stack);
    ASSERT_EQ(stack.back().toInt(), 45);
  }
}

TEST(LoopPeelerTest, LoopWithTerminationCondition) {
  // tests with explicit termination conditions
  static const auto str_func_def = R"JIT(
    def test_with_cond_times():
      sum = 0
      i = 0
      while (sum < 2):
        sum += i
        i += 1
      return sum
    )JIT";

  // the peel changes the termination condition to false
  // so the original loop doesn't run
  auto cu = compile(str_func_def);
  auto& f = toGraphFunction(cu->get_function("test_with_cond_times"));
  auto stack = createStack({});
  // peeling 5 iterations should update the termination
  // condition to false
  {
    LoopsPeeler peeler(true_pred, 5);
    auto copy = f.graph()->copy();
    peeler.run(copy);
    int num_loops =
        std::count_if(copy->nodes().begin(), copy->nodes().end(), is_loop);
    ASSERT_EQ(num_loops, 2);
    Code code(copy, "");
    InterpreterState interpreter{code};
    interpreter.run(stack);
    ASSERT_EQ(stack.back().toInt(), 3);
  }

  // the termination condition remains true
  {
    LoopsPeeler peeler(true_pred, 1);
    auto copy = f.graph()->copy();
    peeler.run(copy);
    int num_loops =
        std::count_if(copy->nodes().begin(), copy->nodes().end(), is_loop);
    ASSERT_EQ(num_loops, 2);
    Code code(copy, "");
    InterpreterState interpreter{code};
    interpreter.run(stack);
    ASSERT_EQ(stack.back().toInt(), 3);
  }
}

// tests simple nested loops
TEST(LoopPeelerTest, SimpleNestedLoops) {
  static const auto str_func_def = R"JIT(
    def test_nested_loops():
      sum = 0
      i = 0
      for i in range(10):
        for j in range(10):
          sum += i + j
      return sum
    )JIT";

  auto cu = compile(str_func_def);
  auto& f = toGraphFunction(cu->get_function("test_nested_loops"));
  auto stack = createStack({});

  {
    LoopsPeeler peeler(true_pred, 1);
    auto copy = f.graph()->copy();
    peeler.run(copy);
    ASSERT_EQ(countNodes(copy, is_loop), 5);
    Code code(copy, "");
    InterpreterState interpreter{code};
    interpreter.run(stack);
    ASSERT_EQ(stack.back().toInt(), 900);
  }

  {
    LoopsPeeler peeler(true_pred, 5);
    auto copy = f.graph()->copy();
    peeler.run(copy);
    ASSERT_EQ(countNodes(copy, is_loop), 5);
    Code code(copy, "");
    InterpreterState interpreter{code};
    interpreter.run(stack);
    ASSERT_EQ(stack.back().toInt(), 900);
  }
}

TEST(LoopPeelerTest, SimpleNestedLoops2) {
  static const auto str_func_def = R"JIT(
    def test_nested_loops():
      sum = 0
      i = 0
      for i in range(10):
        j = 0
        while sum < 2:
          sum += i + j
          j += 1
      return sum
    )JIT";

  auto cu = compile(str_func_def);
  auto& f = toGraphFunction(cu->get_function("test_nested_loops"));
  auto stack = createStack({});
  {
    LoopsPeeler peeler(true_pred, 1);
    auto copy = f.graph()->copy();
    peeler.run(copy);
    ASSERT_EQ(countNodes(copy, is_loop), 5);
    Code code(copy, "");
    InterpreterState interpreter{code};
    interpreter.run(stack);
    ASSERT_EQ(stack.back().toInt(), 3);
  }

  {
    LoopsPeeler peeler(true_pred, 5);
    auto copy = f.graph()->copy();
    peeler.run(copy);
    ASSERT_EQ(countNodes(copy, is_loop), 5);
    Code code(copy, "");
    InterpreterState interpreter{code};
    interpreter.run(stack);
    ASSERT_EQ(stack.back().toInt(), 3);
  }
}

TEST(JitTracing, Basic) {
  constexpr int batch_size = 4;
  constexpr int input_size = 256;

  int hidden_size = 2 * input_size;

  auto input = at::randn({batch_size, input_size}, at::kCPU);
  auto hx = at::randn({batch_size, hidden_size}, at::kCPU);
  auto cx = at::randn({batch_size, hidden_size}, at::kCPU);
  auto w_ih = t_def(at::randn({4 * hidden_size, input_size}, at::kCPU));
  auto w_hh = t_def(at::randn({4 * hidden_size, hidden_size}, at::kCPU));

  auto graph = build_lstm();
  auto stack = createStack({input, hx, cx, w_ih, w_hh});
  auto traced = TraceGraph(graph, stack);

  // Check that the inputs of traced graph have the same type as the inputs
  // specified here.
  ASSERT_EQ(*traced->inputs().at(0)->type(), *TensorType::create(input));
  ASSERT_EQ(*traced->inputs().at(1)->type(), *TensorType::create(hx));
  ASSERT_EQ(*traced->inputs().at(2)->type(), *TensorType::create(cx));
  ASSERT_EQ(*traced->inputs().at(3)->type(), *TensorType::create(w_ih));
  ASSERT_EQ(*traced->inputs().at(4)->type(), *TensorType::create(w_hh));

  Tensor prof_out;
  pop(stack, prof_out);

  {
    stack = createStack({input, hx, cx, w_ih, w_hh});
    Code cd(traced, "traced");
    InterpreterState is{cd};
    is.run(stack);
    Tensor traced_out;
    pop(stack, traced_out);
    torch::allclose(prof_out, traced_out);
  }

  {
    stack = createStack({input, hx, cx, w_ih, w_hh});
    Code cd(graph, "graph");
    InterpreterState is{cd};
    is.run(stack);
    Tensor scripted_out;
    pop(stack, scripted_out);
    torch::allclose(prof_out, scripted_out);
  }
}

// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
TEST(InsertAndEliminateRedundantGuardsTest, Basic) {
  static const auto basic_example = R"JIT(
  def basic(x, y):
    a = x + y
    b = x * y
    c = x + 1
    d = a - c
    e = b - c
    return d + e
  )JIT";

  auto cu = compile(basic_example);
  auto& fun = toGraphFunction(cu->get_function("basic"));
  auto pr = ProfilingRecord::instrumentGraph(fun.graph());
  auto x = at::randn({2, 3}, at::kCPU);
  auto y = at::randn({2, 3}, at::kCPU);
  auto stack = createStack({x, y});
  // introduce some profiling information
  Code cd(pr->profiled_graph_, "");
  InterpreterState is{cd};
  is.run(stack);
  auto copy = pr->profiled_graph_->copy();
  ProfilingRecord::removeProfileCounter(copy->block());
  InsertGuards(copy);
  auto nodes = copy->block()->nodes();
  auto guard = std::find_if(nodes.begin(), nodes.end(), [](Node* n) {
    return n->kind() == prim::Guard;
  });
  ASSERT_NE(guard, nodes.end());
  ASSERT_EQ(
      guard->input()->type()->expectRef<TensorType>().sizes().size(),
      std::nullopt);
  checkShape(*guard, {2, 3}, false);
  auto is_guard = [](Node* n) { return n->kind() == prim::Guard; };
  int num_guards = std::count_if(nodes.begin(), nodes.end(), is_guard);
  ASSERT_EQ(num_guards, 12);
  // now eliminate as many guards as possible
  // we should be left with two guards on x and y's defs
  EliminateRedundantGuards(copy);
  num_guards = std::count_if(nodes.begin(), nodes.end(), is_guard);
  ASSERT_EQ(num_guards, 2);
}

TEST(InsertBailOutsTest, Basic) {
  static const auto basic_example = R"JIT(
  def basic_loop(x, y):

      a = x + 1
      b = y + 2
      c = x + y + 3

      for i in range(10):
          a = a + b
          # invariant
          d = b * c
          #
          a = a - d

      e = a + 4
      return e
  )JIT";

  auto cu = compile(basic_example);
  auto& fun = toGraphFunction(cu->get_function("basic_loop"));
  auto pr = ProfilingRecord::instrumentGraph(fun.graph());
  auto x = at::randn({2, 3}, at::kCPU);
  auto y = at::randn({2, 3}, at::kCPU);
  auto stack = createStack({x, y});
  // introduce some profiling information
  Code cd(pr->profiled_graph_, "");
  InterpreterState is{cd};
  is.run(stack);
  auto copy = pr->profiled_graph_->copy();
  ProfilingRecord::removeProfileCounter(copy->block());
  InsertGuards(copy);
  EliminateRedundantGuards(copy);
  auto nodes = copy->block()->nodes();
  auto is_guard = [](Node* n) { return n->kind() == prim::Guard; };
  auto num_guards = std::count_if(nodes.begin(), nodes.end(), is_guard);
  ASSERT_EQ(num_guards, 3);
  InsertBailOuts(copy);
  auto is_bailout = [](Node* n) { return n->kind() == prim::BailOut; };
  auto num_bailouts = std::count_if(nodes.begin(), nodes.end(), is_bailout);
  ASSERT_EQ(num_guards, num_bailouts);
  std::vector<Node*> bailouts(num_bailouts);
  std::copy_if(nodes.begin(), nodes.end(), bailouts.begin(), is_bailout);

  for (auto blo : bailouts) {
    ASSERT_EQ(blo->inputs().at(0)->node()->kind(), prim::BailoutTemplate);
  }
}

TEST(ProfilerTest, Basic) {
  constexpr int batch_size = 4;
  constexpr int input_size = 256;

  int hidden_size = 2 * input_size;

  auto input = at::randn({batch_size, input_size}, at::kCPU);
  auto hx = at::randn({batch_size, hidden_size}, at::kCPU);
  auto cx = at::randn({batch_size, hidden_size}, at::kCPU);
  auto w_ih = t_def(at::randn({4 * hidden_size, input_size}, at::kCPU));
  auto w_hh = t_def(at::randn({4 * hidden_size, hidden_size}, at::kCPU));

  auto g = build_lstm();
  auto stack = createStack({input, hx, cx, w_ih, w_hh});

  auto& opt_graph = *g.get();
  ArgumentSpecCreator arg_spec_creator(opt_graph);
  ArgumentSpec spec =
      arg_spec_creator.create(autograd::GradMode::is_enabled(), stack);
  arg_spec_creator.specializeTypes(opt_graph, spec);
  auto pr = ProfilingRecord::instrumentGraph(g);
  Code cd(pr->profiled_graph_, "");
  InterpreterState is{cd};
  is.run(stack);

  // profiled types are stored as attributes and show up in the dump, e.g.
  // Tensor = prim::profile[profiled_type=Double(4, 256, strides=[256, 1],
  // requires_grad=0, device=cpu)
  testing::FileCheck()
      .check("Tensor = prim::profile[profiled_type")
      ->check_same("256")
      ->run(*pr->profiled_graph_);

  auto begin = pr->profiled_graph_->block()->nodes().begin();
  auto end = pr->profiled_graph_->block()->nodes().end();
  auto mm =
      std::find_if(begin, end, [](Node* n) { return n->kind() == aten::add; });
  ASSERT_NE(mm, end);
  std::vector<int64_t> mm_expected{4, 2048};
  std::vector<int64_t> eltwise{4, 512};
  checkShape(mm->inputs().at(0)->node()->ty(attr::profiled_type), mm_expected);
  auto mul_n =
      std::find_if(begin, end, [](Node* n) { return n->kind() == aten::mul; });
  ASSERT_NE(mul_n, end);
  checkShape(mul_n->inputs().at(0)->node()->ty(attr::profiled_type), eltwise);
  auto tanh_n =
      std::find_if(begin, end, [](Node* n) { return n->kind() == aten::tanh; });
  checkShape(tanh_n->inputs().at(0)->node()->ty(attr::profiled_type), eltwise);
}

TEST(ProfilerTest, OptionalProfiling) {
  auto graph = std::make_shared<Graph>();
  std::unordered_map<std::string, Value*> vmap;
  parseIR(
      R"IR(
graph(%inp : Tensor,
      %weight : Tensor,
      %bias : Tensor?):
  %1 : Tensor = aten::linear(%inp, %weight, %bias)
  return (%1))IR",
      &*graph,
      vmap);

  auto pr = ProfilingRecord::instrumentGraph(graph);
  pr->profiling_count_ = 2;

  auto input = torch::randn({1, 2});
  auto weight = torch::randn({2, 2});
  auto bias = torch::randn({1, 2});

  auto stack = createStack({input, weight, bias});
  Code cd(pr->profiled_graph_, "");
  InterpreterState is{cd};
  is.run(stack);

  testing::FileCheck()
      .check_count("Tensor? = prim::profile[profiled_type", 1, true)
      ->run(*pr->profiled_graph_);

  // make sure we recorded the shape
  auto begin = pr->profiled_graph_->block()->nodes().begin();
  auto end = pr->profiled_graph_->block()->nodes().end();
  auto linear = std::find_if(
      begin, end, [](Node* n) { return n->kind() == aten::linear; });
  ASSERT_NE(linear, end);
  std::vector<int64_t> bias_expected_shape = {1, 2};
  auto profiled_bias = linear->namedInput("bias")->node();
  checkShape(profiled_bias->ty(attr::profiled_type), bias_expected_shape);
  ASSERT_EQ(0, profiled_bias->i(attr::seen_none));

  auto none_bias = c10::IValue();

  stack.clear();
  stack.emplace_back(input);
  stack.emplace_back(weight);
  stack.emplace_back(none_bias);
  is = InterpreterState{cd};
  is.run(stack);

  // make sure we recorded that "None" was seen.
  begin = pr->profiled_graph_->block()->nodes().begin();
  end = pr->profiled_graph_->block()->nodes().end();
  linear = std::find_if(
      begin, end, [](Node* n) { return n->kind() == aten::linear; });
  ASSERT_NE(linear, end);
  profiled_bias = linear->namedInput("bias")->node();
  checkShape(profiled_bias->ty(attr::profiled_type), bias_expected_shape);
  ASSERT_EQ(1, profiled_bias->i(attr::seen_none));
}

TEST(CallStackTest, Basic) {
  const auto text = R"(
def ham(x):
    return x/7

def bar(x):
    return x*3

def baz(x):
    return ham(x)*x

def foo(x):
    return bar(x)*baz(x)*11
  )";
  auto cu = compile(text);
  const auto& foo = toGraphFunction(cu->get_function("foo"));
  for (Node* n : foo.optimized_graph()->nodes()) {
    if (n->kind() == prim::Constant) {
      if (!n->hasAttribute(attr::value) ||
          n->kindOf(attr::value) != AttributeKind::i) {
        continue;
      }
      int v = n->i(attr::value);
      switch (v) {
        case 3: {
          // Const 3 comes from function 'bar', which gets inlined to 'foo'.
          // The callstack for the corresponding node should contain only the
          // function 'bar'.
          ASSERT_TRUE(n->callstack());
          auto callstack_vector = (*n->callstack())->vec();
          ASSERT_EQ(callstack_vector.size(), 1);
          ASSERT_EQ(std::get<0>(callstack_vector[0]), &cu->get_function("bar"));
          break;
        }
        case 7: {
          // Const 7 comes from function 'ham', which gets inlined to 'baz',
          // which is then inlined to 'foo'. The callstack for the corresponding
          // node should contain these two functions.
          ASSERT_TRUE(n->callstack());
          auto callstack_vector = (*n->callstack())->vec();
          ASSERT_EQ(callstack_vector.size(), 2);
          ASSERT_EQ(std::get<0>(callstack_vector[0]), &cu->get_function("baz"));
          ASSERT_EQ(std::get<0>(callstack_vector[1]), &cu->get_function("ham"));
          break;
        }
        case 11: {
          // Const 11 comes from function 'foo', which is not inlined anywhere
          // and thus it should not have a callstack.
          ASSERT_FALSE(n->callstack());
          break;
        }
      }
    }
  }

  // Check that inlining doesn't corrupt callstack of the callee's nodes.
  const auto& baz = toGraphFunction(cu->get_function("baz"));
  for (Node* n : baz.optimized_graph()->nodes()) {
    if (n->kind() == prim::Constant) {
      if (!n->hasAttribute(attr::value) ||
          n->kindOf(attr::value) != AttributeKind::i) {
        continue;
      }
      int v = n->i(attr::value);
      ASSERT_TRUE(v == 7);
      // Const 7 comes from function 'ham', which gets inlined to 'baz'. 'baz'
      // was also inlined into 'foo', but when looking at the graph of 'baz' we
      // should only see a callstack of depth 1 (containing only 'ham').
      ASSERT_TRUE(n->callstack());
      auto callstack_vector = (*n->callstack())->vec();
      ASSERT_EQ(callstack_vector.size(), 1);
      ASSERT_EQ(std::get<0>(callstack_vector[0]), &cu->get_function("ham"));
    }
  }
}

TEST(CallStackTest, Caching) {
  const auto text = R"(

def a(x):
    print("a1")
    print("a2")
    return x

def b(x):
    print("b1")
    print("b2")
    a(x)
    return x

def c(x):
    print("c1")
    print("c2")
    b(x)
    return x
  )";
  auto cu = compile(text);
  const auto& baz = toGraphFunction(cu->get_function("c"));
  std::unordered_map<std::string, InlinedCallStack*> callstack_objects;
  for (Node* n : baz.optimized_graph()->nodes()) {
    if (n->kind() == prim::Constant) {
      if (!n->hasAttribute(attr::value) ||
          n->kindOf(attr::value) != AttributeKind::s) {
        continue;
      }
      // NOLINTNEXTLINE(performance-unnecessary-copy-initialization)
      std::string v = n->s(attr::value);
      if (n->callstack()) {
        callstack_objects[v] = n->callstack()->get();
      }
    }
  }
  // We expect to see nodes prim::Constant[value="a1"] and
  // prim::Constant[value="a2"] inlined to function 'c'. Their callstacks are
  // the same (a->b->c), so we want to make sure we're not creating different
  // callstack entries for them.
  ASSERT_TRUE(callstack_objects.count("a1") && callstack_objects.count("a2"));
  ASSERT_TRUE(callstack_objects.at("a1") == callstack_objects.at("a2"));
}

TEST(InlinedCallStackTest, BlockAnnotation) {
  Module a("A");
  a.define(R"(
    def forward(self, x, y, z: int):
      if (z == 1):
        return x + y
      else:
        return x * y
  )");
  Module b("B");
  b.define(R"(
    def forward(self, x):
      return x + 2
  )");
  Module c("C");
  c.register_module("A0", a);
  c.register_module("B0", b);
  c.define(R"(
    def forward(self, x, y, z: int):
      return self.A0.forward(x, y, z) + self.B0.forward(x)
  )");

  auto graph =
      toGraphFunction(c.get_method("forward").function()).optimized_graph();
  std::stringstream add_ss, mul_ss;
  for (Node* n : graph->nodes()) {
    if (n->kind() == prim::If) {
      for (Block* block : n->blocks()) {
        for (Node* if_node : block->nodes()) {
          if (if_node->kind() == aten::add) {
            for (const auto& e : if_node->callstack().value()->vec()) {
              add_ss << std::get<1>(e);
            }
            add_ss << if_node->sourceRange();
          }
          if (if_node->kind() == aten::mul) {
            for (const auto& e : if_node->callstack().value()->vec()) {
              mul_ss << std::get<1>(e);
            }
            mul_ss << if_node->sourceRange();
          }
        }
      }
    }
  }
  ASSERT_NE(add_ss.str().find("line 3"), std::string::npos);
  ASSERT_NE(add_ss.str().find("line 4"), std::string::npos);
  ASSERT_NE(
      add_ss.str().find("return self.A0.forward(x, y, z)"), std::string::npos);
  ASSERT_NE(add_ss.str().find("return x + y"), std::string::npos);
  ASSERT_NE(mul_ss.str().find("line 3"), std::string::npos);
  ASSERT_NE(mul_ss.str().find("line 6"), std::string::npos);
  ASSERT_NE(
      mul_ss.str().find("return self.A0.forward(x, y, z)"), std::string::npos);
  ASSERT_NE(mul_ss.str().find("return x * y"), std::string::npos);
}

// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
TEST(InlinedCallStackTest, SelfCallMethods) {
  Module a("A");
  a.define(R"(
    def my_new_method(self, x):
      return x * 3
    def forward_impl_(self, x, y):
      return self.my_new_method(x) + y
    def forward(self, x, y):
      y = y + 2
      return self.forward_impl_(x, y)
  )");
  Module b("B");
  b.define(R"(
    def forward(self, x):
      return x + 2
  )");
  Module c("C");
  c.register_module("A0", a);
  c.register_module("B0", b);
  c.define(R"(
    def call_b(self, x):
      return self.B0.forward(x)
    def forward(self, x, y):
      return self.A0.forward(x, y) + self.call_b(x)
  )");

  auto graph =
      toGraphFunction(c.get_method("forward").function()).optimized_graph();
  std::unordered_map<std::string, size_t> module_hierarchies;
  for (Node* n : graph->nodes()) {
    auto hierarchy = torch::jit::utils::getNodesModuleHierarchy(*n);
    if (module_hierarchies.count(hierarchy) == 0) {
      module_hierarchies[hierarchy] = 0;
    }
    module_hierarchies[hierarchy] += 1;
  }
  ASSERT_EQ(module_hierarchies["A0(A)"], 2);
  ASSERT_EQ(module_hierarchies["A0(A).SELF(A).SELF(A)"], 2);
  ASSERT_EQ(module_hierarchies["A0(A).SELF(A)"], 1);
  ASSERT_EQ(module_hierarchies["SELF(C)"], 1);
  ASSERT_EQ(module_hierarchies["SELF(C).B0(B)"], 1);
}

TEST(AutogradSymbolsTest, Basic) {
  Symbol sym = Symbol::fromQualString("aten::test_symbol");
  Graph graph;
  auto node = graph.create(sym);
  TORCH_CHECK(canRunWithAutograd(node));

  sym = Symbol::fromQualString("prim::test_symbol");
  node = graph.create(sym);
  TORCH_CHECK(canRunWithAutograd(node));

  sym = Symbol::fromQualString("prim::FusionGroup");
  node = graph.create(sym);
  TORCH_CHECK(!canRunWithAutograd(node));

  sym = Symbol::fromQualString("custom::test_symbol");
  node = graph.create(sym);
  TORCH_CHECK(!canRunWithAutograd(node));
}

TEST(DefaultArgTypeHintingTest, Basic) {
  const auto text_non_hinted = R"(

def a(x, y=1):
    print("a1")
    print("a2")
    return x
  )";

  const auto text_hinted = R"(

def a(x, y:int=1):
    print("a1")
    print("a2")
    return x
  )";

  try {
    compile(text_non_hinted);
    ASSERT_TRUE(0);
  } catch (const std::exception& c) {
  }

  auto cu = compile(text_hinted);
}

// Basic set case.
TEST(FuturesTest, Basic) {
  auto f1 = c10::make_intrusive<Future>(IntType::get());
  ASSERT_FALSE(f1->completed());
  ASSERT_FALSE(f1->hasValue());
  int32_t sat1 = 0;
  int32_t sat2 = 0;
  f1->addCallback([&](Future& /* unused */) { ++sat1; });
  f1->markCompleted(43);
  ASSERT_TRUE(f1->completed());
  ASSERT_TRUE(f1->hasValue());
  ASSERT_FALSE(f1->hasError());
  ASSERT_EQ(sat1, 1);
  ASSERT_EQ(f1->constValue().toInt(), 43);
  ASSERT_EQ(f1->value().toInt(), 43);
  f1->addCallback([&](Future& /* unused */) { ++sat2; });
  ASSERT_EQ(sat1, 1);
  ASSERT_EQ(sat2, 1);
}

// Sparse CUDA tensor test
TEST(FutureTest, SparseTensor) {
  // Skip test if CUDA is not available.
  bool has_cuda = at::globalContext().hasCUDA();
  if (!has_cuda) {
    LOG(INFO) << "CUDA not available, skipping test";
  }
  for (int i = 0; i < 2; ++i) {
    auto f = c10::make_intrusive<Future>(TensorType::get());
    at::TensorOptions opts = at::TensorOptions().device(at::DeviceType::CUDA);
    auto sparse_tensor = i == 0 ? at::ones(10).to_sparse()
                                : at::sparse_coo_tensor(
                                      at::arange(10).unsqueeze(0).to(at::kLong),
                                      at::ones({10, 10}),
                                      opts);
    // Runs storage extraction for sparse CUDA tensors
    f->markCompleted(sparse_tensor);
    ASSERT_TRUE(f->completed());
    ASSERT_FALSE(f->hasError());
  }
}

// Basic error cases.
TEST(FuturesTest, Error) {
  auto f1 = c10::make_intrusive<Future>(IntType::get());
  int sat1 = 0;
  int sat2 = 0;
  f1->addCallback([&](Future& /* unused */) { ++sat1; });
  f1->setError(
      std::make_exception_ptr(c10::ivalue::Future::FutureError("Failed")));
  ASSERT_EQ(sat1, 1);
  ASSERT_TRUE(f1->completed());
  ASSERT_TRUE(f1->hasError());
  ASSERT_FALSE(f1->hasValue());
  try {
    (void)f1->value();
    ASSERT_TRUE(false); // Supposed to throw.
  } catch (const std::exception& e) {
    ASSERT_TRUE(strcmp(e.what(), "Failed") == 0);
  }
  f1->addCallback([&](Future& /* unused */) { ++sat2; });
  ASSERT_EQ(sat1, 1);
  ASSERT_EQ(sat2, 1);
  f1->setErrorIfNeeded(
      std::make_exception_ptr(c10::ivalue::Future::FutureError("Dup")));
  ASSERT_TRUE(strcmp(f1->tryRetrieveErrorMessage().c_str(), "Failed") == 0);
  ASSERT_EQ(sat1, 1);
  ASSERT_EQ(sat2, 1);
  try {
    (void)f1->constValue();
    ASSERT_TRUE(false); // Supposed to throw.
  } catch (const std::exception& e) {
    // Original error should be logged.
    ASSERT_TRUE(std::string(e.what()).find("Failed") != std::string::npos);
  }
}

// then
TEST(FuturesTest, Then) {
  auto f1 = c10::make_intrusive<Future>(IntType::get());
  auto f2 = f1->then(
      [](Future& f1) -> IValue { return f1.constValue().toInt() + 1; },
      IntType::get());
  auto f3 = f2->then(
      [](Future& f2) -> IValue { return f2.constValue().toInt() * 3; },
      IntType::get());
  bool done = false;
  f3->addCallback([&done](Future& f3) {
    ASSERT_EQ(f3.constValue().toInt(), (42 + 1) * 3);
    done = true;
  });
  ASSERT_FALSE(done);
  f1->markCompleted(42);
  ASSERT_TRUE(done);
}

// collectAll()
TEST(FuturesTest, CollectAll) {
  auto s1 = c10::make_intrusive<Future>(IntType::get());
  auto s2 = c10::make_intrusive<Future>(IntType::get());
  auto s3 = c10::make_intrusive<Future>(IntType::get());

  // Empty case
  c10::List<intrusive_ptr<ivalue::Future>> futures(
      FutureType::create(IntType::get()));
  auto c1 = collectAll(futures);
  ASSERT_TRUE(c1->completed());
  ASSERT_EQ(c1->value().toList().size(), 0);
  ASSERT_TRUE(
      *(c1->value().toList().elementType()) ==
      *FutureType::create(IntType::get()));

  // 1-element, initially not completed.
  futures.push_back(s1);
  auto c2 = collectAll(futures);
  ASSERT_FALSE(c2->completed());
  s1->markCompleted(5);
  ASSERT_TRUE(c2->completed());
  ASSERT_EQ(c2->value().toList().size(), 1);
  ASSERT_TRUE(
      *(c2->value().toList().elementType()) ==
      *FutureType::create(IntType::get()));
  ASSERT_EQ(c2->value().toList().get(0).toFuture()->value().toInt(), 5);

  // 1-element, already completed
  auto c3 = collectAll(futures);
  ASSERT_TRUE(c3->completed());
  ASSERT_EQ(c3->value().toList().size(), 1);
  ASSERT_EQ(c3->value().toList().get(0).toFuture()->value().toInt(), 5);

  // 3 elements.
  futures.push_back(s2);
  futures.push_back(s3);
  auto c4 = collectAll(futures);
  ASSERT_FALSE(c4->completed());
  s3->markCompleted(7);
  ASSERT_FALSE(c4->completed());
  s2->markCompleted(6);
  ASSERT_TRUE(c4->completed());
  ASSERT_EQ(c4->value().toList().size(), 3);
  ASSERT_EQ(c4->value().toList().get(0).toFuture()->value().toInt(), 5);
  ASSERT_EQ(c4->value().toList().get(1).toFuture()->value().toInt(), 6);
  ASSERT_EQ(c4->value().toList().get(2).toFuture()->value().toInt(), 7);
  ASSERT_TRUE(
      *(c4->value().toList().elementType()) ==
      *FutureType::create(IntType::get()));

  // Handle exception in the list.
  auto s4 = c10::make_intrusive<Future>(IntType::get());
  futures.push_back(s4);
  auto c5 = collectAll(futures);
  ASSERT_FALSE(c5->completed());
  s4->setError(
      std::make_exception_ptr(c10::ivalue::Future::FutureError("Failed")));
  ASSERT_TRUE(c5->completed());
  try {
    c5->value();
    ASSERT_TRUE(false); // supposed to throw
  } catch (const std::exception& e) {
    ASSERT_EQ(std::string(e.what()), "Failed");
  }
}

// collectAny()
TEST(FuturesTest, CollectAny) {
  auto s1 = c10::make_intrusive<Future>(IntType::get());

  // Empty case
  c10::List<intrusive_ptr<ivalue::Future>> futures(
      FutureType::create(IntType::get()));
  auto c1 = collectAny(futures);
  ASSERT_TRUE(c1->completed());

  // 1 element, not yet satisfied
  futures.push_back(s1);
  auto c2 = collectAny(futures);
  ASSERT_FALSE(c2->completed());
  s1->markCompleted(5);
  ASSERT_TRUE(c2->completed());
  ASSERT_TRUE(c2->value().isInt());
  ASSERT_EQ(c2->value().toInt(), 5);

  // 1 element already satisfied.
  auto c3 = collectAny(futures);
  ASSERT_TRUE(c3->completed());
  ASSERT_TRUE(c3->value().isInt());
  ASSERT_EQ(c3->value().toInt(), 5);

  // 2 elements
  futures.clear();
  auto s2 = c10::make_intrusive<Future>(IntType::get());
  auto s3 = c10::make_intrusive<Future>(IntType::get());
  futures.push_back(s2);
  futures.push_back(s3);
  auto c4 = collectAny(futures);
  ASSERT_FALSE(c4->completed());
  s3->markCompleted(7);
  ASSERT_TRUE(c4->completed());
  ASSERT_EQ(c4->value().toInt(), 7);
  s2->markCompleted(1);
  ASSERT_EQ(c4->value().toInt(), 7);
}

TEST(TLSFutureCallbacksTest, Basic) {
  // cb that verifies the profiler is enabled
  auto profilerEnabledCb = [](Future& /* unused */) {
    ASSERT_TRUE(torch::autograd::profiler::profilerEnabled());
  };
  // test running callbacks with propagation of TLS state.
  {
    // Enable the profiler in this thread
    torch::autograd::profiler::enableProfilerLegacy(
        torch::autograd::profiler::ProfilerConfig(
            torch::autograd::profiler::ProfilerState::CPU, false, false));
    auto s1 = c10::make_intrusive<Future>(IntType::get());
    s1->addCallback(wrapPropagateTLSState(profilerEnabledCb));
    std::thread t([s1 = std::move(s1)]() { s1->markCompleted(); });
    // Since we join here, we can ensure that all callbacks corresponding to
    // markCompleted() have finished.
    t.join();
    torch::autograd::profiler::disableProfilerLegacy();
  }
  // then() with TLS State
  {
    // Enable the profiler in this thread
    torch::autograd::profiler::enableProfilerLegacy(
        torch::autograd::profiler::ProfilerConfig(
            torch::autograd::profiler::ProfilerState::CPU, false, false));
    auto s1 = c10::make_intrusive<Future>(IntType::get());
    auto s2 = s1->then(
        wrapPropagateTLSState([&profilerEnabledCb](Future& s1) {
          profilerEnabledCb(s1);
          return at::IValue(1);
        }),
        IntType::get());
    std::thread t([s1 = std::move(s1)]() { s1->markCompleted(); });
    t.join();
    s2->wait();
    torch::autograd::profiler::disableProfilerLegacy();
  }
}

TEST(ProfilerDisableInCallbackTest, Basic) {
  // cb that verifies the profiler is enabled
  auto profilerEnabledCb = []() {
    ASSERT_TRUE(torch::autograd::profiler::profilerEnabled());
  };
  torch::autograd::profiler::enableProfilerLegacy(
      torch::autograd::profiler::ProfilerConfig(
          torch::autograd::profiler::ProfilerState::CPU, false, false));
  auto s1 = c10::make_intrusive<Future>(IntType::get());
  auto verifyProfilerCb =
      wrapPropagateTLSState([&profilerEnabledCb](Future& /* unused */) {
        // Ensure the profiler is still enabled in this thread.
        profilerEnabledCb();
        auto t1 = torch::ones({2, 2});
        auto t2 = torch::ones({2, 2});
        torch::add(t1, t2);
        // Don't cleanup TLSState, and just consolidate.
        auto opts =
            torch::autograd::profiler::ProfilerDisableOptions(false, true);
        auto thread_event_lists =
            // NOLINTNEXTLINE(performance-move-const-arg)
            torch::autograd::profiler::disableProfilerLegacy(std::move(opts));
        // Ensure that the events from this thread are still profiled and we
        // obtain the expected in events in our consolidated list when calling
        // disableProfilerLegacy().
        bool found_ones = false;
        bool found_add = false;
        for (const auto& li : thread_event_lists) {
          for (const auto& evt : li) {
            if (strcmp(evt.name(), "aten::add") == 0) {
              found_add = true;
            } else if (strcmp(evt.name(), "aten::ones") == 0) {
              found_ones = true;
            }
          }
          if (found_add && found_ones) {
            break;
          }
        }
        ASSERT_TRUE(found_ones);
        ASSERT_TRUE(found_add);
      });

  s1->addCallback(verifyProfilerCb);
  // Disable the profiler, but do not consolidate results in the main thread.
  auto opts = torch::autograd::profiler::ProfilerDisableOptions(true, false);
  // NOLINTNEXTLINE(performance-move-const-arg)
  torch::autograd::profiler::disableProfilerLegacy(std::move(opts));
  std::thread t([s1 = std::move(s1)]() { s1->markCompleted(at::IValue(1)); });
  t.join();

  // Similar to above test, but verifies correctness in the case where
  // continuation runs on the main thread.
  torch::autograd::profiler::enableProfilerLegacy(
      torch::autograd::profiler::ProfilerConfig(
          torch::autograd::profiler::ProfilerState::CPU, false, false));
  s1 = c10::make_intrusive<Future>(IntType::get());
  s1->addCallback(verifyProfilerCb);
  // Runs callback inline
  s1->markCompleted(at::IValue(1));
  opts = torch::autograd::profiler::ProfilerDisableOptions(true, false);
  // NOLINTNEXTLINE(performance-move-const-arg)
  torch::autograd::profiler::disableProfilerLegacy(std::move(opts));
}

TEST(RecordDebugHandles, Basic) {
  // Enable the profiler in this thread
  const std::set<torch::autograd::profiler::ActivityType> activities(
      {torch::autograd::profiler::ActivityType::CPU});
  torch::autograd::profiler::prepareProfiler(
      torch::autograd::profiler::ProfilerConfig(
          torch::autograd::profiler::ProfilerState::KINETO, false, false),
      activities);
  torch::autograd::profiler::enableProfiler(
      torch::autograd::profiler::ProfilerConfig(
          torch::autograd::profiler::ProfilerState::KINETO, false, false),
      activities);
  {
    RECORD_EDGE_SCOPE_WITH_DEBUG_HANDLE_AND_INPUTS("my_function", 42, {});
    float x{5.9999}, y{2.1212};
    float z = x / y;
    (void)z;
  }
  {
    RECORD_USER_SCOPE_WITH_INPUTS("not_my_function", {});
    float x{5.9999}, y{2.1212};
    float z = x / y;
    (void)z;
  }
  auto profiler_results_ptr = torch::autograd::profiler::disableProfiler();
  const auto& kineto_events = profiler_results_ptr->events();
  size_t my_events{0};
  for (const auto& e : kineto_events) {
    if (e.name() == "my_function") {
      ASSERT_EQ(e.debugHandle(), 42);
      my_events++;
    } else if (e.name() == "not_my_function") {
      ASSERT_EQ(e.debugHandle(), -1);
      my_events++;
    }
  }
  ASSERT_EQ(my_events, 2);
}

TEST(RecordDebugHandles, ScopedCallbacks) {
  // Enable the profiler in this thread
  torch::autograd::profiler::prepareProfiler(
      torch::autograd::profiler::ProfilerConfig(
          torch::autograd::profiler::ProfilerState::KINETO, false, false),
      {torch::autograd::profiler::ActivityType::CPU});
  torch::autograd::profiler::enableProfiler(
      torch::autograd::profiler::ProfilerConfig(
          torch::autograd::profiler::ProfilerState::KINETO, false, false),
      {torch::autograd::profiler::ActivityType::CPU});

  {
    auto a = torch::rand({128, 128});
    auto b = torch::rand({128, 128});
    auto c = a + b;
  }
  auto profiler_results_ptr = torch::autograd::profiler::disableProfiler();
  ASSERT_TRUE(profiler_results_ptr->events().size() > 0);

  // Enable the profiler in this thread
  torch::autograd::profiler::prepareProfiler(
      torch::autograd::profiler::ProfilerConfig(
          torch::autograd::profiler::ProfilerState::KINETO, false, false),
      {torch::autograd::profiler::ActivityType::CPU});
  torch::autograd::profiler::enableProfiler(
      torch::autograd::profiler::ProfilerConfig(
          torch::autograd::profiler::ProfilerState::KINETO, false, false),
      {torch::autograd::profiler::ActivityType::CPU},
      {at::RecordScope::LITE_INTERPRETER});
  {
    auto a = torch::rand({128, 128});
    auto b = torch::rand({128, 128});
    auto c = a + b;
  }
  profiler_results_ptr = torch::autograd::profiler::disableProfiler();
  ASSERT_TRUE(profiler_results_ptr->events().size() == 0);

  torch::autograd::profiler::prepareProfiler(
      torch::autograd::profiler::ProfilerConfig(
          torch::autograd::profiler::ProfilerState::KINETO, false, false),
      {torch::autograd::profiler::ActivityType::CPU});
  torch::autograd::profiler::enableProfiler(
      torch::autograd::profiler::ProfilerConfig(
          torch::autograd::profiler::ProfilerState::KINETO, false, false),
      {torch::autograd::profiler::ActivityType::CPU},
      {at::RecordScope::LITE_INTERPRETER});
  {
    RECORD_EDGE_SCOPE_WITH_DEBUG_HANDLE_AND_INPUTS("my_function", 42, {});
    auto a = torch::rand({128, 128});
    auto b = torch::rand({128, 128});
    auto c = a + b;
  }
  {
    RECORD_USER_SCOPE_WITH_INPUTS("not_my_function", {});
    auto a = torch::rand({128, 128});
    auto b = torch::rand({128, 128});
    auto c = a + b;
  }
  profiler_results_ptr = torch::autograd::profiler::disableProfiler();
  const auto& kineto_events = profiler_results_ptr->events();
  for (const auto& e : kineto_events) {
    if (e.name() == "my_function") {
      ASSERT_EQ(e.debugHandle(), 42);
    }
  }
  ASSERT_TRUE(profiler_results_ptr->events().size() == 1);
}

TEST(IValueKWargsTest, Basic) {
  const auto text = R"(
    def foo(a : int, b : int, c : int = 4):
      return a + 2*b + 3*c
  )";
  auto cu = compile(text);
  auto result = cu->get_function("foo")({1}, {{"b", 3}});
  ASSERT_EQ(result.toInt(), 19);
}

TEST(ComputeFlopsTest, Basic) {
  uint64_t flops = 0;

  // Test unknown operator
  std::unordered_map<std::string, c10::IValue> extra_args;
  flops = torch::profiler::impl::computeFlops(
      std::string("aten::unknown"), extra_args);
  ASSERT_EQ(flops, 0);

  // Test aten::conv2d
  extra_args.clear();
  std::vector<int64_t> input_size = {4, 5, 6, 7};
  std::vector<int64_t> weight_size = {3, 5, 2, 1};
  std::vector<int64_t> padding = {1, 0};
  std::vector<int64_t> stride = {1, 1};
  std::vector<int64_t> dilation = {0, 0};
  extra_args["input_size"] = at::IValue(at::IntArrayRef(input_size));
  extra_args["weight_size"] = at::IValue(at::IntArrayRef(weight_size));
  extra_args["groups"] = 1;
  extra_args["padding"] = at::IValue(at::IntArrayRef(padding));
  extra_args["stride"] = at::IValue(at::IntArrayRef(stride));
  extra_args["dilation"] = at::IValue(at::IntArrayRef(dilation));
  flops = torch::profiler::impl::computeFlops(
      std::string("aten::conv2d"), extra_args);
  ASSERT_EQ(flops, 13440);

  // Test aten::conv2d fail
  input_size = {4, 5, 6, 7};
  weight_size = {4, 5, 6};
  extra_args["input_size"] = at::IValue(at::IntArrayRef(input_size));
  extra_args["weight_size"] = at::IValue(at::IntArrayRef(weight_size));
  flops = torch::profiler::impl::computeFlops(
      std::string("aten::conv2d"), extra_args);
  ASSERT_EQ(flops, 0);

  // Test aten::conv2d fail 2
  weight_size = {3, 5, 2, 1};
  stride = {0, 0};
  extra_args["weight_size"] = at::IValue(at::IntArrayRef(input_size));
  extra_args["stride"] = at::IValue(at::IntArrayRef(stride));
  flops = torch::profiler::impl::computeFlops(
      std::string("aten::conv2d"), extra_args);
  ASSERT_EQ(flops, 0);

  // Test aten::conv2d fail 3
  extra_args.clear();
  input_size = {4, 5, 6, 7};
  extra_args["input_size"] = at::IValue(at::IntArrayRef(input_size));
  flops = torch::profiler::impl::computeFlops(
      std::string("aten::conv2d"), extra_args);
  ASSERT_EQ(flops, 0);

  // Test aten::mm
  extra_args.clear();
  std::vector<int64_t> mat1_sizes = {3, 4, 5, 6};
  std::vector<int64_t> mat2_sizes = {6, 5, 4, 3};
  extra_args["mat1_size"] = at::IValue(at::IntArrayRef(mat1_sizes));
  extra_args["mat2_size"] = at::IValue(at::IntArrayRef(mat2_sizes));
  flops =
      torch::profiler::impl::computeFlops(std::string("aten::mm"), extra_args);
  ASSERT_EQ(flops, 43200);

  // Test aten::addmm
  flops = torch::profiler::impl::computeFlops(
      std::string("aten::addmm"), extra_args);
  ASSERT_EQ(flops, 43200);

  // Test aten::bmm
  extra_args.clear();
  mat1_sizes = {7, 5, 6};
  mat2_sizes = {7, 6, 3};
  extra_args["mat1_size"] = at::IValue(at::IntArrayRef(mat1_sizes));
  extra_args["mat2_size"] = at::IValue(at::IntArrayRef(mat2_sizes));
  flops =
      torch::profiler::impl::computeFlops(std::string("aten::bmm"), extra_args);
  ASSERT_EQ(flops, 1260);

  // Test aten::baddbmm
  flops = torch::profiler::impl::computeFlops(
      std::string("aten::baddbmm"), extra_args);
  ASSERT_EQ(flops, 1260);

  // Test mm out of range
  extra_args.clear();
  flops =
      torch::profiler::impl::computeFlops(std::string("aten::mm"), extra_args);
  ASSERT_EQ(flops, 0);

  // Test aten::add.Tensor
  extra_args.clear();
  std::vector<int64_t> mat_sizes = {3, 4, 5, 6};
  extra_args["mat_size"] = at::IValue(at::IntArrayRef(mat_sizes));
  flops =
      torch::profiler::impl::computeFlops(std::string("aten::add"), extra_args);
  ASSERT_EQ(flops, 360);

  // Test aten::mul.Tensor
  extra_args.clear();
  mat_sizes = {3, 4, 5, 6};
  extra_args["mat_size"] = at::IValue(at::IntArrayRef(mat_sizes));
  flops =
      torch::profiler::impl::computeFlops(std::string("aten::mul"), extra_args);
  ASSERT_EQ(flops, 360);
}

TEST(TestConstant, TensorGrad) {
  auto graph = std::make_shared<Graph>();
  IValue ten = torch::randn({3, 5}).requires_grad_(true);
  auto con = tryInsertConstant(*graph, ten);
  ASSERT_TRUE(con == std::nullopt);
}

TEST(TestMutation, Basic) {
  auto graph = std::make_shared<Graph>();
  std::unordered_map<std::string, Value*> vmap;
  parseIR(
      R"IR(
graph(%x.1 : Tensor):
  %2 : int = prim::Constant[value=1]()
  %9 : int = prim::Constant[value=4]()
  %x.3 : Tensor = aten::add(%x.1, %2, %2)
  %7 : Tensor = aten::add_(%x.3, %2, %2)
  %y.1 : Tensor = aten::add(%x.3, %9, %2)
  return (%y.1))IR",
      &*graph,
      vmap);
  RemoveTensorMutation(graph, [](Node*) { return false; });
  testing::FileCheck().check("aten::add_")->run(*graph);
  RemoveTensorMutation(graph, [](Node*) { return true; });
  testing::FileCheck().check_not("aten::add_")->run(*graph);
}

TEST(TestInplaceToFunctionalActivation, Basic) {
  auto graph = std::make_shared<Graph>();
  std::unordered_map<std::string, Value*> vmap;
  parseIR(
      R"IR(
graph(%x.1 : Tensor):
  %2 : int = prim::Constant[value=1]()
  %x.3 : Tensor = aten::add(%x.1, %2, %2)
  %y : Tensor = aten::relu_(%x.3)
  return (%y))IR",
      &*graph,
      vmap);
  InplaceToFunctionalActivation(graph);
  testing::FileCheck().check("aten::relu")->run(*graph);
  testing::FileCheck().check_not("aten::relu_")->run(*graph);
}

TEST(TestRegisterShapeOp, Basic) {
  auto graph = std::make_shared<Graph>();
  std::unordered_map<std::string, Value*> vmap;
  parseIR(
      R"IR(
graph():
  %2 : int = prim::Constant[value=5]()
  %3: int[] = prim::ListConstruct(%2, %2)
  return (%3))IR",
      &*graph,
      vmap);

  auto g2 = std::make_shared<Graph>();
  parseIR(
      R"IR(
graph():
  %2 : Tensor = prim::MakeTestTensor()
  return (%2))IR",
      &*g2,
      vmap);

  const FunctionSchema& schema = g2->nodes().begin()->schema();
  torch::jit::RegisterShapeComputeGraphForSchema(schema, graph);
  PropagateShapesOnGraph(g2);
  testing::FileCheck().check("5, 5")->run(*g2);
}

TEST(TestFunctionalToInplaceActivation, Basic) {
  auto graph = std::make_shared<Graph>();
  std::unordered_map<std::string, Value*> vmap;
  parseIR(
      R"IR(
graph(%x.1 : Tensor):
  %2 : int = prim::Constant[value=1]()
  %x.3 : Tensor = aten::add(%x.1, %2, %2)
  %y : Tensor = aten::relu(%x.3)
  return (%y))IR",
      &*graph,
      vmap);
  FunctionalToInplaceActivation(graph);
  testing::FileCheck().check("aten::relu_")->run(*graph);
  testing::FileCheck().check_not("aten::relu(")->run(*graph);
}

TEST(TestFunctionExecutor, SimpleExecutorTest) {
  auto graph = std::make_shared<Graph>();
  parseIR(
      R"IR(
graph(%x.1 : Tensor):
  %2 : int = prim::Constant[value=1]()
  %x.3 : Tensor = aten::add(%x.1, %2, %2)
  %y : Tensor = aten::relu(%x.3)
  return (%y))IR",
      &*graph);
  {
    auto func = std::make_unique<GraphFunction>(
        "name", graph, [](GraphFunction&) {}, ExecutorExecutionMode::PROFILING);
    auto a = at::rand({2, 2, 2}, TensorOptions(kCPU).dtype(at::kFloat));
    Stack stack = {a};
    func->run(stack);
    auto g = lastExecutedOptimizedGraph();
    testing::FileCheck()
        .check("prim::profile")
        ->check("aten::add")
        ->check("aten::relu")
        ->run(*g);
  }
  {
    auto func = std::make_unique<GraphFunction>(
        "name", graph, [](GraphFunction&) {}, ExecutorExecutionMode::SIMPLE);
    auto a = at::rand({2, 2, 2}, TensorOptions(kCPU).dtype(at::kFloat));
    Stack stack = {a};
    func->run(stack);
    auto g = func->getDebugState().graph;
    testing::FileCheck()
        .check_not("prim::profile")
        ->check("aten::add")
        ->check("aten::relu")
        ->run(*g);
  }
}

TEST(TestFunctionExecutor, RunDecompositionTest) {
  static auto* func = torch::jit::GetDecompositionExecutor(
      "aten::var(Tensor self, bool unbiased=True) -> Tensor");
  for (bool unbiased : {true, false}) {
    auto input = at::rand({4, 4});
    Stack stack = {input, unbiased};
    func->run(stack);
    at::Tensor out = pop(stack).toTensor();
    ASSERT_TRUE(at::allclose(out, input.var(unbiased)));
  }
}

TEST(TestShapeGraphLinting, Basic) {
  auto schemas = RegisteredShapeComputeSchemas();
  for (const auto& schema : schemas) {
    // arange does not actually support complex, leave as
    // union[int, float] for now
    if (schema->name() == "aten::arange") {
      continue;
    }
    auto g = shapeComputeGraphForSchema(*schema);
    TORCH_INTERNAL_ASSERT(g);
    LintShapeComputeGraph(schema, *g);
  }
}

// TODO: move to test_kernel when global settings are explicit
// fusion parameters
class Composed : public ::testing::Test {
 public:
  void SetUp() override {
    torch::jit::tensorexpr::getTEMustUseLLVMOnCPU() = false;
  }
};

TEST_F(Composed, ComposedOp) {
  struct WithCPUFuser {
    WithCPUFuser(bool val = true) : cpuFuserEnabled(canFuseOnCPU()) {
      overrideCanFuseOnCPU(val);
    }

    ~WithCPUFuser() {
      overrideCanFuseOnCPU(cpuFuserEnabled);
    }

    bool cpuFuserEnabled;
  };

#ifdef TORCH_ENABLE_LLVM
  const auto graph_string = R"IR(
      graph(%0 : Float(5, 3, strides=[3, 1], device=cpu),
            %1 : Float(5, 3, strides=[1, 5], device=cpu)):
        %2 : Float(5, 3, strides=[3, 1], device=cpu) = aten::mul(%0, %1)
        %3 : Float(5, 3, strides=[3, 1], device=cpu) = aten::mul(%0, %2)
        %4 : Float(5, 3, strides=[3, 1], device=cpu) = aten::mul(%0, %3)
        return (%3, %4))IR";
  auto graph = std::make_shared<Graph>();
  parseIR(graph_string, &*graph);

  // wrong input sizes so we hit the fallback path
  auto a = at::rand({2, 2, 2}, TensorOptions(kCPU).dtype(at::kFloat));
  auto b = at::rand({2, 2, 2}, TensorOptions(kCPU).dtype(at::kFloat))
               .transpose(0, 1);
  auto ref1 = a * (a * b);
  auto ref2 = a * ref1;
  WithCPUFuser g(true);
  bool fusable_on_device = torch::jit::tensorexpr::getTEMustUseLLVMOnCPU();
  torch::jit::tensorexpr::getTEMustUseLLVMOnCPU() = false;
  FuseTensorExprs(
      graph,
      /*min_group_size*/ 2,
      /*add_composed_op*/ true,
      /*fuse_to_dynamic_shapes*/ true);
  Code code(graph, "");
  InterpreterState interpreter{code};
  std::vector<IValue> stack = {a, b};
  interpreter.run(stack);
  at::Tensor out2 = pop(stack).toTensor();
  at::Tensor out1 = pop(stack).toTensor();
  ASSERT_TRUE(at::allclose(ref1, out1));
  ASSERT_TRUE(at::allclose(ref2, out2));

  auto inp_1 = at::ones({4, 4}, TensorOptions(kCPU).dtype(at::kFloat));
  auto inp_2 = at::ones({4, 4}, TensorOptions(kCPU).dtype(at::kFloat));
  stack = {inp_1, inp_2, a, b};
  InterpreterState interpreter2{code};
  interpreter2.run(stack);
  out2 = pop(stack).toTensor();
  out1 = pop(stack).toTensor();
  ASSERT_TRUE(at::allclose(ref1, out1));
  ASSERT_TRUE(at::allclose(ref2, out2));
  // inp_1 is on the bottom of the stack, and corresponds
  // to the second output. inp_2 is on the top corresponds to first output
  ASSERT_TRUE(at::allclose(inp_1, ref2));
  ASSERT_TRUE(at::allclose(inp_2, ref1));
  torch::jit::tensorexpr::getTEMustUseLLVMOnCPU() = fusable_on_device;
#endif
}

TEST(ConstantPropagation, CustomClassesCanBePropagated) {
#ifdef USE_PYTORCH_QNNPACK
  const auto src = R"IR(
    graph():
        %none: NoneType = prim::Constant()
        %dim: int = prim::Constant[value=3]()
        %shape: int[] = prim::ListConstruct(%dim, %dim)
        %weight: Tensor = aten::ones(%shape, %none, %none, %none, %none)
        %scale: float = prim::Constant[value=1.]()
        %zero_point: int = prim::Constant[value=0]()
        %dtype: int = prim::Constant[value=12]()
        %weight_q: Tensor = aten::quantize_per_tensor(%weight, %scale, %zero_point, %dtype)
        %params: __torch__.torch.classes.quantized.LinearPackedParamsBase = quantized::linear_prepack(%weight_q, %none)
        return (%params)
  )IR";
  auto graph = std::make_shared<Graph>();
  std::unordered_map<std::string, Value*> vmap;
  parseIR(src, graph.get(), vmap);

  ConstantPropagation(graph);

  testing::FileCheck().check_not("quantized::linear_prepack")->run(*graph);
#endif
}

} // namespace jit
} // namespace torch
