Spaces:
Running
Running
CUDA: remove unnecessary warp reduce in FA (ggml/1032)
Browse files* kqmax_new_j in every thread within warp is same after operate at line 199,this reduce can be omit
* same problem in vec32
---------
Co-authored-by: ZhaoXiaoYu <[email protected]>
ggml/src/ggml-cuda/fattn-vec-f16.cuh
CHANGED
|
@@ -220,7 +220,6 @@ static __global__ void flash_attn_vec_ext_f16(
|
|
| 220 |
for (int j = 0; j < ncols; ++j) {
|
| 221 |
half kqmax_new_j = ncols == 1 ? kqmax_new : kqmax_new_arr[j];
|
| 222 |
|
| 223 |
-
kqmax_new_j = warp_reduce_max(kqmax_new_j);
|
| 224 |
if (threadIdx.x == 0) {
|
| 225 |
kqmax_shared[j][threadIdx.y] = kqmax_new_j;
|
| 226 |
}
|
|
|
|
| 220 |
for (int j = 0; j < ncols; ++j) {
|
| 221 |
half kqmax_new_j = ncols == 1 ? kqmax_new : kqmax_new_arr[j];
|
| 222 |
|
|
|
|
| 223 |
if (threadIdx.x == 0) {
|
| 224 |
kqmax_shared[j][threadIdx.y] = kqmax_new_j;
|
| 225 |
}
|
ggml/src/ggml-cuda/fattn-vec-f32.cuh
CHANGED
|
@@ -206,7 +206,6 @@ static __global__ void flash_attn_vec_ext_f32(
|
|
| 206 |
for (int j = 0; j < ncols; ++j) {
|
| 207 |
float kqmax_new_j = kqmax_new_arr[j];
|
| 208 |
|
| 209 |
-
kqmax_new_j = warp_reduce_max(kqmax_new_j);
|
| 210 |
if (threadIdx.x == 0) {
|
| 211 |
kqmax_shared[j][threadIdx.y] = kqmax_new_j;
|
| 212 |
}
|
|
|
|
| 206 |
for (int j = 0; j < ncols; ++j) {
|
| 207 |
float kqmax_new_j = kqmax_new_arr[j];
|
| 208 |
|
|
|
|
| 209 |
if (threadIdx.x == 0) {
|
| 210 |
kqmax_shared[j][threadIdx.y] = kqmax_new_j;
|
| 211 |
}
|