#include <gtest/gtest.h>
#include "test/cpp/tensorexpr/test_base.h"

#include "test/cpp/tensorexpr/test_utils.h"
#include "torch/csrc/jit/tensorexpr/ir_simplifier.h"
#include "torch/csrc/jit/tensorexpr/registerizer.h"

#include <iostream>

namespace torch {
namespace jit {
using namespace torch::jit::tensorexpr;

// Can replace a simple scalar access with a local variable.
TEST(Registerizer, RegisterizerSimple) {
  BufHandle a("A", {1}, kInt);
  VarHandle x("x", kInt);
  StmtPtr stmt = Block::make(
      {Store::make(a, {0}, 0),
       For::make(
           x,
           0,
           10,
           Block::make(
               {Store::make(a, {0}, Add::make(Load::make(a, {0}), x))}))});

  /*
   * A[0] = 0;
   * for (int x = 0; x < 10; x++) {
   *   A[0] = (A[0]) + x;
   * }
   */

  stmt = registerize(stmt);

  /*
   * int A_1 = 0;
   * for (int x = 0; x < 10; x++) {
   *   A_1 = x + A_1;
   * }
   * A[0] = A_1;
   */

  std::ostringstream oss;
  oss << *stmt;

  const std::string& verification_pattern =
      R"IR(
# CHECK: int A_1 = 0;
# CHECK: for (int x = 0; x < 10; x++)
# CHECK-NOT: A[
# CHECK:   A_1 =
# CHECK: A[0] = A_1;)IR";

  torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
}

// Won't do replacement of a loop access.
TEST(Registerizer, RegisterizerLoop) {
  BufHandle a("A", {10}, kInt);
  VarHandle x("x", kInt);
  StmtPtr stmt = Block::make(
      {Store::make(a, {0}, 0),
       For::make(
           x,
           0,
           10,
           Block::make(
               {Store::make(a, {x}, Add::make(Load::make(a, {x}), x))}))});

  /*
   * A[0] = 0;
   * for (int x = 0; x < 10; x++) {
   *   A[x] = (A[x]) + x;
   * }
   */

  // No change.
  stmt = registerize(stmt);

  /*
   * A[0] = 0;
   * for (int x = 0; x < 10; x++) {
   *   A[x] = (A[x]) + x;
   * }
   */

  std::ostringstream oss;
  oss << *stmt;

  const std::string& verification_pattern =
      R"IR(
# CHECK-NOT: int
# CHECK: A[0] = 0;
# CHECK: for (int x = 0; x < 10; x++)
# CHECK-NOT: A_
# CHECK:   A[x] =
# CHECK-NOT: A[0] = A_1;)IR";

  torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
}

// Won't replace even if the load is a fixed scalar, since the store could
// invalidate it.
TEST(Registerizer, RegisterizerLoopFixedLoad) {
  BufHandle a("A", {1}, kInt);
  VarHandle x("x", kInt);
  StmtPtr stmt = Block::make(
      {Store::make(a, {0}, 0),
       For::make(
           x,
           0,
           10,
           Block::make(
               {Store::make(a, {x}, Add::make(Load::make(a, {0}), x))}))});

  /*
   * A[0] = 0;
   * for (int x = 0; x < 10; x++) {
   *   A[x] = (A[0]) + x;
   * }
   */

  // No change.
  stmt = registerize(stmt);

  /*
   * A[0] = 0;
   * for (int x = 0; x < 10; x++) {
   *   A[x] = (A[0]) + x;
   * }
   */

  std::ostringstream oss;
  oss << *stmt;

  const std::string& verification_pattern =
      R"IR(
# CHECK-NOT: int
# CHECK: A[0] = 0;
# CHECK: for (int x = 0; x < 10; x++)
# CHECK-NOT: A_
# CHECK:   A[x] =
# CHECK-NOT: A[0] = A_1;)IR";

  torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
}

// We can registerize accesses that occur entirely within inner scopes, even if
// they depend on the loop var.
TEST(Registerizer, RegisterizerLoopInternal) {
  BufHandle a("A", {1}, kInt);
  VarHandle x("x", kInt);
  StmtPtr stmt = Block::make({For::make(
      x,
      0,
      10,
      Block::make(
          {Store::make(a, {x}, Add::make(Load::make(a, {x}), x)),
           Store::make(a, {x}, Add::make(Load::make(a, {x}), x))}))});

  /*
   * for (int x = 0; x < 10; x++) {
   *   A[x] = (A[x]) + x;
   *   A[x] = (A[x]) + x;
   * }
   */

  stmt = registerize(stmt);

  // TODO: the order of terms in addition changes and in general depends on
  // some hash value. This results in unpredictable swaps of the operands from
  // random changes, which is not great. Ideally, we should ensure some
  // specific order (ideally, the original one).
  /*
   * for (int x = 0; x < 10; x++) {
   *   int A_1 = A[x];
   *   A_1 = x + A_1;
   *   A_1 = x + A_1;
   *   A[x] = A_1;
   * }
   */

  std::ostringstream oss;
  oss << *stmt;

  const std::string& verification_pattern =
      R"IR(
# CHECK: for (int x = 0; x < 10; x++)
# CHECK: int A_1 = A[x];
# CHECK:   A_1 = A_1 + x;
# CHECK:   A_1 = A_1 + x;
# CHECK:   A[x] = A_1;
# CHECK: })IR";

  torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
}

// An access can be overlapped by another read in the same Expr. In this case
// B[z] and B[y] overlap and prevent registerization of both accesses.
TEST(Registerizer, RegisterizerLoopInternalLoadOverlap) {
  BufHandle a("A", {10}, kInt);
  BufHandle b("B", {10}, kInt);
  VarHandle x("x", kInt);
  VarHandle y("y", kInt);
  VarHandle z("z", kInt);
  StmtPtr stmt = Block::make({For::make(
      x,
      0,
      10,
      Store::make(a, {x}, Add::make(Load::make(b, {y}), Load::make(b, {z}))))});
  stmt = IRSimplifier::simplify(stmt);

  /*
   * for (int x = 0; x < 10; x++) {
   *   A[x] = (B[y]) + (B[z]);
   * }
   */

  std::ostringstream before;
  before << *stmt;

  // No change.
  stmt = registerize(stmt);

  std::ostringstream after;
  after << *stmt;

  ASSERT_EQ(before.str(), after.str());
}

TEST(Registerizer, RegisterizerLoopInternalRepeated) {
  BufHandle a("A", {1}, kInt);
  VarHandle x("x", kInt);
  StmtPtr stmt = Block::make(
      {For::make(
           x,
           0,
           10,
           Block::make(
               {Store::make(a, {0}, Add::make(Load::make(a, {1}), x)),
                Store::make(a, {0}, Add::make(Load::make(a, {1}), x))})),
       For::make(
           x,
           0,
           10,
           Block::make(
               {Store::make(a, {0}, Add::make(Load::make(a, {1}), x)),
                Store::make(a, {0}, Add::make(Load::make(a, {1}), x))}))

      });

  /*
   * for (int x = 0; x < 10; x++) {
   *   A[0] = x + (A[1]);
   *   A[0] = x + (A[1]);
   * }
   * for (int x = 0; x < 10; x++) {
   *   A[0] = x + (A[1]);
   *   A[0] = x + (A[1]);
   * }
   */

  stmt = registerize(stmt);

  /*
   * int A_1 = A[1];
   * int A_2 = A[0];
   * for (int x = 0; x < 10; x++) {
   *   A_2 = A_1 + x;
   *   A_2 = A_1 + x;
   * }
   * for (int x = 0; x < 10; x++) {
   *   A_2 = A_1 + x;
   *   A_2 = A_1 + x;
   * }
   * A[0] = A_2;
   */

  std::ostringstream oss;
  oss << *stmt;

  const std::string& verification_pattern =
      R"IR(
# CHECK: int A_1 = A[1];
# CHECK: int A_2 = A[0];
# CHECK: for (int x = 0; x < 10; x++)
# CHECK:   A_2 = A_1 + x;
# CHECK:   A_2 = A_1 + x;
# CHECK: }
# CHECK: for (int x = 0; x < 10; x++)
# CHECK:   A_2 = A_1 + x;
# CHECK:   A_2 = A_1 + x;
# CHECK: }
# CHECK-NOT: A[1]
# CHECK: A[0] = A_2;
# CHECK-NOT: A[1]
# CHECK: })IR";

  torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
}

TEST(Registerizer, RegisterizerLoopInternalRepeatedOverlapLoopVar) {
  BufHandle a("A", {1}, kInt);
  VarHandle x("x", kInt);
  StmtPtr stmt = Block::make(
      {For::make(
           x,
           0,
           10,
           Block::make(
               {Store::make(a, {0}, Add::make(Load::make(a, {x}), x)),
                Store::make(a, {0}, Add::make(Load::make(a, {x}), x))})),
       For::make(
           x,
           0,
           10,
           Block::make(
               {Store::make(a, {0}, Add::make(Load::make(a, {x}), x)),
                Store::make(a, {0}, Add::make(Load::make(a, {x}), x))}))

      });
  stmt = IRSimplifier::simplify(stmt);

  /*
   * for (int x = 0; x < 10; x++) {
   *   A[0] = (A[x]) + x;
   *   A[0] = (A[x]) + x;
   * }
   * for (int x = 0; x < 10; x++) {
   *   A[0] = (A[x]) + x;
   *   A[0] = (A[x]) + x;
   * }
   */

  std::ostringstream before;
  before << *stmt;

  // No change.
  stmt = registerize(stmt);

  std::ostringstream after;
  after << *stmt;

  ASSERT_EQ(before.str(), after.str());
}

TEST(Registerizer, RegisterizerLoopInternalRepeatedOverlapOther) {
  BufHandle a("A", {1}, kInt);
  VarHandle x("x", kInt);
  VarHandle y("y", kInt);
  StmtPtr stmt = IRSimplifier::simplify(Block::make(
      {For::make(
           x,
           0,
           10,
           Block::make(
               {Store::make(a, {0}, Add::make(x, Load::make(a, {y}))),
                Store::make(a, {0}, Add::make(x, Load::make(a, {y})))})),
       For::make(
           x,
           0,
           10,
           Block::make(
               {Store::make(a, {0}, Add::make(x, Load::make(a, {y}))),
                Store::make(a, {0}, Add::make(x, Load::make(a, {y})))}))

      }));

  /*
   * for (int x = 0; x < 10; x++) {
   *   A[0] = (A[x]) + x;
   *   A[0] = (A[x]) + x;
   * }
   * for (int x = 0; x < 10; x++) {
   *   A[0] = (A[x]) + x;
   *   A[0] = (A[x]) + x;
   * }
   */

  std::ostringstream before;
  before << *stmt;

  // No change.
  stmt = registerize(stmt);

  std::ostringstream after;
  after << *stmt;

  ASSERT_EQ(before.str(), after.str());
}

// Will registerize multiple accesses of different items of the same buffer.
TEST(Registerizer, RegisterizerMultiVar) {
  BufHandle a("A", {2}, kInt);
  VarHandle x("x", kInt);
  StmtPtr stmt = Block::make({
      Store::make(a, {0}, 0),
      Store::make(a, {1}, 0),
      For::make(
          x,
          0,
          10,
          Block::make(
              {Store::make(a, {0}, Add::make(Load::make(a, {0}), x)),
               Store::make(a, {1}, Sub::make(Load::make(a, {1}), x))})),
  });

  /*
   * A[0] = 0;
   * A[1] = 0;
   * for (int x = 0; x < 10; x++) {
   *   A[0] = (A[0]) + x;
   *   A[1] = (A[1]) - x;
   * }
   */

  stmt = registerize(stmt);

  /*
   * int A_1 = 0;
   * int A_2 = 0;
   * for (int x = 0; x < 10; x++) {
   *   A_2 = x + A_2;
   *   A_1 = A_1 - x;
   * }
   * A[1] = A_2;
   * A[0] = A_1;
   */

  std::ostringstream oss;
  oss << *stmt;

  const std::string& verification_pattern =
      R"IR(
# CHECK: int A_1 = 0;
# CHECK: int A_2 = 0;
# CHECK: for (int x = 0; x < 10; x++)
# CHECK-NOT: A[
# CHECK:   A_1 =
# CHECK:   A_2 =
# CHECK: A[1] = A_2
# CHECK: A[0] = A_1;)IR";

  torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
}

// Will registerize the valid accesses while skipping invalid replacements.
TEST(Registerizer, RegisterizerVariableLoad) {
  BufHandle a("A", {1}, kInt);
  BufHandle b("B", {10}, kInt);
  VarHandle x("x", kInt);
  VarHandle x2("x", kInt);
  StmtPtr stmt = Block::make(
      {Store::make(a, {0}, 0),
       For::make(x, 0, 10, Store::make(b, {x}, x)),
       For::make(
           x2,
           0,
           10,
           Block::make({Store::make(
               a, {0}, Add::make(Load::make(a, {0}), Load::make(b, {x2})))}))});

  /*
   * A[0] = 0;
   * for (int x = 0; x < 10; x++) {
   *   B[x] = x;
   * }
   * for (int x_1 = 0; x_1 < 10; x_1++) {
   *   A[0] = (A[0]) + (B[x_1]);
   * }
   */

  stmt = registerize(stmt);

  /*
   * int A_1 = 0;
   * for (int x = 0; x < 10; x++) {
   *   B[x] = x;
   * }
   * for (int x_1 = 0; x_1 < 10; x_1++) {
   *   A_1 = A_1 + (B[x_1]);
   * }
   * A[0] = A_1;
   */

  std::ostringstream oss;
  oss << *stmt;

  const std::string& verification_pattern =
      R"IR(
# CHECK: int A_1 = 0;
# CHECK: for (int x = 0; x < 10; x++)
# CHECK:   B[x] = x
# CHECK: for (int x_1 = 0; x_1 < 10; x_1++)
# CHECK-NOT: A[
# CHECK:   A_1 =
# CHECK: A[0] = A_1;)IR";

  torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
}

// Can registerize variable accesses so long as the variable does not change.
TEST(Registerizer, RegisterizerSymbolicIndices) {
  VarHandle i("i", kInt);
  VarHandle N("N", kInt);
  BufHandle a("A", {N}, kInt);
  VarHandle x("x", kInt);
  StmtPtr stmt = Block::make(
      {Store::make(a, {i}, 0),
       For::make(
           x,
           0,
           10,
           Block::make(
               {Store::make(a, {i}, Add::make(Load::make(a, {i}), x))}))});

  /*
   * A[i] = 0;
   * for (int x = 0; x < 10; x++) {
   *   A[i] = (A[i]) + x;
   * }
   */

  stmt = registerize(stmt);

  /*
   * int A_1 = 0;
   * for (int x = 0; x < 10; x++) {
   *   A_1 = x + A_1;
   * }
   * A[i] = A_1;
   */

  std::ostringstream oss;
  oss << *stmt;

  const std::string& verification_pattern =
      R"IR(
# CHECK: int A_1 = 0;
# CHECK: for (int x = 0; x < 10; x++)
# CHECK-NOT: A[
# CHECK:   A_1 =
# CHECK: A[i] = A_1;)IR";

  torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
}

// Can registerize accesses dependent on multiple loop vars.
TEST(Registerizer, RegisterizerMultiLoop) {
  BufHandle a("A", {1}, kInt);
  VarHandle x("x", kInt);
  VarHandle y("y", kInt);
  StmtPtr stmt = Block::make(
      {Store::make(a, {0}, 0),
       For::make(
           x,
           0,
           10,
           For::make(
               y,
               0,
               10,
               Block::make({Store::make(
                   a,
                   {0},
                   Mul::make(Add::make(Load::make(a, {0}), x), y))})))});

  /*
   * A[0] = 0;
   * for (int x = 0; x < 10; x++) {
   *   for (int y = 0; y < 10; y++) {
   *     A[0] = x * y + (A[0]) * y;
   *   }
   * }
   */

  stmt = registerize(stmt);

  /*
   * int A_1 = 0;
   * for (int x = 0; x < 10; x++) {
   *   for (int y = 0; y < 10; y++) {
   *     A_1 = x * y + y * A_1;
   *   }
   * }
   * A[0] = A_1;
   */

  std::ostringstream oss;
  oss << *stmt;

  const std::string& verification_pattern =
      R"IR(
# CHECK: int A_1 = 0;
# CHECK: for (int x = 0; x < 10; x++)
# CHECK:   for (int y = 0; y < 10; y++)
# CHECK-NOT: A[
# CHECK:     A_1 =
# CHECK: A[0] = A_1;)IR";

  torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
}

// Can registerize correctly if scalars already exist in the program.
TEST(Registerizer, RegisterizerRepeated) {
  BufHandle a("A", {2}, kInt);
  VarHandle x("x", kInt);
  StmtPtr stmt = Block::make({
      Store::make(a, {0}, 0),
      Store::make(a, {1}, 0),
      For::make(
          x,
          0,
          10,
          Block::make(
              {Store::make(a, {0}, Add::make(Load::make(a, {0}), x)),
               Store::make(a, {1}, Sub::make(Load::make(a, {1}), x))})),
  });

  // Registerize manually to make sure we only replace a single target.
  {
    registerizer::RegisterizerAnalysis analysis;
    stmt->accept(&analysis);
    auto candidates = analysis.getCandidates();
    ASSERT_EQ(candidates.size(), 2);

    candidates.pop_back();
    registerizer::RegisterizerReplacer replacer(candidates);
    stmt = stmt->accept_mutator(&replacer);
  }

  // Re-analyze and replace the second target.
  {
    registerizer::RegisterizerAnalysis analysis;
    stmt->accept(&analysis);
    auto candidates = analysis.getCandidates();
    ASSERT_EQ(candidates.size(), 1);

    registerizer::RegisterizerReplacer replacer(candidates);
    stmt = stmt->accept_mutator(&replacer);
  }

  std::ostringstream oss;
  oss << *stmt;

  const std::string& verification_pattern =
      R"IR(
# CHECK: int A_1 = 0;
# CHECK: int A_1_1 = 0;
# CHECK: for (int x = 0; x < 10; x++)
# CHECK-NOT: A[
# CHECK:   A_1 =
# CHECK:   A_1_1 =
# CHECK: A[1] = A_1_1;
# CHECK: A[0] = A_1;)IR";

  torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
}

// Can registerize the load of A.
TEST(Registerizer, RegisterizerNoLoads) {
  BufHandle a("A", {1}, kInt);
  VarHandle x("x", kInt);
  StmtPtr stmt = Block::make(
      {Store::make(a, {0}, 0),
       For::make(
           x, 0, 10, Block::make({Store::make(a, {0}, Add::make(x, 1))}))});

  /*
   * A[0] = 0;
   * for (int x = 0; x < 10; x++) {
   *   A[0] = x + 1;
   * }
   */

  stmt = registerize(stmt);

  /*
   * int A_1 = 0;
   * for (int x = 0; x < 10; x++) {
   *   A_1 = x + 1;
   * }
   * A[0] = A_1;
   */

  std::ostringstream oss;
  oss << *stmt;

  const std::string& verification_pattern =
      R"IR(
# CHECK: int A_1 = 0;
# CHECK: for (int x = 0; x < 10; x++)
# CHECK-NOT: A[
# CHECK:   A_1 =
# CHECK: A[0] = A_1;)IR";

  torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
}

// Can registerize the load of A but not the store of B.
TEST(Registerizer, RegisterizerNoRepeatedStores) {
  BufHandle a("A", {1}, kInt);
  BufHandle b("B", {10}, kInt);
  VarHandle x("x", kInt);
  StmtPtr stmt = Block::make(
      {Store::make(a, {0}, 0),
       For::make(
           x,
           0,
           10,
           Block::make(
               {Store::make(b, {x}, Add::make(Load::make(a, {0}), x))}))});

  /*
   * A[0] = 0;
   * for (int x = 0; x < 10; x++) {
   *   B[x] = (A[0]) + x;
   * }
   */

  stmt = registerize(stmt);

  // TODO: its unnecessary to reorder the initializer of A[0], but it's not
  // actually worse so lets not worry for now.

  /*
   * int A_1 = 0;
   * for (int x = 0; x < 10; x++) {
   *   B[x] = x + A_1;
   * }
   * A[0] = A_1;
   */

  std::ostringstream oss;
  oss << *stmt;

  const std::string& verification_pattern =
      R"IR(
# CHECK: int A_1 = 0;
# CHECK: for (int x = 0; x < 10; x++)
# CHECK-NOT: A_
# CHECK:   B[x] =
# CHECK: A[0] = A_1;)IR";

  torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
}

// Won't registerize if there are multiple accesses which may overlap.
TEST(Registerizer, RegisterizerMultiVarOverlap) {
  BufHandle a("A", {2}, kInt);
  VarHandle x("x", kInt);
  StmtPtr stmt = Block::make({
      Store::make(a, {0}, 0),
      Store::make(a, {1}, 0),
      For::make(
          x,
          0,
          10,
          Block::make(
              {Store::make(a, {x}, Add::make(Load::make(a, {0}), x)),
               Store::make(a, {x + 1}, Sub::make(Load::make(a, {1}), x))})),
  });
  stmt = IRSimplifier::simplify(stmt);

  std::ostringstream before;
  before << *stmt;

  // No change.
  stmt = registerize(stmt);

  std::ostringstream after;
  after << *stmt;

  ASSERT_EQ(before.str(), after.str());
}

TEST(Registerizer, RegisterizerAllocs) {
  BufHandle a("A", {2}, kInt);
  BufHandle c("C", {1}, kInt);
  VarHandle x("x", kInt);

  BufHandle b("B", {Load::make(c, {0})}, kInt);

  StmtPtr stmt = Block::make(
      {Allocate::make(b),
       Store::make(a, {0}, Load::make(c, {0})),
       Store::make(b, {0}, 0),
       For::make(
           x,
           0,
           10,
           Block::make(
               {Store::make(b, {0}, Add::make(Load::make(b, {0}), x)),
                Store::make(a, {0}, Load::make(c, {0}))})),
       Free::make(b)});

  /*
   * Allocate(B, int, {C[0]});
   * A[0] = C[0];
   * B[0] = 0;
   * for (int x = 0; x < 10; x++) {
   *   B[0] = (B[0]) + x;
   *   A[0] = C[0];
   * }
   * Free(B);
   */

  stmt = registerize(stmt);

  /*
   * int C_1 = C[0];
   * Allocate(B, int, {C_});
   * int A_1 = C_1;
   * int B_1 = 0;
   * for (int x = 0; x < 10; x++) {
   *   B_1 = B_1 + x;
   *   A_1 = C_1;
   * }
   * B[0] = B_1;
   * A[0] = A_1;
   * Free(B);
   */

  std::ostringstream oss;
  oss << *stmt;

  const std::string& verification_pattern =
      R"IR(
# CHECK: int C_1 = C[0];
# CHECK: Allocate(B
# CHECK: int A_1 = C_1;
# CHECK: int B_1 = 0;
# CHECK: for (int x = 0; x < 10; x++)
# CHECK:   B_1 =
# CHECK:   A_1 = C_
# CHECK: B[0] = B_1;
# CHECK: A[0] = A_1;
# CHECK: Free(B)IR";

  torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
}

TEST(Registerizer, RegisterizerNoInitializer) {
  BufHandle a("A", {1}, kInt);
  VarHandle x("x", kInt);
  StmtPtr stmt = Block::make({For::make(
      x,
      0,
      10,
      Block::make({Store::make(a, {0}, Add::make(Load::make(a, {0}), x))}))});

  /*
   * for (int x = 0; x < 10; x++) {
   *   A[0] = (A[0]) + x;
   * }
   */

  stmt = registerize(stmt);

  /*
   * int A_1 = A[0];
   * for (int x = 0; x < 10; x++) {
   *   A_1 = x + A_1;
   * }
   * A[0] = A_1;
   */

  std::ostringstream oss;
  oss << *stmt;

  const std::string& verification_pattern =
      R"IR(
# CHECK: int A_1 = A[0];
# CHECK: for (int x = 0; x < 10; x++)
# CHECK-NOT: A[
# CHECK:   A_1 =
# CHECK: A[0] = A_1;)IR";

  torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
}

TEST(Registerizer, RegisterizerNoInitializerLoopVar) {
  BufHandle a("A", {1}, kInt);
  VarHandle x("x", kInt);
  StmtPtr stmt = Block::make({For::make(
      x,
      0,
      10,
      Block::make({Store::make(a, {x}, Add::make(Load::make(a, {x}), x))}))});
  stmt = IRSimplifier::simplify(stmt);

  /*
   * for (int x = 0; x < 10; x++) {
   *   A[x] = (A[x]) + x;
   * }
   */

  std::ostringstream before;
  before << *stmt;

  // No change.
  stmt = registerize(stmt);

  std::ostringstream after;
  after << *stmt;

  ASSERT_EQ(before.str(), after.str());
}

TEST(Registerizer, RegisterizerLoadThenStore) {
  BufHandle a("A", {1}, kInt);
  BufHandle b("B", {1}, kInt);
  VarHandle x("x", kInt);
  StmtPtr stmt = Block::make({For::make(
      x,
      0,
      10,
      Block::make(
          {Store::make(b, {0}, Add::make(Load::make(a, {0}), x)),
           Store::make(a, {0}, Load::make(b, {0}))}))});

  /*
   * for (int x = 0; x < 10; x++) {
   *   B[0] = (A[0]) + x;
   *   A[0] = B[0];
   * }
   */

  stmt = registerize(stmt);

  /*
   * int A_1 = A[0];
   * int B_1 = B[0];
   * for (int x = 0; x < 10; x++) {
   *   B_1 = x + A_1;
   *   A_1 = B_1;
   * }
   * B[0] = B_1;
   * A[0] = A_1;
   */

  std::ostringstream oss;
  oss << *stmt;

  const std::string& verification_pattern =
      R"IR(
# CHECK: int A_1 = A[0];
# CHECK: int B_1 = B[0];
# CHECK: for (int x = 0; x < 10; x++)
# CHECK-NOT: B[
# CHECK:   B_1 =
# CHECK-NOT: A[
# CHECK:   A_1 = B_
# CHECK: B[0] = B_
# CHECK: A[0] = A_1;)IR";

  torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
}

TEST(Registerizer, RegisterizerParallelized) {
  BufHandle a("A", {1}, kInt);
  VarHandle x("x", kInt);
  LoopOptions loopOpts;
  loopOpts.set_gpu_block_index(0);
  StmtPtr stmt = Block::make(
      {Store::make(a, {0}, 0),
       For::make(
           x,
           0,
           10,
           Block::make({Store::make(a, {0}, Add::make(Load::make(a, {0}), x))}),
           loopOpts)});

  /*
   * A[0] = 0;
   * for (int x = 0; x < 10; x++) {
   *   A[0] = (A[0]) + x;
   * }
   */

  ASSERT_THROWS_WITH(
      registerize(stmt),
      "Registerization must occur after parallelism flattening");
}

// Should be able to registerize this since the scalar would exist before the
// branch.
TEST(Registerizer, RegisterizerConditionAfter) {
  BufHandle a("A", {5}, kInt);
  BufHandle b("B", {5}, kInt);
  BufHandle c("C", {5}, kInt);
  VarHandle x("x", kInt);

  StmtPtr stmt = Block::make(
      {Store::make(a, {x}, Load::make(b, {x})),
       Store::make(c, {x}, Load::make(a, {x})),
       Cond::make(
           CompareSelect::make(x, 5, CompareSelectOperation::kLT),
           Store::make(a, {x}, Add::make(Load::make(a, {x}), 1)),
           nullptr)});

  /*
   * A[x] = B[x];
   * C[x] = A[x];
   * if (x<5 ? 1 : 0) {
   *   A[x] = (A[x]) + 1;
   * }
   */

  stmt = registerize(stmt);

  /*
   * int A_1 = B[x];
   * C[x] = A_1;
   * if (x<5 ? 1 : 0) {
   *   A_1 = A_1 + 1;
   * }
   * A[x] = A_1;
   */

  std::ostringstream oss;
  oss << *stmt;

  const std::string& verification_pattern =
      R"IR(
# CHECK: int A_1 = B[x];
# CHECK: C[x] = A_1;
# CHECK: if (
# CHECK:   A_1 = A_1 + 1;
# CHECK: A[x] = A_1;)IR";

  torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
}

// Should be able to registerize this since the scalar exists in the same form
// after the branch and there is no overlap.
TEST(Registerizer, RegisterizerConditionBefore) {
  BufHandle a("A", {5}, kInt);
  BufHandle b("B", {5}, kInt);
  BufHandle c("C", {5}, kInt);
  VarHandle x("x", kInt);

  StmtPtr stmt = Block::make(
      {Cond::make(
           CompareSelect::make(x, 5, CompareSelectOperation::kLT),
           Store::make(a, {x}, Add::make(Load::make(a, {x}), 1)),
           nullptr),
       Store::make(a, {x}, Load::make(b, {x})),
       Store::make(c, {x}, Load::make(a, {x}))});

  /*
   * if (x<5 ? 1 : 0) {
   *   A[x] = (A[x]) + 1;
   * }
   * A[x] = B[x];
   * C[x] = A[x];
   */

  stmt = registerize(stmt);

  /*
   * int A_ 1 = A[x];
   * if (x<5 ? 1 : 0) {
   *   A_1 = A_1 + 1;
   * }
   * A_1 = B[x];
   * C[x] = A_1;
   * A[x] = A_1;
   */

  std::ostringstream oss;
  oss << *stmt;

  const std::string& verification_pattern =
      R"IR(
# CHECK: int A_1 = A[x];
# CHECK: if (
# CHECK:   A_1 = A_1 + 1;
# CHECK: }
# CHECK: A_1 = B[x];
# CHECK: C[x] = A_1;
# CHECK: A[x] = A_1;)IR";

  torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
}

// Should be able to registerize this as the combination of the two above rules.
TEST(Registerizer, RegisterizerConditionInside) {
  BufHandle a("A", {5}, kInt);
  BufHandle b("B", {5}, kInt);
  BufHandle c("C", {5}, kInt);
  VarHandle x("x", kInt);

  StmtPtr stmt = Block::make(
      {Store::make(a, {x}, Load::make(b, {x})),
       Store::make(c, {x}, Load::make(a, {x})),
       Cond::make(
           CompareSelect::make(x, 5, CompareSelectOperation::kLT),
           Store::make(a, {x}, Add::make(Load::make(a, {x}), 1)),
           nullptr),
       Store::make(b, {x}, Load::make(a, {x})),
       Store::make(a, {x}, Load::make(c, {x}))});

  /*
   * A[x] = B[x];
   * C[x] = A[x];
   * if (x<5 ? 1 : 0) {
   *   A[x] = (A[x]) + 1;
   * }
   * B[x] = A[x];
   * A[x] = C[x];
   */

  stmt = registerize(stmt);

  /*
   * int A_1 = B[x];
   * C[x] = A_1;
   * if (x<5 ? 1 : 0) {
   *   A_1 = A_1 + 1;
   * }
   * B[x] = A_1;
   * A_1 = C[x];
   * A[x] = A_1;
   */

  std::ostringstream oss;
  oss << *stmt;

  const std::string& verification_pattern =
      R"IR(
# CHECK: int A_1 = B[x];
# CHECK: C[x] = A_1;
# CHECK: if (
# CHECK:   A_1 = A_1 + 1;
# CHECK: }
# CHECK: B[x] = A_1;
# CHECK: A_1 = C[x];
# CHECK: A[x] = A_1;)IR";

  torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
}

// An example where an access is cut by an overlapping access inside a
// condition, and both sides are large enough to be registerized but cannot be
// because there is no safe place to put the initializer or finalizer.
TEST(Registerizer, RegisterizerConditionInsideOverlap1) {
  BufHandle a("A", {5}, kInt);
  BufHandle b("B", {5}, kInt);
  BufHandle c("C", {5}, kInt);
  VarHandle x("x", kInt);
  VarHandle y("y", kInt);

  StmtPtr stmt = Block::make(
      // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks)
      {Store::make(a, {x}, Load::make(b, {x})),
       Store::make(c, {x}, Load::make(a, {x})),
       Cond::make(
           CompareSelect::make(x, 5, CompareSelectOperation::kLT),
           Block::make({
               Store::make(a, {x}, Add::make(Load::make(a, {x}), 1)),
               Store::make(a, {0}, 3),
               Store::make(a, {x}, Add::make(Load::make(a, {x}), 1)),
           }),
           nullptr),
       Store::make(b, {x}, Load::make(a, {x})),
       Store::make(a, {x}, Load::make(c, {x}))});

  /*
   * A[x] = B[x];
   * C[x] = A[x];
   * if (x<5 ? 1 : 0) {
   *   A[x] = (A[x]) + 1;
   *   A[0] = 3;
   *   A[x] = (A[x]) + 1;
   * }
   * B[x] = A[x];
   * A[x] = C[x];
   */

  // The A[0] store overlaps, A[x] cutting the region that can be registerized
  // into two groups.
  // Each group has 2 loads and 2 stores however, so we could registerize it,
  // but the first group would need to be finalized inside the condition block,
  // the second would need to be initialized inside the condition block. There's
  // no safe place to put these that's visible to the other uses in the group
  // and so neither registerization is possible.

  std::ostringstream before;
  before << *stmt;

  // No change.
  stmt = registerize(stmt);

  std::ostringstream after;
  after << *stmt;

  ASSERT_EQ(before.str(), after.str());
}

// Same as the above, but the access group before the condition (and after the
// condition) are large enough to be registerized without needing the access
// from the loop. Registerization occurs but does not include any accesses in
// the condition, and the first group must be finalized before the Cond, the
// second initialized after it.
TEST(Registerizer, RegisterizerConditionInsideOverlap2) {
  BufHandle a("A", {5}, kInt);
  BufHandle b("B", {5}, kInt);
  BufHandle c("C", {5}, kInt);
  VarHandle x("x", kInt);
  VarHandle y("y", kInt);

  StmtPtr stmt = Block::make(
      // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks)
      {Store::make(a, {x}, Load::make(b, {x})),
       Store::make(a, {x}, Load::make(b, {x + 1})),
       Store::make(c, {x}, Load::make(a, {x})),
       Cond::make(
           CompareSelect::make(x, 5, CompareSelectOperation::kLT),
           Block::make({
               Store::make(a, {x}, Add::make(Load::make(a, {x}), 1)),
               Store::make(a, {0}, 3),
               Store::make(a, {x}, Add::make(Load::make(a, {x}), 1)),
           }),
           nullptr),
       Store::make(b, {x}, Load::make(a, {x})),
       Store::make(b, {x + 1}, Load::make(a, {x})),
       Store::make(a, {x}, Load::make(c, {x}))});

  /*
   * A[x] = B[x];
   * A[x] = B[x + 1];
   * C[x] = A[x];
   * if (x<5 ? 1 : 0) {
   *   A[x] = (A[x]) + 1;
   *   A[0] = 3;
   *   A[x] = (A[x]) + 1;
   * }
   * B[x] = A[x];
   * B[x + 1] = A[x];
   * A[x] = C[x];
   */

  stmt = registerize(stmt);

  /*
   * int A_1 = B[x];              // A_1 initializer
   * A_1 = B[x + 1];              //
   * C[x] = A_1;                  //
   * A[x] = A_1;                  // A_1 finalizer
   * if (x<5 ? 1 : 0) {
   *   A[x] = (A[x]) + 1;
   *   A[0] = 3;
   *   A[x] = (A[x]) + 1;
   * }
   * int A_2 = A[x];              // A_2 initialier
   * B[x] = A_2;                  //
   * B[x + 1] = A_2;              //
   * A_2 = C[x];                  //
   * A[x] = A_2;                  // A_2 finalizer
   */

  std::ostringstream oss;
  oss << *stmt;

  const std::string& verification_pattern =
      R"IR(
# CHECK: int A_1 = B[x];
# CHECK: A_1 = B[x + 1];
# CHECK: C[x] = A_1;
# CHECK: A[x] = A_1;
# CHECK: if (
# CHECK-NOT:   A_1 = A_1 + 1;
# CHECK:   A[x] = (A[x]
# CHECK:   A[0] =
# CHECK:   A[x] = (A[x]
# CHECK: }
# CHECK: int A_2 = A[x];
# CHECK: B[x] = A_2;
# CHECK: B[x + 1] = A_2;
# CHECK: A_2 = C[x];
# CHECK: A[x] = A_2;)IR";

  torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
}

// When accesses are within conditional blocks they are not visible to the wider
// program, because we don't know if the branch would be taken and if it isn't
// the accesses in it don't need to be valid (think size checks on the index).
// In this case the accesses cannot be registerized.
TEST(Registerizer, RegisterizerConditionHidden) {
  BufHandle a("A", {5}, kInt);
  BufHandle b("B", {5}, kInt);
  BufHandle c("C", {5}, kInt);
  VarHandle x("x", kInt);

  StmtPtr stmt = Block::make(
      {Cond::make(
           CompareSelect::make(x, 5, CompareSelectOperation::kLT),
           Store::make(a, {x}, Add::make(Load::make(a, {x}), 1)),
           nullptr),
       Cond::make(
           CompareSelect::make(x, 5, CompareSelectOperation::kGT),
           Store::make(a, {x}, Add::make(Load::make(a, {x}), 1)),
           nullptr)});

  /*
   * if (x<5 ? 1 : 0) {
   *   A[x] = (A[x]) + 1;
   * }
   * if (x>5 ? 1 : 0) {
   *   A[x] = (A[x]) + 1;
   * }
   */

  std::ostringstream before;
  before << *stmt;

  // No change.
  stmt = registerize(stmt);

  std::ostringstream after;
  after << *stmt;

  ASSERT_EQ(before.str(), after.str());
}

// But... if the same access is found in a non conditional scope, that means
// that that access is valid in the higher scope (or at least if its not it's
// the user's fault). It "unhides" the conditional accesses, allowing
// registerization to occur.
TEST(Registerizer, RegisterizerConditionUnhidden) {
  BufHandle a("A", {5}, kInt);
  BufHandle b("B", {5}, kInt);
  BufHandle c("C", {5}, kInt);
  VarHandle x("x", kInt);

  StmtPtr stmt = Block::make(
      {Cond::make(
           CompareSelect::make(x, 5, CompareSelectOperation::kLT),
           Store::make(a, {x}, Add::make(Load::make(a, {x}), 1)),
           nullptr),
       Store::make(a, {x}, Add::make(Load::make(a, {x}), 1)),
       Cond::make(
           CompareSelect::make(x, 5, CompareSelectOperation::kGT),
           Store::make(a, {x}, Add::make(Load::make(a, {x}), 1)),
           nullptr)});

  /*
   * if (x<5 ? 1 : 0) {
   *   A[x] = (A[x]) + 1;
   * }
   * A[x] = (A[x]) + 1;            <-- this is doing the unhiding.
   * if (x>5 ? 1 : 0) {
   *   A[x] = (A[x]) + 1;
   * }
   */

  stmt = registerize(stmt);

  /*
   * int A_1 = A[x];
   * if (x<5 ? 1 : 0) {
   *   A_1 = A_1 + 1;
   * }
   * A_1 = A_1 + 1;
   * if (x>5 ? 1 : 0) {
   *   A_1 = A_1 + 1;
   * }
   * A[x] = A_1;
   */

  std::ostringstream oss;
  oss << *stmt;

  const std::string& verification_pattern =
      R"IR(
# CHECK: int A_1 = A[x];
# CHECK: if (x<5
# CHECK:   A_1 = A_1 + 1;
# CHECK: }
# CHECK: A_1 = A_1 + 1;
# CHECK: if (x>5
# CHECK:   A_1 = A_1 + 1;
# CHECK: }
# CHECK: A[x] = A_1;)IR";

  torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
}

// Can registerize a load that occurs in the condition of a Cond.
TEST(Registerizer, RegisterizerCondCondition) {
  BufHandle a("A", {5}, kInt);
  BufHandle b("B", {5}, kInt);
  BufHandle c("C", {5}, kInt);
  VarHandle x("x", kInt);

  StmtPtr stmt = Block::make(
      {Store::make(a, {x}, Load::make(b, {x})),
       Store::make(c, {x}, Load::make(a, {x})),
       Cond::make(
           CompareSelect::make(
               Load::make(a, {x}), 5, CompareSelectOperation::kLT),
           Store::make(c, {x}, Add::make(Load::make(c, {x}), 1)),
           nullptr)});

  /*
   * A[x] = B[x];
   * C[x] = A[x];
   * if ((A[x])<5 ? 1 : 0) {
   *   C[x] = (C[x]) + 1;
   * }
   */

  stmt = registerize(stmt);

  /*
   * int A_1 = B[x];
   * int C_1 = A_1;
   * if (A_1<5 ? 1 : 0) {
   *   C_1 = C_1 + 1;
   * }
   * C[x] = C_1;
   */

  std::ostringstream oss;
  oss << *stmt;

  const std::string& verification_pattern =
      R"IR(
# CHECK: int A_1 = B[x];
# CHECK: int C_1 = A_1;
# CHECK: if (A_1<5
# CHECK:   C_1 = C_1 + 1;
# CHECK: C[x] = C_1;)IR";

  torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
}

// Appearing in the condition of a Cond makes it visible to the enclosing scope,
// and so we can registerize internal usages.
TEST(Registerizer, RegisterizerCondConditionUnhidden) {
  BufHandle a("A", {5}, kInt);
  BufHandle b("B", {5}, kInt);
  BufHandle c("C", {5}, kInt);
  VarHandle x("x", kInt);

  StmtPtr stmt = Block::make({Cond::make(
      CompareSelect::make(Load::make(a, {x}), 5, CompareSelectOperation::kLT),
      Store::make(a, {x}, Add::make(Load::make(a, {x}), 1)),
      Store::make(a, {x}, Add::make(Load::make(a, {x}), 10)))});

  /*
   * if ((A[x])<5 ? 1 : 0) {
   *   A[x] = (A[x]) + 1;
   * } else {
   *   A[x] = (A[x]) + 10;
   * }
   */

  stmt = registerize(stmt);

  /*
   * int A_1 = A[x];
   * if (A_1<5 ? 1 : 0) {
   *   A_1 = A_1 + 1;
   * } else {
   *   A_1 = A_1 + 10;
   * }
   * A[x] = A_1;
   */

  std::ostringstream oss;
  oss << *stmt;

  const std::string& verification_pattern =
      R"IR(
# CHECK: int A_1 = A[x];
# CHECK: if (A_1<5
# CHECK:   A_1 = A_1 + 1;
# CHECK: } else {
# CHECK:   A_1 = A_1 + 10;
# CHECK: }
# CHECK: A[x] = A_1;)IR";

  torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
}

// Conditional hiding also works for IfThenElse exprs.
TEST(Registerizer, RegisterizerIfThenElseHidden) {
  BufHandle a("A", {5}, kInt);
  BufHandle b("B", {5}, kInt);
  BufHandle c("C", {5}, kInt);
  VarHandle x("x", kInt);
  VarHandle y("y", kInt);

  StmtPtr stmt = Block::make(
      {Store::make(
           b,
           {y},
           IfThenElse::make(
               CompareSelect::make(x, 5, CompareSelectOperation::kLT),
               Add::make(Load::make(a, {x}), 1),
               Add::make(Load::make(a, {x + 1}), 2))),
       Store::make(
           b,
           {y + 1},
           IfThenElse::make(
               CompareSelect::make(x, 5, CompareSelectOperation::kLT),
               Add::make(Load::make(a, {x}), 1),
               Add::make(Load::make(a, {x + 1}), 2)))});

  /*
   * B[y] = IfThenElse(x<5 ? 1 : 0, (A[x]) + 1, (A[x + 1]) + 2);
   * B[y + 1] = IfThenElse(x<5 ? 1 : 0, (A[x]) + 1, (A[x + 1]) + 2);
   */

  std::ostringstream before;
  before << *stmt;

  // No change.
  stmt = registerize(stmt);

  std::ostringstream after;
  after << *stmt;

  ASSERT_EQ(before.str(), after.str());
}

// Conditional unhiding also works for IfThenElse exprs.
TEST(Registerizer, RegisterizerIfThenElseUnhidden) {
  BufHandle a("A", {5}, kInt);
  BufHandle b("B", {5}, kInt);
  BufHandle c("C", {5}, kInt);
  VarHandle x("x", kInt);
  VarHandle y("y", kInt);

  StmtPtr stmt = Block::make({
      Store::make(a, {x}, 0),
      Store::make(
          b,
          {y},
          IfThenElse::make(
              CompareSelect::make(x, 5, CompareSelectOperation::kLT),
              Add::make(Load::make(a, {x}), 1),
              Add::make(Load::make(a, {x + 1}), 2))),
      Store::make(
          b,
          {y + 1},
          IfThenElse::make(
              CompareSelect::make(x, 5, CompareSelectOperation::kLT),
              Add::make(Load::make(a, {x}), 1),
              Add::make(Load::make(a, {x + 1}), 2))),
  });

  /*
   * A[x] = 0;
   * B[y] = IfThenElse(x<5 ? 1 : 0, (A[x]) + 1, (A[x + 1]) + 2);
   * B[y + 1] = IfThenElse(x<5 ? 1 : 0, (A[x]) + 1, (A[x + 1]) + 2);
   */

  stmt = registerize(stmt);

  /*
   * int A_1 = 0;
   * B[y] = IfThenElse(x<5 ? 1 : 0, A_1 + 1, (A[x + 1]) + 2);
   * B[y + 1] = IfThenElse(x<5 ? 1 : 0, A_1 + 1, (A[x + 1]) + 2);
   * A[x] = A_1;
   */

  std::ostringstream oss;
  oss << *stmt;

  const std::string& verification_pattern =
      R"IR(
# CHECK: int A_1 = 0;
# CHECK: B[y] = IfThenElse(x<5 ? 1 : 0, A_1 + 1, (A[x + 1]) + 2);
# CHECK: B[y + 1] = IfThenElse(x<5 ? 1 : 0, A_1 + 1, (A[x + 1]) + 2);
# CHECK: A[x] = A_1;)IR";

  torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
}

// Nested IfThenElse exprs can't promote to higher level scopes.
TEST(Registerizer, RegisterizerIfThenElseNested) {
  BufHandle a("A", {5}, kInt);
  BufHandle b("B", {5}, kInt);
  BufHandle c("C", {5}, kInt);
  BufHandle d("D", {5}, kInt);
  VarHandle x("x", kInt);

  StmtPtr stmt = Block::make({Store::make(
      a,
      {x},
      IfThenElse::make(
          CompareSelect::make(x, 3, CompareSelectOperation::kLT),
          IfThenElse::make(
              CompareSelect::make(x, 2, CompareSelectOperation::kEQ),
              Load::make(d, {x}),
              Load::make(b, {x})),
          IfThenElse::make(
              CompareSelect::make(x, 5, CompareSelectOperation::kEQ),
              Load::make(c, {x}),
              Load::make(d, {x}))))});

  /*
   * A[x] = IfThenElse(x<3 ? 1 : 0,
   *          IfThenElse(x==2 ? 1 : 0, D[x], B[x]),
   *            IfThenElse(x==5 ? 1 : 0, C[x], D[x]));
   */

  std::ostringstream before;
  before << *stmt;

  // No change.
  stmt = registerize(stmt);

  std::ostringstream after;
  after << *stmt;

  ASSERT_EQ(before.str(), after.str());
}

// Cannot registerize an access completely contained within an IfThenElse
// branch, since it is not a Stmt and cannot hold variable definitions. We need
// to check that we don't promote the initializer/finalizer to the enclosing
// Block.
TEST(Registerizer, RegisterizerIfThenElseInternal) {
  // Making these floats so they don't get simplified to a single access.
  BufHandle a("A", {5}, kFloat);
  BufHandle b("B", {5}, kFloat);
  VarHandle x("x", kInt);

  StmtPtr stmt = Block::make({Store::make(
      a,
      {x},
      IfThenElse::make(
          CompareSelect::make(x, 3, CompareSelectOperation::kLT),
          Add::make(Load::make(b, {x}), Load::make(b, {x})),
          Load::make(b, {x})))});

  /*
   * A[x] = IfThenElse(x<3 ? 1 : 0, (B[x]) + (B[x]), B[x]);
   */

  std::ostringstream before;
  before << *stmt;

  // No change.
  stmt = registerize(stmt);

  std::ostringstream after;
  after << *stmt;

  ASSERT_EQ(before.str(), after.str());

  // If this was a Cond instead of an IfThenElse then we could registerize the
  // two accesses to B[x] in the True branch.

  // Actually lets verify that.

  stmt = Block::make({Cond::make(
      CompareSelect::make(x, 3, CompareSelectOperation::kLT),
      Store::make(a, {x}, Add::make(Load::make(b, {x}), Load::make(b, {x}))),
      Store::make(a, {x}, Load::make(b, {x})))});

  /*
   * if (x<3 ? 1 : 0) {
   *   A[x] = (B[x]) + (B[x]);
   * } else {
   *   A[x] = B[x];
   * }
   */

  stmt = registerize(stmt);

  /*
   * if (x<3 ? 1 : 0) {
   *   float B_1 = B[x];
   *   A[x] = B_1 + B_1;
   * } else {
   *   A[x] = B[x];
   * }
   */

  std::ostringstream oss;
  oss << *stmt;

  const std::string& verification_pattern =
      R"IR(
# CHECK-NOT: int
# CHECK-NOT: float
# CHECK: if (x<3
# CHECK:   float B_1 =
# CHECK:   A[x] = B_1 + B_1
# CHECK: } else {
# CHECK:   A[x] = B[x]
# CHECK: }
# CHECK-NOT: A[x]
# CHECK-NOT: B[x])IR";

  torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
}

// Can registerize a load that occurs in the condition of an IfThenElse;
TEST(Registerizer, RegisterizerIfThenElseCondition) {
  BufHandle a("A", {5}, kInt);
  BufHandle b("B", {5}, kInt);
  BufHandle c("C", {5}, kInt);
  VarHandle x("x", kInt);

  StmtPtr stmt = Block::make(
      {Store::make(a, {x}, Load::make(a, {x})),
       Store::make(
           a,
           {x},
           IfThenElse::make(
               CompareSelect::make(
                   Load::make(a, {x}), 5, CompareSelectOperation::kLT),
               Load::make(b, {0}),
               Load::make(c, {0})))});

  /*
   * A[x] = A[x];       <---- just here so there are enough accesses to combine.
   * A[x] = IfThenElse((A[x])<5 ? 1 : 0, B[0], C[0]);
   */

  stmt = registerize(stmt);

  /*
   * int A_1 = A[x];
   * A_1 = A_1;
   * A_1 = IfThenElse(A_1<5 ? 1 : 0, B[0], C[0]);
   * A[x] = A_1;
   */

  std::ostringstream oss;
  oss << *stmt;

  const std::string& verification_pattern =
      R"IR(
# CHECK: int A_1 = A[x];
# CHECK: A_1 = IfThenElse(A_1<5 ? 1 : 0, B[0], C[0]);
# CHECK: A[x] = A_1;)IR";

  torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
}

// Appearing in the condition of a Cond makes it visible to the enclosing scope,
// and so we can registerize internal usages.
TEST(Registerizer, RegisterizerIfThenElseConditionUnhidden) {
  BufHandle a("A", {5}, kInt);
  BufHandle b("B", {5}, kInt);
  BufHandle c("C", {5}, kInt);
  VarHandle x("x", kInt);

  StmtPtr stmt = Block::make({Store::make(
      b,
      {x},
      IfThenElse::make(
          CompareSelect::make(
              Load::make(a, {x}), 5, CompareSelectOperation::kLT),
          Add::make(Load::make(a, {x}), 1),
          Add::make(Load::make(a, {x}), 10)))});

  /*
   * B[x] = IfThenElse((A[x])<5 ? 1 : 0, (A[x]) + 1, (A[x]) + 10);
   */

  stmt = registerize(stmt);

  /*
   * int A_1 = A[x];
   * B[x] = IfThenElse(A_1<5 ? 1 : 0, A_1 + 1, A_1 + 10);
   */

  std::ostringstream oss;
  oss << *stmt;

  const std::string& verification_pattern =
      R"IR(
# CHECK: int A_1 = A[x];
# CHECK: B[x] = IfThenElse(A_1<5 ? 1 : 0, A_1 + 1, A_1 + 10);)IR";

  torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
}

// Cannot promote accesses internal to IfThenElse branches even if the enclosing
// scope if conditional.
TEST(Registerizer, RegisterizerConditionBranchOnly) {
  BufHandle a("A", {5}, kInt);
  VarHandle x("x", kInt);
  StmtPtr stmt = Block::make({For::make(
      x,
      0,
      10,
      Block::make({
          Cond::make(
              CompareSelect::make(x, 5, CompareSelectOperation::kLT),
              Store::make(
                  a,
                  {x},
                  IfThenElse::make(
                      CompareSelect::make(x, 5, CompareSelectOperation::kLT),
                      Add::make(Load::make(a, {x}), x),
                      Add::make(Load::make(a, {x - 5}), x))),
              Store::make(
                  a,
                  {x - 5},
                  IfThenElse::make(
                      CompareSelect::make(x, 5, CompareSelectOperation::kLT),
                      Add::make(Load::make(a, {x}), x),
                      Add::make(Load::make(a, {x - 5}), x)))),
      }))});
  stmt = IRSimplifier::simplify(stmt);

  std::ostringstream before;
  before << *stmt;

  /* for (int x = 0; x < 10; x++) {
   *   if (x<5 ? 1 : 0) {
   *     A[x] = IfThenElse(x<5 ? 1 : 0, (A[x]) + x, (A[x - 5]) + x);
   *   } else {
   *     A[x - 5] = IfThenElse(x<5 ? 1 : 0, (A[x]) + x, (A[x - 5]) + x);
   *   }
   * }
   */

  // No change.
  stmt = registerize(stmt);

  std::ostringstream after;
  after << *stmt;

  ASSERT_EQ(before.str(), after.str());
}

// We can registerize an IfThenElse that appears in the condition branch of a
// Cond. This is a weird but valid thing to do.
TEST(Registerizer, RegisterizerCondIfThenElse) {
  BufHandle a("A", {5}, kInt);
  BufHandle b("B", {5}, kInt);
  BufHandle c("C", {5}, kInt);
  VarHandle x("x", kInt);

  StmtPtr stmt = Block::make({Cond::make(
      CompareSelect::make(
          IfThenElse::make(
              CompareSelect::make(
                  Load::make(a, {x}), 5, CompareSelectOperation::kLT),
              Load::make(a, {x}),
              Load::make(b, {x})),
          x,
          CompareSelectOperation::kEQ),
      Store::make(c, {x}, Add::make(Load::make(c, {x}), 1)),
      nullptr)});

  /*
   * if ((IfThenElse((A[x])<5 ? 1 : 0, A[x], B[x]))==x ? 1 : 0) {
   *   C[x] = (C[x]) + 1;
   * }
   */

  stmt = registerize(stmt);

  // access to A can be registerized, but not B or C

  /*
   * int A_1 = A[x];
   * if ((IfThenElse(A_1<5 ? 1 : 0, A_1, B[x]))==x ? 1 : 0) {
   *   C[x] = (C[x]) + 1;
   * }
   */

  std::ostringstream oss;
  oss << *stmt;

  const std::string& verification_pattern =
      R"IR(
# CHECK: int A_1 = A[x];
# CHECK: if ((IfThenElse(A_1<5 ? 1 : 0, A_1, B[x]
# CHECK:   C[x] = (C[x]) + 1;)IR";

  torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
}

// Can registerize a conditional access in the RHS of a store unhidden by it's
// LHS, and hoist it out of a loop.
TEST(Registerizer, RegisterizerIfThenElseLoop) {
  BufHandle a("A", {5}, kInt);
  BufHandle b("B", {5}, kInt);
  VarHandle x("x", kInt);
  VarHandle y("y", kInt);

  StmtPtr stmt = For::make(
      y,
      0,
      10,
      Store::make(
          a,
          {x},
          IfThenElse::make(
              CompareSelect::make(x, 3, CompareSelectOperation::kLT),
              Load::make(a, {x}),
              Load::make(b, {y}))));

  /*
   * for (int y = 0; y < 10; y++) {
   *   A[x] = IfThenElse(x<3 ? 1 : 0, A[x], B[y]);
   * }
   */

  stmt = registerize(stmt);

  /*
   * int A_1 = A[x];
   * for (int y = 0; y < 10; y++) {
   *   A_1 = IfThenElse(x<3 ? 1 : 0, A_1, B[y]);
   * }
   * A[x] = A_1;
   */

  std::ostringstream oss;
  oss << *stmt;

  const std::string& verification_pattern =
      R"IR(
# CHECK: int A_1 = A[x];
# CHECK: for (
# CHECK:   A_1 = IfThenElse(x<3 ? 1 : 0, A_1, B[y]);
# CHECK: }
# CHECK: A[x] = A_1;)IR";

  torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
}

// Cannot registerize if the RHS overlaps the access creating visibility.
TEST(Registerizer, RegisterizerIfThenElseLoopCut) {
  BufHandle a("A", {5}, kInt);
  BufHandle b("B", {5}, kInt);
  VarHandle x("x", kInt);
  VarHandle y("y", kInt);

  StmtPtr stmt = Block::make({For::make(
      y,
      0,
      10,
      Store::make(
          a,
          {x},
          IfThenElse::make(
              CompareSelect::make(x, 3, CompareSelectOperation::kLT),
              Load::make(a, {x}),
              Load::make(a, {y}))))});

  /*
   * for (int y = 0; y < 10; y++) {
   *   A[x] = IfThenElse(x<3 ? 1 : 0, A[x], A[y]);
   * }
   */

  std::ostringstream before;
  before << *stmt;

  // No change.
  stmt = registerize(stmt);

  std::ostringstream after;
  after << *stmt;

  ASSERT_EQ(before.str(), after.str());
}

// Simple case where an access is cut by an overlapping access later in the
// program, we can registerize up until the overlap.
TEST(Registerizer, RegisterizerPartialAfter) {
  BufHandle a("A", {1}, kInt);
  VarHandle x("x", kInt);
  StmtPtr stmt = Block::make(
      {Store::make(a, {0}, 0),
       For::make(
           x,
           0,
           10,
           Block::make(
               {Store::make(a, {0}, Add::make(Load::make(a, {0}), x))})),
       For::make(x, 1, 10, Store::make(a, {x}, Load::make(a, {x - 1})))});

  /*
   * A[0] = 0;
   * for (int x = 0; x < 10; x++) {
   *   A[0] = (A[0]) + x;
   * }
   * for (int x = 1; x < 10; x++) {
   *   A[x] = A[x - 1];
   * }
   */

  stmt = registerize(stmt);

  /*
   * int A_1 = 0;
   * for (int x = 0; x < 10; x++) {
   *   A_1 = A_1 + x;
   * }
   * A[0] = A_1;
   * for (int x = 1; x < 10; x++) {
   *   A[x] = A[x - 1];
   * }
   */

  std::ostringstream oss;
  oss << *stmt;

  const std::string& verification_pattern =
      R"IR(
# CHECK: int A_1 = 0;
# CHECK: for (
# CHECK:   A_1 = A_1 + x;
# CHECK: }
# CHECK: A[0] = A_1;
# CHECK: for (
# CHECK:   A[x] = A[x - 1];
# CHECK: }
# CHECK-NOT: A)IR";

  torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
}

// We can registerize an access which overlaps a previous access, the
// initializer must be inserted after the previous access.
TEST(Registerizer, RegisterizerPartialBefore) {
  BufHandle a("A", {1}, kInt);
  VarHandle x("x", kInt);
  StmtPtr stmt = Block::make(
      {For::make(x, 1, 10, Store::make(a, {x}, Load::make(a, {x - 1}))),
       Store::make(a, {0}, 0),
       For::make(
           x,
           0,
           10,
           Block::make(
               {Store::make(a, {0}, Add::make(Load::make(a, {0}), x))}))});

  /*
   * for (int x = 1; x < 10; x++) {
   *   A[x] = A[x - 1];
   * }
   * A[0] = 0;
   * for (int x = 0; x < 10; x++) {
   *   A[0] = (A[0]) + x;
   * }
   */

  stmt = registerize(stmt);

  /*
   * for (int x = 1; x < 10; x++) {
   *   A[x] = A[x - 1];
   * }
   * int A_1 = 0;
   * for (int x = 0; x < 10; x++) {
   *   A_1 = A_1 + x;
   * }
   * A[0] = A_1;
   */

  std::ostringstream oss;
  oss << *stmt;

  const std::string& verification_pattern =
      R"IR(
# CHECK-NOT: int
# CHECK: for (
# CHECK:   A[x] = A[x - 1];
# CHECK: }
# CHECK: int A_1 = 0;
# CHECK: for (
# CHECK:   A_1 = A_1 + x;
# CHECK: }
# CHECK: A[0] = A_1;)IR";

  torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
}

// The combination of the previous two tests, an access is cut by an overlapping
// access in both directions.
TEST(Registerizer, RegisterizerPartialInside) {
  BufHandle a("A", {1}, kInt);
  VarHandle x1("x1", kInt);
  VarHandle x2("x2", kInt);
  VarHandle x3("x3", kInt);
  StmtPtr stmt = Block::make(
      {Store::make(a, {0}, 2),
       For::make(
           x1, 0, 10, Store::make(a, {0}, Add::make(Load::make(a, {0}), x1))),
       For::make(x2, 1, 10, Store::make(a, {x2}, Load::make(a, {x2 - 1}))),
       For::make(
           x3, 0, 10, Store::make(a, {0}, Add::make(Load::make(a, {0}), x3)))});

  /*
   * A[0] = 2;
   * for (int x1 = 0; x1 < 10; x1++) {
   *   A[0] = (A[0]) + x1;
   * }
   * for (int x2 = 1; x2 < 10; x2++) {
   *   A[x2] = A[x2 - 1];
   * }
   * for (int x3 = 0; x3 < 10; x3++) {
   *   A[0] = (A[0]) + x3;
   * }
   */

  stmt = registerize(stmt);

  /*
   * int A_1 = 2;
   * for (int x1 = 0; x1 < 10; x1++) {
   *   A_1 = A_1 + x1;
   * }
   * A[0] = A_1;
   * for (int x2 = 1; x2 < 10; x2++) {
   *   A[x2] = A[x2 - 1];
   * }
   * int A_2 = A[0];
   * for (int x3 = 0; x3 < 10; x3++) {
   *   A_2 = A_2 + x3;
   * }
   * A[0] = A_2;
   */

  std::ostringstream oss;
  oss << *stmt;

  const std::string& verification_pattern =
      R"IR(
# CHECK: int A_1 = 2;
# CHECK: for (
# CHECK:   A_1 = A_1 + x1;
# CHECK: }
# CHECK: A[0] = A_1;
# CHECK: for (
# CHECK:   A[x2] =
# CHECK: }
# CHECK: int A_2 = A[0];
# CHECK: for (
# CHECK:   A_2 = A_2 + x3;
# CHECK: }
# CHECK: A[0] = A_2;)IR";

  torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
}

// An element could be registerized program wide but is cut by a conditional
// access, we should break this into two scalars and write back to the buffer
// before the condition.
TEST(Registerizer, RegisterizerPartialCondition) {
  BufHandle a("A", {1}, kInt);
  VarHandle x("x", kInt);
  StmtPtr stmt = Block::make(
      {Store::make(a, {0}, 2),
       For::make(
           x, 0, 10, Store::make(a, {0}, Add::make(Load::make(a, {0}), x))),
       Cond::make(
           CompareSelect::make(x, 5, CompareSelectOperation::kLT),
           Store::make(a, {x}, Load::make(a, {x - 1})),
           nullptr),
       For::make(
           x, 0, 10, Store::make(a, {0}, Add::make(Load::make(a, {0}), x)))});

  /*
   * A[0] = 2;
   * for (int x = 0; x < 10; x++) {
   *   A[0] = (A[0]) + x;
   * }
   * if (x<5 ? 1 : 0) {
   *   A[x] = A[x - 1];
   * }
   * for (int x = 0; x < 10; x++) {
   *   A[0] = (A[0]) + x;
   * }
   */

  stmt = registerize(stmt);

  /*
   * int A_1 = 2;
   * for (int x = 0; x < 10; x++) {
   *   A_1 = A_1 + x;
   * }
   * A[0] = A_1;
   * if (x<5 ? 1 : 0) {
   *   A[x] = A[x - 1];
   * }
   * int A_2 = A[0];
   * for (int x = 0; x < 10; x++) {
   *   A_2 = A_2 + x;
   * }
   * A[0] = A_2;
   */

  std::ostringstream oss;
  oss << *stmt;

  const std::string& verification_pattern =
      R"IR(
# CHECK: int A_1 = 2;
# CHECK: for (
# CHECK:   A_1 = A_1 + x;
# CHECK: }
# CHECK: A[0] = A_1;
# CHECK: if (
# CHECK:   A[x] =
# CHECK: }
# CHECK: int A_2 = A[0];
# CHECK: for (
# CHECK:   A_2 = A_2 + x;
# CHECK: }
# CHECK: A[0] = A_2;)IR";

  torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
}

// Tests case where an access is cut by an internal conditional access which
// itself is registerized.
TEST(Registerizer, RegisterizerPartialConditionInternalCut) {
  BufHandle a("A", {1}, kInt);
  VarHandle x("x", kInt);
  StmtPtr stmt = Block::make(
      {Store::make(a, {0}, 1),
       Store::make(a, {0}, 3),
       Cond::make(
           CompareSelect::make(x, 5, CompareSelectOperation::kLT),
           Block::make({Store::make(a, {x}, 1), Store::make(a, {x}, 3)}),
           nullptr),
       Store::make(a, {0}, 4),
       Store::make(a, {0}, 6)});

  /*
   * A[0] = 1;
   * A[0] = 3;
   * if (x<5 ? 1 : 0) {
   *   A[x] = 1;
   *   A[x] = 3;
   * }
   * A[0] = 4;
   * A[0] = 6;
   */

  stmt = registerize(stmt);

  /*
   * int A_1 = 1;
   * A_1 = 3;
   * A[0] = A_1;
   * if (x<5 ? 1 : 0) {
   *   int A_2 = 1;
   *   A_2 = 3;
   *   A[x] = A_2;
   * }
   * int A_3 = 4;
   * A_3 = 6;
   * A[0] = A_3;
   */

  std::ostringstream oss;
  oss << *stmt;

  const std::string& verification_pattern =
      R"IR(
# CHECK: int A_1 = 1;
# CHECK: A_1 = 3
# CHECK: A[0] = A_1;
# CHECK: if (
# CHECK:   int A_2 = 1;
# CHECK:   A_2 = 3;
# CHECK:   A[x] = A_2;
# CHECK: }
# CHECK: int A_3 = 4;
# CHECK: A_3 = 6;
# CHECK: A[0] = A_3;)IR";

  torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
}

// First statement in condition closes outer access, but can be registerized
// with later statements.
TEST(Registerizer, RegisterizerPartialConditionInternalStart) {
  BufHandle a("A", {1}, kInt);
  VarHandle x("x", kInt);
  StmtPtr stmt = Block::make(
      {Store::make(a, {0}, 1),
       Store::make(a, {0}, 3),
       Cond::make(
           CompareSelect::make(x, 5, CompareSelectOperation::kLT),
           Block::make({Store::make(a, {x}, 1), Store::make(a, {x}, 3)}),
           nullptr),
       Store::make(a, {x}, 4),
       Store::make(a, {x}, 6)});

  /*
   * A[0] = 1;
   * A[0] = 3;
   * if (x<5 ? 1 : 0) {
   *   A[x] = 1;
   *   A[x] = 3;
   * }
   * A[x] = 4;
   * A[x] = 6;
   */

  stmt = registerize(stmt);

  /*
   * int A_1 = 1;
   * A_1 = 3;
   * A[0] = A_1;
   * int A_2 = A[x];    <--- must read from the input here.
   * if (x<5 ? 1 : 0) {
   *   A_2 = 1;
   *   A_2 = 3;
   * }
   * A_2 = 4;
   * A_2 = 6;
   * A[x] = A_2;
   */

  // TODO: I suppose we could refactor with a conditional initializer?

  std::ostringstream oss;
  oss << *stmt;

  const std::string& verification_pattern =
      R"IR(
# CHECK: int A_1 = 1;
# CHECK: A_1 = 3
# CHECK: A[0] = A_1;
# CHECK: int A_2 = A[x];
# CHECK: if (
# CHECK:   A_2 = 1;
# CHECK:   A_2 = 3;
# CHECK: }
# CHECK: A_2 = 4;
# CHECK: A_2 = 6;
# CHECK: A[x] = A_2;)IR";

  torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
}

// An access cuts two open overlaps and creates four scalar variables.
TEST(Registerizer, RegisterizerPartialOverlapsTwo) {
  BufHandle a("A", {1}, kInt);
  VarHandle x("x", kInt);
  StmtPtr stmt = Block::make(
      {Store::make(a, {1}, Load::make(a, {0})),
       Store::make(a, {0}, Load::make(a, {1})),
       Store::make(a, {0}, Load::make(a, {1})),
       For::make(x, 1, 10, Store::make(a, {x}, x)),
       Store::make(a, {1}, Load::make(a, {0})),
       Store::make(a, {0}, Load::make(a, {1})),
       Store::make(a, {0}, Load::make(a, {1}))});

  /*
   * A[1] = A[0];
   * A[0] = A[1];
   * A[0] = A[1];
   * for (int x = 1; x < 10; x++) {
   *   A[x] = x;
   * }
   * A[1] = A[0];
   * A[0] = A[1];
   * A[0] = A[1];
   */

  stmt = registerize(stmt);

  /*
   * int A_1 = A[0];
   * int A_2 = A_1;
   * A_1 = A_2;
   * A_1 = A_2;
   * A[1] = A_2;
   * A[0] = A_1;
   * for (int x = 1; x < 10; x++) {
   *   A[x] = x;
   * }
   * int A_3 = A[0];
   * int A_4 = A_3;
   * A_3 = A_4;
   * A_3 = A_4;
   * A[1] = A_4;
   * A[0] = A_3;
   */

  std::ostringstream oss;
  oss << *stmt;

  const std::string& verification_pattern =
      R"IR(
# CHECK: int A_1 = A[0];
# CHECK: int A_2 = A_1;
# CHECK: A_1 = A_2;
# CHECK: A_1 = A_2;
# CHECK: A[1] = A_2;
# CHECK: A[0] = A_1;
# CHECK: for (
# CHECK:   A[x] = x;
# CHECK: }
# CHECK: int A_3 = A[0];
# CHECK: int A_4 = A_3;
# CHECK: A_3 = A_4;
# CHECK: A_3 = A_4;
# CHECK: A[1] = A_4;
# CHECK: A[0] = A_3;)IR";

  torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
}

// Nested blocks will automatically be flattened and do not provent
// registerization of enclosed accesses.
TEST(Registerizer, RegisterizerNestedBlocks) {
  BufHandle a("A", {1}, kInt);
  VarHandle x("x", kInt);
  StmtPtr stmt = Block::make(
      // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks)
      {Store::make(a, {0}, Add::make(Load::make(a, {0}), 1)),
       Block::make({Store::make(a, {0}, Add::make(Load::make(a, {0}), 2))}),
       Block::make(
           {Store::make(a, {0}, Add::make(Load::make(a, {0}), 3)),
            Block::make(
                {Store::make(a, {0}, Add::make(Load::make(a, {0}), 4))})})});

  /*
   * A[0] = (A[0]) + 1;
   * {
   *   A[0] = (A[0]) + 2;
   * }
   * {
   *   A[0] = (A[0]) + 3;
   *   {
   *     A[0] = (A[0]) + 4;
   *   }
   * }
   */

  stmt = registerize(stmt);

  /*
   * int A_1 = A[0];
   * A_1 = A_1 + 1;
   * A_1 = A_1 + 2;
   * A_1 = A_1 + 3;
   * A_1 = A_1 + 4;
   * A[0] = A_1;
   */

  std::ostringstream oss;
  oss << *stmt;

  const std::string& verification_pattern =
      R"IR(
# CHECK: int A_1 = A[0];
# CHECK: A_1 = A_1 + 1;
# CHECK: A_1 = A_1 + 2;
# CHECK: A_1 = A_1 + 3;
# CHECK: A_1 = A_1 + 4;
# CHECK: A[0] = A_1;)IR";

  torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
}

// The access can be registerized internally to a condition, but must ensure
// that both initializer and finalizer are within the same condition.
TEST(Registerizer, RegisterizerNestedConditions) {
  BufHandle a("A", {1}, kInt);
  VarHandle x("x", kInt);
  StmtPtr stmt = Block::make({Cond::make(
      CompareSelect::make(x, 5, CompareSelectOperation::kLT),
      Block::make(
          {Store::make(a, {0}, Add::make(Load::make(a, {0}), 1)),
           Cond::make(
               CompareSelect::make(x, 2, CompareSelectOperation::kEQ),
               Store::make(a, {0}, Add::make(Load::make(a, {0}), 1)),
               nullptr)}),
      nullptr)});

  /*
   * if (x<5 ? 1 : 0) {
   *   A[0] = (A[0]) + 1;
   *   if (x==2 ? 1 : 0) {
   *
   *     A[0] = (A[0]) + 1;
   *   }
   * }
   */

  stmt = registerize(stmt);

  /*
   * if (x<5 ? 1 : 0) {
   *   int A_1 = A[0];
   *   A_1 = A_1 + 1;
   *   if (x==2 ? 1 : 0) {
   *     A_1 = A_1 + 1;
   *   }
   * A[0] = A_1;
   * }
   */

  std::ostringstream oss;
  oss << *stmt;

  const std::string& verification_pattern =
      R"IR(
# CHECK: if (x<5
# CHECK:   int A_1 = A[0];
# CHECK:   A_1 = A_1 + 1;
# CHECK:   if (x==2
# CHECK:     A_1 = A_1 + 1;
# CHECK:   }
# CHECK: A[0] = A_1;
# CHECK: })IR";

  torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
}

// If an access exists outside the scope of the condition then we can lift
// nested conditional usages into the same scalar.
TEST(Registerizer, RegisterizerNestedConditionsUnhidden) {
  BufHandle a("A", {1}, kInt);
  VarHandle x("x", kInt);
  StmtPtr stmt = Block::make(
      {Store::make(a, {0}, Add::make(Load::make(a, {0}), 1)),
       Cond::make(
           CompareSelect::make(x, 5, CompareSelectOperation::kLT),
           Block::make(
               {Store::make(a, {1}, 1),
                Cond::make(
                    CompareSelect::make(x, 2, CompareSelectOperation::kEQ),
                    Store::make(a, {0}, Add::make(Load::make(a, {0}), 1)),
                    nullptr)}),
           nullptr)});

  /*
   * A[0] = (A[0]) + 1;
   * if (x<5 ? 1 : 0) {
   *   A[1] = 1;
   *   if (x==2 ? 1 : 0) {
   *     A[0] = (A[0]) + 1;
   *   }
   * }
   */

  stmt = registerize(stmt);

  /*
   * int A_1 = A[0];
   * A_1 = A_1 + 1;
   * if (x<5 ? 1 : 0) {
   *   A[1] = 1;
   *   if (x==2 ? 1 : 0) {
   *     A_1 = A_1 + 1;
   *   }
   * }
   * A[0] = A_1;
   */

  std::ostringstream oss;
  oss << *stmt;

  const std::string& verification_pattern =
      R"IR(
# CHECK: int A_1 = A[0];
# CHECK: A_1 = A_1 + 1;
# CHECK: if (x<5
# CHECK:   A[1] = 1;
# CHECK:   if (x==2
# CHECK:     A_1 = A_1 + 1;
# CHECK: A[0] = A_1;)IR";

  torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
}

TEST(Registerizer, RegisterizerNestedConditionsHiddenFirst) {
  BufHandle a("A", {1}, kInt);
  VarHandle x("x", kInt);
  StmtPtr stmt = Block::make(
      {Cond::make(
           CompareSelect::make(x, 2, CompareSelectOperation::kEQ),
           Store::make(a, {0}, Add::make(Load::make(a, {0}), 1)),
           nullptr),
       Cond::make(
           CompareSelect::make(x, 5, CompareSelectOperation::kLT),
           Block::make({Cond::make(
               CompareSelect::make(x, 2, CompareSelectOperation::kEQ),
               Store::make(a, {0}, Add::make(Load::make(a, {0}), 1)),
               nullptr)}),
           nullptr)});

  /*
   * if (x==2 ? 1 : 0) {
   *   A[0] = (A[0]) + 1;
   * }
   * if (x<5 ? 1 : 0) {
   *   if (x==2 ? 1 : 0) {
   *     A[0] = (A[0]) + 1;
   *   }
   * }
   */

  std::ostringstream before;
  before << *stmt;

  // No change.
  stmt = registerize(stmt);

  std::ostringstream after;
  after << *stmt;

  ASSERT_EQ(before.str(), after.str());

  // NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores)
  stmt = registerize(stmt);
}

TEST(Registerizer, RegisterizerNestedConditionsHiddenSecond) {
  BufHandle a("A", {1}, kInt);
  VarHandle x("x", kInt);
  StmtPtr stmt = Block::make(
      {Cond::make(
           CompareSelect::make(x, 5, CompareSelectOperation::kLT),
           Block::make({Cond::make(
               CompareSelect::make(x, 2, CompareSelectOperation::kEQ),
               Store::make(a, {0}, Add::make(Load::make(a, {0}), 1)),
               nullptr)}),
           nullptr),
       Cond::make(
           CompareSelect::make(x, 2, CompareSelectOperation::kEQ),
           Store::make(a, {0}, Add::make(Load::make(a, {0}), 1)),
           nullptr)});

  /*
   * if (x<5 ? 1 : 0) {
   *   if (x==2 ? 1 : 0) {
   *     A[0] = (A[0]) + 1;
   *   }
   * }
   * if (x==2 ? 1 : 0) {
   *   A[0] = (A[0]) + 1;
   * }
   */

  std::ostringstream before;
  before << *stmt;

  // No change.
  stmt = registerize(stmt);

  std::ostringstream after;
  after << *stmt;

  ASSERT_EQ(before.str(), after.str());

  // NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores)
  stmt = registerize(stmt);
}

// If an access is cut by another access internal to a condition block, it still
// cuts the access.
TEST(Registerizer, RegisterizerNestedConditionsCut) {
  BufHandle a("A", {1}, kInt);
  VarHandle x("x", kInt);
  StmtPtr stmt = Block::make(
      {Store::make(a, {0}, Add::make(Load::make(a, {0}), 1)),
       Cond::make(
           CompareSelect::make(x, 5, CompareSelectOperation::kLT),
           Block::make(
               {Store::make(a, {x}, 1),
                Cond::make(
                    CompareSelect::make(x, 2, CompareSelectOperation::kEQ),
                    Store::make(a, {0}, Add::make(Load::make(a, {0}), 1)),
                    nullptr)}),
           nullptr)});

  /*
   * A[0] = (A[0]) + 1;
   * if (x<5 ? 1 : 0) {
   *   A[x] = 1;
   *   if (x==2 ? 1 : 0) {
   *
   *     A[0] = (A[0]) + 1;
   *   }
   * }
   */

  std::ostringstream before;
  before << *stmt;

  // No change.
  stmt = registerize(stmt);

  std::ostringstream after;
  after << *stmt;

  ASSERT_EQ(before.str(), after.str());
}

TEST(Registerizer, RegisterizerNestedConditionLoopHidden) {
  BufHandle a("A", {10}, kInt);
  BufHandle b("B", {10}, kInt);
  VarHandle x("x", kInt);
  StmtPtr stmt = Block::make(
      {Cond::make(
           CompareSelect::make(x, 2, CompareSelectOperation::kEQ),
           Store::make(a, {0}, Add::make(Load::make(a, {0}), 1)),
           nullptr),
       For::make(
           x,
           0,
           10,
           Block::make(
               {Store::make(b, {x}, 0),
                Cond::make(
                    CompareSelect::make(x, 2, CompareSelectOperation::kEQ),
                    Store::make(a, {0}, Add::make(Load::make(a, {0}), 1)),
                    nullptr)}))});

  /*
   * if (x==2 ? 1 : 0) {
   *   A[0] = (A[0]) + 1;
   * }
   * for (int x = 0; x < 10; x++) {
   *   B[x] = 0;     <-- this is only here to prevent Loop/Cond reordering.
   *   if (x==2 ? 1 : 0) {
   *     A[0] = (A[0]) + 1;
   *   }
   * }
   */

  std::ostringstream before;
  before << *stmt;

  // No change.
  stmt = registerize(stmt);

  std::ostringstream after;
  after << *stmt;

  ASSERT_EQ(before.str(), after.str());
}

// Three loops and four element regions, three of which should be registerized
// at different levels of the IR.
TEST(Registerizer, RegisterizerNestedConditionThreeDeep) {
  BufHandle a("A", {10}, kInt);
  BufHandle b("B", {10}, kInt);
  VarHandle x("x", kInt);
  StmtPtr stmt = Block::make(
      {Store::make(a, {4}, 0),
       Cond::make(
           CompareSelect::make(x, 2, CompareSelectOperation::kGT),
           Cond::make(
               CompareSelect::make(x, 3, CompareSelectOperation::kGT),
               Block::make({
                   Cond::make(
                       CompareSelect::make(x, 4, CompareSelectOperation::kGT),
                       Block::make({
                           Store::make(
                               a, {1}, Add::make(Load::make(a, {1}), 1)),
                           Store::make(
                               a, {2}, Add::make(Load::make(a, {2}), 1)),
                           Store::make(
                               a, {3}, Add::make(Load::make(a, {3}), 1)),
                           Store::make(
                               a, {4}, Add::make(Load::make(a, {4}), 1)),
                           Store::make(
                               a, {1}, Add::make(Load::make(a, {1}), 1)),
                       }),
                       nullptr),
                   Store::make(a, {2}, Add::make(Load::make(a, {2}), 1)),
               }),
               nullptr),
           nullptr)});

  /*
   * A[4] = 0;
   * if (x>2 ? 1 : 0) {
   *   if (x>3 ? 1 : 0) {
   *     if (x>4 ? 1 : 0) {
   *       A[1] = (A[1]) + 1;
   *       A[2] = (A[2]) + 1;
   *       A[3] = (A[3]) + 1;
   *       A[4] = (A[4]) + 1;
   *       A[1] = (A[1]) + 1;
   *     }
   *     A[2] = (A[2]) + 1;
   *   }
   * }
   */

  stmt = registerize(stmt);

  /*
   * int A_1 = 0;
   * if (x>2 ? 1 : 0) {
   *   if (x>3 ? 1 : 0) {
   *     int A_3 = A[2];
   *     if (x>4 ? 1 : 0) {
   *       int A_2 = A[1];
   *       A_2 = A_2 + 1;
   *       A_3 = A_3 + 1;
   *       A[3] = (A[3]) + 1;
   *       A_1 = A_1 + 1;
   *       A_2 = A_2 + 1;
   *       A[1] = A_2;
   *     }
   *     A_3 = A_3 + 1;
   *     A[2] = A_3;
   *   }
   * }
   * A[4] = A_1;
   */

  std::ostringstream oss;
  oss << *stmt;

  const std::string& verification_pattern =
      R"IR(
# CHECK: int A_1 = 0;
# CHECK: if (x>2 ? 1 : 0) {
# CHECK:   if (x>3 ? 1 : 0) {
# CHECK:     int A_3 = A[2];
# CHECK:     if (x>4 ? 1 : 0) {
# CHECK:       int A_2 = A[1];
# CHECK:       A_2 = A_2 + 1;
# CHECK:       A_3 = A_3 + 1;
# CHECK:       A[3] = (A[3]) + 1;
# CHECK:       A_1 = A_1 + 1;
# CHECK:       A_2 = A_2 + 1;
# CHECK:       A[1] = A_2;
# CHECK:     }
# CHECK:     A_3 = A_3 + 1;
# CHECK:     A[2] = A_3;
# CHECK:   }
# CHECK: }
# CHECK: A[4] = A_1;)IR";

  torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
}

// Can replace a simple scalar access with a local variable even when that
// variable is an outer loop var.
TEST(Registerizer, RegisterizerNestedLoopSimple) {
  BufHandle a("A", {1}, kInt);
  VarHandle x("x", kInt);
  VarHandle y("y", kInt);
  StmtPtr stmt = Block::make({For::make(
      y,
      0,
      10,
      For::make(
          x,
          0,
          10,
          Block::make(
              {Store::make(a, {y}, Add::make(Load::make(a, {y}), x))})))});

  /*
   * for (int y = 0; y < 10; y++) {
   *   for (int x = 0; x < 10; x++) {
   *     A[y] = (A[y]) + x;
   *   }
   * }
   */

  stmt = registerize(stmt);

  /*
   * for (int y = 0; y < 10; y++) {
   *   int A_1 = A[y];
   *   for (int x = 0; x < 10; x++) {
   *     A_1 = A_1 + x;
   *   }
   * A[y] = A_1;
   * }
   */

  std::ostringstream oss;
  oss << *stmt;

  const std::string& verification_pattern =
      R"IR(
# CHECK: for (int y
# CHECK:   int A_1 = A[y];
# CHECK:   for (int x
# CHECK:     A_1 = A_1 + x;
# CHECK:   }
# CHECK:   A[y] = A_1;
# CHECK: })IR";

  torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
}

// Test the positive case of the hiddenAccess split, where an internal
// conditional access can be hoisted up through a loop to match an existing
// access in a higher scope and the two can be registerized.
TEST(Registerizer, RegisterizerHiddenAccessYes) {
  BufHandle a("A", {10}, kInt);
  BufHandle b("B", {10}, kInt);
  VarHandle x("x", kInt);
  VarHandle y("y", kInt);
  StmtPtr stmt = Block::make({Cond::make(
      CompareSelect::make(x, 2, CompareSelectOperation::kEQ),
      Block::make(
          {Store::make(a, {0}, 0),
           For::make(
               x,
               0,
               10,
               Block::make(
                   {Store::make(b, {x}, 0),
                    // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks)
                    Cond::make(
                        CompareSelect::make(x, 3, CompareSelectOperation::kEQ),
                        For::make(
                            y,
                            0,
                            10,
                            Store::make(
                                a, {0}, Add::make(Load::make(a, {0}), 1))),
                        nullptr)}))}),
      nullptr)});

  /*
   * if (x==2 ? 1 : 0) {
   *   A[0] = 0;
   *   for (int x = 0; x < 10; x++) {
   *     B[x] = 0;
   *     if (x==3 ? 1 : 0) {
   *       for (int y = 0; y < 10; y++) {
   *         A[0] = (A[0]) + 1;
   *       }
   *     }
   *   }
   * }
   */

  stmt = registerize(stmt);

  /*
   * if (x==2 ? 1 : 0) {
   *   int A_1 = 0;
   *   for (int x = 0; x < 10; x++) {
   *     B[x] = 0;
   *     if (x==3 ? 1 : 0) {
   *       for (int y = 0; y < 10; y++) {
   *         A_1 = A_1 + 1;
   *       }
   *     }
   *   }
   *   A[0] = A_1;
   * }
   */

  std::ostringstream oss;
  oss << *stmt;

  const std::string& verification_pattern =
      R"IR(
# CHECK: if (x==2
# CHECK:   int A_1 = 0;
# CHECK:   for (int x
# CHECK:     B[x] = 0;
# CHECK:     if (x==3
# CHECK:       for (int y
# CHECK:         A_1 = A_1 + 1;
# CHECK:       }
# CHECK:     }
# CHECK:   }
# CHECK:  A[0] = A_1;
# CHECK: })IR";

  torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
}

// Test the negative case of the hiddenAccess split, where the hoisted access is
// never unhidden at a higher scope and registerization occurs at the lower
// scope.
TEST(Registerizer, RegisterizerHiddenAccessNo) {
  BufHandle a("A", {10}, kInt);
  BufHandle b("B", {10}, kInt);
  VarHandle x("x", kInt);
  VarHandle y("y", kInt);
  StmtPtr stmt = Block::make({Cond::make(
      CompareSelect::make(x, 2, CompareSelectOperation::kEQ),
      Block::make({For::make(
          x,
          0,
          10,
          Block::make(
              {Store::make(b, {x}, 0),
               // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks)
               Cond::make(
                   CompareSelect::make(x, 3, CompareSelectOperation::kEQ),
                   For::make(
                       y,
                       0,
                       10,
                       Store::make(a, {0}, Add::make(Load::make(a, {0}), 1))),
                   nullptr)}))}),
      nullptr)});

  /*
   * if (x==2 ? 1 : 0) {
   *   A[0] = 0;
   *   for (int x = 0; x < 10; x++) {
   *     B[x] = 0;
   *     if (x==3 ? 1 : 0) {
   *       for (int y = 0; y < 10; y++) {
   *         A[0] = (A[0]) + 1;
   *       }
   *     }
   *   }
   * }
   */

  stmt = registerize(stmt);

  /*
   * if (x==2 ? 1 : 0) {
   *   for (int x = 0; x < 10; x++) {
   *     B[x] = 0;
   *     if (x==3 ? 1 : 0) {
   *       int A_1 = A[0];
   *       for (int y = 0; y < 10; y++) {
   *         A_1 = A_1 + 1;
   *       }
   *       A[0] = A_1;
   *     }
   *   }
   * }
   */

  std::ostringstream oss;
  oss << *stmt;

  const std::string& verification_pattern =
      R"IR(
# CHECK: if (x==2
# CHECK:   for (int x
# CHECK:     B[x] = 0;
# CHECK:     if (x==3
# CHECK:       int A_1 = A[0];
# CHECK:       for (int y
# CHECK:         A_1 = A_1 + 1;
# CHECK:       }
# CHECK:       A[0] = A_1;
# CHECK:     }
# CHECK:   }
# CHECK: })IR";

  torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
}

// In this case the conditional access must be hoisted by two loops, there are
// two accesses here one is unhidden and the other isnt. A[0] can be
// registerized but B[0] cannot.
TEST(Registerizer, RegisterizerHiddenAccessMultiLoop) {
  BufHandle a("A", {10}, kInt);
  BufHandle b("B", {10}, kInt);
  VarHandle x("x", kInt);
  VarHandle y("y", kInt);
  StmtPtr stmt = Block::make({Cond::make(
      CompareSelect::make(x, 2, CompareSelectOperation::kEQ),
      Block::make(
          {Store::make(a, {0}, 0),
           For::make(
               x,
               0,
               10,
               For::make(
                   y,
                   0,
                   10,
                   Block::make({Cond::make(
                       CompareSelect::make(y, 3, CompareSelectOperation::kEQ),
                       Block::make(
                           {Store::make(
                                a, {0}, Add::make(Load::make(a, {0}), 1)),
                            Store::make(
                                b, {0}, Add::make(Load::make(b, {0}), 1))}),
                       nullptr)})))}),
      nullptr)});

  /*
   * if (x==2 ? 1 : 0) {
   *   A[0] = 0;
   *   for (int x = 0; x < 10; x++) {
   *     for (int y = 0; y < 10; y++) {
   *       if (y==3 ? 1 : 0) {
   *         A[0] = (A[0]) + 1;
   *         B[0] = (B[0]) + 1;
   *       }
   *     }
   *   }
   * }
   */

  stmt = registerize(stmt);

  /*
   * if (x==2 ? 1 : 0) {
   *   int A_1 = 0;
   *   for (int x = 0; x < 10; x++) {
   *     for (int y = 0; y < 10; y++) {
   *       if (y==3 ? 1 : 0) {
   *         A_1 = A_1 + 1;
   *         B[0] = (B[0]) + 1;
   *       }
   *     }
   *   }
   *   A[0] = A_1;
   * }
   */

  std::ostringstream oss;
  oss << *stmt;

  const std::string& verification_pattern =
      R"IR(
# CHECK: if (x==2
# CHECK:   int A_1 = 0;
# CHECK:   for (int x
# CHECK:     for (int y
# CHECK:       if (y==3
# CHECK:         A_1 = A_1 + 1;
# CHECK:         B[0] = (B[0]) + 1;
# CHECK:       }
# CHECK:     }
# CHECK:   }
# CHECK:  A[0] = A_1;
# CHECK: })IR";

  torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
}

// Accesses are registerized inside two conditions, but the immediate parent is
// not a condition.
TEST(Registerizer, RegisterizerTwoConditionalLoops) {
  BufHandle a("A", {1}, kInt);
  VarHandle x("x", kInt);
  StmtPtr stmt = Block::make(
      {Cond::make(
           CompareSelect::make(x, 5, CompareSelectOperation::kLT),
           For::make(
               x, 0, 10, Store::make(a, {0}, Add::make(Load::make(a, {0}), 1))),
           nullptr),
       Cond::make(
           CompareSelect::make(x, 5, CompareSelectOperation::kGT),
           For::make(
               x, 0, 10, Store::make(a, {0}, Add::make(Load::make(a, {0}), 1))),
           nullptr)});

  /*
   * if (x<5 ? 1 : 0) {
   *   for (int x = 0; x < 10; x++) {
   *     A[0] = (A[0]) + 1;
   *   }
   * }
   * if (x>5 ? 1 : 0) {
   *   for (int x = 0; x < 10; x++) {
   *     A[0] = (A[0]) + 1;
   *   }
   * }
   */

  stmt = registerize(stmt);

  /*
   * if (x<5 ? 1 : 0) {
   *   int A_1 = A[0];
   *   for (int x = 0; x < 10; x++) {
   *     A_1 = A_1 + 1;
   *   }
   *   A[0] = A_1;
   * }
   * if (x>5 ? 1 : 0) {
   *   int A_2 = A[0];
   *   for (int x = 0; x < 10; x++) {
   *     A_2 = A_2 + 1;
   *   }
   *   A[0] = A_2;
   * }
   */

  std::ostringstream oss;
  oss << *stmt;

  const std::string& verification_pattern =
      R"IR(
# CHECK: if (x<5
# CHECK:   int A_1 = A[0];
# CHECK:   for (int x
# CHECK:     A_1 = A_1 + 1;
# CHECK:   }
# CHECK:   A[0] = A_1;
# CHECK: }
# CHECK: if (x>5
# CHECK:   int A_2 = A[0];
# CHECK:   for (int x
# CHECK:     A_2 = A_2 + 1;
# CHECK:   }
# CHECK:   A[0] = A_2;
# CHECK: })IR";

  torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
}

// Accesses are registerized inside two conditions, cut in the middle.
TEST(Registerizer, RegisterizerTwoConditionalLoopsCut) {
  BufHandle a("A", {1}, kInt);
  VarHandle x("x", kInt);
  StmtPtr stmt = Block::make(
      {Cond::make(
           CompareSelect::make(x, 5, CompareSelectOperation::kLT),
           For::make(
               x, 0, 10, Store::make(a, {0}, Add::make(Load::make(a, {0}), 1))),
           nullptr),
       For::make(x, 0, 10, Store::make(a, {x}, 1)),
       Cond::make(
           CompareSelect::make(x, 5, CompareSelectOperation::kGT),
           For::make(
               x, 0, 10, Store::make(a, {0}, Add::make(Load::make(a, {0}), 1))),
           nullptr)});

  /*
   * if (x<5 ? 1 : 0) {
   *   for (int x = 0; x < 10; x++) {
   *     A[0] = (A[0]) + 1;
   *   }
   * }
   * for (int x = 0; x < 10; x++) {
   *   A[x] = 1;
   * }
   * if (x>5 ? 1 : 0) {
   *   for (int x = 0; x < 10; x++) {
   *     A[0] = (A[0]) + 1;
   *   }
   * }
   */

  stmt = registerize(stmt);

  /*
   * if (x<5 ? 1 : 0) {
   *   int A_1 = A[0];
   *   for (int x = 0; x < 10; x++) {
   *     A_1 = A_1 + 1;
   *   }
   *   A[0] = A_1;
   * }
   * for (int x = 0; x < 10; x++) {
   *   A[x] = 1;
   * }
   * if (x>5 ? 1 : 0) {
   *   int A_2 = A[0];
   *   for (int x = 0; x < 10; x++) {
   *     A_2 = A_2 + 1;
   *   }
   *   A[0] = A_2;
   * }
   */

  std::ostringstream oss;
  oss << *stmt;

  const std::string& verification_pattern =
      R"IR(
# CHECK: if (x<5
# CHECK:   int A_1 = A[0];
# CHECK:   for (int x
# CHECK:     A_1 = A_1 + 1;
# CHECK:   }
# CHECK:   A[0] = A_1;
# CHECK: }
# CHECK: for (int x
# CHECK:  A[x] = 1;
# CHECK: if (x>5
# CHECK:   int A_2 = A[0];
# CHECK:   for (int x
# CHECK:     A_2 = A_2 + 1;
# CHECK:   }
# CHECK:   A[0] = A_2;
# CHECK: })IR";

  torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
}

// references a Let var in a local scope which cannot be hoisted out of the
// loop.
TEST(Registerizer, RegisterizerLoopLetVar) {
  BufHandle a("A", {10}, kInt);
  VarHandle x("x", kInt);
  VarHandle y("y", kInt);
  StmtPtr stmt = IRSimplifier::simplify(Block::make({For::make(
      x,
      0,
      10,
      Block::make(
          {Let::make(y, 30),
           Store::make(a, {y}, Add::make(x, Load::make(a, {y})))}))}));

  /*
   * for (int x = 0; x < 10; x++) {
   *   int y = 30;
   *   A[y] = x + (A[y]);
   * }
   */

  std::ostringstream before;
  before << *stmt;

  // No change.
  stmt = registerize(stmt);

  std::ostringstream after;
  after << *stmt;

  ASSERT_EQ(before.str(), after.str());
}

// references a Let var in an outer scope that does not prevent hoisting the
// initializer.
TEST(Registerizer, RegisterizerLoopLetVarOuter) {
  BufHandle a("A", {10}, kInt);
  VarHandle x("x", kInt);
  VarHandle y("y", kInt);
  StmtPtr stmt = Block::make(
      {Let::make(y, 30),
       For::make(
           x,
           0,
           10,
           Block::make(
               {Store::make(a, {y}, Add::make(x, Load::make(a, {y})))}))});

  /*
   * int y = 30;
   * for (int x = 0; x < 10; x++) {
   *   A[y] = x + (A[y]);
   * }
   */

  stmt = registerize(stmt);

  /*
   * int y = 30;
   * int A_1 = A[y];
   * for (int x = 0; x < 10; x++) {
   *   A_1 = A_1 + x;
   * }
   * A[y] = A_1;
   */

  std::ostringstream oss;
  oss << *stmt;

  const std::string& verification_pattern =
      R"IR(
# CHECK: int y = 30;
# CHECK: int A_1 = A[y];
# CHECK: for (int x
# CHECK:   A_1 = A_1 + x;
# CHECK: A[y] = A_1;)IR";

  torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
}

// Okay so the registerizer generally goes after index flattening, but just in
// case. Test multi index registerization.
TEST(Registerizer, RegisterizerMultiDim) {
  BufHandle a("A", {3, 4, 5}, kInt);
  VarHandle x("x", kInt);
  StmtPtr stmt = Block::make(
      {Store::make(a, {0, 1, 2}, 0),
       For::make(
           x,
           0,
           10,
           Block::make({Store::make(
               a, {0, 1, 2}, Add::make(Load::make(a, {0, 1, 2}), x))}))});

  /*
   * A[0, 1, 2] = 0;
   * for (int x = 0; x < 10; x++) {
   *   A[0, 1, 2] = (A[0, 1, 2]) + x;
   * }
   */

  stmt = registerize(stmt);

  /*
   * int A_1 = 0;
   * for (int x = 0; x < 10; x++) {
   *   A_1 = x + A_1;
   * }
   * A[0, 1, 2] = A_1;
   */

  std::ostringstream oss;
  oss << *stmt;

  const std::string& verification_pattern =
      R"IR(
# CHECK: int A_1 = 0;
# CHECK: for (int x = 0; x < 10; x++)
# CHECK-NOT: A[
# CHECK:   A_1 =
# CHECK: A[0, 1, 2] = A_1;)IR";

  torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
}

// Wont registerize if only some dims match, but will still registerize distinct
// elements.
TEST(Registerizer, RegisterizerMultiDimPartial) {
  BufHandle a("A", {3, 4, 5}, kInt);
  VarHandle x("x", kInt);
  StmtPtr stmt = Block::make(
      {Store::make(a, {0, 1, 2}, 0),
       For::make(
           x,
           0,
           10,
           Block::make({Store::make(
               a, {0, 2, 2}, Add::make(Load::make(a, {0, 1, 4}), x))}))});

  /*
   * A[0, 1, 2] = 0;
   * for (int x = 0; x < 10; x++) {
   *   A[0, 2, 2] = (A[0, 1, 4]) + x;
   * }
   */

  stmt = registerize(stmt);

  /*
   * A[0, 1, 2] = 0;
   * int A_1 = A[0, 1, 4];
   * int A_2 = A[0, 2, 2];
   * for (int x = 0; x < 10; x++) {
   *   A_2 = A_1 + x;
   * }
   * A[0, 2, 2] = A_2;
   */

  std::ostringstream oss;
  oss << *stmt;

  const std::string& verification_pattern =
      R"IR(
# CHECK: A[0, 1, 2] = 0;
# CHECK: int A_1 = A[0, 1, 4];
# CHECK: int A_2 = A[0, 2, 2];
# CHECK: for (
# CHECK:   A_2 = A_1 + x;
# CHECK: A[0, 2, 2] = A_2;)IR";

  torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
}

// If they could overlap across all dimensions we cannot registerize.
TEST(Registerizer, RegisterizerMultiDimOverlap) {
  BufHandle a("A", {3, 4, 5}, kInt);
  VarHandle x("x", kInt);
  VarHandle y("y", kInt);
  StmtPtr stmt = Block::make(
      {Store::make(a, {0, 1, 2}, 0),
       For::make(
           x,
           0,
           10,
           Block::make({Store::make(
               a, {0, x, 2}, Add::make(Load::make(a, {y, 2, 2}), x))}))});
  stmt = IRSimplifier::simplify(stmt);

  /*
   * A[0, 1, 2] = 0;
   * for (int x = 0; x < 10; x++) {
   *   A[0, x, 2] = (A[y, 2, 2]) + x;
   * }
   */

  std::ostringstream before;
  before << *stmt;

  // No change.
  stmt = registerize(stmt);

  std::ostringstream after;
  after << *stmt;

  ASSERT_EQ(before.str(), after.str());
}

// But, if one dimension is known to be distinct they do not overlap.
TEST(Registerizer, RegisterizerMultiDimPartialOverlap) {
  BufHandle a("A", {3, 4, 5}, kInt);
  VarHandle x("x", kInt);
  VarHandle y("y", kInt);
  StmtPtr stmt = Block::make(
      {Store::make(a, {0, 1, 2}, 0),
       For::make(
           x,
           0,
           10,
           Block::make({Store::make(
               a, {0, x, 2}, Add::make(Load::make(a, {y, 2, 4}), x))}))});

  /*
   * A[0, 1, 2] = 0;                          <---- 2nd dim overlaps with store.
   * for (int x = 0; x < 10; x++) {
   *   A[0, x, 2] = (A[y, 2, 4]) + x;           <---- 3rd dim has constant diff.
   * }
   */

  stmt = registerize(stmt);

  /*
   * A[0, 1, 2] = 0;
   * int A_1 = A[y, 2, 4];
   * for (int x = 0; x < 10; x++) {
   *   A[0, x, 2] = A_1 + x;
   * }
   */

  std::ostringstream oss;
  oss << *stmt;

  const std::string& verification_pattern =
      R"IR(
# CHECK: A[0, 1, 2] = 0;
# CHECK: int A_1 = A[y, 2, 4];
# CHECK: for (
# CHECK:   A[0, x, 2] = A_1 + x;
# CHECK: })IR";

  torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
}

// A 3D reduction with different input dimensionality.
TEST(Registerizer, RegisterizerMultiDim3DReduction1) {
  BufHandle a("A", {10}, kInt);
  BufHandle b("B", {10, 10}, kInt);
  BufHandle c("C", {10, 10, 10}, kInt);
  VarHandle x("x", kInt);
  VarHandle y("y", kInt);
  VarHandle z("z", kInt);
  StmtPtr stmt = For::make(
      x,
      0,
      10,
      For::make(
          y,
          0,
          10,
          For::make(
              z,
              0,
              10,
              Store::make(
                  c,
                  {x, y, z},
                  Add::make(
                      Load::make(c, {x, y, z}),
                      Mul::make(Load::make(b, {x, y}), Load::make(a, {x})))))));

  /*
   * for (int x = 0; x < 10; x++) {
   *   for (int y = 0; y < 10; y++) {
   *     for (int z = 0; z < 10; z++) {
   *       C[x, y, z] = (C[x, y, z]) + (B[x, y]) * (A[x]);
   *     }
   *   }
   * }
   */

  // We can registerize the A and B access since they can be hoisted before
  // hitting a dependent loop var.

  stmt = registerize(stmt);

  /*
   * for (int x = 0; x < 10; x++) {
   *   int A_1 = A[x];
   *   for (int y = 0; y < 10; y++) {
   *     int B_1 = B[x, y];
   *     for (int z = 0; z < 10; z++) {
   *       C[x, y, z] = A_1 * B_1 + (C[x, y, z]);
   *     }
   *   }
   * }
   */

  std::ostringstream oss;
  oss << *stmt;

  const std::string& verification_pattern =
      R"IR(
# CHECK: for (int x
# CHECK:   int A_1 = A[x];
# CHECK:   for (int y
# CHECK:     int B_1 = B[x, y];
# CHECK:       for (int z
# CHECK:         C[x, y, z] = A_1 * B_1 + (C[x, y, z]);
# CHECK: })IR";

  torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
}

// A 3D reduction with the same smaller dimensionality using different loop
// vars.
TEST(Registerizer, RegisterizerMultiDim3DReduction2) {
  BufHandle a("A", {10}, kInt);
  BufHandle b("B", {10}, kInt);
  BufHandle c("C", {10}, kInt);
  VarHandle x("x", kInt);
  VarHandle y("y", kInt);
  VarHandle z("z", kInt);
  StmtPtr stmt = For::make(
      x,
      0,
      10,
      // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks)
      For::make(
          y,
          0,
          10,
          For::make(
              z,
              0,
              10,
              Store::make(
                  c,
                  {x},
                  Add::make(
                      Load::make(c, {x}),
                      Mul::make(Load::make(b, {y}), Load::make(a, {x})))))));

  /*
   * for (int x = 0; x < 10; x++) {
   *   for (int y = 0; y < 10; y++) {
   *     for (int z = 0; z < 10; z++) {
   *       C[x] = (C[x]) + (B[y]) * (A[x]);
   *     }
   *   }
   * }
   */

  // We can registerize all accesses, the A and C access can be hoisted to the
  // outer loop since they depend only on it's loop var while the B can only be
  // raised to the loop of y.

  stmt = registerize(stmt);

  /*
   * for (int x = 0; x < 10; x++) {
   *   int A_1 = A[x];
   *   int C_1 = C[x];
   *   for (int y = 0; y < 10; y++) {
   *     int B_1 = B[y];
   *     for (int z = 0; z < 10; z++) {
   *       C_1 = A_1 * B_1 + C_1;
   *     }
   *   }
   *   C[x] = C_1;
   * }
   */

  std::ostringstream oss;
  oss << *stmt;

  const std::string& verification_pattern =
      R"IR(
# CHECK: for (int x
# CHECK:   int A_1 = A[x];
# CHECK:   int C_1 = C[x];
# CHECK:   for (int y
# CHECK:     int B_1 = B[y];
# CHECK:       for (int z
# CHECK:         C_1 = A_1 * B_1 + C_1;
# CHECK:       }
# CHECK:     }
# CHECK:   C[x] = C_1;
# CHECK: })IR";

  torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
}

} // namespace jit
} // namespace torch
