#include <gtest/gtest.h>

#include "test/cpp/jit/test_utils.h"

#include <torch/csrc/jit/testing/file_check.h>
#include "torch/csrc/jit/passes/common_subexpression_elimination.h"
#include "torch/csrc/jit/passes/utils/subgraph_utils.h"

namespace torch {
namespace jit {

TEST(SubgraphUtilsTest, Basic) {
  auto graph = build_lstm();
  EliminateCommonSubexpression(graph);

  std::vector<Node*> originalNodes(
      graph->nodes().begin(), graph->nodes().end());

  for (bool reverse_iterate : {true, false}) {
    // Merge everything into a single subgraph
    bool first = true;
    // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
    Node* subgraph;
    auto it =
        reverse_iterate ? graph->nodes().rbegin() : graph->nodes().begin();
    auto end = reverse_iterate ? graph->nodes().rend() : graph->nodes().end();
    for (; it != end;) {
      if (first) {
        subgraph = SubgraphUtils::createSingletonSubgraph(
            *it, prim::DifferentiableGraph);
        it = reverse_iterate ? ++subgraph->reverseIterator()
                             : ++subgraph->iterator();
        first = false;
      }

      SubgraphUtils::mergeNodeIntoSubgraph(*it, subgraph);
      it = reverse_iterate ? ++subgraph->reverseIterator()
                           : ++subgraph->iterator();
    }

    // Unmerge and compare with original node listing
    // NOLINTNEXTLINE(clang-analyzer-core.CallAndMessage)
    SubgraphUtils::unmergeSubgraph(subgraph);
    EliminateCommonSubexpression(graph);

    std::vector<Node*> newNodes(graph->nodes().begin(), graph->nodes().end());
    ASSERT_EQ(originalNodes.size(), newNodes.size());
  }
}

TEST(SubgraphUtilsTest, MergeSubgraphs) {
  auto graph = std::make_shared<Graph>();
  std::unordered_map<std::string, Value*> parse_map;
  parseIR(
      R"IR(
graph(%a : Tensor, %b : Tensor, %c : Tensor):
  %x : Tensor = aten::sigmoid(%a)
  %y : Tensor = aten::mul(%a, %b)
  %p : Tensor = aten::div(%c, %b)
  %q1 : Tensor = aten::mul(%p, %a)
  %q2 : Tensor = aten::tanh(%q1)
  %q3 : Tensor = aten::tanh(%q2)
  %q4 : Tensor = aten::tanh(%q3)
  %q5 : Tensor = aten::hardsigmoid(%q4)
  return (%x, %y, %q5))IR",
      &*graph,
      parse_map);

  std::vector<Node*> originalNodes(
      graph->nodes().begin(), graph->nodes().end());
  for (bool reverse_merge : {true, false}) {
    // Merge everything into two adjacent subgraphs
    Node* graph1 = SubgraphUtils::createSingletonSubgraph(
        *graph->nodes().begin(), prim::DifferentiableGraph);
    while (true) {
      Node* next = graph1->next();
      if (next->kind() == aten::tanh) {
        break;
      }
      SubgraphUtils::mergeNodeIntoSubgraph(next, graph1);
    }
    Node* graph2 = SubgraphUtils::createSingletonSubgraph(
        graph1->next(), prim::DifferentiableGraph);
    while (graph2->next() != *graph->nodes().end()) {
      SubgraphUtils::mergeNodeIntoSubgraph(graph2->next(), graph2);
    }
    // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
    Node* subgraph;
    if (reverse_merge) {
      SubgraphUtils::mergeNodeIntoSubgraph(graph2, graph1);
      subgraph = graph1;
    } else {
      SubgraphUtils::mergeNodeIntoSubgraph(graph1, graph2);
      subgraph = graph2;
    }
    auto run_file_check = [](std::shared_ptr<Graph> graph) {
      graph->lint();
      testing::FileCheck()
          .check("aten::sigmoid")
          ->check("aten::mul")
          ->check("aten::div")
          ->check("aten::mul")
          ->check_count("aten::tanh", 3)
          ->check("aten::hardsigmoid")
          ->run(*graph);
    };
    run_file_check(subgraph->g(attr::Subgraph));

    // Unmerge and compare with original node listing
    SubgraphUtils::unmergeSubgraph(subgraph);
    EliminateCommonSubexpression(graph);
    run_file_check(graph);

    std::vector<Node*> newNodes(graph->nodes().begin(), graph->nodes().end());
    ASSERT_EQ(originalNodes.size(), newNodes.size());
  }
}

TEST(SubgraphUtilsTest, GraphName) {
  auto graph = std::make_shared<Graph>();

  std::unordered_map<std::string, Value*> parse_map;
  parseIR(
      R"IR(
graph(%a : Tensor, %b : Tensor, %c : Tensor):
  %x : Tensor = aten::tanh(%a)
  %y : Tensor = aten::mul(%a, %b)
  %p : Tensor = aten::div(%c, %b)
  %q1 : Tensor = aten::mul(%p, %a)
  %q2 : Tensor = aten::tanh(%q1)
  %q3 : Tensor = aten::tanh(%q2)
  %q4 : Tensor = aten::tanh(%q3)
  %q5 : Tensor = aten::tanh(%q4)
  return (%x, %y, %q5))IR",
      &*graph,
      parse_map);
  std::string ref_full_name = "graph_tanh_mul_div_mul_tanh_tanh_tanh_tanh";
  std::string full_name =
      SubgraphUtils::generateNameForGraph(graph, 80, "graph");
  ASSERT_EQ(full_name, ref_full_name);

  std::string truncated_name =
      SubgraphUtils::generateNameForGraph(graph, 10, "graph");

  ASSERT_LE(truncated_name.size(), ref_full_name.size());
}

} // namespace jit
} // namespace torch
