#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
#include <ATen/core/Tensor.h>
#include <ATen/Config.h>
#include <ATen/Parallel.h>
#include <ATen/TensorOperators.h>
#include <ATen/native/ConvolutionMM3d.h>
#include <ATen/native/ConvUtils.h>
#include <ATen/native/Pool.h>
#include <ATen/native/cpu/DepthwiseConvKernel.h>
#include <ATen/native/utils/ParamUtils.h>
#include <ATen/native/xnnpack/Engine.h>
#include <c10/core/GradMode.h>
#include <c10/util/accumulate.h>
#include <c10/util/irange.h>
#include <c10/macros/Macros.h>
#include <limits>
#include <utility>

#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/Functions.h>
#else
#include <ATen/ops/permute.h>
#endif

#if AT_NNPACK_ENABLED()
#include <nnpack.h>
#endif

#if AT_MKLDNN_ENABLED()
#include <ATen/native/mkldnn/Utils.h>
#endif

#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/Functions.h>
#include <ATen/NativeFunctions.h>
#else
#include <ATen/ops/_conv_depthwise2d.h>
#include <ATen/ops/_convolution.h>
#include <ATen/ops/_convolution_double_backward_native.h>
#include <ATen/ops/_convolution_mode.h>
#include <ATen/ops/_convolution_mode_native.h>
#include <ATen/ops/_convolution_native.h>
#include <ATen/ops/_mps_convolution.h>
#include <ATen/ops/_mps_convolution_transpose.h>
#include <ATen/ops/_nnpack_available.h>
#include <ATen/ops/_nnpack_spatial_convolution.h>
#include <ATen/ops/_slow_conv2d_backward.h>
#include <ATen/ops/_unsafe_view.h>
#include <ATen/ops/cat.h>
#include <ATen/ops/constant_pad_nd.h>
#include <ATen/ops/conv1d_native.h>
#include <ATen/ops/conv2d_native.h>
#include <ATen/ops/conv3d_native.h>
#include <ATen/ops/conv_depthwise3d.h>
#include <ATen/ops/conv_transpose1d_native.h>
#include <ATen/ops/conv_transpose2d_native.h>
#include <ATen/ops/conv_transpose3d_native.h>
#include <ATen/ops/convolution.h>
#include <ATen/ops/convolution_backward_native.h>
#include <ATen/ops/convolution_backward_overrideable.h>
#include <ATen/ops/convolution_backward_overrideable_native.h>
#include <ATen/ops/convolution_native.h>
#include <ATen/ops/convolution_overrideable.h>
#include <ATen/ops/convolution_overrideable_native.h>
#include <ATen/ops/cudnn_convolution.h>
#include <ATen/ops/cudnn_convolution_transpose.h>
#include <ATen/ops/empty.h>
#include <ATen/ops/empty_like.h>
#include <ATen/ops/empty_native.h>
#include <ATen/ops/miopen_convolution.h>
#include <ATen/ops/miopen_convolution_transpose.h>
#include <ATen/ops/miopen_depthwise_convolution.h>
#include <ATen/ops/mkldnn_convolution.h>
#include <ATen/ops/mps_convolution_backward.h>
#include <ATen/ops/mps_convolution_transpose_backward.h>
#include <ATen/ops/slow_conv3d.h>
#include <ATen/ops/slow_conv_dilated2d.h>
#include <ATen/ops/slow_conv_dilated3d.h>
#include <ATen/ops/slow_conv_transpose2d.h>
#include <ATen/ops/slow_conv_transpose3d.h>
#include <ATen/ops/thnn_conv2d.h>
#include <ATen/ops/view_as_real.h>
#include <ATen/ops/zeros.h>
#include <ATen/ops/zeros_like.h>
#endif

constexpr int MIOPEN_DIM_MAX = 5;

namespace at::native {


static bool conv_benchmark_empty_cache = true;

// Check workload to activate fast depthwise FP16 cudnn conv kernels
template <typename T>
bool check_cudnn_depthwise_workload(const at::Tensor& input, T stride) {
  auto w = at::symint::size<T>(input, 3);  // same as h
  auto ch = at::symint::size<T>(input, 1);
  auto bs = at::symint::size<T>(input, 0);
  if (stride==1) {
    if (w >= 7) {
      // All batch sizes and nb_channels
      if (w >= 112) {
        return true;
      }

      // large nb_channels
      if (ch >= 1024) {
        // NOLINTNEXTLINE(bugprone-branch-clone,cppcoreguidelines-avoid-magic-numbers)
        if (w >= 56) {
          return true;
        } else if (bs >= 32) {
          return true;
        }
      }

      // batch_size specific
      if (bs >= 128) {
        // NOLINTNEXTLINE(bugprone-branch-clone,cppcoreguidelines-avoid-magic-numbers)
        if (ch >= 512) {
          return true;
        } else if (ch >= 64) {
          if (w >= 14) {
            return true;
          }
        } else if ((ch >= 32) && (w >=28)) {
          return true;
        }
      } else if (bs >= 64) {
        // NOLINTNEXTLINE(bugprone-branch-clone,cppcoreguidelines-avoid-magic-numbers)
        if ((ch >= 256) && (w >= 14)) {
          return true;
        } else if ((ch >= 32) && (w >= 28)) {
          return true;
        }
      } else if (bs >= 32) {
        // NOLINTNEXTLINE(bugprone-branch-clone,cppcoreguidelines-avoid-magic-numbers)
        if ((ch >= 256) && (w >= 14)) {
          return true;
        } else if ((ch >= 128) && (w >= 28)) {
          return true;
        } else if ((ch >= 32) && (w >= 56)) {
          return true;
        }
      } else if (bs >= 16) {
        if ((ch >= 1024) && (w >= 14)) {
          return true;
        }
        // NOLINTNEXTLINE(bugprone-branch-clone,cppcoreguidelines-avoid-magic-numbers)
        if ((ch >= 256) && (w >= 28)) {
          return true;
        } else if ((ch >= 32) && (w >= 56)) {
          return true;
        }
      } else if (bs >= 8) {
        // NOLINTNEXTLINE(bugprone-branch-clone,cppcoreguidelines-avoid-magic-numbers)
        if ((ch >= 512) && (w >= 28)) {
          return true;
        } else if ((ch >= 64) && (w >= 56)) {
          return true;
        }
      }
    }
  } else if (stride==2) {
    if (ch < 256) {
      return false;
    }

    if (w >= 7) {
      if (bs >= 128) {
        // NOLINTNEXTLINE(bugprone-branch-clone,cppcoreguidelines-avoid-magic-numbers)
        if (ch >= 1024) {
          return true;
        } else if ((ch >= 512) && (w >= 14)) {
          return true;
        } else if (w >= 28) {
          return true;
        }
      } else if (bs >= 64) {
        // NOLINTNEXTLINE(bugprone-branch-clone,cppcoreguidelines-avoid-magic-numbers)
        if ((ch >= 512) && (w >= 14)) {
          return true;
        } else if (w >= 28) {
          return true;
        }
      } else if (bs >= 32) {
        // NOLINTNEXTLINE(bugprone-branch-clone,cppcoreguidelines-avoid-magic-numbers)
        if ((ch >= 1024) && (w >= 14)) {
          return true;
        } else if (w >= 28) {
          return true;
        }
      } else if (bs >= 16) {
        // NOLINTNEXTLINE(bugprone-branch-clone,cppcoreguidelines-avoid-magic-numbers)
        if ((ch >= 512) && (w >= 28)) {
          return true;
        } else if (w >= 56) {
          return true;
        }
      } else if (bs >= 8) {
        // NOLINTNEXTLINE(bugprone-branch-clone,cppcoreguidelines-avoid-magic-numbers)
        if ((ch >= 1024) && (w >= 28)) {
          return true;
        } else if (w >= 56) {
          return true;
        }
      } else if (bs >= 1) {
        if ((ch >= 512) && (w >=112)) {
          return true;
        }
      }
    }
  }
  return false;
}

// simplified version for cudnn 8.2 and above
template <typename T>
bool check_cudnn_depthwise_workload_with_filter(const at::Tensor& input, T stride, const at::Tensor& weight) {
  // 1D conv
  if(at::symint::size<T>(input, 2) == 1 && stride == 1){
    return true;
  }

  // 2d conv
  // only square filters
  if (at::symint::size<T>(weight, 2) != at::symint::size<T>(weight, 3)) return false;
  auto filter = at::symint::size<T>(weight, 3);
  // only 1/3/5 filter
  if (filter != 1 && filter != 3 && filter != 5) return false;
  // we don't enforce square input but only check width to reduce heuristic space
  if (at::symint::size<T>(input, 3) < 7) return false; // min width 7
  auto w = at::symint::size<T>(input, 3);
  // only 1/2 stride, use cudnn for all stride 1
  if (stride == 1) return true;
  if (stride != 2) return false;

  auto ch = at::symint::size<T>(input, 1);
  auto bs = at::symint::size<T>(input, 0);
  // special case since bs1 show good perf in lots of cases
  if (bs == 1) {
    if (filter == 1 && w <= 28) return true;
    if (filter == 3 || filter == 5) return true;
  } else {
    if (filter == 1 && bs <= 16 && ch >= 128 && w <= 7) return true;
    if (filter == 3 || filter == 5) {
      if ((ch >= 512) || (ch >= 256 && w >= 28)) return true;
    }
  }
  return false;
}


#if defined(C10_MOBILE)
static bool xnnpack_use_convolution2d(
    const Tensor& input,
    const Tensor& weight,
    const at::OptionalIntArrayRef bias_sizes_opt,
    const IntArrayRef padding,
    const IntArrayRef stride,
    const IntArrayRef dilation,
    const int64_t groups,
    const bool transposed) {
  return xnnpack::use_convolution2d(input, weight, bias_sizes_opt, padding, stride, dilation, groups, transposed);
}

static bool xnnpack_use_convolution2d(
    const Tensor& input,
    const Tensor& weight,
    const at::OptionalSymIntArrayRef bias_sizes_opt,
    const SymIntArrayRef padding,
    const SymIntArrayRef stride,
    const SymIntArrayRef dilation,
    const c10::SymInt groups,
    const bool transposed) {
  // Never use xnnpack for symbolic tracing
  return false;
}
#endif

// This struct is templated so that we can run backend selection in a dynamic
// shapes context; all of the real kernel selection in eager mode runs with
// int64_t
template <typename T>
struct ConvParams {
  std::vector<T> stride;
  std::vector<T> padding;
  std::vector<T> dilation;
  bool transposed{};
  std::vector<T> output_padding;
  T groups{};
  bool benchmark{};
  bool deterministic{};
  bool cudnn_enabled{};
  bool allow_tf32{};

  bool is_strided() const {
    bool is_strided = false;
    for (const auto& s : stride) {
      is_strided |= (s != 1);
    }
    return is_strided;
  }

  bool is_dilated() const {
    bool is_dilated = false;
    for (const auto& d : dilation) {
      is_dilated |= (d != 1);
    }
    return is_dilated;
  }

  bool is_padded() const {
    bool is_padded = false;
    for (auto p : padding) {
      is_padded |= (p != 0);
    }
    return is_padded;
  }

  bool is_output_padding_neg() const {
    bool is_non_neg = false;
    for (const auto& p : output_padding) {
      is_non_neg |= (p < 0);
    }
    return is_non_neg;
  }

  bool is_output_padding_big() const {
    bool is_big = false;
    for (auto i: c10::irange(output_padding.size())) {
      is_big |= (output_padding[i] >= stride[i]);
    }
    return is_big;
  }

  bool is_padding_neg() const {
    bool is_non_neg = false;
    for (const auto& p : padding) {
      is_non_neg |= (p < 0);
    }
    return is_non_neg;
  }

  bool is_dilation_neg() const {
    bool is_non_neg = false;
    for (const auto& p : dilation) {
      is_non_neg |= (p < 0);
    }
    return is_non_neg;
  }

  bool is_stride_nonpos() const {
    bool is_nonpos = false;
    for (const auto& s : stride) {
      is_nonpos |= (s <= 0);
    }
    return is_nonpos;
  }

  void view1d_as_2d() {
    if (stride.size() == 1) {
      stride.insert(stride.begin(), 1);
      padding.insert(padding.begin(), 0);
      dilation.insert(dilation.begin(), 1);
      output_padding.insert(output_padding.begin(), 0);
    }
  }

  bool use_cpu_depthwise3x3_winograd(const at::Tensor& input, const at::Tensor& weight, const std::optional<at::Tensor>& bias) const {
#if defined(__ARM_NEON__) || (defined(__riscv_v_intrinsic) && __riscv_v_intrinsic>=12000)
    // Currently only 3x3 depthwise convolutions on tensors of float are supported.
    return (input.ndimension() == 4) &&
           (at::symint::size<T>(input, 1) == groups) &&
           (weight.ndimension() == 4 ) &&
           (at::symint::size<T>(weight, 0) % at::symint::size<T>(input, 1) == 0) &&
           (at::symint::size<T>(weight, 1) == 1) &&
           (at::symint::size<T>(weight, 2) == 3) &&
           (at::symint::size<T>(weight, 3) == 3) &&
           (input.device().is_cpu()) &&
           (input.scalar_type() == at::kFloat) &&
           input.is_contiguous() &&
           (weight.device().is_cpu()) &&
           (weight.scalar_type() == at::kFloat) &&
           weight.is_contiguous() &&
           (!bias.has_value() || bias->is_contiguous()) &&
           !is_strided() &&
           !is_dilated() &&
           !transposed;
#else
    return false;
#endif
  }

  bool needs_64bit_indexing_no_split(const at::Tensor& input, const at::Tensor& weight) const {
    constexpr int64_t int_max = std::numeric_limits<int>::max();
    auto numel_input = at::symint::numel<T>(input);
    // empty input
    if (numel_input == 0) {
      return false;
    }
    // input size can not be reduced to the range of int by splitting the batch dim
    auto n = at::symint::size<T>(input, 0);
    if (numel_input / n > int_max) {
      return true;
    }
    // output size can not be reduced to the range of int by splitting the batch dim
    T outsize = 1;
    if (transposed) {
      auto o = conv_input_size(at::symint::sizes<T>(input), at::symint::sizes<T>(weight), padding, output_padding, stride, dilation, groups);
      outsize = c10::multiply_integers(o.begin() + 1, o.end());
    } else {
      auto o = conv_output_size(at::symint::sizes<T>(input), at::symint::sizes<T>(weight), padding, stride, dilation);
      outsize = c10::multiply_integers(o.begin() + 1, o.end());
    }
    return outsize > int_max;
  }

  bool use_cudnn(const at::Tensor& input, const at::Tensor& weight) const {
  // Note [Mobile check segfaults]
  // cudnn and miopen are guaranteed not to be on mobile, and T102591915 / T110194934 suggest
  // that maybe the compiledWithCuDNN() check sometimes segfaults (though I can't imagine how)
#if !defined(C10_MOBILE)
    if (!detail::getCUDAHooks().compiledWithCuDNN()) {
      return false;
    }
    if (needs_64bit_indexing_no_split(input, weight)) {
      static long cudnn_version = detail::getCUDAHooks().versionCuDNN();
      if (!(cudnn_version >= 90300 && at::native::cudnnv8_enabled_check_debug())) {
        TORCH_WARN_ONCE("cuDNN cannot be used for large non-batch-splittable convolutions"
                        " if the V8 API is not enabled or before cuDNN version 9.3+."
                        " Consider upgrading cuDNN and/or enabling the V8 API for better efficiency.");
        return false;
      }
    }
    if (!input.is_cuda() || !cudnn_enabled) {
      return false;
    }
    if (input.scalar_type() == at::kBFloat16 || weight.scalar_type() == at::kBFloat16) {
      if (!(detail::getCUDAHooks().supportsBFloat16ConvolutionWithCuDNNv8() && at::native::cudnnv8_enabled_check_debug())) {
        return false;
      }
    }
    if (cudnn_conv_suggest_memory_format(input, weight) == at::MemoryFormat::Contiguous) {
      // bypass dilation checks for channels_last convolution
      if (deterministic && is_dilated()) {
        // cudnn doesn't support deterministic dilated convolution fully yet
        return false;
      }
      if (is_dilated()) {
        return detail::getCUDAHooks().supportsDilatedConvolutionWithCuDNN() && !is_output_padding_big();
      }
    }
    return !is_output_padding_big();
#else
    return false;
#endif
  }

  // Use cudnn for FP16 depthwise convolutions
  bool use_cudnn_depthwise(const at::Tensor& input, const at::Tensor& weight) const  {
    if (cudnn_conv_suggest_memory_format(input, weight) != at::MemoryFormat::Contiguous && use_cudnn(input, weight)) {
      // always use cudnn_depthwise for channels_last format
      return true;
    }
    if (detail::getCUDAHooks().supportsDepthwiseConvolutionWithCuDNN()) {
      long cudnn_version = detail::getCUDAHooks().versionCuDNN();
      if (cudnn_version >= 8200) {
        bool kernel_cond =  (use_cudnn(input, weight) &&
                             input.scalar_type() == kHalf && // only for FP16
                             weight.scalar_type() == kHalf &&
                             is_depthwise(input, weight) &&
                             input.ndimension() == 4 &&   // TODO: 5-D contiguous depthwise is not supported yet, need benchmarks
                             !is_dilated() && // no dilation supported
                             (stride[0] == stride[1] || at::symint::size<T>(input, 2) == 1) && // square or 1d
                             at::symint::size<T>(input, 1) >= 32); // min 32 channels supported)
        if (kernel_cond) {
          return check_cudnn_depthwise_workload_with_filter<T>(input, stride[1], weight);
        }
      }
      // keep (7600 <= cudnn < 8200) code unchanged
      bool kernel_cond =  (cudnn_version >= 7600 &&
                           use_cudnn(input, weight) &&
                           input.scalar_type() == kHalf && // only for FP16
                           weight.scalar_type() == kHalf &&
                           is_depthwise(input, weight) &&
                           input.ndimension() == 4 &&   // TODO: 5-D contiguous depthwise is not supported yet, need benchmarks
                           at::symint::size<T>(weight, 2) == at::symint::size<T>(weight, 3) && // only square kernels
                           at::symint::size<T>(input, 2) >= 7 && // min width/height 7
                           !is_dilated() && // no dilation supported
                           stride[0] == stride[1] && // equal strides
                           ((at::symint::size<T>(weight, 3) == 3) || (at::symint::size<T>(weight, 3) == 1)) &&
                           at::symint::size<T>(input, 1) >= 32); // min 32 channels supported)
      if (kernel_cond) {
        return check_cudnn_depthwise_workload<T>(input, stride[0]);
      } else {
        return false;
      }
    } else {
      return false;
    }
  }

  bool use_miopen(const at::Tensor& input, const at::Tensor& weight, bool bias_defined) const  {
    if (needs_64bit_indexing_no_split(input, weight)) {
      return false;
    }
    return ((input.scalar_type() == at::kFloat) || (input.scalar_type() == at::kHalf) || (input.scalar_type() == at::kBFloat16))
           && cudnn_enabled
           && input.is_cuda()
           && detail::getCUDAHooks().compiledWithMIOpen()
           && input.dim() <= MIOPEN_DIM_MAX
           && !(groups > 1 && is_dilated()) // MIOpen currently does not support dilation with groups of size > 1
           ;
  }
  bool use_mkldnn(const at::Tensor& input, const at::Tensor& weight) const  {
#if AT_MKLDNN_ENABLED()
    if (!at::globalContext().userEnabledMkldnn()) {
      return false;
    }
    if (transposed && is_output_padding_big()) {
      return false;
    }
    if (input.device().is_cpu() &&
        ((input.scalar_type() == at::kBFloat16 && mkldnn_bf16_device_check()) ||
         (input.scalar_type() == at::kHalf && mkldnn_fp16_device_check()))) {
      return true;
    }
    return (input.is_mkldnn()) || // input is mkldnn Tensor
      (input.device().is_cpu() &&
       input.scalar_type() == kFloat && // only on CPU Float Tensors
       // For 1x1 filters, MKLDNN is faster than THNN when multi-threaded,
       // but THNN is faster when single-threaded.
       (is_strided() || is_dilated() || at::symint::size<T>(input, 0) >= 16 ||
        at::symint::size<T>(weight, -1) != 1 || at::symint::size<T>(weight, -2) != 1 || at::get_num_threads() > 1) &&
       (groups > 1
        || (at::symint::size<T>(weight, -1) > 3 && at::symint::size<T>(weight, -2) > 3)
        || at::symint::size<T>(input, 0) > 1
        || at::symint::size<T>(input, 0)*at::symint::size<T>(input, 1)*at::symint::size<T>(input, 2)*at::symint::size<T>(input, 3) > 20480) // for some case, native is faster
        );

#endif
    return false;
  }
  bool use_nnpack(const at::Tensor& input, const at::Tensor& weight) const  {
#if AT_NNPACK_ENABLED()
    return at::globalContext().userEnabledNNPACK() &&
           at::_nnpack_available() &&
           input.device().is_cpu() &&
           input.scalar_type() == kFloat && // only on CPU Float Tensors
           !is_dilated() && // or dilation
           !transposed &&   // or transposed tensors
           input.ndimension() == 4 && // must be in NCHW format
           weight.ndimension() == 4 &&
           (at::symint::size<T>(weight, 2) < 17) && (at::symint::size<T>(weight, 3) < 17) && // NNPACK only supports kernels up to 16x16
           (padding[0] < at::symint::size<T>(weight, 2)) && (padding[1] < at::symint::size<T>(weight, 3)) // NNPACK only supports padding < kernel_size. See https://github.com/pytorch/pytorch/issues/90142.
#if !defined(C10_MOBILE)
           && at::symint::size<T>(input, 0) >= 16 // ensure large enough batch size to ensure perf, tuneable
#endif
       ;
#endif
    return false;
  }
  bool use_xnnpack(const at::Tensor& input, const at::Tensor& weight,
                   const at::OptionalArrayRef<T> bias_sizes_opt) const {
#if defined(C10_MOBILE)
    if (!transposed) {
      // NB: for the call here, it MATTERS that we are templated. If you
      // untemplate this to always use SymInt, the function
      // xnnpack_use_convolution2d will always return false
      return (at::symint::size<T>(input, 1) == groups) &&
              xnnpack_use_convolution2d(
                  input,
                  weight,
                  bias_sizes_opt,
                  padding,
                  stride,
                  dilation,
                  groups,
                  transposed);
    }
#endif
    return false;
  }

  bool use_mps(const at::Tensor& input, const at::Tensor& weight) const {
    // These checks need to be expanded. Currently we have very limited set of
    // checks for MPS.
#ifdef USE_MPS
    if (needs_64bit_indexing_no_split(input, weight)) {
      return false;
    }
    if (!input.is_mps()) {
      return false;
    }
    return true;
#else
    return false;
#endif
  }

  // We currently only have depthwise support for the case where groups ==
  // nInputPlane and nInputPlane == nOutputPlane (the latter due to the lack of
  // a depthwise multiplier)
  bool is_depthwise(const at::Tensor& input, const at::Tensor& weight) const  {
    return input.is_cuda() &&
           !transposed &&
           (input.ndimension() == 4 || input.ndimension() == 5) &&
           at::symint::size<T>(input, 1) == groups &&
           groups > 1 && // no point if there is only a single group
           at::symint::size<T>(weight, 0) % at::symint::size<T>(input, 1) == 0; // output channels must be a multiple of input channels
  }
};

DEFINE_DISPATCH(conv_depthwise2d_backward_stub);
DEFINE_DISPATCH(conv_depthwise3d_backward_stub);
DEFINE_DISPATCH(cudnn_convolution_backward_stub);
DEFINE_DISPATCH(cudnn_convolution_transpose_backward_stub);
DEFINE_DISPATCH(slow_conv_transpose3d_backward_stub);
DEFINE_DISPATCH(convolution_depthwise3x3_winograd_stub);
DEFINE_DISPATCH(miopen_convolution_backward_stub);
DEFINE_DISPATCH(miopen_convolution_transpose_backward_stub);
DEFINE_DISPATCH(miopen_depthwise_convolution_backward_stub);
DEFINE_DISPATCH(mkldnn_convolution_backward_stub);
DEFINE_DISPATCH(mkldnn_convolution_transpose_stub);
DEFINE_DISPATCH(mkldnn_convolution_transpose_backward_stub);
DEFINE_DISPATCH(slow_conv_dilated2d_backward_stub);
DEFINE_DISPATCH(slow_conv_dilated3d_backward_stub);
DEFINE_DISPATCH(slow_conv_transpose2d_backward_stub);
REGISTER_NO_CPU_DISPATCH(conv_depthwise2d_backward_stub);
REGISTER_NO_CPU_DISPATCH(conv_depthwise3d_backward_stub);
REGISTER_NO_CPU_DISPATCH(cudnn_convolution_backward_stub);
REGISTER_NO_CPU_DISPATCH(cudnn_convolution_transpose_backward_stub);
REGISTER_NO_CPU_DISPATCH(miopen_convolution_backward_stub);
REGISTER_NO_CPU_DISPATCH(miopen_convolution_transpose_backward_stub);
REGISTER_NO_CPU_DISPATCH(miopen_depthwise_convolution_backward_stub);

template <typename T>
std::ostream& operator<<(std::ostream & out, const ConvParams<T>& params) {
  out << "ConvParams {"
      << "  stride = " << IntArrayRef{params.stride}
      << "  padding = " << ArrayRef<T>{params.padding}
      << "  dilation = " << IntArrayRef{params.dilation}
      << "  transposed = " << params.transposed
      << "  output_padding = " << ArrayRef<T>{params.output_padding}
      << "  groups = " << params.groups
      << "  benchmark = " << params.benchmark
      << "  deterministic = " << params.deterministic
      << "  cudnn_enabled = " << params.cudnn_enabled
      << "  allow_tf32 = " << params.allow_tf32
      << "}";
  return out;
}

template <typename T>
static void check_shape_forward(const at::Tensor& input,
                                const c10::ArrayRef<T>& weight_sizes, const at::Tensor& bias,
                                const ConvParams<T>& params) {
  int64_t k = input.ndimension();
  int64_t weight_dim = weight_sizes.size();
  auto groups = params.groups;
  const auto& padding = params.padding;
  const auto& dilation = params.dilation;
  bool transposed = params.transposed;

  TORCH_CHECK(!params.is_padding_neg(), "negative padding is not supported");
  TORCH_CHECK(!params.is_output_padding_neg(), "negative output_padding is not supported");
  TORCH_CHECK(!params.is_stride_nonpos(), "non-positive stride is not supported");
  TORCH_CHECK(!params.is_dilation_neg(), "dilation should be greater than zero");

  TORCH_CHECK(weight_dim == k,
           "Expected ", weight_dim, "-dimensional input for ", weight_dim,
           "-dimensional weight ", weight_sizes, ", but got ", k, "-dimensional input of size ",
           at::symint::sizes<T>(input), " instead");
  TORCH_CHECK(weight_sizes[0] >= groups,
           "Given groups=", groups, ", expected weight to be at least ", groups,
           " at dimension 0, but got weight of size ", weight_sizes, " instead");
  TORCH_CHECK(weight_sizes[0] % groups == 0,
           "Given groups=", groups, ", expected weight to be divisible by ",
           groups, " at dimension 0, but got weight of size [", weight_sizes,
           "] instead");

  if (!transposed) {
    std::vector<T> input_shape;
    std::vector<T> kernel_shape;
    bool kernel_size_correct = true;

    TORCH_CHECK(at::symint::size<T>(input, 1) == (weight_sizes[1] * groups),
                "Given groups=", groups, ", weight of size ", weight_sizes,
                ", expected input", at::symint::sizes<T>(input), " to have ",
                (weight_sizes[1] * groups), " channels, but got ", at::symint::size<T>(input, 1),
                " channels instead");

    TORCH_CHECK(!bias.defined() || (bias.ndimension() == 1 && at::symint::size<T>(bias, 0) == weight_sizes[0]),
             "Given weight of size ", weight_sizes,
             ", expected bias to be 1-dimensional with ", weight_sizes[0], " elements",
             ", but got bias of size ", at::symint::sizes<T>(bias), " instead");

    for (const auto i : c10::irange(2, k)) {
      input_shape.push_back(at::symint::size<T>(input, i) + 2 * padding[i-2]);
      // log new kernel size considering dilation
      kernel_shape.push_back(dilation[i-2] * (weight_sizes[i]-1) + 1);
      if (input_shape.back() < kernel_shape.back()) {
        kernel_size_correct = false;
      }
    }

    TORCH_CHECK(input_shape.size() == kernel_shape.size(), "Inconsistent shape between Input and Kernel");

    if (!kernel_size_correct) {
      // If kernel size is incorrect
      std::ostringstream input_ss;
      std::ostringstream kernel_ss;
      std::string separator = "";

      for (int i = 0, len = input_shape.size(); i < len; ++i) {
        input_ss << separator << input_shape[i];
        kernel_ss << separator << kernel_shape[i];
        separator = " x ";
      }

      AT_ERROR("Calculated padded input size per channel: (", input_ss.str(), "). "
               "Kernel size: (", kernel_ss.str(), "). Kernel size can't be greater than actual input size");
    }
  } else { // transposed
    TORCH_CHECK(at::symint::size<T>(input, 1) == weight_sizes[0],
             "Given transposed=", transposed, ", weight of size ", weight_sizes,
             ", expected input", at::symint::sizes<T>(input), " to have ", weight_sizes[0],
             " channels, but got ", at::symint::size<T>(input, 1), " channels instead");
    TORCH_CHECK(!bias.defined() || (bias.ndimension() == 1 && at::symint::size<T>(bias, 0) == weight_sizes[1] * groups),
             "Given transposed=", transposed, ", weight of size ", weight_sizes,
             ", expected bias to be 1-dimensional with ", weight_sizes[1] * groups, " elements",
             ", but got bias of size ", at::symint::sizes<T>(bias), " instead");
  }
}

template <typename T>
static void check_shape_backward(
    const at::Tensor& input,
    const c10::ArrayRef<T>& weight_sizes,
    const ConvParams<T>& params) {
  check_shape_forward<T>(input, weight_sizes, /*bias=*/ Tensor(), params);
}

// Given an input tensor and an expected number of spatial dimensions, checks that the
// input is a valid shape and returns the batched form of the input.
//
// Args:
//     input (Tensor): Input tensor
//     num_spatial_dims (int): Number of spatial dimensions expected for the input
//     func_name (string): Function name to produce a nice error message for invalid input
//
// Returns a std::tuple containing:
//     batched_input (Tensor): Input with a batch dimension
//     is_batched (bool): Indicates whether the original input was already batched
static std::tuple<Tensor, bool> batchify(
    const Tensor& input,
    const int64_t num_spatial_dims,
    const std::string& func_name) {
  // assume NTs are always batched
  if (input.is_nested()) {
    return std::make_tuple(input, true);
  }
  const auto dim_count_no_batch = num_spatial_dims + 1;
  const auto dim_count_batch = dim_count_no_batch + 1;
  const auto is_batched = (input.dim() == dim_count_batch);
  TORCH_CHECK(input.dim() == dim_count_no_batch || is_batched,
      "Expected ", dim_count_no_batch, "D (unbatched) or ", dim_count_batch,
      "D (batched) input to ", func_name, ", but got input of size: ", input.sizes());
  return std::make_tuple(is_batched ? input : input.unsqueeze(0), is_batched);
}

static void check_input_same_type_as_parameters(
    const Tensor& input,
    const Tensor& weight,
    const Tensor& bias) {
  TORCH_CHECK(input.options().type_equal(weight.options()),
      "Input type (", input.toString(), ") and weight type (", weight.toString(),
      ") should be the same");
  TORCH_CHECK(!bias.defined() || (input.options().type_equal(bias.options())),
      "Input type (", input.toString(), ") and bias type (", bias.toString(),
      ") should be the same");
}

static void check_input_same_type_as_parameters(
    const Tensor& input,
    const Tensor& weight) {
  check_input_same_type_as_parameters(input, weight, /*bias=*/ Tensor());
}

#if AT_MKLDNN_ENABLED()
static void check_input_same_type_as_parameters(
    const Tensor& input,
    const Tensor& weight,
    const Tensor& bias,
    const ConvBackend backend) {
  if (backend == ConvBackend::Mkldnn || backend == ConvBackend::MkldnnTranspose) {
    TORCH_CHECK(input.options().type_equal(weight.options())
        || (input.is_mkldnn() && weight.device().is_cpu() && weight.scalar_type() == kFloat),
        "Input type (", input.toString(), ") and weight type (", weight.toString(),
        ") should be the same or input should be a MKLDNN tensor and weight is a dense tensor");
    TORCH_CHECK(!bias.defined() || (input.options().type_equal(bias.options()))
        || (input.is_mkldnn() && bias.device().is_cpu() && bias.scalar_type() == kFloat),
        "Input type (", input.toString(), ") and bias type (", bias.toString(),
        ") should be the same or input should be a MKLDNN tensor and bias is a dense tensor");
  } else {
    check_input_same_type_as_parameters(input, weight, bias);
  }
}
#endif

static auto view4d(const at::Tensor& tensor) -> at::Tensor {
  TORCH_CHECK(tensor.ndimension() == 3,
           "expected 3D tensor, got tensor with ", tensor.ndimension(),
           " dimensions instead");
  return tensor.unsqueeze(2);
}

static auto view3d(const at::Tensor& tensor) -> at::Tensor {
  TORCH_CHECK(tensor.ndimension() == 4,
           "expected 4D tensor, got tensor with ", tensor.ndimension(),
           " dimensions instead");
  return tensor.squeeze(2);
}

static at::Tensor subtensor(at::Tensor& tensor, int64_t dim, int64_t groups, int64_t g) {
  if (!tensor.defined()) {
    return at::Tensor();
  }
  const auto memory_format = tensor.suggest_memory_format();
  int64_t n = tensor.sizes()[dim] / groups;
  return tensor.narrow(dim, n * g, n).contiguous(memory_format);
}

namespace {

std::pair<Tensor, Tensor> complex_to_real(const Tensor& inp) {
  auto inp_view_as_complex = at::view_as_real(inp);
  auto dim_i = inp_view_as_complex.dim() - 1;
  auto i_r = inp_view_as_complex.select(dim_i, 0);
  auto i_i = inp_view_as_complex.select(dim_i, 1);
  return std::make_pair(i_r, i_i);
}

at::Tensor complex_convolution(
    const Tensor& input,
    const Tensor& weight,
    const Tensor& bias,
    SymIntArrayRef stride,
    SymIntArrayRef padding,
    SymIntArrayRef dilation,
    bool transposed,
    SymIntArrayRef output_padding,
    const c10::SymInt& groups) {
  check_input_same_type_as_parameters(input, weight, bias);
  auto [i_r, i_i] = complex_to_real(input.resolve_conj());
  auto [w_r, w_i] = complex_to_real(weight.resolve_conj());

  // [NOTE] Complex Convolution
  // conv(W, x, b) = conv(Wr, xr, br) - conv(Wi, xi, 0) + i(conv(Wi, xr, bi) + conv(Wr, xi, 0))
  // where W, x and b are all complex inputs.
  // With Gauss Trick:
  // a = conv(Wr, xr, br),
  // b = conv(Wi, xi, 0),
  // c = conv(Wr + Wi, xr + xi, bi + br)
  // conv(W, x, b) = a - b + i(c - a - b)
  Tensor a, b, c;
  if (!bias.defined()) {
    a = at::convolution_symint(i_r, w_r, bias, stride, padding, dilation, transposed, output_padding, groups);
    b = at::convolution_symint(i_i, w_i, bias, stride, padding, dilation, transposed, output_padding, groups);
    c = at::convolution_symint(i_r + i_i, w_r + w_i, bias, stride, padding, dilation, transposed, output_padding, groups);
  } else {
    auto [b_r, b_i] = complex_to_real(bias.resolve_conj());
    a = at::convolution_symint(i_r, w_r, b_r, stride, padding, dilation, transposed, output_padding, groups);
    b = at::convolution_symint(i_i, w_i, Tensor(), stride, padding, dilation, transposed, output_padding, groups);
    c = at::convolution_symint(i_r + i_i, w_r + w_i, b_r + b_i, stride, padding, dilation, transposed, output_padding, groups);
  }

  auto i = c10::Scalar(c10::complex<double>(0, 1));
  return a - b + i * (c - a - b);
}

at::Tensor complex_convolution_mode(
    const at::Tensor& input,
    const at::Tensor& weight,
    const std::optional<at::Tensor>& bias_opt,
    c10::SymIntArrayRef stride,
    c10::string_view padding,
    c10::SymIntArrayRef dilation,
    const c10::SymInt& groups) {
  auto bias = bias_opt.value_or(Tensor());
  check_input_same_type_as_parameters(input, weight, bias);
  auto [i_r, i_i] = complex_to_real(input.resolve_conj());
  auto [w_r, w_i] = complex_to_real(weight.resolve_conj());

  // See [NOTE] Complex Convolution
  Tensor a, b, c;
  if (!bias.defined()) {
    a = at::_convolution_mode_symint(i_r, w_r, bias, stride, padding, dilation, groups);
    b = at::_convolution_mode_symint(i_i, w_i, bias, stride, padding, dilation, groups);
    c = at::_convolution_mode_symint(i_r + i_i, w_r + w_i, bias, stride, padding, dilation, groups);
  } else {
    auto [b_r, b_i] = complex_to_real(bias.resolve_conj());
    a = at::_convolution_mode_symint(i_r, w_r, b_r, stride, padding, dilation, groups);
    b = at::_convolution_mode_symint(i_i, w_i, Tensor(), stride, padding, dilation, groups);
    c = at::_convolution_mode_symint(i_r + i_i, w_r + w_i, b_r + b_i, stride, padding, dilation, groups);
  }

  auto i = c10::Scalar(c10::complex<double>(0, 1));
  return a - b + i * (c - a - b);
}

} // namespace

at::Tensor conv1d_symint(
    const Tensor& input_, const Tensor& weight, const std::optional<Tensor>& bias_opt,
    SymIntArrayRef stride, SymIntArrayRef padding, SymIntArrayRef dilation, c10::SymInt groups) {
  // See [Note: hacky wrapper removal for optional tensor]
  c10::MaybeOwned<Tensor> bias_maybe_owned = at::borrow_from_optional_tensor(bias_opt);
  const Tensor& bias = *bias_maybe_owned;

  TORCH_CHECK(
    !bias.defined() || bias.dtype() == input_.dtype(),
    "Input type (",
    input_.dtype().name(),
    ") and bias type (",
    bias.dtype().name(),
    ") should be the same");

  auto [input, is_batched] = batchify(input_, /*num_spatial_dims=*/ 1, "conv1d");
  Tensor output;
  if (at::isComplexType(input_.scalar_type())) {
    output = complex_convolution(input, weight, bias, stride, padding, dilation, false, {0}, groups);
  } else {
    output = at::convolution_symint(input, weight, bias, stride, padding, dilation, false, {0}, groups);
  }
  return is_batched ? std::move(output) : output.squeeze(0);
}

at::Tensor conv2d_symint(
    const Tensor& input_, const Tensor& weight, const std::optional<Tensor>& bias_opt,
    SymIntArrayRef stride, SymIntArrayRef padding, SymIntArrayRef dilation, c10::SymInt groups) {
  // See [Note: hacky wrapper removal for optional tensor]
  c10::MaybeOwned<Tensor> bias_maybe_owned = at::borrow_from_optional_tensor(bias_opt);
  const Tensor& bias = *bias_maybe_owned;

  TORCH_CHECK(
    !bias.defined() || bias.dtype() == input_.dtype(),
    "Input type (",
    input_.dtype().name(),
    ") and bias type (",
    bias.dtype().name(),
    ") should be the same");

  auto [input, is_batched] = batchify(input_, /*num_spatial_dims=*/ 2, "conv2d");
  Tensor output;
  if (at::isComplexType(input_.scalar_type())) {
    output = complex_convolution(input, weight, bias, stride, padding, dilation, false, {{0, 0}}, groups);
  } else {
    output = at::convolution_symint(input, weight, bias, stride, padding, dilation, false, {{0, 0}}, groups);
  }
  return is_batched ? std::move(output) : output.squeeze(0);
}

at::Tensor conv3d_symint(
    const Tensor& input_, const Tensor& weight, const std::optional<Tensor>& bias_opt,
    SymIntArrayRef stride, SymIntArrayRef padding, SymIntArrayRef dilation, c10::SymInt groups) {
  // See [Note: hacky wrapper removal for optional tensor]
  c10::MaybeOwned<Tensor> bias_maybe_owned = at::borrow_from_optional_tensor(bias_opt);
  const Tensor& bias = *bias_maybe_owned;

  TORCH_CHECK(
    !bias.defined() || bias.dtype() == input_.dtype(),
    "Input type (",
    input_.dtype().name(),
    ") and bias type (",
    bias.dtype().name(),
    ") should be the same");

  auto [input, is_batched] = batchify(input_, /*num_spatial_dims=*/ 3, "conv3d");
  Tensor output;
  if (at::isComplexType(input_.scalar_type())) {
    output = complex_convolution(input, weight, bias, stride, padding, dilation, false, {{0, 0, 0}}, groups);
  } else {
    output = at::convolution_symint(input, weight, bias, stride, padding, dilation, false, {{0, 0, 0}}, groups);
  }
  return is_batched ? std::move(output) : output.squeeze(0);
}


static Tensor convolution_same(
    const Tensor &input, const Tensor &weight, const Tensor &bias,
    SymIntArrayRef stride, SymIntArrayRef dilation, const c10::SymInt& groups) {

  auto k = weight.dim();
  TORCH_CHECK(k > 2, "weight should have at least three dimensions");
  TORCH_CHECK(groups > 0, "non-positive groups is not supported");
  auto dim = static_cast<size_t>(k - 2);
  auto weight_sizes = weight.sym_sizes();
  auto input_sizes = input.sym_sizes();
  TORCH_CHECK(k == input.dim(),
              "Expected ", k, "-dimensional input for ",
              k, "-dimensional weight", weight_sizes, ", but got ",
              input.dim(), "-dimensional input of size ",
              input.sizes(), " instead");
  TORCH_CHECK(stride.size() == dim || stride.size() == 1U,
              "stride cannot broadcast to ", dim, " dimensions");
  TORCH_CHECK(dilation.size() == dim || dilation.size() == 1U,
              "dilation cannot broadcast to ", dim, " dimensions");
  for (auto i: c10::irange(stride.size())) {
    TORCH_CHECK(stride[i] == 1, "padding='same' is not supported for strided convolutions");
  }

  // Calculate the correct padding
  SymDimVector padding_l, padding_r;
  bool symmetric_padding = true;
  for (auto i: c10::irange(dim)) {
    auto s = stride.size() == 1 ? stride[0] : stride[i];
    auto d = dilation.size() == 1 ? dilation[0] : dilation[i];
    auto pad = pooling_same_mode_padding_lr(
        input_sizes[i + 2], weight_sizes[i + 2], s, d);
    padding_l.push_back(pad.first);
    padding_r.push_back(pad.second);
    if (pad.first != pad.second) {
      symmetric_padding = false;
    }
  }

  if (symmetric_padding) {
    // All backends handle symmetric padding natively
    SymDimVector output_padding(static_cast<size_t>(dim));
    return at::convolution_symint(input, weight, bias, stride, padding_l, dilation,
                               false, output_padding, groups);
  }

  TORCH_WARN_ONCE("Using padding='same' with even kernel lengths and odd dilation may"
                  " require a zero-padded copy of the input be created");
  SmallVector<c10::SymInt, kDimVectorStaticSize * 2> pad_nd(static_cast<size_t>(2 * dim));
  for (auto i: c10::irange(dim)) {
    // Apply padding by the difference, leaving only a symmetric padding
    auto delta_pad = padding_r[i] - padding_l[i];
    auto pad_idx = 2 * (dim - 1 - i);  // F.pad goes from last dim to first
    if (delta_pad > 0) {
      pad_nd[pad_idx + 1] = delta_pad;
    } else {
      pad_nd[pad_idx] = delta_pad;
      padding_l[i] = padding_r[i];
    }
  }
  auto padded_input = at::constant_pad_nd_symint(input, pad_nd, 0);
  SymDimVector output_padding(static_cast<size_t>(dim));
  return at::convolution_symint(padded_input, weight, bias, stride, padding_l,
                                dilation, false, output_padding, groups);
}

Tensor _convolution_mode_symint(
    const Tensor& input, const Tensor& weight, const std::optional<Tensor>& bias_opt,
    SymIntArrayRef stride, c10::string_view padding, SymIntArrayRef dilation,
    c10::SymInt groups) {
  // See [Note: hacky wrapper removal for optional tensor]
  c10::MaybeOwned<Tensor> bias_maybe_owned = at::borrow_from_optional_tensor(bias_opt);
  const Tensor& bias = *bias_maybe_owned;

  if (padding == "same") {
    return at::native::convolution_same(
        input, weight, bias, stride, dilation, groups);
  } else if (padding == "valid") {
    return at::convolution_symint(
        input, weight, bias, stride, {{0}}, dilation, false, {{0}}, groups);
  }
  TORCH_CHECK(false, "Invalid padding string: '", padding, "'");
}

at::Tensor conv1d_padding_symint(
    const Tensor& input_, const Tensor& weight, const std::optional<Tensor>& bias,
    c10::SymIntArrayRef stride, c10::string_view padding, c10::SymIntArrayRef dilation,
    c10::SymInt groups) {
  auto [input, is_batched] = batchify(input_, /*num_spatial_dims=*/ 1, "conv1d");
  Tensor output;
  if (at::isComplexType(input_.scalar_type())) {
    output = complex_convolution_mode(input, weight, bias, stride, padding, dilation, groups);
  } else {
    output = at::_convolution_mode_symint(input, weight, bias, stride, padding, dilation, groups);
  }
  return is_batched ? std::move(output) : output.squeeze(0);
}

at::Tensor conv2d_padding_symint(
    const Tensor& input_, const Tensor& weight, const std::optional<Tensor>& bias,
    c10::SymIntArrayRef stride, c10::string_view padding, c10::SymIntArrayRef dilation,
    c10::SymInt groups) {
  auto [input, is_batched] = batchify(input_, /*num_spatial_dims=*/ 2, "conv2d");
  Tensor output;
  if (at::isComplexType(input_.scalar_type())) {
    output = complex_convolution_mode(input, weight, bias, stride, padding, dilation, groups);
  } else {
    output = at::_convolution_mode_symint(input, weight, bias, stride, padding, dilation, groups);
  }
  return is_batched ? std::move(output) : output.squeeze(0);
}

at::Tensor conv3d_padding_symint(
    const Tensor& input_, const Tensor& weight, const std::optional<Tensor>& bias,
    c10::SymIntArrayRef stride, c10::string_view padding, c10::SymIntArrayRef dilation,
    c10::SymInt groups) {
  auto [input, is_batched] = batchify(input_, /*num_spatial_dims=*/ 3, "conv3d");
  Tensor output;
  if (at::isComplexType(input_.scalar_type())) {
    output = complex_convolution_mode(input, weight, bias, stride, padding, dilation, groups);
  } else {
    output = at::_convolution_mode_symint(input, weight, bias, stride, padding, dilation, groups);
  }
  return is_batched ? std::move(output) : output.squeeze(0);
}

at::Tensor conv_transpose1d_symint(
    const Tensor& input_, const Tensor& weight, const std::optional<Tensor>& bias_opt,
    SymIntArrayRef stride, SymIntArrayRef padding, SymIntArrayRef output_padding, c10::SymInt groups, SymIntArrayRef dilation) {
  // See [Note: hacky wrapper removal for optional tensor]
  c10::MaybeOwned<Tensor> bias_maybe_owned = at::borrow_from_optional_tensor(bias_opt);
  const Tensor& bias = *bias_maybe_owned;

  auto [input, is_batched] = batchify(input_, /*num_spatial_dims=*/ 1, "conv_transpose1d");
  Tensor output;
  if (at::isComplexType(input_.scalar_type())) {
    output = complex_convolution(
      input, weight, bias, stride, padding, dilation, true, output_padding, groups);
  } else {
    output = at::convolution_symint(
      input, weight, bias, stride, padding, dilation, true, output_padding, groups);
  }
  return is_batched ? std::move(output) : output.squeeze(0);
}

at::Tensor conv_transpose2d_symint(
    const Tensor& input_, const Tensor& weight, const std::optional<Tensor>& bias_opt,
    SymIntArrayRef stride, SymIntArrayRef padding, SymIntArrayRef output_padding, c10::SymInt groups, SymIntArrayRef dilation) {
  // See [Note: hacky wrapper removal for optional tensor]
  c10::MaybeOwned<Tensor> bias_maybe_owned = at::borrow_from_optional_tensor(bias_opt);
  const Tensor& bias = *bias_maybe_owned;

  auto [input, is_batched] = batchify(input_, /*num_spatial_dims=*/ 2, "conv_transpose2d");
  Tensor output;
  if (at::isComplexType(input_.scalar_type())) {
    output = complex_convolution(
      input, weight, bias, stride, padding, dilation, true, output_padding, groups);
  } else {
    output = at::convolution_symint(
      input, weight, bias, stride, padding, dilation, true, output_padding, groups);
  }
  return is_batched ? std::move(output) : output.squeeze(0);
}

at::Tensor conv_transpose3d_symint(
    const Tensor& input_, const Tensor& weight, const std::optional<Tensor>& bias_opt,
    SymIntArrayRef stride, SymIntArrayRef padding, SymIntArrayRef output_padding, c10::SymInt groups, SymIntArrayRef dilation) {
  // See [Note: hacky wrapper removal for optional tensor]
  c10::MaybeOwned<Tensor> bias_maybe_owned = at::borrow_from_optional_tensor(bias_opt);
  const Tensor& bias = *bias_maybe_owned;

  auto [input, is_batched] = batchify(input_, /*num_spatial_dims=*/ 3, "conv_transpose3d");
  Tensor output;
  if (at::isComplexType(input_.scalar_type())) {
    output = complex_convolution(
      input, weight, bias, stride, padding, dilation, true, output_padding, groups);
  } else {
    output = at::convolution_symint(
      input, weight, bias, stride, padding, dilation, true, output_padding, groups);
  }
  return is_batched ? std::move(output) : output.squeeze(0);
}

at::Tensor convolution(
    const Tensor& input, const Tensor& weight, const std::optional<Tensor>& bias_opt,
    IntArrayRef stride, IntArrayRef padding, IntArrayRef dilation,
    bool transposed, IntArrayRef output_padding, int64_t groups) {
  // See [Note: hacky wrapper removal for optional tensor]
  c10::MaybeOwned<Tensor> bias_maybe_owned = at::borrow_from_optional_tensor(bias_opt);
  const Tensor& bias = *bias_maybe_owned;

  auto& ctx = at::globalContext();
  // See Note [Enabling Deterministic Operations]
  bool deterministic = ctx.deterministicCuDNN() || ctx.deterministicAlgorithms();
  return at::_convolution(input, weight, bias, stride, padding, dilation,
                          transposed, output_padding, groups,
                          ctx.benchmarkCuDNN(), deterministic, ctx.userEnabledCuDNN(), ctx.allowTF32CuDNN());
}

at::Tensor convolution_overrideable(
    const Tensor& input, const Tensor& weight, const std::optional<Tensor>& bias_opt,
    IntArrayRef stride, IntArrayRef padding, IntArrayRef dilation,
    bool transposed, IntArrayRef output_padding, int64_t groups) {
  TORCH_CHECK_NOT_IMPLEMENTED(false, "convolution_overrideable not implemented. You are likely triggering this with tensor backend other than CPU/CUDA/MKLDNN, if this is intended, please use TORCH_LIBRARY_IMPL to override this function ");
}

// Function to select the convolution backend based on the inputs and params.
// This overload is used within the convolution internals but not exposed to python.
// NB: The forward pass provides a bias tensor while the backward pass provides
// a bool indicating whether the bias is defined. This is done to save memory by
// avoiding saving the full bias tensor for backward.
template <typename T>
ConvBackend _select_conv_backend(
    const Tensor& input,
    const Tensor& weight,
    const std::optional<Tensor>& bias,
    const at::OptionalArrayRef<T> bias_sizes_opt,
    const bool need_backward,
    const ConvParams<T>& params) {

  // don't send empty inputs through backends
  if (at::symint::size<T>(input, 0) == 0 || at::symint::size<T>(input, 1) == 0) {
    return input.is_mkldnn() ? ConvBackend::MkldnnEmpty : ConvBackend::Empty;
  } else if (at::symint::numel<T>(input) == 0) {
    TORCH_CHECK(false, "Only zero batch or zero channel inputs are supported, but got input shape: ", at::symint::sizes<T>(input));
  }

  if (params.is_depthwise(input, weight)) {
    if (params.use_cudnn_depthwise(input, weight)) {
      return ConvBackend::Cudnn;
    } else if (params.use_miopen(input, weight, bias_sizes_opt.has_value())) {
      return ConvBackend::MiopenDepthwise;
    } else {
      if (input.ndimension() == 4) {
        return ConvBackend::CudaDepthwise2d;
      } else if (input.ndimension() == 5) {
        return ConvBackend::CudaDepthwise3d;
      } else {
        // unsupported
      }
    }
  } else if (params.use_cudnn(input, weight)) {
    if (params.transposed) {
      return ConvBackend::CudnnTranspose;
    } else {
      return ConvBackend::Cudnn;
    }
  } else if (params.use_miopen(input, weight, bias_sizes_opt.has_value())) {
    if (params.transposed) {
      return ConvBackend::MiopenTranspose;
    } else {
      return ConvBackend::Miopen;
    }
  } else if (params.use_mkldnn(input, weight)) {
    if (params.transposed) {
      return ConvBackend::MkldnnTranspose;
    } else {
      return ConvBackend::Mkldnn;
    }
  } else if (!need_backward && params.use_xnnpack(input, weight, bias_sizes_opt)) {
    // Using prepacked conv is preferred, but XNNPACK is still the fastest
    // option for NHWC.
    return ConvBackend::Xnnpack2d;
  // 3x3 depthwith convolutions implementation is inference only
  } else if (!need_backward && params.use_cpu_depthwise3x3_winograd(input, weight, bias)) {
    return ConvBackend::Winograd3x3Depthwise;
  } else if (
      !params.transposed && (input.ndimension() == 5) &&
      (input.device().is_cpu()) &&
      !params.is_dilated()) {
    // fast path for grouped conv3d
    return ConvBackend::Slow3d;
  } else if (input.device().is_cpu() || input.is_cuda()) {
    // backends without support for groups
    if (params.transposed) {
      if (input.ndimension() == 4) {
        return ConvBackend::SlowTranspose2d;
      } else if (input.ndimension() == 5) {
        return ConvBackend::SlowTranspose3d;
      } else {
        // unsupported
      }
    } else {  /* Not transposed */
      if (input.ndimension() == 4) {
        if (params.is_dilated()) {
          return ConvBackend::SlowDilated2d;
        } else {  /* dim == 4, non-dilated */
          if (params.use_nnpack(input, weight)) {
            return ConvBackend::NnpackSpatial;
          } else {
            /* CPU implementation has specialized MM kernels
               for non-dilated case here */
            return ConvBackend::Slow2d;
          }
        }
      } else if (input.ndimension() == 5 && (input.is_cuda() || params.is_dilated())) {
        return ConvBackend::SlowDilated3d;
      } else if (input.ndimension() == 5) { /* dim == 5, CPU, non-dilated */
        /* CPU implementation has specialized MM kernels
           for non-dilated case here */
        return ConvBackend::Slow3d;
      } else {
        // unsupported
      }
    }
  } else if (params.use_mps(input, weight)) {
    if (params.transposed) {
      return ConvBackend::MpsTranspose;
    } else {
      return ConvBackend::Mps;
    }
  } else {
    // Only reach here when input is backend with out-of-source implementation.
    return ConvBackend::Overrideable;
  }

  // Error out if no suitable backend was found.
  AT_ERROR("unsupported ConvNd parameters");
}

// Selects a backend for convolution based on the inputs and params.
ConvBackend select_conv_backend(
    const Tensor& input_r, const Tensor& weight_r, const std::optional<Tensor>& bias_opt,
    SymIntArrayRef stride_, SymIntArrayRef padding_, SymIntArrayRef dilation_,
    bool transposed_, SymIntArrayRef output_padding_, c10::SymInt groups_, const at::OptionalSymIntArrayRef bias_sizes_opt) {
  c10::MaybeOwned<Tensor> bias_maybe_owned = at::borrow_from_optional_tensor(bias_opt);
  const Tensor& bias = *bias_maybe_owned;

  auto& ctx = at::globalContext();
  auto k = weight_r.ndimension();
  int64_t dim = k - 2;
  ConvParams<c10::SymInt> params;
  params.stride = expand_param_if_needed(stride_, "stride", dim);
  params.padding = expand_param_if_needed(padding_, "padding", dim);
  params.dilation = expand_param_if_needed(dilation_, "dilation", dim);
  params.transposed = transposed_;
  params.output_padding = expand_param_if_needed(output_padding_, "output_padding", dim);
  params.groups = std::move(groups_);
  params.benchmark = ctx.benchmarkCuDNN();
  params.deterministic = ctx.deterministicCuDNN() || ctx.deterministicAlgorithms();
  params.cudnn_enabled = ctx.userEnabledCuDNN();
  params.allow_tf32 = ctx.allowTF32CuDNN();

  auto input = input_r;
  auto weight = weight_r;
  check_shape_forward(input, weight.sym_sizes(), bias, params);

  // Expand 1d -> 2d.
  // This is only done for backends that don't natively support 1d spatial input.
  if (k == 3 && !input.is_mkldnn() && !input.is_xpu()) {
    // avoid accidentally going through NHWC for permuted 3d input.
    input = input.contiguous();
    params.view1d_as_2d();
    input = view4d(input);
    weight = view4d(weight);
  }

  auto bias_sizes = bias.defined() ? std::optional<SymIntArrayRef>(bias.sym_sizes()) : bias_sizes_opt;
  bool need_backward = GradMode::is_enabled() &&
      (input.requires_grad() || weight.requires_grad() || (bias.defined() && bias.requires_grad()));
  return _select_conv_backend(input, weight, bias, bias_sizes, need_backward, params);
}

// For BC reasons, have a copy that does not require bias_opt
static ConvBackend select_conv_backend(
    const Tensor& input,
    const Tensor& weight,
    const at::OptionalIntArrayRef bias_sizes_opt,
    const bool need_backward,
    const ConvParams<int64_t>& params) {
  return _select_conv_backend(input, weight, {}, bias_sizes_opt, need_backward, params);
}

static at::Tensor _convolution_nogroup_backend(
    const Tensor& input,
    const Tensor& weight,
    const Tensor& bias,
    const ConvBackend backend,
    const ConvParams<int64_t>& params) {
  auto kernel_size = weight.sizes().slice(2);
  switch(backend) {
    case ConvBackend::NnpackSpatial:
#if AT_NNPACK_ENABLED()
      return at::_nnpack_spatial_convolution(input, weight, bias, params.padding, params.stride);
#else
      TORCH_INTERNAL_ASSERT(false, "NnpackSpatial backend was selected in PyTorch compiled without nnpack support");
#endif
    case ConvBackend::Slow2d:
      return at::thnn_conv2d(input, weight, kernel_size, bias, params.stride, params.padding);
    case ConvBackend::SlowDilated2d:
      return at::slow_conv_dilated2d(
          input, weight, kernel_size, bias, params.stride, params.padding, params.dilation);
    case ConvBackend::SlowDilated3d:
      return at::slow_conv_dilated3d(
          input, weight, kernel_size, bias, params.stride, params.padding, params.dilation);
    case ConvBackend::SlowTranspose2d:
      return at::slow_conv_transpose2d(
          input, weight, kernel_size, bias, params.stride, params.padding, params.output_padding, params.dilation);
    case ConvBackend::SlowTranspose3d:
      return at::slow_conv_transpose3d(
          input, weight, kernel_size, bias, params.stride, params.padding, params.output_padding, params.dilation);
    default:
      TORCH_CHECK(false, "Unsupported conv nogroup backend encountered");
  }
}

static inline std::vector<int64_t> calc_output_size(
    const Tensor& input,
    const Tensor& weight,
    const ConvParams<int64_t>& params) {
  std::vector<int64_t> output_size = params.transposed ?
    conv_input_size(input.sizes(), weight.sizes(), params.padding, params.output_padding,
        params.stride, params.dilation, params.groups) :
    conv_output_size(input.sizes(), weight.sizes(), params.padding, params.stride, params.dilation);

  // Handle empty # of channels.
  if (input.size(input_channels_dim) == 0) {
    output_size[output_channels_dim] = 0;
  }
  return output_size;
}

static inline at::MemoryFormat determine_backend_memory_format(
    const Tensor& input,
    const Tensor& weight,
    const ConvBackend backend) {
  at::MemoryFormat backend_memory_format = at::MemoryFormat::Contiguous;
#if !defined(C10_MOBILE)
  auto k = weight.ndimension();
  // See Note [Mobile check segfaults]
  switch(backend) {
    case ConvBackend::Cudnn:
    case ConvBackend::CudnnTranspose:
      if (detail::getCUDAHooks().compiledWithCuDNN()) {
        backend_memory_format = cudnn_conv_suggest_memory_format(input, weight);
      }
      break;
    case ConvBackend::Miopen:
    case ConvBackend::MiopenDepthwise:
    case ConvBackend::MiopenTranspose:
      if (detail::getCUDAHooks().compiledWithMIOpen() && miopen_conv_use_channels_last(input, weight)) {
        TORCH_INTERNAL_ASSERT((k == 4 || k == 5),
            "Expected 4D or 5D input for miopen memory format selection in determine_backend_memory_format()");
        backend_memory_format = (k == 5) ? at::MemoryFormat::Contiguous /*at::MemoryFormat::ChannelsLast3d*/ : at::MemoryFormat::ChannelsLast;
      }
      break;
    case ConvBackend::Mkldnn:
    case ConvBackend::MkldnnTranspose:
      if (mkldnn_conv_use_channels_last(input, weight)) {
        backend_memory_format = (k == 5) ? at::MemoryFormat::ChannelsLast3d : at::MemoryFormat::ChannelsLast;
      }
      break;
    case ConvBackend::Slow2d:
    case ConvBackend::SlowDilated2d:
    case ConvBackend::SlowTranspose2d:
      if (thnn_conv_use_channels_last(input, weight)) {
        backend_memory_format = at::MemoryFormat::ChannelsLast;
      }
      break;
    case ConvBackend::Overrideable:
      if (xpu_conv_use_channels_last(input, weight)) {
        backend_memory_format = (k == 5) ? at::MemoryFormat::ChannelsLast3d : at::MemoryFormat::ChannelsLast;
      }
      break;
    default:
      backend_memory_format = at::MemoryFormat::Contiguous;
  }
#endif
  return backend_memory_format;
}

at::MemoryFormat _determine_backend_memory_format(
    const Tensor& input,
    const Tensor& weight,
    const ConvBackend backend)  {
  return determine_backend_memory_format(input, weight, backend);
}

at::Tensor _convolution(
    const Tensor& input_r, const Tensor& weight_r, const std::optional<Tensor>& bias_r_opt,
    IntArrayRef stride_, IntArrayRef padding_, IntArrayRef dilation_,
    bool transposed_, IntArrayRef output_padding_, int64_t groups_,
    bool benchmark, bool deterministic, bool cudnn_enabled, bool allow_tf32) {
  // See [Note: hacky wrapper removal for optional tensor]
  c10::MaybeOwned<Tensor> bias_r_maybe_owned = at::borrow_from_optional_tensor(bias_r_opt);
  const Tensor& bias_r = *bias_r_maybe_owned;

  auto input = input_r;
  auto weight = weight_r;
  auto bias = bias_r;
  auto k = weight.ndimension();
  c10::IntArrayRef weight_sizes = weight.sizes();
  int64_t dim = k - 2;

  TORCH_CHECK(dim > 0, "weight should have at least three dimensions");
  TORCH_CHECK(groups_ > 0, "non-positive groups is not supported");

  ConvParams<int64_t> params;
  params.stride = expand_param_if_needed(stride_, "stride", dim);
  params.padding = expand_param_if_needed(padding_, "padding", dim);
  params.dilation = expand_param_if_needed(dilation_, "dilation", dim);
  params.transposed = transposed_;
  params.output_padding = expand_param_if_needed(output_padding_, "output_padding", dim);
  params.groups = groups_;
  params.benchmark = benchmark;
  params.deterministic = deterministic;
  params.cudnn_enabled = cudnn_enabled;
  params.allow_tf32 = allow_tf32;

  check_shape_forward(input, weight_sizes, bias, params);

  // Expand 1d -> 2d.
  // This is only done for backends that don't natively support 1d spatial input.
  if (k == 3 && !input.is_mkldnn() && !input.is_xpu()) {
    // avoid accidentally going through NHWC for permuted 3d input.
    input = input.contiguous();
    params.view1d_as_2d();
    input = view4d(input);
    weight = view4d(weight);
  }

  // Select appropriate backend to use.
  auto bias_sizes_opt = bias.defined() ? std::optional<IntArrayRef>(bias.sizes()) : std::nullopt;
  bool need_backward = GradMode::is_enabled() &&
      (input.requires_grad() || weight.requires_grad() || (bias.defined() && bias.requires_grad()));
  ConvBackend backend = _select_conv_backend(input, weight, bias, c10::OptionalIntArrayRef(bias_sizes_opt), need_backward, params);
  at::MemoryFormat backend_memory_format = determine_backend_memory_format(input, weight, backend);

  // Call the backend.
  Tensor output;
  auto kernel_size = weight.sizes().slice(2);
  switch (backend) {
    case ConvBackend::CudaDepthwise2d:
      output = at::_conv_depthwise2d(input.contiguous(), weight, kernel_size, bias,
          params.stride, params.padding, params.dilation);
      break;
    case ConvBackend::CudaDepthwise3d:
      output = at::conv_depthwise3d(input.contiguous(), weight, kernel_size, bias,
          params.stride, params.padding, params.dilation);
      break;
    case ConvBackend::Cudnn:
      check_input_same_type_as_parameters(input, weight, bias);
      output = at::cudnn_convolution(
          input.contiguous(backend_memory_format), weight, params.padding, params.stride,
          params.dilation, params.groups, params.benchmark, params.deterministic, params.allow_tf32);
      if (bias.defined()) {
        output.add_(reshape_bias(input.dim(), bias));
      }
      break;
    case ConvBackend::CudnnTranspose:
      check_input_same_type_as_parameters(input, weight, bias);
      output = at::cudnn_convolution_transpose(
          input.contiguous(backend_memory_format), weight, params.padding, params.output_padding,
          params.stride, params.dilation, params.groups, params.benchmark, params.deterministic, params.allow_tf32);
      if (bias.defined()) {
        output.add_(reshape_bias(input.dim(), bias));
      }
      break;
    case ConvBackend::Empty:
    {
      Tensor weight_view;
      // Use permute and clone to avoid at::_unsafe_view(weight, -1) failure for non-contiguous cases where
      // view size is not compatible with input tensor's size and stride.
      if(weight.is_contiguous()) {
        weight_view = at::_unsafe_view(weight, -1);
      } else if (weight.is_contiguous(at::MemoryFormat::ChannelsLast)) {
        weight_view = at::_unsafe_view(at::permute(weight, {0, 2, 3, 1}), -1);
      } else if (weight.is_contiguous(at::MemoryFormat::ChannelsLast3d)) {
        weight_view = at::_unsafe_view(at::permute(weight, {0, 2, 3, 4, 1}), -1);
      } else {
        weight_view = at::_unsafe_view(weight.clone(at::MemoryFormat::Contiguous), -1);
      }

      output = (input.size(1) == 0) ? (input.view(-1) * weight_view) : (input * weight_view[0]);
      if (bias.defined()) {
        output.add_(bias[0]);
      }
      output = output.view(calc_output_size(input, weight, params));
      break;
    }
    case ConvBackend::Miopen:
      check_input_same_type_as_parameters(input, weight, bias);
      output = at::miopen_convolution(
          input.contiguous(backend_memory_format), weight, bias, params.padding, params.stride,
          params.dilation, params.groups, params.benchmark, params.deterministic);
      break;
    case ConvBackend::MiopenDepthwise:
      output = at::miopen_depthwise_convolution(
          input.contiguous(backend_memory_format), weight, bias, params.padding, params.stride,
          params.dilation, params.groups, params.benchmark, params.deterministic);
      break;
    case ConvBackend::MiopenTranspose:
      check_input_same_type_as_parameters(input, weight, bias);
      output = at::miopen_convolution_transpose(
          input.contiguous(backend_memory_format), weight, bias, params.padding, params.output_padding,
          params.stride, params.dilation, params.groups, params.benchmark, params.deterministic);
      break;
    case ConvBackend::Mkldnn:
#if AT_MKLDNN_ENABLED()
      check_input_same_type_as_parameters(input, weight, bias, backend);
      if (!input.is_mkldnn()) {
        // need to ensure contiguous for non-mkldnn tensors
        input = input.contiguous(backend_memory_format);
        weight = weight.contiguous(backend_memory_format);
        bias = bias.defined() ? bias.contiguous() : bias;
      }
      output = at::mkldnn_convolution(
          input, weight, bias, params.padding, params.stride, params.dilation, params.groups);
#else
      TORCH_INTERNAL_ASSERT(false, "Mkldnn backend was selected in PyTorch compiled without mkldnn support");
#endif
      break;
    case ConvBackend::MkldnnTranspose:
#if AT_MKLDNN_ENABLED()
      check_input_same_type_as_parameters(input, weight, bias, backend);
      if (!input.is_mkldnn()) {
        // need to ensure contiguous for non-mkldnn tensors
        input = input.contiguous(backend_memory_format);
        weight = weight.contiguous(backend_memory_format);
        bias = bias.defined() ? bias.contiguous() : bias;
      }
      output = mkldnn_convolution_transpose_stub(input.device().type(),
          input, weight, bias, params.padding, params.output_padding, params.stride, params.dilation, params.groups);
#else
      TORCH_INTERNAL_ASSERT(false, "Mkldnn backend was selected in PyTorch compiled without mkldnn support");
#endif
      break;
    case ConvBackend::MkldnnEmpty:
#if AT_MKLDNN_ENABLED()
      output = empty_mkldnn(
          calc_output_size(input, weight, params), optTypeMetaToScalarType(input.options().dtype_opt()),
          input.options().layout_opt(), input.options().device_opt(), input.options().pinned_memory_opt());
#else
      TORCH_INTERNAL_ASSERT(false, "Mkldnn backend was selected in PyTorch compiled without mkldnn support");
#endif
      break;
    case ConvBackend::Overrideable:
      output = at::convolution_overrideable(
          input, weight, bias, params.stride, params.padding, params.dilation, params.transposed,
          params.output_padding, params.groups);
      break;
    case ConvBackend::Slow3d:
      output = at::slow_conv3d(input, weight, kernel_size, bias, params.stride, params.padding);
      break;
    case ConvBackend::Winograd3x3Depthwise:
      output = convolution_depthwise3x3_winograd_stub(
          input.device().type(), input, weight, bias, params.stride, params.padding, params.groups);
      break;
    case ConvBackend::Xnnpack2d:
      output = xnnpack::convolution2d(
          input, weight, bias, params.padding, params.stride, params.dilation, params.groups);
      break;
    // Handle backends that don't natively support groups > 1.
    case ConvBackend::NnpackSpatial:
    case ConvBackend::Slow2d:
    case ConvBackend::SlowDilated2d:
    case ConvBackend::SlowDilated3d:
    case ConvBackend::SlowTranspose2d:
    case ConvBackend::SlowTranspose3d:
      input = input.contiguous(backend_memory_format);
      weight = weight.contiguous(backend_memory_format);
      if (params.groups == 1) {
        output = _convolution_nogroup_backend(input, weight, bias, backend, params);
      } else {
        std::vector<Tensor> outputs(params.groups);
        for (const auto g : c10::irange(params.groups)) {
          auto input_g = subtensor(input, 1, params.groups, g);
          auto weight_g = subtensor(weight, 0, params.groups, g);
          auto bias_g = subtensor(bias, 0, params.groups, g);
          outputs[g] = _convolution_nogroup_backend(input_g, weight_g, bias_g, backend, params);
        }
        output = at::cat(outputs, 1);
      }
      break;
    case ConvBackend::Mps:
#ifdef USE_MPS
      TORCH_CHECK(input.options().type_equal(weight.options()),
               "Input type (", input.toString(), ") and weight type (", weight.toString(),
               ") should be the same");
      TORCH_CHECK(!bias.defined() || (input.options().type_equal(bias.options())),
               "Input type (", input.toString(), ") and bias type (", bias.toString(),
               ") should be the same");

      output = at::_mps_convolution(input, weight, bias.defined() ? bias.contiguous() : bias,
                                     params.padding, params.stride, params.dilation,
                                     params.groups);
#else
      TORCH_INTERNAL_ASSERT(false, "MPS backend was selected in PyTorch without support");
#endif
      break;
    case ConvBackend::MpsTranspose:
#ifdef USE_MPS
      TORCH_CHECK(input.options().type_equal(weight.options()),
               "Input type (", input.toString(), ") and weight type (", weight.toString(),
               ") should be the same");
      TORCH_CHECK(!bias.defined() || (input.options().type_equal(bias.options())),
               "Input type (", input.toString(), ") and bias type (", bias.toString(),
               ") should be the same");
      output = at::_mps_convolution_transpose(
          input.contiguous(backend_memory_format), weight,
          params.padding, params.output_padding,
          params.stride, params.dilation, params.groups);
      if (bias.defined()) {
        output.add_(reshape_bias(input.dim(), bias));
      }
#else
      TORCH_INTERNAL_ASSERT(false, "MPS backend was selected in PyTorch without support");
#endif
      break;
  }

  if (k == 3 && !input.is_mkldnn() && !input.is_xpu()) {
    output = view3d(output);
  }

  return output;
}

at::Tensor _convolution(
    const Tensor& input_r, const Tensor& weight_r, const std::optional<Tensor>& bias_r_opt,
    IntArrayRef stride_, IntArrayRef padding_, IntArrayRef dilation_,
    bool transposed_, IntArrayRef output_padding_, int64_t groups_,
    bool benchmark, bool deterministic, bool cudnn_enabled)
{
  // See [Note: hacky wrapper removal for optional tensor]
  c10::MaybeOwned<Tensor> bias_r_maybe_owned = at::borrow_from_optional_tensor(bias_r_opt);
  const Tensor& bias_r = *bias_r_maybe_owned;

  return at::_convolution(input_r, weight_r, bias_r, stride_, padding_, dilation_, transposed_, output_padding_, groups_, benchmark, deterministic, cudnn_enabled, at::globalContext().allowTF32CuDNN());
}

std::tuple<Tensor, Tensor, Tensor> convolution_backward_overrideable(
        const Tensor& grad_output, const Tensor& input, const Tensor& weight,
        IntArrayRef stride, IntArrayRef padding, IntArrayRef dilation,
        bool transposed, IntArrayRef output_padding, int64_t groups, std::array<bool, 3> output_mask) {
   TORCH_CHECK_NOT_IMPLEMENTED(false, "convolution_backward_overrideable: You are likely triggering this with tensor backend other than CPU/CUDA/MKLDNN, if this is intended, please use TORCH_LIBRARY_IMPL to override this function ");
  return std::tuple<Tensor, Tensor, Tensor>(
          at::empty_like(input, LEGACY_CONTIGUOUS_MEMORY_FORMAT),
          at::empty_like(weight, LEGACY_CONTIGUOUS_MEMORY_FORMAT),
          at::empty({}));
}

static Tensor subvariable(const Tensor& var, int64_t dim, int64_t groups, int64_t g) {
  int64_t n = var.sizes()[dim] / groups;
  auto result = var.narrow(dim, n * g, n);
  return result;
}

std::tuple<Tensor,Tensor,Tensor> _convolution_double_backward( const std::optional<Tensor>& ggI_opt, const std::optional<Tensor>& ggW_r_opt, const std::optional<Tensor>& ggb_opt,
    const Tensor& gO_r, const Tensor& weight_r, const Tensor& input,
    IntArrayRef stride_, IntArrayRef padding_, IntArrayRef dilation_,
    bool transposed_, IntArrayRef output_padding_, int64_t groups_,
    std::array<bool, 3> output_mask) {
  // See [Note: hacky wrapper removal for optional tensor]
  c10::MaybeOwned<Tensor> ggI_maybe_owned = at::borrow_from_optional_tensor(ggI_opt);
  const Tensor& ggI = *ggI_maybe_owned;
  const Tensor& ggW_r = c10::value_or_else(ggW_r_opt, [] {return Tensor();});
  const Tensor& ggb = c10::value_or_else(ggb_opt, [] {return Tensor();});


  auto ggW = ggW_r;
  auto gO = gO_r;
  auto weight = weight_r;

  int64_t dim = weight.ndimension() - 2;
  ConvParams<int64_t> params;
  params.stride = expand_param_if_needed(stride_, "stride", dim);
  params.padding = expand_param_if_needed(padding_, "padding", dim);
  params.dilation = expand_param_if_needed(dilation_, "dilation", dim);
  params.transposed = transposed_;
  params.output_padding = expand_param_if_needed(output_padding_, "output_padding", dim);
  // TODO: hacky way of inferring the groups number for grouped Conv3D
  // See: https://github.com/pytorch/pytorch/pull/36355
  if (!params.transposed && input.dim() > 4) {
    // Avoid undefined behavior when num channels == 0; params are unused for that case.
    params.groups = (weight.size(1) > 0) ? input.size(1) / weight.size(1) : -1;
  } else {
    params.groups = groups_;
  }

  // Compute ggO = conv(ggI, w) + conv(i, ggW) + ggb
  Tensor ggO;
  if (input.numel() != 0) {
    if (ggI.defined()) {
      if (weight.is_cuda()) {
        weight = weight.contiguous();
      }
      ggO = at::convolution(ggI, weight, Tensor(), params.stride, params.padding, params.dilation, params.transposed, params.output_padding, params.groups);
    }

    if (ggW.defined()) {
      if (ggW.is_cuda()) {
        ggW = ggW.contiguous();
      }
      auto ggW_term = at::convolution(input, ggW, Tensor(), params.stride, params.padding, params.dilation, params.transposed, params.output_padding, params.groups);
      if (ggO.defined()) {
        ggO = ggO + ggW_term;
      } else {
        ggO = ggW_term;
      }
    }
  }

  if (ggb.defined()) {
    // View as (1, ggb.size(0), 1, 1...)

    // Expand
    std::vector<int64_t> new_size(gO.ndimension(), 1);
    new_size[1] = ggb.sizes()[0];
    auto ggb_contiguous = ggb.contiguous();
    auto ggb_view = ggb_contiguous.view(new_size);

    // Expand
    auto ggb_expanded = ggb_view.expand(gO.sizes());

    if (ggO.defined()) {
      ggO = ggO + ggb_expanded;
    } else {
      ggO = ggb_expanded;
    }
  }

  // Compute gW = conv(ggI, gO)
  Tensor gW;
  if (ggI.defined()) {

    // Modified params with correct padding
    ConvParams<int64_t> gw_conv_params(params);

    // Disable groups as they are handled separately
    auto groups = gw_conv_params.groups;
    gw_conv_params.groups = 1;
    std::swap(gw_conv_params.dilation, gw_conv_params.stride);

    // Transpose gO and ggI to accumulate over batch
    auto gOt = gO.transpose(0, 1);
    auto ggIt = ggI.transpose(0, 1);

    Tensor gWt;
    // Compute conv
    if (input.numel() != 0) {
      if (groups == 1) {

        if (gOt.is_cuda()) {
          gOt = gOt.contiguous();
        }
        // Compute conv
        if (params.transposed) {
          gw_conv_params.transposed = false;
          gWt = at::convolution(gOt, ggIt, Tensor(), gw_conv_params.stride, gw_conv_params.padding, gw_conv_params.dilation, gw_conv_params.transposed, gw_conv_params.output_padding, gw_conv_params.groups);
        } else {
          gWt = at::convolution(ggIt, gOt, Tensor(), gw_conv_params.stride, gw_conv_params.padding, gw_conv_params.dilation, gw_conv_params.transposed, gw_conv_params.output_padding, gw_conv_params.groups);
        }
      } else {
        std::vector<Tensor> gWt_list(groups);
        for (const auto g : c10::irange(groups)) {
          auto ggIt_g = subvariable(ggIt, 0, groups, g);
          auto gOt_g = subvariable(gOt, 0, groups, g);
          if (gOt_g.is_cuda()) {
            gOt_g = gOt_g.contiguous();
          }

          // Compute conv
          if (params.transposed) {
            gw_conv_params.transposed = false;
            gWt_list[g] = at::convolution(gOt_g, ggIt_g, Tensor(), gw_conv_params.stride, gw_conv_params.padding, gw_conv_params.dilation, gw_conv_params.transposed, gw_conv_params.output_padding, gw_conv_params.groups);
          } else {
            gWt_list[g] = at::convolution(ggIt_g, gOt_g, Tensor(), gw_conv_params.stride, gw_conv_params.padding, gw_conv_params.dilation, gw_conv_params.transposed, gw_conv_params.output_padding, gw_conv_params.groups);
          }
        }

        gWt = at::cat(gWt_list, 1);
      }

      // Transpose gW to match chan_in and chan_out
      gW = gWt.transpose(0, 1);

      // narrow gW to only relevant portion
      // we do it this way instead of narrowing the input itself because
      // the ConvForward kernels don't support asymmetric padding.
      auto gW_size = gW.sizes();
      auto w_size = weight.sizes();
      for (const auto i : c10::irange(2, static_cast<int64_t>(gW_size.size()))) {
        if (gW_size[i] > w_size[i]) {
            gW = gW.narrow(i, 0, w_size[i]);
            gW_size = gW.sizes();
        }
      }
    }
  }

  // Compute gI = convT(gO, ggW) if !transposed
  //         gI = conv(gO, ggw)  if transposed
  Tensor gI;
  if (input.numel() != 0) {
    if (ggW.defined()) {
      ConvParams<int64_t> gi_conv_params(params);
      gi_conv_params.transposed = !params.transposed;

      if (params.transposed) {
        if (gO.is_cuda()) {
          gO = gO.contiguous();
        }
        gI = at::convolution(gO, ggW, Tensor(), gi_conv_params.stride, gi_conv_params.padding, gi_conv_params.dilation, gi_conv_params.transposed, gi_conv_params.output_padding, gi_conv_params.groups);

        // narrow gI to only relevant portion
        // we do it this way because negative output_padding is not supported
        // TODO: figure out if we can narrow gO and save some compute,
        // rather than narrowing the computed gI
        auto gI_size = gI.sizes();
        auto i_size = input.sizes();
        for (const auto i : c10::irange(2, static_cast<int64_t>(gI_size.size()))) {
          if (gI_size[i] > i_size[i]) {
            gI = gI.narrow(i, 0, i_size[i]);
            gI_size = gI.sizes();
          }
        }
      } else {
        // calculate output_padding
        // TODO: figure out why this needs to be computed...
        auto kernel_size = weight.sizes().slice(2);
        auto input_shape = input.sizes().slice(2);
        auto grad_output_shape = gO.sizes().slice(2);

        for (const auto i : c10::irange(kernel_size.size())) {
          // Check if whole input has been used or not
          auto expected_input_shape = (kernel_size[i] - 1) * gi_conv_params.dilation[i]
            - 2 * gi_conv_params.padding[i]
            + (gi_conv_params.stride[i] * (grad_output_shape[i] - 1) + 1);
          if (expected_input_shape != input_shape[i]) {
            gi_conv_params.output_padding[i] = input_shape[i] - expected_input_shape;
          }
        }

        if (gO.is_cuda()) {
          gO = gO.contiguous();
        }

        gI = at::convolution(gO, ggW, Tensor(), gi_conv_params.stride, gi_conv_params.padding, gi_conv_params.dilation, gi_conv_params.transposed, gi_conv_params.output_padding, gi_conv_params.groups);
      }
    }
  }

  return std::tuple<Tensor,Tensor,Tensor>{ggO, gI, gW};
}

static std::tuple<at::Tensor, at::Tensor, at::Tensor> _convolution_backward_nogroup_backend(
    const Tensor& grad_output,
    const Tensor& input,
    const Tensor& weight,
    const std::array<bool, 3> output_mask,
    const ConvBackend backend,
    const ConvParams<int64_t>& params) {
  auto kernel_size = weight.sizes().slice(2);
  switch(backend) {
    case ConvBackend::Slow2d:
      return at::_slow_conv2d_backward(
        grad_output, input, weight, kernel_size, params.stride, params.padding, output_mask);
    // NB: nnpack backward does not support strided convolutions; use slow impl instead
    case ConvBackend::NnpackSpatial:
    case ConvBackend::SlowDilated2d:
      return slow_conv_dilated2d_backward_stub(
        input.device().type(),
        grad_output, input, weight, kernel_size, params.stride, params.padding, params.dilation, output_mask);
    case ConvBackend::SlowDilated3d:
      return slow_conv_dilated3d_backward_stub(
        input.device().type(),
        grad_output, input, weight, kernel_size, params.stride, params.padding, params.dilation, output_mask);
    case ConvBackend::SlowTranspose2d:
      return slow_conv_transpose2d_backward_stub(
        input.device().type(), grad_output, input, weight, kernel_size, params.stride, params.padding,
        params.output_padding, params.dilation, output_mask);
    case ConvBackend::SlowTranspose3d:
      return slow_conv_transpose3d_backward_stub(
        input.device().type(), grad_output, input, weight, kernel_size, params.stride, params.padding,
        params.output_padding, params.dilation, output_mask);
    default:
      TORCH_CHECK(false, "Unsupported conv nogroup backend encountered");
  }
}

// Backward pass for convolution. Computes gradients for input, weight, and bias depending on the
// output_mask setting. This function supports 1D, 2D, or 3D spatial convolution and currently requires
// a single batch dimension to be present.
//
// Args:
//   grad_output_: tensor of shape (N, C_out, L_out), (N, C_out, H_out, W_out), or (N, C_out, D_out, H_out, W_out)
//   input_: tensor of shape (N, C_in, L_in), (N, C_in, H_in, W_in), or (N, C_in, D_in, H_in, W_in)
//   weight_: tensor of shape (C_out, C_in // groups, *kernel_size); dimension of kernel_size must match the number
//       of input spatial dimensions
//   bias_sizes_opt: if specified, indicates that a bias was used in the forward pass and contains the shape
//       of the bias. While the bias shape can be computed from other inputs, it is provided to this function for
//       ease of use. The bias shape is (weight.shape[0]) for normal convolution and (weight.shape[1] * groups)
//       for transposed convolution.
//   stride: single value or an array with dimension matching the number of input spatial dimensions
//   padding: single value or an array with dimension matching the number of input spatial dimensions
//   dilation: single value or an array with dimension matching the number of input spatial dimensions
//   transposed: boolean indicating whether the convolution is transposed
//   output_padding: single value or dimension == number of input spatial dimensions; only supported when
//       transposed is true
//   groups: number of groups for grouped convolution
//   output_mask: 3-dim boolean array specifying which gradients to compute in input, weight, bias order
std::tuple<Tensor, Tensor, Tensor> convolution_backward(
    const Tensor& grad_output_, const Tensor& input_, const Tensor& weight_,
    const at::OptionalIntArrayRef bias_sizes_opt,
    IntArrayRef stride, IntArrayRef padding, IntArrayRef dilation, bool transposed, IntArrayRef output_padding,
    int64_t groups, std::array<bool, 3> output_mask) {
  auto grad_output = grad_output_;
  auto input = input_;
  auto weight = weight_;

  auto k = weight.ndimension();
  int64_t dim = k - 2;

  TORCH_CHECK(dim > 0, "weight should have at least three dimensions");

  auto& ctx = at::globalContext();
  ConvParams<int64_t> params;
  params.stride = expand_param_if_needed(stride, "stride", dim);
  params.padding = expand_param_if_needed(padding, "padding", dim);
  params.dilation = expand_param_if_needed(dilation, "dilation", dim);
  params.transposed = transposed;
  params.output_padding = expand_param_if_needed(output_padding, "output_padding", dim);
  params.groups = groups;
  params.benchmark = ctx.benchmarkCuDNN();
  params.deterministic = ctx.deterministicCuDNN() || ctx.deterministicAlgorithms();
  params.cudnn_enabled = ctx.userEnabledCuDNN();
  params.allow_tf32 = ctx.allowTF32CuDNN();

  // Validate inputs.
  check_shape_backward(input, weight.sizes(), params);
  TORCH_CHECK(input.dim() == grad_output.dim(),
      "Expected input and grad_output to have the same number of dimensions, but got: ",
      input.dim(), " and ", grad_output.dim());

  // output_padding is only supported for transposed convolutions
  if (!params.transposed) {
    for (auto pad : params.output_padding) {
      TORCH_CHECK(pad == 0, "output_padding is not supported for non-transposed convolutions; got: ",
        params.output_padding);
    }
  }

  // Expand 1d -> 2d.
  // This is only done for backends that don't natively support 1d spatial input.
  if (k == 3 && !input.is_mkldnn() && !input.is_xpu()) {
    // avoid accidentally going through NHWC for permuted 3d input.
    input = input.contiguous();
    params.view1d_as_2d();
    grad_output = view4d(grad_output);
    input = view4d(input);
    weight = view4d(weight);
  }

  // Select appropriate backend to use.
  ConvBackend backend = select_conv_backend(input, weight, bias_sizes_opt, /*need_backward=*/ true, params);
  at::MemoryFormat backend_memory_format = determine_backend_memory_format(input, weight, backend);

  // Call the backend.
  Tensor backend_grad_input, backend_grad_weight, backend_grad_bias;
  auto kernel_size = weight.sizes().slice(2);
  switch(backend) {
    case ConvBackend::CudaDepthwise2d:
    {
      std::array<bool, 2> input_weight_output_mask = {output_mask[0], output_mask[1]};
      std::tie(backend_grad_input, backend_grad_weight) =
        conv_depthwise2d_backward_stub(input.device().type(), grad_output, input,
          weight, kernel_size, params.stride, params.padding, params.dilation, input_weight_output_mask);
      break;
    }
    case ConvBackend::CudaDepthwise3d:
      TORCH_CHECK(input.ndimension() == 5);
      std::tie(backend_grad_input, backend_grad_weight, backend_grad_bias) =
        conv_depthwise3d_backward_stub(
          input.device().type(), grad_output, input, weight, kernel_size, params.stride,
          params.padding, params.dilation, output_mask);
      break;
    case ConvBackend::Cudnn:
    {
      check_input_same_type_as_parameters(input, weight);
      std::array<bool, 2> input_weight_output_mask = {output_mask[0], output_mask[1]};
      std::tie(backend_grad_input, backend_grad_weight) = cudnn_convolution_backward_stub(
          input.device().type(),
          // Only make input contiguous when it is necessary for the backwards computation
          output_mask[1] ? input.contiguous(backend_memory_format) : input,
          grad_output, weight, params.padding, params.stride,
          params.dilation, params.groups, params.benchmark, params.deterministic, params.allow_tf32,
          input_weight_output_mask);
      break;
    }
    case ConvBackend::Mps:
    {
#ifdef USE_MPS
      check_input_same_type_as_parameters(input, weight);
      std::tie(backend_grad_input, backend_grad_weight, backend_grad_bias) =
        at::mps_convolution_backward(input, grad_output, weight, params.padding,
          params.stride, params.dilation, params.groups, output_mask);
#else
      TORCH_INTERNAL_ASSERT(false, "MPS backend was selected in PyTorch without support");
#endif
      break;
    }
    case ConvBackend::MpsTranspose:
    {
#ifdef USE_MPS
      check_input_same_type_as_parameters(input, weight);
      std::array<bool, 2> input_weight_output_mask = {output_mask[0], output_mask[1]};
      std::tie(backend_grad_input, backend_grad_weight) = at::mps_convolution_transpose_backward(
        // Only make input contiguous when it is necessary for the backwards computation
        output_mask[1] ? input.contiguous(backend_memory_format) : input,
        grad_output, weight, params.padding, params.output_padding,
        params.stride, params.dilation, params.groups, input_weight_output_mask);
#else
      TORCH_INTERNAL_ASSERT(false, "MPS backend was selected in PyTorch without support");
#endif
      break;
    }
    case ConvBackend::CudnnTranspose:
    {
      check_input_same_type_as_parameters(input, weight);
      std::array<bool, 2> input_weight_output_mask = {output_mask[0], output_mask[1]};
      std::tie(backend_grad_input, backend_grad_weight) = cudnn_convolution_transpose_backward_stub(
        input.device().type(),
        // Only make input contiguous when it is necessary for the backwards computation
        output_mask[1] ? input.contiguous(backend_memory_format) : input,
        grad_output, weight, params.padding, params.output_padding,
        params.stride, params.dilation, params.groups, params.benchmark, params.deterministic, params.allow_tf32,
        input_weight_output_mask);
      break;
    }
    case ConvBackend::Empty:
      if (output_mask[0]) {
        backend_grad_input = at::zeros_like(input);
      }
      if (output_mask[1]) {
        backend_grad_weight = at::zeros_like(weight);
      }
      if (output_mask[2]) {
        backend_grad_bias = at::zeros(*bias_sizes_opt, weight.options());
      }
      break;
    case ConvBackend::MkldnnEmpty:
#if AT_MKLDNN_ENABLED()
      if (output_mask[0]) {
        if (input.is_mkldnn()) {
          backend_grad_input = empty_mkldnn(input.sizes(), optTypeMetaToScalarType(input.options().dtype_opt()),
              input.options().layout_opt(), input.options().device_opt(), input.options().pinned_memory_opt());
          backend_grad_input.zero_();
        } else {
          backend_grad_input = at::zeros_like(input);
        }
      }
      if (output_mask[1]) {
        // mkldnn weight is not supported during training by the mkldnn backend
        backend_grad_weight = at::zeros_like(weight);
      }
      if (output_mask[2]) {
        // mkldnn bias is not supported during training by the mkldnn backend
        backend_grad_bias = at::zeros(*bias_sizes_opt, weight.options());
      }
#else
      TORCH_INTERNAL_ASSERT(false, "Mkldnn backend was selected in PyTorch compiled without mkldnn support");
#endif
      break;
    case ConvBackend::Miopen:
      check_input_same_type_as_parameters(input, weight);
      std::tie(backend_grad_input, backend_grad_weight, backend_grad_bias) =
        miopen_convolution_backward_stub(
          input.device().type(),
          input.contiguous(backend_memory_format), grad_output, weight, params.padding, params.stride,
          params.dilation, params.groups, params.benchmark, params.deterministic, output_mask);
      break;
    case ConvBackend::MiopenDepthwise:
      std::tie(backend_grad_input, backend_grad_weight, backend_grad_bias) =
          miopen_depthwise_convolution_backward_stub(
            input.device().type(),
            input.contiguous(backend_memory_format), grad_output, weight, params.padding, params.stride,
            params.dilation, params.groups, params.benchmark, params.deterministic, output_mask);
      break;
    case ConvBackend::MiopenTranspose:
      check_input_same_type_as_parameters(input, weight);
      std::tie(backend_grad_input, backend_grad_weight, backend_grad_bias) =
        miopen_convolution_transpose_backward_stub(
          input.device().type(),
          input.contiguous(backend_memory_format), grad_output, weight, params.padding, params.output_padding,
          params.stride, params.dilation, params.groups, params.benchmark, params.deterministic, output_mask);
      break;
    case ConvBackend::Mkldnn:
      TORCH_CHECK(!weight.is_mkldnn(),
          "The MKLDNN backend does not support weight as an MKLDNN tensor during training");
      if (!input.is_mkldnn()) {
        input = input.contiguous(backend_memory_format);
        weight = weight.contiguous(backend_memory_format);
      }
      std::tie(backend_grad_input, backend_grad_weight, backend_grad_bias) =
        mkldnn_convolution_backward_stub(input.device().type(), input, grad_output, weight, params.padding,
          params.stride, params.dilation, params.groups, output_mask);
      break;
    case ConvBackend::MkldnnTranspose:
      TORCH_CHECK(!weight.is_mkldnn(),
          "The MKLDNN backend does not support weight as an MKLDNN tensor during training");
      if (!input.is_mkldnn()) {
        input = input.contiguous(backend_memory_format);
        weight = weight.contiguous(backend_memory_format);
      }
      std::tie(backend_grad_input, backend_grad_weight, backend_grad_bias) =
        mkldnn_convolution_transpose_backward_stub(input.device().type(), input, grad_output, weight, params.padding,
        params.output_padding, params.stride, params.dilation, params.groups, output_mask);
      break;
    case ConvBackend::Overrideable:
      // Only reach here when input is backend with out-of-source implementation.
      std::tie(backend_grad_input, backend_grad_weight, backend_grad_bias) =
        at::convolution_backward_overrideable(grad_output, input, weight, params.stride, params.padding,
          params.dilation, params.transposed, params.output_padding, params.groups, output_mask);
      break;
    case ConvBackend::Slow3d:
      // Note that no CUDA implementation of this kernel exists currently.
      std::tie(backend_grad_input, backend_grad_weight, backend_grad_bias) =
        slow_conv3d_backward_cpu(
            grad_output, input, weight, kernel_size,
            params.stride, params.padding, output_mask);
      break;
    // Handle backends that don't natively support groups > 1.
    case ConvBackend::NnpackSpatial:
    case ConvBackend::Slow2d:
    case ConvBackend::SlowDilated2d:
    case ConvBackend::SlowDilated3d:
    case ConvBackend::SlowTranspose2d:
    case ConvBackend::SlowTranspose3d:
    {
      input = input.contiguous(backend_memory_format);
      weight = weight.contiguous(backend_memory_format);
      if (params.groups == 1) {
        std::tie(backend_grad_input, backend_grad_weight, backend_grad_bias) =
          _convolution_backward_nogroup_backend(
            grad_output, input, weight, output_mask, backend, params);
      } else {
        std::vector<Tensor> backend_grad_inputs(params.groups);
        std::vector<Tensor> backend_grad_weights(params.groups);
        std::vector<Tensor> backend_grad_biases(params.groups);
        for (int g = 0; g < params.groups; ++g) {
          auto grad_output_g = subtensor(grad_output, 1, params.groups, g);
          auto input_g = subtensor(input, 1, params.groups, g);
          auto weight_g = subtensor(weight, 0, params.groups, g);
          std::tie(backend_grad_inputs[g], backend_grad_weights[g], backend_grad_biases[g]) =
            _convolution_backward_nogroup_backend(
              grad_output_g, input_g, weight_g, output_mask, backend, params);
        }
        if (output_mask[0]) {
          backend_grad_input = at::cat(backend_grad_inputs, 1);
        }
        if (output_mask[1]) {
          backend_grad_weight = at::cat(backend_grad_weights, 0);
        }
        if (output_mask[2]) {
          backend_grad_bias = at::cat(backend_grad_biases, 0);
        }
      }
      break;
    }
    // Backward is not supported for these backends.
    case ConvBackend::Winograd3x3Depthwise:
      TORCH_CHECK(false, "Backward is not supported for depthwise 3x3 winograd");
      break;
    case ConvBackend::Xnnpack2d:
      TORCH_CHECK(false, "Backward is not supported for xnnpack");
      break;
  }

  // Convert 2D inputs back to 1D for backends that don't natively support 1D
  // spatial inputs.
  if (output_mask[0]) {
    if (k == 3 && !input.is_mkldnn() && !input.is_xpu()) {
      backend_grad_input = view3d(backend_grad_input);
    }
  }
  if (output_mask[1]) {
    if (k == 3 && !input.is_mkldnn() && !input.is_xpu()) {
      backend_grad_weight = view3d(backend_grad_weight);
    }
  }
  if (output_mask[2]) {
    if (!backend_grad_bias.defined()) {
      // Calculate bias gradients outside of the backend for those that don't support it.
      backend_grad_bias = grad_output.sum((dim == 3) ? IntArrayRef{0, 2, 3, 4} : IntArrayRef{0, 2, 3});
    }
  }

  return std::make_tuple(backend_grad_input, backend_grad_weight, backend_grad_bias);
}

void _cudnn_set_conv_benchmark_empty_cache(bool enable) {
  conv_benchmark_empty_cache = enable;
}

bool _cudnn_get_conv_benchmark_empty_cache() {
  return conv_benchmark_empty_cache;
}



} // namespace at::native
