#include #include #include #include #include namespace at { namespace native { template struct TopKTypeConfig {}; template <> struct TopKTypeConfig { typedef uint32_t RadixType; // Converts a float to an integer representation with the same // sorting; i.e., for floats f1, f2: // if f1 < f2 then convert(f1) < convert(f2) // We use this to enable radix selection of floating-point values. // This also gives a relative order for NaNs, but that's ok, as they // will all be adjacent // neg inf: signbit=1 exp=ff fraction=0 --> radix = 0 00 ff.. // pos inf: signbit=0 exp=ff fraction=0 --> radix = 1 ff 00.. // pos nan: signbit=0 exp=ff fraction>0 --> radix = 1 ff x>0 // neg nan: signbit=1 exp=ff fraction>0 --> radix = 0 00 x struct TopKTypeConfig { typedef uint32_t RadixType; static inline __device__ RadixType convert(uint8_t v) { return v; } static inline __device__ uint8_t deconvert(RadixType v) { return v; } }; template <> struct TopKTypeConfig { typedef uint32_t RadixType; static inline __device__ RadixType convert(int8_t v) { return 128u + v; } static inline __device__ int8_t deconvert(RadixType v) { return v - 128; } }; template <> struct TopKTypeConfig { typedef uint32_t RadixType; static inline __device__ RadixType convert(int16_t v) { static_assert(sizeof(short) == 2, ""); return 32768u + v; } static inline __device__ int16_t deconvert(RadixType v) { return v - 32768; } }; template <> struct TopKTypeConfig { typedef uint32_t RadixType; static inline __device__ RadixType convert(int32_t v) { static_assert(sizeof(int) == 4, ""); return 2147483648u + v; } static inline __device__ int32_t deconvert(RadixType v) { return v - 2147483648u; } }; template <> struct TopKTypeConfig { typedef uint64_t RadixType; static inline __device__ RadixType convert(int64_t v) { static_assert(sizeof(int64_t) == 8, ""); return 9223372036854775808ull + v; } static inline __device__ int64_t deconvert(RadixType v) { return v - 9223372036854775808ull; } }; template <> struct TopKTypeConfig { typedef uint64_t RadixType; static inline __device__ RadixType convert(double v) { RadixType x = __double_as_longlong(v); RadixType mask = -((x >> 63)) | 0x8000000000000000; return (v == v) ? (x ^ mask) : 0xffffffffffffffff; } static inline __device__ double deconvert(RadixType v) { RadixType mask = ((v >> 63) - 1) | 0x8000000000000000; return __longlong_as_double(v ^ mask); } }; template <> struct TopKTypeConfig { typedef uint32_t RadixType; static inline __device__ RadixType convert(at::Half v) { #if defined(__CUDA_ARCH__) || defined(USE_ROCM) RadixType x = __half_as_ushort(v); RadixType mask = (x & 0x00008000) ? 0x0000ffff : 0x00008000; return (v == v) ? (x ^ mask) : 0xffff; #else CUDA_KERNEL_ASSERT(false); return 0u; #endif } static inline __device__ at::Half deconvert(RadixType v) { #if defined(__CUDA_ARCH__) || defined(USE_ROCM) RadixType mask = (v & 0x00008000) ? 0x00008000 : 0x0000ffff; return __ushort_as_half(v ^ mask); #else CUDA_KERNEL_ASSERT(false); return static_cast(0); #endif } }; template <> struct TopKTypeConfig { typedef uint32_t RadixType; static inline __device__ RadixType convert(at::BFloat16 v) { RadixType x = v.x; RadixType mask = (x & 0x00008000) ? 0x0000ffff : 0x00008000; return (v == v) ? (x ^ mask) : 0xffff; } static inline __device__ at::BFloat16 deconvert(RadixType v) { RadixType mask = (v & 0x00008000) ? 0x00008000 : 0x0000ffff; at::BFloat16 r; r.x = (v ^ mask); return r; } }; // This function counts the distribution of all input values in a // slice we are selecting by radix digit at `radixDigitPos`, but only // those that pass the filter `((v & desiredMask) == desired)`. // This produces and broadcasts the seen counts for a single block only. // `smem` must have at least `RadixSize` elements. template < typename scalar_t, typename bitwise_t, typename index_t, typename CountType, int RadixSize, int RadixBits> __device__ void countRadixUsingMask( CountType counts[RadixSize], CountType* smem, bitwise_t desired, bitwise_t desiredMask, int radixDigitPos, index_t sliceSize, index_t withinSliceStride, const scalar_t* data) { // Clear out per-thread counts from a previous round #pragma unroll for (int i = 0; i < RadixSize; ++i) { counts[i] = 0; } if (threadIdx.x < RadixSize) { smem[threadIdx.x] = 0; } __syncthreads(); // Scan over all the data. Upon a read, the warp will accumulate // counts per each digit in the radix using warp voting. #if !defined(USE_ROCM) // Must be called outside of loop to ensure all threads participate unsigned mask = WARP_BALLOT(threadIdx.x < sliceSize); #endif for (index_t i = threadIdx.x; i < sliceSize;) { bitwise_t val = TopKTypeConfig::convert(doLdg(&data[i * withinSliceStride])); bool hasVal = ((val & desiredMask) == desired); bitwise_t digitInRadix = at::cuda::Bitfield::getBitfield( val, radixDigitPos, RadixBits); #pragma unroll for (uint32_t j = 0; j < RadixSize; ++j) { bool vote = hasVal && (digitInRadix == j); #if defined(USE_ROCM) counts[j] += __popcll(WARP_BALLOT(vote)); #else counts[j] += __popc(WARP_BALLOT(vote, mask)); #endif } i += blockDim.x; #if !defined(USE_ROCM) mask = WARP_BALLOT(i < sliceSize, mask); #endif } // Now, for each warp, sum values if (at::cuda::getLaneId() == 0) { #pragma unroll for (uint32_t i = 0; i < RadixSize; ++i) { gpuAtomicAddNoReturn(&smem[i], counts[i]); } } __syncthreads(); // For each thread, read in the total counts #pragma unroll for (uint32_t i = 0; i < RadixSize; ++i) { counts[i] = smem[i]; } __syncthreads(); } // Over what radix we are selecting values constexpr int RADIX_BITS = 2; // digits are base-(2 ^ RADIX_BITS) constexpr int RADIX_SIZE = 4; // 2 ^ RADIX_BITS constexpr int RADIX_MASK = (RADIX_SIZE - 1); // This finds the unique value `v` that matches the pattern // ((v & desired) == desiredMask) in our sorted int format template __device__ scalar_t findPattern( scalar_t* smem, const scalar_t* data, index_t sliceSize, index_t withinSliceStride, bitwise_t desired, bitwise_t desiredMask) { if (threadIdx.x < 2) { smem[threadIdx.x] = static_cast(0); } __syncthreads(); // All threads participate in the loop, in order to sync on the flag index_t numIterations = round_up(sliceSize, static_cast(blockDim.x)); for (index_t i = threadIdx.x; i < numIterations; i += blockDim.x) { bool inRange = (i < sliceSize); scalar_t v = inRange ? doLdg(&data[i * withinSliceStride]) : static_cast(0); if (inRange && ((TopKTypeConfig::convert(v) & desiredMask) == desired)) { // There should not be conflicts if we are using findPattern, // since the result is unique smem[0] = static_cast(1); smem[1] = v; // can't use val as the flag, since it could be 0 } __syncthreads(); scalar_t found = smem[0]; scalar_t val = smem[1]; __syncthreads(); // Check to see if a thread found the value if (found != static_cast(0)) { // all threads return this value return val; } } // should not get here CUDA_KERNEL_ASSERT(false); return static_cast(0); } // Returns the top-Kth element found in the data using radix selection template __device__ void radixSelect( const scalar_t* data, index_t k, bool largest, index_t sliceSize, index_t withinSliceStride, int* smem, scalar_t* topK) { // Per-thread buckets into which we accumulate digit counts in our // radix int counts[RADIX_SIZE]; // We only consider elements x such that (x & desiredMask) == desired // Initially, we consider all elements of the array, so the above // statement is true regardless of input. bitwise_t desired = 0; bitwise_t desiredMask = 0; // We are looking for the top kToFind-th element when iterating over // digits; this count gets reduced by elimination when counting // successive digits int kToFind = k; // We start at the most significant digit in our radix, scanning // through to the least significant digit for (int digitPos = sizeof(scalar_t) * 8 - RADIX_BITS; digitPos >= 0; digitPos -= RADIX_BITS) { // Count radix distribution for the current position and reduce // across all threads countRadixUsingMask< scalar_t, bitwise_t, index_t, int, RADIX_SIZE, RADIX_BITS>( counts, smem, desired, desiredMask, digitPos, sliceSize, withinSliceStride, data); auto found_unique = [&](int i, int count) -> bool { /* All threads have the same value in counts here, so all */ /* threads will return from the function. */ if (count == 1 && kToFind == 1) { /* There is a unique answer. */ desired = at::cuda::Bitfield::setBitfield( desired, i, digitPos, RADIX_BITS); desiredMask = at::cuda::Bitfield::setBitfield( desiredMask, RADIX_MASK, digitPos, RADIX_BITS); /* The answer is now the unique element v such that: */ /* (v & desiredMask) == desired */ /* However, we do not yet know what the actual element is. We */ /* need to perform a search through the data to find the */ /* element that matches this pattern. */ *topK = findPattern( (scalar_t*)smem, data, sliceSize, withinSliceStride, desired, desiredMask); return true; } return false; }; auto found_non_unique = [&](int i, int count) -> bool { if (count >= kToFind) { desired = at::cuda::Bitfield::setBitfield( desired, i, digitPos, RADIX_BITS); desiredMask = at::cuda::Bitfield::setBitfield( desiredMask, RADIX_MASK, digitPos, RADIX_BITS); /* The top-Kth element v must now be one such that: */ /* (v & desiredMask == desired) */ /* but we haven't narrowed it down; we must check the next */ /* least-significant digit */ return true; } kToFind -= count; return false; // continue the loop }; // All threads participate in the comparisons below to know the // final result if (largest) { // Process in descending order #pragma unroll for (int i = RADIX_SIZE - 1; i >= 0; --i) { int count = counts[i]; if (found_unique(i, count)) { return; } if (found_non_unique(i, count)) { break; } } } else { // Process in ascending order #pragma unroll for (int i = 0; i < RADIX_SIZE; ++i) { int count = counts[i]; if (found_unique(i, count)) { return; } if (found_non_unique(i, count)) { break; } } } } // end digitPos for // There is no unique result, but there is a non-unique result // matching `desired` exactly *topK = TopKTypeConfig::deconvert(desired); } } // namespace native } // namespace at