#version 450 core #define PRECISION ${PRECISION} #define FORMAT ${FORMAT} layout(std430) buffer; /* Qualifiers: layout - storage - precision - memory */ /* * Output Image */ layout(set = 0, binding = 0, FORMAT) uniform PRECISION restrict writeonly image3D uOutput; /* * Input Buffer */ layout(set = 0, binding = 1) uniform PRECISION sampler3D uInput; /* * Params Buffer * input_shader_extents is the dimensions of the Vulkan 3D texture XYZ * with a zero pad at W. * input_tensor_dims is the dimensions of the NCHW PyTorch Tensor. * input_dim_stride is the stride to include elements along the softmax * dimension calculation. early_exit is the global workgroup position-based * condition for unnecessary invocations to exit. */ layout(set = 0, binding = 2) uniform PRECISION restrict Block { ivec4 input_shader_extents; ivec4 input_tensor_dims; ivec4 input_dim_stride; ivec4 early_exit; } uBlock; /* * Local Work Group Size */ layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; void main() { const ivec3 pos = ivec3(gl_GlobalInvocationID); // how "wide" a batch is in terms of z. Only have one invocation per batch, // as one batch width has elements from every channel in-memory. if (!all(lessThan(pos, uBlock.early_exit.xyz))) { return; } const int b_stride = int(ceil(uBlock.input_tensor_dims.y / 4.0)); const ivec3 src_pos = ivec3(pos.x, pos.y, pos.z * b_stride); // tail case, padded zeros in memory if tensor's channel dim % 4 != 0 uint tail_case_size = uBlock.input_tensor_dims.y % 4; if (tail_case_size == 0) { tail_case_size = 4; } // Calculate the denominator for the whole dimension. // For numerical stability to avoid floating point overflow, // we leverage the translation invariance of the softmax function, // subtracting every element along channel by the maximum element along // channel. find the maximum element float max_element = texelFetch(uInput, src_pos, 0)[0]; for (int c = 0; c < b_stride - 1; c++) { const vec4 c_texel = texelFetch(uInput, ivec3(src_pos.x, src_pos.y, src_pos.z + c), 0); for (int t = 0; t < 4; t++) { if (c_texel[t] > max_element) { max_element = c_texel[t]; } } } vec4 c_texel = texelFetch( uInput, ivec3(src_pos.x, src_pos.y, src_pos.z + b_stride - 1), 0); for (int t = 0; t < tail_case_size; t++) { if (c_texel[t] > max_element) { max_element = c_texel[t]; } } // Calculate the denominator. float denominator = 0; for (int c = 0; c < b_stride - 1; c++) { const vec4 c_texel = texelFetch(uInput, ivec3(src_pos.x, src_pos.y, src_pos.z + c), 0); for (int t = 0; t < 4; t++) { denominator += exp(c_texel[t] - max_element); } } c_texel = texelFetch( uInput, ivec3(src_pos.x, src_pos.y, src_pos.z + b_stride - 1), 0); for (int t = 0; t < tail_case_size; t++) { denominator += exp(c_texel[t] - max_element); } // Calculate every final channel element. for (int c = 0; c < b_stride; c++) { const ivec3 dst_pos = ivec3(src_pos.x, src_pos.y, src_pos.z + c); const vec4 numerator = exp(texelFetch(uInput, dst_pos, 0) - max_element); imageStore(uOutput, dst_pos, numerator / denominator); } }