Spaces:
Running
Running
| template<int D, int parallel_blocks> // D == head size | |
| __launch_bounds__(D, 1) | |
| static __global__ void flash_attn_combine_results( | |
| const float * __restrict__ VKQ_parts, | |
| const float2 * __restrict__ VKQ_meta, | |
| float * __restrict__ dst) { | |
| VKQ_parts += parallel_blocks*D * gridDim.y*blockIdx.x; | |
| VKQ_meta += parallel_blocks * gridDim.y*blockIdx.x; | |
| dst += D * gridDim.y*blockIdx.x; | |
| const int tid = threadIdx.x; | |
| __builtin_assume(tid < D); | |
| __shared__ float2 meta[parallel_blocks]; | |
| if (tid < 2*parallel_blocks) { | |
| ((float *) meta)[threadIdx.x] = ((const float *)VKQ_meta) [blockIdx.y*(2*parallel_blocks) + tid]; | |
| } | |
| __syncthreads(); | |
| float kqmax = meta[0].x; | |
| for (int l = 1; l < parallel_blocks; ++l) { | |
| kqmax = max(kqmax, meta[l].x); | |
| } | |
| float VKQ_numerator = 0.0f; | |
| float VKQ_denominator = 0.0f; | |
| for (int l = 0; l < parallel_blocks; ++l) { | |
| const float diff = meta[l].x - kqmax; | |
| const float KQ_max_scale = expf(diff); | |
| const uint32_t ftz_mask = 0xFFFFFFFF * (diff > SOFTMAX_FTZ_THRESHOLD); | |
| *((uint32_t *) &KQ_max_scale) &= ftz_mask; | |
| VKQ_numerator += KQ_max_scale * VKQ_parts[l*gridDim.y*D + blockIdx.y*D + tid]; | |
| VKQ_denominator += KQ_max_scale * meta[l].y; | |
| } | |
| dst[blockIdx.y*D + tid] = VKQ_numerator / VKQ_denominator; | |
| } | |