#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
#include <ATen/native/IndexingUtils.h>

namespace at::native {

bool canUse32BitIndexMath(const TensorBase& t, int64_t max_elem) {
  auto elements = t.sym_numel();
  if (elements >= max_elem) {
    return false;
  }
  if (elements == 0) {
    return max_elem > 0;
  }

  c10::SymInt offset = 0;
  auto linearId = elements - 1;

  // NOTE: Assumes all strides are positive, which is true for now
  // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
  for (int i = t.dim() - 1; i >= 0; --i) {
    auto curDimIndex = linearId % t.sym_size(i);
    auto curDimOffset = curDimIndex * t.sym_stride(i);
    offset += curDimOffset;
    linearId /= t.sym_size(i);
  }

  if (offset >= max_elem) {
    return false;
  }

  return true;
}

} // namespace at::native
