#include <gtest/gtest.h>

#include <c10/util/flat_hash_map.h>
#include <c10/util/irange.h>
#include <c10/util/tempfile.h>

#include <torch/torch.h>

#include <test/cpp/api/support.h>

#include <cstdio>
#include <memory>
#include <sstream>
#include <string>
#include <vector>

using namespace torch::test;
using namespace torch::nn;
using namespace torch::optim;

namespace {
Sequential xor_model() {
  return Sequential(
      Linear(2, 8),
      Functional(at::sigmoid),
      Linear(8, 1),
      Functional(at::sigmoid));
}

torch::Tensor save_and_load(torch::Tensor input) {
  std::stringstream stream;
  torch::save(input, stream);
  torch::Tensor tensor;
  torch::load(tensor, stream);
  return tensor;
}
} // namespace

template <typename DerivedOptions>
void is_optimizer_param_group_equal(
    const OptimizerParamGroup& lhs,
    const OptimizerParamGroup& rhs) {
  const auto& lhs_params = lhs.params();
  const auto& rhs_params = rhs.params();

  ASSERT_TRUE(lhs_params.size() == rhs_params.size());
  for (const auto j : c10::irange(lhs_params.size())) {
    ASSERT_TRUE(torch::equal(lhs_params[j], rhs_params[j]));
  }
  ASSERT_TRUE(
      static_cast<const DerivedOptions&>(lhs.options()) ==
      static_cast<const DerivedOptions&>(rhs.options()));
}

template <typename DerivedOptimizerParamState>
void is_optimizer_state_equal(
    const ska::flat_hash_map<void*, std::unique_ptr<OptimizerParamState>>&
        lhs_state,
    const ska::flat_hash_map<void*, std::unique_ptr<OptimizerParamState>>&
        rhs_state) {
  ASSERT_TRUE(lhs_state.size() == rhs_state.size());
  for (const auto& value : lhs_state) {
    auto found = rhs_state.find(value.first);
    ASSERT_TRUE(found != rhs_state.end());
    const DerivedOptimizerParamState& lhs_curr_state =
        static_cast<const DerivedOptimizerParamState&>(*(value.second.get()));
    const DerivedOptimizerParamState& rhs_curr_state =
        static_cast<const DerivedOptimizerParamState&>(*(found->second.get()));
    ASSERT_TRUE(lhs_curr_state == rhs_curr_state);
  }
}

template <
    typename OptimizerClass,
    typename DerivedOptimizerOptions,
    typename DerivedOptimizerParamState>
void test_serialize_optimizer(
    DerivedOptimizerOptions options,
    bool only_has_global_state = false) {
  torch::manual_seed(0);
  auto model1 = Linear(5, 2);
  auto model2 = Linear(5, 2);
  auto model3 = Linear(5, 2);

  // Models 1, 2, 3 will have the same parameters.
  auto model_tempfile = c10::make_tempfile();
  torch::save(model1, model_tempfile.name);
  torch::load(model2, model_tempfile.name);
  torch::load(model3, model_tempfile.name);

  auto param1 = model1->named_parameters();
  auto param2 = model2->named_parameters();
  auto param3 = model3->named_parameters();
  for (const auto& p : param1) {
    ASSERT_TRUE(p->allclose(param2[p.key()]));
    ASSERT_TRUE(param2[p.key()].allclose(param3[p.key()]));
  }
  // Make some optimizers
  auto optim1 = OptimizerClass(
      {torch::optim::OptimizerParamGroup(model1->parameters())}, options);
  auto optim2 = OptimizerClass(model2->parameters(), options);
  auto optim2_2 = OptimizerClass(model2->parameters(), options);
  auto optim3 = OptimizerClass(model3->parameters(), options);
  auto optim3_2 = OptimizerClass(model3->parameters(), options);
  for (auto& param_group : optim3_2.param_groups()) {
    const double lr = param_group.options().get_lr();
    // change the learning rate, which will be overwritten by the loading
    // otherwise, test cannot check if options are saved and loaded correctly
    param_group.options().set_lr(lr + 0.01);
  }

  auto x = torch::ones({10, 5});

  auto step = [&x](torch::optim::Optimizer& optimizer, Linear model) {
    optimizer.zero_grad();
    auto y = model->forward(x).sum();
    y.backward();
    auto closure = []() { return torch::tensor({10}); };
    optimizer.step(closure);
  };

  // Do 2 steps of model1
  step(optim1, model1);
  step(optim1, model1);

  // Do 2 steps of model 2 without saving the optimizer
  step(optim2, model2);
  step(optim2_2, model2);

  // Do 1 step of model 3
  step(optim3, model3);

  // save the optimizer
  auto optim_tempfile = c10::make_tempfile();
  torch::save(optim3, optim_tempfile.name);
  torch::load(optim3_2, optim_tempfile.name);

  auto& optim3_2_param_groups = optim3_2.param_groups();
  auto& optim3_param_groups = optim3.param_groups();
  auto& optim3_2_state = optim3_2.state();
  auto& optim3_state = optim3.state();

  // optim3_2 and optim1 should have param_groups and state of size 1 and
  // state_size respectively
  ASSERT_TRUE(optim3_2_param_groups.size() == 1);
  // state_size = 2 for all optimizers except LBFGS as LBFGS only maintains one
  // global state
  unsigned state_size = only_has_global_state ? 1 : 2;
  ASSERT_TRUE(optim3_2_state.size() == state_size);

  // optim3_2 and optim1 should have param_groups and state of same size
  ASSERT_TRUE(optim3_2_param_groups.size() == optim3_param_groups.size());
  ASSERT_TRUE(optim3_2_state.size() == optim3_state.size());

  // checking correctness of serialization logic for optimizer.param_groups_ and
  // optimizer.state_
  for (const auto i : c10::irange(optim3_2_param_groups.size())) {
    is_optimizer_param_group_equal<DerivedOptimizerOptions>(
        optim3_2_param_groups[i], optim3_param_groups[i]);
    is_optimizer_state_equal<DerivedOptimizerParamState>(
        optim3_2_state, optim3_state);
  }

  // Do step2 for model 3
  step(optim3_2, model3);

  param1 = model1->named_parameters();
  param2 = model2->named_parameters();
  param3 = model3->named_parameters();
  for (const auto& p : param1) {
    const auto& name = p.key();
    // Model 1 and 3 should be the same
    ASSERT_TRUE(
        param1[name].norm().item<float>() == param3[name].norm().item<float>());
    ASSERT_TRUE(
        param1[name].norm().item<float>() != param2[name].norm().item<float>());
  }
}

/// Utility function to save a value of `int64_t` type.
void write_int_value(
    torch::serialize::OutputArchive& archive,
    const std::string& key,
    const int64_t& value) {
  archive.write(key, c10::IValue(value));
}
// Utility function to save a vector of buffers.
template <typename BufferContainer>
void write_tensors_to_archive(
    torch::serialize::OutputArchive& archive,
    const std::string& key,
    const BufferContainer& buffers) {
  archive.write(
      key + "/size", torch::tensor(static_cast<int64_t>(buffers.size())));
  for (const auto index : c10::irange(buffers.size())) {
    archive.write(
        key + "/" + std::to_string(index), buffers[index], /*is_buffer=*/true);
  }
}

// Utility function to save a vector of step buffers.
void write_step_buffers(
    torch::serialize::OutputArchive& archive,
    const std::string& key,
    const std::vector<int64_t>& steps) {
  std::vector<torch::Tensor> tensors;
  tensors.reserve(steps.size());
  for (const auto& step : steps) {
    tensors.push_back(torch::tensor(static_cast<int64_t>(step)));
  }
  write_tensors_to_archive(archive, key, tensors);
}

#define OLD_SERIALIZATION_LOGIC_WARNING_CHECK(funcname, optimizer, filename) \
  {                                                                          \
    WarningCapture warnings;                                                 \
    funcname(optimizer, filename);                                           \
    ASSERT_EQ(                                                               \
        count_substr_occurrences(warnings.str(), "old serialization"), 1);   \
  }

TEST(SerializeTest, KeysFunc) {
  auto tempfile = c10::make_tempfile();
  torch::serialize::OutputArchive output_archive;
  for (const auto i : c10::irange(3)) {
    output_archive.write(
        "element/" + std::to_string(i), c10::IValue(static_cast<int64_t>(i)));
  }
  output_archive.save_to(tempfile.name);
  torch::serialize::InputArchive input_archive;
  input_archive.load_from(tempfile.name);
  std::vector<std::string> keys = input_archive.keys();
  ASSERT_EQ(keys.size(), 3);
  for (const auto i : c10::irange(keys.size())) {
    ASSERT_EQ(keys[i], "element/" + std::to_string(i));
  }
}

TEST(SerializeTest, TryReadFunc) {
  auto tempfile = c10::make_tempfile();
  torch::serialize::OutputArchive output_archive;
  for (const auto i : c10::irange(3)) {
    output_archive.write(
        "element/" + std::to_string(i), c10::IValue(static_cast<int64_t>(i)));
  }
  output_archive.save_to(tempfile.name);
  torch::serialize::InputArchive input_archive;
  input_archive.load_from(tempfile.name);
  c10::IValue ivalue;
  ASSERT_FALSE(input_archive.try_read("1", ivalue));
  ASSERT_TRUE(input_archive.try_read("element/1", ivalue));
  ASSERT_EQ(ivalue.toInt(), 1);
}

TEST(SerializeTest, Basic) {
  torch::manual_seed(0);

  auto x = torch::randn({5, 5});
  auto y = save_and_load(x);

  ASSERT_TRUE(y.defined());
  ASSERT_EQ(x.sizes().vec(), y.sizes().vec());
  ASSERT_TRUE(x.allclose(y));
}

TEST(SerializeTest, MathBits) {
  torch::manual_seed(0);

  auto options = torch::TensorOptions{}.dtype(torch::kComplexFloat);
  auto x = torch::randn({5, 5}, options);
  {
    auto expected = torch::conj(x);
    auto actual = save_and_load(expected);

    ASSERT_TRUE(actual.defined());
    ASSERT_EQ(actual.sizes().vec(), expected.sizes().vec());
    ASSERT_TRUE(actual.allclose(expected));
  }

  {
    auto expected = torch::_neg_view(x);
    auto actual = save_and_load(expected);

    ASSERT_TRUE(actual.defined());
    ASSERT_EQ(actual.sizes().vec(), expected.sizes().vec());
    ASSERT_TRUE(actual.allclose(expected));
  }

  {
    auto expected = torch::conj(torch::_neg_view(x));
    auto actual = save_and_load(expected);

    ASSERT_TRUE(actual.defined());
    ASSERT_EQ(actual.sizes().vec(), expected.sizes().vec());
    ASSERT_TRUE(actual.allclose(expected));
  }

  {
    // We don't support serializing `ZeroTensor` as it is not public facing yet.
    // If in future, `ZeroTensor` serialization is supported, this test should
    // start failing!
    auto t = torch::_efficientzerotensor({5, 5});
    ASSERT_THROWS_WITH(save_and_load(t), "ZeroTensor is not serializable,");
  }
}

TEST(SerializeTest, BasicToFile) {
  torch::manual_seed(0);

  auto x = torch::randn({5, 5});

  auto tempfile = c10::make_tempfile();
  torch::save(x, tempfile.name);

  torch::Tensor y;
  torch::load(y, tempfile.name);

  ASSERT_TRUE(y.defined());
  ASSERT_EQ(x.sizes().vec(), y.sizes().vec());
  ASSERT_TRUE(x.allclose(y));
}

TEST(SerializeTest, BasicViaFunc) {
  torch::manual_seed(0);

  auto x = torch::randn({5, 5});

  std::string serialized;
  torch::save(x, [&](const void* buf, size_t n) {
    serialized.append(reinterpret_cast<const char*>(buf), n);
    return n;
  });
  torch::Tensor y;
  torch::load(y, serialized.data(), serialized.size());

  ASSERT_TRUE(y.defined());
  ASSERT_EQ(x.sizes().vec(), y.sizes().vec());
  ASSERT_TRUE(x.allclose(y));

  torch::Tensor z;
  torch::load(
      z,
      [&](uint64_t pos, void* buf, size_t n) -> size_t {
        if (pos >= serialized.size())
          return 0;
        size_t nbytes =
            std::min(static_cast<size_t>(pos) + n, serialized.size()) - pos;
        memcpy(buf, serialized.data() + pos, nbytes);
        return nbytes;
      },
      [&]() -> size_t { return serialized.size(); });
  ASSERT_TRUE(z.defined());
  ASSERT_EQ(x.sizes().vec(), z.sizes().vec());
  ASSERT_TRUE(x.allclose(z));
}

TEST(SerializeTest, Resized) {
  torch::manual_seed(0);

  auto x = torch::randn({11, 5});
  x.resize_({5, 5});
  auto y = save_and_load(x);

  ASSERT_TRUE(y.defined());
  ASSERT_EQ(x.sizes().vec(), y.sizes().vec());
  ASSERT_TRUE(x.allclose(y));
}

TEST(SerializeTest, Sliced) {
  torch::manual_seed(0);

  auto x = torch::randn({11, 5});
  x = x.slice(0, 1, 5);
  auto y = save_and_load(x);

  ASSERT_TRUE(y.defined());
  ASSERT_EQ(x.sizes().vec(), y.sizes().vec());
  ASSERT_TRUE(x.allclose(y));
}

TEST(SerializeTest, NonContiguous) {
  torch::manual_seed(0);

  auto x = torch::randn({11, 5});
  x = x.slice(1, 1, 4);
  auto y = save_and_load(x);

  ASSERT_TRUE(y.defined());
  ASSERT_EQ(x.sizes().vec(), y.sizes().vec());
  ASSERT_TRUE(x.allclose(y));
}

TEST(SerializeTest, ErrorOnMissingKey) {
  struct B : torch::nn::Module {
    B(const std::string& name_c) {
      register_buffer(name_c, torch::ones(5, torch::kFloat));
    }
  };
  struct A : torch::nn::Module {
    A(const std::string& name_b, const std::string& name_c) {
      register_module(name_b, std::make_shared<B>(name_c));
    }
  };
  struct M : torch::nn::Module {
    M(const std::string& name_a,
      const std::string& name_b,
      const std::string& name_c) {
      register_module(name_a, std::make_shared<A>(name_b, name_c));
    }
  };

  // create a hierarchy of models with names differing below the top level
  auto model1 = std::make_shared<M>("a", "b", "c");
  auto model2 = std::make_shared<M>("a", "b", "x");
  auto model3 = std::make_shared<M>("a", "x", "c");

  std::stringstream stream;
  torch::save(model1, stream);
  // We want the errors to contain hierarchy information, too.
  ASSERT_THROWS_WITH(
      torch::load(model2, stream), "No such serialized tensor 'a.b.x'");
  stream.seekg(0, stream.beg);
  ASSERT_THROWS_WITH(
      torch::load(model3, stream), "No such serialized submodule: 'a.x'");
}

TEST(SerializeTest, XOR) {
  // We better be able to save and load an XOR model!
  auto getLoss = [](Sequential model, uint32_t batch_size) {
    auto inputs = torch::empty({batch_size, 2});
    auto labels = torch::empty({batch_size});
    for (const auto i : c10::irange(batch_size)) {
      inputs[i] = torch::randint(2, {2}, torch::kInt64);
      labels[i] = inputs[i][0].item<int64_t>() ^ inputs[i][1].item<int64_t>();
    }
    auto x = model->forward<torch::Tensor>(inputs);
    return torch::binary_cross_entropy(x, labels);
  };

  auto model = xor_model();
  auto model2 = xor_model();
  auto model3 = xor_model();
  auto optimizer = torch::optim::SGD(
      model->parameters(),
      torch::optim::SGDOptions(1e-1).momentum(0.9).nesterov(true).weight_decay(
          1e-6));

  float running_loss = 1;
  int epoch = 0;
  while (running_loss > 0.1) {
    torch::Tensor loss = getLoss(model, 4);
    optimizer.zero_grad();
    loss.backward();
    optimizer.step();

    // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers,cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
    running_loss = running_loss * 0.99 + loss.sum().item<float>() * 0.01;
    ASSERT_LT(epoch, 3000);
    epoch++;
  }

  auto tempfile = c10::make_tempfile();
  torch::save(model, tempfile.name);
  torch::load(model2, tempfile.name);

  auto loss = getLoss(model2, 100);
  ASSERT_LT(loss.item<float>(), 0.1);
}

TEST(SerializeTest, Optim) {
  auto model1 = Linear(5, 2);
  auto model2 = Linear(5, 2);
  auto model3 = Linear(5, 2);

  // Models 1, 2, 3 will have the same parameters.
  auto model_tempfile = c10::make_tempfile();
  torch::save(model1, model_tempfile.name);
  torch::load(model2, model_tempfile.name);
  torch::load(model3, model_tempfile.name);

  auto param1 = model1->named_parameters();
  auto param2 = model2->named_parameters();
  auto param3 = model3->named_parameters();
  for (const auto& p : param1) {
    ASSERT_TRUE(p->allclose(param2[p.key()]));
    ASSERT_TRUE(param2[p.key()].allclose(param3[p.key()]));
  }

  // Make some optimizers with momentum (and thus state)
  auto optim1 = torch::optim::SGD(
      model1->parameters(), torch::optim::SGDOptions(1e-1).momentum(0.9));
  auto optim2 = torch::optim::SGD(
      model2->parameters(), torch::optim::SGDOptions(1e-1).momentum(0.9));
  auto optim2_2 = torch::optim::SGD(
      model2->parameters(), torch::optim::SGDOptions(1e-1).momentum(0.9));
  auto optim3 = torch::optim::SGD(
      model3->parameters(), torch::optim::SGDOptions(1e-1).momentum(0.9));
  auto optim3_2 = torch::optim::SGD(
      model3->parameters(), torch::optim::SGDOptions(1e-1).momentum(0.9));

  auto x = torch::ones({10, 5});

  auto step = [&x](torch::optim::Optimizer& optimizer, Linear model) {
    optimizer.zero_grad();
    auto y = model->forward(x).sum();
    y.backward();
    optimizer.step();
  };

  // Do 2 steps of model1
  step(optim1, model1);
  step(optim1, model1);

  // Do 2 steps of model 2 without saving the optimizer
  step(optim2, model2);
  step(optim2_2, model2);

  // Do 2 steps of model 3 while saving the optimizer
  step(optim3, model3);

  auto optim_tempfile = c10::make_tempfile();
  torch::save(optim3, optim_tempfile.name);
  torch::load(optim3_2, optim_tempfile.name);
  step(optim3_2, model3);

  param1 = model1->named_parameters();
  param2 = model2->named_parameters();
  param3 = model3->named_parameters();
  for (const auto& p : param1) {
    const auto& name = p.key();
    // Model 1 and 3 should be the same
    ASSERT_TRUE(
        param1[name].norm().item<float>() == param3[name].norm().item<float>());
    ASSERT_TRUE(
        param1[name].norm().item<float>() != param2[name].norm().item<float>());
  }
}

TEST(SerializeTest, Optim_Adagrad) {
  test_serialize_optimizer<Adagrad, AdagradOptions, AdagradParamState>(
      AdagradOptions(1e-1));

  // bc compatibility check
  auto model1 = Linear(5, 2);
  auto optim1 = torch::optim::Adagrad(
      model1->parameters(), torch::optim::AdagradOptions(1e-1));

  auto x = torch::ones({10, 5});
  auto step = [&x](torch::optim::Optimizer& optimizer, Linear model) {
    optimizer.zero_grad();
    auto y = model->forward(x).sum();
    y.backward();
    optimizer.step();
  };
  step(optim1, model1);
  auto optim1_2 =
      Adagrad(model1->parameters(), torch::optim::AdagradOptions(1e-1));

  // fill up with optim1 sum_buffers
  std::vector<torch::Tensor> sum_buffers;
  // fill up with optim1 state_buffers
  std::vector<int64_t> step_buffers;
  const auto& params_ = optim1.param_groups()[0].params();
  const auto& optim1_state = optim1.state();
  for (const auto& param : params_) {
    auto key_ = param.unsafeGetTensorImpl();
    const AdagradParamState& curr_state_ =
        static_cast<const AdagradParamState&>(*(optim1_state.at(key_).get()));
    sum_buffers.emplace_back(curr_state_.sum());
    step_buffers.emplace_back(curr_state_.step());
  }
  // write sum_buffers and step_buffers to the file
  auto optim_tempfile_old_format = c10::make_tempfile();
  torch::serialize::OutputArchive output_archive;
  write_tensors_to_archive(output_archive, "sum_buffers", sum_buffers);
  write_step_buffers(output_archive, "step_buffers", step_buffers);
  output_archive.save_to(optim_tempfile_old_format.name);
  OLD_SERIALIZATION_LOGIC_WARNING_CHECK(
      torch::load, optim1_2, optim_tempfile_old_format.name);
  is_optimizer_state_equal<AdagradParamState>(optim1.state(), optim1_2.state());
}

TEST(SerializeTest, Optim_SGD) {
  test_serialize_optimizer<SGD, SGDOptions, SGDParamState>(
      SGDOptions(1e-1).momentum(0.9));

  // bc compatibility check
  auto model1 = Linear(5, 2);
  auto model1_params = model1->parameters();
  // added a tensor for lazy init check - when all params do not have a momentum
  // buffer entry
  model1_params.emplace_back(torch::randn({2, 3}));
  auto optim1 = torch::optim::SGD(
      model1_params, torch::optim::SGDOptions(0.01).momentum(0.9));

  auto x = torch::ones({10, 5});
  auto step = [&x](torch::optim::Optimizer& optimizer, Linear model) {
    optimizer.zero_grad();
    auto y = model->forward(x).sum();
    y.backward();
    optimizer.step();
  };
  step(optim1, model1);

  std::vector<at::Tensor> momentum_buffers;
  int64_t iteration_{0};
  const auto& params_ = optim1.param_groups()[0].params();
  const auto& optim1_state = optim1.state();
  for (const auto i : c10::irange(params_.size())) {
    if (i != (params_.size() - 1)) {
      auto key_ = params_[i].unsafeGetTensorImpl();
      const SGDParamState& curr_state_ =
          static_cast<const SGDParamState&>(*(optim1_state.at(key_).get()));
      momentum_buffers.emplace_back(curr_state_.momentum_buffer());
    }
  }
  ASSERT_TRUE(momentum_buffers.size() == (params_.size() - 1));
  // write momentum_buffers to the file
  auto optim_tempfile_old_format = c10::make_tempfile();
  torch::serialize::OutputArchive output_archive;
  write_tensors_to_archive(
      output_archive, "momentum_buffers", momentum_buffers);
  write_int_value(output_archive, "iteration_", iteration_);
  output_archive.save_to(optim_tempfile_old_format.name);
  auto optim1_2 =
      SGD(model1_params, torch::optim::SGDOptions(1e-1).momentum(0.9));
  OLD_SERIALIZATION_LOGIC_WARNING_CHECK(
      torch::load, optim1_2, optim_tempfile_old_format.name);
  is_optimizer_state_equal<SGDParamState>(optim1.state(), optim1_2.state());
}

TEST(SerializeTest, Optim_Adam) {
  test_serialize_optimizer<Adam, AdamOptions, AdamParamState>(
      AdamOptions().lr(0.99999).amsgrad(true).weight_decay(0.5));

  // bc compatibility check
  auto model1 = Linear(5, 2);
  auto model1_params = model1->parameters();
  // added a tensor for lazy init check - when all params do not have entry in
  // buffers
  model1_params.emplace_back(torch::randn({2, 3}));
  auto optim1 = torch::optim::Adam(
      model1_params, torch::optim::AdamOptions().weight_decay(0.5));

  auto x = torch::ones({10, 5});
  auto step = [&x](torch::optim::Optimizer& optimizer, Linear model) {
    optimizer.zero_grad();
    auto y = model->forward(x).sum();
    y.backward();
    optimizer.step();
  };
  step(optim1, model1);

  std::vector<int64_t> step_buffers;
  std::vector<at::Tensor> exp_average_buffers;
  std::vector<at::Tensor> exp_average_sq_buffers;
  std::vector<at::Tensor> max_exp_average_sq_buffers;
  const auto& params_ = optim1.param_groups()[0].params();
  const auto& optim1_state = optim1.state();
  for (const auto i : c10::irange(params_.size())) {
    if (i != (params_.size() - 1)) {
      auto key_ = params_[i].unsafeGetTensorImpl();
      const AdamParamState& curr_state_ =
          static_cast<const AdamParamState&>(*(optim1_state.at(key_).get()));
      step_buffers.emplace_back(curr_state_.step());
      exp_average_buffers.emplace_back(curr_state_.exp_avg());
      exp_average_sq_buffers.emplace_back(curr_state_.exp_avg_sq());
      if (curr_state_.max_exp_avg_sq().defined()) {
        max_exp_average_sq_buffers.emplace_back(curr_state_.max_exp_avg_sq());
      }
    }
  }
  // write buffers to the file
  auto optim_tempfile_old_format = c10::make_tempfile();
  torch::serialize::OutputArchive output_archive;
  write_step_buffers(output_archive, "step_buffers", step_buffers);
  write_tensors_to_archive(
      output_archive, "exp_average_buffers", exp_average_buffers);
  write_tensors_to_archive(
      output_archive, "exp_average_sq_buffers", exp_average_sq_buffers);
  write_tensors_to_archive(
      output_archive, "max_exp_average_sq_buffers", max_exp_average_sq_buffers);
  output_archive.save_to(optim_tempfile_old_format.name);
  auto optim1_2 = Adam(model1_params, torch::optim::AdamOptions());
  OLD_SERIALIZATION_LOGIC_WARNING_CHECK(
      torch::load, optim1_2, optim_tempfile_old_format.name);
  is_optimizer_state_equal<AdamParamState>(optim1.state(), optim1_2.state());
}

TEST(SerializeTest, Optim_AdamW) {
  test_serialize_optimizer<AdamW, AdamWOptions, AdamWParamState>(
      AdamWOptions().lr(0.99999).amsgrad(true).betas(
          std::make_tuple(0.999, 0.1)));

  // bc compatibility check
  auto model1 = Linear(5, 2);
  auto model1_params = model1->parameters();
  // added a tensor for lazy init check - when all params do not have entry in
  // buffers
  model1_params.emplace_back(torch::randn({2, 3}));
  auto optim1 = torch::optim::AdamW(
      model1_params, torch::optim::AdamWOptions().weight_decay(0.5));

  auto x = torch::ones({10, 5});
  auto step = [&x](torch::optim::Optimizer& optimizer, Linear model) {
    optimizer.zero_grad();
    auto y = model->forward(x).sum();
    y.backward();
    optimizer.step();
  };
  step(optim1, model1);

  std::vector<int64_t> step_buffers;
  std::vector<at::Tensor> exp_average_buffers;
  std::vector<at::Tensor> exp_average_sq_buffers;
  std::vector<at::Tensor> max_exp_average_sq_buffers;
  const auto& params_ = optim1.param_groups()[0].params();
  const auto& optim1_state = optim1.state();
  for (const auto i : c10::irange(params_.size())) {
    if (i != (params_.size() - 1)) {
      auto key_ = params_[i].unsafeGetTensorImpl();
      const AdamWParamState& curr_state_ =
          static_cast<const AdamWParamState&>(*(optim1_state.at(key_).get()));
      step_buffers.emplace_back(curr_state_.step());
      exp_average_buffers.emplace_back(curr_state_.exp_avg());
      exp_average_sq_buffers.emplace_back(curr_state_.exp_avg_sq());
      if (curr_state_.max_exp_avg_sq().defined()) {
        max_exp_average_sq_buffers.emplace_back(curr_state_.max_exp_avg_sq());
      }
    }
  }
  // write buffers to the file
  auto optim_tempfile_old_format = c10::make_tempfile();
  torch::serialize::OutputArchive output_archive;
  write_step_buffers(output_archive, "step_buffers", step_buffers);
  write_tensors_to_archive(
      output_archive, "exp_average_buffers", exp_average_buffers);
  write_tensors_to_archive(
      output_archive, "exp_average_sq_buffers", exp_average_sq_buffers);
  write_tensors_to_archive(
      output_archive, "max_exp_average_sq_buffers", max_exp_average_sq_buffers);
  output_archive.save_to(optim_tempfile_old_format.name);
  auto optim1_2 = AdamW(model1_params, torch::optim::AdamWOptions());
  OLD_SERIALIZATION_LOGIC_WARNING_CHECK(
      torch::load, optim1_2, optim_tempfile_old_format.name);
  is_optimizer_state_equal<AdamWParamState>(optim1.state(), optim1_2.state());
}

TEST(SerializeTest, Optim_RMSprop) {
  auto options = RMSpropOptions(0.1).momentum(0.9).centered(true);
  test_serialize_optimizer<RMSprop, RMSpropOptions, RMSpropParamState>(options);

  // bc compatibility check
  auto model1 = Linear(5, 2);
  auto model1_params = model1->parameters();

  // added a tensor for lazy init check - when all params do not have a momentum
  // buffer entry
  model1_params.emplace_back(torch::randn({2, 3}));
  auto optim1 = torch::optim::RMSprop(model1_params, options);

  auto x = torch::ones({10, 5});
  auto step = [&x](torch::optim::Optimizer& optimizer, Linear model) {
    optimizer.zero_grad();
    auto y = model->forward(x).sum();
    y.backward();
    optimizer.step();
  };
  step(optim1, model1);

  std::vector<at::Tensor> square_average_buffers;
  std::vector<at::Tensor> momentum_buffers;
  std::vector<at::Tensor> grad_average_buffers;
  const auto& params_ = optim1.param_groups()[0].params();
  const auto& optim1_state = optim1.state();
  for (const auto i : c10::irange(params_.size())) {
    if (i != (params_.size() - 1)) {
      auto key_ = params_[i].unsafeGetTensorImpl();
      const RMSpropParamState& curr_state_ =
          static_cast<const RMSpropParamState&>(*(optim1_state.at(key_).get()));
      square_average_buffers.emplace_back(curr_state_.square_avg());
      if (curr_state_.momentum_buffer().defined()) {
        momentum_buffers.emplace_back(curr_state_.momentum_buffer());
      }
      if (curr_state_.grad_avg().defined()) {
        grad_average_buffers.emplace_back(curr_state_.grad_avg());
      }
    }
  }
  // write buffers to the file
  auto optim_tempfile_old_format = c10::make_tempfile();
  torch::serialize::OutputArchive output_archive;
  write_tensors_to_archive(
      output_archive, "square_average_buffers", square_average_buffers);
  write_tensors_to_archive(
      output_archive, "momentum_buffers", momentum_buffers);
  write_tensors_to_archive(
      output_archive, "grad_average_buffers", grad_average_buffers);
  output_archive.save_to(optim_tempfile_old_format.name);
  auto optim1_2 = RMSprop(model1_params, options);
  OLD_SERIALIZATION_LOGIC_WARNING_CHECK(
      torch::load, optim1_2, optim_tempfile_old_format.name);
  const auto& params1_2_ = optim1_2.param_groups()[0].params();
  auto& optim1_2_state = optim1_2.state();
  // old RMSprop didn't track step value
  for (const auto i : c10::irange(params1_2_.size())) {
    if (i != (params1_2_.size() - 1)) {
      auto key_ = params_[i].unsafeGetTensorImpl();
      const RMSpropParamState& curr_state_ =
          static_cast<const RMSpropParamState&>(*(optim1_state.at(key_).get()));
      RMSpropParamState& curr_state1_2_ =
          static_cast<RMSpropParamState&>(*(optim1_2_state.at(key_).get()));
      curr_state1_2_.step(curr_state_.step());
    }
  }
  is_optimizer_state_equal<RMSpropParamState>(optim1.state(), optim1_2.state());
}

TEST(SerializeTest, Optim_LBFGS) {
  test_serialize_optimizer<LBFGS, LBFGSOptions, LBFGSParamState>(
      LBFGSOptions(), true);
  // bc compatibility check
  auto model1 = Linear(5, 2);
  auto model1_params = model1->parameters();
  // added a tensor for lazy init check - when all params do not have entry in
  // buffers
  model1_params.emplace_back(torch::randn({2, 3}));
  auto optim1 =
      torch::optim::LBFGS(model1_params, torch::optim::LBFGSOptions());

  auto x = torch::ones({10, 5});
  auto step = [&x](torch::optim::Optimizer& optimizer, Linear model) {
    optimizer.zero_grad();
    auto y = model->forward(x).sum();
    y.backward();
    auto closure = []() { return torch::tensor({10}); };
    optimizer.step(closure);
  };

  step(optim1, model1);

  at::Tensor d, t, H_diag, prev_flat_grad, prev_loss;
  std::deque<at::Tensor> old_dirs, old_stps;

  const auto& params_ = optim1.param_groups()[0].params();
  auto key_ = params_[0].unsafeGetTensorImpl();
  const auto& optim1_state =
      static_cast<const LBFGSParamState&>(*(optim1.state().at(key_).get()));
  d = optim1_state.d();
  t = at::tensor(optim1_state.t());
  H_diag = optim1_state.H_diag();
  prev_flat_grad = optim1_state.prev_flat_grad();
  prev_loss = at::tensor(optim1_state.prev_loss());
  old_dirs = optim1_state.old_dirs();

  // write buffers to the file
  auto optim_tempfile_old_format = c10::make_tempfile();
  torch::serialize::OutputArchive output_archive;
  output_archive.write("d", d, /*is_buffer=*/true);
  output_archive.write("t", t, /*is_buffer=*/true);
  output_archive.write("H_diag", H_diag, /*is_buffer=*/true);
  output_archive.write("prev_flat_grad", prev_flat_grad, /*is_buffer=*/true);
  output_archive.write("prev_loss", prev_loss, /*is_buffer=*/true);
  write_tensors_to_archive(output_archive, "old_dirs", old_dirs);
  write_tensors_to_archive(output_archive, "old_stps", old_stps);
  output_archive.save_to(optim_tempfile_old_format.name);

  auto optim1_2 = LBFGS(model1_params, torch::optim::LBFGSOptions());
  OLD_SERIALIZATION_LOGIC_WARNING_CHECK(
      torch::load, optim1_2, optim_tempfile_old_format.name);

  const auto& params1_2_ = optim1_2.param_groups()[0].params();
  auto param_key = params1_2_[0].unsafeGetTensorImpl();
  auto& optim1_2_state =
      static_cast<LBFGSParamState&>(*(optim1_2.state().at(param_key).get()));

  // old LBFGS didn't track func_evals, n_iter, ro, al values
  optim1_2_state.func_evals(optim1_state.func_evals());
  optim1_2_state.n_iter(optim1_state.n_iter());
  optim1_2_state.ro(optim1_state.ro());
  optim1_2_state.al(optim1_state.al());

  is_optimizer_state_equal<LBFGSParamState>(optim1.state(), optim1_2.state());
}

TEST(SerializeTest, XOR_CUDA) {
  torch::manual_seed(0);
  // We better be able to save and load a XOR model!
  auto getLoss = [](Sequential model,
                    uint32_t batch_size,
                    bool is_cuda = false) {
    auto inputs = torch::empty({batch_size, 2});
    auto labels = torch::empty({batch_size});
    if (is_cuda) {
      inputs = inputs.cuda();
      labels = labels.cuda();
    }
    for (const auto i : c10::irange(batch_size)) {
      inputs[i] = torch::randint(2, {2}, torch::kInt64);
      labels[i] = inputs[i][0].item<int64_t>() ^ inputs[i][1].item<int64_t>();
    }
    auto x = model->forward<torch::Tensor>(inputs);
    return torch::binary_cross_entropy(x, labels);
  };

  auto model = xor_model();
  auto model2 = xor_model();
  auto model3 = xor_model();
  auto optimizer = torch::optim::SGD(
      model->parameters(),
      torch::optim::SGDOptions(1e-1).momentum(0.9).nesterov(true).weight_decay(
          1e-6));

  float running_loss = 1;
  int epoch = 0;
  while (running_loss > 0.1) {
    torch::Tensor loss = getLoss(model, 4);
    optimizer.zero_grad();
    loss.backward();
    optimizer.step();

    // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers,cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
    running_loss = running_loss * 0.99 + loss.sum().item<float>() * 0.01;
    ASSERT_LT(epoch, 3000);
    epoch++;
  }

  auto tempfile = c10::make_tempfile();
  torch::save(model, tempfile.name);
  torch::load(model2, tempfile.name);

  auto loss = getLoss(model2, 100);
  ASSERT_LT(loss.item<float>(), 0.1);

  model2->to(torch::kCUDA);
  loss = getLoss(model2, 100, true);
  ASSERT_LT(loss.item<float>(), 0.1);

  auto tempfile2 = c10::make_tempfile();
  torch::save(model2, tempfile2.name);
  torch::load(model3, tempfile2.name);

  loss = getLoss(model3, 100, true);
  ASSERT_LT(loss.item<float>(), 0.1);
}

TEST(
    SerializeTest,
    CanSerializeModulesWithIntermediateModulesWithoutParametersOrBuffers) {
  struct C : torch::nn::Module {
    C() {
      register_buffer("foo", torch::ones(5, torch::kInt32));
    }
  };
  struct B : torch::nn::Module {};
  struct A : torch::nn::Module {
    A() {
      register_module("b", std::make_shared<B>());
      register_module("c", std::make_shared<C>());
    }
  };
  struct M : torch::nn::Module {
    M() {
      register_module("a", std::make_shared<A>());
    }
  };

  auto out = std::make_shared<M>();
  std::stringstream ss;
  torch::save(out, ss);
  auto in = std::make_shared<M>();
  torch::load(in, ss);

  const int output = in->named_buffers()["a.c.foo"].sum().item<int>();
  ASSERT_EQ(output, 5);
}

TEST(SerializeTest, VectorOfTensors) {
  torch::manual_seed(0);

  std::vector<torch::Tensor> x_vec = {
      torch::randn({1, 2}), torch::randn({3, 4})};

  std::stringstream stream;
  torch::save(x_vec, stream);

  std::vector<torch::Tensor> y_vec;
  torch::load(y_vec, stream);

  for (const auto i : c10::irange(x_vec.size())) {
    auto& x = x_vec[i];
    auto& y = y_vec[i];
    ASSERT_TRUE(y.defined());
    ASSERT_EQ(x.sizes().vec(), y.sizes().vec());
    ASSERT_TRUE(x.allclose(y));
  }
}

TEST(SerializeTest, IValue) {
  c10::IValue ivalue(1);
  auto tempfile = c10::make_tempfile();
  torch::serialize::OutputArchive output_archive;
  output_archive.write("value", ivalue);
  output_archive.save_to(tempfile.name);

  torch::serialize::InputArchive input_archive;
  input_archive.load_from(tempfile.name);
  c10::IValue ivalue_out;
  input_archive.read("value", ivalue_out);
  ASSERT_EQ(ivalue_out.toInt(), 1);

  ASSERT_THROWS_WITH(
      input_archive.read("bad_key", ivalue_out),
      "does not have a field with name");
}

// NOTE: if a `Module` contains unserializable submodules (e.g.
// `nn::Functional`), we expect those submodules to be skipped when the `Module`
// is being serialized.
TEST(SerializeTest, UnserializableSubmoduleIsSkippedWhenSavingModule) {
  struct A : torch::nn::Module {
    A() {
      register_module("relu", torch::nn::Functional(torch::relu));
    }
  };

  auto out = std::make_shared<A>();
  std::stringstream ss;
  torch::save(out, ss);

  torch::serialize::InputArchive archive;
  archive.load_from(ss);
  torch::serialize::InputArchive relu_archive;

  // Submodule with name "relu" should not exist in the `InputArchive`,
  // because the "relu" submodule is an `nn::Functional` and is not
  // serializable.
  ASSERT_FALSE(archive.try_read("relu", relu_archive));
}

// NOTE: If a `Module` contains unserializable submodules (e.g.
// `nn::Functional`), we don't check the existence of those submodules in the
// `InputArchive` when deserializing.
TEST(SerializeTest, UnserializableSubmoduleIsIgnoredWhenLoadingModule) {
  struct B : torch::nn::Module {
    B() {
      register_module("relu1", torch::nn::Functional(torch::relu));
      register_buffer("foo", torch::zeros(5, torch::kInt32));
    }
  };
  struct A : torch::nn::Module {
    A() {
      register_module("b", std::make_shared<B>());
      register_module("relu2", torch::nn::Functional(torch::relu));
    }
  };

  auto out = std::make_shared<A>();
  // Manually change the values of "b.foo", so that we can check whether the
  // buffer contains these values after deserialization.
  out->named_buffers()["b.foo"].fill_(1);
  auto tempfile = c10::make_tempfile();
  torch::save(out, tempfile.name);

  torch::serialize::InputArchive archive;
  archive.load_from(tempfile.name);
  torch::serialize::InputArchive archive_b;
  torch::serialize::InputArchive archive_relu;
  torch::Tensor tensor_foo;

  ASSERT_TRUE(archive.try_read("b", archive_b));
  ASSERT_TRUE(archive_b.try_read("foo", tensor_foo, /*is_buffer=*/true));

  // Submodule with name "relu1" should not exist in `archive_b`, because the
  // "relu1" submodule is an `nn::Functional` and is not serializable.
  ASSERT_FALSE(archive_b.try_read("relu1", archive_relu));

  // Submodule with name "relu2" should not exist in `archive`, because the
  // "relu2" submodule is an `nn::Functional` and is not serializable.
  ASSERT_FALSE(archive.try_read("relu2", archive_relu));

  auto in = std::make_shared<A>();
  // `torch::load(...)` works without error, even though `A` contains the
  // `nn::Functional` submodules while the serialized file doesn't, because the
  // `nn::Functional` submodules are not serializable and thus ignored when
  // deserializing.
  torch::load(in, tempfile.name);

  // Check that the "b.foo" buffer is correctly deserialized from the file.
  const int output = in->named_buffers()["b.foo"].sum().item<int>();
  // `output` should equal to the sum of the values we manually assigned to
  // "b.foo" before serialization.
  ASSERT_EQ(output, 5);
}
