#include <gtest/gtest.h>

#include <test/cpp/jit/test_utils.h>
#include <torch/csrc/jit/ir/subgraph_matcher.h>
#include <torch/csrc/jit/passes/subgraph_rewrite.h>
#include <torch/csrc/jit/testing/file_check.h>

namespace torch {
namespace jit {
using namespace testing;

TEST(SubgraphRewriterTest, FilterMatch) {
  auto graph = std::make_shared<Graph>();

  parseIR(
      R"IR(
graph(%0):
  %a = a::aaa(%0)
  %b : int = prim::Constant[value=1]()
  %c = c::ccc(%a, %b)
  return (%c))IR",
      graph.get());

  std::string pattern = R"IR(
graph(%a, %b):
  %c = c::ccc(%a, %b)
  return (%c))IR";
  Graph pattern_graph;
  std::unordered_map<std::string, Value*> vmap;

  parseIR(pattern, &pattern_graph, vmap);

  auto b_is_constant = [](const Match& match,
                          const std::unordered_map<std::string, Value*>& vmap) {
    const auto& match_vmap = match.values_map;
    auto b_node = match_vmap.at(vmap.at("b"))->node();
    return b_node->kind() == prim::Constant;
  };

  auto b_is_one = [](const Match& match,
                     const std::unordered_map<std::string, Value*>& vmap) {
    const auto& match_vmap = match.values_map;
    auto b_val = toIValue(match_vmap.at(vmap.at("b")));
    return b_val && b_val->isInt() && b_val->toInt() == 1;
  };

  auto b_is_two = [](const Match& match,
                     const std::unordered_map<std::string, Value*>& vmap) {
    const auto& match_vmap = match.values_map;
    auto b_val = toIValue(match_vmap.at(vmap.at("b")));
    return b_val && b_val->isInt() && b_val->toInt() == 2;
  };

  std::string replacement = R"IR(
graph(%a, %b):
  %d = d::ddd(%a, %b)
  return (%d))IR";

  SubgraphRewriter rewriter;
  rewriter.RegisterRewritePattern(pattern, replacement);

  // b is constant, so the match will succeed
  {
    auto g = graph->copy();
    rewriter.runOnGraph(g, b_is_constant);
    FileCheck().check("d::ddd")->check_not("c::ccc")->run(*g);
  }

  // b is constant and the value is one, the match will succeed
  {
    auto g = graph->copy();
    rewriter.runOnGraph(g, {b_is_constant, b_is_one});
    FileCheck().check("d::ddd")->check_not("c::ccc")->run(*g);
  }

  // b is constant but the value is not two, the match will fail
  {
    auto g = graph->copy();
    rewriter.runOnGraph(g, {b_is_constant, b_is_two});
    FileCheck().check("c::ccc")->check_not("d::ddd")->run(*g);
  }
}

TEST(SubgraphRewriterTest, FilterNoMatch) {
  auto graph = std::make_shared<Graph>();
  parseIR(
      R"IR(
graph(%0):
  %a = a::aaa(%0)
  %b = prim::Constant[value=1]()
  %c = c::ccc(%a, %b)
  return (%c))IR",
      graph.get());

  std::string pattern = R"IR(
graph(%a, %b):
  %c = c::ccc(%a, %b)
  return (%c))IR";
  Graph pattern_graph;
  std::unordered_map<std::string, Value*> vmap;

  parseIR(pattern, &pattern_graph, vmap);

  auto filter = [](const Match& match,
                   const std::unordered_map<std::string, Value*>& vmap) {
    const auto& match_vmap = match.values_map;
    auto b_node = match_vmap.at(vmap.at("b"))->node();
    // b_node is not prim::Assign, so this won't match and we'll skip the
    // rewrite
    return b_node->kind() == prim::Assign;
  };

  std::string replacement = R"IR(
graph(%a, %b):
  %d = d::ddd(%a, %b)
  return (%d))IR";

  SubgraphRewriter rewriter;
  rewriter.RegisterRewritePattern(pattern, replacement);
  rewriter.runOnGraph(graph, filter);

  FileCheck().check("c::ccc")->check_not("d::ddd")->run(*graph);
}

TEST(SubgraphRewriterTest, MultiOutput) {
  {
    auto graph = std::make_shared<Graph>();

    // Basic multi-output pattern rewriting
    parseIR(
        R"IR(
graph(%0, %1):
  %a1, %a2 = a::aaa(%0, %1)
  %b = b::bbb(%a1)
  %c = c::ccc(%b)

  %x1, %x2 = a::aaa(%c, %a2)
  %y = b::bbb(%x1)
  %z = d::ddd(%y)
  return (%z))IR",
        graph.get());

    std::string pattern = R"IR(
graph(%0, %1):
  %a1, %a2 = a::aaa(%0, %1)
  %b = b::bbb(%a1)
  return (%b, %a2))IR";

    std::string replacement = R"IR(
graph(%a, %b):
  %x, %y = ab::ababab(%a, %b)
  return (%x, %y))IR";

    SubgraphRewriter rewriter;
    rewriter.RegisterRewritePattern(pattern, replacement);

    auto g = graph->copy();
    rewriter.runOnGraph(g);
    FileCheck().check("ab::ababab")->check("ab::ababab")->run(*g);
  }
  {
    auto graph = std::make_shared<Graph>();

    // Mimic a real model case
    parseIR(
        R"IR(
    graph(%k, %m, %x1, %x2, %x3, %x4, %y1, %y2, %y3, %y4):
      %a1 = aa::aaa(%x1, %k)
      %b1_1, %b1_2 = bb::bbb(%y1, %a1)
      %a2 = aa::aaa(%x2, %k)
      %b2_1, %b2_2 = bb::bbb(%y2, %a2)
      %a3 = aa::aaa(%x3, %k)
      %b3_1, %b3_2 = bb::bbb(%y3, %a3)
      %a4 = aa::aaa(%x4, %k)
      %b4_1, %b4_2 = bb::bbb(%y4, %a4)
      %c = cc::ccc(%b4_1)
      %d1 = dd::ddd(%b1_2, %m)
      %e1 = ee::eee(%b1_1, %d1)
      %d2 = dd::ddd(%b2_2, %m)
      %e2 = ee::eee(%b2_1, %d2)
      %d3 = dd::ddd(%b3_2, %m)
      %e3 = ee::eee(%b3_1, %d3)
      %d4 = dd::ddd(%b4_2, %m)
      %e4 = ee::eee(%b4_1, %d4)
      return (%d1, %d2, %d3, %d4, %e1, %e2, %e3, %e4)
      )IR",
        graph.get());

    std::string pattern = R"IR(
    graph(%a, %b, %c, %d):
        %y0 = aa::aaa(%b, %c)
        %y1, %y2 = bb::bbb(%a, %y0)
        %y3 = dd::ddd(%y2, %d)
        return (%y3, %y1))IR";

    std::string replacement = R"IR(
    graph(%a, %b, %c, %d):
      %x, %y = ab::ababab(%a, %b, %c, %d)
      return (%x, %y))IR";

    SubgraphRewriter rewriter;
    rewriter.RegisterRewritePattern(pattern, replacement);

    auto g = graph->copy();
    rewriter.runOnGraph(g);
    FileCheck().check("ab::ababab")->check("ab::ababab")->run(*g);
  }
  {
    auto graph = std::make_shared<Graph>();

    // A case where no rewriting should occur due to data dependencies
    parseIR(
        R"IR(
    graph(%x, %y):
      %a = aa::aaa(%x)
      %b = bb::bbb(%a)
      %e = ee::eee(%b)
      %c = cc::ccc(%y)
      %d = dd::ddd(%b, %c)
      %f = ff::fff(%b, %d)
      return (%f)
      )IR",
        graph.get());

    std::string pattern = R"IR(
    graph(%a, %c):
        %b = bb::bbb(%a)
        %d = dd::ddd(%b, %c)
        return (%d, %b))IR";

    std::string replacement = R"IR(
    graph(%a, %c):
      %d, %b = db::fused(%a, %c)
      return (%d, %b))IR";

    SubgraphRewriter rewriter;
    rewriter.RegisterRewritePattern(pattern, replacement);

    auto g = graph->copy();
    rewriter.runOnGraph(g);
    // We should not perform the replacement on the given graph due to data
    // dependency constraints: the output %b is used in %e, which precedes one
    // def of the input %c.
    FileCheck().check_not("db::fused")->run(*g);
  }
}

TEST(SubgraphRewriterTest, OutputType) {
  std::string pattern = R"IR(
graph(%a, %b):
  %c = c::ccc(%a, %b)
  return (%c))IR";
  Graph pattern_graph;
  std::unordered_map<std::string, Value*> vmap;

  parseIR(pattern, &pattern_graph, vmap);

  auto b_is_constant = [](const Match& match,
                          const std::unordered_map<std::string, Value*>& vmap) {
    const auto& match_vmap = match.values_map;
    auto b_node = match_vmap.at(vmap.at("b"))->node();
    return b_node->kind() == prim::Constant;
  };

  std::string replacement = R"IR(
graph(%a, %b):
  %d = d::ddd(%a, %b)
  return (%d))IR";

  SubgraphRewriter rewriter;
  rewriter.RegisterRewritePattern(pattern, replacement);
  {
    auto graph = std::make_shared<Graph>();

    parseIR(
        R"IR(
  graph(%0):
    %a : Float(10, 20) = a::aaa(%0)
    %b : int = prim::Constant[value=1]()
    %c : Float(10, 20) = c::ccc(%a, %b)
    return (%c))IR",
        graph.get());

    // output has shape info.
    rewriter.runOnGraph(graph, b_is_constant);
    FileCheck()
        .check("Float(10, 20) = d::ddd")
        ->check_not("c::ccc")
        ->run(*graph);
  }
  {
    auto graph = std::make_shared<Graph>();

    parseIR(
        R"IR(
  graph(%0):
    %a = a::aaa(%0)
    %b : int = prim::Constant[value=1]()
    %c = c::ccc(%a, %b)
    return (%c))IR",
        graph.get());

    // output has not shape info.
    rewriter.runOnGraph(graph, b_is_constant);
    FileCheck().check("Tensor = d::ddd")->check_not("c::ccc")->run(*graph);
  }
}

} // namespace jit
} // namespace torch
