// Copyright 2004-present Facebook. All Rights Reserved.
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS

#include <ATen/native/UpSample.h>
#include <c10/util/irange.h>
#include <c10/util/TypeCast.h>

namespace at::native::upsample {

TORCH_API c10::SmallVector<int64_t, 3> compute_output_size(
    c10::IntArrayRef input_size,  // Full input tensor size.
    at::OptionalIntArrayRef output_size,
    std::optional<c10::ArrayRef<double>> scale_factors) {
  const auto spatial_dimensions = static_cast<int64_t>(input_size.size()) - 2;
  if (output_size) {
    TORCH_CHECK(!scale_factors, "Must specify exactly one of output_size and scale_factors");
    TORCH_CHECK(static_cast<int64_t>(output_size->size()) == spatial_dimensions);
    return {output_size->data(), output_size->data() + output_size->size()};
  }
  if (scale_factors) {
    TORCH_CHECK(!output_size, "Must specify exactly one of output_size and scale_factors");
    TORCH_CHECK(static_cast<int64_t>(scale_factors->size()) == spatial_dimensions);
    c10::SmallVector<int64_t, 3> ret;
    for (const auto i : c10::irange(spatial_dimensions)) {
      const double odim = static_cast<double>(input_size[i+2]) * scale_factors.value()[i];
      ret.push_back(c10::checked_convert<int64_t>(odim, "int64_t"));
    }
    return ret;
  }
  TORCH_CHECK(false, "Must specify exactly one of output_size and scale_factors");
}

} // namespace at::native::upsample
