ggerganov commited on
Commit
02d7878
·
unverified ·
1 Parent(s): 714ee6b

metal : add F32 support + update bench output

Browse files
Files changed (7) hide show
  1. Makefile +1 -0
  2. extra/bench-all.sh +7 -3
  3. ggml-metal.m +12 -0
  4. ggml-metal.metal +81 -7
  5. ggml.c +8 -0
  6. ggml.h +1 -0
  7. whisper.cpp +1 -0
Makefile CHANGED
@@ -186,6 +186,7 @@ ifndef WHISPER_NO_METAL
186
  ifeq ($(UNAME_S),Darwin)
187
  WHISPER_METAL := 1
188
 
 
189
  CXXFLAGS += -DGGML_USE_METAL
190
  LDFLAGS += -framework Foundation -framework Metal -framework MetalKit
191
  endif
 
186
  ifeq ($(UNAME_S),Darwin)
187
  WHISPER_METAL := 1
188
 
189
+ CFLAGS += -DGGML_USE_METAL
190
  CXXFLAGS += -DGGML_USE_METAL
191
  LDFLAGS += -framework Foundation -framework Metal -framework MetalKit
192
  endif
extra/bench-all.sh CHANGED
@@ -44,8 +44,8 @@ if [ "$encoder_only" -eq 0 ]; then
44
  printf "\n"
45
  fi
46
 
47
- printf "| %6s | %6s | %12s | %9s | %3s | %7s | %7s | %7s | %7s |\n" "CPU" "OS" "Config" "Model" "Th" "Enc." "Dec." "PP" "Commit"
48
- printf "| %6s | %6s | %12s | %9s | %3s | %7s | %7s | %7s | %7s |\n" "---" "---" "---" "---" "---" "---" "---" "---" "---"
49
 
50
  for model in "${models[@]}"; do
51
  # actual run
@@ -83,9 +83,13 @@ for model in "${models[@]}"; do
83
  config="$config COREML"
84
  fi
85
 
 
 
 
 
86
  commit=$(git rev-parse --short HEAD)
87
 
88
  if [ $ret -eq 0 ]; then
89
- printf "| <todo> | <todo> | %12s | %9s | %3s | %7s | %7s | %7s | %7s |\n" "$config" "$model" "$n_threads" "$encode_time" "$decode_time" "$prompt_time" "$commit"
90
  fi
91
  done
 
44
  printf "\n"
45
  fi
46
 
47
+ printf "| %6s | %6s | %16s | %11s | %3s | %7s | %7s | %7s | %7s |\n" "CPU" "OS" "Config" "Model" "Th" "Enc." "Dec." "PP" "Commit"
48
+ printf "| %6s | %6s | %16s | %11s | %3s | %7s | %7s | %7s | %7s |\n" "---" "---" "---" "---" "---" "---" "---" "---" "---"
49
 
50
  for model in "${models[@]}"; do
51
  # actual run
 
83
  config="$config COREML"
84
  fi
85
 
86
+ if [[ $system_info == *"METAL = 1"* ]]; then
87
+ config="$config METAL"
88
+ fi
89
+
90
  commit=$(git rev-parse --short HEAD)
91
 
92
  if [ $ret -eq 0 ]; then
93
+ printf "| <todo> | <todo> | %16s | %11s | %3s | %7s | %7s | %7s | %7s |\n" "$config" "$model" "$n_threads" "$encode_time" "$decode_time" "$prompt_time" "$commit"
94
  fi
95
  done
ggml-metal.m CHANGED
@@ -78,6 +78,7 @@ struct ggml_metal_context {
78
  GGML_METAL_DECL_KERNEL(get_rows_q6_K);
79
  GGML_METAL_DECL_KERNEL(rms_norm);
80
  GGML_METAL_DECL_KERNEL(norm);
 
81
  GGML_METAL_DECL_KERNEL(mul_mat_f16_f32);
82
  GGML_METAL_DECL_KERNEL(mul_mat_f16_f32_1row);
83
  GGML_METAL_DECL_KERNEL(mul_mat_f16_f32_l4);
@@ -89,6 +90,7 @@ struct ggml_metal_context {
89
  GGML_METAL_DECL_KERNEL(mul_mat_q4_K_f32);
90
  GGML_METAL_DECL_KERNEL(mul_mat_q5_K_f32);
91
  GGML_METAL_DECL_KERNEL(mul_mat_q6_K_f32);
 
92
  GGML_METAL_DECL_KERNEL(mul_mm_f16_f32);
93
  GGML_METAL_DECL_KERNEL(mul_mm_q4_0_f32);
94
  GGML_METAL_DECL_KERNEL(mul_mm_q4_1_f32);
@@ -237,6 +239,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
237
  GGML_METAL_ADD_KERNEL(get_rows_q6_K);
238
  GGML_METAL_ADD_KERNEL(rms_norm);
239
  GGML_METAL_ADD_KERNEL(norm);
 
240
  GGML_METAL_ADD_KERNEL(mul_mat_f16_f32);
241
  GGML_METAL_ADD_KERNEL(mul_mat_f16_f32_1row);
242
  GGML_METAL_ADD_KERNEL(mul_mat_f16_f32_l4);
@@ -248,6 +251,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
248
  GGML_METAL_ADD_KERNEL(mul_mat_q4_K_f32);
249
  GGML_METAL_ADD_KERNEL(mul_mat_q5_K_f32);
250
  GGML_METAL_ADD_KERNEL(mul_mat_q6_K_f32);
 
251
  GGML_METAL_ADD_KERNEL(mul_mm_f16_f32);
252
  GGML_METAL_ADD_KERNEL(mul_mm_q4_0_f32);
253
  GGML_METAL_ADD_KERNEL(mul_mm_q8_0_f32);
@@ -309,6 +313,7 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
309
  GGML_METAL_DEL_KERNEL(get_rows_q6_K);
310
  GGML_METAL_DEL_KERNEL(rms_norm);
311
  GGML_METAL_DEL_KERNEL(norm);
 
312
  GGML_METAL_DEL_KERNEL(mul_mat_f16_f32);
313
  GGML_METAL_DEL_KERNEL(mul_mat_f16_f32_1row);
314
  GGML_METAL_DEL_KERNEL(mul_mat_f16_f32_l4);
@@ -320,6 +325,7 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
320
  GGML_METAL_DEL_KERNEL(mul_mat_q4_K_f32);
321
  GGML_METAL_DEL_KERNEL(mul_mat_q5_K_f32);
322
  GGML_METAL_DEL_KERNEL(mul_mat_q6_K_f32);
 
323
  GGML_METAL_DEL_KERNEL(mul_mm_f16_f32);
324
  GGML_METAL_DEL_KERNEL(mul_mm_q4_0_f32);
325
  GGML_METAL_DEL_KERNEL(mul_mm_q8_0_f32);
@@ -885,6 +891,7 @@ void ggml_metal_graph_compute(
885
  ne00%32 == 0 &&
886
  ne11 > 1) {
887
  switch (src0->type) {
 
888
  case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_mul_mm_f16_f32]; break;
889
  case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_0_f32]; break;
890
  case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_1_f32]; break;
@@ -919,6 +926,11 @@ void ggml_metal_graph_compute(
919
 
920
  // use custom matrix x vector kernel
921
  switch (src0t) {
 
 
 
 
 
922
  case GGML_TYPE_F16:
923
  {
924
  nth0 = 32;
 
78
  GGML_METAL_DECL_KERNEL(get_rows_q6_K);
79
  GGML_METAL_DECL_KERNEL(rms_norm);
80
  GGML_METAL_DECL_KERNEL(norm);
81
+ GGML_METAL_DECL_KERNEL(mul_mat_f32_f32);
82
  GGML_METAL_DECL_KERNEL(mul_mat_f16_f32);
83
  GGML_METAL_DECL_KERNEL(mul_mat_f16_f32_1row);
84
  GGML_METAL_DECL_KERNEL(mul_mat_f16_f32_l4);
 
90
  GGML_METAL_DECL_KERNEL(mul_mat_q4_K_f32);
91
  GGML_METAL_DECL_KERNEL(mul_mat_q5_K_f32);
92
  GGML_METAL_DECL_KERNEL(mul_mat_q6_K_f32);
93
+ GGML_METAL_DECL_KERNEL(mul_mm_f32_f32);
94
  GGML_METAL_DECL_KERNEL(mul_mm_f16_f32);
95
  GGML_METAL_DECL_KERNEL(mul_mm_q4_0_f32);
96
  GGML_METAL_DECL_KERNEL(mul_mm_q4_1_f32);
 
239
  GGML_METAL_ADD_KERNEL(get_rows_q6_K);
240
  GGML_METAL_ADD_KERNEL(rms_norm);
241
  GGML_METAL_ADD_KERNEL(norm);
242
+ GGML_METAL_ADD_KERNEL(mul_mat_f32_f32);
243
  GGML_METAL_ADD_KERNEL(mul_mat_f16_f32);
244
  GGML_METAL_ADD_KERNEL(mul_mat_f16_f32_1row);
245
  GGML_METAL_ADD_KERNEL(mul_mat_f16_f32_l4);
 
251
  GGML_METAL_ADD_KERNEL(mul_mat_q4_K_f32);
252
  GGML_METAL_ADD_KERNEL(mul_mat_q5_K_f32);
253
  GGML_METAL_ADD_KERNEL(mul_mat_q6_K_f32);
254
+ GGML_METAL_ADD_KERNEL(mul_mm_f32_f32);
255
  GGML_METAL_ADD_KERNEL(mul_mm_f16_f32);
256
  GGML_METAL_ADD_KERNEL(mul_mm_q4_0_f32);
257
  GGML_METAL_ADD_KERNEL(mul_mm_q8_0_f32);
 
313
  GGML_METAL_DEL_KERNEL(get_rows_q6_K);
314
  GGML_METAL_DEL_KERNEL(rms_norm);
315
  GGML_METAL_DEL_KERNEL(norm);
316
+ GGML_METAL_DEL_KERNEL(mul_mat_f32_f32);
317
  GGML_METAL_DEL_KERNEL(mul_mat_f16_f32);
318
  GGML_METAL_DEL_KERNEL(mul_mat_f16_f32_1row);
319
  GGML_METAL_DEL_KERNEL(mul_mat_f16_f32_l4);
 
325
  GGML_METAL_DEL_KERNEL(mul_mat_q4_K_f32);
326
  GGML_METAL_DEL_KERNEL(mul_mat_q5_K_f32);
327
  GGML_METAL_DEL_KERNEL(mul_mat_q6_K_f32);
328
+ GGML_METAL_DEL_KERNEL(mul_mm_f32_f32);
329
  GGML_METAL_DEL_KERNEL(mul_mm_f16_f32);
330
  GGML_METAL_DEL_KERNEL(mul_mm_q4_0_f32);
331
  GGML_METAL_DEL_KERNEL(mul_mm_q8_0_f32);
 
891
  ne00%32 == 0 &&
892
  ne11 > 1) {
893
  switch (src0->type) {
894
+ case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_mul_mm_f32_f32]; break;
895
  case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_mul_mm_f16_f32]; break;
896
  case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_0_f32]; break;
897
  case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_1_f32]; break;
 
926
 
927
  // use custom matrix x vector kernel
928
  switch (src0t) {
929
+ case GGML_TYPE_F32:
930
+ {
931
+ [encoder setComputePipelineState:ctx->pipeline_mul_mat_f32_f32];
932
+ nrows = 4;
933
+ } break;
934
  case GGML_TYPE_F16:
935
  {
936
  nth0 = 32;
ggml-metal.metal CHANGED
@@ -523,6 +523,79 @@ kernel void kernel_mul_mat_q8_0_f32(
523
  }
524
  }
525
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
526
  kernel void kernel_mul_mat_f16_f32_1row(
527
  device const char * src0,
528
  device const char * src1,
@@ -1399,13 +1472,13 @@ kernel void kernel_mul_mat_q4_K_f32(
1399
  device const float * src1,
1400
  device float * dst,
1401
  constant int64_t & ne00,
1402
- constant int64_t & ne01[[buffer(4)]],
1403
- constant int64_t & ne02[[buffer(5)]],
1404
- constant int64_t & ne10[[buffer(9)]],
1405
- constant int64_t & ne12[[buffer(11)]],
1406
- constant int64_t & ne0[[buffer(15)]],
1407
- constant int64_t & ne1[[buffer(16)]],
1408
- constant uint & gqa[[buffer(17)]],
1409
  uint3 tgpig[[threadgroup_position_in_grid]],
1410
  uint tiisg[[thread_index_in_simdgroup]],
1411
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
@@ -2268,6 +2341,7 @@ typedef void (mat_mm_t)(
2268
  constant uint & gqa,
2269
  threadgroup uchar *, uint3, uint, uint);
2270
 
 
2271
  template [[host_name("kernel_mul_mm_f16_f32")]] kernel mat_mm_t kernel_mul_mm<half4x4, 1, dequantize_f16>;
2272
  template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_0, 2, dequantize_q4_0>;
2273
  template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_1, 2, dequantize_q4_1>;
 
523
  }
524
  }
525
 
526
+ #define N_F32_F32 4
527
+
528
+ kernel void kernel_mul_mat_f32_f32(
529
+ device const char * src0,
530
+ device const char * src1,
531
+ device float * dst,
532
+ constant int64_t & ne00,
533
+ constant int64_t & ne01,
534
+ constant int64_t & ne02,
535
+ constant uint64_t & nb00,
536
+ constant uint64_t & nb01,
537
+ constant uint64_t & nb02,
538
+ constant int64_t & ne10,
539
+ constant int64_t & ne11,
540
+ constant int64_t & ne12,
541
+ constant uint64_t & nb10,
542
+ constant uint64_t & nb11,
543
+ constant uint64_t & nb12,
544
+ constant int64_t & ne0,
545
+ constant int64_t & ne1,
546
+ uint3 tgpig[[threadgroup_position_in_grid]],
547
+ uint tiisg[[thread_index_in_simdgroup]]) {
548
+
549
+ const int64_t r0 = tgpig.x;
550
+ const int64_t rb = tgpig.y*N_F32_F32;
551
+ const int64_t im = tgpig.z;
552
+
553
+ device const float * x = (device const float *) (src0 + r0*nb01 + im/(ne12/ne02)*nb02);
554
+
555
+ if (ne00 < 128) {
556
+ for (int row = 0; row < N_F32_F32; ++row) {
557
+ int r1 = rb + row;
558
+ if (r1 >= ne11) {
559
+ break;
560
+ }
561
+
562
+ device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12);
563
+
564
+ float sumf = 0;
565
+ for (int i = tiisg; i < ne00; i += 32) {
566
+ sumf += (float) x[i] * (float) y[i];
567
+ }
568
+
569
+ float all_sum = simd_sum(sumf);
570
+ if (tiisg == 0) {
571
+ dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
572
+ }
573
+ }
574
+ } else {
575
+ device const float4 * x4 = (device const float4 *)x;
576
+ for (int row = 0; row < N_F32_F32; ++row) {
577
+ int r1 = rb + row;
578
+ if (r1 >= ne11) {
579
+ break;
580
+ }
581
+
582
+ device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12);
583
+ device const float4 * y4 = (device const float4 *) y;
584
+
585
+ float sumf = 0;
586
+ for (int i = tiisg; i < ne00/4; i += 32) {
587
+ for (int k = 0; k < 4; ++k) sumf += (float) x4[i][k] * y4[i][k];
588
+ }
589
+
590
+ float all_sum = simd_sum(sumf);
591
+ if (tiisg == 0) {
592
+ for (int i = 4*(ne00/4); i < ne00; ++i) all_sum += (float) x[i] * y[i];
593
+ dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
594
+ }
595
+ }
596
+ }
597
+ }
598
+
599
  kernel void kernel_mul_mat_f16_f32_1row(
600
  device const char * src0,
601
  device const char * src1,
 
1472
  device const float * src1,
1473
  device float * dst,
1474
  constant int64_t & ne00,
1475
+ constant int64_t & ne01 [[buffer(4)]],
1476
+ constant int64_t & ne02 [[buffer(5)]],
1477
+ constant int64_t & ne10 [[buffer(9)]],
1478
+ constant int64_t & ne12 [[buffer(11)]],
1479
+ constant int64_t & ne0 [[buffer(15)]],
1480
+ constant int64_t & ne1 [[buffer(16)]],
1481
+ constant uint & gqa [[buffer(17)]],
1482
  uint3 tgpig[[threadgroup_position_in_grid]],
1483
  uint tiisg[[thread_index_in_simdgroup]],
1484
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
 
2341
  constant uint & gqa,
2342
  threadgroup uchar *, uint3, uint, uint);
2343
 
2344
+ template [[host_name("kernel_mul_mm_f32_f32")]] kernel mat_mm_t kernel_mul_mm<float4x4, 1, dequantize_f32>;
2345
  template [[host_name("kernel_mul_mm_f16_f32")]] kernel mat_mm_t kernel_mul_mm<half4x4, 1, dequantize_f16>;
2346
  template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_0, 2, dequantize_q4_0>;
2347
  template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_1, 2, dequantize_q4_1>;
ggml.c CHANGED
@@ -20753,6 +20753,14 @@ int ggml_cpu_has_arm_fma(void) {
20753
  #endif
20754
  }
20755
 
 
 
 
 
 
 
 
 
20756
  int ggml_cpu_has_f16c(void) {
20757
  #if defined(__F16C__)
20758
  return 1;
 
20753
  #endif
20754
  }
20755
 
20756
+ int ggml_cpu_has_metal(void) {
20757
+ #if defined(GGML_USE_METAL)
20758
+ return 1;
20759
+ #else
20760
+ return 0;
20761
+ #endif
20762
+ }
20763
+
20764
  int ggml_cpu_has_f16c(void) {
20765
  #if defined(__F16C__)
20766
  return 1;
ggml.h CHANGED
@@ -1961,6 +1961,7 @@ extern "C" {
1961
  GGML_API int ggml_cpu_has_fma (void);
1962
  GGML_API int ggml_cpu_has_neon (void);
1963
  GGML_API int ggml_cpu_has_arm_fma (void);
 
1964
  GGML_API int ggml_cpu_has_f16c (void);
1965
  GGML_API int ggml_cpu_has_fp16_va (void);
1966
  GGML_API int ggml_cpu_has_wasm_simd (void);
 
1961
  GGML_API int ggml_cpu_has_fma (void);
1962
  GGML_API int ggml_cpu_has_neon (void);
1963
  GGML_API int ggml_cpu_has_arm_fma (void);
1964
+ GGML_API int ggml_cpu_has_metal (void);
1965
  GGML_API int ggml_cpu_has_f16c (void);
1966
  GGML_API int ggml_cpu_has_fp16_va (void);
1967
  GGML_API int ggml_cpu_has_wasm_simd (void);
whisper.cpp CHANGED
@@ -3669,6 +3669,7 @@ const char * whisper_print_system_info(void) {
3669
  s += "FMA = " + std::to_string(ggml_cpu_has_fma()) + " | ";
3670
  s += "NEON = " + std::to_string(ggml_cpu_has_neon()) + " | ";
3671
  s += "ARM_FMA = " + std::to_string(ggml_cpu_has_arm_fma()) + " | ";
 
3672
  s += "F16C = " + std::to_string(ggml_cpu_has_f16c()) + " | ";
3673
  s += "FP16_VA = " + std::to_string(ggml_cpu_has_fp16_va()) + " | ";
3674
  s += "WASM_SIMD = " + std::to_string(ggml_cpu_has_wasm_simd()) + " | ";
 
3669
  s += "FMA = " + std::to_string(ggml_cpu_has_fma()) + " | ";
3670
  s += "NEON = " + std::to_string(ggml_cpu_has_neon()) + " | ";
3671
  s += "ARM_FMA = " + std::to_string(ggml_cpu_has_arm_fma()) + " | ";
3672
+ s += "METAL = " + std::to_string(ggml_cpu_has_metal()) + " | ";
3673
  s += "F16C = " + std::to_string(ggml_cpu_has_f16c()) + " | ";
3674
  s += "FP16_VA = " + std::to_string(ggml_cpu_has_fp16_va()) + " | ";
3675
  s += "WASM_SIMD = " + std::to_string(ggml_cpu_has_wasm_simd()) + " | ";