ggerganov commited on
Commit
f1ea157
·
1 Parent(s): 09e4a9b

metal : add quantized FA support (llama/10149)

Browse files

* metal : add quantized FA (vec) support

ggml-ci

* metal : add quantized FA (non-vec) support

* metal : fix support check

ggml-ci

* metal : clean-up

* metal : clean-up (cont)

* metal : fix shared memory calc + reduce smem + comments

* metal : float-correctness

* metal : minor [no ci]

Files changed (2) hide show
  1. ggml/src/ggml-metal.m +265 -37
  2. ggml/src/ggml-metal.metal +303 -155
ggml/src/ggml-metal.m CHANGED
@@ -255,9 +255,49 @@ enum ggml_metal_kernel_type {
255
  GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96,
256
  GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112,
257
  GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128,
258
- //GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, // https://github.com/ggerganov/llama.cpp/issues/7261
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
259
  GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128,
260
- //GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, // https://github.com/ggerganov/llama.cpp/issues/7261
 
 
 
 
 
 
 
 
 
 
261
  GGML_METAL_KERNEL_TYPE_CPY_F32_F32,
262
  GGML_METAL_KERNEL_TYPE_CPY_F32_F16,
263
  GGML_METAL_KERNEL_TYPE_CPY_F16_F16,
@@ -710,9 +750,49 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
710
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96, flash_attn_ext_f16_h96, support_simdgroup_mm);
711
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112, flash_attn_ext_f16_h112, support_simdgroup_mm);
712
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128, flash_attn_ext_f16_h128, support_simdgroup_mm);
713
- //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, flash_attn_ext_f16_h256, support_simdgroup_mm);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
714
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128, flash_attn_ext_vec_f16_h128, support_simdgroup_reduction);
715
- //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, flash_attn_ext_vec_f16_h256, support_simdgroup_reduction);
 
 
 
 
 
 
 
 
 
 
716
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F16, cpy_f32_f16, true);
717
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F32, cpy_f32_f32, true);
718
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F16, cpy_f16_f16, true);
@@ -869,13 +949,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
869
  case GGML_OP_LEAKY_RELU:
870
  return true;
871
  case GGML_OP_FLASH_ATTN_EXT:
872
- if (op->src[1]->type != GGML_TYPE_F16) {
873
- return false;
874
- }
875
- if (op->src[2]->type != GGML_TYPE_F16) {
876
- return false;
877
- }
878
- if (op->src[0]->ne[0] == 256) {
879
  return false;
880
  }
881
  return support_simdgroup_mm; // TODO: over-restricted for vec-kernels
@@ -2822,6 +2896,7 @@ static void ggml_metal_encode_node(
2822
  GGML_ASSERT(ne11 % 32 == 0);
2823
 
2824
  GGML_ASSERT(src0->type == GGML_TYPE_F32);
 
2825
 
2826
  GGML_ASSERT(ggml_are_same_shape (src1, src2));
2827
 
@@ -2869,26 +2944,154 @@ static void ggml_metal_encode_node(
2869
  bool use_vec_kernel = false;
2870
 
2871
  if (ne01 >= 4 || (ne00%128 != 0)) {
2872
- switch (ne00) {
2873
- case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64 ].pipeline; break;
2874
- case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80 ].pipeline; break;
2875
- case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96 ].pipeline; break;
2876
- case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112].pipeline; break;
2877
- case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128].pipeline; break;
2878
- //case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256].pipeline; break;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2879
  default:
2880
- {
2881
- GGML_LOG_ERROR("unsupported size: %lld\n", ne00);
2882
- GGML_LOG_ERROR("add template specialization for this size\n");
2883
- GGML_ABORT("add template specialization for this size");
2884
- }
2885
  }
2886
  } else {
2887
  use_vec_kernel = true;
2888
 
2889
  switch (ne00) {
2890
- case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128].pipeline; break;
2891
- //case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256].pipeline; break;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2892
  default:
2893
  {
2894
  GGML_LOG_ERROR("unsupported size: %lld\n", ne00);
@@ -2942,10 +3145,16 @@ static void ggml_metal_encode_node(
2942
  GGML_ASSERT(nqptg % 8 == 0);
2943
  GGML_ASSERT(ncpsg % 32 == 0);
2944
 
 
 
 
 
 
 
2945
  int64_t nsgmax = 2;
2946
 
2947
  while (true) {
2948
- const size_t smem = nqptg*(ne00 + 2*nsgmax*(ncpsg + nqptg))*(sizeof(float)/2);
2949
  if (smem > device.maxThreadgroupMemoryLength) {
2950
  break;
2951
  }
@@ -2956,16 +3165,15 @@ static void ggml_metal_encode_node(
2956
  // simdgroups per threadgroup (a.k.a. warps)
2957
  const int64_t nsg = ne01 <= nqptg ? MAX(4, MIN(nsgmax, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32))) : 4;
2958
 
2959
- const size_t smem = nqptg*(ne00 + 2*nsg*(ncpsg + nqptg))*(sizeof(float)/2);
2960
 
2961
- //printf("smem: %zu, max: %zu\n", smem, device.maxThreadgroupMemoryLength);
2962
  GGML_ASSERT(smem <= device.maxThreadgroupMemoryLength);
2963
-
2964
- [encoder setThreadgroupMemoryLength:GGML_PAD(smem, 16) atIndex:0];
2965
-
2966
  [encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)];
2967
  } else {
2968
- // half1x4 kernel
2969
  const int64_t nqptg = 1; // queries per threadgroup !! sync with kernel template arguments !!
2970
  const int64_t ncpsg = 32; // cache values per simdgroup !! sync with kernel template arguments !!
2971
 
@@ -2973,8 +3181,28 @@ static void ggml_metal_encode_node(
2973
  GGML_ASSERT(nqptg % 1 == 0);
2974
  GGML_ASSERT(ncpsg % 32 == 0);
2975
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2976
  // simdgroups per threadgroup (a.k.a. warps)
2977
- const int64_t nsgt = MAX(2, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32));
2978
 
2979
  int64_t nsg = 1;
2980
  while (nsg <= nsgt) {
@@ -2982,12 +3210,12 @@ static void ggml_metal_encode_node(
2982
  }
2983
  nsg /= 2;
2984
 
2985
- const size_t smem = (nqptg*(ne00 + 2*nsg*(ncpsg + nqptg)) + nsg*ne00)*(sizeof(float)/2);
2986
 
2987
- //printf("smem: %zu, max: %zu\n", smem, device.maxThreadgroupMemoryLength);
2988
  GGML_ASSERT(smem <= device.maxThreadgroupMemoryLength);
2989
- [encoder setThreadgroupMemoryLength:GGML_PAD(smem, 16) atIndex:0];
2990
-
2991
  [encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)];
2992
  }
2993
  } break;
 
255
  GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96,
256
  GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112,
257
  GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128,
258
+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256,
259
+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H64,
260
+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H80,
261
+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H96,
262
+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H112,
263
+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H128,
264
+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H256,
265
+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H64,
266
+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H80,
267
+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H96,
268
+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H112,
269
+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H128,
270
+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H256,
271
+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H64,
272
+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H80,
273
+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H96,
274
+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H112,
275
+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H128,
276
+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H256,
277
+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H64,
278
+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H80,
279
+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H96,
280
+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H112,
281
+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H128,
282
+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H256,
283
+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H64,
284
+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H80,
285
+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H96,
286
+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H112,
287
+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H128,
288
+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H256,
289
  GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128,
290
+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H128,
291
+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H128,
292
+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H128,
293
+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H128,
294
+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H128,
295
+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256,
296
+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H256,
297
+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H256,
298
+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H256,
299
+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H256,
300
+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H256,
301
  GGML_METAL_KERNEL_TYPE_CPY_F32_F32,
302
  GGML_METAL_KERNEL_TYPE_CPY_F32_F16,
303
  GGML_METAL_KERNEL_TYPE_CPY_F16_F16,
 
750
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96, flash_attn_ext_f16_h96, support_simdgroup_mm);
751
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112, flash_attn_ext_f16_h112, support_simdgroup_mm);
752
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128, flash_attn_ext_f16_h128, support_simdgroup_mm);
753
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, flash_attn_ext_f16_h256, support_simdgroup_mm);
754
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H64, flash_attn_ext_q4_0_h64, support_simdgroup_mm);
755
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H80, flash_attn_ext_q4_0_h80, support_simdgroup_mm);
756
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H96, flash_attn_ext_q4_0_h96, support_simdgroup_mm);
757
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H112, flash_attn_ext_q4_0_h112, support_simdgroup_mm);
758
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H128, flash_attn_ext_q4_0_h128, support_simdgroup_mm);
759
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H256, flash_attn_ext_q4_0_h256, support_simdgroup_mm);
760
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H64, flash_attn_ext_q4_1_h64, support_simdgroup_mm);
761
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H80, flash_attn_ext_q4_1_h80, support_simdgroup_mm);
762
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H96, flash_attn_ext_q4_1_h96, support_simdgroup_mm);
763
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H112, flash_attn_ext_q4_1_h112, support_simdgroup_mm);
764
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H128, flash_attn_ext_q4_1_h128, support_simdgroup_mm);
765
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H256, flash_attn_ext_q4_1_h256, support_simdgroup_mm);
766
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H64, flash_attn_ext_q5_0_h64, support_simdgroup_mm);
767
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H80, flash_attn_ext_q5_0_h80, support_simdgroup_mm);
768
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H96, flash_attn_ext_q5_0_h96, support_simdgroup_mm);
769
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H112, flash_attn_ext_q5_0_h112, support_simdgroup_mm);
770
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H128, flash_attn_ext_q5_0_h128, support_simdgroup_mm);
771
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H256, flash_attn_ext_q5_0_h256, support_simdgroup_mm);
772
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H64, flash_attn_ext_q5_1_h64, support_simdgroup_mm);
773
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H80, flash_attn_ext_q5_1_h80, support_simdgroup_mm);
774
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H96, flash_attn_ext_q5_1_h96, support_simdgroup_mm);
775
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H112, flash_attn_ext_q5_1_h112, support_simdgroup_mm);
776
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H128, flash_attn_ext_q5_1_h128, support_simdgroup_mm);
777
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H256, flash_attn_ext_q5_1_h256, support_simdgroup_mm);
778
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H64, flash_attn_ext_q8_0_h64, support_simdgroup_mm);
779
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H80, flash_attn_ext_q8_0_h80, support_simdgroup_mm);
780
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H96, flash_attn_ext_q8_0_h96, support_simdgroup_mm);
781
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H112, flash_attn_ext_q8_0_h112, support_simdgroup_mm);
782
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H128, flash_attn_ext_q8_0_h128, support_simdgroup_mm);
783
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H256, flash_attn_ext_q8_0_h256, support_simdgroup_mm);
784
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128, flash_attn_ext_vec_f16_h128, support_simdgroup_reduction);
785
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H128, flash_attn_ext_vec_q4_0_h128, support_simdgroup_reduction);
786
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H128, flash_attn_ext_vec_q4_1_h128, support_simdgroup_reduction);
787
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H128, flash_attn_ext_vec_q5_0_h128, support_simdgroup_reduction);
788
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H128, flash_attn_ext_vec_q5_1_h128, support_simdgroup_reduction);
789
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H128, flash_attn_ext_vec_q8_0_h128, support_simdgroup_reduction);
790
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, flash_attn_ext_vec_f16_h256, support_simdgroup_reduction);
791
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H256, flash_attn_ext_vec_q4_0_h256, support_simdgroup_reduction);
792
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H256, flash_attn_ext_vec_q4_1_h256, support_simdgroup_reduction);
793
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H256, flash_attn_ext_vec_q5_0_h256, support_simdgroup_reduction);
794
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H256, flash_attn_ext_vec_q5_1_h256, support_simdgroup_reduction);
795
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H256, flash_attn_ext_vec_q8_0_h256, support_simdgroup_reduction);
796
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F16, cpy_f32_f16, true);
797
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F32, cpy_f32_f32, true);
798
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F16, cpy_f16_f16, true);
 
949
  case GGML_OP_LEAKY_RELU:
950
  return true;
951
  case GGML_OP_FLASH_ATTN_EXT:
952
+ if (op->src[1]->type != op->src[2]->type) {
 
 
 
 
 
 
953
  return false;
954
  }
955
  return support_simdgroup_mm; // TODO: over-restricted for vec-kernels
 
2896
  GGML_ASSERT(ne11 % 32 == 0);
2897
 
2898
  GGML_ASSERT(src0->type == GGML_TYPE_F32);
2899
+ GGML_ASSERT(src1->type == src2->type);
2900
 
2901
  GGML_ASSERT(ggml_are_same_shape (src1, src2));
2902
 
 
2944
  bool use_vec_kernel = false;
2945
 
2946
  if (ne01 >= 4 || (ne00%128 != 0)) {
2947
+ switch (src1->type) {
2948
+ case GGML_TYPE_F16:
2949
+ {
2950
+ switch (ne00) {
2951
+ case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64 ].pipeline; break;
2952
+ case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80 ].pipeline; break;
2953
+ case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96 ].pipeline; break;
2954
+ case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112].pipeline; break;
2955
+ case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128].pipeline; break;
2956
+ case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256].pipeline; break;
2957
+ default:
2958
+ {
2959
+ GGML_LOG_ERROR("unsupported size: %lld\n", ne00);
2960
+ GGML_LOG_ERROR("add template specialization for this size\n");
2961
+ GGML_ABORT("add template specialization for this size");
2962
+ }
2963
+ }
2964
+ } break;
2965
+ case GGML_TYPE_Q4_0:
2966
+ {
2967
+ switch (ne00) {
2968
+ case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H64 ].pipeline; break;
2969
+ case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H80 ].pipeline; break;
2970
+ case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H96 ].pipeline; break;
2971
+ case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H112].pipeline; break;
2972
+ case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H128].pipeline; break;
2973
+ case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H256].pipeline; break;
2974
+ default:
2975
+ {
2976
+ GGML_LOG_ERROR("unsupported size: %lld\n", ne00);
2977
+ GGML_LOG_ERROR("add template specialization for this size\n");
2978
+ GGML_ABORT("add template specialization for this size");
2979
+ }
2980
+ }
2981
+ } break;
2982
+ case GGML_TYPE_Q4_1:
2983
+ {
2984
+ switch (ne00) {
2985
+ case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H64 ].pipeline; break;
2986
+ case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H80 ].pipeline; break;
2987
+ case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H96 ].pipeline; break;
2988
+ case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H112].pipeline; break;
2989
+ case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H128].pipeline; break;
2990
+ case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H256].pipeline; break;
2991
+ default:
2992
+ {
2993
+ GGML_LOG_ERROR("unsupported size: %lld\n", ne00);
2994
+ GGML_LOG_ERROR("add template specialization for this size\n");
2995
+ GGML_ABORT("add template specialization for this size");
2996
+ }
2997
+ }
2998
+ } break;
2999
+ case GGML_TYPE_Q5_0:
3000
+ {
3001
+ switch (ne00) {
3002
+ case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H64 ].pipeline; break;
3003
+ case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H80 ].pipeline; break;
3004
+ case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H96 ].pipeline; break;
3005
+ case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H112].pipeline; break;
3006
+ case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H128].pipeline; break;
3007
+ case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H256].pipeline; break;
3008
+ default:
3009
+ {
3010
+ GGML_LOG_ERROR("unsupported size: %lld\n", ne00);
3011
+ GGML_LOG_ERROR("add template specialization for this size\n");
3012
+ GGML_ABORT("add template specialization for this size");
3013
+ }
3014
+ }
3015
+ } break;
3016
+ case GGML_TYPE_Q5_1:
3017
+ {
3018
+ switch (ne00) {
3019
+ case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H64 ].pipeline; break;
3020
+ case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H80 ].pipeline; break;
3021
+ case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H96 ].pipeline; break;
3022
+ case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H112].pipeline; break;
3023
+ case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H128].pipeline; break;
3024
+ case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H256].pipeline; break;
3025
+ default:
3026
+ {
3027
+ GGML_LOG_ERROR("unsupported size: %lld\n", ne00);
3028
+ GGML_LOG_ERROR("add template specialization for this size\n");
3029
+ GGML_ABORT("add template specialization for this size");
3030
+ }
3031
+ }
3032
+ } break;
3033
+ case GGML_TYPE_Q8_0:
3034
+ {
3035
+ switch (ne00) {
3036
+ case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H64 ].pipeline; break;
3037
+ case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H80 ].pipeline; break;
3038
+ case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H96 ].pipeline; break;
3039
+ case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H112].pipeline; break;
3040
+ case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H128].pipeline; break;
3041
+ case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H256].pipeline; break;
3042
+ default:
3043
+ {
3044
+ GGML_LOG_ERROR("unsupported size: %lld\n", ne00);
3045
+ GGML_LOG_ERROR("add template specialization for this size\n");
3046
+ GGML_ABORT("add template specialization for this size");
3047
+ }
3048
+ }
3049
+ } break;
3050
  default:
3051
+ {
3052
+ GGML_LOG_ERROR("unsupported type: %d\n", src1->type);
3053
+ GGML_LOG_ERROR("add template specialization for this type\n");
3054
+ GGML_ABORT("add template specialization for this type");
3055
+ }
3056
  }
3057
  } else {
3058
  use_vec_kernel = true;
3059
 
3060
  switch (ne00) {
3061
+ case 128:
3062
+ {
3063
+ switch (src1->type) {
3064
+ case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128].pipeline; break;
3065
+ case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H128].pipeline; break;
3066
+ case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H128].pipeline; break;
3067
+ case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H128].pipeline; break;
3068
+ case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H128].pipeline; break;
3069
+ case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H128].pipeline; break;
3070
+ default:
3071
+ {
3072
+ GGML_LOG_ERROR("unsupported type: %d\n", src1->type);
3073
+ GGML_LOG_ERROR("add template specialization for this type\n");
3074
+ GGML_ABORT("add template specialization for this type");
3075
+ }
3076
+ }
3077
+ } break;
3078
+ case 256:
3079
+ {
3080
+ switch (src1->type) {
3081
+ case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256].pipeline; break;
3082
+ case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H256].pipeline; break;
3083
+ case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H256].pipeline; break;
3084
+ case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H256].pipeline; break;
3085
+ case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H256].pipeline; break;
3086
+ case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H256].pipeline; break;
3087
+ default:
3088
+ {
3089
+ GGML_LOG_ERROR("unsupported type: %d\n", src1->type);
3090
+ GGML_LOG_ERROR("add template specialization for this type\n");
3091
+ GGML_ABORT("add template specialization for this type");
3092
+ }
3093
+ }
3094
+ } break;
3095
  default:
3096
  {
3097
  GGML_LOG_ERROR("unsupported size: %lld\n", ne00);
 
3145
  GGML_ASSERT(nqptg % 8 == 0);
3146
  GGML_ASSERT(ncpsg % 32 == 0);
3147
 
3148
+ // 16*32*(nsg)
3149
+ // the shared memory needed for the simdgroups to load the KV cache
3150
+ // each thread loads (dequantizes) 16 head elements, there are 32 threads in th SG
3151
+ //
3152
+ #define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(ne00 + 2*(ncpsg + nqptg)*(nsg)) + 16*32*(nsg))*(sizeof(float)/2), 16))
3153
+
3154
  int64_t nsgmax = 2;
3155
 
3156
  while (true) {
3157
+ const size_t smem = FATTN_SMEM(nsgmax);
3158
  if (smem > device.maxThreadgroupMemoryLength) {
3159
  break;
3160
  }
 
3165
  // simdgroups per threadgroup (a.k.a. warps)
3166
  const int64_t nsg = ne01 <= nqptg ? MAX(4, MIN(nsgmax, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32))) : 4;
3167
 
3168
+ const size_t smem = FATTN_SMEM(nsg);
3169
 
3170
+ //printf("smem: %zu, max: %zu, nsg = %d\n", smem, device.maxThreadgroupMemoryLength, (int) nsg);
3171
  GGML_ASSERT(smem <= device.maxThreadgroupMemoryLength);
3172
+ [encoder setThreadgroupMemoryLength:smem atIndex:0];
3173
+ #undef FATTN_SMEM
 
3174
  [encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)];
3175
  } else {
3176
+ // half4x4 kernel
3177
  const int64_t nqptg = 1; // queries per threadgroup !! sync with kernel template arguments !!
3178
  const int64_t ncpsg = 32; // cache values per simdgroup !! sync with kernel template arguments !!
3179
 
 
3181
  GGML_ASSERT(nqptg % 1 == 0);
3182
  GGML_ASSERT(ncpsg % 32 == 0);
3183
 
3184
+ // ne00 + 2*ncpsg*(nsg)
3185
+ // for each query, we load it as f16 in shared memory (ne00)
3186
+ // and store the attention scores (nqptg x ncpsg) as f32
3187
+ //
3188
+ // 2*ne00*(nsg)
3189
+ // each simdgroup has a full f32 head vector in shared mem to accumulate results
3190
+ //
3191
+ #define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(ne00 + 2*ncpsg*(nsg)) + 2*ne00*(nsg))*(sizeof(float)/2), 16))
3192
+
3193
+ int64_t nsgmax = 2;
3194
+
3195
+ while (true) {
3196
+ const size_t smem = FATTN_SMEM(nsgmax);
3197
+ if (smem > device.maxThreadgroupMemoryLength) {
3198
+ break;
3199
+ }
3200
+ nsgmax *= 2;
3201
+ }
3202
+ nsgmax /= 2;
3203
+
3204
  // simdgroups per threadgroup (a.k.a. warps)
3205
+ const int64_t nsgt = MAX(2, MIN(nsgmax, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32)));
3206
 
3207
  int64_t nsg = 1;
3208
  while (nsg <= nsgt) {
 
3210
  }
3211
  nsg /= 2;
3212
 
3213
+ const size_t smem = FATTN_SMEM(nsg);
3214
 
3215
+ //printf("smem: %zu, max: %zu, nsg = %d\n", smem, device.maxThreadgroupMemoryLength, (int) nsg);
3216
  GGML_ASSERT(smem <= device.maxThreadgroupMemoryLength);
3217
+ [encoder setThreadgroupMemoryLength:smem atIndex:0];
3218
+ #undef FATTN_SMEM
3219
  [encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)];
3220
  }
3221
  } break;
ggml/src/ggml-metal.metal CHANGED
@@ -2723,46 +2723,10 @@ kernel void kernel_leaky_relu_f32(
2723
  dst[tpig] = src0[tpig] > 0.0f ? src0[tpig] : src0[tpig] * slope;
2724
  }
2725
 
2726
- typedef void (flash_attn_ext_f16_t)(
2727
- device const char * q,
2728
- device const char * k,
2729
- device const char * v,
2730
- device const char * mask,
2731
- device float * dst,
2732
- constant int64_t & ne01,
2733
- constant int64_t & ne02,
2734
- constant int64_t & ne03,
2735
- constant uint64_t & nb01,
2736
- constant uint64_t & nb02,
2737
- constant uint64_t & nb03,
2738
- constant int64_t & ne11,
2739
- constant int64_t & ne12,
2740
- constant int64_t & ne13,
2741
- constant uint64_t & nb11,
2742
- constant uint64_t & nb12,
2743
- constant uint64_t & nb13,
2744
- constant uint64_t & nb21,
2745
- constant uint64_t & nb22,
2746
- constant uint64_t & nb23,
2747
- constant uint64_t & nb31,
2748
- constant int64_t & ne1,
2749
- constant int64_t & ne2,
2750
- constant float & scale,
2751
- constant float & max_bias,
2752
- constant float & m0,
2753
- constant float & m1,
2754
- constant uint32_t & n_head_log2,
2755
- constant float & logit_softcap,
2756
- threadgroup half * shared,
2757
- uint3 tgpig[[threadgroup_position_in_grid]],
2758
- uint3 tpitg[[thread_position_in_threadgroup]],
2759
- uint3 ntg[[threads_per_threadgroup]],
2760
- ushort tiisg[[thread_index_in_simdgroup]],
2761
- ushort sgitg[[simdgroup_index_in_threadgroup]]);
2762
-
2763
  // ref: https://arxiv.org/pdf/2307.08691.pdf
2764
- template<int64_t D, int64_t Q = 8, int64_t C = 32> // head size, queries per threadgroup, cache items per threadgroup
2765
- kernel void kernel_flash_attn_ext_f16(
 
2766
  device const char * q,
2767
  device const char * k,
2768
  device const char * v,
@@ -2800,15 +2764,15 @@ kernel void kernel_flash_attn_ext_f16(
2800
  ushort sgitg[[simdgroup_index_in_threadgroup]]) {
2801
  const short nsg = ntg.y; // number of simdgroups
2802
 
2803
- const short iq3 = tgpig[2];
2804
- const short iq2 = tgpig[1];
2805
- const short iq1 = tgpig[0]*Q;
2806
 
2807
- const short D4 = D/4;
2808
- const short D8 = D/8;
2809
- //const short Q8 = Q/8;
2810
- const short NW = N_SIMDWIDTH;
2811
- const short SH = (C + Q); // shared memory per simdgroup in (half)
2812
 
2813
  const short T = D + 2*nsg*SH; // shared memory size per query in (half)
2814
  const short TF = T/2; // shared memory size per query in (float)
@@ -2818,6 +2782,9 @@ kernel void kernel_flash_attn_ext_f16(
2818
  threadgroup half4 * sq4 = (threadgroup half4 *) (shared + 0*D); // same as above but in half4
2819
  threadgroup float * ss = (threadgroup float *) (shared + 2*sgitg*SH + 1*D); // scratch buffer for attention and diagonal matrix
2820
 
 
 
 
2821
  // store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper)
2822
  simdgroup_half8x8 lo[D8];
2823
 
@@ -2849,25 +2816,28 @@ kernel void kernel_flash_attn_ext_f16(
2849
  threadgroup_barrier(mem_flags::mem_threadgroup);
2850
 
2851
  {
2852
- float S[Q] = { [0 ... Q-1] = 0.0h };
2853
  float M[Q] = { [0 ... Q-1] = -FLT_MAX/2 };
2854
 
 
 
 
 
2855
  // assume K and V are same shape
2856
  const short ne22 = ne12;
2857
  const short ne23 = ne13;
2858
 
2859
- // broadcast
2860
  const short rk2 = ne02/ne12;
2861
  const short rk3 = ne03/ne13;
2862
 
2863
- const short rv2 = ne02/ne22;
2864
- const short rv3 = ne03/ne23;
2865
-
2866
- // k indices
2867
  const short ik2 = iq2/rk2;
2868
  const short ik3 = iq3/rk3;
2869
 
2870
- // v indices
 
 
 
2871
  const short iv2 = iq2/rv2;
2872
  const short iv3 = iq3/rv3;
2873
 
@@ -2906,13 +2876,59 @@ kernel void kernel_flash_attn_ext_f16(
2906
  for (short cc = 0; cc < C/8; ++cc) {
2907
  simdgroup_float8x8 mqk = make_filled_simdgroup_matrix<float, 8>(0.h);
2908
 
2909
- device const half * pk = (device const half *) ((device const char *) k + ((ic + 8*cc)*nb11 + ik2*nb12 + ik3*nb13));
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2910
 
2911
- for (short i = 0; i < D8; ++i) {
2912
- simdgroup_half8x8 mk;
2913
- simdgroup_load(mk, pk + i*8, nb11/sizeof(half), 0, true); // transpose
2914
 
2915
- simdgroup_multiply_accumulate(mqk, mq[i], mk, mqk);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2916
  }
2917
 
2918
  simdgroup_store(mqk, ss + 8*cc, TF, 0, false);
@@ -2977,16 +2993,61 @@ kernel void kernel_flash_attn_ext_f16(
2977
  // O = O + (Q*K^T)*V
2978
  {
2979
  for (short cc = 0; cc < C/8; ++cc) {
2980
- device const half * pv = (device const half *) ((device const char *) v + ((ic + 8*cc)*nb21 + iv2*nb22 + iv3*nb23));
 
 
 
 
 
 
 
 
 
2981
 
2982
- for (short i = 0; i < D8; ++i) {
2983
- simdgroup_half8x8 mk;
2984
- simdgroup_load(mk, pv + i*8, nb21/sizeof(half), 0, false);
 
 
 
 
 
 
 
 
2985
 
2986
- simdgroup_float8x8 mv;
2987
- simdgroup_load(mv, ss + 8*cc, TF, 0, false);
2988
 
2989
- simdgroup_multiply_accumulate(lo[i], mv, mk, lo[i]);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2990
  }
2991
  }
2992
  }
@@ -3003,7 +3064,7 @@ kernel void kernel_flash_attn_ext_f16(
3003
 
3004
  // reduce the warps sequentially
3005
  for (short sg = 1; sg < nsg; ++sg) {
3006
- float S = { 0.0h };
3007
  float M = { -FLT_MAX/2 };
3008
 
3009
  threadgroup_barrier(mem_flags::mem_threadgroup);
@@ -3082,15 +3143,54 @@ kernel void kernel_flash_attn_ext_f16(
3082
  }
3083
  }
3084
 
3085
- template [[host_name("kernel_flash_attn_ext_f16_h64" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<64>;
3086
- template [[host_name("kernel_flash_attn_ext_f16_h80" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<80>;
3087
- template [[host_name("kernel_flash_attn_ext_f16_h96" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<96>;
3088
- template [[host_name("kernel_flash_attn_ext_f16_h112")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<112>;
3089
- template [[host_name("kernel_flash_attn_ext_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<128>;
3090
- //template [[host_name("kernel_flash_attn_ext_f16_h256")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<256>;
3091
-
3092
- template<int64_t D, int64_t Q = 1, int64_t C = 32> // head size, queries per threadgroup, cache items per threadgroup
3093
- kernel void kernel_flash_attn_ext_vec_f16(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3094
  device const char * q,
3095
  device const char * k,
3096
  device const char * v,
@@ -3128,36 +3228,27 @@ kernel void kernel_flash_attn_ext_vec_f16(
3128
  ushort sgitg[[simdgroup_index_in_threadgroup]]) {
3129
  const short nsg = ntg.y; // number of simdgroups
3130
 
3131
- const short iq3 = tgpig[2];
3132
- const short iq2 = tgpig[1];
3133
- const short iq1 = tgpig[0];
3134
 
3135
- const short D4 = D/4;
3136
- const short NW = N_SIMDWIDTH;
3137
- const short SH = (C + Q); // shared memory per simdgroup in (half)
 
 
3138
 
3139
  const short T = D + 2*nsg*SH; // shared memory size per query in (half)
3140
 
3141
- float slope = 1.0f;
3142
-
3143
- // ALiBi
3144
- if (max_bias > 0.0f) {
3145
- const uint32_t h = iq2;
3146
-
3147
- const float base = h < n_head_log2 ? m0 : m1;
3148
- const int exp = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
3149
-
3150
- slope = pow(base, exp);
3151
- }
3152
-
3153
- //threadgroup half * sq = (threadgroup half *) (shared + 0*D); // holds the query data
3154
- threadgroup half4 * sq4 = (threadgroup half4 *) (shared + 0*D); // same as above but in half4
3155
- threadgroup float * ss = (threadgroup float *) (shared + 2*sgitg*SH + 1*D); // scratch buffer for attention and diagonal matrix
3156
- threadgroup float4 * ss4 = (threadgroup float4 *) (shared + 2*sgitg*SH + 1*D); // same as above but in half4
3157
- threadgroup half4 * sr4 = (threadgroup half4 *) (shared + sgitg*D + 1*T); // scratch buffer for the results
3158
 
3159
  // store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper)
3160
- half4 lo[D4/NW];
3161
 
3162
  // load heads from Q to shared memory
3163
  device const float4 * q4 = (device const float4 *) ((device const char *) q + (iq1*nb01 + iq2*nb02 + iq3*nb03));
@@ -3171,8 +3262,8 @@ kernel void kernel_flash_attn_ext_vec_f16(
3171
  }
3172
 
3173
  // zero out lo
3174
- for (short i = tiisg; i < D4; i += NW) {
3175
- lo[i/NW] = 0.0h;
3176
  }
3177
 
3178
  // zero out shared memory SH
@@ -3183,38 +3274,52 @@ kernel void kernel_flash_attn_ext_vec_f16(
3183
  threadgroup_barrier(mem_flags::mem_threadgroup);
3184
 
3185
  {
3186
- float S = { 0.0h };
3187
- float M = { -FLT_MAX/2 };
 
 
 
 
3188
 
3189
  // assume K and V are same shape
3190
  const short ne22 = ne12;
3191
  const short ne23 = ne13;
3192
 
3193
- // broadcast
3194
  const short rk2 = ne02/ne12;
3195
  const short rk3 = ne03/ne13;
3196
 
 
 
 
 
3197
  const short rv2 = ne02/ne22;
3198
  const short rv3 = ne03/ne23;
3199
 
3200
- // k indices
3201
- const short ik2 = iq2 / rk2;
3202
- const short ik3 = iq3 / rk3;
3203
-
3204
- // v indices
3205
- const short iv2 = iq2 / rv2;
3206
- const short iv3 = iq3 / rv3;
3207
 
3208
  // load the queries from shared memory into local memory
3209
- float4 mq[D4/NW];
3210
 
3211
- for (short ii = 0; ii < D4; ii += NW) {
3212
- short i = ii + tiisg;
3213
- mq[ii/NW] = (float4) sq4[i];
3214
  }
3215
 
3216
  // pointer to the mask
3217
- device const half4 * mp4 = (device const half4 *) (mask + iq1*nb31);
 
 
 
 
 
 
 
 
 
 
 
 
3218
 
3219
  // loop over the KV cache
3220
  // each simdgroup handles blocks of Q rows and C columns
@@ -3226,47 +3331,54 @@ kernel void kernel_flash_attn_ext_vec_f16(
3226
 
3227
  // Q*K^T
3228
  {
3229
- #pragma unroll
3230
  for (short cc = 0; cc < C/4; ++cc) {
3231
- float4 mqk = { 0.0h };
3232
 
3233
- device const half4 * pk4 = (device const half4 *) ((device const char *) k + ((ic + 4*cc)*nb11 + ik2*nb12 + ik3*nb13));
3234
 
3235
  #pragma unroll
3236
- for (short ii = 0; ii < D4; ii += NW) {
3237
- const short i = ii + tiisg;
3238
 
3239
  float4x4 mk;
3240
- mk[0] = (float4) pk4[i + 0*(nb11/8)];
3241
- mk[1] = (float4) pk4[i + 1*(nb11/8)];
3242
- mk[2] = (float4) pk4[i + 2*(nb11/8)];
3243
- mk[3] = (float4) pk4[i + 3*(nb11/8)];
3244
 
3245
- mqk += (float4) (mq[ii/NW] * mk);
 
 
 
 
3246
  }
3247
 
3248
- // reduce the results from the threads in the simdgroup
3249
- mqk += simd_shuffle_down(mqk, 16);
3250
- mqk += simd_shuffle_down(mqk, 8);
 
 
 
 
3251
  mqk += simd_shuffle_down(mqk, 4);
3252
  mqk += simd_shuffle_down(mqk, 2);
3253
  mqk += simd_shuffle_down(mqk, 1);
3254
 
3255
  // mqk = mqk*scale + mask*slope
3256
- if (tiisg == 0) {
3257
  mqk *= scale;
3258
 
3259
  if (logit_softcap != 0.0f) {
3260
  mqk = logit_softcap*precise::tanh(mqk);
3261
  }
3262
 
3263
- mqk += (mask != q) ? ((float4) mp4[ic/4 + cc])*slope : (float4) 0.0f;
3264
 
3265
- ss4[cc] = mqk;
3266
  }
3267
  }
3268
  }
3269
 
 
 
3270
  // online softmax
3271
  {
3272
  const short p = tiisg;
@@ -3286,29 +3398,32 @@ kernel void kernel_flash_attn_ext_vec_f16(
3286
 
3287
  // O = diag(ms)*O
3288
  #pragma unroll
3289
- for (short ii = 0; ii < D4; ii += NW) {
3290
- lo[ii/NW] *= ms;
3291
  }
3292
  }
3293
 
 
 
3294
  // O = O + (Q*K^T)*V
3295
  {
3296
  #pragma unroll
3297
  for (short cc = 0; cc < C/4; ++cc) {
3298
- device const half4 * pv4 = (device const half4 *) ((device const char *) v + ((ic + 4*cc)*nb21 + iv2*nb22 + iv3*nb23));
 
 
3299
 
3300
  #pragma unroll
3301
- for (short ii = 0; ii < D4; ii += NW) {
3302
- const short i = ii + tiisg;
 
 
 
3303
 
3304
- lo[ii/NW] += pv4[i + 0*(nb21/8)] * ss[4*cc + 0];
3305
- lo[ii/NW] += pv4[i + 1*(nb21/8)] * ss[4*cc + 1];
3306
- lo[ii/NW] += pv4[i + 2*(nb21/8)] * ss[4*cc + 2];
3307
- lo[ii/NW] += pv4[i + 3*(nb21/8)] * ss[4*cc + 3];
3308
  }
3309
  }
3310
  }
3311
-
3312
  }
3313
 
3314
  // these are needed for reducing the results from the simdgroups (reuse the ss buffer)
@@ -3318,10 +3433,32 @@ kernel void kernel_flash_attn_ext_vec_f16(
3318
  }
3319
  }
3320
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3321
  // store results to shared memory
3322
- for (short ii = 0; ii < D4; ii += NW) {
3323
- short i = ii + tiisg;
3324
- sr4[i] = lo[ii/NW];
3325
  }
3326
 
3327
  threadgroup_barrier(mem_flags::mem_threadgroup);
@@ -3348,30 +3485,41 @@ kernel void kernel_flash_attn_ext_vec_f16(
3348
  }
3349
 
3350
  // O_0 = diag(ms0)*O_0 + diag(ms1)*O_1
3351
- for (short ii = 0; ii < D4; ii += NW) {
3352
- short i = ii + tiisg;
3353
- sr4[i] = sr4[i]*ms0 + sr4[i + r*D4]*ms1;
3354
  }
3355
  }
3356
 
3357
  threadgroup_barrier(mem_flags::mem_threadgroup);
3358
  }
3359
 
3360
- device float4 * dst4 = (device float4 *) dst;
3361
 
3362
  // final rescale with 1/S and store to global memory
3363
  if (sgitg == 0) {
3364
  const float S = ss[0];
3365
 
3366
- for (short ii = 0; ii < D4; ii += NW) {
3367
- short i = ii + tiisg;
3368
- dst4[(iq3*ne2*ne1 + iq2 + (iq1)*ne1)*D4 + i] = (float4) sr4[i]/S;
3369
  }
3370
  }
3371
  }
3372
 
3373
- template [[host_name("kernel_flash_attn_ext_vec_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<128>;
3374
- //template [[host_name("kernel_flash_attn_ext_vec_f16_h256")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<256>;
 
 
 
 
 
 
 
 
 
 
 
 
 
3375
 
3376
  template<typename T0, typename T1>
3377
  kernel void kernel_cpy(
 
2723
  dst[tpig] = src0[tpig] > 0.0f ? src0[tpig] : src0[tpig] * slope;
2724
  }
2725
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2726
  // ref: https://arxiv.org/pdf/2307.08691.pdf
2727
+ // D - head size, Q - queries per threadgroup, KV - key/value processed per each simdgroup, C - cache items per threadgroup
2728
+ template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &), short D, short Q = 8, short KV = 8, short C = 32>
2729
+ kernel void kernel_flash_attn_ext(
2730
  device const char * q,
2731
  device const char * k,
2732
  device const char * v,
 
2764
  ushort sgitg[[simdgroup_index_in_threadgroup]]) {
2765
  const short nsg = ntg.y; // number of simdgroups
2766
 
2767
+ const int iq3 = tgpig[2];
2768
+ const int iq2 = tgpig[1];
2769
+ const int iq1 = tgpig[0]*Q;
2770
 
2771
+ const short D4 = D/4;
2772
+ const short D8 = D/8;
2773
+ const short D16 = D/16;
2774
+ const short NW = N_SIMDWIDTH;
2775
+ const short SH = (C + Q); // shared memory per simdgroup in (half)
2776
 
2777
  const short T = D + 2*nsg*SH; // shared memory size per query in (half)
2778
  const short TF = T/2; // shared memory size per query in (float)
 
2782
  threadgroup half4 * sq4 = (threadgroup half4 *) (shared + 0*D); // same as above but in half4
2783
  threadgroup float * ss = (threadgroup float *) (shared + 2*sgitg*SH + 1*D); // scratch buffer for attention and diagonal matrix
2784
 
2785
+ threadgroup half * skv = (threadgroup half *) (shared + sgitg*(4*16*KV) + Q*T); // scratch buffer to load K and V in shared memory
2786
+ threadgroup half4x4 * skv4 = (threadgroup half4x4 *) (shared + sgitg*(4*16*KV) + Q*T); // same as above but in half4x4
2787
+
2788
  // store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper)
2789
  simdgroup_half8x8 lo[D8];
2790
 
 
2816
  threadgroup_barrier(mem_flags::mem_threadgroup);
2817
 
2818
  {
2819
+ float S[Q] = { [0 ... Q-1] = 0.0f };
2820
  float M[Q] = { [0 ... Q-1] = -FLT_MAX/2 };
2821
 
2822
+ // thread indices inside the simdgroup
2823
+ const short tx = tiisg%4;
2824
+ const short ty = tiisg/4;
2825
+
2826
  // assume K and V are same shape
2827
  const short ne22 = ne12;
2828
  const short ne23 = ne13;
2829
 
2830
+ // broadcast k
2831
  const short rk2 = ne02/ne12;
2832
  const short rk3 = ne03/ne13;
2833
 
 
 
 
 
2834
  const short ik2 = iq2/rk2;
2835
  const short ik3 = iq3/rk3;
2836
 
2837
+ // broadcast v
2838
+ const short rv2 = ne02/ne22;
2839
+ const short rv3 = ne03/ne23;
2840
+
2841
  const short iv2 = iq2/rv2;
2842
  const short iv3 = iq3/rv3;
2843
 
 
2876
  for (short cc = 0; cc < C/8; ++cc) {
2877
  simdgroup_float8x8 mqk = make_filled_simdgroup_matrix<float, 8>(0.h);
2878
 
2879
+ // this is compile-time check, so it does not have runtime overhead
2880
+ if (is_same<block_q, half4x4>::value) {
2881
+ // we can read directly from global memory
2882
+ device const half * pk = (device const half *) ((device const char *) k + ((ic + 8*cc)*nb11 + ik2*nb12 + ik3*nb13));
2883
+
2884
+ for (short i = 0; i < D8; ++i) {
2885
+ simdgroup_half8x8 mk;
2886
+ simdgroup_load(mk, pk + i*8, nb11/sizeof(half), 0, true); // transpose
2887
+
2888
+ simdgroup_multiply_accumulate(mqk, mq[i], mk, mqk);
2889
+ }
2890
+ } else {
2891
+ for (short ii = 0; ii < D16; ii += 4) {
2892
+ device const block_q * pk4 = (device const block_q *) ((device const char *) k + ((ic + 8*cc + ty)*nb11 + ik2*nb12 + ik3*nb13));
2893
+
2894
+ if (D16%4 == 0) {
2895
+ // the head is evenly divisible by 4*16 = 64, so no need for bound checks
2896
+ half4x4 tmp;
2897
+ dequantize_func(pk4 + (ii + tx)/nl, (ii + tx)%nl, tmp);
2898
+ skv4[4*ty + tx] = tmp;
2899
 
2900
+ simdgroup_barrier(mem_flags::mem_threadgroup);
 
 
2901
 
2902
+ #pragma unroll
2903
+ for (short k = 0; k < 4; ++k) {
2904
+ simdgroup_half8x8 mk;
2905
+
2906
+ simdgroup_load(mk, skv + 16*k + 0*8, 4*16, 0, true); // transpose
2907
+ simdgroup_multiply_accumulate(mqk, mq[2*(ii + k) + 0], mk, mqk);
2908
+
2909
+ simdgroup_load(mk, skv + 16*k + 1*8, 4*16, 0, true); // transpose
2910
+ simdgroup_multiply_accumulate(mqk, mq[2*(ii + k) + 1], mk, mqk);
2911
+ }
2912
+ } else {
2913
+ if (ii + tx < D16) {
2914
+ half4x4 tmp;
2915
+ dequantize_func(pk4 + (ii + tx)/nl, (ii + tx)%nl, tmp);
2916
+ skv4[4*ty + tx] = tmp;
2917
+ }
2918
+
2919
+ simdgroup_barrier(mem_flags::mem_threadgroup);
2920
+
2921
+ for (short k = 0; k < 4 && ii + k < D16; ++k) {
2922
+ simdgroup_half8x8 mk;
2923
+
2924
+ simdgroup_load(mk, skv + 16*k + 0*8, 4*16, 0, true); // transpose
2925
+ simdgroup_multiply_accumulate(mqk, mq[2*(ii + k) + 0], mk, mqk);
2926
+
2927
+ simdgroup_load(mk, skv + 16*k + 1*8, 4*16, 0, true); // transpose
2928
+ simdgroup_multiply_accumulate(mqk, mq[2*(ii + k) + 1], mk, mqk);
2929
+ }
2930
+ }
2931
+ }
2932
  }
2933
 
2934
  simdgroup_store(mqk, ss + 8*cc, TF, 0, false);
 
2993
  // O = O + (Q*K^T)*V
2994
  {
2995
  for (short cc = 0; cc < C/8; ++cc) {
2996
+ simdgroup_float8x8 ms;
2997
+ simdgroup_load(ms, ss + 8*cc, TF, 0, false);
2998
+
2999
+ if (is_same<block_q, half4x4>::value) {
3000
+ // we can read directly from global memory
3001
+ device const half * pv = (device const half *) ((device const char *) v + ((ic + 8*cc)*nb21 + iv2*nb22 + iv3*nb23));
3002
+ #pragma unroll
3003
+ for (short i = 0; i < D8; ++i) {
3004
+ simdgroup_half8x8 mv;
3005
+ simdgroup_load(mv, pv + i*8, nb21/sizeof(half), 0, false);
3006
 
3007
+ simdgroup_multiply_accumulate(lo[i], ms, mv, lo[i]);
3008
+ }
3009
+ } else {
3010
+ for (short ii = 0; ii < D16; ii += 4) {
3011
+ device const block_q * pv4 = (device const block_q *) ((device const char *) v + ((ic + 8*cc + ty)*nb21 + iv2*nb22 + iv3*nb23));
3012
+
3013
+ if (D16%4 == 0) {
3014
+ // no need for bound checks
3015
+ half4x4 tmp;
3016
+ dequantize_func(pv4 + (ii + tx)/nl, (ii + tx)%nl, tmp);
3017
+ skv4[4*ty + tx] = tmp;
3018
 
3019
+ simdgroup_barrier(mem_flags::mem_threadgroup);
 
3020
 
3021
+ #pragma unroll
3022
+ for (short k = 0; k < 4; ++k) {
3023
+ simdgroup_half8x8 mv;
3024
+
3025
+ simdgroup_load(mv, skv + 16*k + 0*8, 4*16, 0, false);
3026
+ simdgroup_multiply_accumulate(lo[2*(ii + k) + 0], ms, mv, lo[2*(ii + k) + 0]);
3027
+
3028
+ simdgroup_load(mv, skv + 16*k + 1*8, 4*16, 0, false);
3029
+ simdgroup_multiply_accumulate(lo[2*(ii + k) + 1], ms, mv, lo[2*(ii + k) + 1]);
3030
+ }
3031
+ } else {
3032
+ if (ii + tx < D16) {
3033
+ half4x4 tmp;
3034
+ dequantize_func(pv4 + (ii + tx)/nl, (ii + tx)%nl, tmp);
3035
+ skv4[4*ty + tx] = tmp;
3036
+ }
3037
+
3038
+ simdgroup_barrier(mem_flags::mem_threadgroup);
3039
+
3040
+ for (short k = 0; k < 4 && ii + k < D16; ++k) {
3041
+ simdgroup_half8x8 mv;
3042
+
3043
+ simdgroup_load(mv, skv + 16*k + 0*8, 4*16, 0, false);
3044
+ simdgroup_multiply_accumulate(lo[2*(ii + k) + 0], ms, mv, lo[2*(ii + k) + 0]);
3045
+
3046
+ simdgroup_load(mv, skv + 16*k + 1*8, 4*16, 0, false);
3047
+ simdgroup_multiply_accumulate(lo[2*(ii + k) + 1], ms, mv, lo[2*(ii + k) + 1]);
3048
+ }
3049
+ }
3050
+ }
3051
  }
3052
  }
3053
  }
 
3064
 
3065
  // reduce the warps sequentially
3066
  for (short sg = 1; sg < nsg; ++sg) {
3067
+ float S = { 0.0f };
3068
  float M = { -FLT_MAX/2 };
3069
 
3070
  threadgroup_barrier(mem_flags::mem_threadgroup);
 
3143
  }
3144
  }
3145
 
3146
+ typedef decltype(kernel_flash_attn_ext<half4x4, 1, dequantize_f16, 64>) flash_attn_ext_t;
3147
+
3148
+ template [[host_name("kernel_flash_attn_ext_f16_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<half4x4, 1, dequantize_f16, 64>;
3149
+ template [[host_name("kernel_flash_attn_ext_f16_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<half4x4, 1, dequantize_f16, 80>;
3150
+ template [[host_name("kernel_flash_attn_ext_f16_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<half4x4, 1, dequantize_f16, 96>;
3151
+ template [[host_name("kernel_flash_attn_ext_f16_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<half4x4, 1, dequantize_f16, 112>;
3152
+ template [[host_name("kernel_flash_attn_ext_f16_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<half4x4, 1, dequantize_f16, 128>;
3153
+ template [[host_name("kernel_flash_attn_ext_f16_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<half4x4, 1, dequantize_f16, 256>;
3154
+
3155
+ template [[host_name("kernel_flash_attn_ext_q4_0_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q4_0, 2, dequantize_q4_0, 64>;
3156
+ template [[host_name("kernel_flash_attn_ext_q4_0_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q4_0, 2, dequantize_q4_0, 80>;
3157
+ template [[host_name("kernel_flash_attn_ext_q4_0_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q4_0, 2, dequantize_q4_0, 96>;
3158
+ template [[host_name("kernel_flash_attn_ext_q4_0_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q4_0, 2, dequantize_q4_0, 112>;
3159
+ template [[host_name("kernel_flash_attn_ext_q4_0_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q4_0, 2, dequantize_q4_0, 128>;
3160
+ template [[host_name("kernel_flash_attn_ext_q4_0_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q4_0, 2, dequantize_q4_0, 256>;
3161
+
3162
+ template [[host_name("kernel_flash_attn_ext_q4_1_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q4_1, 2, dequantize_q4_1, 64>;
3163
+ template [[host_name("kernel_flash_attn_ext_q4_1_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q4_1, 2, dequantize_q4_1, 80>;
3164
+ template [[host_name("kernel_flash_attn_ext_q4_1_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q4_1, 2, dequantize_q4_1, 96>;
3165
+ template [[host_name("kernel_flash_attn_ext_q4_1_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q4_1, 2, dequantize_q4_1, 112>;
3166
+ template [[host_name("kernel_flash_attn_ext_q4_1_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q4_1, 2, dequantize_q4_1, 128>;
3167
+ template [[host_name("kernel_flash_attn_ext_q4_1_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q4_1, 2, dequantize_q4_1, 256>;
3168
+
3169
+ template [[host_name("kernel_flash_attn_ext_q5_0_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q5_0, 2, dequantize_q5_0, 64>;
3170
+ template [[host_name("kernel_flash_attn_ext_q5_0_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q5_0, 2, dequantize_q5_0, 80>;
3171
+ template [[host_name("kernel_flash_attn_ext_q5_0_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q5_0, 2, dequantize_q5_0, 96>;
3172
+ template [[host_name("kernel_flash_attn_ext_q5_0_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q5_0, 2, dequantize_q5_0, 112>;
3173
+ template [[host_name("kernel_flash_attn_ext_q5_0_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q5_0, 2, dequantize_q5_0, 128>;
3174
+ template [[host_name("kernel_flash_attn_ext_q5_0_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q5_0, 2, dequantize_q5_0, 256>;
3175
+
3176
+ template [[host_name("kernel_flash_attn_ext_q5_1_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q5_1, 2, dequantize_q5_1, 64>;
3177
+ template [[host_name("kernel_flash_attn_ext_q5_1_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q5_1, 2, dequantize_q5_1, 80>;
3178
+ template [[host_name("kernel_flash_attn_ext_q5_1_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q5_1, 2, dequantize_q5_1, 96>;
3179
+ template [[host_name("kernel_flash_attn_ext_q5_1_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q5_1, 2, dequantize_q5_1, 112>;
3180
+ template [[host_name("kernel_flash_attn_ext_q5_1_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q5_1, 2, dequantize_q5_1, 128>;
3181
+ template [[host_name("kernel_flash_attn_ext_q5_1_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q5_1, 2, dequantize_q5_1, 256>;
3182
+
3183
+ template [[host_name("kernel_flash_attn_ext_q8_0_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q8_0, 2, dequantize_q8_0, 64>;
3184
+ template [[host_name("kernel_flash_attn_ext_q8_0_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q8_0, 2, dequantize_q8_0, 80>;
3185
+ template [[host_name("kernel_flash_attn_ext_q8_0_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q8_0, 2, dequantize_q8_0, 96>;
3186
+ template [[host_name("kernel_flash_attn_ext_q8_0_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q8_0, 2, dequantize_q8_0, 112>;
3187
+ template [[host_name("kernel_flash_attn_ext_q8_0_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q8_0, 2, dequantize_q8_0, 128>;
3188
+ template [[host_name("kernel_flash_attn_ext_q8_0_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q8_0, 2, dequantize_q8_0, 256>;
3189
+
3190
+ // NOTE: can use half instead of float precision for some extra perf
3191
+ // D - head size, Q - queries per threadgroup, C - cache items per threadgroup
3192
+ template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread float4x4 &), short D, short Q = 1, short C = 32>
3193
+ kernel void kernel_flash_attn_ext_vec(
3194
  device const char * q,
3195
  device const char * k,
3196
  device const char * v,
 
3228
  ushort sgitg[[simdgroup_index_in_threadgroup]]) {
3229
  const short nsg = ntg.y; // number of simdgroups
3230
 
3231
+ const int iq3 = tgpig[2];
3232
+ const int iq2 = tgpig[1];
3233
+ const int iq1 = tgpig[0];
3234
 
3235
+ const short D4 = D/4;
3236
+ const short D16 = D/16;
3237
+ const short NW = N_SIMDWIDTH;
3238
+ const short NW4 = NW/4;
3239
+ const short SH = C; // shared memory per simdgroup in (half)
3240
 
3241
  const short T = D + 2*nsg*SH; // shared memory size per query in (half)
3242
 
3243
+ //threadgroup half * sq = (threadgroup half *) (shared + 0*D); // holds the query data
3244
+ threadgroup half4 * sq4 = (threadgroup half4 *) (shared + 0*D); // same as above but in half4
3245
+ threadgroup half4x4 * sq44 = (threadgroup half4x4 *) (shared + 0*D); // same as above but in half4x4
3246
+ threadgroup float * ss = (threadgroup float *) (shared + 2*sgitg*SH + 1*D); // scratch buffer for attention
3247
+ threadgroup float4 * ss4 = (threadgroup float4 *) (shared + 2*sgitg*SH + 1*D); // same as above but in half4
3248
+ threadgroup float4x4 * sr44 = (threadgroup float4x4 *) (shared + 2*sgitg*D + Q*T); // scratch buffer for the results
 
 
 
 
 
 
 
 
 
 
 
3249
 
3250
  // store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper)
3251
+ float4x4 lo[D16/NW4];
3252
 
3253
  // load heads from Q to shared memory
3254
  device const float4 * q4 = (device const float4 *) ((device const char *) q + (iq1*nb01 + iq2*nb02 + iq3*nb03));
 
3262
  }
3263
 
3264
  // zero out lo
3265
+ for (short i = 0; i < D16/NW4; i += NW4) {
3266
+ lo[i] = float4x4(0.0f);
3267
  }
3268
 
3269
  // zero out shared memory SH
 
3274
  threadgroup_barrier(mem_flags::mem_threadgroup);
3275
 
3276
  {
3277
+ float S = 0.0f;
3278
+ float M = -FLT_MAX/2;
3279
+
3280
+ // thread indices inside the simdgroup
3281
+ const short tx = tiisg%8;
3282
+ const short ty = tiisg/8;
3283
 
3284
  // assume K and V are same shape
3285
  const short ne22 = ne12;
3286
  const short ne23 = ne13;
3287
 
3288
+ // broadcast k
3289
  const short rk2 = ne02/ne12;
3290
  const short rk3 = ne03/ne13;
3291
 
3292
+ const short ik2 = iq2/rk2;
3293
+ const short ik3 = iq3/rk3;
3294
+
3295
+ // broadcast v
3296
  const short rv2 = ne02/ne22;
3297
  const short rv3 = ne03/ne23;
3298
 
3299
+ const short iv2 = iq2/rv2;
3300
+ const short iv3 = iq3/rv3;
 
 
 
 
 
3301
 
3302
  // load the queries from shared memory into local memory
3303
+ float4x4 mq[D16/NW4];
3304
 
3305
+ for (short ii = 0; ii < D16; ii += NW4) {
3306
+ mq[ii/NW4] = (float4x4) sq44[ii + tx];
 
3307
  }
3308
 
3309
  // pointer to the mask
3310
+ device const half * mp = (device const half *) (mask + iq1*nb31);
3311
+
3312
+ float slope = 1.0f;
3313
+
3314
+ // ALiBi
3315
+ if (max_bias > 0.0f) {
3316
+ const uint32_t h = iq2;
3317
+
3318
+ const float base = h < n_head_log2 ? m0 : m1;
3319
+ const int exp = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
3320
+
3321
+ slope = pow(base, exp);
3322
+ }
3323
 
3324
  // loop over the KV cache
3325
  // each simdgroup handles blocks of Q rows and C columns
 
3331
 
3332
  // Q*K^T
3333
  {
3334
+ // each simdgroup processes 1 query and 4 keys
3335
  for (short cc = 0; cc < C/4; ++cc) {
3336
+ float mqk = 0.0;
3337
 
3338
+ device const block_q * pk = (device const block_q *) ((device const char *) k + ((ic + 4*cc + ty)*nb11 + ik2*nb12 + ik3*nb13));
3339
 
3340
  #pragma unroll
3341
+ for (short ii = 0; ii < D16; ii += NW4) {
3342
+ const short i = ii + tx;
3343
 
3344
  float4x4 mk;
3345
+ dequantize_func(pk + i/nl, i%nl, mk);
 
 
 
3346
 
3347
+ mqk +=
3348
+ dot(mq[ii/NW4][0], mk[0]) +
3349
+ dot(mq[ii/NW4][1], mk[1]) +
3350
+ dot(mq[ii/NW4][2], mk[2]) +
3351
+ dot(mq[ii/NW4][3], mk[3]);
3352
  }
3353
 
3354
+ // simdgroup reduce
3355
+ // [ 0 .. 7] -> [ 0]
3356
+ // [ 8 .. 15] -> [ 8]
3357
+ // [16 .. 23] -> [16]
3358
+ // [24 .. 31] -> [24]
3359
+ //mqk += simd_shuffle_down(mqk, 16);
3360
+ //mqk += simd_shuffle_down(mqk, 8);
3361
  mqk += simd_shuffle_down(mqk, 4);
3362
  mqk += simd_shuffle_down(mqk, 2);
3363
  mqk += simd_shuffle_down(mqk, 1);
3364
 
3365
  // mqk = mqk*scale + mask*slope
3366
+ if (tx == 0) {
3367
  mqk *= scale;
3368
 
3369
  if (logit_softcap != 0.0f) {
3370
  mqk = logit_softcap*precise::tanh(mqk);
3371
  }
3372
 
3373
+ mqk += (mask != q) ? ((float) mp[ic + 4*cc + ty])*slope : (float) 0.0f;
3374
 
3375
+ ss[4*cc + ty] = mqk;
3376
  }
3377
  }
3378
  }
3379
 
3380
+ simdgroup_barrier(mem_flags::mem_threadgroup);
3381
+
3382
  // online softmax
3383
  {
3384
  const short p = tiisg;
 
3398
 
3399
  // O = diag(ms)*O
3400
  #pragma unroll
3401
+ for (short ii = 0; ii < D16; ii += NW4) {
3402
+ lo[ii/NW4] *= ms;
3403
  }
3404
  }
3405
 
3406
+ simdgroup_barrier(mem_flags::mem_threadgroup);
3407
+
3408
  // O = O + (Q*K^T)*V
3409
  {
3410
  #pragma unroll
3411
  for (short cc = 0; cc < C/4; ++cc) {
3412
+ device const block_q * pv4 = (device const block_q *) ((device const char *) v + ((ic + 4*cc + ty)*nb21 + iv2*nb22 + iv3*nb23));
3413
+
3414
+ const float4x4 lss(ss[4*cc + ty]);
3415
 
3416
  #pragma unroll
3417
+ for (short ii = 0; ii < D16; ii += NW4) {
3418
+ const short i = ii + tx;
3419
+
3420
+ float4x4 mv;
3421
+ dequantize_func(pv4 + i/nl, i%nl, mv);
3422
 
3423
+ lo[ii/NW4] += mv*lss;
 
 
 
3424
  }
3425
  }
3426
  }
 
3427
  }
3428
 
3429
  // these are needed for reducing the results from the simdgroups (reuse the ss buffer)
 
3433
  }
3434
  }
3435
 
3436
+ // simdgroup reduce
3437
+ // [ 0, 8, 16, 24] -> [ 0]
3438
+ // [ 1, 9, 17, 25] -> [ 1]
3439
+ // [ 2, 10, 18, 26] -> [ 2]
3440
+ // [ 3, 11, 19, 27] -> [ 3]
3441
+ // [ 4, 12, 20, 28] -> [ 4]
3442
+ // [ 5, 13, 21, 29] -> [ 5]
3443
+ // [ 6, 14, 22, 30] -> [ 6]
3444
+ // [ 7, 15, 23, 31] -> [ 7]
3445
+ for (short ii = 0; ii < D16; ii += NW4) {
3446
+ lo[ii/NW4][0] += simd_shuffle_down(lo[ii/NW4][0], 16);
3447
+ lo[ii/NW4][0] += simd_shuffle_down(lo[ii/NW4][0], 8);
3448
+
3449
+ lo[ii/NW4][1] += simd_shuffle_down(lo[ii/NW4][1], 16);
3450
+ lo[ii/NW4][1] += simd_shuffle_down(lo[ii/NW4][1], 8);
3451
+
3452
+ lo[ii/NW4][2] += simd_shuffle_down(lo[ii/NW4][2], 16);
3453
+ lo[ii/NW4][2] += simd_shuffle_down(lo[ii/NW4][2], 8);
3454
+
3455
+ lo[ii/NW4][3] += simd_shuffle_down(lo[ii/NW4][3], 16);
3456
+ lo[ii/NW4][3] += simd_shuffle_down(lo[ii/NW4][3], 8);
3457
+ }
3458
+
3459
  // store results to shared memory
3460
+ for (short i = tiisg; i < D16; i += NW4) {
3461
+ sr44[i] = lo[i/NW4];
 
3462
  }
3463
 
3464
  threadgroup_barrier(mem_flags::mem_threadgroup);
 
3485
  }
3486
 
3487
  // O_0 = diag(ms0)*O_0 + diag(ms1)*O_1
3488
+ for (short i = tiisg; i < D16; i += NW) {
3489
+ sr44[i] = sr44[i]*ms0 + sr44[i + r*D16]*ms1;
 
3490
  }
3491
  }
3492
 
3493
  threadgroup_barrier(mem_flags::mem_threadgroup);
3494
  }
3495
 
3496
+ device float4x4 * dst44 = (device float4x4 *) dst;
3497
 
3498
  // final rescale with 1/S and store to global memory
3499
  if (sgitg == 0) {
3500
  const float S = ss[0];
3501
 
3502
+ for (short i = tiisg; i < D16; i += NW) {
3503
+ dst44[(iq3*ne2*ne1 + iq2 + (iq1)*ne1)*D16 + i] = sr44[i]/S;
 
3504
  }
3505
  }
3506
  }
3507
 
3508
+ typedef decltype(kernel_flash_attn_ext_vec<half4x4, 1, dequantize_f16, 64>) flash_attn_ext_vec_t;
3509
+
3510
+ template [[host_name("kernel_flash_attn_ext_vec_f16_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<half4x4, 1, dequantize_f16, 128>;
3511
+ template [[host_name("kernel_flash_attn_ext_vec_q4_0_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<block_q4_0, 2, dequantize_q4_0, 128>;
3512
+ template [[host_name("kernel_flash_attn_ext_vec_q4_1_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<block_q4_1, 2, dequantize_q4_1, 128>;
3513
+ template [[host_name("kernel_flash_attn_ext_vec_q5_0_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<block_q5_0, 2, dequantize_q5_0, 128>;
3514
+ template [[host_name("kernel_flash_attn_ext_vec_q5_1_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<block_q5_1, 2, dequantize_q5_1, 128>;
3515
+ template [[host_name("kernel_flash_attn_ext_vec_q8_0_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<block_q8_0, 2, dequantize_q8_0, 128>;
3516
+
3517
+ template [[host_name("kernel_flash_attn_ext_vec_f16_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<half4x4, 1, dequantize_f16, 256>;
3518
+ template [[host_name("kernel_flash_attn_ext_vec_q4_0_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<block_q4_0, 2, dequantize_q4_0, 256>;
3519
+ template [[host_name("kernel_flash_attn_ext_vec_q4_1_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<block_q4_1, 2, dequantize_q4_1, 256>;
3520
+ template [[host_name("kernel_flash_attn_ext_vec_q5_0_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<block_q5_0, 2, dequantize_q5_0, 256>;
3521
+ template [[host_name("kernel_flash_attn_ext_vec_q5_1_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<block_q5_1, 2, dequantize_q5_1, 256>;
3522
+ template [[host_name("kernel_flash_attn_ext_vec_q8_0_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<block_q8_0, 2, dequantize_q8_0, 256>;
3523
 
3524
  template<typename T0, typename T1>
3525
  kernel void kernel_cpy(