// Copyright (c) Facebook, Inc. and its affiliates.
// All rights reserved.
//
// This source code is licensed under the BSD-style license found in the
// LICENSE file in the root directory of this source tree.

#include <ATen/functorch/BatchRulesHelper.h>
#include <ATen/Operators.h>
#include <ATen/functorch/PlumbingHelper.h>
#include <ATen/functorch/BatchedFallback.h>
#include <ATen/native/TensorAdvancedIndexing.h>
#include <ATen/native/IndexKernel.h>
#include <ATen/native/IndexingUtils.h>
#include <torch/library.h>


namespace at::functorch {

namespace {
static bool any_has_value(ArrayRef<std::optional<int64_t>> bdims) {
  for (const auto& bdim : bdims) {
    if (bdim.has_value()) {
      return true;
    }
  }
  return false;
}

static int64_t get_num_leading_nones(ArrayRef<std::optional<Tensor>> indices) {
  int64_t result = 0;
  for (const auto& idx : indices) {
    if (!idx.has_value() || !idx->defined()) {
      result++;
    } else {
      return result;
    }
  }
  return result;
}

static int64_t get_max_index_logical_dim(
    ArrayRef<std::optional<Tensor>> indices,
    ArrayRef<std::optional<int64_t>> indices_bdims) {
  int64_t max_logical_dim = -1;
  TORCH_INTERNAL_ASSERT(indices.size() == indices_bdims.size());
  TORCH_INTERNAL_ASSERT(!indices.empty());
  for (const auto i : c10::irange(0, indices.size())) {
    const auto& maybe_tensor = indices[i];
    if (!maybe_tensor.has_value() || !maybe_tensor->defined()) {
      continue;
    }
    auto logical_dim = rankWithoutBatchDim(maybe_tensor.value(), indices_bdims[i]);
    max_logical_dim = std::max(logical_dim, max_logical_dim);
  }
  return max_logical_dim;
}

static std::vector<std::optional<Tensor>> batchIndices(
  ArrayRef<std::optional<Tensor>> indices,
  ArrayRef<std::optional<int64_t>> indices_bdims,
  int64_t batch_size,
  std::optional<int64_t> self_bdim,
  std::optional<int64_t> values_bdim = std::nullopt) {
  // There are 3 main cases:
  // 1. self is batched, indices/values are not batched
  // In this case, we just need to augment indices with a None at the front to
  // basically broadcast the indexing across the batch dimension of self.
  //
  // 2. self is not batched, some indices are batched.
  // In this case, we don't need to do anything - indices will automatically
  // broadcast to work with the unbatched self.
  //
  // 3. self is batched, some indices are batched.
  // In this case, we simply need to add an arange that indexes along the first
  // dimension (i.e. the batch dimension). We also need to make sure this
  // broadcasts with the rest of the indices.
  //
  // In all three cases, depending on if advanced indices are adjacent we will
  // have to permute the output.
  // See NOTE: [advanced indexing (index.Tensor) batch rule] for more details
  //
  // There is one more case worth mentioning - boolean tensor indices. If we
  // have "batched" boolean tensor indices, that is unrepresentable, as each
  // batch would result in a tensor with different values.
  std::vector<std::optional<Tensor>> indices_;

  int64_t maxLogicalRank = get_max_index_logical_dim(indices, indices_bdims);
  bool indices_batched = any_has_value(indices_bdims);

  for (size_t i = 0; i < indices.size(); i++) {
    auto index = indices[i];
    if (index.has_value() && index->numel() != 0) {
      const auto idx_bdim = indices_bdims[i];
      indices_.emplace_back(maybePadToLogicalRank(moveBatchDimToFront(index.value(), idx_bdim), idx_bdim, maxLogicalRank));
      if (index.value().dtype() == kBool && indices_bdims[i].has_value()) {
        throw std::runtime_error("vmap: We do not support batching operators that can support dynamic shape. Attempting to batch over indexing with a boolean mask.");
      }
    } else {
      indices_.push_back(index);
    }
  }

  auto maxIndexDim = maxLogicalRank;
  if (indices_batched || values_bdim.has_value()) {
    maxIndexDim += 1;
  }

  if (!indices_batched && self_bdim.has_value()) {
    indices_.insert(indices_.begin(), std::nullopt);
  } else if (indices_batched && !self_bdim.has_value()) {
    // do nothing
  } else if (indices_batched && (self_bdim.has_value() || values_bdim.has_value())) {
    auto arange_index = at::arange(0, batch_size);
    while (arange_index.dim() < maxIndexDim) {
      arange_index = arange_index.unsqueeze(-1);
    }
    // TODO: this is O(N)
    indices_.insert(indices_.begin(), arange_index);
  }
  return indices_;
}

// Define an "advanced index" to be a selection object that is
// a non-trivial Tensor (i.e. it does not represent :).
static bool is_advanced_index(const std::optional<Tensor>& idx) {
  if (!idx.has_value()) {
    return false;
  }
  if (!idx->defined()) {
    return false;
  }
  return true;
}

// See NOTE: [advanced indices adjacent] for definition
static bool are_advanced_indices_adjacent(ArrayRef<std::optional<Tensor>> indices) {
  int64_t num_advanced_indices_regions = 0;
  bool in_advanced_indices_region = false;
  for (const auto& idx : indices) {
    if (!in_advanced_indices_region && is_advanced_index(idx)) {
      num_advanced_indices_regions++;
      in_advanced_indices_region = true;
      continue;
    }
    if (in_advanced_indices_region && !is_advanced_index(idx)) {
      in_advanced_indices_region = false;
      continue;
    }
  }
  return num_advanced_indices_regions <= 1;
}

// Given a Tensor[B, <first_region>, <second_region>, ...]
// Swaps the regions to produce Tensor[B, <second_region>, <first_region>, ...]
//
// Concretely speaking, given
// - tensor: Tensor[B, 2, 3, 4, 5, 6, 7, 8]
// - first_region_size: 2
// - second_region_size: 3
// Produces:
// - result: Tensor[B, 4, 5, 6, 2, 3, 7, 8]
//                     -------  ----
//                     region2  region1
static Tensor swap_regions(const Tensor& tensor, int64_t first_region_size, int64_t second_region_size) {
  VmapDimVector permutation(tensor.dim(), 0);
  std::iota(permutation.begin(), permutation.end(), 0);
  std::rotate(
      permutation.begin() + 1,
      permutation.begin() + 1 + first_region_size,
      permutation.begin() + 1 + first_region_size + second_region_size);
  return tensor.permute(permutation);
}

std::tuple<Tensor, std::optional<int64_t>> index_batch_rule(
    const Tensor& self,
    std::optional<int64_t> self_bdim,
    ArrayRef<std::optional<Tensor>> indices,
    ArrayRef<std::optional<int64_t>> indices_bdims) {

  // NOTE: [advanced indexing (index.Tensor) batch rule]
  //
  // This is a three step procedure:
  // 1. batch `indices`. Depends on self_bdim and indices_bdim.
  // 2. call at::index
  // 3. (maybe) reorder the dimensions in the result.
  // Why is step 3 necessary? Let's take a detour first.
  //
  // NOTE: [advanced indices adjacent]
  // Definition: In a list of std::optional<Tensor> indices,
  // we say that "advanced indices are adjacent" if ALL advanced indices are
  // not separated by a None (slice).
  //
  // So, for example,
  // [:, :, (0, 1), (0, 1), :] -> True
  // [:, (0, 1), :, (0, 1), :] -> False, the advanced indices are separated by a slice
  //
  // See https://numpy.org/doc/stable/user/basics.indexing.html#combining-advanced-and-basic-indexing
  // for more details.
  //
  // NOTE: [Why is step 3 necessary?]
  //
  // In the original self[*indices] expression,
  // depending on whether or not the "advanced indices inside `indices` are
  // adjacent", something different happens.
  //
  // For example:
  // - self: Tensor[4, 5, 6, 7]
  // - indices: [:, (0, 1), (0, 1), :] (advanced indices are adjacent)
  // - self[*indices]: Tensor[4, 2, 7]
  // If advanced indices are adjacent, you get the output you would expect.
  // (0, 1), (0, 1) says "please index these two dimensions at (0, 0) and (1, 1)
  // to produce two elements".
  //
  // If advanced indices are not adjacent, it is ambiguous to where the new
  // dimension of size 2 should go. The numpy spec says it should go at the very
  // front of the Tensor.
  //
  // - self: Tensor[4, 5, 6, 7]
  // - indices: [:, (0, 1), :, (0, 1)] (advanced indices not adjacent)
  // - self[*indices]: Tensor[2, 4, 6]
  //
  // Now, this leads to some weird interactions with vmap.
  // The indices might originally have adjacent advanced indices, but after
  // batching them with "batchIndices", they may no longer be adjacent!
  // - indices: [:, (0, 1), (0, 1)]
  // - batched_indices (for example): [(0, 1), :, (0, 1), (0, 1)]
  // This leads to the dimension of size 2 appearing somewhere else.
  //
  // There are a couple of different cases that we walk through in the code below.
  //
  // Background reading for why we care about if the advanced indices are adjacent:
  // https://numpy.org/doc/stable/user/basics.indexing.html#combining-advanced-and-basic-indexing
  auto self_ = moveBatchDimToFront(self, self_bdim);
  TORCH_INTERNAL_ASSERT(indices.size() == indices_bdims.size());
  bool advanced_indices_are_adjacent = are_advanced_indices_adjacent(indices);

  // Step 1
  const auto batched_indices = batchIndices(indices, indices_bdims, self_.size(0), self_bdim);
  auto num_leading_nones = get_num_leading_nones(indices);
  auto max_index_dim = get_max_index_logical_dim(indices, indices_bdims);

  // Step 2
  auto res = at::index(self_, List<std::optional<Tensor>>(batched_indices));

  // Step 3: There are three cases (these match the cases outlined in batchIndices)
  bool self_batched = self_bdim.has_value();
  bool indices_batched = any_has_value(indices_bdims);

  TORCH_INTERNAL_ASSERT(self_batched || indices_batched, "Requires at least one batched to get here");

  // Case 1
  if (self_batched && !indices_batched) {
    if (advanced_indices_are_adjacent) {
      // self: Tensor[B, 5, 6, 7, 8]
      // indices: [:, Tensor[2, 2], Tensor[2, 2], :]
      // batched_indices: [:, :, Tensor[2, 2], Tensor[2, 2], :]
      // res: Tensor[B, 5, 2, 2, 8]
      return std::make_tuple(res, 0);
    } else {
      // self: Tensor[B, 5, 6, 7]
      // indices: [Tensor[2, 2], :, Tensor[2, 2]]
      // batched_indices: [:, Tensor[2, 2], :, Tensor[2, 2]]
      // res: Tensor[2, 2, B, 6]
      return std::make_tuple(res, max_index_dim);
    }
  }

  // Case 2
  if (!self_batched && indices_batched) {
    if (advanced_indices_are_adjacent) {
      // self: Tensor[5, 6, 7, 8]
      // indices: [:, :, Tensor[B, 2, 2], Tensor[2, 2]]
      // batched_indices: indices (no change)
      // res: Tensor[5, 6, B, 2, 2]
      return std::make_tuple(res, num_leading_nones);
    } else {
      // self: Tensor[5, 6, 7, 8, 9]
      // indices: [:, :, Tensor[B, 2, 2], :, Tensor[2, 2]]
      // batched_indices: indices (no change)
      // res: Tensor[B, 2, 2, 5, 6, 8]
      return std::make_tuple(res, 0);
    }
  }

  // Case 3: self_batched and indices_batched
  TORCH_INTERNAL_ASSERT(self_batched && indices_batched);
  if (!advanced_indices_are_adjacent) {
    // self: Tensor[B, 5, 6, 7, 8]
    // indices: [:, Tensor[B, 2, 2], :, Tensor[2, 2]]
    // batched_indices: [arange(B).expand(B, 2, 2), :, Tensor[B, 2, 2], :, Tensor[2, 2]]
    // res: Tensor[B, 2, 2, 5, 7]
    return std::make_tuple(res, 0);
  }
  // In other words, in batched_indices, advanced indices are adjacent
  if (num_leading_nones == 0) {
    // self: Tensor[B, 5, 6, 7, 8]
    // indices: [Tensor[B, 2, 2], Tensor[2, 2], :, :]
    // batched_indices: [arange(B).expand(B, 2, 2), Tensor[B, 2, 2], Tensor[2, 2], :, :]
    // res: Tensor[B, 2, 2, 7, 8]
    return std::make_tuple(res, 0);
  }
  // This is the tricky case. In indices, advanced indices are adjacent.
  // In batched_indices, advanced indices are no longer adjacent
  //
  // self: Tensor[B, 5, 6, 7, 8, 9]
  // indices: [:, :, Tensor[B, 2, 3], Tensor[2, 3], :]
  // batched_indices: [arange(B).expand(B, 2, 3), :, :, Tensor[B, 2, 3], Tensor[2, 3], :]
  // res: Tensor[B, 2, 3, 5, 6, 9]
  // expected: Tensor[B, 5, 6, 2, 3, 9]
  //
  // The resolution is to move dims around until we get the right shape.
  // The result is set up as [B, <maxIndexDim>, <leading_nones>, ...]
  // we just have to move the <leading_nones> to before the <maxIndexDim> to produce
  // [B, <leading_nones>, <maxIndexDim>, ...]
  return std::make_tuple(swap_regions(res, max_index_dim, num_leading_nones), 0);
}

// plumbing done since we don't support List<std::optional<Tensor>> in codegen
Tensor index_plumbing(const Tensor & self, const List<std::optional<Tensor>> & indices
) {
  c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched);
  auto maybe_layer = maybeCurrentDynamicLayer();
  vmap_check_escaped(maybe_layer, "index_plumbing");
  int64_t cur_level = maybe_layer->layerId();
  if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(indices, cur_level)) {
    return at::index(self, indices);
  }
  auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level);
  std::vector<std::optional<Tensor>> indices_value;
  std::vector<std::optional<int64_t>> indices_bdims;
  for (const auto&& indRef : indices) {
      std::optional<Tensor> ind = indRef;
      std::optional<Tensor> index;
      std::optional<int64_t> index_bdim;
      if (ind.has_value()) {
        std::tie(index, index_bdim) = unwrapTensorAtLevel(ind.value(), cur_level);
      }
    indices_value.push_back(index);
    indices_bdims.push_back(index_bdim);
  }
  auto results = index_batch_rule(self_value, self_bdim, indices_value, indices_bdims);
  return makeBatched(std::get<0>(results), std::get<1>(results), cur_level);
}

namespace {
  // Code is mostly duplicated from
  // https://github.com/pytorch/pytorch/blob/fb0e27d38a8fdab4e1c14d6378c9e41cb30fd6a3
  // /aten/src/ATen/native/TensorAdvancedIndexing.cpp#L294-L312
  VmapDimVector compute_indexed_shape(const Tensor &src, TensorList indices_list)
  {
    int64_t dims_before = 0, dims_indexed = 0;
    IntArrayRef replacement_shape;
    for (const auto dim : c10::irange(indices_list.size())) {
      if (!indices_list[dim].defined()) {
        if (dims_indexed == 0) {
          dims_before++;
        }
      } else {
        dims_indexed++;
        replacement_shape = indices_list[dim].sizes();
      }
    }

    // Replace indexed dimensions in src with stride 0 and the size of the result tensor.
    // The offset in these dimensions is computed by the kernel using the index tensor's
    // values and the stride of src. The new shape is not meaningful. It's used to make
    // the shape compatible with the result tensor.
    auto shape = VmapDimVector(src.sizes());
    int64_t end = dims_before + dims_indexed;
    shape.erase(shape.begin() + dims_before, shape.begin() + end);
    shape.insert(shape.begin() + dims_before, replacement_shape.begin(), replacement_shape.end());
    return shape;
  }

  // Code is mostly duplicated from
  // https://github.com/pytorch/pytorch/blob/fb0e27d38a8fdab4e1c14d6378c9e41cb30fd6a3
  // /aten/src/ATen/native/TensorAdvancedIndexing.cpp#L379-L405
  VmapDimVector get_indexed_shape(Tensor self, const torch::List<std::optional<at::Tensor>> &orig)
  {
    at::native::checkIndexTensorTypes(orig, /*allow_int*/ true);
    // first expand BoolTensor (masks) or ByteTensor (masks) into 1 or more LongTensors
    auto indices = at::native::expandTensors(self, orig);
    // next broadcast all index tensors together
    try {
      indices = at::expand_outplace(indices);
    } catch (std::exception &e) {
      TORCH_CHECK_INDEX(false, "shape mismatch: indexing tensors could not be broadcast together"
                               " with shapes ");
    }
    // add missing null Tensors so that it matches self.dim()
    while (indices.size() < static_cast<size_t>(self.dim())) {
      indices.emplace_back();
    }
    // if the non-null indices are not all adjacent, transpose self and indices
    // together so that they're adjacent at the front
    if (!at::native::hasContiguousSubspace(indices)) {
      std::tie(self, indices) = at::native::transposeToFront(self, indices);
    }
    return compute_indexed_shape(self, indices);
  }

  std::tuple<Tensor, std::vector<std::optional<Tensor>>, Tensor>
  index_put_batch_rule_helper(const Tensor &self,
                              std::optional<int64_t> self_bdim,
                              ArrayRef<std::optional<Tensor>> indices,
                              ArrayRef<std::optional<int64_t>> indices_bdims,
                              const Tensor &values,
                              std::optional<int64_t> values_bdim,
                              std::optional<int64_t> opt_batch_size = {}) {

    Tensor self_ = moveBatchDimToFront(self, self_bdim);
    Tensor values_ = moveBatchDimToFront(values, values_bdim);
    // for inplace variants `index_put_` and `_index_put_impl_` we find the batch_size
    // here while for `index_put` does it outside of this function.
    const auto batch_size = opt_batch_size ? opt_batch_size.value() : self_.size(0);
    self_ = ensure_has_bdim(self_, self_bdim.has_value(), batch_size);
    values_ = ensure_has_bdim(values_, values_bdim.has_value(), batch_size);
    TORCH_INTERNAL_ASSERT(indices.size() == indices_bdims.size());

    // we've already made sure that self has bdim at 0.
    const auto indices_ = batchIndices(indices, indices_bdims, batch_size, /*self_bdim=*/0, values_bdim);

    auto indexed_shape = get_indexed_shape(self_, List<std::optional<Tensor>>(indices_));

    // handle broadcasting support for values
    // Eg. Given `indexed_shape.size()` is 5 and
    // shape of `values` is (N, 2, 3), then following block
    // will reshape `values` to (N, 1, 1, 2, 3).
    if ( (int64_t) indexed_shape.size() > values_.dim()) {
      auto values_sizes = values_.sizes();

      // number of unit dims (for broadcasting value to indexed_shape)
      auto n_unit_dims = indexed_shape.size() - values_sizes.size();
      VmapDimVector new_values_shape(values_sizes.size() + n_unit_dims);

      // add the batch-dim
      new_values_shape[0] = batch_size;

      // insert the unit dims for broadcasting.
      for (const auto idx : c10::irange(n_unit_dims)) {
        // since batch-dim is already be filled.
        new_values_shape[idx + 1] = 1;
      }
      for (const auto idx: c10::irange(1, values_sizes.size())) {
        // since batch and unit dims are already be filled.
        new_values_shape[idx + n_unit_dims] = values_sizes[idx];
      }
      values_ = values_.view(new_values_shape);
    }

    return std::make_tuple(self_, indices_, values_);
  }

  auto unpackSelfAndIndicesAndValuesAtCurrentLevel(const Tensor &self,
                                                   const List<std::optional<Tensor>> &indices,
                                                   const Tensor &values, int64_t cur_level)
  {
    auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level);
    std::vector<std::optional<Tensor>> indices_value;
    std::vector<std::optional<int64_t>> indices_bdims;
    for (const auto &&indRef : indices)
    {
      std::optional<Tensor> ind = indRef;
      std::optional<Tensor> index;
      std::optional<int64_t> index_bdim;
      if (ind.has_value()) {
        std::tie(index, index_bdim) = unwrapTensorAtLevel(ind.value(), cur_level);
      }
      indices_value.push_back(index);
      indices_bdims.push_back(index_bdim);
    }
    auto [values_value, values_bdim] = unwrapTensorAtLevel(values, cur_level);
    return std::make_tuple(self_value, self_bdim, indices_value, indices_bdims, values_value, values_bdim);
  }

}  // namespace

void index_put__batch_rule(
    const Tensor& self,
    std::optional<int64_t> self_bdim,
    ArrayRef<std::optional<Tensor>> indices,
    ArrayRef<std::optional<int64_t>> indices_bdims,
    const Tensor& values,
    std::optional<int64_t> values_bdim,
    bool accumulate) {
  if (!self_bdim.has_value()) {
    vmapIncompatibleInplaceError("index_put_");
  }
  auto [self_, indices_, values_] = index_put_batch_rule_helper(
      self, self_bdim, indices, indices_bdims, values, values_bdim);
  at::index_put_(self_, List<std::optional<Tensor>>(indices_), values_, accumulate);
}

// plumbing done since we don't support List<std::optional<Tensor>> in codegen
Tensor& index_put__plumbing(Tensor & self, const List<std::optional<Tensor>> & indices
, const Tensor & values, bool accumulate) {
  c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched);
  auto maybe_layer = maybeCurrentDynamicLayer();
  vmap_check_escaped(maybe_layer, "index_put__plumbing");
  int64_t cur_level = maybe_layer->layerId();

  // on device mismatch, we can move 0d tensors to self device
  auto values_ = values;
  if (values.device() != self.device() && values.numel() == 1 && values.dim() == 0) {
    values_ = values.to(self.device());
  }

  if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(indices, cur_level) && !isBatchedAtLevel(values_, cur_level)) {
    return self.index_put_(indices, values_, accumulate);
  }
  auto [self_value, self_bdim, indices_value, indices_bdims, values_value, values_bdim] =
      unpackSelfAndIndicesAndValuesAtCurrentLevel(self, indices, values_, cur_level);
  index_put__batch_rule(self_value, self_bdim, indices_value, indices_bdims, values_value, values_bdim, accumulate);
  return self;
}

void _index_put_impl__batch_rule(
    const Tensor& self,
    std::optional<int64_t> self_bdim,
    ArrayRef<std::optional<Tensor>> indices,
    ArrayRef<std::optional<int64_t>> indices_bdims,
    const Tensor& values,
    std::optional<int64_t> values_bdim,
    bool accumulate,
    bool unsafe) {
  if (!self_bdim.has_value()) {
    vmapIncompatibleInplaceError("_index_put_impl_");
  }
  auto [self_, indices_, values_] = index_put_batch_rule_helper(
      self, self_bdim, indices, indices_bdims, values, values_bdim);
  at::_index_put_impl_(self_, List<std::optional<Tensor>>(indices_), values_, accumulate, unsafe);
}

// plumbing done since we don't support List<std::optional<Tensor>> in codegen
Tensor &_index_put_impl__plumbing(Tensor &self, const List<std::optional<Tensor>> &indices,
                                  const Tensor &values, bool accumulate, bool unsafe) {
  c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched);
  auto maybe_layer = maybeCurrentDynamicLayer();
  vmap_check_escaped(maybe_layer, "_index_put_impl__plumbing");
  int64_t cur_level = maybe_layer->layerId();
  if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(indices, cur_level) && !isBatchedAtLevel(values, cur_level)) {
    return at::_index_put_impl_(self, indices, values, accumulate, unsafe);
  }
  auto [self_value, self_bdim, indices_value, indices_bdims, values_value, values_bdim] =
      unpackSelfAndIndicesAndValuesAtCurrentLevel(self, indices, values, cur_level);
  _index_put_impl__batch_rule(self_value, self_bdim, indices_value, indices_bdims, values_value, values_bdim, accumulate, unsafe);
  return self;
}

static Tensor maybe_permute_values(
    const Tensor& values,
    ArrayRef<std::optional<Tensor>> orig_indices,
    ArrayRef<std::optional<int64_t>> orig_indices_bdims) {
  bool indices_batched = any_has_value(orig_indices_bdims);
  bool advanced_indices_are_adjacent = are_advanced_indices_adjacent(orig_indices);
  auto num_leading_nones = get_num_leading_nones(orig_indices);
  auto max_index_dim = get_max_index_logical_dim(orig_indices, orig_indices_bdims);

  // NB: values has its B dimension at the front
  if (!indices_batched) {
    if (advanced_indices_are_adjacent) {
      // self: Tensor[B, 5, 6, 7, 8]
      // indices: [:, Tensor[2, 2], Tensor[2, 2], :]
      // batched_indices: [:, :, Tensor[2, 2], Tensor[2, 2], :]
      // required values: Tensor[B, 5, 2, 2, 8]
      return values;
    }
    // self: Tensor[B, 5, 6, 7]
    // indices: [Tensor[2, 2], :, Tensor[2, 2]]
    // batched_indices: [:, Tensor[2, 2], :, Tensor[2, 2]]
    // required values: Tensor[2, 2, B, 6]
    return values.movedim(0, max_index_dim);
  }
  if (!advanced_indices_are_adjacent) {
    // self: Tensor[B, 5, 6, 7, 8]
    // indices: [:, Tensor[B, 2, 2], :, Tensor[2, 2]]
    // batched_indices: [arange(B).expand(B, 2, 2), :, Tensor[B, 2, 2], :, Tensor[2, 2]]
    // required values: Tensor[B, 2, 2, 5, 7]
    return values;
  }
  // In other words, in batched_indices, advanced indices are adjacent
  if (num_leading_nones == 0) {
    // self: Tensor[B, 5, 6, 7, 8]
    // indices: [Tensor[B, 2, 2], Tensor[2, 2], :, :]
    // batched_indices: [arange(B).expand(B, 2, 2), Tensor[B, 2, 2], Tensor[2, 2], :, :]
    // required values: Tensor[B, 2, 2, 7, 8]
    return values;
  }
  // This is the tricky case. In indices, advanced indices are adjacent.
  // In batched_indices, advanced indices are no longer adjacent
  //
  // self: Tensor[B, 5, 6, 7, 8, 9]
  // indices: [:, :, Tensor[B, 2, 3], Tensor[2, 3], :]
  // batched_indices: [arange(B).expand(B, 2, 3), :, :, Tensor[B, 2, 3], Tensor[2, 3], :]
  // required values: Tensor[B, 2, 3, 5, 6, 9]
  // actual values: Tensor[B, 5, 6, 2, 3, 9]
  //
  // The resolution is to move dims around until we get the right shape.
  // The values is set up as [B, <leading_nones>, <maxIndexDim>, ...]
  // we just have to move the <maxIndexDim> to before the <leading_nones> to produce
  // [B, <maxIndexDim>, <leading_nones>, ...]
  return swap_regions(values, num_leading_nones, max_index_dim);
}

std::tuple<Tensor, std::optional<int64_t>> index_put_batch_rule(
    const Tensor& self,
    std::optional<int64_t> self_bdim,
    ArrayRef<std::optional<Tensor>> indices,
    ArrayRef<std::optional<int64_t>> indices_bdims,
    const Tensor& values,
    std::optional<int64_t> values_bdim,
    bool accumulate) {
  TORCH_INTERNAL_ASSERT(indices.size() == indices_bdims.size());

  // find the batch_size
  int64_t batch_size = 0;
  if (self_bdim || values_bdim) {
    batch_size = get_bdim_size2(self, self_bdim, values, values_bdim);
  } else {
    // one or more of the indices is batched.
    for (size_t i = 0; i < indices.size(); i++) {
      if (indices_bdims[i] && indices[i].has_value()) {
        batch_size = indices[i].value().size(*indices_bdims[i]);
        break;
      }
    }
  }

  auto [self_, indices_, values_] = index_put_batch_rule_helper(
      self, self_bdim, indices, indices_bdims, values, values_bdim, batch_size);

  // Why do we need to permute values?
  // See NOTE [Advanced indexing (index.Tensor) batch rule] for details,
  // but the gist is that index_put effectively does the following:
  // - result = self_.clone()
  // - result[indices_] = values
  // - return result
  // Now, the problem is, result[indices_] might return a Tensor whose shape is
  // the shape of values, but permuted. This is because the shape of result[indices_]
  // depends on if the original indices "have adjacent advanced indices"
  // and the batched `indices_` might change the "have adjacent advanced indices" property
  values_ = maybe_permute_values(values_, indices, indices_bdims);

  auto result = at::index_put(self_, List<std::optional<Tensor>>(indices_), values_, accumulate);
  return std::make_tuple(result, 0);
}

// plumbing done since we don't support List<std::optional<Tensor>> in codegen
Tensor index_put_plumbing(const Tensor & self, const List<std::optional<Tensor>> & indices,
                          const Tensor & values, bool accumulate) {
  c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched);
  auto maybe_layer = maybeCurrentDynamicLayer();
  vmap_check_escaped(maybe_layer, "index_put_plumbing");
  int64_t cur_level = maybe_layer->layerId();

  // on device mismatch, we can move 0d tensors to self device
  auto values_ = values;
  if (values.device() != self.device() && values.numel() == 1 && values.dim() == 0) {
    values_ = values.to(self.device());
  }

  if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(indices, cur_level) && !isBatchedAtLevel(values_, cur_level)) {
    return self.index_put(indices, values_, accumulate);
  }
  auto [self_value, self_bdim, indices_value, indices_bdims, values_value, values_bdim] =
      unpackSelfAndIndicesAndValuesAtCurrentLevel(self, indices, values_, cur_level);
  auto results = index_put_batch_rule(self_value, self_bdim, indices_value, indices_bdims, values_value, values_bdim, accumulate);
  return makeBatched(std::get<0>(results), std::get<1>(results), cur_level);
}

namespace {

template<typename Func, typename ...Args>
std::tuple<Tensor, std::optional<int64_t>> scatter_batch_rule(
    Func f,
    const Tensor& self, std::optional<int64_t> self_bdim,
    int64_t dim,
    const Tensor& index, std::optional<int64_t> index_bdim,
    const Scalar& value, Args... args) {
  auto self_logical_rank = rankWithoutBatchDim(self, self_bdim);
  auto index_logical_rank = rankWithoutBatchDim(index, index_bdim);
  auto batch_size = get_bdim_size2(self, self_bdim, index, index_bdim);

  auto self_ = moveBatchDimToFront(self, self_bdim);
  auto index_ = moveBatchDimToFront(index, index_bdim);

  if (self_logical_rank == 0) {
    self_ = self_.unsqueeze(-1);
  }
  if (index_logical_rank == 0) {
    index_ = index_.unsqueeze(-1);
  }
  self_ = ensure_has_bdim(self_, self_bdim.has_value(), batch_size);
  index_ = ensure_has_bdim(index_, index_bdim.has_value(), batch_size);
  auto physical_dim = getPhysicalDim(self_, /*has_batch_dim*/true, dim);

  auto result = f(self_, physical_dim, index_, value, args...);
  // result should have same shape as self
  if (self_logical_rank == 0) {
    result = result.squeeze(-1);
  }
  return std::make_tuple(result, 0);
}

template <typename Func, typename ...Args>
inline std::tuple<Tensor, std::optional<int64_t>> scatter_batch_rule(
    Func f,
    const Tensor& self, std::optional<int64_t> self_bdim,
    int64_t dim,
    const Tensor& index, std::optional<int64_t> index_bdim,
    const Tensor& src, std::optional<int64_t> src_bdim, Args... args) {
  auto self_logical_rank = rankWithoutBatchDim(self, self_bdim);
  auto index_logical_rank = rankWithoutBatchDim(index, index_bdim);
  auto src_logical_rank = rankWithoutBatchDim(src, src_bdim);
  auto batch_size = get_bdim_size3(self, self_bdim, index, index_bdim, src, src_bdim);

  auto self_ = moveBatchDimToFront(self, self_bdim);
  auto index_ = moveBatchDimToFront(index, index_bdim);
  auto src_ = moveBatchDimToFront(src, src_bdim);

  if (self_logical_rank == 0) {
    self_ = self_.unsqueeze(-1);
  }
  if (index_logical_rank == 0) {
    index_ = index_.unsqueeze(-1);
  }
  if (src_logical_rank == 0) {
    src_ = src_.unsqueeze(-1);
  }
  self_ = ensure_has_bdim(self_, self_bdim.has_value(), batch_size);
  index_ = ensure_has_bdim(index_, index_bdim.has_value(), batch_size);
  src_ = ensure_has_bdim(src_, src_bdim.has_value(), batch_size);
  auto physical_dim = getPhysicalDim(self_, /*has_batch_dim*/true, dim);

  auto result = f(self_, physical_dim, index_, src_, args...);
  // result should have same shape as self
  if (self_logical_rank == 0) {
    result = result.squeeze(-1);
  }
  return std::make_tuple(result, 0);
}

} // namespace

std::tuple<Tensor, std::optional<int64_t>> scatter_value_batch_rule(
    const Tensor& self, std::optional<int64_t> self_bdim,
    int64_t dim,
    const Tensor& index, std::optional<int64_t> index_bdim,
    const Scalar& value) {
  return scatter_batch_rule(ATEN_FN2(scatter, value),
                            self, self_bdim, dim, index, index_bdim, value);
}

std::tuple<Tensor, std::optional<int64_t>> scatter_src_batch_rule(
    const Tensor& self, std::optional<int64_t> self_bdim,
    int64_t dim,
    const Tensor& index, std::optional<int64_t> index_bdim,
    const Tensor& src, std::optional<int64_t> src_bdim) {
  return scatter_batch_rule(ATEN_FN2(scatter, src),
                            self, self_bdim, dim, index, index_bdim, src, src_bdim);
}

std::tuple<Tensor, std::optional<int64_t>> scatter_add_batch_rule(
    const Tensor& self, std::optional<int64_t> self_bdim,
    int64_t dim,
    const Tensor& index, std::optional<int64_t> index_bdim,
    const Tensor& src, std::optional<int64_t> src_bdim) {
  return scatter_batch_rule(ATEN_FN(scatter_add),
                            self, self_bdim, dim, index, index_bdim, src, src_bdim);
}

std::tuple<Tensor, std::optional<int64_t>> scatter_reduce_batch_rule(
    const Tensor& self, std::optional<int64_t> self_bdim,
    int64_t dim,
    const Tensor& index, std::optional<int64_t> index_bdim,
    const Tensor& src, std::optional<int64_t> src_bdim,
    const c10::string_view reduce) {
  return scatter_batch_rule(ATEN_FN2(scatter, reduce),
                            self, self_bdim, dim, index, index_bdim, src, src_bdim, reduce);
}

std::tuple<Tensor, std::optional<int64_t>> scatter_value_reduce_batch_rule(
    const Tensor& self, std::optional<int64_t> self_bdim,
    int64_t dim,
    const Tensor& index, std::optional<int64_t> index_bdim,
    const Scalar& src,
    const c10::string_view reduce) {
  return scatter_batch_rule(ATEN_FN2(scatter, value_reduce),
                            self, self_bdim, dim, index, index_bdim, src, reduce);
}

std::tuple<Tensor, std::optional<int64_t>> gather_batch_rule(
    const Tensor& self, std::optional<int64_t> self_bdim,
    int64_t dim,
    const Tensor& index, std::optional<int64_t> index_bdim,
    bool sparse_grad) {
  auto self_logical_rank = rankWithoutBatchDim(self, self_bdim);
  auto index_logical_rank = rankWithoutBatchDim(index, index_bdim);
  auto batch_size = get_bdim_size2(self, self_bdim, index, index_bdim);

  auto self_ = moveBatchDimToFront(self, self_bdim);
  auto index_ = moveBatchDimToFront(index, index_bdim);

  if (self_logical_rank == 0) {
    self_ = self_.unsqueeze(-1);
  }
  if (index_logical_rank == 0) {
    index_ = index_.unsqueeze(-1);
  }
  self_ = ensure_has_bdim(self_, self_bdim.has_value(), batch_size);
  index_ = ensure_has_bdim(index_, index_bdim.has_value(), batch_size);
  auto physical_dim = getPhysicalDim(self_, /*has_batch_dim*/true, dim);

  auto result = at::gather(self_, physical_dim, index_, sparse_grad);
  // result should have same rank as index
  if (index_logical_rank == 0) {
    result = result.squeeze(-1);
  }
  return std::make_tuple(result, 0);
}

Tensor get_expanded_index(const Tensor& index, IntArrayRef self_size, int64_t dim) {
  if (index.dim() == 0) {
    return index.expand(self_size);
  }
  dim = maybe_wrap_dim(dim, static_cast<int64_t>(self_size.size()));

  // setup new_index_shape as [BS, 1, ..., idx_size, ..., 1]
  // to reshape index_
  auto idx_size = index.size(0);  // get non-batch size of index tensor
  Tensor index_;
  {
    VmapDimVector new_index_shape(self_size.size(), 1);
    new_index_shape[dim] = idx_size;
    index_ = index.view(new_index_shape);
  }
  // Now apply expand to index_
  {
    VmapDimVector new_index_shape = {self_size.begin(), self_size.end()};
    new_index_shape[dim] = idx_size;
    index_ = index_.expand(new_index_shape);
  }
  return index_;
}

Tensor index_select_decomp(const Tensor &self, int64_t dim, const Tensor &index)
{
  Tensor index_ = index;
  if (self.dim() > index.dim()) {
    index_ = get_expanded_index(index, self.sizes(), dim);
  }

  auto result = at::gather(self, dim, index_);

  // output of gather has same dimension as `index` while
  // output of index_select has same dimension as self
  // Eg. t = torch.tensor(1)
  //     idx = torch.tensor([0])
  //     torch.index_select(t, 0, idx) # 0-D
  //     torch.gather(t, 0, idx) # 1-D
  if (self.dim() == 0 && result.dim() != 0) {
    result = result.squeeze(-1);
  }

  return result;
}

Tensor index_copy_decomp(
    const Tensor &self, int64_t dim,
    const Tensor &index, const Tensor &source)
{
  Tensor index_ = index;
  if (self.dim() > index.dim()) {
    index_ = get_expanded_index(index, self.sizes(), dim);
  }

  return at::scatter(self, dim, index_, source);  ;
}

// Note [Fix vmap slice_scatter]
// registers a decomposition for `slice_scatter` that calls into `slice.src`
// *_scatter operators have some special semantics though, that we can't easily
// through a decomposition: slice_scatter's output needs to have the same
// size, size, strides and storage_offset as the input.
Tensor slice_scatter_decomp(const Tensor &self, const Tensor &src,
                            int64_t dim, std::optional<int64_t> start,
                            std::optional<int64_t> end, int64_t step)
{
  auto idx = at::arange(start.value_or(0), end.value_or(self.size(dim)), step, self.options().dtype(kLong));
  idx = get_expanded_index(idx, self.sizes(), dim);
  return at::scatter(self, dim, idx, src);
}

Tensor select_scatter_decomp(
    const Tensor &self, const Tensor &source,
    int64_t dim, int64_t index)
{
  // supports negative index
  index = maybe_wrap_dim(index, self.size(dim));
  auto index_ = at::scalar_tensor(index, self.options().dtype(kLong));

  return at::scatter(self, dim, index_.expand_as(self), source.unsqueeze(dim).expand_as(self));
}

std::tuple<Tensor, std::optional<int64_t>> diagonal_scatter_batch_rule(
    const Tensor &self, std::optional<int64_t> self_bdim,
    const Tensor &src, std::optional<int64_t> src_bdim,
    int64_t offset, int64_t dim1, int64_t dim2)
{
  auto self_ = moveBatchDimToFront(self, self_bdim);
  auto src_ = moveBatchDimToFront(src, src_bdim);

  auto batch_size = get_bdim_size2(self, self_bdim, src, src_bdim);

  self_ = ensure_has_bdim(self_, self_bdim.has_value(), batch_size);
  src_ = ensure_has_bdim(src_, src_bdim.has_value(), batch_size);

  auto self_logical_rank = rankWithoutBatchDim(self, self_bdim);
  dim1 = maybe_wrap_dim(dim1, self_logical_rank) + 1;
  dim2 = maybe_wrap_dim(dim2, self_logical_rank) + 1;

  return std::make_tuple(at::diagonal_scatter(self_, src_, offset, dim1, dim2), 0);
}

std::tuple<Tensor, std::optional<int64_t>> index_add_batch_rule_impl(
    Tensor& self, std::optional<int64_t> self_bdim,
    int64_t dim,
    const Tensor& index, std::optional<int64_t> index_bdim,
    const Tensor& other, std::optional<int64_t> other_bdim,
    const Scalar& alpha,
    const bool inplace) {

  if (inplace && !self_bdim.has_value()){
    vmapIncompatibleInplaceError("index_add_");
  }

  if (!index_bdim) {
    // Handle scalar tensors... self, other can be scalar tensors
    const auto self_logical_rank = rankWithoutBatchDim(self, self_bdim);
    const auto other_logical_rank = rankWithoutBatchDim(other, other_bdim);
    auto self_ = moveBatchDimToFront(self, self_bdim);
    if (self_logical_rank == 0) {
      self_ = self_.unsqueeze(-1);
    }
    auto other_ = moveBatchDimToFront(other, other_bdim);
    if (other_logical_rank == 0) {
      other_ = other_.unsqueeze(-1);
    }
    dim = maybe_wrap_dim(dim, self_logical_rank);

    const auto batch_size = get_bdim_size2(self, self_bdim, other, other_bdim);
    self_ = ensure_has_bdim(self_, self_bdim.has_value(), batch_size);
    other_ = ensure_has_bdim(other_, other_bdim.has_value(), batch_size);

    if (inplace) {
      self_.index_add_(dim + 1, index, other_, alpha);
      if (self_logical_rank == 0) {
        self_ = self_.squeeze(-1);
      }
      return std::make_tuple(self, 0);
    }

    auto result = self_.index_add(dim + 1, index, other_, alpha);
    if (self_logical_rank == 0) {
      result = result.squeeze(-1);
    }
    return std::make_tuple(result, 0);
  }

  // Index is batched. For-loop and stack is the best thing I can come up with
  // right now. We really want generalized index_add kernel in PyTorch
  auto batch_size = get_bdim_size3(self, self_bdim, other, other_bdim, index, index_bdim);
  std::vector<Tensor> results;
  if (!inplace) {
    results.reserve(batch_size);
  }
  for (const auto i : c10::irange(0, batch_size)) {
    const auto& self_slice = self_bdim.has_value() ?
      self.select(*self_bdim, i) : self;
    const auto& other_slice = other_bdim.has_value() ?
      other.select(*other_bdim, i) : other;
    const auto& index_slice = index_bdim.has_value() ?
      index.select(*index_bdim, i) : index;

    if (inplace) {
      self_slice.index_add_(dim, index_slice, other_slice, alpha);
    } else {
      results.push_back(at::index_add(self_slice, dim, index_slice, other_slice, alpha));
    }
  }
  if (inplace) {
    return std::make_tuple(at::stack(self), 0);
  }
  return std::make_tuple(at::stack(results), 0);
}

void index_add__batch_rule(
    Tensor& self, std::optional<int64_t> self_bdim,
    int64_t dim,
    const Tensor& index, std::optional<int64_t> index_bdim,
    const Tensor& other, std::optional<int64_t> other_bdim,
    const Scalar& alpha) {
  index_add_batch_rule_impl(self, self_bdim, dim, index, index_bdim, other,
                            other_bdim, alpha, true);
}

std::tuple<Tensor, std::optional<int64_t>> index_add_batch_rule(
    Tensor& self, std::optional<int64_t> self_bdim,
    int64_t dim,
    const Tensor& index, std::optional<int64_t> index_bdim,
    const Tensor& other, std::optional<int64_t> other_bdim,
    const Scalar& alpha) {
  auto self_ = self.clone(at::MemoryFormat::Preserve);
  return index_add_batch_rule_impl(self_, self_bdim, dim, index, index_bdim,
                                   other, other_bdim, alpha, false);
}

static std::tuple<Tensor,Tensor> binary_pointwise_align(
    const Tensor & self,
    std::optional<int64_t> self_bdim,
    const Tensor & mask,
    std::optional<int64_t> mask_bdim) {
  // compute max logical rank
  auto tensor_logical_rank = rankWithoutBatchDim(self, self_bdim);
  auto other_logical_rank = rankWithoutBatchDim(mask, mask_bdim);
  auto max_logical_rank = std::max(tensor_logical_rank, other_logical_rank);

  auto tensor_ = moveBatchDimToFront(self, self_bdim);
  auto other_ = moveBatchDimToFront(mask, mask_bdim);

  // If the dimensions aren't aligned, we need to line them up.
  // Tensor[B, 3] + Tensor[2, 5, 3] -> Tensor[B, 1, 1, 3] + Tensor[2, 5, 3]
  // Note that only tensors that have a batch dim need to be modified.
  // Tensor[B, 2, 3, 5] + Tensor[5] -> no changes needed
  tensor_ = maybePadToLogicalRank(tensor_, self_bdim, max_logical_rank);
  other_ = maybePadToLogicalRank(other_, mask_bdim, max_logical_rank);

  return std::make_tuple(tensor_, other_);
}

std::tuple<Tensor, std::optional<int64_t>> masked_fill_scalar_batch_rule(
    const Tensor & self,
    std::optional<int64_t> self_bdim,
    const Tensor & mask,
    std::optional<int64_t> mask_bdim,
    const Scalar& source) {
  auto tensors = binary_pointwise_align(self, self_bdim, mask, mask_bdim);
  auto result = at::masked_fill(std::get<0>(tensors), std::get<1>(tensors), source);
  return std::make_tuple(result, 0);
}

std::tuple<Tensor, std::optional<int64_t>> index_fill_batch_rule_helper(
  int64_t batch_size,
  int64_t self_logical_rank,
  int64_t index_logical_rank,
  Tensor & self_,
  int64_t dim,
  Tensor & index_,
  const Scalar & value
  ){
  if (self_logical_rank != 0){
    auto index_offset = at::arange(
      batch_size,
      at::TensorOptions().dtype(index_.scalar_type()).device(index_.device())
    );
    if (index_logical_rank == 0){
      index_ = index_.unsqueeze(-1);
    }
    index_ = index_.add(index_offset.unsqueeze(-1), self_.size(dim + 1));
    index_ = reshape_dim_into(0, 0, index_);
    self_ = reshape_dim_into(0, dim, self_);
    self_.index_fill_(dim, index_, value);
    self_ = reshape_dim_outof(dim, batch_size, self_);
    return std::make_tuple(self_, dim);
  }

  // If self_logical_rank == 0, the batch dim is certainly 0, and we must apply batched indices to each row.
  if (index_logical_rank != 0){
    index_ = reshape_dim_into(0, 0, index_);
  }
  self_.unsqueeze_(-1);
  self_.index_fill_(dim + 1, index_, value);
  self_.squeeze_(-1);

  return std::make_tuple(self_, 0);
}

std::tuple<Tensor, std::optional<int64_t>> index_fill_int_scalar_batch_rule_impl(
    Tensor & self, std::optional<int64_t> self_bdim,
    int64_t dim,
    const Tensor & index, std::optional<int64_t> index_bdim,
    const Scalar & value,
    const bool inplace) {
  const auto self_logical_rank = rankWithoutBatchDim(self, self_bdim);
  const auto index_logical_rank = rankWithoutBatchDim(index, index_bdim);
  Tensor self_ = moveBatchDimToFront(self, self_bdim);
  Tensor index_ = moveBatchDimToFront(index, index_bdim);
  dim = maybe_wrap_dim(dim, self_logical_rank);

  if (inplace && !self_bdim.has_value()) {
    vmapIncompatibleInplaceError("index_fill_");
  }

  if (!index_bdim) {
    if (self_logical_rank == 0){
      self_.unsqueeze_(-1);
    }
    self_.index_fill_(dim + 1, index_, value);
    if (self_logical_rank == 0) {
      self_.squeeze_(-1);
    }
    return std::make_tuple(self_, 0);
  }

  auto batch_size = get_bdim_size2(self, self_bdim, index, index_bdim);
  self_ = ensure_has_bdim(self_, self_bdim.has_value(), batch_size);
  index_ = ensure_has_bdim(index_, index_bdim.has_value(), batch_size);

  if (inplace) {
    // Do for-loop for in-place because we cannot reshape
    // `self_` having an incompatible stride without copying.
    for (const auto i : c10::irange(0, batch_size)) {
      const auto& self_slice = self_.select(0, i);
      const auto& index_slice = index_.select(0, i);
      self_slice.index_fill_(
        dim,
        index_slice,
        value
      );
    }
    return std::make_tuple(self_, 0);
  }

  self_ = self_bdim.has_value() ? self_ : self_.clone();

  return index_fill_batch_rule_helper(batch_size, self_logical_rank, index_logical_rank, self_, dim, index_, value);
}

std::tuple<Tensor, std::optional<int64_t>> index_fill_int_tensor_batch_rule_impl(
    Tensor & self, std::optional<int64_t> self_bdim,
    int64_t dim,
    const Tensor & index, std::optional<int64_t> index_bdim,
    const Tensor & value, std::optional<int64_t> value_bdim,
    const bool inplace) {
  const auto self_logical_rank = rankWithoutBatchDim(self, self_bdim);
  const auto index_logical_rank = rankWithoutBatchDim(index, index_bdim);
  Tensor self_ = moveBatchDimToFront(self, self_bdim);
  Tensor index_ = moveBatchDimToFront(index, index_bdim);
  Tensor value_ = moveBatchDimToFront(value, value_bdim);
  dim = maybe_wrap_dim(dim, self_logical_rank);

  if (inplace && !self_bdim.has_value()) {
    vmapIncompatibleInplaceError("index_fill_");
  }

  if (!index_bdim && !value_bdim) {
    if (self_logical_rank == 0){
      self_.unsqueeze_(-1);
    }
    self_.index_fill_(dim + 1, index_, value);
    if (self_logical_rank == 0) {
      self_.squeeze_(-1);
    }
    return std::make_tuple(self_, 0);
  }

  auto batch_size = get_bdim_size3(self, self_bdim, index, index_bdim, value, value_bdim);
  self_ = ensure_has_bdim(self_, self_bdim.has_value(), batch_size);
  index_ = ensure_has_bdim(index_, index_bdim.has_value(), batch_size);

  if (inplace || value_bdim.has_value()) {
    // Do for-loop for in-place because we cannot reshape
    // `self_` having an incompatible stride without copying.
    // If value has a batch dim, we do for-loop as well because
    // index_fill_ supports 1-element tensor only.
    for (const auto i : c10::irange(0, batch_size)) {
      const auto& self_slice = self_.select(0, i);
      const auto& index_slice = index_.select(0, i);
      self_slice.index_fill_(
        dim,
        index_slice,
        value_bdim.has_value() ? value_.select(0, i) : value_
      );
    }
    return std::make_tuple(self_, 0);
  }

  self_ = self_bdim.has_value() ? self_ : self_.clone();

  // calling .item() on value is safe here because value is guaranteed to not be a batched tensor.
  return index_fill_batch_rule_helper(batch_size, self_logical_rank, index_logical_rank, self_, dim, index_, value.item());
}

void index_fill__int_scalar_batch_rule(
    Tensor & self, std::optional<int64_t> self_bdim,
    int64_t dim,
    const Tensor & index, std::optional<int64_t> index_bdim,
    const Scalar & value) {
  index_fill_int_scalar_batch_rule_impl(self, self_bdim, dim, index, index_bdim, value, true);
}

void index_fill__int_tensor_batch_rule(
    Tensor & self, std::optional<int64_t> self_bdim,
    int64_t dim,
    const Tensor & index, std::optional<int64_t> index_bdim,
    const Tensor & value, std::optional<int64_t> value_bdim) {
  index_fill_int_tensor_batch_rule_impl(self, self_bdim, dim, index, index_bdim, value, value_bdim, true);
}

std::tuple<Tensor, std::optional<int64_t>> index_fill_int_scalar_batch_rule(
    const Tensor & self, std::optional<int64_t> self_bdim,
    int64_t dim,
    const Tensor & index, std::optional<int64_t> index_bdim,
    const Scalar & value) {
  auto self_ = self.clone(at::MemoryFormat::Preserve);
  return index_fill_int_scalar_batch_rule_impl(self_, self_bdim, dim, index, index_bdim, value, false);
}

std::tuple<Tensor, std::optional<int64_t>> index_fill_int_tensor_batch_rule(
    const Tensor & self, std::optional<int64_t> self_bdim,
    int64_t dim,
    const Tensor & index, std::optional<int64_t> index_bdim,
    const Tensor & value, std::optional<int64_t> value_bdim) {
  auto self_ = self.clone(at::MemoryFormat::Preserve);
  return index_fill_int_tensor_batch_rule_impl(self_, self_bdim, dim, index, index_bdim, value, value_bdim, false);
}

}

TORCH_LIBRARY_IMPL(aten, FuncTorchBatched, m) {
  m.impl("index.Tensor", index_plumbing);
  m.impl("index_put_", index_put__plumbing);
  m.impl("index_put", index_put_plumbing);
  m.impl("_index_put_impl_", _index_put_impl__plumbing);
  m.impl("slice_scatter", slice_scatter_decomp);
  m.impl("select_scatter", select_scatter_decomp);
  m.impl("index_copy", index_copy_decomp);
  m.impl("index_select", index_select_decomp);
  VMAP_SUPPORT2(masked_fill, Scalar, masked_fill_scalar_batch_rule);
  VMAP_SUPPORT2(index_fill_, int_Tensor, index_fill__int_tensor_batch_rule);
  VMAP_SUPPORT2(index_fill_, int_Scalar, index_fill__int_scalar_batch_rule);
  VMAP_SUPPORT2(index_fill, int_Tensor, index_fill_int_tensor_batch_rule);
  VMAP_SUPPORT2(index_fill, int_Scalar, index_fill_int_scalar_batch_rule);
  VMAP_SUPPORT(index_add_, index_add__batch_rule);
  VMAP_SUPPORT(index_add, index_add_batch_rule);
  VMAP_SUPPORT(diagonal_scatter, diagonal_scatter_batch_rule);
  VMAP_SUPPORT(gather, gather_batch_rule);
  VMAP_SUPPORT2(scatter, value, scatter_value_batch_rule);
  VMAP_SUPPORT2(scatter, src, scatter_src_batch_rule);
  VMAP_SUPPORT(scatter_add, scatter_add_batch_rule);
  VMAP_SUPPORT2(scatter, reduce, scatter_reduce_batch_rule);
  VMAP_SUPPORT2(scatter, value_reduce, scatter_value_reduce_batch_rule);
  // as_strided_scatter does not work with the for-loop fallback today,
  // because as_strided_scatter will return an output that matches
  // the strides/storage_offset of its input.
  // With the for loop fallback, each input tensor is a slice into
  // the larger batched tensor.
  m.impl("as_strided_scatter", torch::CppFunction::makeFromBoxedFunction<&vmapErrorFallback>());
}

} // namespace at::functorch
