#include <ATen/core/boxing/impl/test_helpers.h>
#include <gtest/gtest.h>

#include <ATen/core/op_registration/op_registration.h>
#include <torch/torch.h>

#include <torch/csrc/autograd/FunctionsManual.h>
#include <torch/csrc/autograd/functions/basic_ops.h>

#include <test/cpp/api/support.h>

using namespace torch::autograd;
using namespace torch::test;

#define ASSERT_VARIABLE_EQ(a, b) ASSERT_TRUE(torch::allclose((a), (b)))
#define EXPECT_VARIABLE_EQ(a, b) EXPECT_TRUE(torch::allclose((a), (b)))

std::string graph_desc(std::shared_ptr<Node> node) {
  if (!node) {
    return "None";
  }
  auto result = node->name() + "(";
  auto next_edges = node->next_edges();
  for (auto& edge : next_edges) {
    result += graph_desc(edge.function);
  }
  return result + ")";
}

Variable simple_fn(const Variable& x, const Variable& y) {
  return x + 2 * y + x * y;
}

TEST(AutogradAPITests, RegisterHookVoidReturnAcceptsUndefinedTensor) {
  auto x = at::zeros({}, at::kCPU);
  x.requires_grad_();
  x.register_hook([](at::TensorBase x) { return; });
  auto y = torch::autograd::UndefinedGrad().apply({x});
  y[0].backward();
}

TEST(AutogradAPITests, RegisterHookTensorReturnAcceptsUndefinedTensor) {
  auto x = at::zeros({}, at::kCPU);
  x.requires_grad_();
  x.register_hook([](at::Tensor x) -> at::Tensor { return x; });
  auto y = torch::autograd::UndefinedGrad().apply({x});
  y[0].backward();
}

TEST(AutogradAPITests, BackwardSimpleTest) {
  Variable x = torch::randn({2, 2}, torch::requires_grad());
  Variable y = torch::randn({2, 2}, torch::requires_grad());
  auto res = simple_fn(x, y);
  backward({res.sum()}, {});

  ASSERT_VARIABLE_EQ(x.grad(), y + torch::ones({2, 2}));
  ASSERT_VARIABLE_EQ(y.grad(), x + torch::ones({2, 2}) * 2);
}

TEST(AutogradAPITests, BackwardTest) {
  Variable x = torch::randn({2, 2}, torch::requires_grad());
  Variable y = torch::randn({2, 2}, torch::requires_grad());
  auto res = simple_fn(x, y);
  backward({res}, {torch::ones({2, 2})}, {}, true);

  backward({res}, {torch::ones({2, 2})});

  ASSERT_VARIABLE_EQ(x.grad(), 2 * (y + torch::ones({2, 2})));
  ASSERT_VARIABLE_EQ(y.grad(), 2 * (x + torch::ones({2, 2}) * 2));
}

TEST(AutogradAPITests, GradSimpleTest) {
  // basic grad
  Variable x = torch::randn({2, 2}, torch::requires_grad());
  Variable y = torch::randn({2, 2}, torch::requires_grad());
  auto res = simple_fn(x, y);
  auto grad_res = grad({res}, {x, y}, {torch::ones({2, 2})});

  ASSERT_VARIABLE_EQ(grad_res[0], y + torch::ones({2, 2}));
  ASSERT_VARIABLE_EQ(grad_res[1], x + torch::ones({2, 2}) * 2);
}

TEST(AutogradAPITests, GradTest) {
  Variable x = torch::randn({2, 2}, torch::requires_grad());
  Variable y = torch::randn({2, 2}, torch::requires_grad());
  auto res = simple_fn(x, y);
  res.backward(torch::ones({2, 2}), false, true);

  Variable x_grad = y + torch::ones({2, 2});
  Variable y_grad = x + torch::ones({2, 2}) * 2;
  ASSERT_VARIABLE_EQ(x.grad(), x_grad);
  ASSERT_VARIABLE_EQ(y.grad(), y_grad);

  Variable grad_sum = 2 * x.grad() + y.grad();
  auto x_hv = grad({grad_sum}, {x}, {torch::ones({2, 2})}, {}, true);

  ASSERT_VARIABLE_EQ(x_hv[0], torch::ones({2, 2}));
  ASSERT_VARIABLE_EQ(x.grad(), x_grad);
  ASSERT_VARIABLE_EQ(y.grad(), y_grad);
}

TEST(AutogradAPITests, GradNonLeafTest) {
  Variable x_init = torch::randn({2, 2}, torch::requires_grad());
  Variable x = x_init;
  Variable y = torch::randn({2, 2}, torch::requires_grad());
  Variable grad_output = torch::ones({2, 2});

  for (int i = 0; i < 5; ++i) {
    auto res = simple_fn(x, y);
    auto input_grads = grad({res}, {x}, {grad_output}, {}, true);

    Variable grad_x_expected = y + torch::ones({2, 2});
    ASSERT_VARIABLE_EQ(input_grads[0], grad_x_expected);
    ASSERT_FALSE(x.grad().defined());
    ASSERT_FALSE(y.grad().defined());
    x = x + 0.05 * input_grads[0];
  }

  float val_init = simple_fn(x_init, y).sum().item().toFloat();
  float val_final = simple_fn(x, y).sum().item().toFloat();
  ASSERT_TRUE(val_final > val_init);

  x.backward(grad_output, false, true);
  ASSERT_TRUE(x_init.grad().defined());
  ASSERT_TRUE(y.grad().defined());
}

TEST(AutogradAPITests, GradUnreachableTest) {
  Variable x = torch::ones({1}, torch::requires_grad());
  Variable y = torch::ones({1}, torch::requires_grad());

  Variable z = x * 2;
  Variable w = y * 2;

  auto grad_res = grad({x * 2}, {x, y}, {}, {}, false, true);
  ASSERT_VARIABLE_EQ(grad_res[0], x * 2);
  ASSERT_FALSE(grad_res[1].defined());

  // This is slightly different than the case above, because z doesn't even
  // have a grad accumulator allocated.
  z = torch::ones({1}, torch::requires_grad());
  grad_res = grad({x * 2}, {x, z}, {}, {}, false, true);

  ASSERT_VARIABLE_EQ(grad_res[0], x * 2);
  ASSERT_FALSE(grad_res[1].defined());

  // allow_unused=False, but grads contains None inside, should throw
  ASSERT_THROWS_WITH(
      grad({x * 2}, {x, y}, {}, {}, false, false), "Set allow_unused=True");
}

TEST(CustomAutogradTest, GradUnreachableDiscoveryTest) {
  // Test that certain nodes are not erroneously executed when an input
  // is unreachable. See #39784
  struct MyFunction : public Function<MyFunction> {
    static Variable forward(AutogradContext* ctx, Variable var) {
      return var;
    }

    static variable_list backward(
        AutogradContext* ctx,
        variable_list grad_output) {
      ADD_FAILURE() << "This node should not be executed!";
      return grad_output;
    }
  };

  auto x = torch::randn(1, torch::requires_grad());
  auto x1 = torch::randn(1);
  auto x2 = MyFunction::apply(x + x1);

  auto y = torch::randn(1, torch::requires_grad());
  auto grad_res = torch::autograd::grad({x2}, {y}, {}, {}, false, true);
  ASSERT_FALSE(grad_res[0].defined());
}

TEST(AutogradAPITests, EmptyInput) {
  Variable x = torch::ones({1}, torch::requires_grad());
  ASSERT_THROWS_WITH(
      grad({x * 2}, /*inputs=*/{}, {x}), "grad requires non-empty inputs.");
}

TEST(AutogradAPITests, RetainGrad) {
  auto input = torch::rand({1, 3}, torch::requires_grad());
  auto h1 = input * 3;
  auto out = (h1 * h1).sum();

  {
    // Warning when grad is accessed for non-leaf tensor
    WarningCapture warnings;
    ASSERT_FALSE(h1.grad().defined());
    ASSERT_TRUE(warnings.str().find("is not a leaf") != std::string::npos);
  }
  // It should be possible to call retain_grad() multiple times
  h1.retain_grad();
  h1.retain_grad();
  {
    // If retain_grad is true for a non-leaf tensor,
    // there should not be any warning when grad is accessed
    WarningCapture warnings;
    ASSERT_FALSE(h1.grad().defined());
    ASSERT_FALSE(warnings.str().find("is not a leaf") != std::string::npos);
  }

  // Gradient should be accumulated
  // NOLINTNEXTLINE(bugprone-argument-comment)
  out.backward({}, /*keep_graph=*/true);
  ASSERT_VARIABLE_EQ(h1 * 2, h1.grad());
  // NOLINTNEXTLINE(bugprone-argument-comment)
  out.backward({}, /*keep_graph=*/true);
  ASSERT_VARIABLE_EQ(h1 * 4, h1.grad());

  {
    torch::NoGradGuard no_grad;
    input.grad().zero_();
  }
  // It should be a no-op for leaves
  input.retain_grad();
  input.retain_grad();
  out.backward();
  ASSERT_VARIABLE_EQ(input * 18, input.grad());
}

TEST(AutogradAPITests, AnomalyMode) {
  // Needs to have backtrace as warning and then throw an error
  torch::autograd::DetectAnomalyGuard detect_anomaly;
  {
    WarningCapture warnings;
    auto x = torch::tensor({5.0}, torch::requires_grad());
    auto y = x * x;
    auto z = y * y;
    y += 1;
    ASSERT_THROWS_WITH(z.backward(), "inplace");
    ASSERT_TRUE(
        warnings.str().find("Traceback of forward") != std::string::npos);
  }
  auto double_backward_produce_nan = [](bool should_throw) {
    auto x = torch::tensor({0.0}, torch::requires_grad());
    auto y = x.pow(1.5);
    auto gr =
        // NOLINTNEXTLINE(bugprone-argument-comment)
        grad({y}, {x}, {}, /*retain_graph=*/true, /*create_backward=*/true);
    if (should_throw) {
      WarningCapture warnings;
      ASSERT_THROWS_WITH(grad({gr[0]}, {x}, {torch::tensor({0.0})});
                         , "returned nan");
      auto msgs = warnings.messages();
      ASSERT_EQ(msgs.size(), 2);
      ASSERT_TRUE(
          msgs[0].find("Traceback of forward call that caused the error") !=
          std::string::npos);
      ASSERT_TRUE(
          msgs[1].find(
              "Traceback of forward call that induced the previous calculation") !=
          std::string::npos);
    } else {
      grad({gr[0]}, {x}, {torch::tensor({0.0})});
    }
  };

  double_backward_produce_nan(true);
  {
    torch::autograd::DetectAnomalyGuard detect_anomaly(/*check_nan=*/false);
    double_backward_produce_nan(false);
    {
      torch::autograd::DetectAnomalyGuard detect_anomaly(/*check_nan=*/true);
      double_backward_produce_nan(true);
    }
  }
  double_backward_produce_nan(true);
}

TEST(CustomAutogradTest, CustomFunctionReturnInputAsIsAndSavesIt) {
  struct MyFunction : public Function<MyFunction> {
    static Variable forward(
        AutogradContext* ctx,
        Variable var1,
        Variable var2) {
      ctx->save_for_backward({var1, var2});
      return var1 * var2, var1;
    }

    static variable_list backward(
        AutogradContext* ctx,
        variable_list grad_output) {
      return {};
    }
  };

  Variable x = torch::randn({5, 5}, torch::requires_grad());
  Variable y = torch::randn({5, 5}, torch::requires_grad());
  MyFunction::apply(x, y);
}

TEST(CustomAutogradTest, CustomFunction) {
  struct MyFunction : public Function<MyFunction> {
    static Variable forward(
        AutogradContext* ctx,
        Variable var1,
        int mul,
        Variable var2) {
      ctx->saved_data["mul"] = mul;
      ctx->save_for_backward({var1, var2});
      return var1 + mul * var2 + var1 * var2;
    }

    static variable_list backward(
        AutogradContext* ctx,
        variable_list grad_output) {
      int mul = ctx->saved_data["mul"].toInt();
      auto saved = ctx->get_saved_variables();
      auto var1 = saved[0];
      auto var2 = saved[1];
      variable_list output = {
          grad_output[0] + grad_output[0] * var2,
          Variable(),
          grad_output[0] * mul + grad_output[0] * var1};
      return output;
    }
  };

  Variable x = torch::randn({5, 5}, torch::requires_grad());
  Variable y = torch::randn({5, 5}, torch::requires_grad());
  auto res = MyFunction::apply(x, 2, y);
  auto go = torch::ones({}, torch::requires_grad());
  res.sum().backward(go, false, true);

  ASSERT_VARIABLE_EQ(x.grad(), y + torch::ones({5, 5}));
  ASSERT_VARIABLE_EQ(y.grad(), x + torch::ones({5, 5}) * 2);
}

TEST(CustomAutogradTest, CustomFunctionWithTensorList) {
  struct MyFunction : public Function<MyFunction> {
    static Variable forward(AutogradContext* ctx, at::TensorList tensors) {
      torch::autograd::variable_list vars;
      for (const at::Tensor& tensor : tensors) {
        vars.push_back(tensor);
      }
      ctx->save_for_backward(vars);
      return tensors[0] + tensors[1] + tensors[0] * tensors[1];
    }

    static variable_list backward(
        AutogradContext* ctx,
        variable_list grad_output) {
      auto saved = ctx->get_saved_variables();
      auto var1 = saved[0];
      auto var2 = saved[1];
      variable_list output = {
          grad_output[0] + grad_output[0] * var2,
          grad_output[0] + grad_output[0] * var1};
      return output;
    }
  };

  at::Tensor x = torch::randn({5, 5}, torch::requires_grad());
  at::Tensor y = torch::randn({5, 5}, torch::requires_grad());
  torch::autograd::variable_list variables = {x, y};
  at::TensorList tensors = variables;
  auto res = MyFunction::apply(tensors);
  auto go = torch::ones({}, torch::requires_grad());
  res.sum().backward(go, false, true);

  ASSERT_VARIABLE_EQ(x.grad(), y + torch::ones({5, 5}));
  ASSERT_VARIABLE_EQ(y.grad(), x + torch::ones({5, 5}));
}

TEST(CustomAutogradTest, GraphTaskTrimEdges) {
  struct MyFunction : public Function<MyFunction> {
    static Variable forward(
        AutogradContext* ctx,
        Variable var1,
        Variable var2,
        int mul,
        bool needs_input1_grad,
        bool needs_input2_grad) {
      // setup the expected should and should not compute idx
      ctx->saved_data["needs_input1_grad"] = needs_input1_grad;
      ctx->saved_data["needs_input2_grad"] = needs_input2_grad;

      ctx->saved_data["mul"] = mul;
      ctx->save_for_backward({var1, var2});
      return var1 + mul * var2 + var1 * var2;
    }

    static variable_list backward(
        AutogradContext* ctx,
        variable_list grad_output) {
      // Test `needs_input_grad` method is working correctly.
      // We have to test this within the backward function.
      auto needs_input1_grad = ctx->saved_data["needs_input1_grad"].toBool();
      auto needs_input2_grad = ctx->saved_data["needs_input2_grad"].toBool();
      IndexRange var1_idx = {0, 1};
      IndexRange var2_idx = {1, 2};
      EXPECT_EQ(ctx->needs_input_grad(0), needs_input1_grad);
      EXPECT_EQ(ctx->needs_input_grad(1), needs_input2_grad);
      EXPECT_EQ(ctx->needs_input_grad({var1_idx}), needs_input1_grad);
      EXPECT_EQ(ctx->needs_input_grad({var2_idx}), needs_input2_grad);
      EXPECT_EQ(
          ctx->needs_input_grad({var1_idx, var2_idx}),
          needs_input1_grad || needs_input2_grad);

      // calculate gradients
      int mul = ctx->saved_data["mul"].toInt();
      auto saved = ctx->get_saved_variables();
      auto var1 = saved[0];
      auto var2 = saved[1];

      Variable grad_var1, grad_var2;
      if (ctx->needs_input_grad(0)) {
        grad_var1 = grad_output[0] + grad_output[0] * var2;
      }
      if (ctx->needs_input_grad(1)) {
        grad_var2 = grad_output[0] * mul + grad_output[0] * var1;
      }
      variable_list output = {
          grad_var1,
          grad_var2,
          Variable(),
          Variable(),
          Variable(),
      };
      return output;
    }
  };

  Variable x = torch::randn({5, 5}, torch::requires_grad());
  Variable y = torch::randn({5, 5}, torch::requires_grad());
  auto go = torch::ones_like(x);
  Variable out;

  // grad_x
  out = MyFunction::apply(
      x,
      y,
      2,
      /* needs_input1_grad= */ true,
      /* needs_input2_grad= */ false);
  auto grad_x = torch::autograd::grad({out}, {x}, {go})[0];
  ASSERT_VARIABLE_EQ(grad_x, y + torch::ones({5, 5}));

  // grad_y
  out = MyFunction::apply(
      x,
      y,
      2,
      /* needs_input1_grad= */ false,
      /* needs_input2_grad= */ true);
  auto grad_y = torch::autograd::grad({out}, {y}, {go})[0];
  ASSERT_VARIABLE_EQ(grad_y, x + torch::ones({5, 5}) * 2);

  // grad_x and grad_y
  out = MyFunction::apply(
      x,
      y,
      2,
      /* needs_input1_grad= */ true,
      /* needs_input2_grad= */ true);
  auto grads = torch::autograd::grad({out}, {x, y}, {go});
  ASSERT_VARIABLE_EQ(grads[0], y + torch::ones({5, 5}));
  ASSERT_VARIABLE_EQ(grads[1], x + torch::ones({5, 5}) * 2);
}

TEST(CustomAutogradTest, FunctionReturnsInput) {
  struct MyFunction : public Function<MyFunction> {
    static Variable forward(AutogradContext* ctx, Variable var1) {
      return var1;
    }

    static variable_list backward(
        AutogradContext* ctx,
        variable_list grad_output) {
      return {grad_output[0] * 2};
    }
  };

  Variable x(torch::ones(1, torch::requires_grad()));
  MyFunction::apply(x).backward(torch::ones(1), true, true);
  ASSERT_VARIABLE_EQ(x.grad(), torch::full(1, 2.));
}

TEST(CustomAutogradTest, FunctionReturnsUndefined) {
  struct MyFunction : public Function<MyFunction> {
    static Variable forward(AutogradContext* ctx, Variable var) {
      return var * 2;
    }

    static variable_list backward(
        AutogradContext* ctx,
        variable_list grad_output) {
      at::Tensor undefined_tensor;
      return {undefined_tensor};
    }
  };

  auto x = torch::ones(1, torch::requires_grad());

  MyFunction::apply(x).backward();
  ASSERT_FALSE(x.grad().defined());

  MyFunction::apply(x.pow(2)).backward();
  ASSERT_FALSE(x.grad().defined());

  MyFunction::apply(x).sum().backward();
  ASSERT_FALSE(x.grad().defined());

  ASSERT_FALSE(torch::autograd::grad(
                   {MyFunction::apply(x)}, {x}, {}, false, false, true)[0]
                   .defined());
}

TEST(CustomAutogradTest, MaterializeGrads) {
  struct MyFunction : public Function<MyFunction> {
    static Variable forward(AutogradContext* ctx, Variable var) {
      return var;
    }

    static variable_list backward(
        AutogradContext* ctx,
        variable_list grad_output) {
      EXPECT_VARIABLE_EQ(grad_output[0], torch::zeros(1));
      return grad_output;
    }
  };

  auto x = torch::ones(1, torch::requires_grad());
  UndefinedGrad().apply({MyFunction::apply(x)})[0].backward();
}

TEST(CustomAutogradTest, DontMaterializeGrads) {
  struct MyFunction : public Function<MyFunction> {
    static Variable forward(AutogradContext* ctx, Variable var) {
      ctx->set_materialize_grads(false);
      return var;
    }

    static variable_list backward(
        AutogradContext* ctx,
        variable_list grad_output) {
      EXPECT_FALSE(grad_output[0].defined());
      return grad_output;
    }
  };

  auto x = torch::ones(1, torch::requires_grad());
  UndefinedGrad().apply({MyFunction::apply(x)})[0].backward();
}

TEST(CustomAutogradTest, NoGradCustomFunction) {
  // Custom Function should respect grad mode
  struct MyOp : public Function<MyOp> {
    static Variable forward(AutogradContext* ctx, Variable x) {
      return x + 1;
    }

    static variable_list backward(AutogradContext* ctx, variable_list dy) {
      return dy;
    }
  };

  auto x = torch::ones({5, 5}, torch::requires_grad());
  {
    at::NoGradGuard no_grad;
    auto y = MyOp::apply(x);
    ASSERT_FALSE(y.requires_grad());
  }
}

TEST(CustomAutogradTest, MarkDirty) {
  struct MyFunction : public Function<MyFunction> {
    static Variable forward(AutogradContext* ctx, Variable v) {
      // Change the value inplace
      auto v_data = v.data_ptr<float>();
      v_data[0] = 2;
      ctx->mark_dirty({v});
      return v;
    }

    static variable_list backward(
        AutogradContext* ctx,
        variable_list grad_output) {
      return {(grad_output[0] * 2.0)};
    }
  };

  // Clone here because modifying leafs inplace is not allowed
  auto x = torch::randn({5, 5}, torch::requires_grad()).clone();
  auto version_before = x._version();
  auto out = MyFunction::apply(x);
  auto version_after = x._version();
  ASSERT_TRUE(version_after >= (version_before + 1));
  out.sum().backward();
}

TEST(CustomAutogradTest, MarkNonDifferentiable) {
  struct MyFunction : public Function<MyFunction> {
    static Variable forward(AutogradContext* ctx, Variable v) {
      Variable output = v > 0;
      ctx->mark_non_differentiable({output});
      return output;
    }

    static variable_list backward(
        AutogradContext* ctx,
        variable_list grad_output) {
      return {(grad_output[0] * 0.0)};
    }
  };

  auto x = torch::randn({5, 5}, torch::requires_grad());
  auto mask = MyFunction::apply(x);
  ASSERT_FALSE(mask.requires_grad());
  auto y = x.masked_fill(mask, 0);
  y.sum().backward();
}

TEST(CustomAutogradTest, MarkNonDifferentiableMixed) {
  struct MyFunction : public Function<MyFunction> {
    static variable_list forward(AutogradContext* ctx, Variable input) {
      Variable a = input + 1;
      Variable b = input + 2;
      ctx->mark_non_differentiable({a});
      return {a, b};
    }

    static variable_list backward(
        AutogradContext* ctx,
        variable_list grad_output) {
      const Variable &grad_a = grad_output[0], &grad_b = grad_output[1];
      EXPECT_VARIABLE_EQ(grad_a, torch::zeros({5, 5}));
      EXPECT_VARIABLE_EQ(grad_b, torch::ones({5, 5}));
      return {grad_b};
    }
  };

  auto x = torch::randn({5, 5}, torch::requires_grad());
  auto out = MyFunction::apply(x);

  ASSERT_FALSE(out[0].requires_grad());
  ASSERT_TRUE(out[1].requires_grad());
  out[1].sum().backward();
  ASSERT_VARIABLE_EQ(x.grad(), torch::ones({5, 5}));
}

TEST(CustomAutogradTest, MarkNonDifferentiableNone) {
  struct MyFunction : public Function<MyFunction> {
    static Variable forward(AutogradContext* ctx, Variable input) {
      auto output = input.clone();
      ctx->mark_non_differentiable({output});
      return output;
    }

    static variable_list backward(
        AutogradContext* ctx,
        variable_list grad_outputs) {
      return {};
    }
  };

  auto x = torch::randn({5, 5}, torch::requires_grad());
  auto r = MyFunction::apply(x * x);
  (r * x).sum().backward();
}

TEST(CustomAutogradTest, ReturnLeafInplace) {
  struct Inplace : public Function<Inplace> {
    static variable_list forward(AutogradContext* ctx, Variable a, Variable b) {
      ctx->mark_dirty({a});
      return {a.add_(b), b + 2};
    }

    static variable_list backward(
        AutogradContext* ctx,
        variable_list grad_output) {
      return {grad_output[0], grad_output[0] + grad_output[1]};
    }
  };

  Variable x = torch::randn({5, 5});
  Variable y = torch::randn({5, 5}, torch::requires_grad());

  auto out = Inplace::apply(x, y);
  auto& q = out[0];
  ASSERT_TRUE(torch::equal(q, x));
  ASSERT_TRUE(q.requires_grad());
  q.sum().backward();
  ASSERT_VARIABLE_EQ(y.grad(), torch::ones({5, 5}));
}

TEST(CustomAutogradTest, ReturnDuplicateInplace) {
  struct DoubleInplace : public Function<DoubleInplace> {
    static variable_list forward(AutogradContext* ctx, Variable x) {
      x.mul_(2);
      ctx->mark_dirty({x});
      return {x, x};
    }

    static variable_list backward(
        AutogradContext* ctsx,
        variable_list grad_outputs) {
      return {grad_outputs[0] * 2 + grad_outputs[1] * 2};
    }
  };

  auto x = torch::randn({5, 5}, torch::requires_grad());

  ASSERT_THROWS_WITH(
      DoubleInplace::apply(x), "leaf Variable that requires grad");
  // TODO ASSERT_THROWS_WITH(DoubleInplace::apply(x.clone()[0]), "only one
  // output");

  auto out = DoubleInplace::apply(x.clone());
  ASSERT_TRUE(torch::equal(out[0], out[1]));
}

TEST(CustomAutogradTest, ReturnDuplicate) {
  struct DoubleDuplicate : public Function<DoubleDuplicate> {
    static variable_list forward(AutogradContext* ctx, Variable x) {
      auto output = x * 2;
      return {output, output};
    }

    static variable_list backward(
        AutogradContext* ctx,
        variable_list grad_outputs) {
      return {grad_outputs[0] * 2 + grad_outputs[1] * 2};
    }
  };

  auto x = torch::randn({5, 5}, torch::requires_grad());
  auto out = DoubleDuplicate::apply(x);
  ASSERT_TRUE(torch::equal(out[0], out[1]));
}

TEST(CustomAutogradTest, SaveEmptyForBackward) {
  struct MyFunction : public Function<MyFunction> {
    static Variable forward(AutogradContext* ctx, Variable input) {
      ctx->save_for_backward({Variable(), input, Variable()});
      return input * input;
    }

    static variable_list backward(
        AutogradContext* ctx,
        variable_list grad_output) {
      auto saved = ctx->get_saved_variables();
      EXPECT_FALSE(saved[0].defined());
      EXPECT_FALSE(saved[2].defined());
      return {saved[1] * 2 * grad_output[0]};
    }
  };

  Variable x = torch::randn({5, 5}, torch::requires_grad());
  auto y = MyFunction::apply(x);
  y.sum().backward();
  ASSERT_VARIABLE_EQ(x.grad(), 2 * x);
}

TEST(CustomAutogradTest, InvalidGradients) {
  struct MyFunction : public Function<MyFunction> {
    static Variable forward(AutogradContext* ctx, Variable x) {
      return x * 2;
    }

    static variable_list backward(
        AutogradContext* ctsx,
        variable_list grad_outputs) {
      return {
          torch::randn(10, torch::dtype(torch::kFloat).requires_grad(true))};
    }
  };

  auto input1 =
      torch::randn({5, 5}, torch::dtype(torch::kFloat).requires_grad(true));
  ASSERT_THROWS_WITH(
      MyFunction::apply(input1).sum().backward(), "expected shape");
  auto input2 =
      torch::randn(10, torch::dtype(torch::kDouble).requires_grad(true));
}

TEST(CustomAutogradTest, NoGradInput) {
  struct MyFunction : public Function<MyFunction> {
    static Variable forward(AutogradContext*, Variable x) {
      return x;
    }

    static variable_list backward(
        AutogradContext*,
        variable_list grad_outputs) {
      return grad_outputs;
    }
  };

  Variable x = torch::randn({5, 5}, torch::requires_grad());
  Variable y;
  {
    at::NoGradGuard no_grad;
    y = MyFunction::apply(x);
  }

  ASSERT_TRUE(x.requires_grad());
  ASSERT_FALSE(y.grad_fn());
}

TEST(CustomAutogradTest, TooManyGrads) {
  struct MyFunction : public Function<MyFunction> {
    static Variable forward(AutogradContext*, Variable input) {
      return input;
    }

    static variable_list backward(AutogradContext*, variable_list grad_output) {
      grad_output.insert(grad_output.end(), {Variable(), Variable()});
      return grad_output;
    }
  };
}

TEST(CustomAutogradTest, DepNoGrad) {
  struct F1 : public Function<F1> {
    static variable_list forward(AutogradContext* ctx, Variable input) {
      auto out = torch::randn(input.sizes());
      ctx->mark_non_differentiable({out});
      return {input, out};
    }

    static variable_list backward(
        AutogradContext* ctx,
        variable_list grad_output) {
      return {grad_output[0]};
    }
  };

  struct F2 : public Function<F2> {
    static Variable forward(AutogradContext*, Variable input, Variable ignore) {
      return input;
    }

    static variable_list backward(AutogradContext*, variable_list grad_output) {
      return {grad_output[0], Variable()};
    }
  };

  auto x = torch::randn(5, torch::requires_grad());
  auto out = F1::apply(x);
  Variable &a = out[0], &b = out[1];
  b = b + 1; // Separate F1 and F2 by another operation
  ASSERT_TRUE(a.requires_grad());
  ASSERT_FALSE(b.requires_grad());

  auto c = F2::apply(a, b);
  c.backward(torch::ones(c.sizes()), false, false);
  ASSERT_VARIABLE_EQ(x.grad(), torch::ones(x.sizes()));
}

TEST(CustomAutogradTest, Reentrant) {
  static Variable y_data = torch::randn({2, 2});
  struct Reenter : public Function<Reenter> {
    static Variable forward(AutogradContext* ctx, Variable input) {
      Variable output;
      {
        at::AutoGradMode enable_grad(true);
        auto x = make_variable(input.tensor_data(), true);
        auto y = make_variable(y_data.tensor_data(), true);
        output = x * y;

        ctx->saved_data["x"] = x;
        ctx->saved_data["y"] = y;
        ctx->saved_data["output_var"] = output;
      }
      return output.detach();
    }

    static variable_list backward(
        AutogradContext* ctx,
        variable_list grad_output) {
      {
        at::AutoGradMode enable_grad(true);
        auto out = ctx->saved_data["output_var"].toTensor();
        out.sum().backward();
      }
      return {ctx->saved_data["x"].toTensor().grad() * grad_output[0]};
    }
  };

  auto x = torch::randn({2, 2}, torch::requires_grad());
  auto out = Reenter::apply(x);
  out.sum().backward();
  ASSERT_VARIABLE_EQ(x.grad(), y_data);
}

// NOTE: If this fails for apparently unrelated reasons in TSAN be aware of
// the TSAN limit on mutex: https://github.com/google/sanitizers/issues/950
TEST(CustomAutogradTest, DeepReentrant) {
  struct DeepReenter : public Function<DeepReenter> {
    static Variable forward(AutogradContext* ctx, Variable x) {
      {
        at::AutoGradMode enable_grad(true);
        ctx->saved_data["x"] = make_variable(x.tensor_data(), true) - 1;
      }
      return ctx->saved_data["x"].toTensor().detach();
    }

    static variable_list backward(
        AutogradContext* ctx,
        variable_list grad_output) {
      if (!at::native::is_nonzero(ctx->saved_data["x"].toTensor())) {
        return grad_output;
      }
      {
        at::AutoGradMode enable_grad(true);
        apply(ctx->saved_data["x"].toTensor())[0].sum().backward();
        return grad_output;
      }
    }
  };

  // This should not stack overflow
  auto v =
      torch::tensor({8193}, torch::dtype(torch::kFloat).requires_grad(true));
  DeepReenter::apply(v).sum().backward();
}

TEST(CustomAutogradTest, ReentrantPriority) {
  static std::vector<int> order;

  struct MyFunction : public Function<MyFunction> {
    static Variable forward(AutogradContext*, Variable x) {
      return x;
    }

    static variable_list backward(AutogradContext*, variable_list grad) {
      order.push_back(0);
      return grad;
    }
  };

  struct Reenter : public Function<Reenter> {
    static Variable forward(AutogradContext* ctx, Variable x) {
      {
        at::AutoGradMode enable_grad(true);
        ctx->saved_data["x"] = make_variable(x.tensor_data(), true) - 1;
      }
      return ctx->saved_data["x"].toTensor().detach();
    }

    static variable_list backward(
        AutogradContext* ctx,
        variable_list grad_output) {
      order.push_back(1);
      if (!at::native::is_nonzero(ctx->saved_data["x"].toTensor())) {
        return grad_output;
      }
      {
        at::AutoGradMode enable_grad(true);
        apply(ctx->saved_data["x"].toTensor())[0].sum().backward();
        return grad_output;
      }
    }
  };

  auto a = MyFunction::apply(
      torch::tensor({6}, torch::dtype(torch::kFloat).requires_grad(true)));
  auto b = Reenter::apply(
      torch::tensor({9}, torch::dtype(torch::kFloat).requires_grad(true)));
  auto v = a * b;
  v.backward();

  // All the reentrant tasks should be prioritized over the MyFunction backward
  // task.
  ASSERT_EQ(order.size(), 10);
  ASSERT_EQ(std::count(order.begin(), order.end(), 1), 9);
  ASSERT_EQ(order.back(), 0);
  // Clear static variable in case test get executed in a loop
  order.clear();
}

TEST(CustomAutogradTest, Hooks) {
  Variable x = torch::ones({5, 5}, torch::requires_grad());
  Variable y = torch::ones({5, 5}) * 4;
  y.set_requires_grad(true);

  int counter = 0;

  std::function<void(int, Variable)> bw_hook(
      [&counter](int inc, Variable grad) { counter += inc; });

  Variable z = x * x + x * 2 + x * y + y;
  x.register_hook([&bw_hook](Variable grad) { bw_hook(0, grad); });
  auto hook_1 =
      z.register_hook([&bw_hook](Variable grad) { bw_hook(1, grad); });
  z.backward(torch::ones({5, 5}), true, true);
  ASSERT_EQ(counter, 1);

  auto hook_2 =
      z.register_hook([&bw_hook](Variable grad) { bw_hook(2, grad); });
  z.backward(torch::ones({5, 5}), true, true);
  ASSERT_EQ(counter, 4);

  z.remove_hook(hook_2);
  z.backward(torch::ones({5, 5}), true, true);
  ASSERT_EQ(counter, 5);

  std::function<Variable(Variable)> bw_hook_modify(
      [](Variable grad) { return grad.mul(2); });

  z.remove_hook(hook_1);
  z.register_hook(bw_hook_modify);
  y.grad().zero_();
  z.backward(torch::ones({5, 5}), true, false);
  ASSERT_VARIABLE_EQ(y.grad(), (x + 1) * 2);

  y.register_hook(bw_hook_modify);
  y.grad().zero_();
  z.backward(torch::ones({5, 5}), false, false);
  ASSERT_VARIABLE_EQ(y.grad(), (x + 1) * 4);

  ASSERT_THROWS_WITH(y.remove_hook(3), "Invalid index");
}

TEST(CustomAutogradTest, HooksInplace) {
  auto a = torch::ones({5, 5}, torch::requires_grad()).clone();

  int hook1_count = 0;
  auto hook1 = ([&hook1_count](Variable grad) {
    hook1_count++;
    ASSERT_VARIABLE_EQ(grad, torch::ones({5, 5}) * 2);
  });

  int hook2_count = 0;
  auto hook2 = ([&hook2_count](Variable grad) {
    hook2_count++;
    ASSERT_VARIABLE_EQ(grad, torch::ones({5, 5}));
  });

  a.register_hook(hook1);
  a.mul_(2);
  a.register_hook(hook2);

  auto out = (a + 1).sum();
  out.backward();

  ASSERT_EQ(hook1_count, 1);
  ASSERT_EQ(hook2_count, 1);
}

TEST(CustomAutogradTest, HooksInplaceWithRetainsGrad) {
  auto a = torch::ones({5, 5}, torch::requires_grad()).clone();

  int hook1_count = 0;
  auto hook1 = ([&hook1_count](Variable grad) {
    hook1_count++;
    ASSERT_VARIABLE_EQ(grad, torch::ones({5, 5}) * 2);
  });

  int hook2_count = 0;
  auto hook2 = ([&hook2_count](Variable grad) {
    hook2_count++;
    ASSERT_VARIABLE_EQ(grad, torch::ones({5, 5}) * 2);
  });

  int hook3_count = 0;
  auto hook3 = ([&hook3_count](Variable grad) {
    hook3_count++;
    ASSERT_VARIABLE_EQ(grad, torch::ones({5, 5}));
  });

  a.register_hook(hook1);
  a.retain_grad();
  a.register_hook(hook2);

  a.mul_(2);
  a.register_hook(hook3);

  auto out = (a + 1).sum();
  out.backward();

  ASSERT_EQ(hook1_count, 1);
  ASSERT_EQ(hook2_count, 1);
  ASSERT_EQ(hook3_count, 1);

  ASSERT_TRUE(a.retains_grad());
  ASSERT_VARIABLE_EQ(a.grad(), torch::ones({5, 5}));
}

TEST(CustomAutogradTest, HooksInplaceTwiceWithRetainsGrad) {
  auto a = torch::ones({5, 5}, torch::requires_grad()).clone();

  int hook1_count = 0;
  auto hook1 = ([&hook1_count](Variable grad) {
    hook1_count++;
    ASSERT_VARIABLE_EQ(grad, torch::ones({5, 5}) * 4);
  });

  int hook2_count = 0;
  auto hook2 = ([&hook2_count](Variable grad) {
    hook2_count++;
    ASSERT_VARIABLE_EQ(grad, torch::ones({5, 5}) * 4);
  });

  int hook3_count = 0;
  auto hook3 = ([&hook3_count](Variable grad) {
    hook3_count++;
    ASSERT_VARIABLE_EQ(grad, torch::ones({5, 5}));
  });

  a.register_hook(hook1);
  a.retain_grad();
  a.register_hook(hook2);

  a.mul_(2);
  a.mul_(2);
  a.register_hook(hook3);

  auto out = (a + 1).sum();
  out.backward();

  ASSERT_EQ(hook1_count, 1);
  ASSERT_EQ(hook2_count, 1);
  ASSERT_EQ(hook3_count, 1);

  ASSERT_TRUE(a.retains_grad());
  ASSERT_VARIABLE_EQ(a.grad(), torch::ones({5, 5}));
}

TEST(CustomAutogradTest, HookNone) {
  struct NoneGradientFunction : public Function<NoneGradientFunction> {
    static variable_list forward(AutogradContext* ctx, Variable x, Variable y) {
      return {x, y};
    }

    static variable_list backward(AutogradContext* ctx, variable_list grad) {
      return {grad[0], Variable()};
    }
  };

  bool was_called = false;

  auto hook = ([&was_called](Variable grad) {
    ASSERT_TRUE(grad.defined());
    was_called = true;
  });

  auto x = torch::randn({5, 5}, torch::requires_grad());
  auto y = torch::randn({5, 5});

  auto out = NoneGradientFunction::apply(x, y);
  Variable rx = x[0], ry = x[1];

  rx.register_hook(hook);
  ry.register_hook(hook);
  (rx + ry).sum().backward();
  ASSERT_TRUE(was_called);
}

TEST(CustomAutogradTest, BackwardWithInputs) {
  Variable x = torch::randn({5, 5}, torch::requires_grad());
  Variable y = torch::randn({5, 5}, torch::requires_grad());
  Variable z = x * x + x * y + y * y;
  Variable x_grad_expected = 2 * x + y;
  Variable y_grad_expected = x + 2 * y;

  z.backward(torch::ones({5, 5}), false, false, {x});

  ASSERT_VARIABLE_EQ(x.grad(), x_grad_expected);
  ASSERT_FALSE(y.grad().defined());
}

TEST(CustomAutogradTest, BackwardWithEmptyInputs) {
  Variable x = torch::randn({5, 5}, torch::requires_grad());
  Variable y = torch::randn({5, 5}, torch::requires_grad());
  Variable z = x * x + x * y + y * y;
  Variable x_grad_expected = 2 * x + y;
  Variable y_grad_expected = x + 2 * y;
  ASSERT_THROWS_WITH(
      z.backward(torch::ones({5, 5}), false, false, std::vector<Variable>{}),
      "cannot be empty");
}

TEST(CustomAutogradTest, BackwardWithNonLeafInputs) {
  Variable x = torch::randn({5, 5}, torch::requires_grad());
  Variable y = torch::randn({5, 5}, torch::requires_grad());
  Variable z = x * x;
  Variable w = y * z + x * y + y * y;

  Variable x_grad_expected = 2 * x * y + y;
  Variable z_grad_expected = y;

  w.backward(torch::ones({5, 5}), false, false, std::vector<Variable>{x, z});

  ASSERT_VARIABLE_EQ(x.grad(), x_grad_expected);
  ASSERT_VARIABLE_EQ(z.grad(), z_grad_expected);
  ASSERT_FALSE(y.grad().defined());
}

TEST(CustomAutogradTest, BackwardWithCreateGraphWarns) {
  c10::WarningUtils::WarnAlways guard(true);

  torch::Tensor x = torch::randn({5, 5}).set_requires_grad(true);
  auto z = x * x;
  {
    WarningCapture warnings;
    z.backward(torch::ones({5, 5}), std::nullopt, true);
    ASSERT_TRUE(
        warnings.str().find("Using backward() with create_graph=True") !=
        std::string::npos);
  }

  {
    WarningCapture warnings;
    torch::autograd::backward({z}, {torch::ones({5, 5})}, std::nullopt, true);
    ASSERT_TRUE(
        warnings.str().find("Using backward() with create_graph=True") !=
        std::string::npos);
  }
}

/**
 * Tests for AutogradNotImplementedFallback
 * - Check that we created the NotImplemented kernel when inputs require grad
 *   but when no inputs require grad, we should not create this node
 * - check_inplace logic
 * - view ops
 * - TODO: Tests for debug-only checks? Don't need for now because CI doesn't
 * test non-NDEBUG builds.
 * - tensorlist input and output
 * - multiple outputs / non-tensor output
 * - rebase_history vs set_history
 */
namespace {

torch::Tensor inplace_op(
    const torch::Tensor& self,
    const torch::Tensor& other) {
  return self.add_(other);
}

std::tuple<torch::Tensor, torch::Tensor> two_arg_inplace_op(
    const torch::Tensor& self,
    const torch::Tensor& other) {
  other.add_(self);
  self.add_(other);
  return std::tuple<torch::Tensor, torch::Tensor>(self, other);
}

std::tuple<torch::Tensor, torch::Tensor> two_pairs_of_view_op(
    const torch::Tensor& self,
    const torch::Tensor& other) {
  // This is not allowed. We test below that this calling into the boxed kernel
  // will raise an error
  return std::tuple<torch::Tensor, torch::Tensor>(self, other);
}

std::tuple<torch::Tensor, torch::Tensor> non_first_view_op(
    const torch::Tensor& self,
    const torch::Tensor& other) {
  // This is not allowed. We test below that this calling into the boxed kernel
  // will raise an error
  return std::tuple<torch::Tensor, torch::Tensor>(self.clone(), other);
}

int64_t ret_single_non_tensor(
    const torch::Tensor& self,
    const torch::Tensor& other) {
  return 12;
}

torch::Tensor opt_op(
    const torch::Tensor& self,
    const std::optional<at::Tensor>& other) {
  if (other.has_value()) {
    return self + other.value();
  } else {
    return self.clone();
  }
}

torch::Tensor my_custom_op(
    const torch::Tensor& self,
    const torch::Tensor& other) {
  return self + other;
}

std::tuple<torch::Tensor, torch::Tensor, int64_t> ret_tuple_non_tensor(
    const torch::Tensor& self,
    const torch::Tensor& other) {
  auto a = self - other;
  auto b = self + other;
  return std::tuple<torch::Tensor, torch::Tensor, int64_t>(a, b, 12);
}

torch::Tensor view_op(const torch::Tensor& self) {
  return self.alias();
}

torch::Tensor view_op_with_extra_arg(
    const torch::Tensor& self,
    const torch::Tensor& other) {
  return self.alias();
}

std::vector<torch::Tensor> ret_tensor_vector_view(
    const torch::Tensor& self,
    const torch::Tensor& other) {
  return {self.alias(), self.alias()};
}

std::vector<at::Tensor> ret_tensor_vector(
    const torch::Tensor& self,
    const torch::Tensor& other) {
  std::vector<at::Tensor> out;
  out.push_back(self + other);
  out.push_back(self - other);
  return out;
}

torch::Tensor tensorlist_op(const torch::Tensor& self, at::TensorList other) {
  const auto& res = self.clone();
  for (const auto& t : other) {
    res.add_(t);
  }
  return res;
}

#define REGISTER_TEST_OP(name, schema, fn)                                 \
  auto m = MAKE_TORCH_LIBRARY(_test);                                      \
  m.def(schema);                                                           \
  auto m_autograd = MAKE_TORCH_LIBRARY_IMPL(_test, Autograd);              \
  auto m_cpu = MAKE_TORCH_LIBRARY_IMPL(_test, CPU);                        \
  auto m_inplaceorview = MAKE_TORCH_LIBRARY_IMPL(_test, ADInplaceOrView);  \
  m_cpu.impl(name, c10::DispatchKey::CPU, TORCH_FN(fn));                   \
  m_autograd.impl(                                                         \
      name, c10::DispatchKey::Autograd, autogradNotImplementedFallback()); \
  m_inplaceorview.impl(                                                    \
      name,                                                                \
      c10::DispatchKey::ADInplaceOrView,                                   \
      autogradNotImplementedInplaceOrViewFallback());

template <typename F>
void assertBasicChecks(F op) {
  auto a = torch::tensor({1.}, {torch::kFloat32}).set_requires_grad(true);
  auto b = torch::tensor({1.}, {torch::kFloat32});
  auto c = torch::tensor({1.}, {torch::kFloat32});

  // If any inputs require grad,
  auto out1 = op(a, b);
  ASSERT_THROWS_WITH(out1.backward(), "is not implemented");

  // # Should not have grad_fn if none require grad
  auto out2 = op(b, c);
  ASSERT_THROWS_WITH(
      out2.backward(),
      "element 0 of tensors does not require grad and does not have a grad_fn");

  // TODO: Forward AD Tests?
}

} // namespace

TEST(TestAutogradNotImplementedFallback, RetSingleNonTensor) {
  REGISTER_TEST_OP(
      "ret_single_non_tensor",
      "_test::ret_single_non_tensor(Tensor self, Tensor other) -> int",
      ret_single_non_tensor);
  auto opHandle = c10::Dispatcher::singleton().findSchemaOrThrow(
      "_test::ret_single_non_tensor", "");
  auto op = [&](const torch::Tensor& _1, const torch::Tensor& _2) {
    return callOpUnboxed<int64_t, const torch::Tensor&, const torch::Tensor&>(
        opHandle, _1, _2);
  };

  auto a = torch::tensor({1.}, {torch::kFloat32}).set_requires_grad(true);
  auto b = torch::tensor({1.}, {torch::kFloat32});

  ASSERT_EQ(op(a, b), ret_single_non_tensor(a, b));
}

TEST(TestAutogradNotImplementedFallback, InplaceOp) {
  REGISTER_TEST_OP(
      "inplace_op",
      "_test::inplace_op(Tensor(a!) self, Tensor other) -> Tensor(a!)",
      inplace_op);
  auto opHandle =
      c10::Dispatcher::singleton().findSchemaOrThrow("_test::inplace_op", "");
  auto op = [&](const torch::Tensor& _1, const torch::Tensor& _2) {
    return callOpUnboxed<
        torch::Tensor,
        const torch::Tensor&,
        const torch::Tensor&>(opHandle, _1, _2);
  };

  auto a = torch::tensor({1.}, {torch::kFloat32}).set_requires_grad(true);
  auto b = torch::tensor({1.}, {torch::kFloat32});

  // Check in-place
  ASSERT_THROWS_WITH(
      op(a, b),
      "a leaf Variable that requires grad is being used in an in-place operation");
  op(b, a);
  a = a.clone();
  b = b.clone();
  auto c = op(a, b);
  ASSERT_TRUE(torch::allclose(c, inplace_op(a, b)));

  // Test in-place on view
  auto base =
      torch::tensor({1.}, {torch::kFloat32}).set_requires_grad(true).clone();
  auto view = base.view(-1);
  auto t = torch::tensor({1.}, {torch::kFloat32});

  torch::Tensor v_nograd;
  {
    c10::NoGradGuard guard;
    v_nograd = base.view(-1);
    op(v_nograd, t);
  }

  ASSERT_THROWS_WITH(op(v_nograd, t), "A view was created in no_grad mode");
  ASSERT_EQ(op(view, t).unsafeGetTensorImpl(), view.unsafeGetTensorImpl());
  ASSERT_THAT(
      op(view, t).grad_fn()->name(), ::testing::HasSubstr("AsStridedBackward"));
}

TEST(TestAutogradNotImplementedFallback, DoubleInplaceOp) {
  REGISTER_TEST_OP(
      "two_arg_inplace_op",
      "_test::two_arg_inplace_op(Tensor(a!) self, Tensor(b!) other) -> (Tensor(a!), Tensor(b!))",
      two_arg_inplace_op);
  auto opHandle = c10::Dispatcher::singleton().findSchemaOrThrow(
      "_test::two_arg_inplace_op", "");
  auto op = [&](const torch::Tensor& _1, const torch::Tensor& _2) {
    return callOpUnboxed<
        std::tuple<torch::Tensor, torch::Tensor>,
        const torch::Tensor&,
        const torch::Tensor&>(opHandle, _1, _2);
  };
  auto a = torch::tensor({1.}, {torch::kFloat32}).set_requires_grad(true);
  auto b = torch::tensor({1.}, {torch::kFloat32});

  // Both are modified in-place!
  ASSERT_THROWS_WITH(
      op(a, b),
      "a leaf Variable that requires grad is being used in an in-place operation");
  ASSERT_THROWS_WITH(
      op(b, a),
      "a leaf Variable that requires grad is being used in an in-place operation");

  auto c =
      torch::tensor({1.}, {torch::kFloat32}).set_requires_grad(true).clone();
  auto d =
      torch::tensor({1.}, {torch::kFloat32}).set_requires_grad(true).clone();

  auto saved_version_c = c._version();
  auto saved_version_d = d._version();
  op(c, d);
  ASSERT_NE(c._version(), saved_version_c);
  ASSERT_NE(d._version(), saved_version_d);
}

TEST(TestAutogradNotImplementedFallback, OptOp) {
  REGISTER_TEST_OP(
      "opt_op", "_test::opt_op(Tensor self, Tensor? other) -> Tensor", opt_op);
  auto opHandle =
      c10::Dispatcher::singleton().findSchemaOrThrow("_test::opt_op", "");
  auto op = [&](const torch::Tensor& _1,
                const std::optional<torch::Tensor>& _2) {
    return callOpUnboxed<
        torch::Tensor,
        const torch::Tensor&,
        const std::optional<torch::Tensor>&>(opHandle, _1, _2);
  };

  auto a = torch::tensor({1.}, {torch::kFloat32}).set_requires_grad(true);
  auto b = torch::tensor({1.}, {torch::kFloat32});

  ASSERT_TRUE(torch::allclose(op(a, b), opt_op(a, b)));
  ASSERT_TRUE(torch::allclose(op(a, {}), opt_op(a, {})));
}

TEST(TestAutogradNotImplementedFallback, OutOfPlaceAddition) {
  REGISTER_TEST_OP(
      "my_custom_op",
      "_test::my_custom_op(Tensor self, Tensor other) -> Tensor",
      my_custom_op);
  auto opHandle =
      c10::Dispatcher::singleton().findSchemaOrThrow("_test::my_custom_op", "");
  auto op = [&](const torch::Tensor& _1, const torch::Tensor& _2) {
    return callOpUnboxed<
        torch::Tensor,
        const torch::Tensor&,
        const torch::Tensor&>(opHandle, _1, _2);
  };

  assertBasicChecks(op);
}

TEST(TestAutogradNotImplementedFallback, RetTupleNonTensor) {
  REGISTER_TEST_OP(
      "ret_tuple_non_tensor",
      "_test::ret_tuple_non_tensor(Tensor self, Tensor other) -> (Tensor, Tensor, int)",
      ret_tuple_non_tensor);
  auto opHandle = c10::Dispatcher::singleton().findSchemaOrThrow(
      "_test::ret_tuple_non_tensor", "");
  auto op = [&](const torch::Tensor& _1, const torch::Tensor& _2) {
    auto out = callOpUnboxed<
        std::tuple<torch::Tensor, torch::Tensor, int64_t>,
        const torch::Tensor&,
        const torch::Tensor&>(opHandle, _1, _2);
    auto [out0, out1, out2] = std::move(out);
    return out0;
  };

  assertBasicChecks(op);
}

TEST(TestAutogradNotImplementedFallback, ViewOp) {
  REGISTER_TEST_OP(
      "view_op", "_test::view_op(Tensor(a) self) -> Tensor(a)", view_op);
  auto opHandle =
      c10::Dispatcher::singleton().findSchemaOrThrow("_test::view_op", "");
  auto op = [&](const torch::Tensor& _1) {
    return callOpUnboxed<torch::Tensor, const torch::Tensor&>(opHandle, _1);
  };
  auto b = torch::tensor({1.}, {torch::kFloat32});
  auto v = op(b);
  ASSERT_TRUE(v.is_view());
  ASSERT_EQ(v._base().unsafeGetTensorImpl(), b.unsafeGetTensorImpl());

  auto b1 =
      torch::tensor({1.}, {torch::kFloat32}).set_requires_grad(true).clone();
  auto v1 = op(b1);
  ASSERT_TRUE(v1.is_view());
  ASSERT_EQ(v1._base().unsafeGetTensorImpl(), b1.unsafeGetTensorImpl());

  // Test inplace on view
  auto t = torch::tensor({1.}, {torch::kFloat32}).set_requires_grad(true);

  // raise on rebase_history when it refreshes grad_fn
  ASSERT_THROWS_WITH(
      v1.add_(t), "which does not have a derivative implemented is forbidden");
  // base should not be aware of the views, so this is still okay
  b1.add_(t);
  ASSERT_THROWS_WITH(
      v1.grad_fn(),
      "which does not have a derivative implemented is forbidden");
}

TEST(TestAutogradNotImplementedFallback, ViewOpWithExtraArg) {
  REGISTER_TEST_OP(
      "view_op_with_extra_arg",
      "_test::view_op_with_extra_arg(Tensor(a) self, Tensor other) -> Tensor(a)",
      view_op_with_extra_arg);
  auto opHandle = c10::Dispatcher::singleton().findSchemaOrThrow(
      "_test::view_op_with_extra_arg", "");
  auto op = [&](const torch::Tensor& _1, const torch::Tensor& _2) {
    return callOpUnboxed<
        torch::Tensor,
        const torch::Tensor&,
        const torch::Tensor&>(opHandle, _1, _2);
  };
  assertBasicChecks(op);
  auto a = torch::tensor({1.}, {torch::kFloat32});
  auto b = torch::tensor({2.}, {torch::kFloat32});
  auto out1 = op(a, b);
  ASSERT_TRUE(out1.is_view());
  ASSERT_EQ(out1._base().unsafeGetTensorImpl(), a.unsafeGetTensorImpl());
}

TEST(TestAutogradNotImplementedFallback, RetTensorVectorView) {
  REGISTER_TEST_OP(
      "ret_tensor_vector_view",
      "_test::ret_tensor_vector_view(Tensor(a) self, Tensor other) -> Tensor[](a)",
      ret_tensor_vector_view);
  auto opHandle = c10::Dispatcher::singleton().findSchemaOrThrow(
      "_test::ret_tensor_vector_view", "");
  auto op = [&](const torch::Tensor& _1, const torch::Tensor& _2) {
    return callOpUnboxed<
        std::vector<at::Tensor>,
        const torch::Tensor&,
        const torch::Tensor&>(opHandle, _1, _2);
  };
  auto a = torch::tensor({1.}, {torch::kFloat32});
  auto b = torch::tensor({1.}, {torch::kFloat32});
  auto out = op(a, b);
  ASSERT_TRUE(out[0].is_view());
  ASSERT_EQ(out[0]._base().unsafeGetTensorImpl(), a.unsafeGetTensorImpl());
  ASSERT_TRUE(out[1].is_view());
  ASSERT_EQ(out[1]._base().unsafeGetTensorImpl(), a.unsafeGetTensorImpl());
}

TEST(TestAutogradNotImplementedFallback, DoubleViewOP) {
  REGISTER_TEST_OP(
      "two_pairs_of_view_op",
      "_test::two_pairs_of_view_op(Tensor(a) self, Tensor(b) other) -> (Tensor(a), Tensor(b))",
      two_pairs_of_view_op);
  auto opHandle = c10::Dispatcher::singleton().findSchemaOrThrow(
      "_test::two_pairs_of_view_op", "");
  auto op = [&](const torch::Tensor& _1, const torch::Tensor& _2) {
    return callOpUnboxed<
        std::tuple<torch::Tensor, torch::Tensor>,
        const torch::Tensor&,
        const torch::Tensor&>(opHandle, _1, _2);
  };
  auto a = torch::tensor({1.}, {torch::kFloat32}).set_requires_grad(true);
  auto b = torch::tensor({1.}, {torch::kFloat32});
  ASSERT_THROWS_WITH(
      op(a, b),
      "Expected only a single output in the operator schema to have a non-write alias annotation");
}

TEST(TestAutogradNotImplementedFallback, NonFirstViewOP) {
  REGISTER_TEST_OP(
      "non_first_view_op",
      "_test::non_first_view_op(Tensor self, Tensor(b) other) -> (Tensor, Tensor(b))",
      non_first_view_op);
  auto opHandle = c10::Dispatcher::singleton().findSchemaOrThrow(
      "_test::non_first_view_op", "");
  auto op = [&](const torch::Tensor& _1, const torch::Tensor& _2) {
    return callOpUnboxed<
        std::tuple<torch::Tensor, torch::Tensor>,
        const torch::Tensor&,
        const torch::Tensor&>(opHandle, _1, _2);
  };
  auto a = torch::tensor({1.}, {torch::kFloat32}).set_requires_grad(true);
  auto b = torch::tensor({1.}, {torch::kFloat32});
  ASSERT_THROWS_WITH(
      op(a, b), "can only create view relationships between the first");
}

TEST(TestAutogradNotImplementedFallback, RetTensorVector) {
  REGISTER_TEST_OP(
      "ret_tensor_vector",
      "_test::ret_tensor_vector(Tensor self, Tensor other) -> Tensor[]",
      ret_tensor_vector);
  auto opHandle = c10::Dispatcher::singleton().findSchemaOrThrow(
      "_test::ret_tensor_vector", "");
  auto op = [&](const torch::Tensor& _1, const torch::Tensor& _2) {
    return callOpUnboxed<
        std::vector<at::Tensor>,
        const torch::Tensor&,
        const torch::Tensor&>(opHandle, _1, _2)[0];
  };
  assertBasicChecks(op);
}

TEST(TestAutogradNotImplementedFallback, TensorlistOp) {
  REGISTER_TEST_OP(
      "tensorlist_op",
      "_test::tensorlist_op(Tensor self, Tensor[] other) -> Tensor",
      tensorlist_op);
  auto opHandle = c10::Dispatcher::singleton().findSchemaOrThrow(
      "_test::tensorlist_op", "");
  auto op = [&](torch::Tensor _1, at::TensorList _2) {
    return callOpUnboxed<torch::Tensor, const torch::Tensor&, at::TensorList>(
        opHandle, _1, _2);
  };

  auto a = torch::tensor({1.}, {torch::kFloat32});
  auto b = torch::tensor({1.}, {torch::kFloat32});
  auto c = torch::tensor({1.}, {torch::kFloat32}).set_requires_grad(true);
  std::vector<torch::Tensor> vec = {b, c};
  auto out = op(a, vec);

  ASSERT_THROWS_WITH(
      torch::autograd::grad({out}, {vec[0]}),
      "element 0 of the input tensors does not require grad");
  ASSERT_THROWS_WITH(
      torch::autograd::grad({out}, {vec[1]}), "is not implemented");

  ASSERT_TRUE(at::allclose(op(a, vec), tensorlist_op(a, vec)));
}

// TODO add these tests if needed
// test_once_differentiable
// test_sparse_backward
// test_save_output_nr
// test_free_deep_graph_pyfunction
// test_naughty_anomaly_access
// test_naughty_autograd-function_stashing_ctx
// test_custom_autograd_repeated_grad_grad
// test_return_leaf
// test_anomaly_detect_nan
// test_no_grad_copy
