#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
#include <ATen/Dispatch.h>
#include <ATen/Parallel.h>
#include <ATen/core/Tensor.h>
#include <ATen/cpu/vec/vec.h>
#include <ATen/native/MaxPooling.h>
#include <c10/util/irange.h>

namespace at::native {

namespace {

template <typename scalar_t>
inline void max_pool1d_kernel(
    scalar_t* C10_RESTRICT op,
    const scalar_t* C10_RESTRICT ip,
    const PoolingParams1D& p) {
  for (const auto kj : c10::irange(p.KW)) {
    int64_t oj = p.valid_output_start(kj);
    int64_t oe = p.valid_output_end(kj);
    int64_t ij = p.index(kj, oj);
    for (; oj < oe; ++oj, ij += p.SJ) {
      scalar_t val = ip[ij];
      bool update_max = std::isnan(val) || op[oj] < val;
      op[oj] = update_max ? val : op[oj];
    }
  }
}

void max_pool1d_impl(
    Tensor& output,
    const Tensor& input,
    const PoolingParams1D& p) {
  AT_DISPATCH_FLOATING_TYPES_AND2(
      ScalarType::BFloat16,
      ScalarType::Half,
      input.scalar_type(),
      "max_pool1d_impl",
      [&] {
        const Tensor in = input.contiguous();
        scalar_t* const OP = output.data_ptr<scalar_t>();
        const scalar_t* const IP = in.const_data_ptr<scalar_t>();

        // Value used for padding
        scalar_t FILL = std::numeric_limits<scalar_t>::has_infinity
            ? -std::numeric_limits<scalar_t>::infinity()
            : std::numeric_limits<scalar_t>::lowest();

        at::parallel_for(0, p.NB * p.NC, 0, [&](int64_t begin, int64_t end) {
          for (const auto it : c10::irange(begin, end)) {
            scalar_t* op = OP + it * p.OW;
            const scalar_t* ip = IP + it * p.IW;
            std::fill_n(op, p.OW, FILL);
            max_pool1d_kernel(op, ip, p);
          }
        });
      });
}

} // namespace

REGISTER_DISPATCH(max_pool1d_stub, &max_pool1d_impl);

} // namespace at::native
