mahorozte mahorozte commited on
Commit
9a8c238
·
1 Parent(s): c7e59ef

CUDA: remove unnecessary warp reduce in FA (ggml/1032)

Browse files

* kqmax_new_j in every thread within warp is same after operate at line 199,this reduce can be omit

* same problem in vec32

---------

Co-authored-by: ZhaoXiaoYu <[email protected]>

ggml/src/ggml-cuda/fattn-vec-f16.cuh CHANGED
@@ -220,7 +220,6 @@ static __global__ void flash_attn_vec_ext_f16(
220
  for (int j = 0; j < ncols; ++j) {
221
  half kqmax_new_j = ncols == 1 ? kqmax_new : kqmax_new_arr[j];
222
 
223
- kqmax_new_j = warp_reduce_max(kqmax_new_j);
224
  if (threadIdx.x == 0) {
225
  kqmax_shared[j][threadIdx.y] = kqmax_new_j;
226
  }
 
220
  for (int j = 0; j < ncols; ++j) {
221
  half kqmax_new_j = ncols == 1 ? kqmax_new : kqmax_new_arr[j];
222
 
 
223
  if (threadIdx.x == 0) {
224
  kqmax_shared[j][threadIdx.y] = kqmax_new_j;
225
  }
ggml/src/ggml-cuda/fattn-vec-f32.cuh CHANGED
@@ -206,7 +206,6 @@ static __global__ void flash_attn_vec_ext_f32(
206
  for (int j = 0; j < ncols; ++j) {
207
  float kqmax_new_j = kqmax_new_arr[j];
208
 
209
- kqmax_new_j = warp_reduce_max(kqmax_new_j);
210
  if (threadIdx.x == 0) {
211
  kqmax_shared[j][threadIdx.y] = kqmax_new_j;
212
  }
 
206
  for (int j = 0; j < ncols; ++j) {
207
  float kqmax_new_j = kqmax_new_arr[j];
208
 
 
209
  if (threadIdx.x == 0) {
210
  kqmax_shared[j][threadIdx.y] = kqmax_new_j;
211
  }