// clang-format off
#include <c10/util/BFloat16.h>
#include <c10/util/BFloat16-math.h>
#include <c10/util/irange.h>
// clang-format on
#include <gtest/gtest.h>

namespace {
float float_from_bytes(uint32_t sign, uint32_t exponent, uint32_t fraction) {
  // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
  uint32_t bytes;
  bytes = 0;
  bytes |= sign;
  bytes <<= 8;
  bytes |= exponent;
  bytes <<= 23;
  bytes |= fraction;

  // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
  float res;
  std::memcpy(&res, &bytes, sizeof(res));
  return res;
}

TEST(BFloat16Conversion, FloatToBFloat16AndBack) {
  // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,cppcoreguidelines-avoid-magic-numbers,modernize-avoid-c-arrays)
  float in[100];
  for (const auto i : c10::irange(100)) {
    // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions,cppcoreguidelines-avoid-magic-numbers)
    in[i] = i + 1.25;
  }

  // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,cppcoreguidelines-avoid-magic-numbers,modernize-avoid-c-arrays)
  c10::BFloat16 bfloats[100];
  // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,cppcoreguidelines-avoid-magic-numbers,modernize-avoid-c-arrays)
  float out[100];

  for (const auto i : c10::irange(100)) {
    bfloats[i].x = c10::detail::bits_from_f32(in[i]);
    out[i] = c10::detail::f32_from_bits(bfloats[i].x);

    // The relative error should be less than 1/(2^7) since BFloat16
    // has 7 bits mantissa.
    EXPECT_LE(std::fabs(out[i] - in[i]) / in[i], 1.0 / 128);
  }
}

TEST(BFloat16Conversion, FloatToBFloat16RNEAndBack) {
  // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,cppcoreguidelines-avoid-magic-numbers,modernize-avoid-c-arrays)
  float in[100];
  for (const auto i : c10::irange(100)) {
    // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions,cppcoreguidelines-avoid-magic-numbers)
    in[i] = i + 1.25;
  }

  // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,cppcoreguidelines-avoid-magic-numbers,modernize-avoid-c-arrays)
  c10::BFloat16 bfloats[100];
  // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,cppcoreguidelines-avoid-magic-numbers,modernize-avoid-c-arrays)
  float out[100];

  for (const auto i : c10::irange(100)) {
    bfloats[i].x = c10::detail::round_to_nearest_even(in[i]);
    out[i] = c10::detail::f32_from_bits(bfloats[i].x);

    // The relative error should be less than 1/(2^7) since BFloat16
    // has 7 bits mantissa.
    EXPECT_LE(std::fabs(out[i] - in[i]) / in[i], 1.0 / 128);
  }
}

TEST(BFloat16Conversion, NaN) {
  float inNaN = float_from_bytes(0, 0xFF, 0x7FFFFF);
  EXPECT_TRUE(std::isnan(inNaN));

  c10::BFloat16 a = c10::BFloat16(inNaN);
  float out = c10::detail::f32_from_bits(a.x);

  EXPECT_TRUE(std::isnan(out));
}

TEST(BFloat16Conversion, Inf) {
  float inInf = float_from_bytes(0, 0xFF, 0);
  EXPECT_TRUE(std::isinf(inInf));

  c10::BFloat16 a = c10::BFloat16(inInf);
  float out = c10::detail::f32_from_bits(a.x);

  EXPECT_TRUE(std::isinf(out));
}

TEST(BFloat16Conversion, SmallestDenormal) {
  float in = std::numeric_limits<float>::denorm_min(); // The smallest non-zero
                                                       // subnormal number
  c10::BFloat16 a = c10::BFloat16(in);
  float out = c10::detail::f32_from_bits(a.x);

  EXPECT_FLOAT_EQ(in, out);
}

TEST(BFloat16Math, Addition) {
  // This test verifies that if only first 7 bits of float's mantissa are
  // changed after addition, we should have no loss in precision.

  // input bits
  // S | Exponent | Mantissa
  // 0 | 10000000 | 10010000000000000000000 = 3.125
  float input = float_from_bytes(0, 0, 0x40480000);

  // expected bits
  // S | Exponent | Mantissa
  // 0 | 10000001 | 10010000000000000000000 = 6.25
  float expected = float_from_bytes(0, 0, 0x40c80000);

  // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
  c10::BFloat16 b;
  b.x = c10::detail::bits_from_f32(input);
  b = b + b;

  float res = c10::detail::f32_from_bits(b.x);
  EXPECT_EQ(res, expected);
}

TEST(BFloat16Math, Subtraction) {
  // This test verifies that if only first 7 bits of float's mantissa are
  // changed after subtraction, we should have no loss in precision.

  // input bits
  // S | Exponent | Mantissa
  // 0 | 10000001 | 11101000000000000000000 = 7.625
  float input = float_from_bytes(0, 0, 0x40f40000);

  // expected bits
  // S | Exponent | Mantissa
  // 0 | 10000000 | 01010000000000000000000 = 2.625
  float expected = float_from_bytes(0, 0, 0x40280000);

  // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
  c10::BFloat16 b;
  b.x = c10::detail::bits_from_f32(input);
  b = b - 5;

  float res = c10::detail::f32_from_bits(b.x);
  EXPECT_EQ(res, expected);
}

// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
TEST(BFloat16Math, NextAfterZero) {
  const c10::BFloat16 zero{0};

  auto check_nextafter =
      [](c10::BFloat16 from, c10::BFloat16 to, c10::BFloat16 expected) {
        c10::BFloat16 actual = std::nextafter(from, to);
        // Check for bitwise equality!
        ASSERT_EQ(actual.x ^ expected.x, uint16_t{0});
      };
  check_nextafter(zero, zero, /*expected=*/zero);
  check_nextafter(zero, -zero, /*expected=*/-zero);
  check_nextafter(-zero, zero, /*expected=*/zero);
  check_nextafter(-zero, -zero, /*expected=*/-zero);
}

float BinaryToFloat(uint32_t bytes) {
  // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
  float res;
  std::memcpy(&res, &bytes, sizeof(res));
  return res;
}

struct BFloat16TestParam {
  uint32_t input;
  uint16_t rne;
};

class BFloat16Test : public ::testing::Test,
                     public ::testing::WithParamInterface<BFloat16TestParam> {};

TEST_P(BFloat16Test, BFloat16RNETest) {
  float value = BinaryToFloat(GetParam().input);
  uint16_t rounded = c10::detail::round_to_nearest_even(value);
  EXPECT_EQ(GetParam().rne, rounded);
}

INSTANTIATE_TEST_SUITE_P(
    BFloat16TestInstantiation,
    BFloat16Test,
    ::testing::Values(
        BFloat16TestParam{0x3F848000, 0x3F84},
        BFloat16TestParam{0x3F848010, 0x3F85},
        BFloat16TestParam{0x3F850000, 0x3F85},
        BFloat16TestParam{0x3F858000, 0x3F86},
        BFloat16TestParam{0x3FFF8000, 0x4000}));

} // namespace
