#include <gtest/gtest.h>

#include <torch/torch.h>

#include <test/cpp/api/support.h>

#include <algorithm>
#include <string>

using namespace torch::nn;

struct AnyModuleTest : torch::test::SeedingFixture {};

TEST_F(AnyModuleTest, SimpleReturnType) {
  struct M : torch::nn::Module {
    int forward() {
      return 123;
    }
  };
  AnyModule any(M{});
  ASSERT_EQ(any.forward<int>(), 123);
}

TEST_F(AnyModuleTest, SimpleReturnTypeAndSingleArgument) {
  struct M : torch::nn::Module {
    int forward(int x) {
      return x;
    }
  };
  AnyModule any(M{});
  ASSERT_EQ(any.forward<int>(5), 5);
}

TEST_F(AnyModuleTest, StringLiteralReturnTypeAndArgument) {
  struct M : torch::nn::Module {
    const char* forward(const char* x) {
      return x;
    }
  };
  AnyModule any(M{});
  ASSERT_EQ(any.forward<const char*>("hello"), std::string("hello"));
}

TEST_F(AnyModuleTest, StringReturnTypeWithConstArgument) {
  struct M : torch::nn::Module {
    std::string forward(int x, const double f) {
      return std::to_string(static_cast<int>(x + f));
    }
  };
  AnyModule any(M{});
  int x = 4;
  ASSERT_EQ(any.forward<std::string>(x, 3.14), std::string("7"));
}

TEST_F(
    AnyModuleTest,
    TensorReturnTypeAndStringArgumentsWithFunkyQualifications) {
  struct M : torch::nn::Module {
    torch::Tensor forward(
        std::string a,
        const std::string& b,
        std::string&& c) {
      const auto s = a + b + c;
      return torch::ones({static_cast<int64_t>(s.size())});
    }
  };
  AnyModule any(M{});
  ASSERT_TRUE(
      any.forward(std::string("a"), std::string("ab"), std::string("abc"))
          .sum()
          .item<int32_t>() == 6);
}

TEST_F(AnyModuleTest, WrongArgumentType) {
  struct M : torch::nn::Module {
    int forward(float x) {
      // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
      return x;
    }
  };
  AnyModule any(M{});
  ASSERT_THROWS_WITH(
      any.forward(5.0),
      "Expected argument #0 to be of type float, "
      "but received value of type double");
}

struct M_test_wrong_number_of_arguments : torch::nn::Module {
  int forward(int a, int b) {
    return a + b;
  }
};

TEST_F(AnyModuleTest, WrongNumberOfArguments) {
  AnyModule any(M_test_wrong_number_of_arguments{});
#if defined(_MSC_VER)
  std::string module_name = "struct M_test_wrong_number_of_arguments";
#else
  std::string module_name = "M_test_wrong_number_of_arguments";
#endif
  ASSERT_THROWS_WITH(
      any.forward(),
      module_name +
          "'s forward() method expects 2 argument(s), but received 0. "
          "If " +
          module_name +
          "'s forward() method has default arguments, "
          "please make sure the forward() method is declared with a corresponding `FORWARD_HAS_DEFAULT_ARGS` macro.");
  ASSERT_THROWS_WITH(
      any.forward(5),
      module_name +
          "'s forward() method expects 2 argument(s), but received 1. "
          "If " +
          module_name +
          "'s forward() method has default arguments, "
          "please make sure the forward() method is declared with a corresponding `FORWARD_HAS_DEFAULT_ARGS` macro.");
  ASSERT_THROWS_WITH(
      any.forward(1, 2, 3),
      module_name +
          "'s forward() method expects 2 argument(s), but received 3.");
}

struct M_default_arg_with_macro : torch::nn::Module {
  double forward(int a, int b = 2, double c = 3.0) {
    return a + b + c;
  }

 protected:
  FORWARD_HAS_DEFAULT_ARGS(
      {1, torch::nn::AnyValue(2)},
      {2, torch::nn::AnyValue(3.0)})
};

struct M_default_arg_without_macro : torch::nn::Module {
  double forward(int a, int b = 2, double c = 3.0) {
    return a + b + c;
  }
};

TEST_F(
    AnyModuleTest,
    PassingArgumentsToModuleWithDefaultArgumentsInForwardMethod) {
  {
    AnyModule any(M_default_arg_with_macro{});

    ASSERT_EQ(any.forward<double>(1), 6.0);
    ASSERT_EQ(any.forward<double>(1, 3), 7.0);
    ASSERT_EQ(any.forward<double>(1, 3, 5.0), 9.0);

    ASSERT_THROWS_WITH(
        any.forward(),
        "M_default_arg_with_macro's forward() method expects at least 1 argument(s) and at most 3 argument(s), but received 0.");
    ASSERT_THROWS_WITH(
        any.forward(1, 2, 3.0, 4),
        "M_default_arg_with_macro's forward() method expects at least 1 argument(s) and at most 3 argument(s), but received 4.");
  }
  {
    AnyModule any(M_default_arg_without_macro{});

    ASSERT_EQ(any.forward<double>(1, 3, 5.0), 9.0);

#if defined(_MSC_VER)
    std::string module_name = "struct M_default_arg_without_macro";
#else
    std::string module_name = "M_default_arg_without_macro";
#endif

    ASSERT_THROWS_WITH(
        any.forward(),
        module_name +
            "'s forward() method expects 3 argument(s), but received 0. "
            "If " +
            module_name +
            "'s forward() method has default arguments, "
            "please make sure the forward() method is declared with a corresponding `FORWARD_HAS_DEFAULT_ARGS` macro.");
    ASSERT_THROWS_WITH(
        any.forward<double>(1),
        module_name +
            "'s forward() method expects 3 argument(s), but received 1. "
            "If " +
            module_name +
            "'s forward() method has default arguments, "
            "please make sure the forward() method is declared with a corresponding `FORWARD_HAS_DEFAULT_ARGS` macro.");
    ASSERT_THROWS_WITH(
        any.forward<double>(1, 3),
        module_name +
            "'s forward() method expects 3 argument(s), but received 2. "
            "If " +
            module_name +
            "'s forward() method has default arguments, "
            "please make sure the forward() method is declared with a corresponding `FORWARD_HAS_DEFAULT_ARGS` macro.");
    ASSERT_THROWS_WITH(
        any.forward(1, 2, 3.0, 4),
        module_name +
            "'s forward() method expects 3 argument(s), but received 4.");
  }
}

struct M : torch::nn::Module {
  explicit M(int value_) : torch::nn::Module("M"), value(value_) {}
  int value;
  int forward(float x) {
    // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
    return x;
  }
};

TEST_F(AnyModuleTest, GetWithCorrectTypeSucceeds) {
  AnyModule any(M{5});
  ASSERT_EQ(any.get<M>().value, 5);
}

TEST_F(AnyModuleTest, GetWithIncorrectTypeThrows) {
  struct N : torch::nn::Module {
    torch::Tensor forward(torch::Tensor input) {
      return input;
    }
  };
  AnyModule any(M{5});
  ASSERT_THROWS_WITH(any.get<N>(), "Attempted to cast module");
}

TEST_F(AnyModuleTest, PtrWithBaseClassSucceeds) {
  AnyModule any(M{5});
  auto ptr = any.ptr();
  ASSERT_NE(ptr, nullptr);
  ASSERT_EQ(ptr->name(), "M");
}

TEST_F(AnyModuleTest, PtrWithGoodDowncastSuccceeds) {
  AnyModule any(M{5});
  auto ptr = any.ptr<M>();
  ASSERT_NE(ptr, nullptr);
  ASSERT_EQ(ptr->value, 5);
}

TEST_F(AnyModuleTest, PtrWithBadDowncastThrows) {
  struct N : torch::nn::Module {
    torch::Tensor forward(torch::Tensor input) {
      return input;
    }
  };
  AnyModule any(M{5});
  ASSERT_THROWS_WITH(any.ptr<N>(), "Attempted to cast module");
}

TEST_F(AnyModuleTest, DefaultStateIsEmpty) {
  struct M : torch::nn::Module {
    explicit M(int value_) : value(value_) {}
    int value;
    int forward(float x) {
      // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
      return x;
    }
  };
  AnyModule any;
  ASSERT_TRUE(any.is_empty());
  any = std::make_shared<M>(5);
  ASSERT_FALSE(any.is_empty());
  ASSERT_EQ(any.get<M>().value, 5);
}

TEST_F(AnyModuleTest, AllMethodsThrowForEmptyAnyModule) {
  struct M : torch::nn::Module {
    int forward(int x) {
      return x;
    }
  };
  AnyModule any;
  ASSERT_TRUE(any.is_empty());
  ASSERT_THROWS_WITH(any.get<M>(), "Cannot call get() on an empty AnyModule");
  ASSERT_THROWS_WITH(any.ptr<M>(), "Cannot call ptr() on an empty AnyModule");
  ASSERT_THROWS_WITH(any.ptr(), "Cannot call ptr() on an empty AnyModule");
  ASSERT_THROWS_WITH(
      any.type_info(), "Cannot call type_info() on an empty AnyModule");
  ASSERT_THROWS_WITH(
      any.forward<int>(5), "Cannot call forward() on an empty AnyModule");
}

TEST_F(AnyModuleTest, CanMoveAssignDifferentModules) {
  struct M : torch::nn::Module {
    std::string forward(int x) {
      return std::to_string(x);
    }
  };
  struct N : torch::nn::Module {
    int forward(float x) {
      // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
      return 3 + x;
    }
  };
  AnyModule any;
  ASSERT_TRUE(any.is_empty());
  any = std::make_shared<M>();
  ASSERT_FALSE(any.is_empty());
  ASSERT_EQ(any.forward<std::string>(5), "5");
  any = std::make_shared<N>();
  ASSERT_FALSE(any.is_empty());
  ASSERT_EQ(any.forward<int>(5.0f), 8);
}

TEST_F(AnyModuleTest, ConstructsFromModuleHolder) {
  struct MImpl : torch::nn::Module {
    explicit MImpl(int value_) : torch::nn::Module("M"), value(value_) {}
    int value;
    int forward(float x) {
      // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
      return x;
    }
  };

  struct M : torch::nn::ModuleHolder<MImpl> {
    using torch::nn::ModuleHolder<MImpl>::ModuleHolder;
    using torch::nn::ModuleHolder<MImpl>::get;
  };

  AnyModule any(M{5});
  ASSERT_EQ(any.get<MImpl>().value, 5);
  ASSERT_EQ(any.get<M>()->value, 5);

  AnyModule module(Linear(3, 4));
  std::shared_ptr<Module> ptr = module.ptr();
  Linear linear(module.get<Linear>());
}

TEST_F(AnyModuleTest, ConvertsVariableToTensorCorrectly) {
  struct M : torch::nn::Module {
    torch::Tensor forward(torch::Tensor input) {
      return input;
    }
  };

  // When you have an autograd::Variable, it should be converted to a
  // torch::Tensor before being passed to the function (to avoid a type
  // mismatch).
  AnyModule any(M{});
  ASSERT_TRUE(
      any.forward(torch::autograd::Variable(torch::ones(5)))
          .sum()
          .item<float>() == 5);
  // at::Tensors that are not variables work too.
  ASSERT_EQ(any.forward(at::ones(5)).sum().item<float>(), 5);
}

namespace torch {
namespace nn {
struct TestAnyValue {
  template <typename T>
  // NOLINTNEXTLINE(bugprone-forwarding-reference-overload)
  explicit TestAnyValue(T&& value) : value_(std::forward<T>(value)) {}
  AnyValue operator()() {
    return std::move(value_);
  }
  AnyValue value_;
};
template <typename T>
AnyValue make_value(T&& value) {
  return TestAnyValue(std::forward<T>(value))();
}
} // namespace nn
} // namespace torch

struct AnyValueTest : torch::test::SeedingFixture {};

TEST_F(AnyValueTest, CorrectlyAccessesIntWhenCorrectType) {
  auto value = make_value<int>(5);
  ASSERT_NE(value.try_get<int>(), nullptr);
  // const and non-const types have the same typeid(),
  // but casting Holder<int> to Holder<const int> is undefined
  // behavior according to UBSAN:
  // https://github.com/pytorch/pytorch/issues/26964
  // ASSERT_NE(value.try_get<const int>(), nullptr);
  ASSERT_EQ(value.get<int>(), 5);
}
// This test does not work at all, because it looks like make_value
// decays const int into int.
// TEST_F(AnyValueTest, CorrectlyAccessesConstIntWhenCorrectType) {
//  auto value = make_value<const int>(5);
//  ASSERT_NE(value.try_get<const int>(), nullptr);
//  // ASSERT_NE(value.try_get<int>(), nullptr);
//  ASSERT_EQ(value.get<const int>(), 5);
//}
TEST_F(AnyValueTest, CorrectlyAccessesStringLiteralWhenCorrectType) {
  auto value = make_value("hello");
  ASSERT_NE(value.try_get<const char*>(), nullptr);
  ASSERT_EQ(value.get<const char*>(), std::string("hello"));
}
TEST_F(AnyValueTest, CorrectlyAccessesStringWhenCorrectType) {
  auto value = make_value(std::string("hello"));
  ASSERT_NE(value.try_get<std::string>(), nullptr);
  ASSERT_EQ(value.get<std::string>(), "hello");
}
TEST_F(AnyValueTest, CorrectlyAccessesPointersWhenCorrectType) {
  std::string s("hello");
  std::string* p = &s;
  auto value = make_value(p);
  ASSERT_NE(value.try_get<std::string*>(), nullptr);
  ASSERT_EQ(*value.get<std::string*>(), "hello");
}
TEST_F(AnyValueTest, CorrectlyAccessesReferencesWhenCorrectType) {
  std::string s("hello");
  const std::string& t = s;
  auto value = make_value(t);
  ASSERT_NE(value.try_get<std::string>(), nullptr);
  ASSERT_EQ(value.get<std::string>(), "hello");
}

TEST_F(AnyValueTest, TryGetReturnsNullptrForTheWrongType) {
  auto value = make_value(5);
  ASSERT_NE(value.try_get<int>(), nullptr);
  ASSERT_EQ(value.try_get<float>(), nullptr);
  ASSERT_EQ(value.try_get<long>(), nullptr);
  ASSERT_EQ(value.try_get<std::string>(), nullptr);
}

TEST_F(AnyValueTest, GetThrowsForTheWrongType) {
  auto value = make_value(5);
  ASSERT_NE(value.try_get<int>(), nullptr);
  ASSERT_THROWS_WITH(
      value.get<float>(),
      "Attempted to cast AnyValue to float, "
      "but its actual type is int");
  ASSERT_THROWS_WITH(
      value.get<long>(),
      "Attempted to cast AnyValue to long, "
      "but its actual type is int");
}

TEST_F(AnyValueTest, MoveConstructionIsAllowed) {
  auto value = make_value(5);
  auto copy = make_value(std::move(value));
  ASSERT_NE(copy.try_get<int>(), nullptr);
  ASSERT_EQ(copy.get<int>(), 5);
}

TEST_F(AnyValueTest, MoveAssignmentIsAllowed) {
  auto value = make_value(5);
  auto copy = make_value(10);
  copy = std::move(value);
  ASSERT_NE(copy.try_get<int>(), nullptr);
  ASSERT_EQ(copy.get<int>(), 5);
}

TEST_F(AnyValueTest, TypeInfoIsCorrectForInt) {
  auto value = make_value(5);
  ASSERT_EQ(value.type_info().hash_code(), typeid(int).hash_code());
}

TEST_F(AnyValueTest, TypeInfoIsCorrectForStringLiteral) {
  auto value = make_value("hello");
  ASSERT_EQ(value.type_info().hash_code(), typeid(const char*).hash_code());
}

TEST_F(AnyValueTest, TypeInfoIsCorrectForString) {
  auto value = make_value(std::string("hello"));
  ASSERT_EQ(value.type_info().hash_code(), typeid(std::string).hash_code());
}
