#include <ATen/LegacyVmapTransforms.h>
#include <ATen/ATen.h>
#include <ATen/core/IListRef.h>
#include <c10/util/irange.h>

namespace at {

// Checks if the batch dims in `bdims` appear at the front of the tensor.
static bool areBdimsAtFrontInOrder(BatchDimsRef bdims) {
  for (const auto idx : c10::irange(static_cast<int64_t>(bdims.size()))) {
    if (bdims[idx].dim() != idx) {
      return false;
    }
  }
  return true;
}

// Takes a BatchedTensorImpl, permutes all of the batch dims to the front,
// and then returns a physical version of the Tensor.
static Tensor permuteBatchDimsToFront(BatchedTensorImpl* batched) {
  auto bdims = batched->bdims();
  const Tensor& physical_tensor = batched->value();
  if (areBdimsAtFrontInOrder(bdims)) {
    return physical_tensor;
  }
  const auto sizes = physical_tensor.sizes();
  VmapDimVector permutation(sizes.size(), 0);
  permutation.reserve(sizes.size());
  const auto is_bdim = createBatchDimBitset(bdims);
  int64_t idx = 0;
  for (const auto& bdim : bdims) {
    permutation[idx++] = bdim.dim();
  }
  for (const auto ptr : c10::irange(sizes.size())) {
    if (is_bdim[ptr]) {
      continue;
    }
    permutation[idx++] = ptr;
  }
  return physical_tensor.permute(permutation);
}

VmapPhysicalView MultiBatchVmapTransform::logicalToPhysical(const Tensor& logical_tensor) {
  auto* batched = maybeGetBatchedImpl(logical_tensor);
  TORCH_INTERNAL_ASSERT(
      batched,
      "logicalToPhysical(tensor) should only be passed a BatchedTensor");
  return { permuteBatchDimsToFront(batched), createVmapLevelsBitset(batched->bdims()) };
}

int64_t VmapPhysicalView::numBatchDims() const {
  return levels_.count();
}

int64_t VmapPhysicalView::numLogicalDims() const {
  return /*physical*/tensor_.dim() - numBatchDims();
}

VmapDimVector VmapPhysicalView::getPhysicalDims(OptionalIntArrayRef opt_logical_dims) const {
  auto logical_ndim = numLogicalDims();
  // NB: fmap doesn't have a SmallVector variant, so we don't use it here.
  VmapDimVector result;
  result.reserve(logical_ndim);
  if (opt_logical_dims.has_value() && !opt_logical_dims.value().empty()) {
    auto logical_dims = opt_logical_dims.value();
    for (auto dim : logical_dims) {
      result.push_back(maybe_wrap_dim(dim, logical_ndim) + numBatchDims());
    }
  } else {
    for (int64_t dim = 0; dim < logical_ndim; dim++) {
      result.push_back(dim + numBatchDims());
    }
  }
  return result;
}

int64_t VmapPhysicalView::getPhysicalDim(int64_t logical_dim) const {
  auto logical_ndim = numLogicalDims();
  return maybe_wrap_dim(logical_dim, logical_ndim) + numBatchDims();
}

VmapDimVector VmapPhysicalView::getPhysicalShape(IntArrayRef logical_shape) const {
  VmapDimVector result;
  result.reserve(logical_shape.size() + numBatchDims());
  auto tensor_sizes = tensor_.sizes();
  result.insert(result.end(), tensor_sizes.begin(), tensor_sizes.begin() + numBatchDims());
  result.insert(result.end(), logical_shape.begin(), logical_shape.end());
  return result;
}

static BatchDims computeFrontBatchDimsFromLevels(std::bitset<kVmapNumLevels> levels_bitset) {
  BatchDims bdims;
  int64_t dim = 0;
  for (const auto level : c10::irange(kVmapNumLevels)) {
    if (!levels_bitset[level]) {
      continue;
    }
    bdims.emplace_back(level, dim++);
  }
  return bdims;
}

// Given a Tensor or a BatchedTensor, returns the underlying physical tensor
// with all vmapped dimensions permuted to the front, if they exist, and a
// bitset of vmap levels that were present in the tensor.
static std::pair<Tensor,std::bitset<kVmapNumLevels>>
getPhysicalTensorAndLevels(const Tensor& self) {
  auto* batched = maybeGetBatchedImpl(self);
  if (batched) {
    return {permuteBatchDimsToFront(batched), createVmapLevelsBitset(batched->bdims())};
  }
  return {self, 0};
}

// Given a Tensor or a BatchedTensor, creates a physical view of the tensor
// such that it has a batch dimension for each level in `requested_levels`
// and `requested_example_dim` number of non-batch-dimensions.
//
// This function is useful in preparing physical views on tensors that can
// then be passed into broadcasting operations. For example, when adding
// two BatchedTensors of sizes [B0, 3] and [B0, B1, 2, 3], where the Bi are the
// batch dimensions, we must align the batch dimensions and non-batch-dimensions
// (henceforth referred to as the "example" dimensions) separately to produce
// tensors of size [B0, 1, 1, 3] and [B0, B1, 2, 3] so that they can be added.
//
// Here's a direct example of using alignBatchDimsAtFront on the above two tensors.
//
// 1) alignBatchDimsAtFront([B0, 3], requested_levels={0, 1}, requested_example_dim=2)
// returns a physical view of size [B0, 1, 1, 3] by adding an extra dimension for
// level 1 and another extra dimension to pad the example dimensions to 2.
//
// 2) alignBatchDimsAtFront([B0, B1, 2, 3], requested_levels={0, 1}, requested_example_dim=2)
// returns a physical view of size [B0, B1, 2, 3]
static Tensor alignBatchDimsAtFront(
    const Tensor& self,
    std::bitset<kVmapNumLevels> requested_levels,
    int64_t requested_example_dim) {
  auto [physical_tensor, tensor_levels] = getPhysicalTensorAndLevels(self);

  TORCH_INTERNAL_ASSERT(
    (tensor_levels | requested_levels) == requested_levels,
    "`requested_levels` must be a superset of `self`'s levels");

  auto physical_sizes = physical_tensor.sizes();

  const auto tensor_example_dim = (
    static_cast<int64_t>(physical_sizes.size())
    - /*num_batch_dims*/static_cast<int64_t>(tensor_levels.count())
  );
  TORCH_INTERNAL_ASSERT(tensor_example_dim <= requested_example_dim);

  if (tensor_levels == requested_levels && tensor_example_dim == requested_example_dim) {
    // Optimization: no need to do another view if the physical tensor is
    // already the correct shape
    return physical_tensor;
  }

  VmapDimVector aligned_sizes(requested_levels.count() + requested_example_dim, 1);

  // align the example dims (non-bdims dims) first
  // aligned_sizes[-tensor_example_dim:] = tensor_sizes[-tensor_example_dim:]
  std::copy(
      physical_sizes.rbegin(),
      physical_sizes.rbegin() + tensor_example_dim,
      aligned_sizes.rbegin());

  // align the bdims
  int64_t level = 0;
  int64_t tensor_dim = 0;
  for (const auto bdim : c10::irange(requested_levels.count())) {
    // Determine the level of the bdim
    while (!requested_levels[level]) level++;
    if (tensor_levels[level]) {
      aligned_sizes[bdim] = physical_sizes[tensor_dim++];
    }
    level++;
  }
  return physical_tensor.view(aligned_sizes);
}

// The algorithm is as follows:
// 1. Figure out what all of the collective levels in `logical_tensors` is.
// 2. Move all batch dims to the front of the tensors and add extra dims
//    of size 1. At this point, every tensor will have a dimension for
//    each of the collective levels.
// 3. Compute the batch_sizes.
// 4. Expand each physical tensor so that they have output batch size equal
//    to `batch_sizes`
VmapPhysicalViewVec
MultiBatchVmapTransform::logicalToPhysical(ITensorListRef logical_tensors) {
  // Figure out all of the collective vmap levels in `logical_tensors`.
  std::bitset<kVmapNumLevels> collective_levels;
  for (const auto& logical_tensor : logical_tensors) {
    auto* batched = maybeGetBatchedImpl(logical_tensor);
    if (batched) {
      collective_levels |= createVmapLevelsBitset(batched->bdims());
    }
  }

  // Populate physical_tensors.
  // This contains a list of regular (non-Batched) Tensors where all of the
  // batch dims have been moved to the front of the tensor. Any previously
  // non-existing batch dims get added to the tensors as new dimensions of size 1.
  std::vector<Tensor> physical_tensors;
  int64_t num_batch_dims = collective_levels.count();
  for (const auto& logical_tensor : logical_tensors) {
    auto requested_example_dim = /*logical_dim*/logical_tensor.dim();
    auto physical_tensor = alignBatchDimsAtFront(
        logical_tensor, collective_levels, requested_example_dim);
    physical_tensors.push_back(std::move(physical_tensor));
  }

  // Compute batch_sizes
  VmapDimVector batch_sizes(num_batch_dims, 1);
  for (const auto& physical_tensor : physical_tensors) {
    auto physical_sizes = physical_tensor.sizes();
    for (const auto dim : c10::irange(num_batch_dims)) {
      if (physical_sizes[dim] != 1) {
        batch_sizes[dim] = physical_sizes[dim];
      }
    }
  }

  // Expand each physical_tensor so that it has batch sizes `batch_sizes`
  VmapPhysicalViewVec result;
  for (const auto& physical_tensor : physical_tensors) {
    VmapDimVector expanded_size(batch_sizes.begin(), batch_sizes.end());
    auto physical_sizes = physical_tensor.sizes();
    expanded_size.insert(
        expanded_size.end(),
        physical_sizes.begin() + num_batch_dims,
        physical_sizes.end());
    result.emplace_back(physical_tensor.expand(expanded_size), collective_levels);
  }
  return result;
}

static std::pair<std::bitset<kVmapNumLevels>,int64_t>
getLevelsAndLargestLogicalDim(TensorList logical_tensors) {
  TORCH_INTERNAL_ASSERT(!logical_tensors.empty());
  std::bitset<kVmapNumLevels> levels;
  int64_t largest_logical_dim = -1;
  for (const auto& tensor : logical_tensors) {
    auto* batched = maybeGetBatchedImpl(tensor);
    if (batched) {
      levels = levels | createVmapLevelsBitset(batched->bdims());
    }
    auto tensor_logical_dim = /*logical dim*/tensor.dim();
    if (tensor_logical_dim > largest_logical_dim) {
      largest_logical_dim = tensor_logical_dim;
    }
  }
  return { levels, largest_logical_dim };
}

VmapPhysicalViewVec BroadcastingVmapTransform::logicalToPhysical(TensorList logical_tensors) {
  TORCH_INTERNAL_ASSERT(
      logical_tensors.size() == 2,
      "This function has only been tested for two tensors. Please add more tests ",
      "before removing this check ");

  VmapPhysicalViewVec result;

  auto [levels, largest_logical_dim] = getLevelsAndLargestLogicalDim(logical_tensors);

  for (const auto& tensor : logical_tensors) {
    // NB: It's possible that we didn't actually need to align `tensor`.
    // For example, when adding two tensors of size (B, 2), and (3, 2), where
    // the first Tensor is a BatchedTensor with batch dim B and the second is
    // a regular Tensor, we will return views of size (B, 1, 2) and (1, 3, 2).
    // However, the view on the second tensor is unnecessary: broadcasting
    // semantics allow for the addition of two tensors of size (B, 1, 2) and (3, 2)!
    //
    // If this unnecessary view is a problem, consider optimizing it away in
    // the future. This may involve creating a new type of VmapPhysicalView
    auto aligned = alignBatchDimsAtFront(tensor, levels, largest_logical_dim) ;
    result.emplace_back(std::move(aligned), levels);
  }
  return result;
}

VmapPhysicalToLogicalMap VmapPhysicalView::getPhysicalToLogicalMap() const {
  return VmapPhysicalToLogicalMap(levels_);
}

Tensor VmapPhysicalToLogicalMap::apply(const Tensor& physical_tensor) const {
  return makeBatched(physical_tensor, computeFrontBatchDimsFromLevels(levels_));
}

void VmapPhysicalToLogicalMap::applyInplace(std::vector<Tensor>& physical_tensors) const {
  for (auto & physical_tensor : physical_tensors) {
    physical_tensor = apply(physical_tensor);
  }
}

} // namespace at
