ggerganov commited on
Commit
847669b
·
1 Parent(s): c4d59b9

metal : add BF16 support (llama/8439)

Browse files

* ggml : add initial BF16 support

ggml-ci

* metal : add mul_mat_id BF16 support

ggml-ci

* metal : check for bfloat support on the Metal device

ggml-ci

* metal : better var names [no ci]

* metal : do not build bfloat kernels when not supported

ggml-ci

* metal : try to fix BF16 support check

ggml-ci

* metal : this should correctly check bfloat support

Files changed (2) hide show
  1. ggml/src/ggml-metal.m +259 -179
  2. ggml/src/ggml-metal.metal +54 -4
ggml/src/ggml-metal.m CHANGED
@@ -36,16 +36,18 @@ static struct ggml_backend_metal_device_context {
36
  id<MTLDevice> mtl_device;
37
  int mtl_device_ref_count;
38
 
39
- bool support_simdgroup_reduction;
40
- bool support_simdgroup_mm;
 
41
 
42
  char name[128];
43
  } g_ggml_ctx_dev_main = {
44
- /*.mtl_device =*/ nil,
45
- /*.mtl_device_ref_count =*/ 0,
46
- /*.support_simdgroup_reduction =*/ false,
47
- /*.support_simdgroup_mm =*/ false,
48
- /*.name =*/ "",
 
49
  };
50
 
51
  // acquire
@@ -55,10 +57,13 @@ static id<MTLDevice> ggml_backend_metal_device_acq(struct ggml_backend_metal_dev
55
  if (ctx->mtl_device == nil) {
56
  ctx->mtl_device = MTLCreateSystemDefaultDevice();
57
 
58
- ctx->support_simdgroup_reduction = [ctx->mtl_device supportsFamily:MTLGPUFamilyApple7];
59
- ctx->support_simdgroup_reduction |= [ctx->mtl_device supportsFamily:MTLGPUFamilyMetal3_GGML];
60
 
61
- ctx->support_simdgroup_mm = [ctx->mtl_device supportsFamily:MTLGPUFamilyApple7];
 
 
 
62
 
63
  strncpy(ctx->name, [[ctx->mtl_device name] UTF8String], sizeof(ctx->name) - 1);
64
  }
@@ -120,6 +125,7 @@ enum ggml_metal_kernel_type {
120
  GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF_8,
121
  GGML_METAL_KERNEL_TYPE_GET_ROWS_F32,
122
  GGML_METAL_KERNEL_TYPE_GET_ROWS_F16,
 
123
  GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_0,
124
  GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_1,
125
  GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_0,
@@ -146,10 +152,14 @@ enum ggml_metal_kernel_type {
146
  GGML_METAL_KERNEL_TYPE_SSM_CONV_F32,
147
  GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32,
148
  GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32,
149
- GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16,
150
  GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32,
151
  GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW,
152
  GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4,
 
 
 
 
 
153
  GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_0_F32,
154
  GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_1_F32,
155
  GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32,
@@ -170,10 +180,11 @@ enum ggml_metal_kernel_type {
170
  GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32,
171
  GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_XS_F32,
172
  GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32,
173
- //GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F16,
174
  GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32,
175
  //GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_1ROW,
176
  //GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_L4,
 
 
177
  GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_0_F32,
178
  GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_1_F32,
179
  GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_0_F32,
@@ -195,6 +206,7 @@ enum ggml_metal_kernel_type {
195
  GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32,
196
  GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32,
197
  GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32,
 
198
  GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32,
199
  GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F32,
200
  GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32,
@@ -216,6 +228,7 @@ enum ggml_metal_kernel_type {
216
  GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32,
217
  GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32,
218
  GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32,
 
219
  GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32,
220
  GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F32,
221
  GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F32,
@@ -300,8 +313,11 @@ enum ggml_metal_kernel_type {
300
  GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H256,
301
  GGML_METAL_KERNEL_TYPE_CPY_F32_F32,
302
  GGML_METAL_KERNEL_TYPE_CPY_F32_F16,
 
303
  GGML_METAL_KERNEL_TYPE_CPY_F16_F16,
304
  GGML_METAL_KERNEL_TYPE_CPY_F16_F32,
 
 
305
  GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0,
306
  GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0,
307
  GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1,
@@ -480,7 +496,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
480
  // dictionary of preprocessor macros
481
  NSMutableDictionary * prep = [NSMutableDictionary dictionary];
482
 
483
- MTLCompileOptions* options = [MTLCompileOptions new];
484
  options.preprocessorMacros = prep;
485
 
486
  //[options setFastMathEnabled:false];
@@ -530,9 +546,10 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
530
  }
531
  }
532
 
533
- GGML_LOG_INFO("%s: simdgroup reduction support = %s\n", __func__, ctx_dev->support_simdgroup_reduction ? "true" : "false");
534
- GGML_LOG_INFO("%s: simdgroup matrix mul. support = %s\n", __func__, ctx_dev->support_simdgroup_mm ? "true" : "false");
535
- GGML_LOG_INFO("%s: hasUnifiedMemory = %s\n", __func__, ctx_dev->mtl_device.hasUnifiedMemory ? "true" : "false");
 
536
 
537
  ctx->capture_next_compute = false;
538
  ctx->capture_started = false;
@@ -578,8 +595,9 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
578
  GGML_LOG_WARN("%s: skipping %-40s (not supported)\n", __func__, "kernel_"#name); \
579
  }
580
 
581
- const bool support_simdgroup_mm = ctx_dev->support_simdgroup_mm;
582
- const bool support_simdgroup_reduction = ctx_dev->support_simdgroup_reduction;
 
583
 
584
  // simd_sum and simd_max requires MTLGPUFamilyApple7
585
 
@@ -607,14 +625,15 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
607
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_QUICK_4, gelu_quick_4, true);
608
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU, silu, true);
609
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU_4, silu_4, true);
610
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16, soft_max_f16, support_simdgroup_reduction);
611
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16_4, soft_max_f16_4, support_simdgroup_reduction);
612
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32, soft_max_f32, support_simdgroup_reduction);
613
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32_4, soft_max_f32_4, support_simdgroup_reduction);
614
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF, diag_mask_inf, true);
615
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF_8, diag_mask_inf_8, true);
616
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_F32, get_rows_f32, true);
617
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_F16, get_rows_f16, true);
 
618
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_0, get_rows_q4_0, true);
619
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_1, get_rows_q4_1, true);
620
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_0, get_rows_q5_0, true);
@@ -635,101 +654,108 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
635
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL, get_rows_iq4_nl, true);
636
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS, get_rows_iq4_xs, true);
637
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_I32, get_rows_i32, true);
638
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RMS_NORM, rms_norm, support_simdgroup_reduction);
639
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GROUP_NORM, group_norm, support_simdgroup_reduction);
640
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NORM, norm, true);
641
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_CONV_F32, ssm_conv_f32, true);
642
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32, ssm_scan_f32, true);
643
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32, mul_mv_f32_f32, support_simdgroup_reduction);
644
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16, mul_mv_f16_f16, support_simdgroup_reduction);
645
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32, mul_mv_f16_f32, support_simdgroup_reduction);
646
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW, mul_mv_f16_f32_1row, support_simdgroup_reduction);
647
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4, mul_mv_f16_f32_l4, support_simdgroup_reduction);
648
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_0_F32, mul_mv_q4_0_f32, support_simdgroup_reduction);
649
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_1_F32, mul_mv_q4_1_f32, support_simdgroup_reduction);
650
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32, mul_mv_q5_0_f32, support_simdgroup_reduction);
651
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32, mul_mv_q5_1_f32, support_simdgroup_reduction);
652
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32, mul_mv_q8_0_f32, support_simdgroup_reduction);
653
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q2_K_F32, mul_mv_q2_K_f32, support_simdgroup_reduction);
654
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q3_K_F32, mul_mv_q3_K_f32, support_simdgroup_reduction);
655
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_K_F32, mul_mv_q4_K_f32, support_simdgroup_reduction);
656
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_K_F32, mul_mv_q5_K_f32, support_simdgroup_reduction);
657
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q6_K_F32, mul_mv_q6_K_f32, support_simdgroup_reduction);
658
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XXS_F32, mul_mv_iq2_xxs_f32, support_simdgroup_reduction);
659
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32, mul_mv_iq2_xs_f32, support_simdgroup_reduction);
660
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_XXS_F32, mul_mv_iq3_xxs_f32, support_simdgroup_reduction);
661
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_S_F32, mul_mv_iq3_s_f32, support_simdgroup_reduction);
662
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_S_F32, mul_mv_iq2_s_f32, support_simdgroup_reduction);
663
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_S_F32, mul_mv_iq1_s_f32, support_simdgroup_reduction);
664
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_M_F32, mul_mv_iq1_m_f32, support_simdgroup_reduction);
665
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32, mul_mv_iq4_nl_f32, support_simdgroup_reduction);
666
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_XS_F32, mul_mv_iq4_xs_f32, support_simdgroup_reduction);
667
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32, mul_mv_id_f32_f32, support_simdgroup_reduction);
668
- //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F16, mul_mv_id_f16_f16, support_simdgroup_reduction);
669
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32, mul_mv_id_f16_f32, support_simdgroup_reduction);
670
- //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_1ROW, mul_mv_id_f16_f32_1row, support_simdgroup_reduction);
671
- //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_L4, mul_mv_id_f16_f32_l4, support_simdgroup_reduction);
672
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_0_F32, mul_mv_id_q4_0_f32, support_simdgroup_reduction);
673
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_1_F32, mul_mv_id_q4_1_f32, support_simdgroup_reduction);
674
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_0_F32, mul_mv_id_q5_0_f32, support_simdgroup_reduction);
675
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_1_F32, mul_mv_id_q5_1_f32, support_simdgroup_reduction);
676
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q8_0_F32, mul_mv_id_q8_0_f32, support_simdgroup_reduction);
677
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q2_K_F32, mul_mv_id_q2_K_f32, support_simdgroup_reduction);
678
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q3_K_F32, mul_mv_id_q3_K_f32, support_simdgroup_reduction);
679
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_K_F32, mul_mv_id_q4_K_f32, support_simdgroup_reduction);
680
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_K_F32, mul_mv_id_q5_K_f32, support_simdgroup_reduction);
681
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q6_K_F32, mul_mv_id_q6_K_f32, support_simdgroup_reduction);
682
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XXS_F32, mul_mv_id_iq2_xxs_f32, support_simdgroup_reduction);
683
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32, mul_mv_id_iq2_xs_f32, support_simdgroup_reduction);
684
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_XXS_F32, mul_mv_id_iq3_xxs_f32, support_simdgroup_reduction);
685
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_S_F32, mul_mv_id_iq3_s_f32, support_simdgroup_reduction);
686
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_S_F32, mul_mv_id_iq2_s_f32, support_simdgroup_reduction);
687
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_S_F32, mul_mv_id_iq1_s_f32, support_simdgroup_reduction);
688
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_M_F32, mul_mv_id_iq1_m_f32, support_simdgroup_reduction);
689
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32, mul_mv_id_iq4_nl_f32, support_simdgroup_reduction);
690
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32, mul_mv_id_iq4_xs_f32, support_simdgroup_reduction);
691
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32, mul_mm_f32_f32, support_simdgroup_mm);
692
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32, mul_mm_f16_f32, support_simdgroup_mm);
693
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32, mul_mm_q4_0_f32, support_simdgroup_mm);
694
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F32, mul_mm_q4_1_f32, support_simdgroup_mm);
695
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32, mul_mm_q5_0_f32, support_simdgroup_mm);
696
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_1_F32, mul_mm_q5_1_f32, support_simdgroup_mm);
697
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q8_0_F32, mul_mm_q8_0_f32, support_simdgroup_mm);
698
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q2_K_F32, mul_mm_q2_K_f32, support_simdgroup_mm);
699
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q3_K_F32, mul_mm_q3_K_f32, support_simdgroup_mm);
700
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_K_F32, mul_mm_q4_K_f32, support_simdgroup_mm);
701
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_K_F32, mul_mm_q5_K_f32, support_simdgroup_mm);
702
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q6_K_F32, mul_mm_q6_K_f32, support_simdgroup_mm);
703
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32, mul_mm_iq2_xxs_f32, support_simdgroup_mm);
704
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32, mul_mm_iq2_xs_f32, support_simdgroup_mm);
705
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F32, mul_mm_iq3_xxs_f32, support_simdgroup_mm);
706
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_S_F32, mul_mm_iq3_s_f32, support_simdgroup_mm);
707
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_S_F32, mul_mm_iq2_s_f32, support_simdgroup_mm);
708
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32, mul_mm_iq1_s_f32, support_simdgroup_mm);
709
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F32, mul_mm_iq1_m_f32, support_simdgroup_mm);
710
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32, mul_mm_iq4_nl_f32, support_simdgroup_mm);
711
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32, mul_mm_iq4_xs_f32, support_simdgroup_mm);
712
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32, mul_mm_id_f32_f32, support_simdgroup_mm);
713
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32, mul_mm_id_f16_f32, support_simdgroup_mm);
714
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32, mul_mm_id_q4_0_f32, support_simdgroup_mm);
715
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F32, mul_mm_id_q4_1_f32, support_simdgroup_mm);
716
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F32, mul_mm_id_q5_0_f32, support_simdgroup_mm);
717
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F32, mul_mm_id_q5_1_f32, support_simdgroup_mm);
718
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F32, mul_mm_id_q8_0_f32, support_simdgroup_mm);
719
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F32, mul_mm_id_q2_K_f32, support_simdgroup_mm);
720
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F32, mul_mm_id_q3_K_f32, support_simdgroup_mm);
721
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F32, mul_mm_id_q4_K_f32, support_simdgroup_mm);
722
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_K_F32, mul_mm_id_q5_K_f32, support_simdgroup_mm);
723
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F32, mul_mm_id_q6_K_f32, support_simdgroup_mm);
724
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32, mul_mm_id_iq2_xxs_f32, support_simdgroup_mm);
725
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32, mul_mm_id_iq2_xs_f32, support_simdgroup_mm);
726
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32, mul_mm_id_iq3_xxs_f32, support_simdgroup_mm);
727
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_S_F32, mul_mm_id_iq3_s_f32, support_simdgroup_mm);
728
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F32, mul_mm_id_iq2_s_f32, support_simdgroup_mm);
729
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32, mul_mm_id_iq1_s_f32, support_simdgroup_mm);
730
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F32, mul_mm_id_iq1_m_f32, support_simdgroup_mm);
731
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32, mul_mm_id_iq4_nl_f32, support_simdgroup_mm);
732
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32, mul_mm_id_iq4_xs_f32, support_simdgroup_mm);
 
 
 
 
 
 
 
733
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32, rope_norm_f32, true);
734
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16, rope_norm_f16, true);
735
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32, rope_neox_f32, true);
@@ -745,58 +771,61 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
745
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC, argsort_f32_i32_asc, true);
746
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC, argsort_f32_i32_desc, true);
747
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32, leaky_relu_f32, true);
748
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64, flash_attn_ext_f16_h64, support_simdgroup_mm);
749
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80, flash_attn_ext_f16_h80, support_simdgroup_mm);
750
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96, flash_attn_ext_f16_h96, support_simdgroup_mm);
751
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112, flash_attn_ext_f16_h112, support_simdgroup_mm);
752
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128, flash_attn_ext_f16_h128, support_simdgroup_mm);
753
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, flash_attn_ext_f16_h256, support_simdgroup_mm);
754
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H64, flash_attn_ext_q4_0_h64, support_simdgroup_mm);
755
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H80, flash_attn_ext_q4_0_h80, support_simdgroup_mm);
756
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H96, flash_attn_ext_q4_0_h96, support_simdgroup_mm);
757
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H112, flash_attn_ext_q4_0_h112, support_simdgroup_mm);
758
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H128, flash_attn_ext_q4_0_h128, support_simdgroup_mm);
759
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H256, flash_attn_ext_q4_0_h256, support_simdgroup_mm);
760
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H64, flash_attn_ext_q4_1_h64, support_simdgroup_mm);
761
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H80, flash_attn_ext_q4_1_h80, support_simdgroup_mm);
762
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H96, flash_attn_ext_q4_1_h96, support_simdgroup_mm);
763
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H112, flash_attn_ext_q4_1_h112, support_simdgroup_mm);
764
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H128, flash_attn_ext_q4_1_h128, support_simdgroup_mm);
765
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H256, flash_attn_ext_q4_1_h256, support_simdgroup_mm);
766
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H64, flash_attn_ext_q5_0_h64, support_simdgroup_mm);
767
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H80, flash_attn_ext_q5_0_h80, support_simdgroup_mm);
768
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H96, flash_attn_ext_q5_0_h96, support_simdgroup_mm);
769
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H112, flash_attn_ext_q5_0_h112, support_simdgroup_mm);
770
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H128, flash_attn_ext_q5_0_h128, support_simdgroup_mm);
771
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H256, flash_attn_ext_q5_0_h256, support_simdgroup_mm);
772
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H64, flash_attn_ext_q5_1_h64, support_simdgroup_mm);
773
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H80, flash_attn_ext_q5_1_h80, support_simdgroup_mm);
774
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H96, flash_attn_ext_q5_1_h96, support_simdgroup_mm);
775
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H112, flash_attn_ext_q5_1_h112, support_simdgroup_mm);
776
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H128, flash_attn_ext_q5_1_h128, support_simdgroup_mm);
777
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H256, flash_attn_ext_q5_1_h256, support_simdgroup_mm);
778
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H64, flash_attn_ext_q8_0_h64, support_simdgroup_mm);
779
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H80, flash_attn_ext_q8_0_h80, support_simdgroup_mm);
780
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H96, flash_attn_ext_q8_0_h96, support_simdgroup_mm);
781
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H112, flash_attn_ext_q8_0_h112, support_simdgroup_mm);
782
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H128, flash_attn_ext_q8_0_h128, support_simdgroup_mm);
783
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H256, flash_attn_ext_q8_0_h256, support_simdgroup_mm);
784
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128, flash_attn_ext_vec_f16_h128, support_simdgroup_reduction);
785
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H128, flash_attn_ext_vec_q4_0_h128, support_simdgroup_reduction);
786
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H128, flash_attn_ext_vec_q4_1_h128, support_simdgroup_reduction);
787
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H128, flash_attn_ext_vec_q5_0_h128, support_simdgroup_reduction);
788
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H128, flash_attn_ext_vec_q5_1_h128, support_simdgroup_reduction);
789
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H128, flash_attn_ext_vec_q8_0_h128, support_simdgroup_reduction);
790
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, flash_attn_ext_vec_f16_h256, support_simdgroup_reduction);
791
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H256, flash_attn_ext_vec_q4_0_h256, support_simdgroup_reduction);
792
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H256, flash_attn_ext_vec_q4_1_h256, support_simdgroup_reduction);
793
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H256, flash_attn_ext_vec_q5_0_h256, support_simdgroup_reduction);
794
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H256, flash_attn_ext_vec_q5_1_h256, support_simdgroup_reduction);
795
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H256, flash_attn_ext_vec_q8_0_h256, support_simdgroup_reduction);
796
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F16, cpy_f32_f16, true);
797
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F32, cpy_f32_f32, true);
798
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F16, cpy_f16_f16, true);
 
799
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F32, cpy_f16_f32, true);
 
 
 
800
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0, cpy_f32_q8_0, true);
801
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0, cpy_f32_q4_0, true);
802
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1, cpy_f32_q4_1, true);
@@ -886,15 +915,18 @@ static id<MTLBuffer> ggml_metal_get_buffer(struct ggml_tensor * t, size_t * offs
886
  }
887
 
888
  static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_context * ctx_dev, const struct ggml_tensor * op) {
889
- for (size_t i = 0, n = 3; i < n; ++i) {
890
- if (op->src[i] != NULL && op->src[i]->type == GGML_TYPE_BF16) {
891
- return false;
 
 
 
 
 
 
892
  }
893
  }
894
 
895
- const bool support_simdgroup_mm = ctx_dev->support_simdgroup_mm;
896
- const bool support_simdgroup_reduction = ctx_dev->support_simdgroup_reduction;
897
-
898
  switch (op->op) {
899
  case GGML_OP_UNARY:
900
  switch (ggml_get_unary_op(op)) {
@@ -932,7 +964,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
932
  case GGML_OP_SOFT_MAX:
933
  case GGML_OP_RMS_NORM:
934
  case GGML_OP_GROUP_NORM:
935
- return support_simdgroup_reduction;
936
  case GGML_OP_NORM:
937
  case GGML_OP_ROPE:
938
  return true;
@@ -952,13 +984,13 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
952
  if (op->src[1]->type != op->src[2]->type) {
953
  return false;
954
  }
955
- return support_simdgroup_mm; // TODO: over-restricted for vec-kernels
956
  case GGML_OP_SSM_CONV:
957
  case GGML_OP_SSM_SCAN:
958
  return true;
959
  case GGML_OP_MUL_MAT:
960
  case GGML_OP_MUL_MAT_ID:
961
- return support_simdgroup_reduction &&
962
  (op->src[0]->type != GGML_TYPE_F32 || op->src[1]->type == GGML_TYPE_F32);
963
  case GGML_OP_CPY:
964
  case GGML_OP_DUP:
@@ -969,6 +1001,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
969
  switch (op->type) {
970
  case GGML_TYPE_F32:
971
  case GGML_TYPE_F16:
 
972
  case GGML_TYPE_Q8_0:
973
  case GGML_TYPE_Q4_0:
974
  case GGML_TYPE_Q4_1:
@@ -981,10 +1014,18 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
981
  }
982
  case GGML_TYPE_F16:
983
  switch (op->type) {
984
- case GGML_TYPE_F32:
985
- case GGML_TYPE_F16:
986
  return true;
987
- default:
 
 
 
 
 
 
 
 
988
  return false;
989
  }
990
  default:
@@ -1855,6 +1896,7 @@ static void ggml_metal_encode_node(
1855
  switch (src0->type) {
1856
  case GGML_TYPE_F32: GGML_ASSERT(nb01 % 16 == 0); break;
1857
  case GGML_TYPE_F16: GGML_ASSERT(nb01 % 8 == 0); break;
 
1858
  default: break;
1859
  }
1860
 
@@ -1863,6 +1905,7 @@ static void ggml_metal_encode_node(
1863
  switch (src0->type) {
1864
  case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32 ].pipeline; break;
1865
  case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32 ].pipeline; break;
 
1866
  case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32 ].pipeline; break;
1867
  case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F32 ].pipeline; break;
1868
  case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32 ].pipeline; break;
@@ -1940,6 +1983,25 @@ static void ggml_metal_encode_node(
1940
  nrows = 4;
1941
  }
1942
  } break;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1943
  case GGML_TYPE_Q4_0:
1944
  {
1945
  nth0 = 8;
@@ -2158,12 +2220,12 @@ static void ggml_metal_encode_node(
2158
  if ([device supportsFamily:MTLGPUFamilyApple7] &&
2159
  ne00 % 32 == 0 && ne00 >= 64 &&
2160
  dst_rows > dst_rows_min) {
2161
-
2162
  // some Metal matrix data types require aligned pointers
2163
  // ref: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf (Table 2.5)
2164
  switch (src0->type) {
2165
- case GGML_TYPE_F32: GGML_ASSERT(nb01 % 16 == 0); break;
2166
- case GGML_TYPE_F16: GGML_ASSERT(nb01 % 8 == 0); break;
 
2167
  default: break;
2168
  }
2169
 
@@ -2172,6 +2234,7 @@ static void ggml_metal_encode_node(
2172
  switch (src0->type) {
2173
  case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32 ].pipeline; break;
2174
  case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32 ].pipeline; break;
 
2175
  case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32 ].pipeline; break;
2176
  case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F32 ].pipeline; break;
2177
  case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F32 ].pipeline; break;
@@ -2241,6 +2304,13 @@ static void ggml_metal_encode_node(
2241
  nth1 = 1;
2242
  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32].pipeline;
2243
  } break;
 
 
 
 
 
 
 
2244
  case GGML_TYPE_Q4_0:
2245
  {
2246
  nth0 = 8;
@@ -2438,6 +2508,7 @@ static void ggml_metal_encode_node(
2438
  switch (src0->type) {
2439
  case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_F32 ].pipeline; break;
2440
  case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_F16 ].pipeline; break;
 
2441
  case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_0 ].pipeline; break;
2442
  case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_1 ].pipeline; break;
2443
  case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_0 ].pipeline; break;
@@ -3237,6 +3308,7 @@ static void ggml_metal_encode_node(
3237
  switch (dstt) {
3238
  case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F32].pipeline; break;
3239
  case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F16].pipeline; break;
 
3240
  case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0].pipeline; break;
3241
  case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0].pipeline; break;
3242
  case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1].pipeline; break;
@@ -3254,6 +3326,14 @@ static void ggml_metal_encode_node(
3254
  default: GGML_ABORT("not implemented");
3255
  };
3256
  } break;
 
 
 
 
 
 
 
 
3257
  default: GGML_ABORT("not implemented");
3258
  }
3259
 
 
36
  id<MTLDevice> mtl_device;
37
  int mtl_device_ref_count;
38
 
39
+ bool has_simdgroup_reduction;
40
+ bool has_simdgroup_mm;
41
+ bool has_bfloat;
42
 
43
  char name[128];
44
  } g_ggml_ctx_dev_main = {
45
+ /*.mtl_device =*/ nil,
46
+ /*.mtl_device_ref_count =*/ 0,
47
+ /*.has_simdgroup_reduction =*/ false,
48
+ /*.has_simdgroup_mm =*/ false,
49
+ /*.has_bfloat =*/ false,
50
+ /*.name =*/ "",
51
  };
52
 
53
  // acquire
 
57
  if (ctx->mtl_device == nil) {
58
  ctx->mtl_device = MTLCreateSystemDefaultDevice();
59
 
60
+ ctx->has_simdgroup_reduction = [ctx->mtl_device supportsFamily:MTLGPUFamilyApple7];
61
+ ctx->has_simdgroup_reduction |= [ctx->mtl_device supportsFamily:MTLGPUFamilyMetal3_GGML];
62
 
63
+ ctx->has_simdgroup_mm = [ctx->mtl_device supportsFamily:MTLGPUFamilyApple7];
64
+
65
+ ctx->has_bfloat = [ctx->mtl_device supportsFamily:MTLGPUFamilyMetal3_GGML];
66
+ ctx->has_bfloat |= [ctx->mtl_device supportsFamily:MTLGPUFamilyApple6];
67
 
68
  strncpy(ctx->name, [[ctx->mtl_device name] UTF8String], sizeof(ctx->name) - 1);
69
  }
 
125
  GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF_8,
126
  GGML_METAL_KERNEL_TYPE_GET_ROWS_F32,
127
  GGML_METAL_KERNEL_TYPE_GET_ROWS_F16,
128
+ GGML_METAL_KERNEL_TYPE_GET_ROWS_BF16,
129
  GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_0,
130
  GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_1,
131
  GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_0,
 
152
  GGML_METAL_KERNEL_TYPE_SSM_CONV_F32,
153
  GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32,
154
  GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32,
 
155
  GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32,
156
  GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW,
157
  GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4,
158
+ GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16,
159
+ GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32,
160
+ GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_1ROW,
161
+ GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_L4,
162
+ GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_BF16,
163
  GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_0_F32,
164
  GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_1_F32,
165
  GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32,
 
180
  GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32,
181
  GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_XS_F32,
182
  GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32,
 
183
  GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32,
184
  //GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_1ROW,
185
  //GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_L4,
186
+ //GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F16,
187
+ GGML_METAL_KERNEL_TYPE_MUL_MV_ID_BF16_F32,
188
  GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_0_F32,
189
  GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_1_F32,
190
  GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_0_F32,
 
206
  GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32,
207
  GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32,
208
  GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32,
209
+ GGML_METAL_KERNEL_TYPE_MUL_MM_BF16_F32,
210
  GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32,
211
  GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F32,
212
  GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32,
 
228
  GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32,
229
  GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32,
230
  GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32,
231
+ GGML_METAL_KERNEL_TYPE_MUL_MM_ID_BF16_F32,
232
  GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32,
233
  GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F32,
234
  GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F32,
 
313
  GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H256,
314
  GGML_METAL_KERNEL_TYPE_CPY_F32_F32,
315
  GGML_METAL_KERNEL_TYPE_CPY_F32_F16,
316
+ GGML_METAL_KERNEL_TYPE_CPY_F32_BF16,
317
  GGML_METAL_KERNEL_TYPE_CPY_F16_F16,
318
  GGML_METAL_KERNEL_TYPE_CPY_F16_F32,
319
+ GGML_METAL_KERNEL_TYPE_CPY_BF16_F32,
320
+ GGML_METAL_KERNEL_TYPE_CPY_BF16_BF16,
321
  GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0,
322
  GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0,
323
  GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1,
 
496
  // dictionary of preprocessor macros
497
  NSMutableDictionary * prep = [NSMutableDictionary dictionary];
498
 
499
+ MTLCompileOptions * options = [MTLCompileOptions new];
500
  options.preprocessorMacros = prep;
501
 
502
  //[options setFastMathEnabled:false];
 
546
  }
547
  }
548
 
549
+ GGML_LOG_INFO("%s: simdgroup reduction = %s\n", __func__, ctx_dev->has_simdgroup_reduction ? "true" : "false");
550
+ GGML_LOG_INFO("%s: simdgroup matrix mul. = %s\n", __func__, ctx_dev->has_simdgroup_mm ? "true" : "false");
551
+ GGML_LOG_INFO("%s: bfloat = %s\n", __func__, ctx_dev->has_bfloat ? "true" : "false");
552
+ GGML_LOG_INFO("%s: hasUnifiedMemory = %s\n", __func__, ctx_dev->mtl_device.hasUnifiedMemory ? "true" : "false");
553
 
554
  ctx->capture_next_compute = false;
555
  ctx->capture_started = false;
 
595
  GGML_LOG_WARN("%s: skipping %-40s (not supported)\n", __func__, "kernel_"#name); \
596
  }
597
 
598
+ const bool has_simdgroup_mm = ctx_dev->has_simdgroup_mm;
599
+ const bool has_simdgroup_reduction = ctx_dev->has_simdgroup_reduction;
600
+ const bool has_bfloat = ctx_dev->has_bfloat;
601
 
602
  // simd_sum and simd_max requires MTLGPUFamilyApple7
603
 
 
625
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_QUICK_4, gelu_quick_4, true);
626
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU, silu, true);
627
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU_4, silu_4, true);
628
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16, soft_max_f16, has_simdgroup_reduction);
629
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16_4, soft_max_f16_4, has_simdgroup_reduction);
630
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32, soft_max_f32, has_simdgroup_reduction);
631
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32_4, soft_max_f32_4, has_simdgroup_reduction);
632
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF, diag_mask_inf, true);
633
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF_8, diag_mask_inf_8, true);
634
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_F32, get_rows_f32, true);
635
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_F16, get_rows_f16, true);
636
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_BF16, get_rows_bf16, has_bfloat);
637
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_0, get_rows_q4_0, true);
638
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_1, get_rows_q4_1, true);
639
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_0, get_rows_q5_0, true);
 
654
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL, get_rows_iq4_nl, true);
655
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS, get_rows_iq4_xs, true);
656
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_I32, get_rows_i32, true);
657
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RMS_NORM, rms_norm, has_simdgroup_reduction);
658
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GROUP_NORM, group_norm, has_simdgroup_reduction);
659
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NORM, norm, true);
660
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_CONV_F32, ssm_conv_f32, true);
661
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32, ssm_scan_f32, true);
662
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32, mul_mv_f32_f32, has_simdgroup_reduction);
663
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32, mul_mv_bf16_f32, has_simdgroup_reduction && has_bfloat);
664
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_1ROW, mul_mv_bf16_f32_1row, has_simdgroup_reduction && has_bfloat);
665
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_L4, mul_mv_bf16_f32_l4, has_simdgroup_reduction && has_bfloat);
666
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_BF16, mul_mv_bf16_bf16, has_simdgroup_reduction && has_bfloat);
667
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32, mul_mv_f16_f32, has_simdgroup_reduction);
668
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW, mul_mv_f16_f32_1row, has_simdgroup_reduction);
669
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4, mul_mv_f16_f32_l4, has_simdgroup_reduction);
670
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16, mul_mv_f16_f16, has_simdgroup_reduction);
671
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_0_F32, mul_mv_q4_0_f32, has_simdgroup_reduction);
672
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_1_F32, mul_mv_q4_1_f32, has_simdgroup_reduction);
673
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32, mul_mv_q5_0_f32, has_simdgroup_reduction);
674
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32, mul_mv_q5_1_f32, has_simdgroup_reduction);
675
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32, mul_mv_q8_0_f32, has_simdgroup_reduction);
676
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q2_K_F32, mul_mv_q2_K_f32, has_simdgroup_reduction);
677
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q3_K_F32, mul_mv_q3_K_f32, has_simdgroup_reduction);
678
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_K_F32, mul_mv_q4_K_f32, has_simdgroup_reduction);
679
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_K_F32, mul_mv_q5_K_f32, has_simdgroup_reduction);
680
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q6_K_F32, mul_mv_q6_K_f32, has_simdgroup_reduction);
681
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XXS_F32, mul_mv_iq2_xxs_f32, has_simdgroup_reduction);
682
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32, mul_mv_iq2_xs_f32, has_simdgroup_reduction);
683
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_XXS_F32, mul_mv_iq3_xxs_f32, has_simdgroup_reduction);
684
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_S_F32, mul_mv_iq3_s_f32, has_simdgroup_reduction);
685
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_S_F32, mul_mv_iq2_s_f32, has_simdgroup_reduction);
686
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_S_F32, mul_mv_iq1_s_f32, has_simdgroup_reduction);
687
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_M_F32, mul_mv_iq1_m_f32, has_simdgroup_reduction);
688
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32, mul_mv_iq4_nl_f32, has_simdgroup_reduction);
689
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_XS_F32, mul_mv_iq4_xs_f32, has_simdgroup_reduction);
690
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32, mul_mv_id_f32_f32, has_simdgroup_reduction);
691
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32, mul_mv_id_f16_f32, has_simdgroup_reduction);
692
+ //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_1ROW, mul_mv_id_f16_f32_1row, has_simdgroup_reduction);
693
+ //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_L4, mul_mv_id_f16_f32_l4, has_simdgroup_reduction);
694
+ //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F16, mul_mv_id_f16_f16, has_simdgroup_reduction);
695
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_BF16_F32, mul_mv_id_bf16_f32, has_simdgroup_reduction && has_bfloat);
696
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_0_F32, mul_mv_id_q4_0_f32, has_simdgroup_reduction);
697
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_1_F32, mul_mv_id_q4_1_f32, has_simdgroup_reduction);
698
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_0_F32, mul_mv_id_q5_0_f32, has_simdgroup_reduction);
699
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_1_F32, mul_mv_id_q5_1_f32, has_simdgroup_reduction);
700
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q8_0_F32, mul_mv_id_q8_0_f32, has_simdgroup_reduction);
701
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q2_K_F32, mul_mv_id_q2_K_f32, has_simdgroup_reduction);
702
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q3_K_F32, mul_mv_id_q3_K_f32, has_simdgroup_reduction);
703
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_K_F32, mul_mv_id_q4_K_f32, has_simdgroup_reduction);
704
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_K_F32, mul_mv_id_q5_K_f32, has_simdgroup_reduction);
705
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q6_K_F32, mul_mv_id_q6_K_f32, has_simdgroup_reduction);
706
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XXS_F32, mul_mv_id_iq2_xxs_f32, has_simdgroup_reduction);
707
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32, mul_mv_id_iq2_xs_f32, has_simdgroup_reduction);
708
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_XXS_F32, mul_mv_id_iq3_xxs_f32, has_simdgroup_reduction);
709
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_S_F32, mul_mv_id_iq3_s_f32, has_simdgroup_reduction);
710
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_S_F32, mul_mv_id_iq2_s_f32, has_simdgroup_reduction);
711
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_S_F32, mul_mv_id_iq1_s_f32, has_simdgroup_reduction);
712
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_M_F32, mul_mv_id_iq1_m_f32, has_simdgroup_reduction);
713
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32, mul_mv_id_iq4_nl_f32, has_simdgroup_reduction);
714
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32, mul_mv_id_iq4_xs_f32, has_simdgroup_reduction);
715
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32, mul_mm_f32_f32, has_simdgroup_mm);
716
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32, mul_mm_f16_f32, has_simdgroup_mm);
717
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_BF16_F32, mul_mm_bf16_f32, has_simdgroup_mm && has_bfloat);
718
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32, mul_mm_q4_0_f32, has_simdgroup_mm);
719
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F32, mul_mm_q4_1_f32, has_simdgroup_mm);
720
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32, mul_mm_q5_0_f32, has_simdgroup_mm);
721
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_1_F32, mul_mm_q5_1_f32, has_simdgroup_mm);
722
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q8_0_F32, mul_mm_q8_0_f32, has_simdgroup_mm);
723
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q2_K_F32, mul_mm_q2_K_f32, has_simdgroup_mm);
724
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q3_K_F32, mul_mm_q3_K_f32, has_simdgroup_mm);
725
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_K_F32, mul_mm_q4_K_f32, has_simdgroup_mm);
726
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_K_F32, mul_mm_q5_K_f32, has_simdgroup_mm);
727
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q6_K_F32, mul_mm_q6_K_f32, has_simdgroup_mm);
728
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32, mul_mm_iq2_xxs_f32, has_simdgroup_mm);
729
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32, mul_mm_iq2_xs_f32, has_simdgroup_mm);
730
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F32, mul_mm_iq3_xxs_f32, has_simdgroup_mm);
731
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_S_F32, mul_mm_iq3_s_f32, has_simdgroup_mm);
732
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_S_F32, mul_mm_iq2_s_f32, has_simdgroup_mm);
733
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32, mul_mm_iq1_s_f32, has_simdgroup_mm);
734
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F32, mul_mm_iq1_m_f32, has_simdgroup_mm);
735
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32, mul_mm_iq4_nl_f32, has_simdgroup_mm);
736
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32, mul_mm_iq4_xs_f32, has_simdgroup_mm);
737
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32, mul_mm_id_f32_f32, has_simdgroup_mm);
738
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32, mul_mm_id_f16_f32, has_simdgroup_mm);
739
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_BF16_F32, mul_mm_id_bf16_f32, has_simdgroup_mm && has_bfloat);
740
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32, mul_mm_id_q4_0_f32, has_simdgroup_mm);
741
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F32, mul_mm_id_q4_1_f32, has_simdgroup_mm);
742
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F32, mul_mm_id_q5_0_f32, has_simdgroup_mm);
743
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F32, mul_mm_id_q5_1_f32, has_simdgroup_mm);
744
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F32, mul_mm_id_q8_0_f32, has_simdgroup_mm);
745
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F32, mul_mm_id_q2_K_f32, has_simdgroup_mm);
746
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F32, mul_mm_id_q3_K_f32, has_simdgroup_mm);
747
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F32, mul_mm_id_q4_K_f32, has_simdgroup_mm);
748
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_K_F32, mul_mm_id_q5_K_f32, has_simdgroup_mm);
749
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F32, mul_mm_id_q6_K_f32, has_simdgroup_mm);
750
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32, mul_mm_id_iq2_xxs_f32, has_simdgroup_mm);
751
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32, mul_mm_id_iq2_xs_f32, has_simdgroup_mm);
752
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32, mul_mm_id_iq3_xxs_f32, has_simdgroup_mm);
753
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_S_F32, mul_mm_id_iq3_s_f32, has_simdgroup_mm);
754
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F32, mul_mm_id_iq2_s_f32, has_simdgroup_mm);
755
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32, mul_mm_id_iq1_s_f32, has_simdgroup_mm);
756
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F32, mul_mm_id_iq1_m_f32, has_simdgroup_mm);
757
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32, mul_mm_id_iq4_nl_f32, has_simdgroup_mm);
758
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32, mul_mm_id_iq4_xs_f32, has_simdgroup_mm);
759
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32, rope_norm_f32, true);
760
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16, rope_norm_f16, true);
761
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32, rope_neox_f32, true);
 
771
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC, argsort_f32_i32_asc, true);
772
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC, argsort_f32_i32_desc, true);
773
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32, leaky_relu_f32, true);
774
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64, flash_attn_ext_f16_h64, has_simdgroup_mm);
775
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80, flash_attn_ext_f16_h80, has_simdgroup_mm);
776
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96, flash_attn_ext_f16_h96, has_simdgroup_mm);
777
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112, flash_attn_ext_f16_h112, has_simdgroup_mm);
778
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128, flash_attn_ext_f16_h128, has_simdgroup_mm);
779
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, flash_attn_ext_f16_h256, has_simdgroup_mm);
780
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H64, flash_attn_ext_q4_0_h64, has_simdgroup_mm);
781
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H80, flash_attn_ext_q4_0_h80, has_simdgroup_mm);
782
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H96, flash_attn_ext_q4_0_h96, has_simdgroup_mm);
783
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H112, flash_attn_ext_q4_0_h112, has_simdgroup_mm);
784
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H128, flash_attn_ext_q4_0_h128, has_simdgroup_mm);
785
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H256, flash_attn_ext_q4_0_h256, has_simdgroup_mm);
786
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H64, flash_attn_ext_q4_1_h64, has_simdgroup_mm);
787
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H80, flash_attn_ext_q4_1_h80, has_simdgroup_mm);
788
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H96, flash_attn_ext_q4_1_h96, has_simdgroup_mm);
789
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H112, flash_attn_ext_q4_1_h112, has_simdgroup_mm);
790
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H128, flash_attn_ext_q4_1_h128, has_simdgroup_mm);
791
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H256, flash_attn_ext_q4_1_h256, has_simdgroup_mm);
792
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H64, flash_attn_ext_q5_0_h64, has_simdgroup_mm);
793
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H80, flash_attn_ext_q5_0_h80, has_simdgroup_mm);
794
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H96, flash_attn_ext_q5_0_h96, has_simdgroup_mm);
795
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H112, flash_attn_ext_q5_0_h112, has_simdgroup_mm);
796
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H128, flash_attn_ext_q5_0_h128, has_simdgroup_mm);
797
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H256, flash_attn_ext_q5_0_h256, has_simdgroup_mm);
798
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H64, flash_attn_ext_q5_1_h64, has_simdgroup_mm);
799
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H80, flash_attn_ext_q5_1_h80, has_simdgroup_mm);
800
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H96, flash_attn_ext_q5_1_h96, has_simdgroup_mm);
801
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H112, flash_attn_ext_q5_1_h112, has_simdgroup_mm);
802
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H128, flash_attn_ext_q5_1_h128, has_simdgroup_mm);
803
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H256, flash_attn_ext_q5_1_h256, has_simdgroup_mm);
804
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H64, flash_attn_ext_q8_0_h64, has_simdgroup_mm);
805
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H80, flash_attn_ext_q8_0_h80, has_simdgroup_mm);
806
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H96, flash_attn_ext_q8_0_h96, has_simdgroup_mm);
807
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H112, flash_attn_ext_q8_0_h112, has_simdgroup_mm);
808
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H128, flash_attn_ext_q8_0_h128, has_simdgroup_mm);
809
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H256, flash_attn_ext_q8_0_h256, has_simdgroup_mm);
810
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128, flash_attn_ext_vec_f16_h128, has_simdgroup_reduction);
811
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H128, flash_attn_ext_vec_q4_0_h128, has_simdgroup_reduction);
812
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H128, flash_attn_ext_vec_q4_1_h128, has_simdgroup_reduction);
813
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H128, flash_attn_ext_vec_q5_0_h128, has_simdgroup_reduction);
814
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H128, flash_attn_ext_vec_q5_1_h128, has_simdgroup_reduction);
815
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H128, flash_attn_ext_vec_q8_0_h128, has_simdgroup_reduction);
816
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, flash_attn_ext_vec_f16_h256, has_simdgroup_reduction);
817
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H256, flash_attn_ext_vec_q4_0_h256, has_simdgroup_reduction);
818
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H256, flash_attn_ext_vec_q4_1_h256, has_simdgroup_reduction);
819
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H256, flash_attn_ext_vec_q5_0_h256, has_simdgroup_reduction);
820
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H256, flash_attn_ext_vec_q5_1_h256, has_simdgroup_reduction);
821
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H256, flash_attn_ext_vec_q8_0_h256, has_simdgroup_reduction);
 
822
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F32, cpy_f32_f32, true);
823
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F16, cpy_f32_f16, true);
824
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_BF16, cpy_f32_bf16, has_bfloat);
825
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F32, cpy_f16_f32, true);
826
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F16, cpy_f16_f16, true);
827
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_BF16_F32, cpy_bf16_f32, has_bfloat);
828
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_BF16_BF16, cpy_bf16_bf16, has_bfloat);
829
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0, cpy_f32_q8_0, true);
830
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0, cpy_f32_q4_0, true);
831
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1, cpy_f32_q4_1, true);
 
915
  }
916
 
917
  static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_context * ctx_dev, const struct ggml_tensor * op) {
918
+ const bool has_simdgroup_mm = ctx_dev->has_simdgroup_mm;
919
+ const bool has_simdgroup_reduction = ctx_dev->has_simdgroup_reduction;
920
+ const bool has_bfloat = ctx_dev->has_bfloat;
921
+
922
+ if (!has_bfloat) {
923
+ for (size_t i = 0, n = 3; i < n; ++i) {
924
+ if (op->src[i] != NULL && op->src[i]->type == GGML_TYPE_BF16) {
925
+ return false;
926
+ }
927
  }
928
  }
929
 
 
 
 
930
  switch (op->op) {
931
  case GGML_OP_UNARY:
932
  switch (ggml_get_unary_op(op)) {
 
964
  case GGML_OP_SOFT_MAX:
965
  case GGML_OP_RMS_NORM:
966
  case GGML_OP_GROUP_NORM:
967
+ return has_simdgroup_reduction;
968
  case GGML_OP_NORM:
969
  case GGML_OP_ROPE:
970
  return true;
 
984
  if (op->src[1]->type != op->src[2]->type) {
985
  return false;
986
  }
987
+ return has_simdgroup_mm; // TODO: over-restricted for vec-kernels
988
  case GGML_OP_SSM_CONV:
989
  case GGML_OP_SSM_SCAN:
990
  return true;
991
  case GGML_OP_MUL_MAT:
992
  case GGML_OP_MUL_MAT_ID:
993
+ return has_simdgroup_reduction &&
994
  (op->src[0]->type != GGML_TYPE_F32 || op->src[1]->type == GGML_TYPE_F32);
995
  case GGML_OP_CPY:
996
  case GGML_OP_DUP:
 
1001
  switch (op->type) {
1002
  case GGML_TYPE_F32:
1003
  case GGML_TYPE_F16:
1004
+ case GGML_TYPE_BF16:
1005
  case GGML_TYPE_Q8_0:
1006
  case GGML_TYPE_Q4_0:
1007
  case GGML_TYPE_Q4_1:
 
1014
  }
1015
  case GGML_TYPE_F16:
1016
  switch (op->type) {
1017
+ case GGML_TYPE_F32:
1018
+ case GGML_TYPE_F16:
1019
  return true;
1020
+ default:
1021
+ return false;
1022
+ }
1023
+ case GGML_TYPE_BF16:
1024
+ switch (op->type) {
1025
+ case GGML_TYPE_F32:
1026
+ case GGML_TYPE_BF16:
1027
+ return true;
1028
+ default:
1029
  return false;
1030
  }
1031
  default:
 
1896
  switch (src0->type) {
1897
  case GGML_TYPE_F32: GGML_ASSERT(nb01 % 16 == 0); break;
1898
  case GGML_TYPE_F16: GGML_ASSERT(nb01 % 8 == 0); break;
1899
+ case GGML_TYPE_BF16: GGML_ASSERT(nb01 % 8 == 0); break;
1900
  default: break;
1901
  }
1902
 
 
1905
  switch (src0->type) {
1906
  case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32 ].pipeline; break;
1907
  case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32 ].pipeline; break;
1908
+ case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_BF16_F32 ].pipeline; break;
1909
  case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32 ].pipeline; break;
1910
  case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F32 ].pipeline; break;
1911
  case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32 ].pipeline; break;
 
1983
  nrows = 4;
1984
  }
1985
  } break;
1986
+ case GGML_TYPE_BF16:
1987
+ {
1988
+ nth0 = 32;
1989
+ nth1 = 1;
1990
+ if (src1t == GGML_TYPE_F32) {
1991
+ if (ne11 * ne12 < 4) {
1992
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_1ROW].pipeline;
1993
+ } else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) {
1994
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_L4].pipeline;
1995
+ nrows = ne11;
1996
+ } else {
1997
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32].pipeline;
1998
+ nrows = 4;
1999
+ }
2000
+ } else {
2001
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_BF16].pipeline;
2002
+ nrows = 4;
2003
+ }
2004
+ } break;
2005
  case GGML_TYPE_Q4_0:
2006
  {
2007
  nth0 = 8;
 
2220
  if ([device supportsFamily:MTLGPUFamilyApple7] &&
2221
  ne00 % 32 == 0 && ne00 >= 64 &&
2222
  dst_rows > dst_rows_min) {
 
2223
  // some Metal matrix data types require aligned pointers
2224
  // ref: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf (Table 2.5)
2225
  switch (src0->type) {
2226
+ case GGML_TYPE_F32: GGML_ASSERT(nb01 % 16 == 0); break;
2227
+ case GGML_TYPE_F16: GGML_ASSERT(nb01 % 8 == 0); break;
2228
+ case GGML_TYPE_BF16: GGML_ASSERT(nb01 % 8 == 0); break;
2229
  default: break;
2230
  }
2231
 
 
2234
  switch (src0->type) {
2235
  case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32 ].pipeline; break;
2236
  case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32 ].pipeline; break;
2237
+ case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_BF16_F32 ].pipeline; break;
2238
  case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32 ].pipeline; break;
2239
  case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F32 ].pipeline; break;
2240
  case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F32 ].pipeline; break;
 
2304
  nth1 = 1;
2305
  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32].pipeline;
2306
  } break;
2307
+ case GGML_TYPE_BF16:
2308
+ {
2309
+ GGML_ASSERT(src1t == GGML_TYPE_F32);
2310
+ nth0 = 32;
2311
+ nth1 = 1;
2312
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_BF16_F32].pipeline;
2313
+ } break;
2314
  case GGML_TYPE_Q4_0:
2315
  {
2316
  nth0 = 8;
 
2508
  switch (src0->type) {
2509
  case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_F32 ].pipeline; break;
2510
  case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_F16 ].pipeline; break;
2511
+ case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_BF16 ].pipeline; break;
2512
  case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_0 ].pipeline; break;
2513
  case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_1 ].pipeline; break;
2514
  case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_0 ].pipeline; break;
 
3308
  switch (dstt) {
3309
  case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F32].pipeline; break;
3310
  case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F16].pipeline; break;
3311
+ case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_BF16].pipeline; break;
3312
  case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0].pipeline; break;
3313
  case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0].pipeline; break;
3314
  case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1].pipeline; break;
 
3326
  default: GGML_ABORT("not implemented");
3327
  };
3328
  } break;
3329
+ case GGML_TYPE_BF16:
3330
+ {
3331
+ switch (dstt) {
3332
+ case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_BF16_F32].pipeline; break;
3333
+ case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_BF16_BF16].pipeline; break;
3334
+ default: GGML_ASSERT(false && "not implemented");
3335
+ };
3336
+ } break;
3337
  default: GGML_ABORT("not implemented");
3338
  }
3339
 
ggml/src/ggml-metal.metal CHANGED
@@ -12,6 +12,20 @@ using namespace metal;
12
 
13
  #define N_SIMDWIDTH 32 // assuming SIMD group size is 32
14
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  constexpr constant static float kvalues_iq4nl_f[16] = {
16
  -127.f, -104.f, -83.f, -65.f, -49.f, -35.f, -22.f, -10.f, 1.f, 13.f, 25.f, 38.f, 53.f, 69.f, 89.f, 113.f
17
  };
@@ -27,6 +41,13 @@ void dequantize_f16(device const half4x4 * src, short il, thread type4x4 & reg)
27
  reg = (type4x4)(*src);
28
  }
29
 
 
 
 
 
 
 
 
30
  template <typename type4x4>
31
  void dequantize_q4_0(device const block_q4_0 *xb, short il, thread type4x4 & reg) {
32
  device const uint16_t * qs = ((device const uint16_t *)xb + 1);
@@ -2041,6 +2062,10 @@ typedef decltype(kernel_mul_mv<half, half4, half, half4>) mul_mv_t;
2041
  template [[host_name("kernel_mul_mv_f32_f32")]] kernel mul_mv_t kernel_mul_mv<float, float4, float, float4>;
2042
  template [[host_name("kernel_mul_mv_f16_f32")]] kernel mul_mv_t kernel_mul_mv<half, half4, float, float4>;
2043
  template [[host_name("kernel_mul_mv_f16_f16")]] kernel mul_mv_t kernel_mul_mv<half, half4, half, half4>;
 
 
 
 
2044
 
2045
  template<typename T, typename T4>
2046
  kernel void kernel_mul_mv_1row(
@@ -2110,6 +2135,9 @@ kernel void kernel_mul_mv_1row(
2110
  typedef decltype(kernel_mul_mv_1row<half, half4>) mul_mv_1row_t;
2111
 
2112
  template [[host_name("kernel_mul_mv_f16_f32_1row")]] kernel mul_mv_1row_t kernel_mul_mv_1row<half, half4>;
 
 
 
2113
 
2114
  // Assumes row size (ne00) is a multiple of 4
2115
  template<typename T, typename T4>
@@ -2169,6 +2197,9 @@ kernel void kernel_mul_mv_l4(
2169
  typedef decltype(kernel_mul_mv_l4<half, half4>) mul_mv_l4_t;
2170
 
2171
  template [[host_name("kernel_mul_mv_f16_f32_l4")]] kernel mul_mv_l4_t kernel_mul_mv_l4<half, half4>;
 
 
 
2172
 
2173
  static float rope_yarn_ramp(const float low, const float high, const int i0) {
2174
  const float y = (i0 / 2 - low) / max(0.001f, high - low);
@@ -3565,10 +3596,17 @@ kernel void kernel_cpy(
3565
 
3566
  typedef decltype(kernel_cpy<float, float>) kernel_cpy_t;
3567
 
3568
- template [[host_name("kernel_cpy_f32_f32")]] kernel kernel_cpy_t kernel_cpy<float, float>;
3569
- template [[host_name("kernel_cpy_f32_f16")]] kernel kernel_cpy_t kernel_cpy<float, half>;
3570
- template [[host_name("kernel_cpy_f16_f16")]] kernel kernel_cpy_t kernel_cpy<half, half>;
3571
- template [[host_name("kernel_cpy_f16_f32")]] kernel kernel_cpy_t kernel_cpy<half, float>;
 
 
 
 
 
 
 
3572
 
3573
  kernel void kernel_cpy_f32_q8_0(
3574
  device const float * src0,
@@ -6473,6 +6511,9 @@ typedef decltype(kernel_get_rows_f<float>) get_rows_f_t;
6473
 
6474
  template [[host_name("kernel_get_rows_f32")]] kernel get_rows_f_t kernel_get_rows_f<float>;
6475
  template [[host_name("kernel_get_rows_f16")]] kernel get_rows_f_t kernel_get_rows_f<half>;
 
 
 
6476
 
6477
  typedef decltype(kernel_get_rows_q<block_q4_0, 2, dequantize_q4_0>) get_rows_q_t;
6478
 
@@ -6504,6 +6545,9 @@ typedef decltype(kernel_mul_mm<half, half4x4, simdgroup_half8x8, float4x4, 1, de
6504
 
6505
  template [[host_name("kernel_mul_mm_f32_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, float4x4, 1, dequantize_f32>;
6506
  template [[host_name("kernel_mul_mm_f16_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half4x4, 1, dequantize_f16>;
 
 
 
6507
  template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q4_0, 2, dequantize_q4_0>;
6508
  template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q4_1, 2, dequantize_q4_1>;
6509
  template [[host_name("kernel_mul_mm_q5_0_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q5_0, 2, dequantize_q5_0>;
@@ -6532,6 +6576,9 @@ typedef decltype(kernel_mul_mm_id<float4x4, 1, dequantize_f32>) mat_mm_id_t;
6532
 
6533
  template [[host_name("kernel_mul_mm_id_f32_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<float4x4, 1, dequantize_f32>;
6534
  template [[host_name("kernel_mul_mm_id_f16_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<half4x4, 1, dequantize_f16>;
 
 
 
6535
  template [[host_name("kernel_mul_mm_id_q4_0_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q4_0, 2, dequantize_q4_0>;
6536
  template [[host_name("kernel_mul_mm_id_q4_1_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q4_1, 2, dequantize_q4_1>;
6537
  template [[host_name("kernel_mul_mm_id_q5_0_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q5_0, 2, dequantize_q5_0>;
@@ -6755,6 +6802,9 @@ typedef decltype(kernel_mul_mv_id<mmv_fn<kernel_mul_mv_impl<float, float4, float
6755
 
6756
  template [[host_name("kernel_mul_mv_id_f32_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_impl<float, float4, float, float4>>>;
6757
  template [[host_name("kernel_mul_mv_id_f16_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_impl<half, half4, float, float4>>>;
 
 
 
6758
  template [[host_name("kernel_mul_mv_id_q8_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q8_0_f32_impl>>;
6759
  template [[host_name("kernel_mul_mv_id_q4_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>>>;
6760
  template [[host_name("kernel_mul_mv_id_q4_1_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>>>;
 
12
 
13
  #define N_SIMDWIDTH 32 // assuming SIMD group size is 32
14
 
15
+ // ref: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf
16
+ //
17
+ // cmd:
18
+ // .../usr/bin/metal -dM -E -c ggml/src/ggml-metal.metal
19
+ // .../usr/bin/metal -dM -E -c -target air64-apple-ios14.0 ggml/src/ggml-metal.metal
20
+ //
21
+ #if __METAL_VERSION__ < 310
22
+ #define GGML_METAL_NO_BFLOAT
23
+ #endif
24
+
25
+ #if !defined(GGML_METAL_NO_BFLOAT)
26
+ typedef matrix<bfloat, 4, 4> bfloat4x4;
27
+ #endif
28
+
29
  constexpr constant static float kvalues_iq4nl_f[16] = {
30
  -127.f, -104.f, -83.f, -65.f, -49.f, -35.f, -22.f, -10.f, 1.f, 13.f, 25.f, 38.f, 53.f, 69.f, 89.f, 113.f
31
  };
 
41
  reg = (type4x4)(*src);
42
  }
43
 
44
+ #if !defined(GGML_METAL_NO_BFLOAT)
45
+ template <typename type4x4>
46
+ void dequantize_bf16(device const bfloat4x4 * src, short il, thread type4x4 & reg) {
47
+ reg = (type4x4)(*src);
48
+ }
49
+ #endif
50
+
51
  template <typename type4x4>
52
  void dequantize_q4_0(device const block_q4_0 *xb, short il, thread type4x4 & reg) {
53
  device const uint16_t * qs = ((device const uint16_t *)xb + 1);
 
2062
  template [[host_name("kernel_mul_mv_f32_f32")]] kernel mul_mv_t kernel_mul_mv<float, float4, float, float4>;
2063
  template [[host_name("kernel_mul_mv_f16_f32")]] kernel mul_mv_t kernel_mul_mv<half, half4, float, float4>;
2064
  template [[host_name("kernel_mul_mv_f16_f16")]] kernel mul_mv_t kernel_mul_mv<half, half4, half, half4>;
2065
+ #if !defined(GGML_METAL_NO_BFLOAT)
2066
+ template [[host_name("kernel_mul_mv_bf16_f32")]] kernel mul_mv_t kernel_mul_mv<bfloat, bfloat4, float, float4>;
2067
+ template [[host_name("kernel_mul_mv_bf16_bf16")]] kernel mul_mv_t kernel_mul_mv<bfloat, bfloat4, bfloat, bfloat4>;
2068
+ #endif
2069
 
2070
  template<typename T, typename T4>
2071
  kernel void kernel_mul_mv_1row(
 
2135
  typedef decltype(kernel_mul_mv_1row<half, half4>) mul_mv_1row_t;
2136
 
2137
  template [[host_name("kernel_mul_mv_f16_f32_1row")]] kernel mul_mv_1row_t kernel_mul_mv_1row<half, half4>;
2138
+ #if !defined(GGML_METAL_NO_BFLOAT)
2139
+ template [[host_name("kernel_mul_mv_bf16_f32_1row")]] kernel mul_mv_1row_t kernel_mul_mv_1row<bfloat, bfloat4>;
2140
+ #endif
2141
 
2142
  // Assumes row size (ne00) is a multiple of 4
2143
  template<typename T, typename T4>
 
2197
  typedef decltype(kernel_mul_mv_l4<half, half4>) mul_mv_l4_t;
2198
 
2199
  template [[host_name("kernel_mul_mv_f16_f32_l4")]] kernel mul_mv_l4_t kernel_mul_mv_l4<half, half4>;
2200
+ #if !defined(GGML_METAL_NO_BFLOAT)
2201
+ template [[host_name("kernel_mul_mv_bf16_f32_l4")]] kernel mul_mv_l4_t kernel_mul_mv_l4<bfloat, bfloat4>;
2202
+ #endif
2203
 
2204
  static float rope_yarn_ramp(const float low, const float high, const int i0) {
2205
  const float y = (i0 / 2 - low) / max(0.001f, high - low);
 
3596
 
3597
  typedef decltype(kernel_cpy<float, float>) kernel_cpy_t;
3598
 
3599
+ template [[host_name("kernel_cpy_f32_f32")]] kernel kernel_cpy_t kernel_cpy<float, float>;
3600
+ template [[host_name("kernel_cpy_f32_f16")]] kernel kernel_cpy_t kernel_cpy<float, half>;
3601
+ #if !defined(GGML_METAL_NO_BFLOAT)
3602
+ template [[host_name("kernel_cpy_f32_bf16")]] kernel kernel_cpy_t kernel_cpy<float, bfloat>;
3603
+ #endif
3604
+ template [[host_name("kernel_cpy_f16_f32")]] kernel kernel_cpy_t kernel_cpy<half, float>;
3605
+ template [[host_name("kernel_cpy_f16_f16")]] kernel kernel_cpy_t kernel_cpy<half, half>;
3606
+ #if !defined(GGML_METAL_NO_BFLOAT)
3607
+ template [[host_name("kernel_cpy_bf16_f32")]] kernel kernel_cpy_t kernel_cpy<bfloat, float>;
3608
+ template [[host_name("kernel_cpy_bf16_bf16")]] kernel kernel_cpy_t kernel_cpy<bfloat, bfloat>;
3609
+ #endif
3610
 
3611
  kernel void kernel_cpy_f32_q8_0(
3612
  device const float * src0,
 
6511
 
6512
  template [[host_name("kernel_get_rows_f32")]] kernel get_rows_f_t kernel_get_rows_f<float>;
6513
  template [[host_name("kernel_get_rows_f16")]] kernel get_rows_f_t kernel_get_rows_f<half>;
6514
+ #if !defined(GGML_METAL_NO_BFLOAT)
6515
+ template [[host_name("kernel_get_rows_bf16")]] kernel get_rows_f_t kernel_get_rows_f<bfloat>;
6516
+ #endif
6517
 
6518
  typedef decltype(kernel_get_rows_q<block_q4_0, 2, dequantize_q4_0>) get_rows_q_t;
6519
 
 
6545
 
6546
  template [[host_name("kernel_mul_mm_f32_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, float4x4, 1, dequantize_f32>;
6547
  template [[host_name("kernel_mul_mm_f16_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half4x4, 1, dequantize_f16>;
6548
+ #if !defined(GGML_METAL_NO_BFLOAT)
6549
+ template [[host_name("kernel_mul_mm_bf16_f32")]] kernel mat_mm_t kernel_mul_mm<bfloat, bfloat4x4, simdgroup_bfloat8x8, bfloat4x4, 1, dequantize_bf16>;
6550
+ #endif
6551
  template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q4_0, 2, dequantize_q4_0>;
6552
  template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q4_1, 2, dequantize_q4_1>;
6553
  template [[host_name("kernel_mul_mm_q5_0_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q5_0, 2, dequantize_q5_0>;
 
6576
 
6577
  template [[host_name("kernel_mul_mm_id_f32_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<float4x4, 1, dequantize_f32>;
6578
  template [[host_name("kernel_mul_mm_id_f16_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<half4x4, 1, dequantize_f16>;
6579
+ #if !defined(GGML_METAL_NO_BFLOAT)
6580
+ template [[host_name("kernel_mul_mm_id_bf16_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<bfloat4x4, 1, dequantize_bf16>;
6581
+ #endif
6582
  template [[host_name("kernel_mul_mm_id_q4_0_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q4_0, 2, dequantize_q4_0>;
6583
  template [[host_name("kernel_mul_mm_id_q4_1_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q4_1, 2, dequantize_q4_1>;
6584
  template [[host_name("kernel_mul_mm_id_q5_0_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q5_0, 2, dequantize_q5_0>;
 
6802
 
6803
  template [[host_name("kernel_mul_mv_id_f32_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_impl<float, float4, float, float4>>>;
6804
  template [[host_name("kernel_mul_mv_id_f16_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_impl<half, half4, float, float4>>>;
6805
+ #if !defined(GGML_METAL_NO_BFLOAT)
6806
+ template [[host_name("kernel_mul_mv_id_bf16_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_impl<bfloat, bfloat4, float, float4>>>;
6807
+ #endif
6808
  template [[host_name("kernel_mul_mv_id_q8_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q8_0_f32_impl>>;
6809
  template [[host_name("kernel_mul_mv_id_q4_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>>>;
6810
  template [[host_name("kernel_mul_mv_id_q4_1_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>>>;