#pragma once #include #include #include #include #include #include #include #include #define HAS_WARP_MERGE_SORT() (CUDA_VERSION >= 110600) namespace at { namespace native { template __device__ inline void swapVars(T& t1, T& t2) { T tmp = t1; t1 = t2; t2 = tmp; } template __device__ inline void bitonicSwap(K& kA, V& vA, bool& validA, K& kB, V& vB, bool& validB, bool dir, const Comparator& comp) { // Invalid entries always sort to the end bool swap = (comp(kA, kB) && validA) || !validB; if (swap == dir) { swapVars(kA, kB); swapVars(vA, vB); swapVars(validA, validB); } }; template __device__ inline void bitonicSort(K *keys, V *values, bool *valid, const Comparator& comp) { #if !defined(USE_ROCM) #pragma unroll #endif for (unsigned int size = 2; size < Power2SortSize; size *= 2) { bool flag = ((threadIdx.x & (size / 2)) != 0); #if !defined(USE_ROCM) #pragma unroll #endif for (unsigned int stride = size / 2; stride > 0; stride /= 2) { __syncthreads(); unsigned int pos = 2 * threadIdx.x - (threadIdx.x & (stride - 1)); bitonicSwap( keys[pos], values[pos], valid[pos], keys[pos + stride], values[pos + stride], valid[pos + stride], flag, comp); } } #if !defined(USE_ROCM) #pragma unroll #endif for (unsigned int stride = Power2SortSize / 2; stride > 0; stride /= 2) { __syncthreads(); unsigned int pos = 2 * threadIdx.x - (threadIdx.x & (stride - 1)); bitonicSwap( keys[pos], values[pos], valid[pos], keys[pos + stride], values[pos + stride], valid[pos + stride], false, comp); } __syncthreads(); } // at::cuda::detail::TensorInfo version // Sorts (key, value) pairs (in different tensors) in-place; i.e., // modifies the input `keys` and `values` template C10_LAUNCH_BOUNDS_1(block_dim_x * max_block_dim_y) __global__ void bitonicSortKVInPlace(at::cuda::detail::TensorInfo keys, IndexType keySlices, IndexType keySliceSize, IndexType keySliceStride, at::cuda::detail::TensorInfo values, IndexType valueSliceStride, Comparator comp) { // Find the slice of the tensor that we are sorting // NOTE: blockDim.y may be less max_block_dim_y const IndexType blockIndex = getLinearBlockId(); const IndexType linearIndex = blockIndex * blockDim.y + threadIdx.y; // If the entire block is out of bounds exit early if (blockIndex * blockDim.y >= keySlices) { return; } // It's also possible for some rows of a block to be out of bounds // but all thread need to run for __syncthreads to work. const bool row_valid = linearIndex < keySlices; constexpr int items_per_thread = 2; constexpr int Power2SortSize = block_dim_x * items_per_thread; // Storage for max_block_dim_y sorts performed in parallel __shared__ K blockSharedKeys[max_block_dim_y][Power2SortSize]; __shared__ V blockSharedValues[max_block_dim_y][Power2SortSize]; __shared__ bool blockSharedValid[max_block_dim_y][Power2SortSize]; auto sharedKeys = blockSharedKeys[threadIdx.y]; auto sharedValues = blockSharedValues[threadIdx.y]; auto sharedValid = blockSharedValid[threadIdx.y]; const IndexType keyStartOffset = at::cuda::detail::IndexToOffset::get(linearIndex, keys); const IndexType valueStartOffset = at::cuda::detail::IndexToOffset::get(linearIndex, values); // Load 2 values per thread into the shared workspace #pragma unroll for (int k = 0; k < items_per_thread; ++k) { auto idx = threadIdx.x + k * blockDim.x; bool valid = row_valid && idx < keySliceSize; sharedKeys[idx] = valid ? keys.data[idx * keySliceStride + keyStartOffset] : K{}; sharedValues[idx] = valid ? values.data[idx * valueSliceStride + valueStartOffset] : V{}; sharedValid[idx] = valid; } // Sort! bitonicSort( sharedKeys, sharedValues, sharedValid, comp); if (!row_valid) { return; } // Store outputs #pragma unroll for (int k = 0; k < items_per_thread; ++k) { auto idx = threadIdx.x + k * blockDim.x; if (idx < keySliceSize) { keys.data[idx * keySliceStride + keyStartOffset] = sharedKeys[idx]; values.data[idx * valueSliceStride + valueStartOffset] = sharedValues[idx]; } } } #if HAS_WARP_MERGE_SORT() template C10_LAUNCH_BOUNDS_1(C10_WARP_SIZE * max_block_dim_y) __global__ void warpMergeSortKVInPlace( at::cuda::detail::TensorInfo keys, IndexType keySlices, IndexType keySliceSize, IndexType keySliceStride, at::cuda::detail::TensorInfo values, IndexType valueSliceStride, Comparator comp, K invalid_key) { // Find the slice of the tensor that we are sorting // NOTE: blockDim.y may be less max_block_dim_y const IndexType blockIndex = getLinearBlockId(); const IndexType linearIndex = blockIndex * blockDim.y + threadIdx.y; // If this row is out of bounds exit early if (linearIndex >= keySlices) { return; } const IndexType keyStartOffset = at::cuda::detail::IndexToOffset::get(linearIndex, keys); const IndexType valueStartOffset = at::cuda::detail::IndexToOffset::get(linearIndex, values); K *keys_slice = &keys.data[keyStartOffset]; V *values_slice = &values.data[valueStartOffset]; StridedRandomAccessor keys_iter(keys_slice, keySliceStride); StridedRandomAccessor values_iter(values_slice, valueSliceStride); namespace cub = ROCM_HIPCUB(at_cuda_detail::cub); CUDA_KERNEL_ASSERT(blockDim.x == C10_WARP_SIZE); CUDA_KERNEL_ASSERT(blockDim.y <= max_block_dim_y); constexpr int items_per_thread = sort_size / C10_WARP_SIZE; static_assert( items_per_thread * C10_WARP_SIZE == sort_size, "sort_size must be a multiple of C10_WARP_SIZE"); using LoadKeys = cub::WarpLoad; using LoadValues = cub::WarpLoad; using Sort = cub::WarpMergeSort; using StoreKeys = cub::WarpStore; using StoreValues = cub::WarpStore; __shared__ union { typename LoadKeys::TempStorage load_keys; typename LoadValues::TempStorage load_values; typename Sort::TempStorage sort; typename StoreKeys::TempStorage store_keys; typename StoreValues::TempStorage store_values; } tmp_storage[max_block_dim_y]; auto& warp_storage = tmp_storage[threadIdx.y]; // Load inputs K local_keys[items_per_thread]; V local_values[items_per_thread]; const auto invalid_value = V{}; LoadKeys(warp_storage.load_keys).Load(keys_iter, local_keys, keySliceSize, invalid_key); WARP_SYNC(); LoadValues(warp_storage.load_values).Load(values_iter, local_values, keySliceSize, invalid_value); WARP_SYNC(); // Sort! We use stable sort to ensure that invalid values are never // sorted before valid values. In testing it performed the same as // .Sort, so there is no down-side. Sort(warp_storage.sort).StableSort( local_keys, local_values, comp, keySliceSize, invalid_key); WARP_SYNC(); // Store outputs StoreKeys(warp_storage.store_keys).Store(keys_iter, local_keys, keySliceSize); WARP_SYNC(); StoreValues(warp_storage.store_values).Store(values_iter, local_values, keySliceSize); } #endif // HAS_WARP_MERGE_SORT() template C10_LAUNCH_BOUNDS_1(block_size) __global__ void radixSortKVInPlace(at::cuda::detail::TensorInfo keys, IndexType keySlices, IndexType keySliceSize, IndexType keySliceStride, at::cuda::detail::TensorInfo values, IndexType valueSliceStride, bool descending) { static_assert(block_size > 0, ""); // Find the slice of the tensor that we are sorting const IndexType linearIndex = getLinearBlockId(); // Tiling the slices could have us be out of bounds, if there are a // lot of slices to sort if (linearIndex >= keySlices) { return; } const IndexType keyStartOffset = at::cuda::detail::IndexToOffset::get(linearIndex, keys); const IndexType valueStartOffset = at::cuda::detail::IndexToOffset::get(linearIndex, values); K *keys_slice = &keys.data[keyStartOffset]; V *values_slice = &values.data[valueStartOffset]; StridedRandomAccessor keys_iter(keys_slice, keySliceStride); StridedRandomAccessor values_iter(values_slice, valueSliceStride); namespace cub = ROCM_HIPCUB(at_cuda_detail::cub); using key_t = typename at::cuda::cub::detail::cuda_type::type; using LoadKeys = cub::BlockLoad; using LoadValues = cub::BlockLoad; using Sort = cub::BlockRadixSort; using StoreKeys = cub::BlockStore; using StoreValues = cub::BlockStore; __shared__ union { typename LoadKeys::TempStorage load_keys; typename LoadValues::TempStorage load_values; typename Sort::TempStorage sort; typename StoreKeys::TempStorage store_keys; typename StoreValues::TempStorage store_values; } tmp_storage; // cub's Block operations operate on a fixed number of items, but the // actual slice we are sorting might be smaller. So, we need to make // up the difference with keys that will always sort higher. const K invalid_key = [descending] { using radix_t = typename cub::Traits::UnsignedBits; union { K key; radix_t radix; } tmp; tmp.radix = descending ? cub::Traits::LOWEST_KEY : cub::Traits::MAX_KEY; return tmp.key; }(); const V invalid_value = static_cast(0); // Load inputs K local_keys[items_per_thread]; V local_values[items_per_thread]; LoadKeys(tmp_storage.load_keys).Load(keys_iter, local_keys, keySliceSize, invalid_key); __syncthreads(); LoadValues(tmp_storage.load_values).Load(values_iter, local_values, keySliceSize, invalid_value); __syncthreads(); // Sort! if (descending) { Sort(tmp_storage.sort).SortDescending( reinterpret_cast(local_keys), local_values); } else { Sort(tmp_storage.sort).Sort( reinterpret_cast(local_keys), local_values); } __syncthreads(); // Store outputs StoreKeys(tmp_storage.store_keys).Store(keys_iter, local_keys, keySliceSize); __syncthreads(); StoreValues(tmp_storage.store_values).Store(values_iter, local_values, keySliceSize); } }} // at::native