#pragma once

#include <test/cpp/common/support.h>

#include <gtest/gtest.h>

#include <ATen/TensorIndexing.h>
#include <c10/util/Exception.h>
#include <torch/nn/cloneable.h>
#include <torch/types.h>
#include <torch/utils.h>

#include <string>
#include <utility>

namespace torch {
namespace test {

// Lets you use a container without making a new class,
// for experimental implementations
class SimpleContainer : public nn::Cloneable<SimpleContainer> {
 public:
  void reset() override {}

  template <typename ModuleHolder>
  ModuleHolder add(
      ModuleHolder module_holder,
      std::string name = std::string()) {
    return Module::register_module(std::move(name), module_holder);
  }
};

struct SeedingFixture : public ::testing::Test {
  SeedingFixture() {
    torch::manual_seed(0);
  }
};

struct WarningCapture : public WarningHandler {
  WarningCapture() : prev_(WarningUtils::get_warning_handler()) {
    WarningUtils::set_warning_handler(this);
  }

  ~WarningCapture() override {
    WarningUtils::set_warning_handler(prev_);
  }

  const std::vector<std::string>& messages() {
    return messages_;
  }

  std::string str() {
    return c10::Join("\n", messages_);
  }

  void process(const c10::Warning& warning) override {
    messages_.push_back(warning.msg());
  }

 private:
  WarningHandler* prev_;
  std::vector<std::string> messages_;
};

inline bool pointer_equal(at::Tensor first, at::Tensor second) {
  return first.data_ptr() == second.data_ptr();
}

// This mirrors the `isinstance(x, torch.Tensor) and isinstance(y,
// torch.Tensor)` branch in `TestCase.assertEqual` in
// torch/testing/_internal/common_utils.py
inline void assert_tensor_equal(
    at::Tensor a,
    at::Tensor b,
    bool allow_inf = false) {
  ASSERT_TRUE(a.sizes() == b.sizes());
  if (a.numel() > 0) {
    if (a.device().type() == torch::kCPU &&
        (a.scalar_type() == torch::kFloat16 ||
         a.scalar_type() == torch::kBFloat16)) {
      // CPU half and bfloat16 tensors don't have the methods we need below
      a = a.to(torch::kFloat32);
    }
    if (a.device().type() == torch::kCUDA &&
        a.scalar_type() == torch::kBFloat16) {
      // CUDA bfloat16 tensors don't have the methods we need below
      a = a.to(torch::kFloat32);
    }
    b = b.to(a);

    if ((a.scalar_type() == torch::kBool) !=
        (b.scalar_type() == torch::kBool)) {
      TORCH_CHECK(false, "Was expecting both tensors to be bool type.");
    } else {
      if (a.scalar_type() == torch::kBool && b.scalar_type() == torch::kBool) {
        // we want to respect precision but as bool doesn't support subtraction,
        // boolean tensor has to be converted to int
        a = a.to(torch::kInt);
        b = b.to(torch::kInt);
      }

      auto diff = a - b;
      if (a.is_floating_point()) {
        // check that NaNs are in the same locations
        auto nan_mask = torch::isnan(a);
        ASSERT_TRUE(torch::equal(nan_mask, torch::isnan(b)));
        diff.index_put_({nan_mask}, 0);
        // inf check if allow_inf=true
        if (allow_inf) {
          auto inf_mask = torch::isinf(a);
          auto inf_sign = inf_mask.sign();
          ASSERT_TRUE(torch::equal(inf_sign, torch::isinf(b).sign()));
          diff.index_put_({inf_mask}, 0);
        }
      }
      // TODO: implement abs on CharTensor (int8)
      if (diff.is_signed() && diff.scalar_type() != torch::kInt8) {
        diff = diff.abs();
      }
      auto max_err = diff.max().item<double>();
      ASSERT_LE(max_err, 1e-5);
    }
  }
}

// This mirrors the `isinstance(x, torch.Tensor) and isinstance(y,
// torch.Tensor)` branch in `TestCase.assertNotEqual` in
// torch/testing/_internal/common_utils.py
inline void assert_tensor_not_equal(at::Tensor x, at::Tensor y) {
  if (x.sizes() != y.sizes()) {
    return;
  }
  ASSERT_GT(x.numel(), 0);
  y = y.type_as(x);
  y = x.is_cuda() ? y.to({torch::kCUDA, x.get_device()}) : y.cpu();
  auto nan_mask = x != x;
  if (torch::equal(nan_mask, y != y)) {
    auto diff = x - y;
    if (diff.is_signed()) {
      diff = diff.abs();
    }
    diff.index_put_({nan_mask}, 0);
    // Use `item()` to work around:
    // https://github.com/pytorch/pytorch/issues/22301
    auto max_err = diff.max().item<double>();
    ASSERT_GE(max_err, 1e-5);
  }
}

inline int count_substr_occurrences(
    const std::string& str,
    const std::string& substr) {
  int count = 0;
  size_t pos = str.find(substr);

  while (pos != std::string::npos) {
    count++;
    pos = str.find(substr, pos + substr.size());
  }

  return count;
}

// A RAII, thread local (!) guard that changes default dtype upon
// construction, and sets it back to the original dtype upon destruction.
//
// Usage of this guard is synchronized across threads, so that at any given
// time, only one guard can take effect.
struct AutoDefaultDtypeMode {
  static std::mutex default_dtype_mutex;

  AutoDefaultDtypeMode(c10::ScalarType default_dtype)
      : prev_default_dtype(
            torch::typeMetaToScalarType(torch::get_default_dtype())) {
    default_dtype_mutex.lock();
    torch::set_default_dtype(torch::scalarTypeToTypeMeta(default_dtype));
  }
  ~AutoDefaultDtypeMode() {
    default_dtype_mutex.unlock();
    torch::set_default_dtype(torch::scalarTypeToTypeMeta(prev_default_dtype));
  }
  c10::ScalarType prev_default_dtype;
};

inline void assert_tensor_creation_meta(
    torch::Tensor& x,
    torch::autograd::CreationMeta creation_meta) {
  auto autograd_meta = x.unsafeGetTensorImpl()->autograd_meta();
  TORCH_CHECK(autograd_meta);
  auto view_meta =
      static_cast<torch::autograd::DifferentiableViewMeta*>(autograd_meta);
  TORCH_CHECK(view_meta->has_bw_view());
  ASSERT_EQ(view_meta->get_creation_meta(), creation_meta);
}
} // namespace test
} // namespace torch
