#include <gtest/gtest.h>

#include <test/cpp/jit/test_utils.h>
#include <torch/csrc/jit/jit_log.h>
#include <torch/csrc/jit/passes/clear_undefinedness.h>
#include <torch/csrc/jit/runtime/custom_operator.h>

namespace torch {
namespace jit {

Stack createStack(std::vector<at::Tensor>&& list) {
  return Stack(
      std::make_move_iterator(list.begin()),
      std::make_move_iterator(list.end()));
}

void assertAllClose(const tensor_list& a, const tensor_list& b) {
  ASSERT_EQ(a.size(), b.size());
  for (size_t i = 0; i < a.size(); ++i) {
    ASSERT_TRUE(a[i].is_same_size(b[i]));
    ASSERT_TRUE(a[i].allclose(b[i]));
  }
}

std::vector<at::Tensor> run(
    InterpreterState& interp,
    const std::vector<at::Tensor>& inputs) {
  std::vector<IValue> stack(inputs.begin(), inputs.end());
  interp.run(stack);
  return fmap(stack, [](const IValue& i) { return i.toTensor(); });
}

static void unpackReturnTuple(Stack& stack) {
  auto tuple = pop(stack).toTuple();
  stack.insert(stack.end(), tuple->elements().begin(), tuple->elements().end());
}

std::pair<tensor_list, tensor_list> runGradient(
    Gradient& grad_spec,
    tensor_list& tensors_in,
    tensor_list& tensor_grads_in) {
  static const auto as_tensorlist = [](const Stack& stack) {
    return fmap(stack, [](const IValue& i) { return i.toTensor(); });
  };
  ClearUndefinedness(grad_spec.df);
  Code f_code{grad_spec.f, ""}, df_code{grad_spec.df, ""};
  InterpreterState f_interpreter{f_code}, df_interpreter{df_code};

  auto f_stack = fmap<IValue>(tensors_in);
  f_interpreter.run(f_stack);

  Stack df_stack;
  df_stack.insert(
      df_stack.end(), tensor_grads_in.begin(), tensor_grads_in.end());
  for (auto offset : grad_spec.df_input_captured_inputs)
    df_stack.push_back(tensors_in[offset]);
  for (auto offset : grad_spec.df_input_captured_outputs)
    df_stack.push_back(f_stack[offset]);
  df_interpreter.run(df_stack);
  unpackReturnTuple(df_stack);
  // Outputs of f needs to be sliced
  f_stack.erase(f_stack.begin() + grad_spec.f_real_outputs, f_stack.end());
  return std::make_pair(as_tensorlist(f_stack), as_tensorlist(df_stack));
}

std::shared_ptr<Graph> build_lstm() {
  const auto graph_string = R"IR(
    graph(%0 : Tensor,
          %1 : Tensor,
          %2 : Tensor,
          %3 : Tensor,
          %4 : Tensor):
      %5 : Tensor = aten::mm(%0, %3)
      %6 : Tensor = aten::mm(%1, %4)
      %7 : int = prim::Constant[value=1]()
      %8 : Tensor = aten::add(%5, %6, %7)
      %9 : Tensor, %10 : Tensor, %11 : Tensor, %12 : Tensor = prim::ConstantChunk[chunks=4, dim=1](%8)
      %13 : Tensor = aten::sigmoid(%9)
      %14 : Tensor = aten::sigmoid(%12)
      %15 : Tensor = aten::tanh(%11)
      %16 : Tensor = aten::sigmoid(%10)
      %17 : Tensor = aten::mul(%16, %2)
      %18 : Tensor = aten::mul(%13, %15)
      %19 : int = prim::Constant[value=1]()
      %20 : Tensor = aten::add(%17, %18, %19)
      %21 : Tensor = aten::tanh(%20)
      %22 : Tensor = aten::mul(%14, %21)
      return (%22, %20))IR";
  auto g = std::make_shared<Graph>();
  torch::jit::parseIR(graph_string, g.get());
  g->lint();

  return g;
}

std::shared_ptr<Graph> build_mobile_export_analysis_graph() {
  // We use following two schemas for this graph:
  //   1. slice.Tensor(Tensor(a) self, int dim=0, int? start=None,
  //                   int? end=None, int step=1) -> Tensor(a)
  //   2. slice.str(str string, int? start=None, int? end=None,
  //                  int step=1) -> str
  // %3 and %4 use slice.Tensor while %5 use slice.str.
  // Since we can see %3 and %4 have the same last argument that is never used
  // (same as default value of schema), we know we can ignore that last arg. For
  // %5, we see that last three args are same as schema default, hence
  // unnecessary.

  const auto graph_string = R"IR(
    graph(%0 : Tensor):
      %1 : int = prim::Constant[value=1]()
      %2 : int = prim::Constant[value=2]()
      %20 : int = prim::Constant[value=0]()
      %21 : int = prim::Constant[value=9223372036854775807]()
      %22 : str = prim::Constant[value="value"]()
      %3 : Tensor  = aten::slice(%0, %1, %20, %2, %1)
      %4 : Tensor = aten::slice(%0, %2, %20, %21, %1)
      %5 : str = aten::slice(%22, %20, %21, %2)
      return (%3, %4, %5))IR";

  auto g = std::make_shared<Graph>();
  torch::jit::parseIR(graph_string, g.get());
  g->lint();
  return g;
}

std::shared_ptr<Graph> build_mobile_export_with_out() {
  const auto graph_string = R"IR(
    graph(%x.1 : Tensor,
          %y.1 : Tensor):
      %8 : NoneType = prim::Constant()
      %6 : int = prim::Constant[value=1]()
      %7 : Tensor = aten::add(%x.1, %y.1, %6, %y.1)
      return (%8))IR";

  auto g = std::make_shared<Graph>();
  torch::jit::parseIR(graph_string, g.get());
  g->lint();
  return g;
}

std::shared_ptr<Graph> build_mobile_export_analysis_graph_nested() {
  // this is pretty much same test as build_mobile_export_analysis_graph(),
  // but some aten::slice operators are hidden under block statement to check
  // if we are correctly recursing all the nodes in graph.
  const auto graph_string = R"IR(
    graph(%0 : Tensor):
      %1 : int = prim::Constant[value=1]()
      %2 : int = prim::Constant[value=2]()
      %20 : int = prim::Constant[value=0]()
      %21 : int = prim::Constant[value=9223372036854775807]()
      %22 : str = prim::Constant[value="value"]()
      %3 : Tensor  = aten::slice(%0, %1, %20, %2, %1)
      %23 : bool = aten::Bool(%3)
      %c : Tensor = prim::If(%23)
        block0():
          %4 : Tensor = aten::slice(%0, %2, %20, %21, %1)
          %5 : str = aten::slice(%22, %20, %21, %2)
          %c.1 : Tensor = aten::slice(%0, %1, %20, %2, %1)
          -> (%c.1)
        block1():
          -> (%3)
      return (%3, %3))IR";
  auto g = std::make_shared<Graph>();
  torch::jit::parseIR(graph_string, g.get());
  g->lint();
  return g;
}

std::shared_ptr<Graph> build_mobile_export_analysis_graph_with_vararg() {
  const auto graph_string = R"IR(
    graph(%0 : Tensor):
      %1 : int = prim::Constant[value=1]()
      %2 : int = prim::Constant[value=2]()
      %3 : int = prim::Constant[value=3]()
      %4 : int[]  = prim::tolist(%1, %2)
      %5 : int[] = prim::tolist(%1, %2, %3)
      return (%4, %5))IR";

  auto g = std::make_shared<Graph>();
  torch::jit::parseIR(graph_string, g.get());
  g->lint();
  return g;
}

std::shared_ptr<Graph> build_mobile_export_analysis_graph_non_const() {
  const auto graph_string = R"IR(
      graph(%input.1 : Tensor):
        %7 : int = prim::Constant[value=1]() # <string>:3:58
        %9 : int = prim::Constant[value=0]() # <string>:3:66
        %8 : int[] = prim::ListConstruct(%7, %7)
        %10 : int[] = prim::ListConstruct(%9, %9)
        %11 : int[] = prim::ListConstruct(%7, %7)
        %12 : Tensor = aten::conv2d(%input.1, %input.1, %input.1, %8, %10, %11, %7)
        return (%12))IR";

  auto g = std::make_shared<Graph>();
  torch::jit::parseIR(graph_string, g.get());
  g->lint();
  return g;
}

at::Tensor t_use(at::Tensor x) {
  return x;
}
at::Tensor t_def(at::Tensor x) {
  return x.t();
}

bool checkRtol(const at::Tensor& diff, const std::vector<at::Tensor> inputs) {
  double maxValue = 0.0;
  for (auto& tensor : inputs) {
    maxValue = fmax(tensor.abs().max().item<float>(), maxValue);
  }
  return diff.abs().max().item<float>() < 2e-6 * maxValue;
}

bool almostEqual(const at::Tensor& a, const at::Tensor& b) {
  return checkRtol(a - b, {a, b});
}

bool exactlyEqual(const at::Tensor& a, const at::Tensor& b) {
  return (a - b).abs().max().item<float>() == 0.f;
}

bool exactlyEqual(
    const std::vector<at::Tensor>& a,
    const std::vector<at::Tensor>& b) {
  if (a.size() != b.size()) {
    return false;
  }
  for (size_t i = 0; i < a.size(); ++i) {
    if (!exactlyEqual(a[i], b[i])) {
      return false;
    }
  }
  return true;
}

std::pair<at::Tensor, at::Tensor> lstm(
    at::Tensor input,
    at::Tensor hx,
    at::Tensor cx,
    at::Tensor w_ih,
    at::Tensor w_hh) {
  auto gates = input.mm(t_use(w_ih)) + hx.mm(t_use(w_hh));

  auto chunked_gates = gates.chunk(4, 1);
  auto ingate = chunked_gates[0];
  auto forgetgate = chunked_gates[1];
  auto cellgate = chunked_gates[2];
  auto outgate = chunked_gates[3];

  ingate = ingate.sigmoid();
  outgate = outgate.sigmoid();
  cellgate = cellgate.tanh();
  forgetgate = forgetgate.sigmoid();

  auto cy = (forgetgate * cx) + (ingate * cellgate);
  auto hy = outgate * cy.tanh();

  return {hy, cy};
}

inline c10::AliasAnalysisKind aliasAnalysisFromSchema() {
  return c10::AliasAnalysisKind::FROM_SCHEMA;
}

namespace {
RegisterOperators reg({
    // This operator is intended to be used in JIT analysis and transformation
    // pass unit tests in which Values with type Tensor are often required. It
    // should not be used in situations in which the graph is actually executed
    // because it always produces empty Tensors.
    Operator(
        "prim::MakeTestTensor() -> Tensor",
        [](Stack& stack) { push(stack, at::Tensor()); },
        aliasAnalysisFromSchema()),
});
} // namespace

std::vector<at::Tensor> runGraph(
    std::shared_ptr<Graph> graph,
    const std::vector<at::Tensor>& inputs) {
  std::vector<IValue> stack = fmap<IValue>(inputs);
  Code code(graph, "test");
  InterpreterState(code).run(stack);
  TORCH_INTERNAL_ASSERT(!stack.empty());
  // Graph outputs that are handled below:
  //   * A list of Tensors.
  //   * 1 Tensor.
  if (stack.front().isTensorList()) {
    return stack.front().toTensorVector();
  }
  TORCH_INTERNAL_ASSERT(stack.front().isTensor());
  return {stack.front().toTensor()};
}

} // namespace jit
} // namespace torch
