#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
#include <ATen/core/Tensor.h>
#include <ATen/Dispatch.h>
#include <ATen/Parallel.h>
#include <ATen/TensorIndexing.h>
#include <ATen/TensorMeta.h>
#include <ATen/TensorOperators.h>
#include <ATen/TensorUtils.h>
#include <ATen/native/cpu/utils.h>
#include <ATen/native/Resize.h>
#include <c10/util/SmallBuffer.h>
#include <ATen/TensorSubclassLikeUtils.h>

#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/Functions.h>
#include <ATen/NativeFunctions.h>
#else
#include <ATen/ops/cross_entropy_loss_native.h>
#include <ATen/ops/empty.h>
#include <ATen/ops/log_softmax.h>
#include <ATen/ops/nll_loss.h>
#include <ATen/ops/nll_loss2d.h>
#include <ATen/ops/nll_loss_backward_native.h>
#include <ATen/ops/nll_loss_forward.h>
#include <ATen/ops/nll_loss_forward_native.h>
#include <ATen/ops/nll_loss_native.h>
#include <ATen/ops/nll_loss_nd.h>
#include <ATen/ops/nll_loss_nd_native.h>
#endif

#include <c10/core/TensorOptions.h>
#include <c10/util/irange.h>

#include <utility>

namespace at::meta {
TORCH_META_FUNC(nll_loss_forward)
(const Tensor& self,
 const Tensor& target,
 const OptionalTensorRef weight_opt,
 int64_t reduction,
 int64_t ignore_index) {
  const Tensor& weight = weight_opt.getTensorRef();

  TORCH_CHECK(
      self.dim() > 0 && self.dim() <= 2, "input tensor should be 1D or 2D");
  TORCH_CHECK(
      target.dim() <= 1,
      "0D or 1D target tensor expected, multi-target not supported");

  auto no_batch_dim = self.dim() == 1  && target.dim() == 0;
  TORCH_CHECK(
      no_batch_dim || (self.size(0) == target.size(0)),
      "size mismatch (got input: ",
      self.sizes(),
      ", target: ",
      target.sizes(),
      ")")

  const auto n_classes = self.size(-1);

  TORCH_CHECK(
      !weight.defined() || (weight.dim() == 1 && weight.numel() == n_classes),
      "weight tensor should be defined either for all ",
      n_classes,
      " classes or no classes"
      " but got weight tensor of shape: ",
      weight.sizes());

  const auto n_dims = self.dim();
  const auto batch_size = self.size(0);

  if (reduction == Reduction::None && n_dims == 2) {
    set_output_raw_strided(0, {batch_size}, {}, self.options());
  } else {
    // produce scalar output when reducing or input is 1d
    set_output_raw_strided(0, {}, {}, self.options());
  }

  set_output_raw_strided(1, {}, {}, self.options());
}

TORCH_META_FUNC(nll_loss_backward)
(const Tensor& grad_output,
 const Tensor& self,
 const Tensor& target,
 OptionalTensorRef weight_opt,
 int64_t reduction,
 int64_t ignore_index,
 const Tensor& total_weight) {
  TORCH_CHECK(
      self.dim() > 0 && self.dim() <= 2, "input tensor should be 1D or 2D");
  TORCH_CHECK(
      target.dim() <= 1,
      "0D or 1D target tensor expected, multi-target not supported");

  auto no_batch_dim = self.dim() == 1  && target.dim() == 0;
  TORCH_CHECK(
      no_batch_dim || (self.size(0) == target.size(0)),
      "size mismatch (got input: ",
      self.sizes(),
      ", target: ",
      target.sizes(),
      ")")
  TORCH_CHECK(
      total_weight.numel() == 1,
      "expected total_weight to be a  single element tensor, got: ",
      total_weight.sizes(),
      " (",
      total_weight.numel(),
      " elements)");

  const auto& weight = weight_opt.getTensorRef();

  TORCH_CHECK(
      !weight.defined() || weight.numel() == self.size(-1),
      "weight tensor should be defined either for all or no classes");

  const auto n_dims = self.dim();

  if (reduction == Reduction::None && n_dims == 2) {
    const auto batch_size = self.size(0);
    check_dim_size(grad_output, 1, 0, batch_size);
  } else {
    TORCH_CHECK(
        grad_output.dim() <= 1 && grad_output.numel() == 1,
        "Expected a single element grad_output tensor, but got: ",
        grad_output.sizes());
  }

  set_output_raw_strided(0, self.sizes(), {}, self.options().memory_format(LEGACY_CONTIGUOUS_MEMORY_FORMAT));
}
} // namespace at::meta

namespace at::native {

namespace {

// Returns a contiguous tensor if the source tensor
// is defined. Otherwise returns the undefined
// source tensor unmodified.
inline Tensor optional_contiguous(const Tensor& source) {
  return source.defined() ? source.contiguous() : source;
}

// Returns the address of the first element of a tensor
// or nullptr if the tensor is undefined.
template <typename scalar_t>
inline scalar_t* optional_data(const Tensor& source) {
  if constexpr (std::is_const<scalar_t>::value) {
    return source.defined() ? source.const_data_ptr<scalar_t>() : nullptr;
  } else {
    return source.defined() ? source.data_ptr<scalar_t>() : nullptr;
  }
}

template <typename scalar_t, typename target_t>
static void nll_loss_out_frame(
    const Tensor& output,
    const Tensor& total_weight,
    const Tensor& input,
    const Tensor& target,
    const Tensor& weight,
    int64_t reduction,
    int64_t ignore_index) {
  const auto n_dims = input.dim();
  const auto n_classes = input.size(-1);

  scalar_t* total_weight_data = total_weight.data_ptr<scalar_t>();
  *total_weight_data = 0;

  auto weight_contiguous = optional_contiguous(weight);
  const scalar_t* weight_data = optional_data<const scalar_t>(weight_contiguous);

  if (reduction == Reduction::None && n_dims == 2) {
    const auto batch_size = input.size(0);
    at::native::resize_output(output, {batch_size});

    auto input_acc = input.accessor<const scalar_t, 2>();
    auto target_acc = target.accessor<const target_t, 1>();
    auto output_acc = output.accessor<scalar_t, 1>();

    at::parallel_for(0, batch_size, 0, [&](int64_t start, int64_t end) {
      for (const auto i : c10::irange(start, end)) {
        const auto cur_target = target_acc[i];

        if (cur_target == ignore_index) {
          output_acc[i] = 0;
          continue;
        }

        TORCH_CHECK_INDEX(
            cur_target >= 0 && cur_target < n_classes,
            "Target ",
            cur_target,
            " is out of bounds.");

        scalar_t cur_weight = weight_data != nullptr ? weight_data[cur_target]
                                                     : static_cast<scalar_t>(1);
        output_acc[i] = -input_acc[i][cur_target] * cur_weight;
      }
    });

    return;
  }

  // produce scalar outputs for the reduction case
  at::native::resize_output(output, {});

  if (target.numel() == 0) {
    // Here target (and input) have zero elements
    // Mean reduction on empty tensors produces NaN. See the discussion in
    // https://github.com/pytorch/pytorch/pull/64572#issuecomment-926504162
    if (reduction == Reduction::Mean) {
      output.fill_(std::numeric_limits<double>::quiet_NaN());
    } else {
      output.zero_();
    }
    total_weight.zero_();
    return;
  }

  auto input_contiguous = input.contiguous();
  auto target_contiguous = target.contiguous();

  const scalar_t* input_data = input_contiguous.const_data_ptr<scalar_t>();
  const target_t* target_data = target_contiguous.const_data_ptr<target_t>();

  const int64_t ndim = input.dim();
  const int64_t batch_size = ndim == 1 ? 1 : input.size(0);

  constexpr int64_t cascade_sum_num_levels = 8;
  const int64_t level_power =
      std::max(int64_t(4), utils::CeilLog2(batch_size) / cascade_sum_num_levels);
  const int64_t level_step = (1 << level_power);
  const int64_t level_mask = level_step - 1;

  int64_t num_ignored = 0;

  // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
  scalar_t weight_partial_sums[cascade_sum_num_levels] = {0};
  // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
  scalar_t loss_partial_sums[cascade_sum_num_levels] = {0};
  for (const auto b : c10::irange(batch_size)) {
    const int64_t cur_target = target_data[b];
    if (cur_target == ignore_index) {
      ++num_ignored;
      continue;
    }

    TORCH_CHECK_INDEX(
        cur_target >= 0 && cur_target < n_classes,
        "Target ",
        cur_target,
        " is out of bounds.");

    const auto data = input_data[b * n_classes + cur_target];
    if (weight_data) {
      const scalar_t weight_val = weight_data[cur_target];
      loss_partial_sums[0] -= data * weight_val;
      weight_partial_sums[0] += weight_val;
    } else {
      loss_partial_sums[0] -= data;
    }

    for (int64_t j = 0; j + 1 < cascade_sum_num_levels; ++j) {
      const auto mask = (level_mask << (j * level_power));
      if (C10_LIKELY((b & mask) != 0)) {
        break;
      }

      weight_partial_sums[j + 1] += weight_partial_sums[j];
      loss_partial_sums[j + 1] += loss_partial_sums[j];

      weight_partial_sums[j] = 0;
      loss_partial_sums[j] = 0;
    }
  }

  const scalar_t total_weight_val = !weight_data ?
    static_cast<scalar_t>(batch_size - num_ignored) :
    std::accumulate(std::begin(weight_partial_sums),
                    std::end(weight_partial_sums),
                    scalar_t{0});

  scalar_t output_val = std::accumulate(std::begin(loss_partial_sums),
                                        std::end(loss_partial_sums),
                                        scalar_t{0});

  if (reduction == Reduction::Mean) {
    output_val /= total_weight_val;
  }

  // write result to output tensors
  *output.data_ptr<scalar_t>() = output_val;
  *total_weight_data = total_weight_val;
}

void nll_loss_forward_out_cpu_template(
    const Tensor& output,
    const Tensor& total_weight,
    const Tensor& input,
    const Tensor& target,
    const Tensor& weight,
    int64_t reduction,
    int64_t ignore_index) {
  AT_DISPATCH_FLOATING_TYPES_AND2(
      ScalarType::BFloat16,
      ScalarType::Half,
      input.scalar_type(),
      "nll_loss_out_frame",
      [&] {
        if (target.scalar_type() == kByte) {
          nll_loss_out_frame<scalar_t, uint8_t>(
              output,
              total_weight,
              input,
              target,
              weight,
              reduction,
              ignore_index);
        } else {
          // assumed to be int64
          nll_loss_out_frame<scalar_t, int64_t>(
              output,
              total_weight,
              input,
              target,
              weight,
              reduction,
              ignore_index);
        }
      });
}

template <typename scalar_t, typename target_t>
static void nll_loss_backward_out_frame(
    const Tensor& grad_input,
    const Tensor& grad_output,
    const Tensor& input,
    const Tensor& target,
    const Tensor& weight,
    int64_t reduction,
    int64_t ignore_index,
    const Tensor& total_weight) {
  const auto n_dims = input.dim();
  const auto n_classes = input.size(-1);

  auto target_ = target;
  if (target.dim() == 0) {
    target_ = target.unsqueeze(0);
  }
  auto target_acc = target_.accessor<const target_t, 1>();

  auto weight_contiguous = optional_contiguous(weight);
  const scalar_t* weight_data = optional_data<const scalar_t>(weight_contiguous);

  if (reduction == Reduction::None && n_dims == 2) {
    const auto batch_size = input.size(0);
    auto grad_input_acc = grad_input.accessor<scalar_t, 2>();
    auto grad_output_acc = grad_output.accessor<const scalar_t, 1>();
    at::parallel_for(0, batch_size, 0, [&](int64_t start, int64_t end) {
      for (const auto i : c10::irange(start, end)) {
        auto cur_target = target_acc[i];
        if (cur_target == ignore_index) {
          continue;
        }
        const scalar_t w =
            weight_data ? weight_data[cur_target] : static_cast<scalar_t>(1);
        grad_input_acc[i][cur_target] = -w * grad_output_acc[i];
      }
    });
    return;
  }

  const scalar_t total_weight_value = *total_weight.const_data_ptr<scalar_t>();

  const scalar_t grad_output_value = *grad_output.const_data_ptr<scalar_t>();

  if (input.dim() == 1) {
    auto grad_input_acc = grad_input.accessor<scalar_t, 1>();

    const auto t = target_acc[0];
    if (t != ignore_index) {
      TORCH_CHECK_INDEX(t >= 0 && t < n_classes, "Target ", t, " is out of bounds.");
      const auto grad = -(reduction == Reduction::Mean ? grad_output_value / total_weight_value
                                                       : grad_output_value);
      grad_input_acc[t] = weight_data != nullptr ? weight_data[t] * grad
                                                 : grad;
    }
  } else if (input.dim() == 2) {
    auto grad_input_acc = grad_input.accessor<scalar_t, 2>();
    const auto grad = -(reduction == Reduction::Mean ? grad_output_value / total_weight_value
                                                     : grad_output_value);

    const auto batch_size = input.size(0);

    at::parallel_for(0, batch_size, 0, [&](int64_t start, int64_t end) {
      for (const auto i : c10::irange(start, end)) {
        const auto t = target_acc[i];
        if (t != ignore_index) {
          TORCH_CHECK_INDEX(t >= 0 && t < n_classes, "Target ", t, " is out of bounds.");
          grad_input_acc[i][t] = weight_data != nullptr ? weight_data[t] * grad
                                                        : grad;
        }
      }
    });
  }
}

void nll_loss_backward_out_cpu_template(
    const Tensor& grad_input,
    const Tensor& grad_output,
    const Tensor& input,
    const Tensor& target,
    const Tensor& weight,
    int64_t reduction,
    int64_t ignore_index,
    const Tensor& total_weight) {
  grad_input.zero_();

  AT_DISPATCH_FLOATING_TYPES_AND2(
      ScalarType::BFloat16,
      ScalarType::Half,
      input.scalar_type(),
      "nll_loss_backward_out_frame",
      [&] {
        if (target.scalar_type() == kByte) {
          nll_loss_backward_out_frame<scalar_t, uint8_t>(
              grad_input,
              grad_output,
              input,
              target,
              weight,
              reduction,
              ignore_index,
              total_weight);
        } else {
          // assumed to be uint64
          nll_loss_backward_out_frame<scalar_t, int64_t>(
              grad_input,
              grad_output,
              input,
              target,
              weight,
              reduction,
              ignore_index,
              total_weight);
        }
      });
}

} // namespace

TORCH_IMPL_FUNC(nll_loss_forward_out_cpu)
(const Tensor& self,
 const Tensor& target,
 const OptionalTensorRef weight_opt,
 int64_t reduction,
 int64_t ignore_index,
 const Tensor& output,
 const Tensor& total_weight) {
  const Tensor& weight = weight_opt.getTensorRef();
  nll_loss_forward_out_cpu_template(
      output, total_weight, self, target, weight, reduction, ignore_index);
}

TORCH_IMPL_FUNC(nll_loss_backward_out_cpu)
(const Tensor& grad_output,
 const Tensor& self,
 const Tensor& target,
 OptionalTensorRef weight_opt,
 int64_t reduction,
 int64_t ignore_index,
 const Tensor& total_weight,
 const Tensor& grad_input
) {
  const Tensor& weight = weight_opt.getTensorRef();
  nll_loss_backward_out_cpu_template(
      grad_input,
      grad_output,
      self,
      target,
      weight,
      reduction,
      ignore_index,
      total_weight);
}

static Tensor cross_entropy_loss_prob_target(
    const Tensor& self,
    const Tensor& target_,
    const Tensor& weight,
    int64_t reduction,
    double label_smoothing) {
  const auto class_dim = self.dim() == 1 ? 0 : 1;
  const auto n_classes = self.size(class_dim);
  TORCH_CHECK(
      !weight.defined() || (weight.dim() == 1 && weight.numel() == n_classes),
      "cross_entropy: weight tensor should be defined either for all ",
      n_classes,
      " classes or no classes"
      " but got weight tensor of shape: ",
      weight.sizes());

  auto input = at::log_softmax(self, class_dim, self.scalar_type());
  Tensor target;

  if (label_smoothing > 0.0) {
    TORCH_CHECK(label_smoothing <= 1.0, "label_smoothing must be between 0.0 and 1.0. Got: ", label_smoothing);
    target = target_ * (1 - label_smoothing) + label_smoothing / n_classes;
  } else {
    target = target_;
  }

  if (weight.defined()) {
    // Expand weight to the correct number of dims for broadcasting with input / target
    Tensor weight_ = weight;
    if (input.dim() > 1) {
        auto weight_broadcast_shape = SmallBuffer<int64_t, 5>(input.dim());
        std::fill(weight_broadcast_shape.begin(), weight_broadcast_shape.end(), 1);
        weight_broadcast_shape[1] = weight.size(0);
        weight_ = weight.view(weight_broadcast_shape);
    }

    switch (reduction) {
      case Reduction::Mean:
        if (input.sym_numel()==0){
          return -(input * target * weight_).sum().fill_(std::numeric_limits<double>::quiet_NaN());
        } else {
          return -(input * target * weight_).sum() / (input.sym_numel() / n_classes);
        }
      case Reduction::Sum:
        return -(input * target * weight_).sum();
      case Reduction::None:
        return -(input * target * weight_).sum(class_dim);
      default:
        TORCH_CHECK(false, "Invalid reduction type encountered in cross_entropy: ", reduction);
    }
  } else {
    switch (reduction) {
      case Reduction::Mean:
        if (input.sym_numel()==0){
          return -(input * target).sum().fill_(std::numeric_limits<double>::quiet_NaN());
        } else {
          return -(input * target).sum() / (input.sym_numel() / n_classes);
        }
      case Reduction::Sum:
        return -(input * target).sum();
      case Reduction::None:
        return -(input * target).sum(class_dim);
      default:
        TORCH_CHECK(false, "Invalid reduction type encountered in cross_entropy: ", reduction);
    }
  }
}

static Tensor cross_entropy_loss_label_smoothing(
    const Tensor& self,
    const Tensor& target,
    const Tensor& weight,
    int64_t reduction,
    c10::SymInt ignore_index,
    double label_smoothing) {
    auto class_dim = self.dim() == 1 ? 0 : 1;
    auto input = at::log_softmax(self, class_dim, self.scalar_type());
    auto nllloss = at::nll_loss_nd_symint(input, target, weight, reduction, ignore_index);

    auto n_classes = input.sym_size(class_dim);

    Tensor smooth_loss;
    if (weight.defined()) {
      // Expand weight to the correct number of dims for broadcasting with input / target
      auto weight_broadcast_shape = SmallBuffer<int64_t, 5>(input.dim());
      std::fill(weight_broadcast_shape.begin(), weight_broadcast_shape.end(), 1);
      weight_broadcast_shape[class_dim] = weight.size(0);
      Tensor weight_ = weight.view(weight_broadcast_shape);

      smooth_loss = -(input * weight_).sum(class_dim);
    } else {
      smooth_loss = -input.sum(class_dim);
    }

    auto ignore_mask = target == std::move(ignore_index);
    smooth_loss.masked_fill_(ignore_mask, 0.0);

    Tensor ret;
    switch (reduction) {
      case Reduction::Mean:
        if (weight.defined()) {
          if (isTensorSubclassLike(weight)){
            // we will collect weights from 0 index which is always valid
            // and mask them out if they are ignored
            auto filtered_target = target.masked_fill(ignore_mask, 0);
            auto tgt_weights = weight.gather(0, filtered_target.flatten());
            auto weight_sum =
                tgt_weights.masked_fill_(ignore_mask.flatten(), 0).sum();
            ret = smooth_loss.sum() / weight_sum;
          } else {
            // TODO: This code can path can be removed if #61309 is resolved
            // loss is normalized by the weights to be consistent with
            // nll_loss_nd
            ret = smooth_loss.sum() /
                weight.gather(0, target.masked_select(~ignore_mask).flatten())
                    .sum();
          }
        } else {
          auto true_mask = ~ignore_mask;
          ret = smooth_loss.sum()/ true_mask.sum();
        }
        break;
      case Reduction::Sum:
        ret = smooth_loss.sum();
        break;
      case Reduction::None:
        ret = smooth_loss;
        break;
      default:
        TORCH_CHECK(false, "Invalid reduction type encountered in cross_entropy: ", reduction);
    }
    return (1 - label_smoothing) * nllloss + ret * (label_smoothing / n_classes);
}

Tensor cross_entropy_loss_symint(
    const Tensor& self,
    const Tensor& target,
    const std::optional<Tensor>& weight,
    int64_t reduction,
    c10::SymInt ignore_index,
    double label_smoothing) {
  Tensor ret;
  if (self.sym_sizes() == target.sym_sizes()) {
    // Assume soft targets when input and target shapes are the same
    TORCH_CHECK(at::isFloatingType(target.scalar_type()),
        "Expected floating point type for target with class probabilities, got ", target.scalar_type());
    TORCH_CHECK(ignore_index < 0, "ignore_index is not supported for floating point target");

    // See [Note: hacky wrapper removal for optional tensor]
    c10::MaybeOwned<Tensor> weight_maybe_owned = at::borrow_from_optional_tensor(weight);
    const Tensor& weight_ = *weight_maybe_owned;
    ret = cross_entropy_loss_prob_target(self, target, weight_, reduction, label_smoothing);
  } else if (label_smoothing > 0.0) {
    TORCH_CHECK(label_smoothing <= 1.0, "label_smoothing must be between 0.0 and 1.0. Got: ", label_smoothing);

    // See [Note: hacky wrapper removal for optional tensor]
    c10::MaybeOwned<Tensor> weight_maybe_owned = at::borrow_from_optional_tensor(weight);
    const Tensor& weight_ = *weight_maybe_owned;
    ret = cross_entropy_loss_label_smoothing(self, target, weight_, reduction, std::move(ignore_index), label_smoothing);
  } else {
    auto class_dim = self.dim() == 1 ? 0 : 1;
    ret = at::nll_loss_nd_symint(
        at::log_softmax(self, class_dim, self.scalar_type()),
        target,
        weight,
        reduction,
        std::move(ignore_index));
  }
  return ret;
}

Tensor & nll_loss_out(const Tensor & self, const Tensor & target, const std::optional<Tensor>& weight_opt, int64_t reduction, int64_t ignore_index, Tensor & output) {
  // See [Note: hacky wrapper removal for optional tensor]
  c10::MaybeOwned<Tensor> weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt);
  const Tensor& weight = *weight_maybe_owned;

  Tensor total_weight = at::empty({0}, self.options());
  return std::get<0>(at::nll_loss_forward_out(output, total_weight, self, target, weight, reduction, ignore_index));
}

Tensor nll_loss_symint(const Tensor & self, const Tensor & target, const std::optional<Tensor>& weight_opt, int64_t reduction, c10::SymInt ignore_index) {
  // See [Note: hacky wrapper removal for optional tensor]
  c10::MaybeOwned<Tensor> weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt);
  const Tensor& weight = *weight_maybe_owned;

  return std::get<0>(at::nll_loss_forward_symint(self, target, weight, reduction, std::move(ignore_index)));
}

Tensor nll_loss_nd_symint(
    const Tensor& self,
    const Tensor& target,
    const std::optional<Tensor>& weight,
    int64_t reduction,
    c10::SymInt ignore_index) {
  if (self.dim() < 1) {
    TORCH_CHECK_VALUE(
        false, "Expected 1 or more dimensions (got ", self.dim(), ")");
  }

  if (self.dim() != 1 && self.sym_sizes()[0] != target.sym_sizes()[0]) {
    TORCH_CHECK_VALUE(
        false,
        "Expected input batch_size (",
        self.sym_sizes()[0],
        ") to match target batch_size (",
        target.sym_sizes()[0],
        ").");
  }

  Tensor ret;
  Tensor input_ = self;
  Tensor target_ = target;
  if (input_.dim() == 1 || input_.dim() == 2) {
    ret = at::nll_loss_symint(input_, target_, weight, reduction, std::move(ignore_index));
  } else if (input_.dim() == 4) {
    ret = at::nll_loss2d_symint(input_, target_, weight, reduction, std::move(ignore_index));
  } else {
    // dim == 3 or dim > 4
    auto n = input_.sym_sizes()[0];
    auto c = input_.sym_sizes()[1];
    auto out_size = input_.sym_sizes().slice(2).vec();
    out_size.insert(out_size.begin(), n);
    if (target_.sym_sizes().slice(1) != input_.sym_sizes().slice(2)) {
      TORCH_CHECK(
          false,
          "Expected target size ",
          SymIntArrayRef(out_size),
          ", got ",
          target_.sym_sizes());
    }
    input_ = input_.contiguous();
    target_ = target_.contiguous();
    // support empty batches, see #15870
    if (input_.sym_numel() > 0) {
      input_ = input_.view_symint({n, std::move(c), 1, -1});
    } else {
      input_ = input_.view_symint({n, std::move(c), 0, 0});
    }
    if (target_.sym_numel() > 0) {
      target_ = target_.view_symint({std::move(n), 1, -1});
    } else {
      target_ = target_.view_symint({std::move(n), 0, 0});
    }
    if (reduction != Reduction::None) {
      ret = at::nll_loss2d_symint(input_, target_, weight, reduction, std::move(ignore_index));
    } else {
      auto out =
          at::nll_loss2d_symint(input_, target_, weight, reduction, std::move(ignore_index));
      ret = out.view_symint(out_size);
    }
  }
  return ret;
}

} // namespace at::native
