#pragma once #include #include #include #include namespace at::native { namespace { // TODO(crcrpar): Handle version bump in codegen. // rel: // https://github.com/pytorch/pytorch/blob/9cf84347767c8abb8feba18a9a1baba321eeb8b9/tools/autograd/gen_inplace_or_view_type.py#L481-L482 inline void increment_version(TensorList tensors) { for (const auto& t : tensors) { t.unsafeGetTensorImpl()->bump_version(); } } // Initializes args and checks if all args are aligned template __device__ bool init_args( T** args, TensorListMetadata& tl, const int64_t chunk_idx, const int64_t chunk_size, const int64_t tensor_loc) { bool all_aligned = true; for (int i = 0; i < depth; i++) { args[i] = (T*)tl.addresses[i][tensor_loc]; args[i] += chunk_idx * chunk_size; if (!is_aligned(args[i])) { all_aligned = false; } } return all_aligned; } // Initializes args and checks if all args are aligned template __device__ bool init_args( T** args, TensorListScalarListMetadata& tl, const int64_t chunk_idx, const int64_t chunk_size, const int64_t tensor_loc) { bool all_aligned = true; for (int i = 0; i < depth; i++) { args[i] = (T*)tl.addresses[i][tensor_loc]; args[i] += chunk_idx * chunk_size; if (!is_aligned(args[i])) { all_aligned = false; } } return all_aligned; } template __device__ bool init_args( T** args, FusedOptimizerTensorListMetadata& tl, const int64_t chunk_idx, const int64_t chunk_size, const int64_t tensor_loc) { bool all_aligned = true; for (int i = 0; i < depth; i++) { args[i] = (T*)tl.addresses[i][tensor_loc]; args[i] += chunk_idx * chunk_size; if (!is_aligned(args[i])) { all_aligned = false; } } return all_aligned; } template __device__ void load_args( T r_args[][kILP], T** args, const int64_t i_start, const int64_t chunk_size, const int64_t n) { #pragma unroll for (int ii = 0; ii < kILP; ii++) { const auto i = i_start + threadIdx.x + ii * blockDim.x; for (int r_index = 0; r_index < depth; r_index++) { r_args[r_index][ii] = 0; if (i < n && i < chunk_size) { r_args[r_index][ii] = args[r_index][i]; } } } } template __device__ void store_args( T* dst, T* src, const int64_t i_start, const int64_t chunk_size, const int64_t n) { #pragma unroll for (int ii = 0; ii < kILP; ii++) { const int64_t i = i_start + threadIdx.x + ii * blockDim.x; if (i < n && i < chunk_size) dst[i] = src[ii]; } } template __device__ __forceinline__ void binary_op_scalar( T r_args[][kILP], T** args, opmath_t scalar, const int64_t n, const int64_t chunk_size, const bool all_aligned, Op op) { // to make things simple, we put aligned case in a different code path if (n % kILP == 0 && chunk_size % kILP == 0 && all_aligned) { for (int64_t i_start = threadIdx.x; i_start * kILP < n && i_start * kILP < chunk_size; i_start += blockDim.x) { // load load_store(r_args[0], args[0], 0, i_start); #pragma unroll for (int ii = 0; ii < kILP; ii++) { r_args[0][ii] = static_cast( op(static_cast(r_args[0][ii]), static_cast(scalar))); } // store load_store(args[res_arg_index], r_args[0], i_start, 0); } } else { for (int64_t i_start = 0; i_start < n && i_start < chunk_size; i_start += blockDim.x * kILP) { // Regardless if depth is 1 (for inplace) or 2 (for out of place), r_args // has depth 1 load_args<1>(r_args, args, i_start, chunk_size, n); #pragma unroll for (int ii = 0; ii < kILP; ii++) { r_args[0][ii] = static_cast( op(static_cast(r_args[0][ii]), static_cast(scalar))); } store_args(args[res_arg_index], r_args[0], i_start, chunk_size, n); } } } template __device__ __forceinline__ void pointwise_op_scalar( T r_args[][kILP], T** args, opmath_t scalar, const int64_t n, const int64_t chunk_size, const bool all_aligned, Op op) { // to make things simple, we put aligned case in a different code path if (n % kILP == 0 && chunk_size % kILP == 0 && all_aligned) { for (int64_t i_start = threadIdx.x; i_start * kILP < n && i_start * kILP < chunk_size; i_start += blockDim.x) { // load load_store(r_args[0], args[0], 0, i_start); load_store(r_args[1], args[1], 0, i_start); load_store(r_args[2], args[2], 0, i_start); #pragma unroll for (int ii = 0; ii < kILP; ii++) { r_args[0][ii] = static_cast( static_cast(r_args[0][ii]) + scalar * op(static_cast(r_args[1][ii]), static_cast(r_args[2][ii]))); } // store load_store(args[res_arg_index], r_args[0], i_start, 0); } } else { for (int64_t i_start = 0; i_start < n && i_start < chunk_size; i_start += blockDim.x * kILP) { // Regardless if depth is 3 (for inplace) or 4 (for out of place), r_args // has depth 3 load_args<3>(r_args, args, i_start, chunk_size, n); #pragma unroll for (int ii = 0; ii < kILP; ii++) { r_args[0][ii] = static_cast( static_cast(r_args[0][ii]) + scalar * op(static_cast(r_args[1][ii]), static_cast(r_args[2][ii]))); } store_args(args[res_arg_index], r_args[0], i_start, chunk_size, n); } } } // // Binary Functors // template struct BinaryOpScalarFunctor { using opmath_t = at::opmath_type; template __device__ __forceinline__ void operator()( int chunk_size, TensorListMetadata& tl, Op op, opmath_t scalar) { const int tensor_loc = tl.block_to_tensor[blockIdx.x]; const int chunk_idx = tl.block_to_chunk[blockIdx.x]; auto n = tl.numel_for_tensor[tensor_loc]; T* args[depth]; const bool all_aligned = init_args(args, tl, chunk_idx, chunk_size, tensor_loc); n -= chunk_idx * chunk_size; T r_args[r_args_depth][kILP]; binary_op_scalar( r_args, args, scalar, n, chunk_size, all_aligned, op); } }; template struct BinaryOpScalarListFunctor { using opmath_t = at::opmath_type; template __device__ __forceinline__ void operator()( int chunk_size, TensorListScalarListMetadata& tl, Op op) { const auto tensor_loc = tl.block_to_tensor[blockIdx.x]; const auto chunk_idx = tl.block_to_chunk[blockIdx.x]; auto n = tl.numel_for_tensor[tensor_loc]; T* args[depth]; const bool all_aligned = init_args(args, tl, chunk_idx, chunk_size, tensor_loc); opmath_t scalar = tl.scalar_vals[tensor_loc]; n -= chunk_idx * chunk_size; T r_args[r_args_depth][kILP]; binary_op_scalar( r_args, args, scalar, n, chunk_size, all_aligned, op); } }; template struct BinaryOpListAlphaFunctor { using opmath_t = at::opmath_type; template __device__ __forceinline__ void operator()( int chunk_size, TensorListMetadata& tl, Op op, opmath_t alpha) { const auto tensor_loc = tl.block_to_tensor[blockIdx.x]; const auto chunk_idx = tl.block_to_chunk[blockIdx.x]; auto n = tl.numel_for_tensor[tensor_loc]; T* args[depth]; const bool all_aligned = init_args(args, tl, chunk_idx, chunk_size, tensor_loc); n -= chunk_idx * chunk_size; T r_args[r_args_depth][kILP]; // to make things simple, we put aligned case in a different code path if (n % kILP == 0 && chunk_size % kILP == 0 && all_aligned) { for (int64_t i_start = threadIdx.x; i_start * kILP < n && i_start * kILP < chunk_size; i_start += blockDim.x) { // load load_store(r_args[0], args[0], 0, i_start); load_store(r_args[1], args[1], 0, i_start); #pragma unroll for (int ii = 0; ii < kILP; ii++) { r_args[0][ii] = static_cast( op(static_cast(r_args[0][ii]), alpha * static_cast(r_args[1][ii]))); } // store load_store(args[res_arg_index], r_args[0], i_start, 0); } } else { for (int64_t i_start = 0; i_start < n && i_start < chunk_size; i_start += blockDim.x * kILP) { load_args(r_args, args, i_start, chunk_size, n); #pragma unroll for (int ii = 0; ii < kILP; ii++) { r_args[0][ii] = static_cast( op(static_cast(r_args[0][ii]), alpha * static_cast(r_args[1][ii]))); } store_args(args[res_arg_index], r_args[0], i_start, chunk_size, n); } } } }; template struct BinaryOpScalarTensorFunctor { using opmath_t = at::opmath_type; template __device__ __forceinline__ void operator()( int chunk_size, TensorListMetadata& tl, Op op, T* scalar, opmath_t alpha) { const int tensor_loc = tl.block_to_tensor[blockIdx.x]; const int chunk_idx = tl.block_to_chunk[blockIdx.x]; auto n = tl.numel_for_tensor[tensor_loc]; T* args[depth]; const bool all_aligned = init_args(args, tl, chunk_idx, chunk_size, tensor_loc); n -= chunk_idx * chunk_size; T r_args[r_args_depth][kILP]; // to make things simple, we put aligned case in a different code path if (n % kILP == 0 && chunk_size % kILP == 0 && all_aligned) { for (int64_t i_start = threadIdx.x; i_start * kILP < n && i_start * kILP < chunk_size; i_start += blockDim.x) { // load load_store(r_args[0], args[0], 0, i_start); #pragma unroll for (int ii = 0; ii < kILP; ii++) { r_args[0][ii] = static_cast(op( static_cast(r_args[0][ii]), static_cast(alpha) * static_cast(*scalar))); } // store load_store(args[res_arg_index], r_args[0], i_start, 0); } } else { for (int64_t i_start = 0; i_start < n && i_start < chunk_size; i_start += blockDim.x * kILP) { // Regardless if depth is 1 (for inplace) or 2 (for out of place), // r_args has depth 1 load_args<1>(r_args, args, i_start, chunk_size, n); #pragma unroll for (int ii = 0; ii < kILP; ii++) { r_args[0][ii] = static_cast(op( static_cast(r_args[0][ii]), static_cast(alpha) * static_cast(*scalar))); } store_args(args[res_arg_index], r_args[0], i_start, chunk_size, n); } } } }; // // Unary Functors // template struct ZeroFunctor { __device__ __forceinline__ void operator()( int chunk_size, TensorListMetadata<1>& tl) { const auto tensor_loc = tl.block_to_tensor[blockIdx.x]; const auto chunk_idx = tl.block_to_chunk[blockIdx.x]; auto n = tl.numel_for_tensor[tensor_loc]; T* args[depth]; const auto all_aligned = init_args(args, tl, chunk_idx, chunk_size, tensor_loc); n -= chunk_idx * chunk_size; T r_args[r_args_depth][kILP]; // to make things simple, we put aligned case in a different code path if (n % kILP == 0 && chunk_size % kILP == 0 && all_aligned) { for (int64_t i_start = threadIdx.x; i_start * kILP < n && i_start * kILP < chunk_size; i_start += blockDim.x) { #pragma unroll for (int ii = 0; ii < kILP; ii++) { r_args[0][ii] = 0; } // store load_store(args[0], r_args[0], i_start, 0); } } else { for (int64_t i_start = 0; i_start < n && i_start < chunk_size; i_start += blockDim.x * kILP) { #pragma unroll for (int ii = 0; ii < kILP; ii++) { r_args[0][ii] = 0; } store_args(args[res_arg_index], r_args[0], i_start, chunk_size, n); } } } }; template struct UnaryOpFunctor { using opmath_t = at::opmath_type; template __device__ __forceinline__ void operator()( int chunk_size, TensorListMetadata& tl, Op op) { const auto tensor_loc = tl.block_to_tensor[blockIdx.x]; const auto chunk_idx = tl.block_to_chunk[blockIdx.x]; auto n = tl.numel_for_tensor[tensor_loc]; T* args[depth]; bool all_aligned = init_args(args, tl, chunk_idx, chunk_size, tensor_loc); n -= chunk_idx * chunk_size; T r_args[r_args_depth][kILP]; // to make things simple, we put aligned case in a different code path if (n % kILP == 0 && chunk_size % kILP == 0 && all_aligned) { for (int64_t i_start = threadIdx.x; i_start * kILP < n && i_start * kILP < chunk_size; i_start += blockDim.x) { // load load_store(r_args[0], args[0], 0, i_start); #pragma unroll for (int ii = 0; ii < kILP; ii++) { r_args[0][ii] = static_cast(op(static_cast(r_args[0][ii]))); } // store load_store(args[res_arg_index], r_args[0], i_start, 0); } } else { for (int64_t i_start = 0; i_start < n && i_start < chunk_size; i_start += blockDim.x * kILP) { load_args(r_args, args, i_start, chunk_size, n); #pragma unroll for (int ii = 0; ii < kILP; ii++) { r_args[0][ii] = static_cast(op(static_cast(r_args[0][ii]))); } store_args(args[res_arg_index], r_args[0], i_start, chunk_size, n); } } } }; // // Pointwise Functors // template struct PointwiseOpScalarFunctor { using opmath_t = at::opmath_type; template __device__ __forceinline__ void operator()( int chunk_size, TensorListMetadata& tl, Op op, opmath_t scalar) { const auto tensor_loc = tl.block_to_tensor[blockIdx.x]; const auto chunk_idx = tl.block_to_chunk[blockIdx.x]; auto n = tl.numel_for_tensor[tensor_loc]; T* args[depth]; const bool all_aligned = init_args(args, tl, chunk_idx, chunk_size, tensor_loc); n -= chunk_idx * chunk_size; T r_args[r_args_depth][kILP]; pointwise_op_scalar( r_args, args, scalar, n, chunk_size, all_aligned, op); } }; template struct PointwiseOpScalarListFunctor { using opmath_t = at::opmath_type; template __device__ __forceinline__ void operator()( int chunk_size, TensorListScalarListMetadata& tl, Op op) { const auto tensor_loc = tl.block_to_tensor[blockIdx.x]; const auto chunk_idx = tl.block_to_chunk[blockIdx.x]; auto n = tl.numel_for_tensor[tensor_loc]; T* args[depth]; const bool all_aligned = init_args(args, tl, chunk_idx, chunk_size, tensor_loc); opmath_t scalar = tl.scalar_vals[tensor_loc]; n -= chunk_idx * chunk_size; T r_args[r_args_depth][kILP]; pointwise_op_scalar( r_args, args, scalar, n, chunk_size, all_aligned, op); } }; template struct PointwiseOpListFunctor { using opmath_t = at::opmath_type; template __device__ __forceinline__ void operator()( int chunk_size, TensorListMetadata& tl, Op op) { const auto tensor_loc = tl.block_to_tensor[blockIdx.x]; const auto chunk_idx = tl.block_to_chunk[blockIdx.x]; auto n = tl.numel_for_tensor[tensor_loc]; T* args[depth]; const bool all_aligned = init_args(args, tl, chunk_idx, chunk_size, tensor_loc); n -= chunk_idx * chunk_size; T r_args[depth - 1][kILP]; // to make things simple, we put aligned case in a different code path if (n % kILP == 0 && chunk_size % kILP == 0 && all_aligned) { for (int64_t i_start = threadIdx.x; i_start * kILP < n && i_start * kILP < chunk_size; i_start += blockDim.x) { // load load_store(r_args[0], args[0], 0, i_start); load_store(r_args[1], args[1], 0, i_start); #pragma unroll for (int ii = 0; ii < kILP; ii++) { r_args[0][ii] = static_cast( op(static_cast(r_args[0][ii]), static_cast(r_args[1][ii]))); } // store load_store(args[2], r_args[0], i_start, 0); } } else { for (int64_t i_start = 0; i_start < n && i_start < chunk_size; i_start += blockDim.x * kILP) { load_args(r_args, args, i_start, chunk_size, n); #pragma unroll for (int ii = 0; ii < kILP; ii++) { r_args[0][ii] = static_cast( op(static_cast(r_args[0][ii]), static_cast(r_args[1][ii]))); } store_args(args[2], r_args[0], i_start, chunk_size, n); } } } }; template struct TernaryOpListFunctor { using opmath_t = at::opmath_type; template __device__ __forceinline__ void operator()( int chunk_size, TensorListMetadata& tl, Op op) { static_assert(depth == 3 || depth == 4, ""); static_assert(depth >= r_args_depth, ""); static_assert(res_arg_index == depth - 1 || res_arg_index == 0, ""); const auto tensor_loc = tl.block_to_tensor[blockIdx.x]; const auto chunk_idx = tl.block_to_chunk[blockIdx.x]; auto n = tl.numel_for_tensor[tensor_loc]; T* args[depth]; const bool all_aligned = init_args(args, tl, chunk_idx, chunk_size, tensor_loc); n -= chunk_idx * chunk_size; T r_args[r_args_depth][kILP]; if (n % kILP == 0 && chunk_size % kILP == 0 && all_aligned) { for (int64_t i_start = threadIdx.x; i_start * kILP < n && i_start * kILP < chunk_size; i_start += blockDim.x) { load_store(r_args[0], args[0], 0, i_start); load_store(r_args[1], args[1], 0, i_start); load_store(r_args[2], args[2], 0, i_start); #pragma unroll for (int ii = 0; ii < kILP; ii++) { r_args[0][ii] = op(static_cast(r_args[0][ii]), static_cast(r_args[1][ii]), static_cast(r_args[2][ii])); } load_store(args[res_arg_index], r_args[0], i_start, 0); } } else { for (int64_t i_start = 0; i_start < n && i_start < chunk_size; i_start += blockDim.x * kILP) { load_args(r_args, args, i_start, chunk_size, n); #pragma unroll for (int ii = 0; ii < kILP; ii++) { r_args[0][ii] = op(static_cast(r_args[0][ii]), static_cast(r_args[1][ii]), static_cast(r_args[2][ii])); } store_args(args[res_arg_index], r_args[0], i_start, chunk_size, n); } } } }; template struct TernaryOpScalarFunctor { using opmath_t = at::opmath_type; template __device__ __forceinline__ void operator()( int chunk_size, TensorListMetadata& tl, Op op, opmath_t alpha) { static_assert(depth == 2 || depth == 3, ""); static_assert(depth >= r_args_depth, ""); static_assert(res_arg_index == depth - 1 || res_arg_index == 0, ""); const auto tensor_loc = tl.block_to_tensor[blockIdx.x]; const auto chunk_idx = tl.block_to_chunk[blockIdx.x]; auto n = tl.numel_for_tensor[tensor_loc]; T* args[depth]; const bool all_aligned = init_args(args, tl, chunk_idx, chunk_size, tensor_loc); n -= chunk_idx * chunk_size; T r_args[r_args_depth][kILP]; // to make things simple, we put aligned case in a different code path if (n % kILP == 0 && chunk_size % kILP == 0 && all_aligned) { for (int64_t i_start = threadIdx.x; i_start * kILP < n && i_start * kILP < chunk_size; i_start += blockDim.x) { // load load_store(r_args[0], args[0], 0, i_start); load_store(r_args[1], args[1], 0, i_start); #pragma unroll for (int ii = 0; ii < kILP; ii++) { r_args[0][ii] = op(static_cast(r_args[0][ii]), static_cast(r_args[1][ii]), alpha); } // store load_store(args[res_arg_index], r_args[0], i_start, 0); } } else { for (int64_t i_start = 0; i_start < n && i_start < chunk_size; i_start += blockDim.x * kILP) { load_args(r_args, args, i_start, chunk_size, n); #pragma unroll for (int ii = 0; ii < kILP; ii++) { r_args[0][ii] = op(static_cast(r_args[0][ii]), static_cast(r_args[1][ii]), alpha); } store_args(args[res_arg_index], r_args[0], i_start, chunk_size, n); } } } }; template struct power_functor { C10_DEVICE T operator()(const T& a, const T& b) const { return at::native::pow_(a, b); } }; template struct reverse_power_functor { C10_DEVICE T operator()(const T& a, const T& b) const { return at::native::pow_(b, a); } }; } // namespace } // namespace at::native