ggerganov commited on
Commit
ddc04a3
·
1 Parent(s): 305dc4e

cmake : fix CUDA build (#0)

Browse files
Files changed (4) hide show
  1. CMakeLists.txt +52 -1
  2. Makefile +15 -1
  3. ggml-cuda/fattn-vec-f16.cu +0 -430
  4. ggml-cuda/fattn-vec-f32.cu +0 -384
CMakeLists.txt CHANGED
@@ -86,6 +86,7 @@ else()
86
  option(WHISPER_OPENBLAS "whisper: prefer OpenBLAS" OFF)
87
  option(WHISPER_OPENBLAS_INTERFACE64 "whisper: use OpenBLAS w/ 64-bit interface" OFF)
88
  option(WHISPER_CUDA "whisper: support for CUDA" OFF)
 
89
  option(WHISPER_CUBLAS "whisper: support for CUDA (deprecated)" OFF)
90
  option(WHISPER_HIPBLAS "whisper: support for hipBLAS" OFF)
91
  option(WHISPER_CLBLAST "whisper: use CLBlast" OFF)
@@ -346,19 +347,51 @@ if (WHISPER_CUBLAS)
346
  endif()
347
 
348
  if (WHISPER_CUDA)
349
- cmake_minimum_required(VERSION 3.17)
350
 
351
  find_package(CUDAToolkit)
352
 
353
  if (CUDAToolkit_FOUND)
354
  message(STATUS "cuBLAS found")
355
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
356
  enable_language(CUDA)
357
 
358
  file(GLOB GGML_SOURCES_CUDA "ggml-cuda/*.cu")
359
  list(APPEND GGML_SOURCES_CUDA ggml-cuda.h)
360
  list(APPEND GGML_SOURCES_CUDA ggml-cuda.cu)
361
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
362
  add_compile_definitions(GGML_USE_CUDA)
363
 
364
  if (WHISPER_STATIC)
@@ -399,6 +432,24 @@ if (WHISPER_HIPBLAS)
399
  file(GLOB GGML_SOURCES_ROCM "ggml-cuda/*.cu")
400
  list(APPEND GGML_SOURCES_ROCM "ggml-cuda.cu")
401
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
402
  add_compile_definitions(GGML_USE_HIPBLAS GGML_USE_CUDA)
403
 
404
  set_source_files_properties(${GGML_SOURCES_ROCM} PROPERTIES LANGUAGE CXX)
 
86
  option(WHISPER_OPENBLAS "whisper: prefer OpenBLAS" OFF)
87
  option(WHISPER_OPENBLAS_INTERFACE64 "whisper: use OpenBLAS w/ 64-bit interface" OFF)
88
  option(WHISPER_CUDA "whisper: support for CUDA" OFF)
89
+ option(WHISPER_CUDA_FA_ALL_QUANTS "whisper: compile all quants for FlashAttention" OFF)
90
  option(WHISPER_CUBLAS "whisper: support for CUDA (deprecated)" OFF)
91
  option(WHISPER_HIPBLAS "whisper: support for hipBLAS" OFF)
92
  option(WHISPER_CLBLAST "whisper: use CLBlast" OFF)
 
347
  endif()
348
 
349
  if (WHISPER_CUDA)
350
+ cmake_minimum_required(VERSION 3.18) # for CMAKE_CUDA_ARCHITECTURES
351
 
352
  find_package(CUDAToolkit)
353
 
354
  if (CUDAToolkit_FOUND)
355
  message(STATUS "cuBLAS found")
356
 
357
+ if (NOT DEFINED CMAKE_CUDA_ARCHITECTURES)
358
+ # 52 == lowest CUDA 12 standard
359
+ # 60 == f16 CUDA intrinsics
360
+ # 61 == integer CUDA intrinsics
361
+ # 70 == compute capability at which unrolling a loop in mul_mat_q kernels is faster
362
+ if (WHISPER_CUDA_F16 OR WHISPER_CUDA_DMMV_F16)
363
+ set(CMAKE_CUDA_ARCHITECTURES "60;61;70") # needed for f16 CUDA intrinsics
364
+ else()
365
+ set(CMAKE_CUDA_ARCHITECTURES "52;61;70") # lowest CUDA 12 standard + lowest for integer intrinsics
366
+ #set(CMAKE_CUDA_ARCHITECTURES "OFF") # use this to compile much faster, but only F16 models work
367
+ endif()
368
+ endif()
369
+ message(STATUS "Using CUDA architectures: ${CMAKE_CUDA_ARCHITECTURES}")
370
+
371
  enable_language(CUDA)
372
 
373
  file(GLOB GGML_SOURCES_CUDA "ggml-cuda/*.cu")
374
  list(APPEND GGML_SOURCES_CUDA ggml-cuda.h)
375
  list(APPEND GGML_SOURCES_CUDA ggml-cuda.cu)
376
 
377
+ file(GLOB SRCS "ggml-cuda/template-instances/fattn-wmma*.cu")
378
+ list(APPEND GGML_SOURCES_CUDA ${SRCS})
379
+ file(GLOB SRCS "ggml-cuda/template-instances/mmq*.cu")
380
+ list(APPEND GGML_SOURCES_CUDA ${SRCS})
381
+
382
+ if (WHISPER_CUDA_FA_ALL_QUANTS)
383
+ file(GLOB SRCS "ggml-cuda/template-instances/fattn-vec*.cu")
384
+ list(APPEND GGML_SOURCES_CUDA ${SRCS})
385
+ add_compile_definitions(GGML_CUDA_FA_ALL_QUANTS)
386
+ else()
387
+ file(GLOB SRCS "ggml-cuda/template-instances/fattn-vec*q4_0-q4_0.cu")
388
+ list(APPEND GGML_SOURCES_CUDA ${SRCS})
389
+ file(GLOB SRCS "ggml-cuda/template-instances/fattn-vec*q8_0-q8_0.cu")
390
+ list(APPEND GGML_SOURCES_CUDA ${SRCS})
391
+ file(GLOB SRCS "ggml-cuda/template-instances/fattn-vec*f16-f16.cu")
392
+ list(APPEND GGML_SOURCES_CUDA ${SRCS})
393
+ endif()
394
+
395
  add_compile_definitions(GGML_USE_CUDA)
396
 
397
  if (WHISPER_STATIC)
 
432
  file(GLOB GGML_SOURCES_ROCM "ggml-cuda/*.cu")
433
  list(APPEND GGML_SOURCES_ROCM "ggml-cuda.cu")
434
 
435
+ file(GLOB SRCS "ggml-cuda/template-instances/fattn-wmma*.cu")
436
+ list(APPEND GGML_SOURCES_CUDA ${SRCS})
437
+ file(GLOB SRCS "ggml-cuda/template-instances/mmq*.cu")
438
+ list(APPEND GGML_SOURCES_CUDA ${SRCS})
439
+
440
+ if (WHISPER_CUDA_FA_ALL_QUANTS)
441
+ file(GLOB SRCS "ggml-cuda/template-instances/fattn-vec*.cu")
442
+ list(APPEND GGML_SOURCES_CUDA ${SRCS})
443
+ add_compile_definitions(GGML_CUDA_FA_ALL_QUANTS)
444
+ else()
445
+ file(GLOB SRCS "ggml-cuda/template-instances/fattn-vec*q4_0-q4_0.cu")
446
+ list(APPEND GGML_SOURCES_CUDA ${SRCS})
447
+ file(GLOB SRCS "ggml-cuda/template-instances/fattn-vec*q8_0-q8_0.cu")
448
+ list(APPEND GGML_SOURCES_CUDA ${SRCS})
449
+ file(GLOB SRCS "ggml-cuda/template-instances/fattn-vec*f16-f16.cu")
450
+ list(APPEND GGML_SOURCES_CUDA ${SRCS})
451
+ endif()
452
+
453
  add_compile_definitions(GGML_USE_HIPBLAS GGML_USE_CUDA)
454
 
455
  set_source_files_properties(${GGML_SOURCES_ROCM} PROPERTIES LANGUAGE CXX)
Makefile CHANGED
@@ -277,6 +277,16 @@ ifdef WHISPER_CUBLAS
277
  WHISPER_CUDA := 1
278
  endif
279
 
 
 
 
 
 
 
 
 
 
 
280
  ifdef WHISPER_CUDA
281
  ifeq ($(shell expr $(NVCC_VERSION) \>= 11.6), 1)
282
  CUDA_ARCH_FLAG ?= native
@@ -289,10 +299,11 @@ ifdef WHISPER_CUDA
289
  LDFLAGS += -lcuda -lcublas -lculibos -lcudart -lcublasLt -lcufft -lpthread -ldl -lrt -L/usr/local/cuda/lib64 -L/opt/cuda/lib64 -L$(CUDA_PATH)/targets/$(UNAME_M)-linux/lib -L/usr/lib/wsl/lib
290
  WHISPER_OBJ += ggml-cuda.o whisper-mel-cuda.o
291
  WHISPER_OBJ += $(patsubst %.cu,%.o,$(wildcard ggml-cuda/*.cu))
 
292
  NVCC = nvcc
293
  NVCCFLAGS = --forward-unknown-to-host-compiler -arch=$(CUDA_ARCH_FLAG)
294
 
295
- ggml-cuda/%.o: ggml-cuda/%.cu ggml-cuda/%.cuh ggml.h ggml-common.h ggml-cuda/common.cuh
296
  $(NVCC) $(NVCCFLAGS) $(CXXFLAGS) -c $< -o $@
297
 
298
  ggml-cuda.o: ggml-cuda.cu ggml-cuda.h ggml.h ggml-backend.h ggml-backend-impl.h ggml-common.h $(wildcard ggml-cuda/*.cuh)
@@ -313,6 +324,7 @@ ifdef WHISPER_HIPBLAS
313
  HIPFLAGS += $(addprefix --offload-arch=,$(GPU_TARGETS))
314
  WHISPER_OBJ += ggml-cuda.o
315
  WHISPER_OBJ += $(patsubst %.cu,%.o,$(wildcard ggml-cuda/*.cu))
 
316
 
317
  ggml-cuda/%.o: ggml-cuda/%.cu ggml-cuda/%.cuh ggml.h ggml-common.h ggml-cuda/common.cuh
318
  $(HIPCC) $(CXXFLAGS) $(HIPFLAGS) -x hip -c -o $@ $<
@@ -457,6 +469,8 @@ libwhisper.so: $(WHISPER_OBJ)
457
 
458
  clean:
459
  rm -f *.o main stream command talk talk-llama bench quantize server lsp libwhisper.a libwhisper.so
 
 
460
 
461
  #
462
  # Examples
 
277
  WHISPER_CUDA := 1
278
  endif
279
 
280
+ OBJS_CUDA_TEMP_INST = $(patsubst %.cu,%.o,$(wildcard ggml-cuda/template-instances/fattn-wmma*.cu))
281
+ OBJS_CUDA_TEMP_INST += $(patsubst %.cu,%.o,$(wildcard ggml-cuda/template-instances/mmq*.cu))
282
+ ifdef WHISPER_CUDA_FA_ALL_QUANTS
283
+ OBJS_CUDA_TEMP_INST += $(patsubst %.cu,%.o,$(wildcard ggml-cuda/template-instances/fattn-vec*.cu))
284
+ else
285
+ OBJS_CUDA_TEMP_INST += $(patsubst %.cu,%.o,$(wildcard ggml-cuda/template-instances/fattn-vec*q4_0-q4_0.cu))
286
+ OBJS_CUDA_TEMP_INST += $(patsubst %.cu,%.o,$(wildcard ggml-cuda/template-instances/fattn-vec*q8_0-q8_0.cu))
287
+ OBJS_CUDA_TEMP_INST += $(patsubst %.cu,%.o,$(wildcard ggml-cuda/template-instances/fattn-vec*f16-f16.cu))
288
+ endif # WHISPER_CUDA_FA_ALL_QUANTS
289
+
290
  ifdef WHISPER_CUDA
291
  ifeq ($(shell expr $(NVCC_VERSION) \>= 11.6), 1)
292
  CUDA_ARCH_FLAG ?= native
 
299
  LDFLAGS += -lcuda -lcublas -lculibos -lcudart -lcublasLt -lcufft -lpthread -ldl -lrt -L/usr/local/cuda/lib64 -L/opt/cuda/lib64 -L$(CUDA_PATH)/targets/$(UNAME_M)-linux/lib -L/usr/lib/wsl/lib
300
  WHISPER_OBJ += ggml-cuda.o whisper-mel-cuda.o
301
  WHISPER_OBJ += $(patsubst %.cu,%.o,$(wildcard ggml-cuda/*.cu))
302
+ WHISPER_OBJ += $(OBJS_CUDA_TEMP_INST)
303
  NVCC = nvcc
304
  NVCCFLAGS = --forward-unknown-to-host-compiler -arch=$(CUDA_ARCH_FLAG)
305
 
306
+ ggml-cuda/%.o: ggml-cuda/%.cu ggml.h ggml-common.h ggml-cuda/common.cuh
307
  $(NVCC) $(NVCCFLAGS) $(CXXFLAGS) -c $< -o $@
308
 
309
  ggml-cuda.o: ggml-cuda.cu ggml-cuda.h ggml.h ggml-backend.h ggml-backend-impl.h ggml-common.h $(wildcard ggml-cuda/*.cuh)
 
324
  HIPFLAGS += $(addprefix --offload-arch=,$(GPU_TARGETS))
325
  WHISPER_OBJ += ggml-cuda.o
326
  WHISPER_OBJ += $(patsubst %.cu,%.o,$(wildcard ggml-cuda/*.cu))
327
+ WHISPER_OBJ += $(OBJS_CUDA_TEMP_INST)
328
 
329
  ggml-cuda/%.o: ggml-cuda/%.cu ggml-cuda/%.cuh ggml.h ggml-common.h ggml-cuda/common.cuh
330
  $(HIPCC) $(CXXFLAGS) $(HIPFLAGS) -x hip -c -o $@ $<
 
469
 
470
  clean:
471
  rm -f *.o main stream command talk talk-llama bench quantize server lsp libwhisper.a libwhisper.so
472
+ rm -vrf ggml-cuda/*.o
473
+ rm -vrf ggml-cuda/template-instances/*.o
474
 
475
  #
476
  # Examples
ggml-cuda/fattn-vec-f16.cu DELETED
@@ -1,430 +0,0 @@
1
- #include "common.cuh"
2
- #include "fattn-common.cuh"
3
- #include "fattn-vec-f16.cuh"
4
-
5
- template<int D, int ncols, int parallel_blocks> // D == head size
6
- #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
7
- __launch_bounds__(D, 1)
8
- #endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
9
- static __global__ void flash_attn_vec_ext_f16(
10
- const char * __restrict__ Q,
11
- const char * __restrict__ K,
12
- const char * __restrict__ V,
13
- const char * __restrict__ mask,
14
- float * __restrict__ dst,
15
- float2 * __restrict__ dst_meta,
16
- const float scale,
17
- const float max_bias,
18
- const float m0,
19
- const float m1,
20
- const uint32_t n_head_log2,
21
- const int ne00,
22
- const int ne01,
23
- const int ne02,
24
- const int ne03,
25
- const int ne10,
26
- const int ne11,
27
- const int ne12,
28
- const int ne13,
29
- const int ne31,
30
- const int nb31,
31
- const int nb01,
32
- const int nb02,
33
- const int nb03,
34
- const int nb11,
35
- const int nb12,
36
- const int nb13,
37
- const int ne0,
38
- const int ne1,
39
- const int ne2,
40
- const int ne3) {
41
- #if FP16_AVAILABLE
42
- //In this kernel Q, K, V are matrices while i, j, k are matrix indices.
43
-
44
- const int ic0 = (blockIdx.x / parallel_blocks) * ncols; // Index of the Q/QKV column to work on.
45
- const int ip = blockIdx.x % parallel_blocks; // Index in group of blocks running for the same column in parallel.
46
-
47
- const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
48
- const float2 * Q_f2 = (const float2 *) (Q + nb02* blockIdx.y + nb01*ic0);
49
- const half2 * K_h2 = (const half2 *) (K + nb12*(blockIdx.y / gqa_ratio));
50
- const half * V_h = (const half *) (V + nb12*(blockIdx.y / gqa_ratio)); // K and V have same shape
51
- const half * maskh = (const half *) mask + ne11*ic0;
52
-
53
- const int stride_KV = nb11 / sizeof(half);
54
- const int stride_KV2 = nb11 / sizeof(half2);
55
-
56
- half slopeh = __float2half(1.0f);
57
-
58
- // ALiBi
59
- if (max_bias > 0.0f) {
60
- const int h = blockIdx.y;
61
-
62
- const float base = h < n_head_log2 ? m0 : m1;
63
- const int exph = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
64
-
65
- slopeh = __float2half(powf(base, exph));
66
- }
67
-
68
- static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64.");
69
- constexpr int nwarps = D / WARP_SIZE;
70
- const int tid = WARP_SIZE*threadIdx.y + threadIdx.x;
71
- __builtin_assume(tid < D);
72
-
73
- __shared__ half KQ[ncols*D];
74
- #pragma unroll
75
- for (int j = 0; j < ncols; ++j) {
76
- KQ[j*D + tid] = -HALF_MAX_HALF;
77
- }
78
- half2 * KQ2 = (half2 *) KQ;
79
-
80
- half kqmax[ncols];
81
- #pragma unroll
82
- for (int j = 0; j < ncols; ++j) {
83
- kqmax[j] = -HALF_MAX_HALF;
84
- }
85
- half kqsum[ncols] = {0.0f};
86
-
87
- __shared__ half kqmax_shared[ncols][WARP_SIZE];
88
- __shared__ half kqsum_shared[ncols][WARP_SIZE];
89
- #pragma unroll
90
- for (int j = 0; j < ncols; ++j) {
91
- if (threadIdx.y == 0) {
92
- kqmax_shared[j][threadIdx.x] = -HALF_MAX_HALF;
93
- kqsum_shared[j][threadIdx.x] = 0.0f;
94
- }
95
- }
96
- __syncthreads();
97
-
98
- // Convert Q to half2 and store in registers:
99
- half2 Q_h2[ncols][D/(2*WARP_SIZE)];
100
- #pragma unroll
101
- for (int j = 0; j < ncols; ++j) {
102
- #pragma unroll
103
- for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
104
- const int i = i0 + threadIdx.x;
105
-
106
- const float2 tmp = Q_f2[j*(nb01/sizeof(float2)) + i];
107
- Q_h2[j][i0/WARP_SIZE] = make_half2(scale, scale) * make_half2(tmp.x, tmp.y);
108
- }
109
- }
110
-
111
- half2 VKQ[ncols] = {{0.0f, 0.0f}};
112
-
113
- const int k_start = parallel_blocks == 1 ? 0 : ip*D;
114
- for (int k_VKQ_0 = k_start; k_VKQ_0 < ne11; k_VKQ_0 += parallel_blocks*D) {
115
- // Calculate KQ tile and keep track of new maximum KQ values:
116
-
117
- // For unknown reasons using a half array of size 1 for kqmax_new causes a performance regression,
118
- // see https://github.com/ggerganov/llama.cpp/pull/7061 .
119
- // Therefore this variable is defined twice but only used once (so that the compiler can optimize out the unused variable).
120
- half kqmax_new = kqmax[0];
121
- half kqmax_new_arr[ncols];
122
- #pragma unroll
123
- for (int j = 0; j < ncols; ++j) {
124
- kqmax_new_arr[j] = kqmax[j];
125
- }
126
-
127
- #pragma unroll
128
- for (int i_KQ_0 = 0; i_KQ_0 < D; i_KQ_0 += nwarps) {
129
- const int i_KQ = i_KQ_0 + threadIdx.y;
130
-
131
- if ((i_KQ_0 + nwarps > D && i_KQ >= D) || (FATTN_KQ_STRIDE % D != 0 && k_VKQ_0 + i_KQ >= ne11)) {
132
- break;
133
- }
134
-
135
- half2 sum2[ncols] = {{0.0f, 0.0f}};
136
- #pragma unroll
137
- for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += WARP_SIZE) {
138
- const int k_KQ = k_KQ_0 + threadIdx.x;
139
-
140
- const half2 K_ik = K_h2[(k_VKQ_0 + i_KQ)*stride_KV2 + k_KQ];
141
- #pragma unroll
142
- for (int j = 0; j < ncols; ++j) {
143
- sum2[j] += K_ik * Q_h2[j][k_KQ_0/WARP_SIZE];
144
- }
145
- }
146
-
147
- #pragma unroll
148
- for (int j = 0; j < ncols; ++j) {
149
- sum2[j] = warp_reduce_sum(sum2[j]);
150
- half sum = __low2half(sum2[j]) + __high2half(sum2[j]);
151
- sum += mask ? slopeh*maskh[j*ne11 + k_VKQ_0 + i_KQ] : __float2half(0.0f);
152
-
153
- if (ncols == 1) {
154
- kqmax_new = ggml_cuda_hmax(kqmax_new, sum);
155
- } else {
156
- kqmax_new_arr[j] = ggml_cuda_hmax(kqmax_new_arr[j], sum);
157
- }
158
-
159
- if (threadIdx.x == 0) {
160
- KQ[j*D + i_KQ] = sum;
161
- }
162
- }
163
- }
164
-
165
- #pragma unroll
166
- for (int j = 0; j < ncols; ++j) {
167
- half kqmax_new_j = ncols == 1 ? kqmax_new : kqmax_new_arr[j];
168
-
169
- kqmax_new_j = warp_reduce_max(kqmax_new_j);
170
- if (threadIdx.x == 0) {
171
- kqmax_shared[j][threadIdx.y] = kqmax_new_j;
172
- }
173
- }
174
-
175
- __syncthreads();
176
-
177
- #pragma unroll
178
- for (int j = 0; j < ncols; ++j) {
179
- half kqmax_new_j = kqmax_shared[j][threadIdx.x];
180
- kqmax_new_j = warp_reduce_max(kqmax_new_j);
181
-
182
- const half KQ_max_scale = hexp(kqmax[j] - kqmax_new_j);
183
- kqmax[j] = kqmax_new_j;
184
-
185
- const half val = hexp(KQ[j*D + tid] - kqmax[j]);
186
- kqsum[j] = kqsum[j]*KQ_max_scale + val;
187
- KQ[j*D + tid] = val;
188
-
189
- VKQ[j] *= __half2half2(KQ_max_scale);
190
- }
191
-
192
- __syncthreads();
193
-
194
- #pragma unroll
195
- for (int k0 = 0; k0 < D; k0 += 2) {
196
- if (FATTN_KQ_STRIDE % D != 0 && k_VKQ_0 + k0 >= ne11) {
197
- break;
198
- }
199
-
200
- half2 V_k;
201
- reinterpret_cast<half&>(V_k.x) = V_h[(k_VKQ_0 + k0 + 0)*stride_KV + tid];
202
- reinterpret_cast<half&>(V_k.y) = V_h[(k_VKQ_0 + k0 + 1)*stride_KV + tid];
203
- #pragma unroll
204
- for (int j = 0; j < ncols; ++j) {
205
- VKQ[j] += V_k*KQ2[j*(D/2) + k0/2];
206
- }
207
- }
208
-
209
- __syncthreads();
210
- }
211
-
212
- #pragma unroll
213
- for (int j = 0; j < ncols; ++j) {
214
- kqsum[j] = warp_reduce_sum(kqsum[j]);
215
- if (threadIdx.x == 0) {
216
- kqsum_shared[j][threadIdx.y] = kqsum[j];
217
- }
218
- }
219
-
220
- __syncthreads();
221
-
222
- #pragma unroll
223
- for (int j_VKQ = 0; j_VKQ < ncols; ++j_VKQ) {
224
- kqsum[j_VKQ] = kqsum_shared[j_VKQ][threadIdx.x];
225
- kqsum[j_VKQ] = warp_reduce_sum(kqsum[j_VKQ]);
226
-
227
- half dst_val = (__low2half(VKQ[j_VKQ]) + __high2half(VKQ[j_VKQ]));
228
- if (parallel_blocks == 1) {
229
- dst_val /= kqsum[j_VKQ];
230
- }
231
- const int j_dst = (ic0 + j_VKQ)*parallel_blocks + ip;
232
- dst[j_dst*D*gridDim.y + D*blockIdx.y + tid] = dst_val;
233
- }
234
-
235
- if (parallel_blocks != 1 && tid != 0) {
236
- #pragma unroll
237
- for (int j = 0; j < ncols; ++j) {
238
- dst_meta[(ic0 + j)*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = make_float2(kqmax[j], kqsum[j]);
239
- }
240
- }
241
- #else
242
- NO_DEVICE_CODE;
243
- #endif // FP16_AVAILABLE
244
- }
245
-
246
- template <int D, int cols_per_block, int parallel_blocks> void launch_fattn_vec_f16(
247
- const ggml_tensor * Q, const ggml_tensor * K, const ggml_tensor * V, ggml_tensor * KQV, const ggml_tensor * mask,
248
- ggml_cuda_pool & pool, cudaStream_t main_stream
249
- ) {
250
- ggml_cuda_pool_alloc<float> dst_tmp(pool);
251
- ggml_cuda_pool_alloc<float2> dst_tmp_meta(pool);
252
-
253
- if (parallel_blocks > 1) {
254
- dst_tmp.alloc(parallel_blocks*ggml_nelements(KQV));
255
- dst_tmp_meta.alloc(parallel_blocks*ggml_nrows(KQV));
256
- }
257
-
258
- constexpr int nwarps = (D + WARP_SIZE - 1) / WARP_SIZE;
259
- const dim3 block_dim(WARP_SIZE, nwarps, 1);
260
- const dim3 blocks_num(parallel_blocks*((Q->ne[1] + cols_per_block - 1) / cols_per_block), Q->ne[2], Q->ne[3]);
261
- const int shmem = 0;
262
-
263
- float scale = 1.0f;
264
- float max_bias = 0.0f;
265
-
266
- memcpy(&scale, (float *) KQV->op_params + 0, sizeof(float));
267
- memcpy(&max_bias, (float *) KQV->op_params + 1, sizeof(float));
268
-
269
- const uint32_t n_head = Q->ne[2];
270
- const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));
271
-
272
- const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
273
- const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
274
-
275
- flash_attn_vec_ext_f16<D, cols_per_block, parallel_blocks>
276
- <<<blocks_num, block_dim, shmem, main_stream>>> (
277
- (const char *) Q->data,
278
- (const char *) K->data,
279
- (const char *) V->data,
280
- mask ? ((const char *) mask->data) : nullptr,
281
- parallel_blocks == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr,
282
- scale, max_bias, m0, m1, n_head_log2,
283
- Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
284
- K->ne[0], K->ne[1], K->ne[2], K->ne[3],
285
- mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
286
- Q->nb[1], Q->nb[2], Q->nb[3],
287
- K->nb[1], K->nb[2], K->nb[3],
288
- KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
289
- );
290
- CUDA_CHECK(cudaGetLastError());
291
-
292
- if (parallel_blocks == 1) {
293
- return;
294
- }
295
-
296
- const dim3 block_dim_combine(D, 1, 1);
297
- const dim3 blocks_num_combine(Q->ne[1], blocks_num.y, blocks_num.z);
298
- const int shmem_combine = 0;
299
-
300
- flash_attn_combine_results<D, parallel_blocks>
301
- <<<blocks_num_combine, block_dim_combine, shmem_combine, main_stream>>>
302
- (dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data);
303
- CUDA_CHECK(cudaGetLastError());
304
- }
305
-
306
- void ggml_cuda_flash_attn_ext_vec_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
307
- const ggml_tensor * Q = dst->src[0];
308
- const ggml_tensor * K = dst->src[1];
309
- const ggml_tensor * V = dst->src[2];
310
-
311
- const ggml_tensor * mask = dst->src[3];
312
-
313
- ggml_tensor * KQV = dst;
314
-
315
- const int32_t precision = KQV->op_params[2];
316
- GGML_ASSERT(precision == GGML_PREC_DEFAULT);
317
-
318
- constexpr int cols_per_block = 1;
319
- constexpr int parallel_blocks = 4;
320
- switch (Q->ne[0]) {
321
- case 64:
322
- launch_fattn_vec_f16< 64, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
323
- break;
324
- case 128:
325
- launch_fattn_vec_f16<128, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
326
- break;
327
- case 256:
328
- launch_fattn_vec_f16<256, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
329
- break;
330
- default:
331
- GGML_ASSERT(false);
332
- break;
333
- }
334
- }
335
-
336
- void ggml_cuda_flash_attn_ext_vec_f16_no_mma(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
337
- const ggml_tensor * Q = dst->src[0];
338
- const ggml_tensor * K = dst->src[1];
339
- const ggml_tensor * V = dst->src[2];
340
-
341
- const ggml_tensor * mask = dst->src[3];
342
-
343
- ggml_tensor * KQV = dst;
344
-
345
- const int32_t precision = KQV->op_params[2];
346
- GGML_ASSERT(precision == GGML_PREC_DEFAULT);
347
- GGML_ASSERT(Q->ne[0] == 64 || Q->ne[0] == 128 && "FlashAttention without tensor cores only supports head sizes 64 and 128.");
348
-
349
- if (Q->ne[1] == 1) {
350
- constexpr int cols_per_block = 1;
351
- constexpr int parallel_blocks = 4;
352
- switch (Q->ne[0]) {
353
- case 64:
354
- launch_fattn_vec_f16< 64, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
355
- break;
356
- case 128:
357
- launch_fattn_vec_f16<128, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
358
- break;
359
- default:
360
- GGML_ASSERT(false);
361
- break;
362
- }
363
- return;
364
- }
365
-
366
- if (Q->ne[1] == 2) {
367
- constexpr int cols_per_block = 2;
368
- constexpr int parallel_blocks = 4;
369
- switch (Q->ne[0]) {
370
- case 64:
371
- launch_fattn_vec_f16< 64, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
372
- break;
373
- case 128:
374
- launch_fattn_vec_f16<128, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
375
- break;
376
- default:
377
- GGML_ASSERT(false);
378
- break;
379
- }
380
- return;
381
- }
382
-
383
- if (Q->ne[1] <= 4) {
384
- constexpr int cols_per_block = 4;
385
- constexpr int parallel_blocks = 4;
386
- switch (Q->ne[0]) {
387
- case 64:
388
- launch_fattn_vec_f16< 64, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
389
- break;
390
- case 128:
391
- launch_fattn_vec_f16<128, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
392
- break;
393
- default:
394
- GGML_ASSERT(false);
395
- break;
396
- }
397
- return;
398
- }
399
-
400
- if (Q->ne[1] <= 8) {
401
- constexpr int cols_per_block = 8;
402
- constexpr int parallel_blocks = 4;
403
- switch (Q->ne[0]) {
404
- case 64:
405
- launch_fattn_vec_f16< 64, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
406
- break;
407
- case 128:
408
- launch_fattn_vec_f16<128, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
409
- break;
410
- default:
411
- GGML_ASSERT(false);
412
- break;
413
- }
414
- return;
415
- }
416
-
417
- constexpr int cols_per_block = 8;
418
- constexpr int parallel_blocks = 1;
419
- switch (Q->ne[0]) {
420
- case 64:
421
- launch_fattn_vec_f16< 64, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
422
- break;
423
- case 128:
424
- launch_fattn_vec_f16<128, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
425
- break;
426
- default:
427
- GGML_ASSERT(false);
428
- break;
429
- }
430
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ggml-cuda/fattn-vec-f32.cu DELETED
@@ -1,384 +0,0 @@
1
- #include "common.cuh"
2
- #include "fattn-common.cuh"
3
- #include "fattn-vec-f32.cuh"
4
-
5
- template<int D, int ncols, int parallel_blocks> // D == head size
6
- #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
7
- __launch_bounds__(D, 1)
8
- #endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
9
- static __global__ void flash_attn_vec_ext_f32(
10
- const char * __restrict__ Q,
11
- const char * __restrict__ K,
12
- const char * __restrict__ V,
13
- const char * __restrict__ mask,
14
- float * __restrict__ dst,
15
- float2 * __restrict__ dst_meta,
16
- const float scale,
17
- const float max_bias,
18
- const float m0,
19
- const float m1,
20
- const uint32_t n_head_log2,
21
- const int ne00,
22
- const int ne01,
23
- const int ne02,
24
- const int ne03,
25
- const int ne10,
26
- const int ne11,
27
- const int ne12,
28
- const int ne13,
29
- const int ne31,
30
- const int nb31,
31
- const int nb01,
32
- const int nb02,
33
- const int nb03,
34
- const int nb11,
35
- const int nb12,
36
- const int nb13,
37
- const int ne0,
38
- const int ne1,
39
- const int ne2,
40
- const int ne3) {
41
- //In this kernel Q, K, V are matrices while i, j, k are matrix indices.
42
-
43
- const int ic0 = (blockIdx.x / parallel_blocks) * ncols; // Index of the Q/QKV column to work on.
44
- const int ip = blockIdx.x % parallel_blocks; // Index in group of blocks running for the same column in parallel.
45
-
46
- const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
47
- const float2 * Q_f2 = (const float2 *) (Q + nb02* blockIdx.y + nb01*ic0);
48
- const half2 * K_h2 = (const half2 *) (K + nb12*(blockIdx.y / gqa_ratio));
49
- const half * V_h = (const half *) (V + nb12*(blockIdx.y / gqa_ratio)); // K and V have same shape
50
- const half * maskh = (const half *) mask + ne11*ic0;
51
-
52
- const int stride_KV = nb11 / sizeof(half);
53
- const int stride_KV2 = nb11 / sizeof(half2);
54
-
55
- float slope = 1.0f;
56
-
57
- // ALiBi
58
- if (max_bias > 0.0f) {
59
- const int h = blockIdx.y;
60
-
61
- const float base = h < n_head_log2 ? m0 : m1;
62
- const int exph = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
63
-
64
- slope = powf(base, exph);
65
- }
66
-
67
- static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64.");
68
- constexpr int nwarps = D / WARP_SIZE;
69
- const int tid = WARP_SIZE*threadIdx.y + threadIdx.x;
70
- __builtin_assume(tid < D);
71
-
72
- __shared__ float KQ[ncols*D];
73
- #pragma unroll
74
- for (int j = 0; j < ncols; ++j) {
75
- KQ[j*D + tid] = -FLT_MAX/2.0f;
76
- }
77
-
78
- float kqmax[ncols];
79
- #pragma unroll
80
- for (int j = 0; j < ncols; ++j) {
81
- kqmax[j] = -FLT_MAX/2.0f;
82
- }
83
- float kqsum[ncols] = {0.0f};
84
-
85
- __shared__ float kqmax_shared[ncols][WARP_SIZE];
86
- __shared__ float kqsum_shared[ncols][WARP_SIZE];
87
- #pragma unroll
88
- for (int j = 0; j < ncols; ++j) {
89
- if (threadIdx.y == 0) {
90
- kqmax_shared[j][threadIdx.x] = -FLT_MAX/2.0f;
91
- kqsum_shared[j][threadIdx.x] = 0.0f;
92
- }
93
- }
94
- __syncthreads();
95
-
96
- // Convert Q to half2 and store in registers:
97
- float2 Q_h2[ncols][D/(2*WARP_SIZE)];
98
- #pragma unroll
99
- for (int j = 0; j < ncols; ++j) {
100
- #pragma unroll
101
- for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
102
- const int i = i0 + threadIdx.x;
103
-
104
- Q_h2[j][i0/WARP_SIZE] = Q_f2[j*(nb01/sizeof(float2)) + i];
105
- Q_h2[j][i0/WARP_SIZE].x *= scale;
106
- Q_h2[j][i0/WARP_SIZE].y *= scale;
107
- }
108
- }
109
-
110
- float VKQ[ncols] = {0.0f};
111
-
112
- const int k_start = parallel_blocks == 1 ? 0 : ip*D;
113
- for (int k_VKQ_0 = k_start; k_VKQ_0 < ne11; k_VKQ_0 += parallel_blocks*D) {
114
- // Calculate KQ tile and keep track of new maximum KQ values:
115
-
116
- float kqmax_new_arr[ncols];
117
- #pragma unroll
118
- for (int j = 0; j < ncols; ++j) {
119
- kqmax_new_arr[j] = kqmax[j];
120
- }
121
-
122
- #pragma unroll
123
- for (int i_KQ_0 = 0; i_KQ_0 < D; i_KQ_0 += nwarps) {
124
- const int i_KQ = i_KQ_0 + threadIdx.y;
125
-
126
- if ((i_KQ_0 + nwarps > D && i_KQ >= D) || (FATTN_KQ_STRIDE % D != 0 && k_VKQ_0 + i_KQ >= ne11)) {
127
- break;
128
- }
129
-
130
- float sum[ncols] = {0.0f};
131
- #pragma unroll
132
- for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += WARP_SIZE) {
133
- const int k_KQ = k_KQ_0 + threadIdx.x;
134
-
135
- const half2 K_ik = K_h2[(k_VKQ_0 + i_KQ)*stride_KV2 + k_KQ];
136
- #pragma unroll
137
- for (int j = 0; j < ncols; ++j) {
138
- sum[j] += __low2float(K_ik) * Q_h2[j][k_KQ_0/WARP_SIZE].x;
139
- sum[j] += __high2float(K_ik) * Q_h2[j][k_KQ_0/WARP_SIZE].y;
140
- }
141
- }
142
-
143
- #pragma unroll
144
- for (int j = 0; j < ncols; ++j) {
145
- sum[j] = warp_reduce_sum(sum[j]);
146
- sum[j] += mask ? slope*__half2float(maskh[j*ne11 + k_VKQ_0 + i_KQ]) : 0.0f;
147
-
148
- kqmax_new_arr[j] = fmaxf(kqmax_new_arr[j], sum[j]);
149
-
150
- if (threadIdx.x == 0) {
151
- KQ[j*D + i_KQ] = sum[j];
152
- }
153
- }
154
- }
155
-
156
- #pragma unroll
157
- for (int j = 0; j < ncols; ++j) {
158
- float kqmax_new_j = kqmax_new_arr[j];
159
-
160
- kqmax_new_j = warp_reduce_max(kqmax_new_j);
161
- if (threadIdx.x == 0) {
162
- kqmax_shared[j][threadIdx.y] = kqmax_new_j;
163
- }
164
- }
165
-
166
- __syncthreads();
167
-
168
- #pragma unroll
169
- for (int j = 0; j < ncols; ++j) {
170
- float kqmax_new_j = kqmax_shared[j][threadIdx.x];
171
- kqmax_new_j = warp_reduce_max(kqmax_new_j);
172
-
173
- const float KQ_max_scale = expf(kqmax[j] - kqmax_new_j);
174
- kqmax[j] = kqmax_new_j;
175
-
176
- const float val = expf(KQ[j*D + tid] - kqmax[j]);
177
- kqsum[j] = kqsum[j]*KQ_max_scale + val;
178
- KQ[j*D + tid] = val;
179
-
180
- VKQ[j] *= KQ_max_scale;
181
- }
182
-
183
- __syncthreads();
184
-
185
- #pragma unroll
186
- for (int k = 0; k < D; ++k) {
187
- if (FATTN_KQ_STRIDE % D != 0 && k_VKQ_0 + k >= ne11) {
188
- break;
189
- }
190
-
191
- const float V_ki = __half2float(V_h[(k_VKQ_0 + k)*stride_KV + tid]);
192
- #pragma unroll
193
- for (int j = 0; j < ncols; ++j) {
194
- VKQ[j] += V_ki*KQ[j*D + k];
195
- }
196
- }
197
-
198
- __syncthreads();
199
- }
200
-
201
- #pragma unroll
202
- for (int j = 0; j < ncols; ++j) {
203
- kqsum[j] = warp_reduce_sum(kqsum[j]);
204
- if (threadIdx.x == 0) {
205
- kqsum_shared[j][threadIdx.y] = kqsum[j];
206
- }
207
- }
208
-
209
- __syncthreads();
210
-
211
- #pragma unroll
212
- for (int j_VKQ = 0; j_VKQ < ncols; ++j_VKQ) {
213
- kqsum[j_VKQ] = kqsum_shared[j_VKQ][threadIdx.x];
214
- kqsum[j_VKQ] = warp_reduce_sum(kqsum[j_VKQ]);
215
-
216
- float dst_val = VKQ[j_VKQ];
217
- if (parallel_blocks == 1) {
218
- dst_val /= kqsum[j_VKQ];
219
- }
220
- const int j_dst = (ic0 + j_VKQ)*parallel_blocks + ip;
221
- dst[j_dst*D*gridDim.y + D*blockIdx.y + tid] = dst_val;
222
- }
223
-
224
- if (parallel_blocks != 1 && tid != 0) {
225
- #pragma unroll
226
- for (int j = 0; j < ncols; ++j) {
227
- dst_meta[(ic0 + j)*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = make_float2(kqmax[j], kqsum[j]);
228
- }
229
- }
230
- }
231
-
232
- template <int D, int cols_per_block, int parallel_blocks> void launch_fattn_vec_f32(
233
- const ggml_tensor * Q, const ggml_tensor * K, const ggml_tensor * V, ggml_tensor * KQV, const ggml_tensor * mask,
234
- ggml_cuda_pool & pool, cudaStream_t main_stream
235
- ) {
236
- ggml_cuda_pool_alloc<float> dst_tmp(pool);
237
- ggml_cuda_pool_alloc<float2> dst_tmp_meta(pool);
238
-
239
- if (parallel_blocks > 1) {
240
- dst_tmp.alloc(parallel_blocks*ggml_nelements(KQV));
241
- dst_tmp_meta.alloc(parallel_blocks*ggml_nrows(KQV));
242
- }
243
-
244
- constexpr int nwarps = (D + WARP_SIZE - 1) / WARP_SIZE;
245
- const dim3 block_dim(WARP_SIZE, nwarps, 1);
246
- const dim3 blocks_num(parallel_blocks*((Q->ne[1] + cols_per_block - 1) / cols_per_block), Q->ne[2], Q->ne[3]);
247
- const int shmem = 0;
248
-
249
- float scale = 1.0f;
250
- float max_bias = 0.0f;
251
-
252
- memcpy(&scale, (float *) KQV->op_params + 0, sizeof(float));
253
- memcpy(&max_bias, (float *) KQV->op_params + 1, sizeof(float));
254
-
255
- const uint32_t n_head = Q->ne[2];
256
- const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));
257
-
258
- const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
259
- const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
260
-
261
- flash_attn_vec_ext_f32<D, cols_per_block, parallel_blocks>
262
- <<<blocks_num, block_dim, shmem, main_stream>>> (
263
- (const char *) Q->data,
264
- (const char *) K->data,
265
- (const char *) V->data,
266
- mask ? ((const char *) mask->data) : nullptr,
267
- parallel_blocks == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr,
268
- scale, max_bias, m0, m1, n_head_log2,
269
- Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
270
- K->ne[0], K->ne[1], K->ne[2], K->ne[3],
271
- mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
272
- Q->nb[1], Q->nb[2], Q->nb[3],
273
- K->nb[1], K->nb[2], K->nb[3],
274
- KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
275
- );
276
- CUDA_CHECK(cudaGetLastError());
277
-
278
- if (parallel_blocks == 1) {
279
- return;
280
- }
281
-
282
- const dim3 block_dim_combine(D, 1, 1);
283
- const dim3 blocks_num_combine(Q->ne[1], blocks_num.y, blocks_num.z);
284
- const int shmem_combine = 0;
285
-
286
- flash_attn_combine_results<D, parallel_blocks>
287
- <<<blocks_num_combine, block_dim_combine, shmem_combine, main_stream>>>
288
- (dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data);
289
- CUDA_CHECK(cudaGetLastError());
290
- }
291
-
292
- void ggml_cuda_flash_attn_ext_vec_f32(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
293
- const ggml_tensor * Q = dst->src[0];
294
- const ggml_tensor * K = dst->src[1];
295
- const ggml_tensor * V = dst->src[2];
296
-
297
- const ggml_tensor * mask = dst->src[3];
298
-
299
- ggml_tensor * KQV = dst;
300
-
301
- GGML_ASSERT(Q->ne[0] == 64 || Q->ne[0] == 128 && "FlashAttention without tensor cores only supports head sizes 64 and 128.");
302
-
303
- if (Q->ne[1] == 1) {
304
- constexpr int cols_per_block = 1;
305
- constexpr int parallel_blocks = 4;
306
- switch (Q->ne[0]) {
307
- case 64:
308
- launch_fattn_vec_f32< 64, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
309
- break;
310
- case 128:
311
- launch_fattn_vec_f32<128, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
312
- break;
313
- default:
314
- GGML_ASSERT(false);
315
- break;
316
- }
317
- return;
318
- }
319
-
320
- if (Q->ne[1] == 2) {
321
- constexpr int cols_per_block = 2;
322
- constexpr int parallel_blocks = 4;
323
- switch (Q->ne[0]) {
324
- case 64:
325
- launch_fattn_vec_f32< 64, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
326
- break;
327
- case 128:
328
- launch_fattn_vec_f32<128, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
329
- break;
330
- default:
331
- GGML_ASSERT(false);
332
- break;
333
- }
334
- return;
335
- }
336
-
337
- if (Q->ne[1] <= 4) {
338
- constexpr int cols_per_block = 4;
339
- constexpr int parallel_blocks = 4;
340
- switch (Q->ne[0]) {
341
- case 64:
342
- launch_fattn_vec_f32< 64, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
343
- break;
344
- case 128:
345
- launch_fattn_vec_f32<128, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
346
- break;
347
- default:
348
- GGML_ASSERT(false);
349
- break;
350
- }
351
- return;
352
- }
353
-
354
- if (Q->ne[1] <= 8) {
355
- constexpr int cols_per_block = 8;
356
- constexpr int parallel_blocks = 4;
357
- switch (Q->ne[0]) {
358
- case 64:
359
- launch_fattn_vec_f32< 64, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
360
- break;
361
- case 128:
362
- launch_fattn_vec_f32<128, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
363
- break;
364
- default:
365
- GGML_ASSERT(false);
366
- break;
367
- }
368
- return;
369
- }
370
-
371
- constexpr int cols_per_block = 8;
372
- constexpr int parallel_blocks = 1;
373
- switch (Q->ne[0]) {
374
- case 64:
375
- launch_fattn_vec_f32< 64, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
376
- break;
377
- case 128:
378
- launch_fattn_vec_f32<128, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
379
- break;
380
- default:
381
- GGML_ASSERT(false);
382
- break;
383
- }
384
- }