#include <gtest/gtest.h>

#include <test/cpp/jit/test_custom_class_registrations.h>
#include <torch/csrc/jit/passes/freeze_module.h>
#include <torch/custom_class.h>
#include <torch/script.h>

#include <iostream>
#include <string>
#include <vector>

namespace torch {
namespace jit {

TEST(CustomClassTest, TorchbindIValueAPI) {
  script::Module m("m");

  // test make_custom_class API
  auto custom_class_obj = make_custom_class<MyStackClass<std::string>>(
      std::vector<std::string>{"foo", "bar"});
  m.define(R"(
    def forward(self, s : __torch__.torch.classes._TorchScriptTesting._StackString):
      return s.pop(), s
  )");

  auto test_with_obj = [&m](IValue obj, std::string expected) {
    auto res = m.run_method("forward", obj);
    auto tup = res.toTuple();
    AT_ASSERT(tup->elements().size() == 2);
    auto str = tup->elements()[0].toStringRef();
    auto other_obj =
        tup->elements()[1].toCustomClass<MyStackClass<std::string>>();
    AT_ASSERT(str == expected);
    auto ref_obj = obj.toCustomClass<MyStackClass<std::string>>();
    AT_ASSERT(other_obj.get() == ref_obj.get());
  };

  test_with_obj(custom_class_obj, "bar");

  // test IValue() API
  auto my_new_stack = c10::make_intrusive<MyStackClass<std::string>>(
      std::vector<std::string>{"baz", "boo"});
  auto new_stack_ivalue = c10::IValue(my_new_stack);

  test_with_obj(new_stack_ivalue, "boo");
}

TEST(CustomClassTest, ScalarTypeClass) {
  script::Module m("m");

  // test make_custom_class API
  auto cc = make_custom_class<ScalarTypeClass>(at::kFloat);
  m.register_attribute("s", cc.type(), cc, false);

  std::ostringstream oss;
  m.save(oss);
  std::istringstream iss(oss.str());
  caffe2::serialize::IStreamAdapter adapter{&iss};
  auto loaded_module = torch::jit::load(iss, torch::kCPU);
}

class TorchBindTestClass : public torch::jit::CustomClassHolder {
 public:
  std::string get() {
    return "Hello, I am your test custom class";
  }
};

// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
constexpr char class_doc_string[] = R"(
  I am docstring for TorchBindTestClass
  Args:
      What is an argument? Oh never mind, I don't take any.

  Return:
      How would I know? I am just a holder of some meaningless test methods.
  )";
// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
constexpr char method_doc_string[] =
    "I am docstring for TorchBindTestClass get_with_docstring method";

namespace {
static auto reg =
    torch::class_<TorchBindTestClass>(
        "_TorchBindTest",
        "_TorchBindTestClass",
        class_doc_string)
        .def("get", &TorchBindTestClass::get)
        .def("get_with_docstring", &TorchBindTestClass::get, method_doc_string);

} // namespace

// Tests DocString is properly propagated when defining CustomClasses.
TEST(CustomClassTest, TestDocString) {
  auto class_type = getCustomClass(
      "__torch__.torch.classes._TorchBindTest._TorchBindTestClass");
  AT_ASSERT(class_type);
  AT_ASSERT(class_type->doc_string() == class_doc_string);

  AT_ASSERT(class_type->getMethod("get").doc_string().empty());
  AT_ASSERT(
      class_type->getMethod("get_with_docstring").doc_string() ==
      method_doc_string);
}

TEST(CustomClassTest, Serialization) {
  script::Module m("m");

  // test make_custom_class API
  auto custom_class_obj = make_custom_class<MyStackClass<std::string>>(
      std::vector<std::string>{"foo", "bar"});
  m.register_attribute(
      "s",
      custom_class_obj.type(),
      custom_class_obj,
      // NOLINTNEXTLINE(bugprone-argument-comment)
      /*is_parameter=*/false);
  m.define(R"(
    def forward(self):
      return self.s.return_a_tuple()
  )");

  auto test_with_obj = [](script::Module& mod) {
    auto res = mod.run_method("forward");
    auto tup = res.toTuple();
    AT_ASSERT(tup->elements().size() == 2);
    auto i = tup->elements()[1].toInt();
    AT_ASSERT(i == 123);
  };

  auto frozen_m = torch::jit::freeze_module(m.clone());

  test_with_obj(m);
  test_with_obj(frozen_m);

  std::ostringstream oss;
  m.save(oss);
  std::istringstream iss(oss.str());
  caffe2::serialize::IStreamAdapter adapter{&iss};
  auto loaded_module = torch::jit::load(iss, torch::kCPU);

  std::ostringstream oss_frozen;
  frozen_m.save(oss_frozen);
  std::istringstream iss_frozen(oss_frozen.str());
  caffe2::serialize::IStreamAdapter adapter_frozen{&iss_frozen};
  auto loaded_frozen_module = torch::jit::load(iss_frozen, torch::kCPU);
}

} // namespace jit
} // namespace torch
