ggerganov slaren commited on
Commit
b6adf19
·
1 Parent(s): 6e1b44c

ggml : add metal backend registry / device (llama/9713)

Browse files

* ggml : add metal backend registry / device

ggml-ci

* metal : fix names [no ci]

* metal : global registry and device instances

ggml-ci

* cont : alternative initialization of global objects

ggml-ci

* llama : adapt to backend changes

ggml-ci

* fixes

* metal : fix indent

* metal : fix build when MTLGPUFamilyApple3 is not available

ggml-ci

* fix merge

* metal : avoid unnecessary singleton accesses

ggml-ci

* metal : minor fix [no ci]

* metal : g_state -> g_ggml_ctx_dev_main [no ci]

* metal : avoid reference of device context in the backend context

ggml-ci

* metal : minor [no ci]

* metal : fix maxTransferRate check

* metal : remove transfer rate stuff

---------

Co-authored-by: slaren <[email protected]>

ggml/include/ggml-backend.h CHANGED
@@ -127,6 +127,8 @@ extern "C" {
127
  bool async;
128
  // pinned host buffer
129
  bool host_buffer;
 
 
130
  // event synchronization
131
  bool events;
132
  };
 
127
  bool async;
128
  // pinned host buffer
129
  bool host_buffer;
130
+ // creating buffers from host ptr
131
+ bool buffer_from_host_ptr;
132
  // event synchronization
133
  bool events;
134
  };
ggml/include/ggml-metal.h CHANGED
@@ -43,7 +43,9 @@ GGML_API ggml_backend_t ggml_backend_metal_init(void);
43
 
44
  GGML_API bool ggml_backend_is_metal(ggml_backend_t backend);
45
 
46
- GGML_API ggml_backend_buffer_t ggml_backend_metal_buffer_from_ptr(void * data, size_t size, size_t max_size);
 
 
47
 
48
  GGML_API void ggml_backend_metal_set_abort_callback(ggml_backend_t backend, ggml_abort_callback abort_callback, void * user_data);
49
 
@@ -57,6 +59,8 @@ GGML_API bool ggml_backend_metal_supports_family(ggml_backend_t backend, int fam
57
  // capture all command buffers committed the next time `ggml_backend_graph_compute` is called
58
  GGML_API void ggml_backend_metal_capture_next_compute(ggml_backend_t backend);
59
 
 
 
60
  #ifdef __cplusplus
61
  }
62
  #endif
 
43
 
44
  GGML_API bool ggml_backend_is_metal(ggml_backend_t backend);
45
 
46
+ GGML_DEPRECATED(
47
+ GGML_API ggml_backend_buffer_t ggml_backend_metal_buffer_from_ptr(void * data, size_t size, size_t max_size),
48
+ "obsoleted by the new device interface - https://github.com/ggerganov/llama.cpp/pull/9713");
49
 
50
  GGML_API void ggml_backend_metal_set_abort_callback(ggml_backend_t backend, ggml_abort_callback abort_callback, void * user_data);
51
 
 
59
  // capture all command buffers committed the next time `ggml_backend_graph_compute` is called
60
  GGML_API void ggml_backend_metal_capture_next_compute(ggml_backend_t backend);
61
 
62
+ GGML_API ggml_backend_reg_t ggml_backend_metal_reg(void);
63
+
64
  #ifdef __cplusplus
65
  }
66
  #endif
ggml/src/ggml-backend.cpp CHANGED
@@ -463,6 +463,7 @@ enum ggml_backend_dev_type ggml_backend_dev_type(ggml_backend_dev_t device) {
463
  }
464
 
465
  void ggml_backend_dev_get_props(ggml_backend_dev_t device, struct ggml_backend_dev_props * props) {
 
466
  device->iface.get_props(device, props);
467
  }
468
 
@@ -479,6 +480,10 @@ ggml_backend_buffer_type_t ggml_backend_dev_buffer_type(ggml_backend_dev_t devic
479
  }
480
 
481
  ggml_backend_buffer_type_t ggml_backend_dev_host_buffer_type(ggml_backend_dev_t device) {
 
 
 
 
482
  return device->iface.get_host_buffer_type(device);
483
  }
484
 
@@ -525,6 +530,10 @@ void * ggml_backend_reg_get_proc_address(ggml_backend_reg_t reg, const char * na
525
  #include "ggml-cuda.h"
526
  #endif
527
 
 
 
 
 
528
  struct ggml_backend_registry {
529
  std::vector<ggml_backend_reg_t> backends;
530
  std::vector<ggml_backend_dev_t> devices;
@@ -533,10 +542,13 @@ struct ggml_backend_registry {
533
  #ifdef GGML_USE_CUDA
534
  register_backend(ggml_backend_cuda_reg());
535
  #endif
 
 
 
536
 
537
  register_backend(ggml_backend_cpu_reg());
538
 
539
- // TODO: sycl, metal, vulkan, kompute, cann
540
  }
541
 
542
  void register_backend(ggml_backend_reg_t reg) {
@@ -1118,9 +1130,10 @@ static void ggml_backend_cpu_device_get_props(ggml_backend_dev_t dev, struct ggm
1118
  props->type = ggml_backend_cpu_device_get_type(dev);
1119
  ggml_backend_cpu_device_get_memory(dev, &props->memory_free, &props->memory_total);
1120
  props->caps = {
1121
- /* async */ false,
1122
- /* host_buffer */ false,
1123
- /* events */ false,
 
1124
  };
1125
  }
1126
 
 
463
  }
464
 
465
  void ggml_backend_dev_get_props(ggml_backend_dev_t device, struct ggml_backend_dev_props * props) {
466
+ memset(props, 0, sizeof(*props));
467
  device->iface.get_props(device, props);
468
  }
469
 
 
480
  }
481
 
482
  ggml_backend_buffer_type_t ggml_backend_dev_host_buffer_type(ggml_backend_dev_t device) {
483
+ if (device->iface.get_host_buffer_type == NULL) {
484
+ return NULL;
485
+ }
486
+
487
  return device->iface.get_host_buffer_type(device);
488
  }
489
 
 
530
  #include "ggml-cuda.h"
531
  #endif
532
 
533
+ #ifdef GGML_USE_METAL
534
+ #include "ggml-metal.h"
535
+ #endif
536
+
537
  struct ggml_backend_registry {
538
  std::vector<ggml_backend_reg_t> backends;
539
  std::vector<ggml_backend_dev_t> devices;
 
542
  #ifdef GGML_USE_CUDA
543
  register_backend(ggml_backend_cuda_reg());
544
  #endif
545
+ #ifdef GGML_USE_METAL
546
+ register_backend(ggml_backend_metal_reg());
547
+ #endif
548
 
549
  register_backend(ggml_backend_cpu_reg());
550
 
551
+ // TODO: sycl, vulkan, kompute, cann
552
  }
553
 
554
  void register_backend(ggml_backend_reg_t reg) {
 
1130
  props->type = ggml_backend_cpu_device_get_type(dev);
1131
  ggml_backend_cpu_device_get_memory(dev, &props->memory_free, &props->memory_total);
1132
  props->caps = {
1133
+ /* .async = */ false,
1134
+ /* .host_buffer = */ false,
1135
+ /* .buffer_from_host_ptr = */ true,
1136
+ /* .events = */ false,
1137
  };
1138
  }
1139
 
ggml/src/ggml-cuda.cu CHANGED
@@ -2920,9 +2920,10 @@ static void ggml_backend_cuda_device_get_props(ggml_backend_dev_t dev, ggml_back
2920
  #endif
2921
 
2922
  props->caps = {
2923
- /* async */ true,
2924
- /* host_buffer */ host_buffer,
2925
- /* events */ events,
 
2926
  };
2927
  }
2928
 
 
2920
  #endif
2921
 
2922
  props->caps = {
2923
+ /* .async = */ true,
2924
+ /* .host_buffer = */ host_buffer,
2925
+ /* .buffer_from_host_ptr = */ false,
2926
+ /* .events = */ events,
2927
  };
2928
  }
2929
 
ggml/src/ggml-metal.m CHANGED
@@ -20,6 +20,69 @@
20
 
21
  #define UNUSED(x) (void)(x)
22
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
  struct ggml_metal_kernel {
24
  id<MTLComputePipelineState> pipeline;
25
  };
@@ -214,16 +277,12 @@ enum ggml_metal_kernel_type {
214
  };
215
 
216
  struct ggml_backend_metal_context {
217
- id<MTLDevice> device;
218
  id<MTLCommandQueue> queue;
219
 
220
  dispatch_queue_t d_queue;
221
 
222
  struct ggml_metal_kernel kernels[GGML_METAL_KERNEL_TYPE_COUNT];
223
 
224
- bool support_simdgroup_reduction;
225
- bool support_simdgroup_mm;
226
-
227
  // capture state
228
  bool capture_next_compute;
229
  bool capture_started;
@@ -280,7 +339,7 @@ static void * ggml_metal_host_malloc(size_t n) {
280
  return data;
281
  }
282
 
283
- static struct ggml_backend_metal_context * ggml_metal_init(void) {
284
  GGML_LOG_INFO("%s: allocating\n", __func__);
285
 
286
  #if TARGET_OS_OSX && !GGML_METAL_NDEBUG
@@ -292,14 +351,14 @@ static struct ggml_backend_metal_context * ggml_metal_init(void) {
292
  [devices release]; // since it was created by a *Copy* C method
293
  #endif
294
 
295
- // Pick and show default Metal device
296
- id<MTLDevice> device = MTLCreateSystemDefaultDevice();
 
 
 
297
  GGML_LOG_INFO("%s: picking default device: %s\n", __func__, [[device name] UTF8String]);
298
 
299
- // Configure context
300
- struct ggml_backend_metal_context * ctx = calloc(1, sizeof(struct ggml_backend_metal_context));
301
- ctx->device = device;
302
- ctx->queue = [ctx->device newCommandQueue];
303
  ctx->d_queue = dispatch_queue_create("ggml-metal", DISPATCH_QUEUE_CONCURRENT);
304
 
305
  id<MTLLibrary> metal_library;
@@ -332,7 +391,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(void) {
332
  NSURL * libURL = [NSURL fileURLWithPath:path_lib];
333
  GGML_LOG_INFO("%s: loading '%s'\n", __func__, [path_lib UTF8String]);
334
 
335
- metal_library = [ctx->device newLibraryWithURL:libURL error:&error];
336
  if (error) {
337
  GGML_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]);
338
  return NULL;
@@ -382,7 +441,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(void) {
382
 
383
  //[options setFastMathEnabled:false];
384
 
385
- metal_library = [ctx->device newLibraryWithSource:src options:options error:&error];
386
  if (error) {
387
  GGML_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]);
388
  return NULL;
@@ -392,44 +451,37 @@ static struct ggml_backend_metal_context * ggml_metal_init(void) {
392
  }
393
 
394
  // print MTL GPU family:
395
- GGML_LOG_INFO("%s: GPU name: %s\n", __func__, [[ctx->device name] UTF8String]);
396
-
397
- const NSInteger MTLGPUFamilyMetal3 = 5001;
398
 
399
  // determine max supported GPU family
400
  // https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf
401
  // https://developer.apple.com/metal/Metal-Feature-Set-Tables.pdf
402
  {
403
  for (int i = MTLGPUFamilyApple1 + 20; i >= MTLGPUFamilyApple1; --i) {
404
- if ([ctx->device supportsFamily:i]) {
405
  GGML_LOG_INFO("%s: GPU family: MTLGPUFamilyApple%d (%d)\n", __func__, i - (int) MTLGPUFamilyApple1 + 1, i);
406
  break;
407
  }
408
  }
409
 
410
  for (int i = MTLGPUFamilyCommon1 + 5; i >= MTLGPUFamilyCommon1; --i) {
411
- if ([ctx->device supportsFamily:i]) {
412
  GGML_LOG_INFO("%s: GPU family: MTLGPUFamilyCommon%d (%d)\n", __func__, i - (int) MTLGPUFamilyCommon1 + 1, i);
413
  break;
414
  }
415
  }
416
 
417
- for (int i = MTLGPUFamilyMetal3 + 5; i >= MTLGPUFamilyMetal3; --i) {
418
- if ([ctx->device supportsFamily:i]) {
419
- GGML_LOG_INFO("%s: GPU family: MTLGPUFamilyMetal%d (%d)\n", __func__, i - (int) MTLGPUFamilyMetal3 + 3, i);
420
  break;
421
  }
422
  }
423
  }
424
 
425
- ctx->support_simdgroup_reduction = [ctx->device supportsFamily:MTLGPUFamilyApple7];
426
- ctx->support_simdgroup_reduction |= [ctx->device supportsFamily:MTLGPUFamilyMetal3];
427
-
428
- ctx->support_simdgroup_mm = [ctx->device supportsFamily:MTLGPUFamilyApple7];
429
-
430
- GGML_LOG_INFO("%s: simdgroup reduction support = %s\n", __func__, ctx->support_simdgroup_reduction ? "true" : "false");
431
- GGML_LOG_INFO("%s: simdgroup matrix mul. support = %s\n", __func__, ctx->support_simdgroup_mm ? "true" : "false");
432
- GGML_LOG_INFO("%s: hasUnifiedMemory = %s\n", __func__, ctx->device.hasUnifiedMemory ? "true" : "false");
433
 
434
  ctx->capture_next_compute = false;
435
  ctx->capture_started = false;
@@ -443,13 +495,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(void) {
443
 
444
  #if TARGET_OS_OSX || (TARGET_OS_IOS && __clang_major__ >= 15)
445
  if (@available(macOS 10.12, iOS 16.0, *)) {
446
- GGML_LOG_INFO("%s: recommendedMaxWorkingSetSize = %8.2f MB\n", __func__, ctx->device.recommendedMaxWorkingSetSize / 1e6);
447
- }
448
- #elif TARGET_OS_OSX
449
- if (ctx->device.maxTransferRate != 0) {
450
- GGML_LOG_INFO("%s: maxTransferRate = %8.2f MB/s\n", __func__, ctx->device.maxTransferRate / 1e6);
451
- } else {
452
- GGML_LOG_INFO("%s: maxTransferRate = built-in GPU\n", __func__);
453
  }
454
  #endif
455
 
@@ -470,7 +516,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(void) {
470
  if (supported) { \
471
  struct ggml_metal_kernel * kernel = &ctx->kernels[e]; \
472
  id<MTLFunction> metal_function = [metal_library newFunctionWithName:@"kernel_"#name]; \
473
- kernel->pipeline = [ctx->device newComputePipelineStateWithFunction:metal_function error:&error]; \
474
  [metal_function release]; \
475
  if (error) { \
476
  GGML_LOG_ERROR("%s: error: load pipeline error: %s\n", __func__, [[error description] UTF8String]); \
@@ -481,6 +527,9 @@ static struct ggml_backend_metal_context * ggml_metal_init(void) {
481
  GGML_LOG_WARN("%s: skipping %-40s (not supported)\n", __func__, "kernel_"#name); \
482
  }
483
 
 
 
 
484
  // simd_sum and simd_max requires MTLGPUFamilyApple7
485
 
486
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD, add, true);
@@ -507,10 +556,10 @@ static struct ggml_backend_metal_context * ggml_metal_init(void) {
507
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_QUICK_4, gelu_quick_4, true);
508
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU, silu, true);
509
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU_4, silu_4, true);
510
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16, soft_max_f16, ctx->support_simdgroup_reduction);
511
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16_4, soft_max_f16_4, ctx->support_simdgroup_reduction);
512
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32, soft_max_f32, ctx->support_simdgroup_reduction);
513
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32_4, soft_max_f32_4, ctx->support_simdgroup_reduction);
514
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF, diag_mask_inf, true);
515
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF_8, diag_mask_inf_8, true);
516
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_F32, get_rows_f32, true);
@@ -535,101 +584,101 @@ static struct ggml_backend_metal_context * ggml_metal_init(void) {
535
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL, get_rows_iq4_nl, true);
536
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS, get_rows_iq4_xs, true);
537
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_I32, get_rows_i32, true);
538
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RMS_NORM, rms_norm, ctx->support_simdgroup_reduction);
539
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GROUP_NORM, group_norm, ctx->support_simdgroup_reduction);
540
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NORM, norm, true);
541
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_CONV_F32, ssm_conv_f32, true);
542
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32, ssm_scan_f32, true);
543
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32, mul_mv_f32_f32, ctx->support_simdgroup_reduction);
544
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16, mul_mv_f16_f16, ctx->support_simdgroup_reduction);
545
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32, mul_mv_f16_f32, ctx->support_simdgroup_reduction);
546
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW, mul_mv_f16_f32_1row, ctx->support_simdgroup_reduction);
547
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4, mul_mv_f16_f32_l4, ctx->support_simdgroup_reduction);
548
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_0_F32, mul_mv_q4_0_f32, ctx->support_simdgroup_reduction);
549
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_1_F32, mul_mv_q4_1_f32, ctx->support_simdgroup_reduction);
550
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32, mul_mv_q5_0_f32, ctx->support_simdgroup_reduction);
551
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32, mul_mv_q5_1_f32, ctx->support_simdgroup_reduction);
552
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32, mul_mv_q8_0_f32, ctx->support_simdgroup_reduction);
553
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q2_K_F32, mul_mv_q2_K_f32, ctx->support_simdgroup_reduction);
554
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q3_K_F32, mul_mv_q3_K_f32, ctx->support_simdgroup_reduction);
555
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_K_F32, mul_mv_q4_K_f32, ctx->support_simdgroup_reduction);
556
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_K_F32, mul_mv_q5_K_f32, ctx->support_simdgroup_reduction);
557
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q6_K_F32, mul_mv_q6_K_f32, ctx->support_simdgroup_reduction);
558
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XXS_F32, mul_mv_iq2_xxs_f32, ctx->support_simdgroup_reduction);
559
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32, mul_mv_iq2_xs_f32, ctx->support_simdgroup_reduction);
560
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_XXS_F32, mul_mv_iq3_xxs_f32, ctx->support_simdgroup_reduction);
561
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_S_F32, mul_mv_iq3_s_f32, ctx->support_simdgroup_reduction);
562
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_S_F32, mul_mv_iq2_s_f32, ctx->support_simdgroup_reduction);
563
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_S_F32, mul_mv_iq1_s_f32, ctx->support_simdgroup_reduction);
564
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_M_F32, mul_mv_iq1_m_f32, ctx->support_simdgroup_reduction);
565
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32, mul_mv_iq4_nl_f32, ctx->support_simdgroup_reduction);
566
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_XS_F32, mul_mv_iq4_xs_f32, ctx->support_simdgroup_reduction);
567
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32, mul_mv_id_f32_f32, ctx->support_simdgroup_reduction);
568
- //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F16, mul_mv_id_f16_f16, ctx->support_simdgroup_reduction);
569
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32, mul_mv_id_f16_f32, ctx->support_simdgroup_reduction);
570
- //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_1ROW, mul_mv_id_f16_f32_1row, ctx->support_simdgroup_reduction);
571
- //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_L4, mul_mv_id_f16_f32_l4, ctx->support_simdgroup_reduction);
572
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_0_F32, mul_mv_id_q4_0_f32, ctx->support_simdgroup_reduction);
573
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_1_F32, mul_mv_id_q4_1_f32, ctx->support_simdgroup_reduction);
574
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_0_F32, mul_mv_id_q5_0_f32, ctx->support_simdgroup_reduction);
575
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_1_F32, mul_mv_id_q5_1_f32, ctx->support_simdgroup_reduction);
576
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q8_0_F32, mul_mv_id_q8_0_f32, ctx->support_simdgroup_reduction);
577
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q2_K_F32, mul_mv_id_q2_K_f32, ctx->support_simdgroup_reduction);
578
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q3_K_F32, mul_mv_id_q3_K_f32, ctx->support_simdgroup_reduction);
579
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_K_F32, mul_mv_id_q4_K_f32, ctx->support_simdgroup_reduction);
580
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_K_F32, mul_mv_id_q5_K_f32, ctx->support_simdgroup_reduction);
581
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q6_K_F32, mul_mv_id_q6_K_f32, ctx->support_simdgroup_reduction);
582
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XXS_F32, mul_mv_id_iq2_xxs_f32, ctx->support_simdgroup_reduction);
583
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32, mul_mv_id_iq2_xs_f32, ctx->support_simdgroup_reduction);
584
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_XXS_F32, mul_mv_id_iq3_xxs_f32, ctx->support_simdgroup_reduction);
585
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_S_F32, mul_mv_id_iq3_s_f32, ctx->support_simdgroup_reduction);
586
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_S_F32, mul_mv_id_iq2_s_f32, ctx->support_simdgroup_reduction);
587
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_S_F32, mul_mv_id_iq1_s_f32, ctx->support_simdgroup_reduction);
588
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_M_F32, mul_mv_id_iq1_m_f32, ctx->support_simdgroup_reduction);
589
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32, mul_mv_id_iq4_nl_f32, ctx->support_simdgroup_reduction);
590
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32, mul_mv_id_iq4_xs_f32, ctx->support_simdgroup_reduction);
591
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32, mul_mm_f32_f32, ctx->support_simdgroup_mm);
592
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32, mul_mm_f16_f32, ctx->support_simdgroup_mm);
593
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32, mul_mm_q4_0_f32, ctx->support_simdgroup_mm);
594
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F32, mul_mm_q4_1_f32, ctx->support_simdgroup_mm);
595
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32, mul_mm_q5_0_f32, ctx->support_simdgroup_mm);
596
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_1_F32, mul_mm_q5_1_f32, ctx->support_simdgroup_mm);
597
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q8_0_F32, mul_mm_q8_0_f32, ctx->support_simdgroup_mm);
598
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q2_K_F32, mul_mm_q2_K_f32, ctx->support_simdgroup_mm);
599
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q3_K_F32, mul_mm_q3_K_f32, ctx->support_simdgroup_mm);
600
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_K_F32, mul_mm_q4_K_f32, ctx->support_simdgroup_mm);
601
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_K_F32, mul_mm_q5_K_f32, ctx->support_simdgroup_mm);
602
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q6_K_F32, mul_mm_q6_K_f32, ctx->support_simdgroup_mm);
603
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32, mul_mm_iq2_xxs_f32, ctx->support_simdgroup_mm);
604
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32, mul_mm_iq2_xs_f32, ctx->support_simdgroup_mm);
605
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F32, mul_mm_iq3_xxs_f32, ctx->support_simdgroup_mm);
606
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_S_F32, mul_mm_iq3_s_f32, ctx->support_simdgroup_mm);
607
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_S_F32, mul_mm_iq2_s_f32, ctx->support_simdgroup_mm);
608
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32, mul_mm_iq1_s_f32, ctx->support_simdgroup_mm);
609
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F32, mul_mm_iq1_m_f32, ctx->support_simdgroup_mm);
610
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32, mul_mm_iq4_nl_f32, ctx->support_simdgroup_mm);
611
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32, mul_mm_iq4_xs_f32, ctx->support_simdgroup_mm);
612
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32, mul_mm_id_f32_f32, ctx->support_simdgroup_mm);
613
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32, mul_mm_id_f16_f32, ctx->support_simdgroup_mm);
614
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32, mul_mm_id_q4_0_f32, ctx->support_simdgroup_mm);
615
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F32, mul_mm_id_q4_1_f32, ctx->support_simdgroup_mm);
616
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F32, mul_mm_id_q5_0_f32, ctx->support_simdgroup_mm);
617
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F32, mul_mm_id_q5_1_f32, ctx->support_simdgroup_mm);
618
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F32, mul_mm_id_q8_0_f32, ctx->support_simdgroup_mm);
619
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F32, mul_mm_id_q2_K_f32, ctx->support_simdgroup_mm);
620
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F32, mul_mm_id_q3_K_f32, ctx->support_simdgroup_mm);
621
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F32, mul_mm_id_q4_K_f32, ctx->support_simdgroup_mm);
622
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_K_F32, mul_mm_id_q5_K_f32, ctx->support_simdgroup_mm);
623
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F32, mul_mm_id_q6_K_f32, ctx->support_simdgroup_mm);
624
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32, mul_mm_id_iq2_xxs_f32, ctx->support_simdgroup_mm);
625
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32, mul_mm_id_iq2_xs_f32, ctx->support_simdgroup_mm);
626
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32, mul_mm_id_iq3_xxs_f32, ctx->support_simdgroup_mm);
627
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_S_F32, mul_mm_id_iq3_s_f32, ctx->support_simdgroup_mm);
628
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F32, mul_mm_id_iq2_s_f32, ctx->support_simdgroup_mm);
629
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32, mul_mm_id_iq1_s_f32, ctx->support_simdgroup_mm);
630
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F32, mul_mm_id_iq1_m_f32, ctx->support_simdgroup_mm);
631
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32, mul_mm_id_iq4_nl_f32, ctx->support_simdgroup_mm);
632
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32, mul_mm_id_iq4_xs_f32, ctx->support_simdgroup_mm);
633
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32, rope_norm_f32, true);
634
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16, rope_norm_f16, true);
635
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32, rope_neox_f32, true);
@@ -643,14 +692,14 @@ static struct ggml_backend_metal_context * ggml_metal_init(void) {
643
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC, argsort_f32_i32_asc, true);
644
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC, argsort_f32_i32_desc, true);
645
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32, leaky_relu_f32, true);
646
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64, flash_attn_ext_f16_h64, ctx->support_simdgroup_mm);
647
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80, flash_attn_ext_f16_h80, ctx->support_simdgroup_mm);
648
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96, flash_attn_ext_f16_h96, ctx->support_simdgroup_mm);
649
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112, flash_attn_ext_f16_h112, ctx->support_simdgroup_mm);
650
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128, flash_attn_ext_f16_h128, ctx->support_simdgroup_mm);
651
- //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, flash_attn_ext_f16_h256, ctx->support_simdgroup_mm);
652
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128, flash_attn_ext_vec_f16_h128, ctx->support_simdgroup_reduction);
653
- //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, flash_attn_ext_vec_f16_h256, ctx->support_simdgroup_reduction);
654
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F16, cpy_f32_f16, true);
655
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F32, cpy_f32_f32, true);
656
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F16, cpy_f16_f16, true);
@@ -684,7 +733,6 @@ static void ggml_metal_free(struct ggml_backend_metal_context * ctx) {
684
  Block_release(ctx->encode_async);
685
 
686
  [ctx->queue release];
687
- [ctx->device release];
688
 
689
  dispatch_release(ctx->d_queue);
690
 
@@ -742,13 +790,16 @@ static id<MTLBuffer> ggml_metal_get_buffer(struct ggml_tensor * t, size_t * offs
742
  return nil;
743
  }
744
 
745
- static bool ggml_metal_supports_op(const struct ggml_backend_metal_context * ctx, const struct ggml_tensor * op) {
746
  for (size_t i = 0, n = 3; i < n; ++i) {
747
  if (op->src[i] != NULL && op->src[i]->type == GGML_TYPE_BF16) {
748
  return false;
749
  }
750
  }
751
 
 
 
 
752
  switch (op->op) {
753
  case GGML_OP_UNARY:
754
  switch (ggml_get_unary_op(op)) {
@@ -786,7 +837,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_context * ctx
786
  case GGML_OP_SOFT_MAX:
787
  case GGML_OP_RMS_NORM:
788
  case GGML_OP_GROUP_NORM:
789
- return ctx->support_simdgroup_reduction;
790
  case GGML_OP_NORM:
791
  case GGML_OP_ROPE:
792
  return true;
@@ -812,13 +863,13 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_context * ctx
812
  if (op->src[0]->ne[0] == 256) {
813
  return false;
814
  }
815
- return ctx->support_simdgroup_mm; // TODO: over-restricted for vec-kernels
816
  case GGML_OP_SSM_CONV:
817
  case GGML_OP_SSM_SCAN:
818
  return true;
819
  case GGML_OP_MUL_MAT:
820
  case GGML_OP_MUL_MAT_ID:
821
- return ctx->support_simdgroup_reduction &&
822
  (op->src[0]->type != GGML_TYPE_F32 || op->src[1]->type == GGML_TYPE_F32);
823
  case GGML_OP_CPY:
824
  case GGML_OP_DUP:
@@ -862,9 +913,12 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_context * ctx
862
  }
863
 
864
  static void ggml_metal_encode_node(
865
- struct ggml_backend_metal_context * ctx,
866
  int idx,
867
  id<MTLComputeCommandEncoder> encoder) {
 
 
 
868
  struct ggml_cgraph * gf = ctx->gf;
869
 
870
  struct ggml_tensor * node = ggml_graph_node(gf, idx);
@@ -894,7 +948,7 @@ static void ggml_metal_encode_node(
894
  } break;
895
  }
896
 
897
- if (!ggml_metal_supports_op(ctx, dst)) {
898
  GGML_LOG_ERROR("%s: error: unsupported op '%s'\n", __func__, ggml_op_desc(dst));
899
  GGML_ABORT("unsupported op");
900
  }
@@ -967,6 +1021,8 @@ static void ggml_metal_encode_node(
967
  // dst->name);
968
  //}
969
 
 
 
970
  switch (dst->op) {
971
  case GGML_OP_CONCAT:
972
  {
@@ -1675,7 +1731,7 @@ static void ggml_metal_encode_node(
1675
  // the numbers below are measured on M2 Ultra for 7B and 13B models
1676
  // these numbers do not translate to other devices or model sizes
1677
  // TODO: need to find a better approach
1678
- if ([ctx->device.name isEqualToString:@"Apple M2 Ultra"]) {
1679
  switch (src0t) {
1680
  case GGML_TYPE_F16: ne11_mm_min = 2; break;
1681
  case GGML_TYPE_Q8_0: ne11_mm_min = 7; break;
@@ -1695,7 +1751,7 @@ static void ggml_metal_encode_node(
1695
 
1696
  // for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
1697
  // AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
1698
- if ([ctx->device supportsFamily:MTLGPUFamilyApple7] &&
1699
  !ggml_is_transposed(src0) &&
1700
  !ggml_is_transposed(src1) &&
1701
  src1t == GGML_TYPE_F32 &&
@@ -1990,7 +2046,7 @@ static void ggml_metal_encode_node(
1990
  // ne21 = n_rows
1991
  const int dst_rows = ne20*ne21;
1992
  const int dst_rows_min = n_as;
1993
- const int dst_rows_max = (ctx->device.maxThreadgroupMemoryLength - 32 - 8192)/4;
1994
 
1995
  // max size of the rowids array in the kernel shared buffer
1996
  GGML_ASSERT(dst_rows <= dst_rows_max);
@@ -2001,7 +2057,7 @@ static void ggml_metal_encode_node(
2001
  // TODO: for now, always use mat-vec kernels until we figure out how to improve the
2002
  // indirect matrix multiplication
2003
  // !!!
2004
- if ([ctx->device supportsFamily:MTLGPUFamilyApple7] &&
2005
  ne00 % 32 == 0 && ne00 >= 64 &&
2006
  dst_rows > dst_rows_min) {
2007
 
@@ -2840,7 +2896,7 @@ static void ggml_metal_encode_node(
2840
 
2841
  while (true) {
2842
  const size_t smem = nqptg*(ne00 + 2*nsgmax*(ncpsg + nqptg))*(sizeof(float)/2);
2843
- if (smem > ctx->device.maxThreadgroupMemoryLength) {
2844
  break;
2845
  }
2846
  nsgmax *= 2;
@@ -2852,8 +2908,8 @@ static void ggml_metal_encode_node(
2852
 
2853
  const size_t smem = nqptg*(ne00 + 2*nsg*(ncpsg + nqptg))*(sizeof(float)/2);
2854
 
2855
- //printf("smem: %zu, max: %zu\n", smem, ctx->device.maxThreadgroupMemoryLength);
2856
- GGML_ASSERT(smem <= ctx->device.maxThreadgroupMemoryLength);
2857
 
2858
  [encoder setThreadgroupMemoryLength:GGML_PAD(smem, 16) atIndex:0];
2859
 
@@ -2878,8 +2934,8 @@ static void ggml_metal_encode_node(
2878
 
2879
  const size_t smem = (nqptg*(ne00 + 2*nsg*(ncpsg + nqptg)) + nsg*ne00)*(sizeof(float)/2);
2880
 
2881
- //printf("smem: %zu, max: %zu\n", smem, ctx->device.maxThreadgroupMemoryLength);
2882
- GGML_ASSERT(smem <= ctx->device.maxThreadgroupMemoryLength);
2883
  [encoder setThreadgroupMemoryLength:GGML_PAD(smem, 16) atIndex:0];
2884
 
2885
  [encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)];
@@ -2954,8 +3010,11 @@ static void ggml_metal_encode_node(
2954
  }
2955
 
2956
  static enum ggml_status ggml_metal_graph_compute(
2957
- struct ggml_backend_metal_context * ctx,
2958
- struct ggml_cgraph * gf) {
 
 
 
2959
  // number of nodes encoded by the main thread (empirically determined)
2960
  const int n_main = 128;
2961
 
@@ -2983,7 +3042,7 @@ static enum ggml_status ggml_metal_graph_compute(
2983
 
2984
  if (!ctx->capture_started) {
2985
  // create capture scope
2986
- ctx->capture_scope = [[MTLCaptureManager sharedCaptureManager] newCaptureScopeWithDevice:ctx->device];
2987
 
2988
  MTLCaptureDescriptor * descriptor = [MTLCaptureDescriptor new];
2989
  descriptor.captureObject = ctx->capture_scope;
@@ -3087,31 +3146,6 @@ static enum ggml_status ggml_metal_graph_compute(
3087
 
3088
  // backend interface
3089
 
3090
- // default buffer
3091
- static id<MTLDevice> g_backend_device = nil;
3092
- static int g_backend_device_ref_count = 0;
3093
-
3094
- static id<MTLDevice> ggml_backend_metal_get_device(void) {
3095
- if (g_backend_device == nil) {
3096
- g_backend_device = MTLCreateSystemDefaultDevice();
3097
- }
3098
-
3099
- g_backend_device_ref_count++;
3100
-
3101
- return g_backend_device;
3102
- }
3103
-
3104
- static void ggml_backend_metal_free_device(void) {
3105
- assert(g_backend_device_ref_count > 0);
3106
-
3107
- g_backend_device_ref_count--;
3108
-
3109
- if (g_backend_device_ref_count == 0) {
3110
- [g_backend_device release];
3111
- g_backend_device = nil;
3112
- }
3113
- }
3114
-
3115
  static const char * ggml_backend_metal_buffer_get_name(ggml_backend_buffer_t buffer) {
3116
  return "Metal";
3117
 
@@ -3124,7 +3158,7 @@ static void ggml_backend_metal_buffer_free_buffer(ggml_backend_buffer_t buffer)
3124
  for (int i = 0; i < ctx->n_buffers; i++) {
3125
  [ctx->buffers[i].metal release];
3126
  }
3127
- ggml_backend_metal_free_device();
3128
 
3129
  if (ctx->owned) {
3130
  #if TARGET_OS_OSX
@@ -3227,7 +3261,7 @@ static ggml_backend_buffer_t ggml_backend_metal_buffer_type_alloc_buffer(ggml_ba
3227
  size_aligned += (size_page - (size_aligned % size_page));
3228
  }
3229
 
3230
- id<MTLDevice> device = ggml_backend_metal_get_device();
3231
 
3232
  ctx->all_data = ggml_metal_host_malloc(size_aligned);
3233
  ctx->all_size = size_aligned;
@@ -3241,16 +3275,16 @@ static ggml_backend_buffer_t ggml_backend_metal_buffer_type_alloc_buffer(ggml_ba
3241
 
3242
  if (size_aligned > 0) {
3243
  ctx->buffers[0].metal = [device newBufferWithBytesNoCopy:ctx->all_data
3244
- length:size_aligned
3245
- options:MTLResourceStorageModeShared
3246
- deallocator:nil];
3247
  }
3248
  }
3249
 
3250
  if (size_aligned > 0 && (ctx->all_data == NULL || ctx->buffers[0].metal == nil)) {
3251
  GGML_LOG_ERROR("%s: error: failed to allocate buffer, size = %8.2f MiB\n", __func__, size_aligned / 1024.0 / 1024.0);
3252
  free(ctx);
3253
- ggml_backend_metal_free_device();
3254
  return NULL;
3255
  }
3256
 
@@ -3265,9 +3299,9 @@ static size_t ggml_backend_metal_buffer_type_get_alignment(ggml_backend_buffer_t
3265
  }
3266
 
3267
  static size_t ggml_backend_metal_buffer_type_get_max_size(ggml_backend_buffer_type_t buft) {
3268
- id<MTLDevice> device = ggml_backend_metal_get_device();
3269
- size_t max_size = device.maxBufferLength;
3270
- ggml_backend_metal_free_device();
3271
 
3272
  return max_size;
3273
 
@@ -3290,15 +3324,14 @@ ggml_backend_buffer_type_t ggml_backend_metal_buffer_type(void) {
3290
  /* .get_alloc_size = */ NULL, // defaults to ggml_nbytes
3291
  /* .is_host = */ ggml_backend_metal_buffer_type_is_host,
3292
  },
3293
- /* .device = */ NULL,
3294
  /* .context = */ NULL,
3295
  };
3296
 
3297
  return &ggml_backend_buffer_type_metal;
3298
  }
3299
 
3300
- // buffer from ptr
3301
-
3302
  ggml_backend_buffer_t ggml_backend_metal_buffer_from_ptr(void * data, size_t size, size_t max_size) {
3303
  struct ggml_backend_metal_buffer_context * ctx = calloc(1, sizeof(struct ggml_backend_metal_buffer_context));
3304
 
@@ -3321,7 +3354,7 @@ ggml_backend_buffer_t ggml_backend_metal_buffer_from_ptr(void * data, size_t siz
3321
  size_aligned += (size_page - (size_aligned % size_page));
3322
  }
3323
 
3324
- id<MTLDevice> device = ggml_backend_metal_get_device();
3325
 
3326
  // the buffer fits into the max buffer size allowed by the device
3327
  if (size_aligned <= device.maxBufferLength) {
@@ -3386,8 +3419,12 @@ static const char * ggml_backend_metal_name(ggml_backend_t backend) {
3386
  }
3387
 
3388
  static void ggml_backend_metal_free(ggml_backend_t backend) {
3389
- struct ggml_backend_metal_context * ctx = (struct ggml_backend_metal_context *)backend->context;
 
 
 
3390
  ggml_metal_free(ctx);
 
3391
  free(backend);
3392
  }
3393
 
@@ -3398,21 +3435,7 @@ static ggml_backend_buffer_type_t ggml_backend_metal_get_default_buffer_type(ggm
3398
  }
3399
 
3400
  static enum ggml_status ggml_backend_metal_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) {
3401
- struct ggml_backend_metal_context * metal_ctx = (struct ggml_backend_metal_context *)backend->context;
3402
-
3403
- return ggml_metal_graph_compute(metal_ctx, cgraph);
3404
- }
3405
-
3406
- static bool ggml_backend_metal_supports_op(ggml_backend_t backend, const struct ggml_tensor * op) {
3407
- struct ggml_backend_metal_context * metal_ctx = (struct ggml_backend_metal_context *)backend->context;
3408
-
3409
- return ggml_metal_supports_op(metal_ctx, op);
3410
- }
3411
-
3412
- static bool ggml_backend_metal_supports_buft(ggml_backend_t backend, ggml_backend_buffer_type_t buft) {
3413
- return buft->iface.get_name == ggml_backend_metal_buffer_type_get_name;
3414
-
3415
- UNUSED(backend);
3416
  }
3417
 
3418
  static void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) {
@@ -3459,7 +3482,7 @@ static void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) {
3459
  [encoder pushDebugGroup:[NSString stringWithCString:ggml_op_desc(ggml_graph_node(ctx->gf, idx)) encoding:NSUTF8StringEncoding]];
3460
  }
3461
 
3462
- ggml_metal_encode_node(ctx, idx, encoder);
3463
 
3464
  if (should_capture) {
3465
  [encoder popDebugGroup];
@@ -3487,8 +3510,8 @@ static struct ggml_backend_i ggml_backend_metal_i = {
3487
  /* .graph_plan_update = */ NULL,
3488
  /* .graph_plan_compute = */ NULL,
3489
  /* .graph_compute = */ ggml_backend_metal_graph_compute,
3490
- /* .supports_op = */ ggml_backend_metal_supports_op,
3491
- /* .supports_buft = */ ggml_backend_metal_supports_buft,
3492
  /* .offload_op = */ NULL,
3493
  /* .event_record = */ NULL,
3494
  /* .event_wait = */ NULL,
@@ -3499,8 +3522,11 @@ static ggml_guid_t ggml_backend_metal_guid(void) {
3499
  return &guid;
3500
  }
3501
 
 
3502
  ggml_backend_t ggml_backend_metal_init(void) {
3503
- struct ggml_backend_metal_context * ctx = ggml_metal_init();
 
 
3504
  if (ctx == NULL) {
3505
  GGML_LOG_ERROR("%s: error: failed to allocate context\n", __func__);
3506
  return NULL;
@@ -3511,7 +3537,7 @@ ggml_backend_t ggml_backend_metal_init(void) {
3511
  *backend = (struct ggml_backend) {
3512
  /* .guid = */ ggml_backend_metal_guid(),
3513
  /* .interface = */ ggml_backend_metal_i,
3514
- /* .device = */ NULL,
3515
  /* .context = */ ctx,
3516
  };
3517
 
@@ -3536,9 +3562,9 @@ void ggml_backend_metal_set_abort_callback(ggml_backend_t backend, ggml_abort_ca
3536
  bool ggml_backend_metal_supports_family(ggml_backend_t backend, int family) {
3537
  GGML_ASSERT(ggml_backend_is_metal(backend));
3538
 
3539
- struct ggml_backend_metal_context * ctx = (struct ggml_backend_metal_context *)backend->context;
3540
 
3541
- return [ctx->device supportsFamily:(MTLGPUFamilyApple1 + family - 1)];
3542
  }
3543
 
3544
  void ggml_backend_metal_capture_next_compute(ggml_backend_t backend) {
@@ -3548,11 +3574,246 @@ void ggml_backend_metal_capture_next_compute(ggml_backend_t backend) {
3548
  ctx->capture_next_compute = true;
3549
  }
3550
 
3551
- ggml_backend_t ggml_backend_reg_metal_init(const char * params, void * user_data); // silence warning
 
 
 
3552
 
3553
- ggml_backend_t ggml_backend_reg_metal_init(const char * params, void * user_data) {
3554
- return ggml_backend_metal_init();
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3555
 
3556
  GGML_UNUSED(params);
3557
- GGML_UNUSED(user_data);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3558
  }
 
20
 
21
  #define UNUSED(x) (void)(x)
22
 
23
+ // globals
24
+
25
+ // overload of MTLGPUFamilyMetal3 (not available in some environments)
26
+ static const NSInteger MTLGPUFamilyMetal3_GGML = 5001;
27
+
28
+ // initialized in ggml_backend_metal_reg
29
+ static struct ggml_backend_reg g_ggml_backend_metal_reg;
30
+ static struct ggml_backend_device g_ggml_backend_metal_device;
31
+
32
+ // information about a Metal device
33
+ // note: assumes single GPU device - the default one
34
+ // TODO: support multiple GPU devices
35
+ static struct ggml_backend_metal_device_context {
36
+ id<MTLDevice> mtl_device;
37
+ int mtl_device_ref_count;
38
+
39
+ bool support_simdgroup_reduction;
40
+ bool support_simdgroup_mm;
41
+
42
+ char name[128];
43
+ } g_ggml_ctx_dev_main = {
44
+ /*.mtl_device =*/ nil,
45
+ /*.mtl_device_ref_count =*/ 0,
46
+ /*.support_simdgroup_reduction =*/ false,
47
+ /*.support_simdgroup_mm =*/ false,
48
+ /*.name =*/ "",
49
+ };
50
+
51
+ // acquire
52
+ static id<MTLDevice> ggml_backend_metal_device_acq(struct ggml_backend_metal_device_context * ctx) {
53
+ assert(ctx != NULL);
54
+
55
+ if (ctx->mtl_device == nil) {
56
+ ctx->mtl_device = MTLCreateSystemDefaultDevice();
57
+
58
+ ctx->support_simdgroup_reduction = [ctx->mtl_device supportsFamily:MTLGPUFamilyApple7];
59
+ ctx->support_simdgroup_reduction |= [ctx->mtl_device supportsFamily:MTLGPUFamilyMetal3_GGML];
60
+
61
+ ctx->support_simdgroup_mm = [ctx->mtl_device supportsFamily:MTLGPUFamilyApple7];
62
+
63
+ strncpy(ctx->name, [[ctx->mtl_device name] UTF8String], sizeof(ctx->name) - 1);
64
+ }
65
+
66
+ ctx->mtl_device_ref_count++;
67
+
68
+ return ctx->mtl_device;
69
+ }
70
+
71
+ // release
72
+ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_context * ctx) {
73
+ assert(ctx != NULL);
74
+ assert(ctx->mtl_device_ref_count > 0);
75
+
76
+ ctx->mtl_device_ref_count--;
77
+
78
+ if (ctx->mtl_device_ref_count == 0) {
79
+ [ctx->mtl_device release];
80
+ ctx->mtl_device = nil;
81
+ }
82
+ }
83
+
84
+ // kernels
85
+
86
  struct ggml_metal_kernel {
87
  id<MTLComputePipelineState> pipeline;
88
  };
 
277
  };
278
 
279
  struct ggml_backend_metal_context {
 
280
  id<MTLCommandQueue> queue;
281
 
282
  dispatch_queue_t d_queue;
283
 
284
  struct ggml_metal_kernel kernels[GGML_METAL_KERNEL_TYPE_COUNT];
285
 
 
 
 
286
  // capture state
287
  bool capture_next_compute;
288
  bool capture_started;
 
339
  return data;
340
  }
341
 
342
+ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t dev) {
343
  GGML_LOG_INFO("%s: allocating\n", __func__);
344
 
345
  #if TARGET_OS_OSX && !GGML_METAL_NDEBUG
 
351
  [devices release]; // since it was created by a *Copy* C method
352
  #endif
353
 
354
+ // init context
355
+ struct ggml_backend_metal_context * ctx = calloc(1, sizeof(struct ggml_backend_metal_context));
356
+ struct ggml_backend_metal_device_context * ctx_dev = dev->context;
357
+
358
+ id<MTLDevice> device = ggml_backend_metal_device_acq(ctx_dev);
359
  GGML_LOG_INFO("%s: picking default device: %s\n", __func__, [[device name] UTF8String]);
360
 
361
+ ctx->queue = [device newCommandQueue];
 
 
 
362
  ctx->d_queue = dispatch_queue_create("ggml-metal", DISPATCH_QUEUE_CONCURRENT);
363
 
364
  id<MTLLibrary> metal_library;
 
391
  NSURL * libURL = [NSURL fileURLWithPath:path_lib];
392
  GGML_LOG_INFO("%s: loading '%s'\n", __func__, [path_lib UTF8String]);
393
 
394
+ metal_library = [device newLibraryWithURL:libURL error:&error];
395
  if (error) {
396
  GGML_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]);
397
  return NULL;
 
441
 
442
  //[options setFastMathEnabled:false];
443
 
444
+ metal_library = [device newLibraryWithSource:src options:options error:&error];
445
  if (error) {
446
  GGML_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]);
447
  return NULL;
 
451
  }
452
 
453
  // print MTL GPU family:
454
+ GGML_LOG_INFO("%s: GPU name: %s\n", __func__, [[device name] UTF8String]);
 
 
455
 
456
  // determine max supported GPU family
457
  // https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf
458
  // https://developer.apple.com/metal/Metal-Feature-Set-Tables.pdf
459
  {
460
  for (int i = MTLGPUFamilyApple1 + 20; i >= MTLGPUFamilyApple1; --i) {
461
+ if ([device supportsFamily:i]) {
462
  GGML_LOG_INFO("%s: GPU family: MTLGPUFamilyApple%d (%d)\n", __func__, i - (int) MTLGPUFamilyApple1 + 1, i);
463
  break;
464
  }
465
  }
466
 
467
  for (int i = MTLGPUFamilyCommon1 + 5; i >= MTLGPUFamilyCommon1; --i) {
468
+ if ([device supportsFamily:i]) {
469
  GGML_LOG_INFO("%s: GPU family: MTLGPUFamilyCommon%d (%d)\n", __func__, i - (int) MTLGPUFamilyCommon1 + 1, i);
470
  break;
471
  }
472
  }
473
 
474
+ for (int i = MTLGPUFamilyMetal3_GGML + 5; i >= MTLGPUFamilyMetal3_GGML; --i) {
475
+ if ([device supportsFamily:i]) {
476
+ GGML_LOG_INFO("%s: GPU family: MTLGPUFamilyMetal%d (%d)\n", __func__, i - (int) MTLGPUFamilyMetal3_GGML + 3, i);
477
  break;
478
  }
479
  }
480
  }
481
 
482
+ GGML_LOG_INFO("%s: simdgroup reduction support = %s\n", __func__, ctx_dev->support_simdgroup_reduction ? "true" : "false");
483
+ GGML_LOG_INFO("%s: simdgroup matrix mul. support = %s\n", __func__, ctx_dev->support_simdgroup_mm ? "true" : "false");
484
+ GGML_LOG_INFO("%s: hasUnifiedMemory = %s\n", __func__, ctx_dev->mtl_device.hasUnifiedMemory ? "true" : "false");
 
 
 
 
 
485
 
486
  ctx->capture_next_compute = false;
487
  ctx->capture_started = false;
 
495
 
496
  #if TARGET_OS_OSX || (TARGET_OS_IOS && __clang_major__ >= 15)
497
  if (@available(macOS 10.12, iOS 16.0, *)) {
498
+ GGML_LOG_INFO("%s: recommendedMaxWorkingSetSize = %8.2f MB\n", __func__, device.recommendedMaxWorkingSetSize / 1e6);
 
 
 
 
 
 
499
  }
500
  #endif
501
 
 
516
  if (supported) { \
517
  struct ggml_metal_kernel * kernel = &ctx->kernels[e]; \
518
  id<MTLFunction> metal_function = [metal_library newFunctionWithName:@"kernel_"#name]; \
519
+ kernel->pipeline = [device newComputePipelineStateWithFunction:metal_function error:&error]; \
520
  [metal_function release]; \
521
  if (error) { \
522
  GGML_LOG_ERROR("%s: error: load pipeline error: %s\n", __func__, [[error description] UTF8String]); \
 
527
  GGML_LOG_WARN("%s: skipping %-40s (not supported)\n", __func__, "kernel_"#name); \
528
  }
529
 
530
+ const bool support_simdgroup_mm = ctx_dev->support_simdgroup_mm;
531
+ const bool support_simdgroup_reduction = ctx_dev->support_simdgroup_reduction;
532
+
533
  // simd_sum and simd_max requires MTLGPUFamilyApple7
534
 
535
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD, add, true);
 
556
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_QUICK_4, gelu_quick_4, true);
557
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU, silu, true);
558
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU_4, silu_4, true);
559
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16, soft_max_f16, support_simdgroup_reduction);
560
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16_4, soft_max_f16_4, support_simdgroup_reduction);
561
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32, soft_max_f32, support_simdgroup_reduction);
562
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32_4, soft_max_f32_4, support_simdgroup_reduction);
563
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF, diag_mask_inf, true);
564
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF_8, diag_mask_inf_8, true);
565
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_F32, get_rows_f32, true);
 
584
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL, get_rows_iq4_nl, true);
585
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS, get_rows_iq4_xs, true);
586
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_I32, get_rows_i32, true);
587
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RMS_NORM, rms_norm, support_simdgroup_reduction);
588
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GROUP_NORM, group_norm, support_simdgroup_reduction);
589
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NORM, norm, true);
590
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_CONV_F32, ssm_conv_f32, true);
591
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32, ssm_scan_f32, true);
592
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32, mul_mv_f32_f32, support_simdgroup_reduction);
593
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16, mul_mv_f16_f16, support_simdgroup_reduction);
594
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32, mul_mv_f16_f32, support_simdgroup_reduction);
595
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW, mul_mv_f16_f32_1row, support_simdgroup_reduction);
596
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4, mul_mv_f16_f32_l4, support_simdgroup_reduction);
597
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_0_F32, mul_mv_q4_0_f32, support_simdgroup_reduction);
598
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_1_F32, mul_mv_q4_1_f32, support_simdgroup_reduction);
599
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32, mul_mv_q5_0_f32, support_simdgroup_reduction);
600
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32, mul_mv_q5_1_f32, support_simdgroup_reduction);
601
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32, mul_mv_q8_0_f32, support_simdgroup_reduction);
602
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q2_K_F32, mul_mv_q2_K_f32, support_simdgroup_reduction);
603
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q3_K_F32, mul_mv_q3_K_f32, support_simdgroup_reduction);
604
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_K_F32, mul_mv_q4_K_f32, support_simdgroup_reduction);
605
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_K_F32, mul_mv_q5_K_f32, support_simdgroup_reduction);
606
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q6_K_F32, mul_mv_q6_K_f32, support_simdgroup_reduction);
607
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XXS_F32, mul_mv_iq2_xxs_f32, support_simdgroup_reduction);
608
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32, mul_mv_iq2_xs_f32, support_simdgroup_reduction);
609
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_XXS_F32, mul_mv_iq3_xxs_f32, support_simdgroup_reduction);
610
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_S_F32, mul_mv_iq3_s_f32, support_simdgroup_reduction);
611
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_S_F32, mul_mv_iq2_s_f32, support_simdgroup_reduction);
612
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_S_F32, mul_mv_iq1_s_f32, support_simdgroup_reduction);
613
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_M_F32, mul_mv_iq1_m_f32, support_simdgroup_reduction);
614
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32, mul_mv_iq4_nl_f32, support_simdgroup_reduction);
615
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_XS_F32, mul_mv_iq4_xs_f32, support_simdgroup_reduction);
616
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32, mul_mv_id_f32_f32, support_simdgroup_reduction);
617
+ //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F16, mul_mv_id_f16_f16, support_simdgroup_reduction);
618
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32, mul_mv_id_f16_f32, support_simdgroup_reduction);
619
+ //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_1ROW, mul_mv_id_f16_f32_1row, support_simdgroup_reduction);
620
+ //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_L4, mul_mv_id_f16_f32_l4, support_simdgroup_reduction);
621
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_0_F32, mul_mv_id_q4_0_f32, support_simdgroup_reduction);
622
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_1_F32, mul_mv_id_q4_1_f32, support_simdgroup_reduction);
623
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_0_F32, mul_mv_id_q5_0_f32, support_simdgroup_reduction);
624
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_1_F32, mul_mv_id_q5_1_f32, support_simdgroup_reduction);
625
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q8_0_F32, mul_mv_id_q8_0_f32, support_simdgroup_reduction);
626
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q2_K_F32, mul_mv_id_q2_K_f32, support_simdgroup_reduction);
627
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q3_K_F32, mul_mv_id_q3_K_f32, support_simdgroup_reduction);
628
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_K_F32, mul_mv_id_q4_K_f32, support_simdgroup_reduction);
629
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_K_F32, mul_mv_id_q5_K_f32, support_simdgroup_reduction);
630
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q6_K_F32, mul_mv_id_q6_K_f32, support_simdgroup_reduction);
631
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XXS_F32, mul_mv_id_iq2_xxs_f32, support_simdgroup_reduction);
632
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32, mul_mv_id_iq2_xs_f32, support_simdgroup_reduction);
633
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_XXS_F32, mul_mv_id_iq3_xxs_f32, support_simdgroup_reduction);
634
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_S_F32, mul_mv_id_iq3_s_f32, support_simdgroup_reduction);
635
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_S_F32, mul_mv_id_iq2_s_f32, support_simdgroup_reduction);
636
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_S_F32, mul_mv_id_iq1_s_f32, support_simdgroup_reduction);
637
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_M_F32, mul_mv_id_iq1_m_f32, support_simdgroup_reduction);
638
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32, mul_mv_id_iq4_nl_f32, support_simdgroup_reduction);
639
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32, mul_mv_id_iq4_xs_f32, support_simdgroup_reduction);
640
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32, mul_mm_f32_f32, support_simdgroup_mm);
641
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32, mul_mm_f16_f32, support_simdgroup_mm);
642
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32, mul_mm_q4_0_f32, support_simdgroup_mm);
643
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F32, mul_mm_q4_1_f32, support_simdgroup_mm);
644
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32, mul_mm_q5_0_f32, support_simdgroup_mm);
645
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_1_F32, mul_mm_q5_1_f32, support_simdgroup_mm);
646
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q8_0_F32, mul_mm_q8_0_f32, support_simdgroup_mm);
647
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q2_K_F32, mul_mm_q2_K_f32, support_simdgroup_mm);
648
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q3_K_F32, mul_mm_q3_K_f32, support_simdgroup_mm);
649
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_K_F32, mul_mm_q4_K_f32, support_simdgroup_mm);
650
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_K_F32, mul_mm_q5_K_f32, support_simdgroup_mm);
651
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q6_K_F32, mul_mm_q6_K_f32, support_simdgroup_mm);
652
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32, mul_mm_iq2_xxs_f32, support_simdgroup_mm);
653
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32, mul_mm_iq2_xs_f32, support_simdgroup_mm);
654
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F32, mul_mm_iq3_xxs_f32, support_simdgroup_mm);
655
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_S_F32, mul_mm_iq3_s_f32, support_simdgroup_mm);
656
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_S_F32, mul_mm_iq2_s_f32, support_simdgroup_mm);
657
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32, mul_mm_iq1_s_f32, support_simdgroup_mm);
658
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F32, mul_mm_iq1_m_f32, support_simdgroup_mm);
659
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32, mul_mm_iq4_nl_f32, support_simdgroup_mm);
660
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32, mul_mm_iq4_xs_f32, support_simdgroup_mm);
661
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32, mul_mm_id_f32_f32, support_simdgroup_mm);
662
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32, mul_mm_id_f16_f32, support_simdgroup_mm);
663
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32, mul_mm_id_q4_0_f32, support_simdgroup_mm);
664
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F32, mul_mm_id_q4_1_f32, support_simdgroup_mm);
665
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F32, mul_mm_id_q5_0_f32, support_simdgroup_mm);
666
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F32, mul_mm_id_q5_1_f32, support_simdgroup_mm);
667
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F32, mul_mm_id_q8_0_f32, support_simdgroup_mm);
668
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F32, mul_mm_id_q2_K_f32, support_simdgroup_mm);
669
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F32, mul_mm_id_q3_K_f32, support_simdgroup_mm);
670
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F32, mul_mm_id_q4_K_f32, support_simdgroup_mm);
671
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_K_F32, mul_mm_id_q5_K_f32, support_simdgroup_mm);
672
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F32, mul_mm_id_q6_K_f32, support_simdgroup_mm);
673
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32, mul_mm_id_iq2_xxs_f32, support_simdgroup_mm);
674
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32, mul_mm_id_iq2_xs_f32, support_simdgroup_mm);
675
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32, mul_mm_id_iq3_xxs_f32, support_simdgroup_mm);
676
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_S_F32, mul_mm_id_iq3_s_f32, support_simdgroup_mm);
677
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F32, mul_mm_id_iq2_s_f32, support_simdgroup_mm);
678
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32, mul_mm_id_iq1_s_f32, support_simdgroup_mm);
679
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F32, mul_mm_id_iq1_m_f32, support_simdgroup_mm);
680
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32, mul_mm_id_iq4_nl_f32, support_simdgroup_mm);
681
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32, mul_mm_id_iq4_xs_f32, support_simdgroup_mm);
682
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32, rope_norm_f32, true);
683
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16, rope_norm_f16, true);
684
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32, rope_neox_f32, true);
 
692
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC, argsort_f32_i32_asc, true);
693
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC, argsort_f32_i32_desc, true);
694
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32, leaky_relu_f32, true);
695
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64, flash_attn_ext_f16_h64, support_simdgroup_mm);
696
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80, flash_attn_ext_f16_h80, support_simdgroup_mm);
697
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96, flash_attn_ext_f16_h96, support_simdgroup_mm);
698
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112, flash_attn_ext_f16_h112, support_simdgroup_mm);
699
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128, flash_attn_ext_f16_h128, support_simdgroup_mm);
700
+ //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, flash_attn_ext_f16_h256, support_simdgroup_mm);
701
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128, flash_attn_ext_vec_f16_h128, support_simdgroup_reduction);
702
+ //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, flash_attn_ext_vec_f16_h256, support_simdgroup_reduction);
703
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F16, cpy_f32_f16, true);
704
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F32, cpy_f32_f32, true);
705
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F16, cpy_f16_f16, true);
 
733
  Block_release(ctx->encode_async);
734
 
735
  [ctx->queue release];
 
736
 
737
  dispatch_release(ctx->d_queue);
738
 
 
790
  return nil;
791
  }
792
 
793
+ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_context * ctx_dev, const struct ggml_tensor * op) {
794
  for (size_t i = 0, n = 3; i < n; ++i) {
795
  if (op->src[i] != NULL && op->src[i]->type == GGML_TYPE_BF16) {
796
  return false;
797
  }
798
  }
799
 
800
+ const bool support_simdgroup_mm = ctx_dev->support_simdgroup_mm;
801
+ const bool support_simdgroup_reduction = ctx_dev->support_simdgroup_reduction;
802
+
803
  switch (op->op) {
804
  case GGML_OP_UNARY:
805
  switch (ggml_get_unary_op(op)) {
 
837
  case GGML_OP_SOFT_MAX:
838
  case GGML_OP_RMS_NORM:
839
  case GGML_OP_GROUP_NORM:
840
+ return support_simdgroup_reduction;
841
  case GGML_OP_NORM:
842
  case GGML_OP_ROPE:
843
  return true;
 
863
  if (op->src[0]->ne[0] == 256) {
864
  return false;
865
  }
866
+ return support_simdgroup_mm; // TODO: over-restricted for vec-kernels
867
  case GGML_OP_SSM_CONV:
868
  case GGML_OP_SSM_SCAN:
869
  return true;
870
  case GGML_OP_MUL_MAT:
871
  case GGML_OP_MUL_MAT_ID:
872
+ return support_simdgroup_reduction &&
873
  (op->src[0]->type != GGML_TYPE_F32 || op->src[1]->type == GGML_TYPE_F32);
874
  case GGML_OP_CPY:
875
  case GGML_OP_DUP:
 
913
  }
914
 
915
  static void ggml_metal_encode_node(
916
+ ggml_backend_t backend,
917
  int idx,
918
  id<MTLComputeCommandEncoder> encoder) {
919
+ struct ggml_backend_metal_context * ctx = backend->context;
920
+ struct ggml_backend_metal_device_context * ctx_dev = backend->device->context;
921
+
922
  struct ggml_cgraph * gf = ctx->gf;
923
 
924
  struct ggml_tensor * node = ggml_graph_node(gf, idx);
 
948
  } break;
949
  }
950
 
951
+ if (!ggml_metal_supports_op(ctx_dev, dst)) {
952
  GGML_LOG_ERROR("%s: error: unsupported op '%s'\n", __func__, ggml_op_desc(dst));
953
  GGML_ABORT("unsupported op");
954
  }
 
1021
  // dst->name);
1022
  //}
1023
 
1024
+ id<MTLDevice> device = ctx_dev->mtl_device;
1025
+
1026
  switch (dst->op) {
1027
  case GGML_OP_CONCAT:
1028
  {
 
1731
  // the numbers below are measured on M2 Ultra for 7B and 13B models
1732
  // these numbers do not translate to other devices or model sizes
1733
  // TODO: need to find a better approach
1734
+ if ([device.name isEqualToString:@"Apple M2 Ultra"]) {
1735
  switch (src0t) {
1736
  case GGML_TYPE_F16: ne11_mm_min = 2; break;
1737
  case GGML_TYPE_Q8_0: ne11_mm_min = 7; break;
 
1751
 
1752
  // for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
1753
  // AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
1754
+ if ([device supportsFamily:MTLGPUFamilyApple7] &&
1755
  !ggml_is_transposed(src0) &&
1756
  !ggml_is_transposed(src1) &&
1757
  src1t == GGML_TYPE_F32 &&
 
2046
  // ne21 = n_rows
2047
  const int dst_rows = ne20*ne21;
2048
  const int dst_rows_min = n_as;
2049
+ const int dst_rows_max = (device.maxThreadgroupMemoryLength - 32 - 8192)/4;
2050
 
2051
  // max size of the rowids array in the kernel shared buffer
2052
  GGML_ASSERT(dst_rows <= dst_rows_max);
 
2057
  // TODO: for now, always use mat-vec kernels until we figure out how to improve the
2058
  // indirect matrix multiplication
2059
  // !!!
2060
+ if ([device supportsFamily:MTLGPUFamilyApple7] &&
2061
  ne00 % 32 == 0 && ne00 >= 64 &&
2062
  dst_rows > dst_rows_min) {
2063
 
 
2896
 
2897
  while (true) {
2898
  const size_t smem = nqptg*(ne00 + 2*nsgmax*(ncpsg + nqptg))*(sizeof(float)/2);
2899
+ if (smem > device.maxThreadgroupMemoryLength) {
2900
  break;
2901
  }
2902
  nsgmax *= 2;
 
2908
 
2909
  const size_t smem = nqptg*(ne00 + 2*nsg*(ncpsg + nqptg))*(sizeof(float)/2);
2910
 
2911
+ //printf("smem: %zu, max: %zu\n", smem, device.maxThreadgroupMemoryLength);
2912
+ GGML_ASSERT(smem <= device.maxThreadgroupMemoryLength);
2913
 
2914
  [encoder setThreadgroupMemoryLength:GGML_PAD(smem, 16) atIndex:0];
2915
 
 
2934
 
2935
  const size_t smem = (nqptg*(ne00 + 2*nsg*(ncpsg + nqptg)) + nsg*ne00)*(sizeof(float)/2);
2936
 
2937
+ //printf("smem: %zu, max: %zu\n", smem, device.maxThreadgroupMemoryLength);
2938
+ GGML_ASSERT(smem <= device.maxThreadgroupMemoryLength);
2939
  [encoder setThreadgroupMemoryLength:GGML_PAD(smem, 16) atIndex:0];
2940
 
2941
  [encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)];
 
3010
  }
3011
 
3012
  static enum ggml_status ggml_metal_graph_compute(
3013
+ ggml_backend_t backend,
3014
+ struct ggml_cgraph * gf) {
3015
+ struct ggml_backend_metal_context * ctx = backend->context;
3016
+ struct ggml_backend_metal_device_context * ctx_dev = backend->device->context;
3017
+
3018
  // number of nodes encoded by the main thread (empirically determined)
3019
  const int n_main = 128;
3020
 
 
3042
 
3043
  if (!ctx->capture_started) {
3044
  // create capture scope
3045
+ ctx->capture_scope = [[MTLCaptureManager sharedCaptureManager] newCaptureScopeWithDevice:ctx_dev->mtl_device];
3046
 
3047
  MTLCaptureDescriptor * descriptor = [MTLCaptureDescriptor new];
3048
  descriptor.captureObject = ctx->capture_scope;
 
3146
 
3147
  // backend interface
3148
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3149
  static const char * ggml_backend_metal_buffer_get_name(ggml_backend_buffer_t buffer) {
3150
  return "Metal";
3151
 
 
3158
  for (int i = 0; i < ctx->n_buffers; i++) {
3159
  [ctx->buffers[i].metal release];
3160
  }
3161
+ ggml_backend_metal_device_rel(buffer->buft->device->context);
3162
 
3163
  if (ctx->owned) {
3164
  #if TARGET_OS_OSX
 
3261
  size_aligned += (size_page - (size_aligned % size_page));
3262
  }
3263
 
3264
+ id<MTLDevice> device = ggml_backend_metal_device_acq(buft->device->context);
3265
 
3266
  ctx->all_data = ggml_metal_host_malloc(size_aligned);
3267
  ctx->all_size = size_aligned;
 
3275
 
3276
  if (size_aligned > 0) {
3277
  ctx->buffers[0].metal = [device newBufferWithBytesNoCopy:ctx->all_data
3278
+ length:size_aligned
3279
+ options:MTLResourceStorageModeShared
3280
+ deallocator:nil];
3281
  }
3282
  }
3283
 
3284
  if (size_aligned > 0 && (ctx->all_data == NULL || ctx->buffers[0].metal == nil)) {
3285
  GGML_LOG_ERROR("%s: error: failed to allocate buffer, size = %8.2f MiB\n", __func__, size_aligned / 1024.0 / 1024.0);
3286
  free(ctx);
3287
+ ggml_backend_metal_device_rel(buft->device->context);
3288
  return NULL;
3289
  }
3290
 
 
3299
  }
3300
 
3301
  static size_t ggml_backend_metal_buffer_type_get_max_size(ggml_backend_buffer_type_t buft) {
3302
+ id<MTLDevice> device = ggml_backend_metal_device_acq(buft->device->context);
3303
+ const size_t max_size = device.maxBufferLength;
3304
+ ggml_backend_metal_device_rel(buft->device->context);
3305
 
3306
  return max_size;
3307
 
 
3324
  /* .get_alloc_size = */ NULL, // defaults to ggml_nbytes
3325
  /* .is_host = */ ggml_backend_metal_buffer_type_is_host,
3326
  },
3327
+ /* .device = */ &g_ggml_backend_metal_device,
3328
  /* .context = */ NULL,
3329
  };
3330
 
3331
  return &ggml_backend_buffer_type_metal;
3332
  }
3333
 
3334
+ // TODO: obsoleted by ggml_backend_metal_device_buffer_from_ptr
 
3335
  ggml_backend_buffer_t ggml_backend_metal_buffer_from_ptr(void * data, size_t size, size_t max_size) {
3336
  struct ggml_backend_metal_buffer_context * ctx = calloc(1, sizeof(struct ggml_backend_metal_buffer_context));
3337
 
 
3354
  size_aligned += (size_page - (size_aligned % size_page));
3355
  }
3356
 
3357
+ id<MTLDevice> device = ggml_backend_metal_device_acq(&g_ggml_ctx_dev_main);
3358
 
3359
  // the buffer fits into the max buffer size allowed by the device
3360
  if (size_aligned <= device.maxBufferLength) {
 
3419
  }
3420
 
3421
  static void ggml_backend_metal_free(ggml_backend_t backend) {
3422
+ struct ggml_backend_metal_context * ctx = backend->context;
3423
+ struct ggml_backend_metal_device_context * ctx_dev = backend->device->context;
3424
+
3425
+ ggml_backend_metal_device_rel(ctx_dev);
3426
  ggml_metal_free(ctx);
3427
+
3428
  free(backend);
3429
  }
3430
 
 
3435
  }
3436
 
3437
  static enum ggml_status ggml_backend_metal_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) {
3438
+ return ggml_metal_graph_compute(backend, cgraph);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3439
  }
3440
 
3441
  static void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) {
 
3482
  [encoder pushDebugGroup:[NSString stringWithCString:ggml_op_desc(ggml_graph_node(ctx->gf, idx)) encoding:NSUTF8StringEncoding]];
3483
  }
3484
 
3485
+ ggml_metal_encode_node(backend, idx, encoder);
3486
 
3487
  if (should_capture) {
3488
  [encoder popDebugGroup];
 
3510
  /* .graph_plan_update = */ NULL,
3511
  /* .graph_plan_compute = */ NULL,
3512
  /* .graph_compute = */ ggml_backend_metal_graph_compute,
3513
+ /* .supports_op = */ NULL,
3514
+ /* .supports_buft = */ NULL,
3515
  /* .offload_op = */ NULL,
3516
  /* .event_record = */ NULL,
3517
  /* .event_wait = */ NULL,
 
3522
  return &guid;
3523
  }
3524
 
3525
+ // TODO: remove in the future
3526
  ggml_backend_t ggml_backend_metal_init(void) {
3527
+ ggml_backend_dev_t dev = ggml_backend_reg_dev_get(ggml_backend_metal_reg(), 0);
3528
+
3529
+ struct ggml_backend_metal_context * ctx = ggml_metal_init(dev);
3530
  if (ctx == NULL) {
3531
  GGML_LOG_ERROR("%s: error: failed to allocate context\n", __func__);
3532
  return NULL;
 
3537
  *backend = (struct ggml_backend) {
3538
  /* .guid = */ ggml_backend_metal_guid(),
3539
  /* .interface = */ ggml_backend_metal_i,
3540
+ /* .device = */ dev,
3541
  /* .context = */ ctx,
3542
  };
3543
 
 
3562
  bool ggml_backend_metal_supports_family(ggml_backend_t backend, int family) {
3563
  GGML_ASSERT(ggml_backend_is_metal(backend));
3564
 
3565
+ struct ggml_backend_metal_device_context * ctx_dev = backend->device->context;
3566
 
3567
+ return [ctx_dev->mtl_device supportsFamily:(MTLGPUFamilyApple1 + family - 1)];
3568
  }
3569
 
3570
  void ggml_backend_metal_capture_next_compute(ggml_backend_t backend) {
 
3574
  ctx->capture_next_compute = true;
3575
  }
3576
 
3577
+ // backend device
3578
+
3579
+ static const char * ggml_backend_metal_device_get_name(ggml_backend_dev_t dev) {
3580
+ return "Metal";
3581
 
3582
+ GGML_UNUSED(dev);
3583
+ }
3584
+
3585
+ static const char * ggml_backend_metal_device_get_description(ggml_backend_dev_t dev) {
3586
+ // acq/rel just to populate ctx->name in case it hasn't been done yet
3587
+ struct ggml_backend_metal_device_context * ctx_dev = (struct ggml_backend_metal_device_context *)dev->context;
3588
+ ggml_backend_metal_device_acq(ctx_dev);
3589
+ ggml_backend_metal_device_rel(ctx_dev);
3590
+
3591
+ return ctx_dev->name;
3592
+ }
3593
+
3594
+ static void ggml_backend_metal_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) {
3595
+ if (@available(macOS 10.12, iOS 16.0, *)) {
3596
+ struct ggml_backend_metal_device_context * ctx_dev = (struct ggml_backend_metal_device_context *)dev->context;
3597
+ id<MTLDevice> device = ggml_backend_metal_device_acq(ctx_dev);
3598
+
3599
+ *total = device.recommendedMaxWorkingSetSize;
3600
+ *free = *total - device.currentAllocatedSize;
3601
+
3602
+ ggml_backend_metal_device_rel(ctx_dev);
3603
+ } else {
3604
+ *free = 1;
3605
+ *total = 1;
3606
+ }
3607
+ }
3608
+
3609
+ static enum ggml_backend_dev_type ggml_backend_metal_device_get_type(ggml_backend_dev_t dev) {
3610
+ return GGML_BACKEND_DEVICE_TYPE_GPU_FULL;
3611
+
3612
+ GGML_UNUSED(dev);
3613
+ }
3614
+
3615
+ static void ggml_backend_metal_device_get_props(ggml_backend_dev_t dev, struct ggml_backend_dev_props * props) {
3616
+ props->name = ggml_backend_metal_device_get_name(dev);
3617
+ props->description = ggml_backend_metal_device_get_description(dev);
3618
+ props->type = ggml_backend_metal_device_get_type(dev);
3619
+ ggml_backend_metal_device_get_memory(dev, &props->memory_free, &props->memory_total);
3620
+ props->caps = (struct ggml_backend_dev_caps) {
3621
+ /* .async = */ false,
3622
+ /* .host_buffer = */ false,
3623
+ /* .buffer_from_host_ptr = */ true,
3624
+ /* .events = */ false,
3625
+ };
3626
+ }
3627
+
3628
+ static ggml_backend_t ggml_backend_metal_device_init(ggml_backend_dev_t dev, const char * params) {
3629
+ struct ggml_backend_metal_context * ctx = ggml_metal_init(dev);
3630
+ if (ctx == NULL) {
3631
+ GGML_LOG_ERROR("%s: error: failed to allocate context\n", __func__);
3632
+ return NULL;
3633
+ }
3634
+
3635
+ ggml_backend_t backend = malloc(sizeof(struct ggml_backend));
3636
+
3637
+ *backend = (struct ggml_backend) {
3638
+ /* .guid = */ ggml_backend_metal_guid(),
3639
+ /* .interface = */ ggml_backend_metal_i,
3640
+ /* .device = */ dev,
3641
+ /* .context = */ ctx,
3642
+ };
3643
+
3644
+ ggml_backend_metal_set_n_cb(backend, 1);
3645
+
3646
+ return backend;
3647
 
3648
  GGML_UNUSED(params);
3649
+ }
3650
+
3651
+ static ggml_backend_buffer_type_t ggml_backend_metal_device_get_buffer_type(ggml_backend_dev_t dev) {
3652
+ return ggml_backend_metal_buffer_type();
3653
+
3654
+ GGML_UNUSED(dev);
3655
+ }
3656
+
3657
+ static ggml_backend_buffer_t ggml_backend_metal_device_buffer_from_ptr(ggml_backend_dev_t dev, void * ptr, size_t size, size_t max_tensor_size) {
3658
+ struct ggml_backend_metal_buffer_context * ctx = calloc(1, sizeof(struct ggml_backend_metal_buffer_context));
3659
+
3660
+ ctx->all_data = ptr;
3661
+ ctx->all_size = size;
3662
+ ctx->owned = false;
3663
+ ctx->n_buffers = 0;
3664
+
3665
+ const size_t size_page = sysconf(_SC_PAGESIZE);
3666
+
3667
+ // page-align the data ptr
3668
+ {
3669
+ const uintptr_t offs = (uintptr_t) ptr % size_page;
3670
+ ptr = (void *) ((char *) ptr - offs);
3671
+ size += offs;
3672
+ }
3673
+
3674
+ size_t size_aligned = size;
3675
+ if ((size_aligned % size_page) != 0) {
3676
+ size_aligned += (size_page - (size_aligned % size_page));
3677
+ }
3678
+
3679
+ struct ggml_backend_metal_device_context * ctx_dev = (struct ggml_backend_metal_device_context *)dev->context;
3680
+ id<MTLDevice> device = ggml_backend_metal_device_acq(ctx_dev);
3681
+
3682
+ // the buffer fits into the max buffer size allowed by the device
3683
+ if (size_aligned <= device.maxBufferLength) {
3684
+ ctx->buffers[ctx->n_buffers].data = ptr;
3685
+ ctx->buffers[ctx->n_buffers].size = size;
3686
+ ctx->buffers[ctx->n_buffers].metal = nil;
3687
+
3688
+ if (size_aligned > 0) {
3689
+ ctx->buffers[ctx->n_buffers].metal = [device newBufferWithBytesNoCopy:ptr length:size_aligned options:MTLResourceStorageModeShared deallocator:nil];
3690
+
3691
+ if (ctx->buffers[ctx->n_buffers].metal == nil) {
3692
+ GGML_LOG_ERROR("%s: error: failed to allocate buffer, size = %8.2f MiB\n", __func__, size_aligned / 1024.0 / 1024.0);
3693
+ return false;
3694
+ }
3695
+ }
3696
+
3697
+ ggml_backend_metal_log_allocated_size(device, size_aligned);
3698
+
3699
+ ++ctx->n_buffers;
3700
+ } else {
3701
+ // this overlap between the views will guarantee that the tensor with the maximum size will fully fit into
3702
+ // one of the views
3703
+ const size_t size_ovlp = ((max_tensor_size + size_page - 1) / size_page + 1) * size_page; // round-up 2 pages just in case
3704
+ const size_t size_step = device.maxBufferLength - size_ovlp;
3705
+ const size_t size_view = device.maxBufferLength;
3706
+
3707
+ for (size_t i = 0; i < size; i += size_step) {
3708
+ const size_t size_step_aligned = (i + size_view <= size) ? size_view : (size_aligned - i);
3709
+
3710
+ ctx->buffers[ctx->n_buffers].data = (void *) ((uint8_t *) ptr + i);
3711
+ ctx->buffers[ctx->n_buffers].size = size_step_aligned;
3712
+ ctx->buffers[ctx->n_buffers].metal = nil;
3713
+
3714
+ if (size_step_aligned > 0) {
3715
+ ctx->buffers[ctx->n_buffers].metal = [device newBufferWithBytesNoCopy:(void *) ((uint8_t *) ptr + i) length:size_step_aligned options:MTLResourceStorageModeShared deallocator:nil];
3716
+
3717
+ if (ctx->buffers[ctx->n_buffers].metal == nil) {
3718
+ GGML_LOG_ERROR("%s: error: failed to allocate buffer, size = %8.2f MiB\n", __func__, size_step_aligned / 1024.0 / 1024.0);
3719
+ return false;
3720
+ }
3721
+ }
3722
+
3723
+ ggml_backend_metal_log_allocated_size(device, size_step_aligned);
3724
+
3725
+ if (i + size_step < size) {
3726
+ GGML_LOG_INFO("\n");
3727
+ }
3728
+
3729
+ ++ctx->n_buffers;
3730
+ }
3731
+ }
3732
+
3733
+ return ggml_backend_buffer_init(ggml_backend_metal_buffer_type(), ggml_backend_metal_buffer_i, ctx, size);
3734
+ }
3735
+
3736
+ static bool ggml_backend_metal_device_supports_op(ggml_backend_dev_t dev, const struct ggml_tensor * op) {
3737
+ struct ggml_backend_metal_device_context * ctx_dev = dev->context;
3738
+
3739
+ return ggml_metal_supports_op(ctx_dev, op);
3740
+ }
3741
+
3742
+ static bool ggml_backend_metal_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) {
3743
+ return buft->iface.get_name == ggml_backend_metal_buffer_type_get_name;
3744
+
3745
+ UNUSED(dev);
3746
+ }
3747
+
3748
+ static bool ggml_backend_metal_device_offload_op(ggml_backend_dev_t dev, const struct ggml_tensor * op) {
3749
+ return false;
3750
+
3751
+ GGML_UNUSED(dev);
3752
+ GGML_UNUSED(op);
3753
+ }
3754
+
3755
+ static struct ggml_backend_device_i ggml_backend_metal_device_i = {
3756
+ /* .get_name = */ ggml_backend_metal_device_get_name,
3757
+ /* .get_description = */ ggml_backend_metal_device_get_description,
3758
+ /* .get_memory = */ ggml_backend_metal_device_get_memory,
3759
+ /* .get_type = */ ggml_backend_metal_device_get_type,
3760
+ /* .get_props = */ ggml_backend_metal_device_get_props,
3761
+ /* .init_backend = */ ggml_backend_metal_device_init,
3762
+ /* .get_buffer_type = */ ggml_backend_metal_device_get_buffer_type,
3763
+ /* .get_host_buffer_type = */ NULL,
3764
+ /* .buffer_from_host_ptr = */ ggml_backend_metal_device_buffer_from_ptr,
3765
+ /* .supports_op = */ ggml_backend_metal_device_supports_op,
3766
+ /* .supports_buft = */ ggml_backend_metal_device_supports_buft,
3767
+ /* .offload_op = */ ggml_backend_metal_device_offload_op,
3768
+ /* .event_new = */ NULL,
3769
+ /* .event_free = */ NULL,
3770
+ /* .event_synchronize = */ NULL,
3771
+ };
3772
+
3773
+ // backend registry
3774
+
3775
+ static const char * ggml_backend_metal_reg_get_name(ggml_backend_reg_t reg) {
3776
+ return "Metal";
3777
+
3778
+ GGML_UNUSED(reg);
3779
+ }
3780
+
3781
+ static size_t ggml_backend_metal_reg_device_count(ggml_backend_reg_t reg) {
3782
+ return 1;
3783
+
3784
+ GGML_UNUSED(reg);
3785
+ }
3786
+
3787
+ static ggml_backend_dev_t ggml_backend_metal_reg_device_get(ggml_backend_reg_t reg, size_t index) {
3788
+ GGML_ASSERT(index == 0);
3789
+
3790
+ return &g_ggml_backend_metal_device;
3791
+
3792
+ GGML_UNUSED(reg);
3793
+ GGML_UNUSED(index);
3794
+ }
3795
+
3796
+ static struct ggml_backend_reg_i ggml_backend_metal_reg_i = {
3797
+ /* .get_name = */ ggml_backend_metal_reg_get_name,
3798
+ /* .device_count = */ ggml_backend_metal_reg_device_count,
3799
+ /* .device_get = */ ggml_backend_metal_reg_device_get,
3800
+ /* .get_proc_address = */ NULL,
3801
+ };
3802
+
3803
+ ggml_backend_reg_t ggml_backend_metal_reg(void) {
3804
+ // TODO: make this thread-safe somehow?
3805
+ {
3806
+ g_ggml_backend_metal_reg = (struct ggml_backend_reg) {
3807
+ /* .iface = */ ggml_backend_metal_reg_i,
3808
+ /* .context = */ NULL,
3809
+ };
3810
+
3811
+ g_ggml_backend_metal_device = (struct ggml_backend_device) {
3812
+ /* .iface = */ ggml_backend_metal_device_i,
3813
+ /* .reg = */ &g_ggml_backend_metal_reg,
3814
+ /* .context = */ &g_ggml_ctx_dev_main,
3815
+ };
3816
+ }
3817
+
3818
+ return &g_ggml_backend_metal_reg;
3819
  }