// Copyright (c) Facebook, Inc. and its affiliates.
// All rights reserved.
//
// This source code is licensed under the BSD-style license found in the
// LICENSE file in the root directory of this source tree.

#include <ATen/functorch/BatchRulesHelper.h>
#include <ATen/functorch/PlumbingHelper.h>
#include <ATen/core/dispatch/Dispatcher.h>

namespace at::functorch {

// convolution_batch_rule translated from jax with modifications:
// https://github.com/google/jax/blob/master/jax/_src/lax/lax.py#L3143

// PyTorch's convolution is different from JAX's conv_general_dilated:
// we do not support batch_group_count (which is needed for convolution backwards).
// Instead, there's a convolution_backward op that needs a batching rule.
static std::tuple<Tensor, std::optional<int64_t>>
convolution_batch_rule(const Tensor& lhs, std::optional<int64_t> lhs_bdim, const Tensor& rhs, std::optional<int64_t> rhs_bdim, const std::optional<Tensor>& bias, std::optional<int64_t> bias_bdim, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef dilation, bool transposed, c10::SymIntArrayRef output_padding, c10::SymInt groups) {
  DimVector lhs_spec(stride.size() + 2);
  std::iota(lhs_spec.begin(), lhs_spec.end(), 0);
  DimVector rhs_spec = lhs_spec;
  DimVector out_spec = lhs_spec;
  if (transposed) {
    rhs_spec[0] = 1;
    rhs_spec[1] = 0;
  }

  // If we have a batched bias or weight, we need to perform the computation separately.
  std::optional<Tensor> unbatched_bias;
  bool separate_bias = false;
  if ((rhs_bdim && bias && bias->defined()) || bias_bdim) {
    TORCH_INTERNAL_ASSERT(bias.has_value());
    TORCH_INTERNAL_ASSERT(bias->defined());
    unbatched_bias = std::nullopt;
    separate_bias = true;
  } else {
    unbatched_bias = bias;
    separate_bias = false;
  }
  std::tuple<Tensor, std::optional<int64_t>> result;
  if (lhs_bdim && !rhs_bdim) {
    auto new_x = reshape_dim_into(*lhs_bdim, lhs_spec[0], lhs);
    auto out = at::convolution_symint(new_x, rhs, unbatched_bias, stride, padding, dilation, transposed, output_padding, groups);
    out = reshape_dim_outof_symint(out_spec[0], lhs.sizes()[*lhs_bdim], out);
    result = std::make_tuple(out, out_spec[0]);
  } else if (!lhs_bdim && rhs_bdim) {
    if (groups == 1) {
      auto new_w = reshape_dim_into(*rhs_bdim, rhs_spec[0], rhs);
      auto out = at::convolution_symint(lhs, new_w, unbatched_bias, stride, padding, dilation, transposed, output_padding, groups);
      out = reshape_dim_outof_symint(out_spec[1], rhs.size(*rhs_bdim), out);
      result = std::make_tuple(out, out_spec[1]);
    } else {
      if (transposed) {
        // conv_transpose with groups is normally NIHW, IOHW -> N(GO)HW
        // With RHS batched, we do the following:
        // NIHW, BIOHW -> NIHW, I(BO)HW -> N(GBO)HW -> BN(GO)HW
        // NB: the following isn't written using rhs_spec
        // (PyTorch convs have a fixed dimension order)

        // BIOHW -> I(BO)HW
        auto new_w = reshape_dim_into(*rhs_bdim, 1, rhs);
        // NIHW, I(BO)HW -> N(GBO)HW
        auto out = at::convolution_symint(lhs, new_w, unbatched_bias, stride, padding, dilation, transposed, output_padding, groups);
        // N(GBO)HW -> NG(BO)HW
        out = reshape_dim_outof_symint(1, groups, out);
        // NG(BO)HW -> NGBOHW
        out = reshape_dim_outof_symint(2, rhs.size(*rhs_bdim), out);
        // NGBOHW -> NB(GO)HW
        out = reshape_dim_into(1, 2, out);
        result = std::make_tuple(out, 1);
      } else {
        // conv with groups is normally N(GI)HW, (GO)IHW -> N(GO)HW
        // With RHS batched, we do the following:
        // N(GI)HW, B(GO)IHW -> N(GI)HW, (GBO)IHW -> N(GBO)HW -> BN(GO)HW
        // NB: the following isn't written using rhs_spec
        // (PyTorch convs have a fixed dimension order)

        // B(GO)IHW -> BGOIHW
        auto new_w = reshape_dim_outof_symint(0 + (*rhs_bdim == 0), groups, rhs);
        // BGOIHW -> G(BO)IHW
        new_w = reshape_dim_into(*rhs_bdim + (*rhs_bdim > 0), 1, new_w);
        // G(BO)IHW -> (GBO)IHW
        new_w = reshape_dim_into(0, 0, new_w);
        // N(GI)HW, (GBO)IHW -> N(GBO)HW
        auto out = at::convolution_symint(lhs, new_w, unbatched_bias, stride, padding, dilation, transposed, output_padding, groups);
        // N(GBO)HW -> NG(BO)HW
        out = reshape_dim_outof_symint(1, groups, out);
        // NG(BO)HW -> NGBOHW
        out = reshape_dim_outof_symint(2, rhs.size(*rhs_bdim), out);
        // NGBOHW -> NB(GO)HW
        out = reshape_dim_into(1, 2, out);
        result = std::make_tuple(out, 1);
      }
    }
  } else if (lhs_bdim && rhs_bdim) {
    auto new_x = reshape_dim_into(*lhs_bdim, lhs_spec[1], lhs);
    groups *= lhs.sizes()[*lhs_bdim];
    auto dim_with_groups = transposed ? 1 : 0;
    auto new_w = reshape_dim_into(*rhs_bdim, rhs_spec[dim_with_groups], rhs);
    auto out = at::convolution_symint(new_x, new_w, unbatched_bias, stride, padding, dilation, transposed, output_padding, groups);
    out = reshape_dim_outof_symint(out_spec[1], lhs.sizes()[*lhs_bdim], out);
    result = std::make_tuple(out, out_spec[1]);
  } else {
    result = std::make_tuple(at::convolution_symint(lhs, rhs, unbatched_bias, stride, padding, dilation, transposed, output_padding, groups), std::nullopt);
  }
  if (separate_bias) {
    auto A = std::get<0>(result);
    auto A_batch_dim = std::get<1>(result);
    auto B = *bias;
    auto B_batch_dim = bias_bdim;
    A = moveBatchDimToFront(A, A_batch_dim);
    B = moveBatchDimToFront(B, B_batch_dim);
    for (size_t i = 0; i < out_spec.size() - 2; i++) {
      B = B.unsqueeze(-1);
    }
    B = maybePadToLogicalRank(B, B_batch_dim, rankWithoutBatchDim(A, A_batch_dim));

    return std::make_tuple(at::add(A, B), 0);
  } else {
    return result;
  }
}

static Tensor _convolution_decomp(
    const Tensor& input_r, const Tensor& weight_r, const std::optional<Tensor>& bias_r_opt,
    IntArrayRef stride_, IntArrayRef padding_, IntArrayRef dilation_,
    bool transposed_, IntArrayRef output_padding_, int64_t groups_,
    bool benchmark, bool deterministic, bool cudnn_enabled, bool allow_tf32) {
  // Ignore everything. If the user called this in the normal way,
  // then they should be fine.
  (void) benchmark;
  (void) deterministic;
  (void) cudnn_enabled;
  (void) allow_tf32;
  return at::convolution(
      input_r, weight_r, bias_r_opt, stride_, padding_, dilation_, transposed_, output_padding_, groups_);
}

static Tensor compute_grad_bias(
    const Tensor& grad_output_, std::array<bool, 3> output_mask) {
  if (!output_mask[2]) {
    return Tensor();
  }
  DimVector reduce_dims;
  reduce_dims.resize(grad_output_.dim() - 1);
  reduce_dims[0] = 0;
  std::iota(reduce_dims.begin() + 1, reduce_dims.end(), 2);
  return grad_output_.sum(reduce_dims);
}

// reshapes the batch_size into dim
static Tensor make_dummy(
    const Tensor& tensor, std::optional<int64_t> tensor_bdim,
    int64_t dim, int64_t batch_size) {
  auto tensor_ = tensor_bdim ? tensor.select(*tensor_bdim, 0) : tensor;
  auto orig_size = tensor_.size(dim);
  tensor_ = tensor_.slice(dim, 0, 1);

  DimVector expand_shape(tensor_.sizes().begin(), tensor_.sizes().end());
  expand_shape[dim] = batch_size * orig_size;

  return tensor_.new_empty({}).expand(expand_shape);
}

static std::tuple<Tensor, std::optional<int64_t>>
convolution_backward_input_batch_rule(
    const Tensor& grad_output, std::optional<int64_t> grad_output_bdim,
    const Tensor& input, std::optional<int64_t> input_bdim,
    const Tensor& weight, std::optional<int64_t> weight_bdim,
    c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef dilation, bool transposed,
    c10::SymIntArrayRef output_padding, const c10::SymInt& groups) {
  const std::array<bool, 3> mask = {true, false, false};
  if (grad_output_bdim && weight_bdim) {
    // regular: BNO, BOI -> N(BO), (BO)I -> N(BI)
    // transposed: BNO, BIO -> N(BO), (BI)O -> N(BI)
    const auto batch_size = weight.size(*weight_bdim);
    const auto grad_output_ = reshape_dim_into(*grad_output_bdim, 1, grad_output);
    const auto weight_ = reshape_dim_into(*weight_bdim, 0, weight);
    auto dummy_input = make_dummy(input, input_bdim, 1, batch_size);
    const auto result = at::convolution_backward_symint(
        grad_output_, dummy_input, weight_, std::nullopt, stride, padding,
        dilation, transposed, output_padding, groups * batch_size, mask);
    const auto grad_input = reshape_dim_outof(1, batch_size, std::get<0>(result));
    return std::make_tuple(grad_input, 1);
  } else if (grad_output_bdim && !weight_bdim) {
    // BNO, OI -> (BN)O, OI -> (BN)I
    // transposed is the same.
    const auto batch_size = grad_output.size(*grad_output_bdim);
    const auto grad_output_ = reshape_dim_into(*grad_output_bdim, 0, grad_output);
    auto dummy_input = make_dummy(input, input_bdim, 0, batch_size);
    const auto result = at::convolution_backward_symint(
        grad_output_, dummy_input, weight, std::nullopt, stride, padding,
        dilation, transposed, output_padding, groups, mask);
    const auto grad_input = reshape_dim_outof(0, batch_size, std::get<0>(result));
    return std::make_tuple(grad_input, 0);
  } else if (!grad_output_bdim && weight_bdim) {
    const auto batch_size = weight.size(*weight_bdim);
    if (groups == 1) {
      // regular: NO, BOI -> NO, O(BI) -> N(BI)
      // transposed: NO, BIO -> NO, (BI)O -> N(BI)
      const auto in_ch_dim = transposed ? 0 : 1;
      const auto weight_ = reshape_dim_into(*weight_bdim, in_ch_dim, weight);
      auto dummy_input = make_dummy(input, input_bdim, 1, batch_size);
      const auto result = at::convolution_backward_symint(
          grad_output, dummy_input, weight_, std::nullopt, stride, padding,
          dilation, transposed, output_padding, groups, mask);
      const auto grad_input = reshape_dim_outof(1, batch_size, std::get<0>(result));
      return std::make_tuple(grad_input, 1);
    }
    Tensor grad_input;
    if (!transposed) {
      // N(GO), B(GO)I -> N(GO), (GO)(BI) -> N(GBI)
      const auto weight_ = reshape_dim_into(*weight_bdim, 1, weight);
      auto dummy_input = make_dummy(input, input_bdim, 1, batch_size);
      const auto result = at::convolution_backward_symint(
          grad_output, dummy_input, weight_, std::nullopt, stride, padding,
          dilation, transposed, output_padding, groups, mask);
      grad_input = std::get<0>(result); // N(GBI)
    } else {
      // N(GO), B(GI)O -> N(GO), (GBI)O -> N(GBI)
      auto weight_ = moveBatchDimToFront(weight, weight_bdim); // B(GI)O
      weight_ = reshape_dim_outof_symint(1, groups, weight_);         // BGIO
      weight_ = weight_.transpose(0, 1);                       // GBIO
      weight_ = weight_.flatten(0, 2);                         // (GBI)O
      const auto dummy_input = make_dummy(input, input_bdim, 1, batch_size);
      const auto result = at::convolution_backward_symint(
          grad_output, dummy_input, weight_, std::nullopt, stride, padding,
          dilation, transposed, output_padding, groups, mask);
      grad_input = std::get<0>(result); // N(GBI)
    }
    // N(GBI) -> NG(BI) -> NGBI -> NBGI -> NB(GI)
    grad_input = reshape_dim_outof_symint(1, groups, grad_input);
    grad_input = reshape_dim_outof_symint(2, batch_size, grad_input);
    grad_input = grad_input.transpose(1, 2);
    grad_input = reshape_dim_into(2, 2, grad_input);
    return std::make_tuple(grad_input, 1);
  } else {
    TORCH_INTERNAL_ASSERT(input_bdim);
    const auto dummy_input = make_dummy(input, input_bdim, 0, 1);
    const auto result = at::convolution_backward_symint(
        grad_output, dummy_input, weight, std::nullopt, stride, padding,
        dilation, transposed, output_padding, groups, mask);
    return std::make_tuple(std::get<0>(result), std::nullopt);
  }
}
static std::tuple<Tensor, std::optional<int64_t>>
convolution_backward_weight_batch_rule(
    const Tensor& grad_output, std::optional<int64_t> grad_output_bdim,
    const Tensor& input, std::optional<int64_t> input_bdim,
    const Tensor& weight, std::optional<int64_t> weight_bdim,
    c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef dilation, bool transposed,
    c10::SymIntArrayRef output_padding, const c10::SymInt& groups) {
  const std::array<bool, 3> mask = {false, true, false};
  if (grad_output_bdim && input_bdim) {
    // BNO, BNI -> N(BO), N(BI) -> (BO)I (regular) (BI)O (transposed)
    const auto batch_size = input.size(*input_bdim);
    const auto grad_output_ = reshape_dim_into(*grad_output_bdim, 1, grad_output);
    const auto input_ = reshape_dim_into(*input_bdim, 1, input);
    const auto dummy_weight = make_dummy(weight, weight_bdim, 0, batch_size);
    const auto result = at::convolution_backward_symint(
        grad_output_, input_, dummy_weight, std::nullopt, stride, padding,
        dilation, transposed, output_padding, groups * batch_size, mask);
    auto grad_weight = std::get<1>(result);
    grad_weight = reshape_dim_outof_symint(0, batch_size, grad_weight);
    return std::make_tuple(grad_weight, 0);
  } else if (grad_output_bdim && !input_bdim) {
    const auto batch_size = grad_output.size(*grad_output_bdim);
    if (groups == 1) {
      // regular: BNO, NI -> N(BO), NI -> (BO)I
      // transposed: BNO, NI -> N(BO), NI -> I(BO)
      const auto grad_output_ = reshape_dim_into(*grad_output_bdim, 1, grad_output);
      const auto out_ch_dim = transposed ? 1 : 0;
      const auto dummy_weight = make_dummy(weight, weight_bdim, out_ch_dim, batch_size);
      const auto result = at::convolution_backward_symint(
          grad_output_, input, dummy_weight, std::nullopt, stride, padding,
          dilation, transposed, output_padding, groups, mask);
      auto grad_weight = std::get<1>(result);
      grad_weight = reshape_dim_outof_symint(out_ch_dim, batch_size, grad_weight);
      return std::make_tuple(grad_weight, out_ch_dim);
    } else {
      auto grad_output_ = moveBatchDimToFront(grad_output, grad_output_bdim); // BN(GO)
      grad_output_ = reshape_dim_outof_symint(2, groups, grad_output_);              // BNGO
      grad_output_ = grad_output_.movedim(0, 2);                              // NGBO
      grad_output_ = grad_output_.flatten(1, 3);                              // N(GBO)
      if (!transposed) {
        // BN(GO), N(GI) -> N(GBO), N(GI) -> (GBO)I
        const auto dummy_weight = make_dummy(weight, weight_bdim, 0, batch_size);
        const auto result = at::convolution_backward_symint(
            grad_output_, input, dummy_weight, std::nullopt, stride, padding,
            dilation, transposed, output_padding, groups, mask);
        auto grad_weight = std::get<1>(result);
        grad_weight = grad_weight.unflatten_symint(0, { groups, batch_size, -1 }); // GBOI
        grad_weight = grad_weight.transpose(0, 1);                          // BGOI
        grad_weight = grad_weight.flatten(1, 2);                            // B(GO)I
        return std::make_tuple(grad_weight, 0);
      } else {
        // BN(GO), N(GI) -> N(GBO), N(GI) -> (GI)(BO)
        const auto dummy_weight = make_dummy(weight, weight_bdim, 1, batch_size);
        const auto result = at::convolution_backward_symint(
            grad_output_, input, dummy_weight, std::nullopt, stride, padding,
            dilation, transposed, output_padding, groups, mask);
        auto grad_weight = std::get<1>(result);
        grad_weight = reshape_dim_outof_symint(1, batch_size, grad_weight);
        return std::make_tuple(grad_weight, 1);
      }
    }
  } else if (!grad_output_bdim && input_bdim) {
    const auto batch_size = input.size(*input_bdim);
    if (groups == 1) {
      // regular: NO, BNI -> NO, N(BI) -> O(BI)
      // transposed: NO, BNI -> NO, N(BI) -> (BI)O
      const auto input_ = reshape_dim_into(*input_bdim, 1, input);
      const auto in_ch_dim = transposed ? 0 : 1;
      const auto dummy_weight = make_dummy(weight, weight_bdim, in_ch_dim, batch_size);
      const auto result = at::convolution_backward_symint(
          grad_output, input_, dummy_weight, std::nullopt, stride, padding,
          dilation, transposed, output_padding, groups, mask);
      auto grad_weight = std::get<1>(result);
      grad_weight = reshape_dim_outof_symint(in_ch_dim, batch_size, grad_weight);
      return std::make_tuple(grad_weight, in_ch_dim);
    } else {
      auto input_ = moveBatchDimToFront(input, input_bdim); // BN(GI)
      input_ = reshape_dim_outof_symint(2, groups, input_);        // BNGI
      input_ = input_.movedim(0, 2);                        // NGBI
      input_ = input_.flatten(1, 3);                        // N(GBI)
      if (!transposed) {
        // regular: N(GO), BN(GI) -> N(GO), N(GBI) -> (GO)(BI)
        const auto dummy_weight = make_dummy(weight, weight_bdim, 1, batch_size);
        const auto result = at::convolution_backward_symint(
            grad_output, input_, dummy_weight, std::nullopt, stride, padding,
            dilation, transposed, output_padding, groups, mask);
        auto grad_weight = std::get<1>(result);
        grad_weight = reshape_dim_outof_symint(1, batch_size, grad_weight);
        return std::make_tuple(grad_weight, 1);
      } else {
        // transposed: N(GO), BN(GI) -> N(GO), N(GBI) -> (GBI)O
        const auto dummy_weight = make_dummy(weight, weight_bdim, 0, batch_size);
        const auto result = at::convolution_backward_symint(
            grad_output, input_, dummy_weight, std::nullopt, stride, padding,
            dilation, transposed, output_padding, groups, mask);
        auto grad_weight = std::get<1>(result);
        grad_weight = grad_weight.unflatten_symint(0, { groups, batch_size, -1 }); // GBIO
        grad_weight = grad_weight.transpose(0, 1);                          // BGIO
        grad_weight = grad_weight.flatten(1, 2);                            // B(GI)O
        return std::make_tuple(grad_weight, 0);
      }
    }
  } else {
    TORCH_INTERNAL_ASSERT(weight_bdim);
    const auto dummy_weight = make_dummy(weight, weight_bdim, 0, 1);
    const auto result = at::convolution_backward_symint(
        grad_output, input, dummy_weight, std::nullopt, stride, padding,
        dilation, transposed, output_padding, groups, mask);
    return std::make_tuple(std::get<1>(result), std::nullopt);

  }
}

static std::tuple<Tensor,Tensor,Tensor> convolution_backward_plumbing(
    const Tensor& grad_output_, const Tensor& input_, const Tensor& weight_,
    const c10::OptionalArrayRef<SymInt> bias_sizes_opt,
    c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef dilation, bool transposed,
    c10::SymIntArrayRef output_padding, c10::SymInt groups, std::array<bool, 3> output_mask) {
  const auto maybe_layer = maybeCurrentDynamicLayer();
  vmap_check_escaped(maybe_layer, "convolution_backward_plumbing");
  int64_t cur_level = maybe_layer->layerId();

  if (!areAnyBatchedAtLevel({grad_output_, input_, weight_}, cur_level)){
    c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched);
    return at::convolution_backward_symint(
        grad_output_, input_, weight_, bias_sizes_opt, stride, padding,
        dilation, transposed, output_padding, groups, output_mask);
  }

  auto [grad_output, grad_output_bdim] = unwrapTensorAtLevel(grad_output_, cur_level);
  auto [input, input_bdim] = unwrapTensorAtLevel(input_, cur_level);
  auto [weight, weight_bdim] = unwrapTensorAtLevel(weight_, cur_level);

  const auto grad_bias = compute_grad_bias(grad_output_, output_mask);
  output_mask[2] = false;

  // TODO: A little bird says that unfold + matmul is actually faster than
  // group convolution in many cases. We should benchmark some of
  // the common cases and replace things with unfold + matmul as necessary.

  // Notation:
  // B - a batch dimension
  // G - groups (sometimes omitted because it doesn't matter)
  // NO - grad_output
  // NI - input
  // OI - weight
  // "(BO)I" - we don't actually care about the values of this Tensor,
  //           we just need to create a tensor on the same device with the
  //           correct shape and pray that the implementation is smart enough
  //           to not do anything with it.

  // BNO, BNI, BOI
  // AKA one of the model ensembling case
  if (grad_output_bdim && input_bdim && weight_bdim) {
    c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched);
    grad_output = reshape_dim_into(*grad_output_bdim, 1, grad_output);

    // BNO, BNI, BOI -> N(BO), N(BI), (BO)I
    const auto batch_size = weight.size(*weight_bdim);
    input = reshape_dim_into(*input_bdim, 1, input);
    weight = reshape_dim_into(*weight_bdim, 0, weight);
    const auto result = at::convolution_backward_symint(
        grad_output, input, weight, std::nullopt, stride, padding, dilation,
        transposed, output_padding, batch_size * groups, output_mask);
    // N(BI), (BO)I -> NBI, BOI
    const auto grad_input = output_mask[0] ?
      reshape_dim_outof(1, batch_size, std::get<0>(result)) : Tensor();
    const auto grad_weight = output_mask[1] ?
      reshape_dim_outof(0, batch_size, std::get<1>(result)) : Tensor();
    return std::make_tuple(
        output_mask[0] ? makeBatched(grad_input, 1, cur_level) : grad_input,
        output_mask[1] ? makeBatched(grad_weight, 0, cur_level) : grad_weight,
        grad_bias);
  }

  Tensor grad_input;
  if (output_mask[0]) {
    c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched);
    const auto result = convolution_backward_input_batch_rule(
        grad_output, grad_output_bdim,
        input, input_bdim,
        weight, weight_bdim,
        stride, padding, dilation, transposed, output_padding, groups);
    grad_input = makeBatched(std::get<0>(result), std::get<1>(result), cur_level);
  }

  Tensor grad_weight;
  if (output_mask[1]) {
    c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched);
    const auto result = convolution_backward_weight_batch_rule(
        grad_output, grad_output_bdim,
        input, input_bdim,
        weight, weight_bdim,
        stride, padding, dilation, transposed, output_padding, groups);
    grad_weight = makeBatched(std::get<0>(result), std::get<1>(result), cur_level);
  }
  return std::make_tuple(grad_input, grad_weight, grad_bias);

  // Someone's definitely going to find a problem with this batching rule so
  // I'm leaving the following fallback if we need it back.
  // static auto op = c10::Dispatcher::singleton()
  //   .findSchemaOrThrow("aten::convolution_backward", "");
  // auto result = slow_fallback<Tensor,Tensor,Tensor>(op, {
  //   grad_output_, input_, weight_, bias_sizes_opt,
  //   stride, padding, dilation, transposed, output_padding, groups, output_mask
  // });
  // return std::make_tuple(grad_input, std::get<1>(result), grad_bias);
}


TORCH_LIBRARY_IMPL(aten, FuncTorchBatched, m) {
  VMAP_SUPPORT(convolution, convolution_batch_rule);
  m.impl("_convolution", _convolution_decomp);
  m.impl("convolution_backward", convolution_backward_plumbing);
}

} // namespace at;:functorch
