Spaces:
Running
Running
ggml : add metal backend registry / device (llama/9713)
Browse files* ggml : add metal backend registry / device
ggml-ci
* metal : fix names [no ci]
* metal : global registry and device instances
ggml-ci
* cont : alternative initialization of global objects
ggml-ci
* llama : adapt to backend changes
ggml-ci
* fixes
* metal : fix indent
* metal : fix build when MTLGPUFamilyApple3 is not available
ggml-ci
* fix merge
* metal : avoid unnecessary singleton accesses
ggml-ci
* metal : minor fix [no ci]
* metal : g_state -> g_ggml_ctx_dev_main [no ci]
* metal : avoid reference of device context in the backend context
ggml-ci
* metal : minor [no ci]
* metal : fix maxTransferRate check
* metal : remove transfer rate stuff
---------
Co-authored-by: slaren <[email protected]>
- ggml/include/ggml-backend.h +2 -0
- ggml/include/ggml-metal.h +5 -1
- ggml/src/ggml-backend.cpp +17 -4
- ggml/src/ggml-cuda.cu +4 -3
- ggml/src/ggml-metal.m +486 -225
ggml/include/ggml-backend.h
CHANGED
|
@@ -127,6 +127,8 @@ extern "C" {
|
|
| 127 |
bool async;
|
| 128 |
// pinned host buffer
|
| 129 |
bool host_buffer;
|
|
|
|
|
|
|
| 130 |
// event synchronization
|
| 131 |
bool events;
|
| 132 |
};
|
|
|
|
| 127 |
bool async;
|
| 128 |
// pinned host buffer
|
| 129 |
bool host_buffer;
|
| 130 |
+
// creating buffers from host ptr
|
| 131 |
+
bool buffer_from_host_ptr;
|
| 132 |
// event synchronization
|
| 133 |
bool events;
|
| 134 |
};
|
ggml/include/ggml-metal.h
CHANGED
|
@@ -43,7 +43,9 @@ GGML_API ggml_backend_t ggml_backend_metal_init(void);
|
|
| 43 |
|
| 44 |
GGML_API bool ggml_backend_is_metal(ggml_backend_t backend);
|
| 45 |
|
| 46 |
-
|
|
|
|
|
|
|
| 47 |
|
| 48 |
GGML_API void ggml_backend_metal_set_abort_callback(ggml_backend_t backend, ggml_abort_callback abort_callback, void * user_data);
|
| 49 |
|
|
@@ -57,6 +59,8 @@ GGML_API bool ggml_backend_metal_supports_family(ggml_backend_t backend, int fam
|
|
| 57 |
// capture all command buffers committed the next time `ggml_backend_graph_compute` is called
|
| 58 |
GGML_API void ggml_backend_metal_capture_next_compute(ggml_backend_t backend);
|
| 59 |
|
|
|
|
|
|
|
| 60 |
#ifdef __cplusplus
|
| 61 |
}
|
| 62 |
#endif
|
|
|
|
| 43 |
|
| 44 |
GGML_API bool ggml_backend_is_metal(ggml_backend_t backend);
|
| 45 |
|
| 46 |
+
GGML_DEPRECATED(
|
| 47 |
+
GGML_API ggml_backend_buffer_t ggml_backend_metal_buffer_from_ptr(void * data, size_t size, size_t max_size),
|
| 48 |
+
"obsoleted by the new device interface - https://github.com/ggerganov/llama.cpp/pull/9713");
|
| 49 |
|
| 50 |
GGML_API void ggml_backend_metal_set_abort_callback(ggml_backend_t backend, ggml_abort_callback abort_callback, void * user_data);
|
| 51 |
|
|
|
|
| 59 |
// capture all command buffers committed the next time `ggml_backend_graph_compute` is called
|
| 60 |
GGML_API void ggml_backend_metal_capture_next_compute(ggml_backend_t backend);
|
| 61 |
|
| 62 |
+
GGML_API ggml_backend_reg_t ggml_backend_metal_reg(void);
|
| 63 |
+
|
| 64 |
#ifdef __cplusplus
|
| 65 |
}
|
| 66 |
#endif
|
ggml/src/ggml-backend.cpp
CHANGED
|
@@ -463,6 +463,7 @@ enum ggml_backend_dev_type ggml_backend_dev_type(ggml_backend_dev_t device) {
|
|
| 463 |
}
|
| 464 |
|
| 465 |
void ggml_backend_dev_get_props(ggml_backend_dev_t device, struct ggml_backend_dev_props * props) {
|
|
|
|
| 466 |
device->iface.get_props(device, props);
|
| 467 |
}
|
| 468 |
|
|
@@ -479,6 +480,10 @@ ggml_backend_buffer_type_t ggml_backend_dev_buffer_type(ggml_backend_dev_t devic
|
|
| 479 |
}
|
| 480 |
|
| 481 |
ggml_backend_buffer_type_t ggml_backend_dev_host_buffer_type(ggml_backend_dev_t device) {
|
|
|
|
|
|
|
|
|
|
|
|
|
| 482 |
return device->iface.get_host_buffer_type(device);
|
| 483 |
}
|
| 484 |
|
|
@@ -525,6 +530,10 @@ void * ggml_backend_reg_get_proc_address(ggml_backend_reg_t reg, const char * na
|
|
| 525 |
#include "ggml-cuda.h"
|
| 526 |
#endif
|
| 527 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 528 |
struct ggml_backend_registry {
|
| 529 |
std::vector<ggml_backend_reg_t> backends;
|
| 530 |
std::vector<ggml_backend_dev_t> devices;
|
|
@@ -533,10 +542,13 @@ struct ggml_backend_registry {
|
|
| 533 |
#ifdef GGML_USE_CUDA
|
| 534 |
register_backend(ggml_backend_cuda_reg());
|
| 535 |
#endif
|
|
|
|
|
|
|
|
|
|
| 536 |
|
| 537 |
register_backend(ggml_backend_cpu_reg());
|
| 538 |
|
| 539 |
-
// TODO: sycl,
|
| 540 |
}
|
| 541 |
|
| 542 |
void register_backend(ggml_backend_reg_t reg) {
|
|
@@ -1118,9 +1130,10 @@ static void ggml_backend_cpu_device_get_props(ggml_backend_dev_t dev, struct ggm
|
|
| 1118 |
props->type = ggml_backend_cpu_device_get_type(dev);
|
| 1119 |
ggml_backend_cpu_device_get_memory(dev, &props->memory_free, &props->memory_total);
|
| 1120 |
props->caps = {
|
| 1121 |
-
/* async
|
| 1122 |
-
/* host_buffer */ false,
|
| 1123 |
-
/*
|
|
|
|
| 1124 |
};
|
| 1125 |
}
|
| 1126 |
|
|
|
|
| 463 |
}
|
| 464 |
|
| 465 |
void ggml_backend_dev_get_props(ggml_backend_dev_t device, struct ggml_backend_dev_props * props) {
|
| 466 |
+
memset(props, 0, sizeof(*props));
|
| 467 |
device->iface.get_props(device, props);
|
| 468 |
}
|
| 469 |
|
|
|
|
| 480 |
}
|
| 481 |
|
| 482 |
ggml_backend_buffer_type_t ggml_backend_dev_host_buffer_type(ggml_backend_dev_t device) {
|
| 483 |
+
if (device->iface.get_host_buffer_type == NULL) {
|
| 484 |
+
return NULL;
|
| 485 |
+
}
|
| 486 |
+
|
| 487 |
return device->iface.get_host_buffer_type(device);
|
| 488 |
}
|
| 489 |
|
|
|
|
| 530 |
#include "ggml-cuda.h"
|
| 531 |
#endif
|
| 532 |
|
| 533 |
+
#ifdef GGML_USE_METAL
|
| 534 |
+
#include "ggml-metal.h"
|
| 535 |
+
#endif
|
| 536 |
+
|
| 537 |
struct ggml_backend_registry {
|
| 538 |
std::vector<ggml_backend_reg_t> backends;
|
| 539 |
std::vector<ggml_backend_dev_t> devices;
|
|
|
|
| 542 |
#ifdef GGML_USE_CUDA
|
| 543 |
register_backend(ggml_backend_cuda_reg());
|
| 544 |
#endif
|
| 545 |
+
#ifdef GGML_USE_METAL
|
| 546 |
+
register_backend(ggml_backend_metal_reg());
|
| 547 |
+
#endif
|
| 548 |
|
| 549 |
register_backend(ggml_backend_cpu_reg());
|
| 550 |
|
| 551 |
+
// TODO: sycl, vulkan, kompute, cann
|
| 552 |
}
|
| 553 |
|
| 554 |
void register_backend(ggml_backend_reg_t reg) {
|
|
|
|
| 1130 |
props->type = ggml_backend_cpu_device_get_type(dev);
|
| 1131 |
ggml_backend_cpu_device_get_memory(dev, &props->memory_free, &props->memory_total);
|
| 1132 |
props->caps = {
|
| 1133 |
+
/* .async = */ false,
|
| 1134 |
+
/* .host_buffer = */ false,
|
| 1135 |
+
/* .buffer_from_host_ptr = */ true,
|
| 1136 |
+
/* .events = */ false,
|
| 1137 |
};
|
| 1138 |
}
|
| 1139 |
|
ggml/src/ggml-cuda.cu
CHANGED
|
@@ -2920,9 +2920,10 @@ static void ggml_backend_cuda_device_get_props(ggml_backend_dev_t dev, ggml_back
|
|
| 2920 |
#endif
|
| 2921 |
|
| 2922 |
props->caps = {
|
| 2923 |
-
/* async
|
| 2924 |
-
/* host_buffer */ host_buffer,
|
| 2925 |
-
/*
|
|
|
|
| 2926 |
};
|
| 2927 |
}
|
| 2928 |
|
|
|
|
| 2920 |
#endif
|
| 2921 |
|
| 2922 |
props->caps = {
|
| 2923 |
+
/* .async = */ true,
|
| 2924 |
+
/* .host_buffer = */ host_buffer,
|
| 2925 |
+
/* .buffer_from_host_ptr = */ false,
|
| 2926 |
+
/* .events = */ events,
|
| 2927 |
};
|
| 2928 |
}
|
| 2929 |
|
ggml/src/ggml-metal.m
CHANGED
|
@@ -20,6 +20,69 @@
|
|
| 20 |
|
| 21 |
#define UNUSED(x) (void)(x)
|
| 22 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 23 |
struct ggml_metal_kernel {
|
| 24 |
id<MTLComputePipelineState> pipeline;
|
| 25 |
};
|
|
@@ -214,16 +277,12 @@ enum ggml_metal_kernel_type {
|
|
| 214 |
};
|
| 215 |
|
| 216 |
struct ggml_backend_metal_context {
|
| 217 |
-
id<MTLDevice> device;
|
| 218 |
id<MTLCommandQueue> queue;
|
| 219 |
|
| 220 |
dispatch_queue_t d_queue;
|
| 221 |
|
| 222 |
struct ggml_metal_kernel kernels[GGML_METAL_KERNEL_TYPE_COUNT];
|
| 223 |
|
| 224 |
-
bool support_simdgroup_reduction;
|
| 225 |
-
bool support_simdgroup_mm;
|
| 226 |
-
|
| 227 |
// capture state
|
| 228 |
bool capture_next_compute;
|
| 229 |
bool capture_started;
|
|
@@ -280,7 +339,7 @@ static void * ggml_metal_host_malloc(size_t n) {
|
|
| 280 |
return data;
|
| 281 |
}
|
| 282 |
|
| 283 |
-
static struct ggml_backend_metal_context * ggml_metal_init(
|
| 284 |
GGML_LOG_INFO("%s: allocating\n", __func__);
|
| 285 |
|
| 286 |
#if TARGET_OS_OSX && !GGML_METAL_NDEBUG
|
|
@@ -292,14 +351,14 @@ static struct ggml_backend_metal_context * ggml_metal_init(void) {
|
|
| 292 |
[devices release]; // since it was created by a *Copy* C method
|
| 293 |
#endif
|
| 294 |
|
| 295 |
-
//
|
| 296 |
-
|
|
|
|
|
|
|
|
|
|
| 297 |
GGML_LOG_INFO("%s: picking default device: %s\n", __func__, [[device name] UTF8String]);
|
| 298 |
|
| 299 |
-
|
| 300 |
-
struct ggml_backend_metal_context * ctx = calloc(1, sizeof(struct ggml_backend_metal_context));
|
| 301 |
-
ctx->device = device;
|
| 302 |
-
ctx->queue = [ctx->device newCommandQueue];
|
| 303 |
ctx->d_queue = dispatch_queue_create("ggml-metal", DISPATCH_QUEUE_CONCURRENT);
|
| 304 |
|
| 305 |
id<MTLLibrary> metal_library;
|
|
@@ -332,7 +391,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(void) {
|
|
| 332 |
NSURL * libURL = [NSURL fileURLWithPath:path_lib];
|
| 333 |
GGML_LOG_INFO("%s: loading '%s'\n", __func__, [path_lib UTF8String]);
|
| 334 |
|
| 335 |
-
metal_library = [
|
| 336 |
if (error) {
|
| 337 |
GGML_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]);
|
| 338 |
return NULL;
|
|
@@ -382,7 +441,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(void) {
|
|
| 382 |
|
| 383 |
//[options setFastMathEnabled:false];
|
| 384 |
|
| 385 |
-
metal_library = [
|
| 386 |
if (error) {
|
| 387 |
GGML_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]);
|
| 388 |
return NULL;
|
|
@@ -392,44 +451,37 @@ static struct ggml_backend_metal_context * ggml_metal_init(void) {
|
|
| 392 |
}
|
| 393 |
|
| 394 |
// print MTL GPU family:
|
| 395 |
-
GGML_LOG_INFO("%s: GPU name: %s\n", __func__, [[
|
| 396 |
-
|
| 397 |
-
const NSInteger MTLGPUFamilyMetal3 = 5001;
|
| 398 |
|
| 399 |
// determine max supported GPU family
|
| 400 |
// https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf
|
| 401 |
// https://developer.apple.com/metal/Metal-Feature-Set-Tables.pdf
|
| 402 |
{
|
| 403 |
for (int i = MTLGPUFamilyApple1 + 20; i >= MTLGPUFamilyApple1; --i) {
|
| 404 |
-
if ([
|
| 405 |
GGML_LOG_INFO("%s: GPU family: MTLGPUFamilyApple%d (%d)\n", __func__, i - (int) MTLGPUFamilyApple1 + 1, i);
|
| 406 |
break;
|
| 407 |
}
|
| 408 |
}
|
| 409 |
|
| 410 |
for (int i = MTLGPUFamilyCommon1 + 5; i >= MTLGPUFamilyCommon1; --i) {
|
| 411 |
-
if ([
|
| 412 |
GGML_LOG_INFO("%s: GPU family: MTLGPUFamilyCommon%d (%d)\n", __func__, i - (int) MTLGPUFamilyCommon1 + 1, i);
|
| 413 |
break;
|
| 414 |
}
|
| 415 |
}
|
| 416 |
|
| 417 |
-
for (int i =
|
| 418 |
-
if ([
|
| 419 |
-
GGML_LOG_INFO("%s: GPU family: MTLGPUFamilyMetal%d (%d)\n", __func__, i - (int)
|
| 420 |
break;
|
| 421 |
}
|
| 422 |
}
|
| 423 |
}
|
| 424 |
|
| 425 |
-
|
| 426 |
-
|
| 427 |
-
|
| 428 |
-
ctx->support_simdgroup_mm = [ctx->device supportsFamily:MTLGPUFamilyApple7];
|
| 429 |
-
|
| 430 |
-
GGML_LOG_INFO("%s: simdgroup reduction support = %s\n", __func__, ctx->support_simdgroup_reduction ? "true" : "false");
|
| 431 |
-
GGML_LOG_INFO("%s: simdgroup matrix mul. support = %s\n", __func__, ctx->support_simdgroup_mm ? "true" : "false");
|
| 432 |
-
GGML_LOG_INFO("%s: hasUnifiedMemory = %s\n", __func__, ctx->device.hasUnifiedMemory ? "true" : "false");
|
| 433 |
|
| 434 |
ctx->capture_next_compute = false;
|
| 435 |
ctx->capture_started = false;
|
|
@@ -443,13 +495,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(void) {
|
|
| 443 |
|
| 444 |
#if TARGET_OS_OSX || (TARGET_OS_IOS && __clang_major__ >= 15)
|
| 445 |
if (@available(macOS 10.12, iOS 16.0, *)) {
|
| 446 |
-
GGML_LOG_INFO("%s: recommendedMaxWorkingSetSize = %8.2f MB\n", __func__,
|
| 447 |
-
}
|
| 448 |
-
#elif TARGET_OS_OSX
|
| 449 |
-
if (ctx->device.maxTransferRate != 0) {
|
| 450 |
-
GGML_LOG_INFO("%s: maxTransferRate = %8.2f MB/s\n", __func__, ctx->device.maxTransferRate / 1e6);
|
| 451 |
-
} else {
|
| 452 |
-
GGML_LOG_INFO("%s: maxTransferRate = built-in GPU\n", __func__);
|
| 453 |
}
|
| 454 |
#endif
|
| 455 |
|
|
@@ -470,7 +516,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(void) {
|
|
| 470 |
if (supported) { \
|
| 471 |
struct ggml_metal_kernel * kernel = &ctx->kernels[e]; \
|
| 472 |
id<MTLFunction> metal_function = [metal_library newFunctionWithName:@"kernel_"#name]; \
|
| 473 |
-
kernel->pipeline = [
|
| 474 |
[metal_function release]; \
|
| 475 |
if (error) { \
|
| 476 |
GGML_LOG_ERROR("%s: error: load pipeline error: %s\n", __func__, [[error description] UTF8String]); \
|
|
@@ -481,6 +527,9 @@ static struct ggml_backend_metal_context * ggml_metal_init(void) {
|
|
| 481 |
GGML_LOG_WARN("%s: skipping %-40s (not supported)\n", __func__, "kernel_"#name); \
|
| 482 |
}
|
| 483 |
|
|
|
|
|
|
|
|
|
|
| 484 |
// simd_sum and simd_max requires MTLGPUFamilyApple7
|
| 485 |
|
| 486 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD, add, true);
|
|
@@ -507,10 +556,10 @@ static struct ggml_backend_metal_context * ggml_metal_init(void) {
|
|
| 507 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_QUICK_4, gelu_quick_4, true);
|
| 508 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU, silu, true);
|
| 509 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU_4, silu_4, true);
|
| 510 |
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16, soft_max_f16,
|
| 511 |
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16_4, soft_max_f16_4,
|
| 512 |
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32, soft_max_f32,
|
| 513 |
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32_4, soft_max_f32_4,
|
| 514 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF, diag_mask_inf, true);
|
| 515 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF_8, diag_mask_inf_8, true);
|
| 516 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_F32, get_rows_f32, true);
|
|
@@ -535,101 +584,101 @@ static struct ggml_backend_metal_context * ggml_metal_init(void) {
|
|
| 535 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL, get_rows_iq4_nl, true);
|
| 536 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS, get_rows_iq4_xs, true);
|
| 537 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_I32, get_rows_i32, true);
|
| 538 |
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RMS_NORM, rms_norm,
|
| 539 |
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GROUP_NORM, group_norm,
|
| 540 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NORM, norm, true);
|
| 541 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_CONV_F32, ssm_conv_f32, true);
|
| 542 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32, ssm_scan_f32, true);
|
| 543 |
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32, mul_mv_f32_f32,
|
| 544 |
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16, mul_mv_f16_f16,
|
| 545 |
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32, mul_mv_f16_f32,
|
| 546 |
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW, mul_mv_f16_f32_1row,
|
| 547 |
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4, mul_mv_f16_f32_l4,
|
| 548 |
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_0_F32, mul_mv_q4_0_f32,
|
| 549 |
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_1_F32, mul_mv_q4_1_f32,
|
| 550 |
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32, mul_mv_q5_0_f32,
|
| 551 |
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32, mul_mv_q5_1_f32,
|
| 552 |
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32, mul_mv_q8_0_f32,
|
| 553 |
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q2_K_F32, mul_mv_q2_K_f32,
|
| 554 |
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q3_K_F32, mul_mv_q3_K_f32,
|
| 555 |
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_K_F32, mul_mv_q4_K_f32,
|
| 556 |
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_K_F32, mul_mv_q5_K_f32,
|
| 557 |
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q6_K_F32, mul_mv_q6_K_f32,
|
| 558 |
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XXS_F32, mul_mv_iq2_xxs_f32,
|
| 559 |
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32, mul_mv_iq2_xs_f32,
|
| 560 |
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_XXS_F32, mul_mv_iq3_xxs_f32,
|
| 561 |
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_S_F32, mul_mv_iq3_s_f32,
|
| 562 |
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_S_F32, mul_mv_iq2_s_f32,
|
| 563 |
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_S_F32, mul_mv_iq1_s_f32,
|
| 564 |
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_M_F32, mul_mv_iq1_m_f32,
|
| 565 |
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32, mul_mv_iq4_nl_f32,
|
| 566 |
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_XS_F32, mul_mv_iq4_xs_f32,
|
| 567 |
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32, mul_mv_id_f32_f32,
|
| 568 |
-
//GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F16, mul_mv_id_f16_f16,
|
| 569 |
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32, mul_mv_id_f16_f32,
|
| 570 |
-
//GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_1ROW, mul_mv_id_f16_f32_1row,
|
| 571 |
-
//GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_L4, mul_mv_id_f16_f32_l4,
|
| 572 |
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_0_F32, mul_mv_id_q4_0_f32,
|
| 573 |
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_1_F32, mul_mv_id_q4_1_f32,
|
| 574 |
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_0_F32, mul_mv_id_q5_0_f32,
|
| 575 |
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_1_F32, mul_mv_id_q5_1_f32,
|
| 576 |
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q8_0_F32, mul_mv_id_q8_0_f32,
|
| 577 |
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q2_K_F32, mul_mv_id_q2_K_f32,
|
| 578 |
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q3_K_F32, mul_mv_id_q3_K_f32,
|
| 579 |
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_K_F32, mul_mv_id_q4_K_f32,
|
| 580 |
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_K_F32, mul_mv_id_q5_K_f32,
|
| 581 |
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q6_K_F32, mul_mv_id_q6_K_f32,
|
| 582 |
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XXS_F32, mul_mv_id_iq2_xxs_f32,
|
| 583 |
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32, mul_mv_id_iq2_xs_f32,
|
| 584 |
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_XXS_F32, mul_mv_id_iq3_xxs_f32,
|
| 585 |
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_S_F32, mul_mv_id_iq3_s_f32,
|
| 586 |
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_S_F32, mul_mv_id_iq2_s_f32,
|
| 587 |
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_S_F32, mul_mv_id_iq1_s_f32,
|
| 588 |
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_M_F32, mul_mv_id_iq1_m_f32,
|
| 589 |
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32, mul_mv_id_iq4_nl_f32,
|
| 590 |
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32, mul_mv_id_iq4_xs_f32,
|
| 591 |
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32, mul_mm_f32_f32,
|
| 592 |
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32, mul_mm_f16_f32,
|
| 593 |
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32, mul_mm_q4_0_f32,
|
| 594 |
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F32, mul_mm_q4_1_f32,
|
| 595 |
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32, mul_mm_q5_0_f32,
|
| 596 |
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_1_F32, mul_mm_q5_1_f32,
|
| 597 |
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q8_0_F32, mul_mm_q8_0_f32,
|
| 598 |
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q2_K_F32, mul_mm_q2_K_f32,
|
| 599 |
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q3_K_F32, mul_mm_q3_K_f32,
|
| 600 |
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_K_F32, mul_mm_q4_K_f32,
|
| 601 |
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_K_F32, mul_mm_q5_K_f32,
|
| 602 |
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q6_K_F32, mul_mm_q6_K_f32,
|
| 603 |
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32, mul_mm_iq2_xxs_f32,
|
| 604 |
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32, mul_mm_iq2_xs_f32,
|
| 605 |
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F32, mul_mm_iq3_xxs_f32,
|
| 606 |
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_S_F32, mul_mm_iq3_s_f32,
|
| 607 |
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_S_F32, mul_mm_iq2_s_f32,
|
| 608 |
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32, mul_mm_iq1_s_f32,
|
| 609 |
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F32, mul_mm_iq1_m_f32,
|
| 610 |
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32, mul_mm_iq4_nl_f32,
|
| 611 |
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32, mul_mm_iq4_xs_f32,
|
| 612 |
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32, mul_mm_id_f32_f32,
|
| 613 |
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32, mul_mm_id_f16_f32,
|
| 614 |
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32, mul_mm_id_q4_0_f32,
|
| 615 |
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F32, mul_mm_id_q4_1_f32,
|
| 616 |
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F32, mul_mm_id_q5_0_f32,
|
| 617 |
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F32, mul_mm_id_q5_1_f32,
|
| 618 |
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F32, mul_mm_id_q8_0_f32,
|
| 619 |
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F32, mul_mm_id_q2_K_f32,
|
| 620 |
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F32, mul_mm_id_q3_K_f32,
|
| 621 |
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F32, mul_mm_id_q4_K_f32,
|
| 622 |
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_K_F32, mul_mm_id_q5_K_f32,
|
| 623 |
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F32, mul_mm_id_q6_K_f32,
|
| 624 |
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32, mul_mm_id_iq2_xxs_f32,
|
| 625 |
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32, mul_mm_id_iq2_xs_f32,
|
| 626 |
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32, mul_mm_id_iq3_xxs_f32,
|
| 627 |
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_S_F32, mul_mm_id_iq3_s_f32,
|
| 628 |
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F32, mul_mm_id_iq2_s_f32,
|
| 629 |
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32, mul_mm_id_iq1_s_f32,
|
| 630 |
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F32, mul_mm_id_iq1_m_f32,
|
| 631 |
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32, mul_mm_id_iq4_nl_f32,
|
| 632 |
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32, mul_mm_id_iq4_xs_f32,
|
| 633 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32, rope_norm_f32, true);
|
| 634 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16, rope_norm_f16, true);
|
| 635 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32, rope_neox_f32, true);
|
|
@@ -643,14 +692,14 @@ static struct ggml_backend_metal_context * ggml_metal_init(void) {
|
|
| 643 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC, argsort_f32_i32_asc, true);
|
| 644 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC, argsort_f32_i32_desc, true);
|
| 645 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32, leaky_relu_f32, true);
|
| 646 |
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64, flash_attn_ext_f16_h64,
|
| 647 |
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80, flash_attn_ext_f16_h80,
|
| 648 |
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96, flash_attn_ext_f16_h96,
|
| 649 |
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112, flash_attn_ext_f16_h112,
|
| 650 |
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128, flash_attn_ext_f16_h128,
|
| 651 |
-
//GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, flash_attn_ext_f16_h256,
|
| 652 |
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128, flash_attn_ext_vec_f16_h128,
|
| 653 |
-
//GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, flash_attn_ext_vec_f16_h256,
|
| 654 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F16, cpy_f32_f16, true);
|
| 655 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F32, cpy_f32_f32, true);
|
| 656 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F16, cpy_f16_f16, true);
|
|
@@ -684,7 +733,6 @@ static void ggml_metal_free(struct ggml_backend_metal_context * ctx) {
|
|
| 684 |
Block_release(ctx->encode_async);
|
| 685 |
|
| 686 |
[ctx->queue release];
|
| 687 |
-
[ctx->device release];
|
| 688 |
|
| 689 |
dispatch_release(ctx->d_queue);
|
| 690 |
|
|
@@ -742,13 +790,16 @@ static id<MTLBuffer> ggml_metal_get_buffer(struct ggml_tensor * t, size_t * offs
|
|
| 742 |
return nil;
|
| 743 |
}
|
| 744 |
|
| 745 |
-
static bool ggml_metal_supports_op(const struct
|
| 746 |
for (size_t i = 0, n = 3; i < n; ++i) {
|
| 747 |
if (op->src[i] != NULL && op->src[i]->type == GGML_TYPE_BF16) {
|
| 748 |
return false;
|
| 749 |
}
|
| 750 |
}
|
| 751 |
|
|
|
|
|
|
|
|
|
|
| 752 |
switch (op->op) {
|
| 753 |
case GGML_OP_UNARY:
|
| 754 |
switch (ggml_get_unary_op(op)) {
|
|
@@ -786,7 +837,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_context * ctx
|
|
| 786 |
case GGML_OP_SOFT_MAX:
|
| 787 |
case GGML_OP_RMS_NORM:
|
| 788 |
case GGML_OP_GROUP_NORM:
|
| 789 |
-
return
|
| 790 |
case GGML_OP_NORM:
|
| 791 |
case GGML_OP_ROPE:
|
| 792 |
return true;
|
|
@@ -812,13 +863,13 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_context * ctx
|
|
| 812 |
if (op->src[0]->ne[0] == 256) {
|
| 813 |
return false;
|
| 814 |
}
|
| 815 |
-
return
|
| 816 |
case GGML_OP_SSM_CONV:
|
| 817 |
case GGML_OP_SSM_SCAN:
|
| 818 |
return true;
|
| 819 |
case GGML_OP_MUL_MAT:
|
| 820 |
case GGML_OP_MUL_MAT_ID:
|
| 821 |
-
return
|
| 822 |
(op->src[0]->type != GGML_TYPE_F32 || op->src[1]->type == GGML_TYPE_F32);
|
| 823 |
case GGML_OP_CPY:
|
| 824 |
case GGML_OP_DUP:
|
|
@@ -862,9 +913,12 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_context * ctx
|
|
| 862 |
}
|
| 863 |
|
| 864 |
static void ggml_metal_encode_node(
|
| 865 |
-
|
| 866 |
int idx,
|
| 867 |
id<MTLComputeCommandEncoder> encoder) {
|
|
|
|
|
|
|
|
|
|
| 868 |
struct ggml_cgraph * gf = ctx->gf;
|
| 869 |
|
| 870 |
struct ggml_tensor * node = ggml_graph_node(gf, idx);
|
|
@@ -894,7 +948,7 @@ static void ggml_metal_encode_node(
|
|
| 894 |
} break;
|
| 895 |
}
|
| 896 |
|
| 897 |
-
if (!ggml_metal_supports_op(
|
| 898 |
GGML_LOG_ERROR("%s: error: unsupported op '%s'\n", __func__, ggml_op_desc(dst));
|
| 899 |
GGML_ABORT("unsupported op");
|
| 900 |
}
|
|
@@ -967,6 +1021,8 @@ static void ggml_metal_encode_node(
|
|
| 967 |
// dst->name);
|
| 968 |
//}
|
| 969 |
|
|
|
|
|
|
|
| 970 |
switch (dst->op) {
|
| 971 |
case GGML_OP_CONCAT:
|
| 972 |
{
|
|
@@ -1675,7 +1731,7 @@ static void ggml_metal_encode_node(
|
|
| 1675 |
// the numbers below are measured on M2 Ultra for 7B and 13B models
|
| 1676 |
// these numbers do not translate to other devices or model sizes
|
| 1677 |
// TODO: need to find a better approach
|
| 1678 |
-
if ([
|
| 1679 |
switch (src0t) {
|
| 1680 |
case GGML_TYPE_F16: ne11_mm_min = 2; break;
|
| 1681 |
case GGML_TYPE_Q8_0: ne11_mm_min = 7; break;
|
|
@@ -1695,7 +1751,7 @@ static void ggml_metal_encode_node(
|
|
| 1695 |
|
| 1696 |
// for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
|
| 1697 |
// AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
|
| 1698 |
-
if ([
|
| 1699 |
!ggml_is_transposed(src0) &&
|
| 1700 |
!ggml_is_transposed(src1) &&
|
| 1701 |
src1t == GGML_TYPE_F32 &&
|
|
@@ -1990,7 +2046,7 @@ static void ggml_metal_encode_node(
|
|
| 1990 |
// ne21 = n_rows
|
| 1991 |
const int dst_rows = ne20*ne21;
|
| 1992 |
const int dst_rows_min = n_as;
|
| 1993 |
-
const int dst_rows_max = (
|
| 1994 |
|
| 1995 |
// max size of the rowids array in the kernel shared buffer
|
| 1996 |
GGML_ASSERT(dst_rows <= dst_rows_max);
|
|
@@ -2001,7 +2057,7 @@ static void ggml_metal_encode_node(
|
|
| 2001 |
// TODO: for now, always use mat-vec kernels until we figure out how to improve the
|
| 2002 |
// indirect matrix multiplication
|
| 2003 |
// !!!
|
| 2004 |
-
if ([
|
| 2005 |
ne00 % 32 == 0 && ne00 >= 64 &&
|
| 2006 |
dst_rows > dst_rows_min) {
|
| 2007 |
|
|
@@ -2840,7 +2896,7 @@ static void ggml_metal_encode_node(
|
|
| 2840 |
|
| 2841 |
while (true) {
|
| 2842 |
const size_t smem = nqptg*(ne00 + 2*nsgmax*(ncpsg + nqptg))*(sizeof(float)/2);
|
| 2843 |
-
if (smem >
|
| 2844 |
break;
|
| 2845 |
}
|
| 2846 |
nsgmax *= 2;
|
|
@@ -2852,8 +2908,8 @@ static void ggml_metal_encode_node(
|
|
| 2852 |
|
| 2853 |
const size_t smem = nqptg*(ne00 + 2*nsg*(ncpsg + nqptg))*(sizeof(float)/2);
|
| 2854 |
|
| 2855 |
-
//printf("smem: %zu, max: %zu\n", smem,
|
| 2856 |
-
GGML_ASSERT(smem <=
|
| 2857 |
|
| 2858 |
[encoder setThreadgroupMemoryLength:GGML_PAD(smem, 16) atIndex:0];
|
| 2859 |
|
|
@@ -2878,8 +2934,8 @@ static void ggml_metal_encode_node(
|
|
| 2878 |
|
| 2879 |
const size_t smem = (nqptg*(ne00 + 2*nsg*(ncpsg + nqptg)) + nsg*ne00)*(sizeof(float)/2);
|
| 2880 |
|
| 2881 |
-
//printf("smem: %zu, max: %zu\n", smem,
|
| 2882 |
-
GGML_ASSERT(smem <=
|
| 2883 |
[encoder setThreadgroupMemoryLength:GGML_PAD(smem, 16) atIndex:0];
|
| 2884 |
|
| 2885 |
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)];
|
|
@@ -2954,8 +3010,11 @@ static void ggml_metal_encode_node(
|
|
| 2954 |
}
|
| 2955 |
|
| 2956 |
static enum ggml_status ggml_metal_graph_compute(
|
| 2957 |
-
|
| 2958 |
-
|
|
|
|
|
|
|
|
|
|
| 2959 |
// number of nodes encoded by the main thread (empirically determined)
|
| 2960 |
const int n_main = 128;
|
| 2961 |
|
|
@@ -2983,7 +3042,7 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
| 2983 |
|
| 2984 |
if (!ctx->capture_started) {
|
| 2985 |
// create capture scope
|
| 2986 |
-
ctx->capture_scope = [[MTLCaptureManager sharedCaptureManager] newCaptureScopeWithDevice:
|
| 2987 |
|
| 2988 |
MTLCaptureDescriptor * descriptor = [MTLCaptureDescriptor new];
|
| 2989 |
descriptor.captureObject = ctx->capture_scope;
|
|
@@ -3087,31 +3146,6 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
| 3087 |
|
| 3088 |
// backend interface
|
| 3089 |
|
| 3090 |
-
// default buffer
|
| 3091 |
-
static id<MTLDevice> g_backend_device = nil;
|
| 3092 |
-
static int g_backend_device_ref_count = 0;
|
| 3093 |
-
|
| 3094 |
-
static id<MTLDevice> ggml_backend_metal_get_device(void) {
|
| 3095 |
-
if (g_backend_device == nil) {
|
| 3096 |
-
g_backend_device = MTLCreateSystemDefaultDevice();
|
| 3097 |
-
}
|
| 3098 |
-
|
| 3099 |
-
g_backend_device_ref_count++;
|
| 3100 |
-
|
| 3101 |
-
return g_backend_device;
|
| 3102 |
-
}
|
| 3103 |
-
|
| 3104 |
-
static void ggml_backend_metal_free_device(void) {
|
| 3105 |
-
assert(g_backend_device_ref_count > 0);
|
| 3106 |
-
|
| 3107 |
-
g_backend_device_ref_count--;
|
| 3108 |
-
|
| 3109 |
-
if (g_backend_device_ref_count == 0) {
|
| 3110 |
-
[g_backend_device release];
|
| 3111 |
-
g_backend_device = nil;
|
| 3112 |
-
}
|
| 3113 |
-
}
|
| 3114 |
-
|
| 3115 |
static const char * ggml_backend_metal_buffer_get_name(ggml_backend_buffer_t buffer) {
|
| 3116 |
return "Metal";
|
| 3117 |
|
|
@@ -3124,7 +3158,7 @@ static void ggml_backend_metal_buffer_free_buffer(ggml_backend_buffer_t buffer)
|
|
| 3124 |
for (int i = 0; i < ctx->n_buffers; i++) {
|
| 3125 |
[ctx->buffers[i].metal release];
|
| 3126 |
}
|
| 3127 |
-
|
| 3128 |
|
| 3129 |
if (ctx->owned) {
|
| 3130 |
#if TARGET_OS_OSX
|
|
@@ -3227,7 +3261,7 @@ static ggml_backend_buffer_t ggml_backend_metal_buffer_type_alloc_buffer(ggml_ba
|
|
| 3227 |
size_aligned += (size_page - (size_aligned % size_page));
|
| 3228 |
}
|
| 3229 |
|
| 3230 |
-
id<MTLDevice> device =
|
| 3231 |
|
| 3232 |
ctx->all_data = ggml_metal_host_malloc(size_aligned);
|
| 3233 |
ctx->all_size = size_aligned;
|
|
@@ -3241,16 +3275,16 @@ static ggml_backend_buffer_t ggml_backend_metal_buffer_type_alloc_buffer(ggml_ba
|
|
| 3241 |
|
| 3242 |
if (size_aligned > 0) {
|
| 3243 |
ctx->buffers[0].metal = [device newBufferWithBytesNoCopy:ctx->all_data
|
| 3244 |
-
|
| 3245 |
-
|
| 3246 |
-
|
| 3247 |
}
|
| 3248 |
}
|
| 3249 |
|
| 3250 |
if (size_aligned > 0 && (ctx->all_data == NULL || ctx->buffers[0].metal == nil)) {
|
| 3251 |
GGML_LOG_ERROR("%s: error: failed to allocate buffer, size = %8.2f MiB\n", __func__, size_aligned / 1024.0 / 1024.0);
|
| 3252 |
free(ctx);
|
| 3253 |
-
|
| 3254 |
return NULL;
|
| 3255 |
}
|
| 3256 |
|
|
@@ -3265,9 +3299,9 @@ static size_t ggml_backend_metal_buffer_type_get_alignment(ggml_backend_buffer_t
|
|
| 3265 |
}
|
| 3266 |
|
| 3267 |
static size_t ggml_backend_metal_buffer_type_get_max_size(ggml_backend_buffer_type_t buft) {
|
| 3268 |
-
id<MTLDevice> device =
|
| 3269 |
-
size_t max_size = device.maxBufferLength;
|
| 3270 |
-
|
| 3271 |
|
| 3272 |
return max_size;
|
| 3273 |
|
|
@@ -3290,15 +3324,14 @@ ggml_backend_buffer_type_t ggml_backend_metal_buffer_type(void) {
|
|
| 3290 |
/* .get_alloc_size = */ NULL, // defaults to ggml_nbytes
|
| 3291 |
/* .is_host = */ ggml_backend_metal_buffer_type_is_host,
|
| 3292 |
},
|
| 3293 |
-
/* .device = */
|
| 3294 |
/* .context = */ NULL,
|
| 3295 |
};
|
| 3296 |
|
| 3297 |
return &ggml_backend_buffer_type_metal;
|
| 3298 |
}
|
| 3299 |
|
| 3300 |
-
//
|
| 3301 |
-
|
| 3302 |
ggml_backend_buffer_t ggml_backend_metal_buffer_from_ptr(void * data, size_t size, size_t max_size) {
|
| 3303 |
struct ggml_backend_metal_buffer_context * ctx = calloc(1, sizeof(struct ggml_backend_metal_buffer_context));
|
| 3304 |
|
|
@@ -3321,7 +3354,7 @@ ggml_backend_buffer_t ggml_backend_metal_buffer_from_ptr(void * data, size_t siz
|
|
| 3321 |
size_aligned += (size_page - (size_aligned % size_page));
|
| 3322 |
}
|
| 3323 |
|
| 3324 |
-
id<MTLDevice> device =
|
| 3325 |
|
| 3326 |
// the buffer fits into the max buffer size allowed by the device
|
| 3327 |
if (size_aligned <= device.maxBufferLength) {
|
|
@@ -3386,8 +3419,12 @@ static const char * ggml_backend_metal_name(ggml_backend_t backend) {
|
|
| 3386 |
}
|
| 3387 |
|
| 3388 |
static void ggml_backend_metal_free(ggml_backend_t backend) {
|
| 3389 |
-
struct ggml_backend_metal_context
|
|
|
|
|
|
|
|
|
|
| 3390 |
ggml_metal_free(ctx);
|
|
|
|
| 3391 |
free(backend);
|
| 3392 |
}
|
| 3393 |
|
|
@@ -3398,21 +3435,7 @@ static ggml_backend_buffer_type_t ggml_backend_metal_get_default_buffer_type(ggm
|
|
| 3398 |
}
|
| 3399 |
|
| 3400 |
static enum ggml_status ggml_backend_metal_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) {
|
| 3401 |
-
|
| 3402 |
-
|
| 3403 |
-
return ggml_metal_graph_compute(metal_ctx, cgraph);
|
| 3404 |
-
}
|
| 3405 |
-
|
| 3406 |
-
static bool ggml_backend_metal_supports_op(ggml_backend_t backend, const struct ggml_tensor * op) {
|
| 3407 |
-
struct ggml_backend_metal_context * metal_ctx = (struct ggml_backend_metal_context *)backend->context;
|
| 3408 |
-
|
| 3409 |
-
return ggml_metal_supports_op(metal_ctx, op);
|
| 3410 |
-
}
|
| 3411 |
-
|
| 3412 |
-
static bool ggml_backend_metal_supports_buft(ggml_backend_t backend, ggml_backend_buffer_type_t buft) {
|
| 3413 |
-
return buft->iface.get_name == ggml_backend_metal_buffer_type_get_name;
|
| 3414 |
-
|
| 3415 |
-
UNUSED(backend);
|
| 3416 |
}
|
| 3417 |
|
| 3418 |
static void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) {
|
|
@@ -3459,7 +3482,7 @@ static void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) {
|
|
| 3459 |
[encoder pushDebugGroup:[NSString stringWithCString:ggml_op_desc(ggml_graph_node(ctx->gf, idx)) encoding:NSUTF8StringEncoding]];
|
| 3460 |
}
|
| 3461 |
|
| 3462 |
-
ggml_metal_encode_node(
|
| 3463 |
|
| 3464 |
if (should_capture) {
|
| 3465 |
[encoder popDebugGroup];
|
|
@@ -3487,8 +3510,8 @@ static struct ggml_backend_i ggml_backend_metal_i = {
|
|
| 3487 |
/* .graph_plan_update = */ NULL,
|
| 3488 |
/* .graph_plan_compute = */ NULL,
|
| 3489 |
/* .graph_compute = */ ggml_backend_metal_graph_compute,
|
| 3490 |
-
/* .supports_op = */
|
| 3491 |
-
/* .supports_buft = */
|
| 3492 |
/* .offload_op = */ NULL,
|
| 3493 |
/* .event_record = */ NULL,
|
| 3494 |
/* .event_wait = */ NULL,
|
|
@@ -3499,8 +3522,11 @@ static ggml_guid_t ggml_backend_metal_guid(void) {
|
|
| 3499 |
return &guid;
|
| 3500 |
}
|
| 3501 |
|
|
|
|
| 3502 |
ggml_backend_t ggml_backend_metal_init(void) {
|
| 3503 |
-
|
|
|
|
|
|
|
| 3504 |
if (ctx == NULL) {
|
| 3505 |
GGML_LOG_ERROR("%s: error: failed to allocate context\n", __func__);
|
| 3506 |
return NULL;
|
|
@@ -3511,7 +3537,7 @@ ggml_backend_t ggml_backend_metal_init(void) {
|
|
| 3511 |
*backend = (struct ggml_backend) {
|
| 3512 |
/* .guid = */ ggml_backend_metal_guid(),
|
| 3513 |
/* .interface = */ ggml_backend_metal_i,
|
| 3514 |
-
/* .device = */
|
| 3515 |
/* .context = */ ctx,
|
| 3516 |
};
|
| 3517 |
|
|
@@ -3536,9 +3562,9 @@ void ggml_backend_metal_set_abort_callback(ggml_backend_t backend, ggml_abort_ca
|
|
| 3536 |
bool ggml_backend_metal_supports_family(ggml_backend_t backend, int family) {
|
| 3537 |
GGML_ASSERT(ggml_backend_is_metal(backend));
|
| 3538 |
|
| 3539 |
-
struct
|
| 3540 |
|
| 3541 |
-
return [
|
| 3542 |
}
|
| 3543 |
|
| 3544 |
void ggml_backend_metal_capture_next_compute(ggml_backend_t backend) {
|
|
@@ -3548,11 +3574,246 @@ void ggml_backend_metal_capture_next_compute(ggml_backend_t backend) {
|
|
| 3548 |
ctx->capture_next_compute = true;
|
| 3549 |
}
|
| 3550 |
|
| 3551 |
-
|
|
|
|
|
|
|
|
|
|
| 3552 |
|
| 3553 |
-
|
| 3554 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3555 |
|
| 3556 |
GGML_UNUSED(params);
|
| 3557 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3558 |
}
|
|
|
|
| 20 |
|
| 21 |
#define UNUSED(x) (void)(x)
|
| 22 |
|
| 23 |
+
// globals
|
| 24 |
+
|
| 25 |
+
// overload of MTLGPUFamilyMetal3 (not available in some environments)
|
| 26 |
+
static const NSInteger MTLGPUFamilyMetal3_GGML = 5001;
|
| 27 |
+
|
| 28 |
+
// initialized in ggml_backend_metal_reg
|
| 29 |
+
static struct ggml_backend_reg g_ggml_backend_metal_reg;
|
| 30 |
+
static struct ggml_backend_device g_ggml_backend_metal_device;
|
| 31 |
+
|
| 32 |
+
// information about a Metal device
|
| 33 |
+
// note: assumes single GPU device - the default one
|
| 34 |
+
// TODO: support multiple GPU devices
|
| 35 |
+
static struct ggml_backend_metal_device_context {
|
| 36 |
+
id<MTLDevice> mtl_device;
|
| 37 |
+
int mtl_device_ref_count;
|
| 38 |
+
|
| 39 |
+
bool support_simdgroup_reduction;
|
| 40 |
+
bool support_simdgroup_mm;
|
| 41 |
+
|
| 42 |
+
char name[128];
|
| 43 |
+
} g_ggml_ctx_dev_main = {
|
| 44 |
+
/*.mtl_device =*/ nil,
|
| 45 |
+
/*.mtl_device_ref_count =*/ 0,
|
| 46 |
+
/*.support_simdgroup_reduction =*/ false,
|
| 47 |
+
/*.support_simdgroup_mm =*/ false,
|
| 48 |
+
/*.name =*/ "",
|
| 49 |
+
};
|
| 50 |
+
|
| 51 |
+
// acquire
|
| 52 |
+
static id<MTLDevice> ggml_backend_metal_device_acq(struct ggml_backend_metal_device_context * ctx) {
|
| 53 |
+
assert(ctx != NULL);
|
| 54 |
+
|
| 55 |
+
if (ctx->mtl_device == nil) {
|
| 56 |
+
ctx->mtl_device = MTLCreateSystemDefaultDevice();
|
| 57 |
+
|
| 58 |
+
ctx->support_simdgroup_reduction = [ctx->mtl_device supportsFamily:MTLGPUFamilyApple7];
|
| 59 |
+
ctx->support_simdgroup_reduction |= [ctx->mtl_device supportsFamily:MTLGPUFamilyMetal3_GGML];
|
| 60 |
+
|
| 61 |
+
ctx->support_simdgroup_mm = [ctx->mtl_device supportsFamily:MTLGPUFamilyApple7];
|
| 62 |
+
|
| 63 |
+
strncpy(ctx->name, [[ctx->mtl_device name] UTF8String], sizeof(ctx->name) - 1);
|
| 64 |
+
}
|
| 65 |
+
|
| 66 |
+
ctx->mtl_device_ref_count++;
|
| 67 |
+
|
| 68 |
+
return ctx->mtl_device;
|
| 69 |
+
}
|
| 70 |
+
|
| 71 |
+
// release
|
| 72 |
+
static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_context * ctx) {
|
| 73 |
+
assert(ctx != NULL);
|
| 74 |
+
assert(ctx->mtl_device_ref_count > 0);
|
| 75 |
+
|
| 76 |
+
ctx->mtl_device_ref_count--;
|
| 77 |
+
|
| 78 |
+
if (ctx->mtl_device_ref_count == 0) {
|
| 79 |
+
[ctx->mtl_device release];
|
| 80 |
+
ctx->mtl_device = nil;
|
| 81 |
+
}
|
| 82 |
+
}
|
| 83 |
+
|
| 84 |
+
// kernels
|
| 85 |
+
|
| 86 |
struct ggml_metal_kernel {
|
| 87 |
id<MTLComputePipelineState> pipeline;
|
| 88 |
};
|
|
|
|
| 277 |
};
|
| 278 |
|
| 279 |
struct ggml_backend_metal_context {
|
|
|
|
| 280 |
id<MTLCommandQueue> queue;
|
| 281 |
|
| 282 |
dispatch_queue_t d_queue;
|
| 283 |
|
| 284 |
struct ggml_metal_kernel kernels[GGML_METAL_KERNEL_TYPE_COUNT];
|
| 285 |
|
|
|
|
|
|
|
|
|
|
| 286 |
// capture state
|
| 287 |
bool capture_next_compute;
|
| 288 |
bool capture_started;
|
|
|
|
| 339 |
return data;
|
| 340 |
}
|
| 341 |
|
| 342 |
+
static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t dev) {
|
| 343 |
GGML_LOG_INFO("%s: allocating\n", __func__);
|
| 344 |
|
| 345 |
#if TARGET_OS_OSX && !GGML_METAL_NDEBUG
|
|
|
|
| 351 |
[devices release]; // since it was created by a *Copy* C method
|
| 352 |
#endif
|
| 353 |
|
| 354 |
+
// init context
|
| 355 |
+
struct ggml_backend_metal_context * ctx = calloc(1, sizeof(struct ggml_backend_metal_context));
|
| 356 |
+
struct ggml_backend_metal_device_context * ctx_dev = dev->context;
|
| 357 |
+
|
| 358 |
+
id<MTLDevice> device = ggml_backend_metal_device_acq(ctx_dev);
|
| 359 |
GGML_LOG_INFO("%s: picking default device: %s\n", __func__, [[device name] UTF8String]);
|
| 360 |
|
| 361 |
+
ctx->queue = [device newCommandQueue];
|
|
|
|
|
|
|
|
|
|
| 362 |
ctx->d_queue = dispatch_queue_create("ggml-metal", DISPATCH_QUEUE_CONCURRENT);
|
| 363 |
|
| 364 |
id<MTLLibrary> metal_library;
|
|
|
|
| 391 |
NSURL * libURL = [NSURL fileURLWithPath:path_lib];
|
| 392 |
GGML_LOG_INFO("%s: loading '%s'\n", __func__, [path_lib UTF8String]);
|
| 393 |
|
| 394 |
+
metal_library = [device newLibraryWithURL:libURL error:&error];
|
| 395 |
if (error) {
|
| 396 |
GGML_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]);
|
| 397 |
return NULL;
|
|
|
|
| 441 |
|
| 442 |
//[options setFastMathEnabled:false];
|
| 443 |
|
| 444 |
+
metal_library = [device newLibraryWithSource:src options:options error:&error];
|
| 445 |
if (error) {
|
| 446 |
GGML_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]);
|
| 447 |
return NULL;
|
|
|
|
| 451 |
}
|
| 452 |
|
| 453 |
// print MTL GPU family:
|
| 454 |
+
GGML_LOG_INFO("%s: GPU name: %s\n", __func__, [[device name] UTF8String]);
|
|
|
|
|
|
|
| 455 |
|
| 456 |
// determine max supported GPU family
|
| 457 |
// https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf
|
| 458 |
// https://developer.apple.com/metal/Metal-Feature-Set-Tables.pdf
|
| 459 |
{
|
| 460 |
for (int i = MTLGPUFamilyApple1 + 20; i >= MTLGPUFamilyApple1; --i) {
|
| 461 |
+
if ([device supportsFamily:i]) {
|
| 462 |
GGML_LOG_INFO("%s: GPU family: MTLGPUFamilyApple%d (%d)\n", __func__, i - (int) MTLGPUFamilyApple1 + 1, i);
|
| 463 |
break;
|
| 464 |
}
|
| 465 |
}
|
| 466 |
|
| 467 |
for (int i = MTLGPUFamilyCommon1 + 5; i >= MTLGPUFamilyCommon1; --i) {
|
| 468 |
+
if ([device supportsFamily:i]) {
|
| 469 |
GGML_LOG_INFO("%s: GPU family: MTLGPUFamilyCommon%d (%d)\n", __func__, i - (int) MTLGPUFamilyCommon1 + 1, i);
|
| 470 |
break;
|
| 471 |
}
|
| 472 |
}
|
| 473 |
|
| 474 |
+
for (int i = MTLGPUFamilyMetal3_GGML + 5; i >= MTLGPUFamilyMetal3_GGML; --i) {
|
| 475 |
+
if ([device supportsFamily:i]) {
|
| 476 |
+
GGML_LOG_INFO("%s: GPU family: MTLGPUFamilyMetal%d (%d)\n", __func__, i - (int) MTLGPUFamilyMetal3_GGML + 3, i);
|
| 477 |
break;
|
| 478 |
}
|
| 479 |
}
|
| 480 |
}
|
| 481 |
|
| 482 |
+
GGML_LOG_INFO("%s: simdgroup reduction support = %s\n", __func__, ctx_dev->support_simdgroup_reduction ? "true" : "false");
|
| 483 |
+
GGML_LOG_INFO("%s: simdgroup matrix mul. support = %s\n", __func__, ctx_dev->support_simdgroup_mm ? "true" : "false");
|
| 484 |
+
GGML_LOG_INFO("%s: hasUnifiedMemory = %s\n", __func__, ctx_dev->mtl_device.hasUnifiedMemory ? "true" : "false");
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 485 |
|
| 486 |
ctx->capture_next_compute = false;
|
| 487 |
ctx->capture_started = false;
|
|
|
|
| 495 |
|
| 496 |
#if TARGET_OS_OSX || (TARGET_OS_IOS && __clang_major__ >= 15)
|
| 497 |
if (@available(macOS 10.12, iOS 16.0, *)) {
|
| 498 |
+
GGML_LOG_INFO("%s: recommendedMaxWorkingSetSize = %8.2f MB\n", __func__, device.recommendedMaxWorkingSetSize / 1e6);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 499 |
}
|
| 500 |
#endif
|
| 501 |
|
|
|
|
| 516 |
if (supported) { \
|
| 517 |
struct ggml_metal_kernel * kernel = &ctx->kernels[e]; \
|
| 518 |
id<MTLFunction> metal_function = [metal_library newFunctionWithName:@"kernel_"#name]; \
|
| 519 |
+
kernel->pipeline = [device newComputePipelineStateWithFunction:metal_function error:&error]; \
|
| 520 |
[metal_function release]; \
|
| 521 |
if (error) { \
|
| 522 |
GGML_LOG_ERROR("%s: error: load pipeline error: %s\n", __func__, [[error description] UTF8String]); \
|
|
|
|
| 527 |
GGML_LOG_WARN("%s: skipping %-40s (not supported)\n", __func__, "kernel_"#name); \
|
| 528 |
}
|
| 529 |
|
| 530 |
+
const bool support_simdgroup_mm = ctx_dev->support_simdgroup_mm;
|
| 531 |
+
const bool support_simdgroup_reduction = ctx_dev->support_simdgroup_reduction;
|
| 532 |
+
|
| 533 |
// simd_sum and simd_max requires MTLGPUFamilyApple7
|
| 534 |
|
| 535 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD, add, true);
|
|
|
|
| 556 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_QUICK_4, gelu_quick_4, true);
|
| 557 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU, silu, true);
|
| 558 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU_4, silu_4, true);
|
| 559 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16, soft_max_f16, support_simdgroup_reduction);
|
| 560 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16_4, soft_max_f16_4, support_simdgroup_reduction);
|
| 561 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32, soft_max_f32, support_simdgroup_reduction);
|
| 562 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32_4, soft_max_f32_4, support_simdgroup_reduction);
|
| 563 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF, diag_mask_inf, true);
|
| 564 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF_8, diag_mask_inf_8, true);
|
| 565 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_F32, get_rows_f32, true);
|
|
|
|
| 584 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL, get_rows_iq4_nl, true);
|
| 585 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS, get_rows_iq4_xs, true);
|
| 586 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_I32, get_rows_i32, true);
|
| 587 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RMS_NORM, rms_norm, support_simdgroup_reduction);
|
| 588 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GROUP_NORM, group_norm, support_simdgroup_reduction);
|
| 589 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NORM, norm, true);
|
| 590 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_CONV_F32, ssm_conv_f32, true);
|
| 591 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32, ssm_scan_f32, true);
|
| 592 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32, mul_mv_f32_f32, support_simdgroup_reduction);
|
| 593 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16, mul_mv_f16_f16, support_simdgroup_reduction);
|
| 594 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32, mul_mv_f16_f32, support_simdgroup_reduction);
|
| 595 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW, mul_mv_f16_f32_1row, support_simdgroup_reduction);
|
| 596 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4, mul_mv_f16_f32_l4, support_simdgroup_reduction);
|
| 597 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_0_F32, mul_mv_q4_0_f32, support_simdgroup_reduction);
|
| 598 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_1_F32, mul_mv_q4_1_f32, support_simdgroup_reduction);
|
| 599 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32, mul_mv_q5_0_f32, support_simdgroup_reduction);
|
| 600 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32, mul_mv_q5_1_f32, support_simdgroup_reduction);
|
| 601 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32, mul_mv_q8_0_f32, support_simdgroup_reduction);
|
| 602 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q2_K_F32, mul_mv_q2_K_f32, support_simdgroup_reduction);
|
| 603 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q3_K_F32, mul_mv_q3_K_f32, support_simdgroup_reduction);
|
| 604 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_K_F32, mul_mv_q4_K_f32, support_simdgroup_reduction);
|
| 605 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_K_F32, mul_mv_q5_K_f32, support_simdgroup_reduction);
|
| 606 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q6_K_F32, mul_mv_q6_K_f32, support_simdgroup_reduction);
|
| 607 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XXS_F32, mul_mv_iq2_xxs_f32, support_simdgroup_reduction);
|
| 608 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32, mul_mv_iq2_xs_f32, support_simdgroup_reduction);
|
| 609 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_XXS_F32, mul_mv_iq3_xxs_f32, support_simdgroup_reduction);
|
| 610 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_S_F32, mul_mv_iq3_s_f32, support_simdgroup_reduction);
|
| 611 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_S_F32, mul_mv_iq2_s_f32, support_simdgroup_reduction);
|
| 612 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_S_F32, mul_mv_iq1_s_f32, support_simdgroup_reduction);
|
| 613 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_M_F32, mul_mv_iq1_m_f32, support_simdgroup_reduction);
|
| 614 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32, mul_mv_iq4_nl_f32, support_simdgroup_reduction);
|
| 615 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_XS_F32, mul_mv_iq4_xs_f32, support_simdgroup_reduction);
|
| 616 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32, mul_mv_id_f32_f32, support_simdgroup_reduction);
|
| 617 |
+
//GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F16, mul_mv_id_f16_f16, support_simdgroup_reduction);
|
| 618 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32, mul_mv_id_f16_f32, support_simdgroup_reduction);
|
| 619 |
+
//GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_1ROW, mul_mv_id_f16_f32_1row, support_simdgroup_reduction);
|
| 620 |
+
//GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_L4, mul_mv_id_f16_f32_l4, support_simdgroup_reduction);
|
| 621 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_0_F32, mul_mv_id_q4_0_f32, support_simdgroup_reduction);
|
| 622 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_1_F32, mul_mv_id_q4_1_f32, support_simdgroup_reduction);
|
| 623 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_0_F32, mul_mv_id_q5_0_f32, support_simdgroup_reduction);
|
| 624 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_1_F32, mul_mv_id_q5_1_f32, support_simdgroup_reduction);
|
| 625 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q8_0_F32, mul_mv_id_q8_0_f32, support_simdgroup_reduction);
|
| 626 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q2_K_F32, mul_mv_id_q2_K_f32, support_simdgroup_reduction);
|
| 627 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q3_K_F32, mul_mv_id_q3_K_f32, support_simdgroup_reduction);
|
| 628 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_K_F32, mul_mv_id_q4_K_f32, support_simdgroup_reduction);
|
| 629 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_K_F32, mul_mv_id_q5_K_f32, support_simdgroup_reduction);
|
| 630 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q6_K_F32, mul_mv_id_q6_K_f32, support_simdgroup_reduction);
|
| 631 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XXS_F32, mul_mv_id_iq2_xxs_f32, support_simdgroup_reduction);
|
| 632 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32, mul_mv_id_iq2_xs_f32, support_simdgroup_reduction);
|
| 633 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_XXS_F32, mul_mv_id_iq3_xxs_f32, support_simdgroup_reduction);
|
| 634 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_S_F32, mul_mv_id_iq3_s_f32, support_simdgroup_reduction);
|
| 635 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_S_F32, mul_mv_id_iq2_s_f32, support_simdgroup_reduction);
|
| 636 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_S_F32, mul_mv_id_iq1_s_f32, support_simdgroup_reduction);
|
| 637 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_M_F32, mul_mv_id_iq1_m_f32, support_simdgroup_reduction);
|
| 638 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32, mul_mv_id_iq4_nl_f32, support_simdgroup_reduction);
|
| 639 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32, mul_mv_id_iq4_xs_f32, support_simdgroup_reduction);
|
| 640 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32, mul_mm_f32_f32, support_simdgroup_mm);
|
| 641 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32, mul_mm_f16_f32, support_simdgroup_mm);
|
| 642 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32, mul_mm_q4_0_f32, support_simdgroup_mm);
|
| 643 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F32, mul_mm_q4_1_f32, support_simdgroup_mm);
|
| 644 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32, mul_mm_q5_0_f32, support_simdgroup_mm);
|
| 645 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_1_F32, mul_mm_q5_1_f32, support_simdgroup_mm);
|
| 646 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q8_0_F32, mul_mm_q8_0_f32, support_simdgroup_mm);
|
| 647 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q2_K_F32, mul_mm_q2_K_f32, support_simdgroup_mm);
|
| 648 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q3_K_F32, mul_mm_q3_K_f32, support_simdgroup_mm);
|
| 649 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_K_F32, mul_mm_q4_K_f32, support_simdgroup_mm);
|
| 650 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_K_F32, mul_mm_q5_K_f32, support_simdgroup_mm);
|
| 651 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q6_K_F32, mul_mm_q6_K_f32, support_simdgroup_mm);
|
| 652 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32, mul_mm_iq2_xxs_f32, support_simdgroup_mm);
|
| 653 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32, mul_mm_iq2_xs_f32, support_simdgroup_mm);
|
| 654 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F32, mul_mm_iq3_xxs_f32, support_simdgroup_mm);
|
| 655 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_S_F32, mul_mm_iq3_s_f32, support_simdgroup_mm);
|
| 656 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_S_F32, mul_mm_iq2_s_f32, support_simdgroup_mm);
|
| 657 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32, mul_mm_iq1_s_f32, support_simdgroup_mm);
|
| 658 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F32, mul_mm_iq1_m_f32, support_simdgroup_mm);
|
| 659 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32, mul_mm_iq4_nl_f32, support_simdgroup_mm);
|
| 660 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32, mul_mm_iq4_xs_f32, support_simdgroup_mm);
|
| 661 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32, mul_mm_id_f32_f32, support_simdgroup_mm);
|
| 662 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32, mul_mm_id_f16_f32, support_simdgroup_mm);
|
| 663 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32, mul_mm_id_q4_0_f32, support_simdgroup_mm);
|
| 664 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F32, mul_mm_id_q4_1_f32, support_simdgroup_mm);
|
| 665 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F32, mul_mm_id_q5_0_f32, support_simdgroup_mm);
|
| 666 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F32, mul_mm_id_q5_1_f32, support_simdgroup_mm);
|
| 667 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F32, mul_mm_id_q8_0_f32, support_simdgroup_mm);
|
| 668 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F32, mul_mm_id_q2_K_f32, support_simdgroup_mm);
|
| 669 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F32, mul_mm_id_q3_K_f32, support_simdgroup_mm);
|
| 670 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F32, mul_mm_id_q4_K_f32, support_simdgroup_mm);
|
| 671 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_K_F32, mul_mm_id_q5_K_f32, support_simdgroup_mm);
|
| 672 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F32, mul_mm_id_q6_K_f32, support_simdgroup_mm);
|
| 673 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32, mul_mm_id_iq2_xxs_f32, support_simdgroup_mm);
|
| 674 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32, mul_mm_id_iq2_xs_f32, support_simdgroup_mm);
|
| 675 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32, mul_mm_id_iq3_xxs_f32, support_simdgroup_mm);
|
| 676 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_S_F32, mul_mm_id_iq3_s_f32, support_simdgroup_mm);
|
| 677 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F32, mul_mm_id_iq2_s_f32, support_simdgroup_mm);
|
| 678 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32, mul_mm_id_iq1_s_f32, support_simdgroup_mm);
|
| 679 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F32, mul_mm_id_iq1_m_f32, support_simdgroup_mm);
|
| 680 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32, mul_mm_id_iq4_nl_f32, support_simdgroup_mm);
|
| 681 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32, mul_mm_id_iq4_xs_f32, support_simdgroup_mm);
|
| 682 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32, rope_norm_f32, true);
|
| 683 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16, rope_norm_f16, true);
|
| 684 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32, rope_neox_f32, true);
|
|
|
|
| 692 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC, argsort_f32_i32_asc, true);
|
| 693 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC, argsort_f32_i32_desc, true);
|
| 694 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32, leaky_relu_f32, true);
|
| 695 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64, flash_attn_ext_f16_h64, support_simdgroup_mm);
|
| 696 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80, flash_attn_ext_f16_h80, support_simdgroup_mm);
|
| 697 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96, flash_attn_ext_f16_h96, support_simdgroup_mm);
|
| 698 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112, flash_attn_ext_f16_h112, support_simdgroup_mm);
|
| 699 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128, flash_attn_ext_f16_h128, support_simdgroup_mm);
|
| 700 |
+
//GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, flash_attn_ext_f16_h256, support_simdgroup_mm);
|
| 701 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128, flash_attn_ext_vec_f16_h128, support_simdgroup_reduction);
|
| 702 |
+
//GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, flash_attn_ext_vec_f16_h256, support_simdgroup_reduction);
|
| 703 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F16, cpy_f32_f16, true);
|
| 704 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F32, cpy_f32_f32, true);
|
| 705 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F16, cpy_f16_f16, true);
|
|
|
|
| 733 |
Block_release(ctx->encode_async);
|
| 734 |
|
| 735 |
[ctx->queue release];
|
|
|
|
| 736 |
|
| 737 |
dispatch_release(ctx->d_queue);
|
| 738 |
|
|
|
|
| 790 |
return nil;
|
| 791 |
}
|
| 792 |
|
| 793 |
+
static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_context * ctx_dev, const struct ggml_tensor * op) {
|
| 794 |
for (size_t i = 0, n = 3; i < n; ++i) {
|
| 795 |
if (op->src[i] != NULL && op->src[i]->type == GGML_TYPE_BF16) {
|
| 796 |
return false;
|
| 797 |
}
|
| 798 |
}
|
| 799 |
|
| 800 |
+
const bool support_simdgroup_mm = ctx_dev->support_simdgroup_mm;
|
| 801 |
+
const bool support_simdgroup_reduction = ctx_dev->support_simdgroup_reduction;
|
| 802 |
+
|
| 803 |
switch (op->op) {
|
| 804 |
case GGML_OP_UNARY:
|
| 805 |
switch (ggml_get_unary_op(op)) {
|
|
|
|
| 837 |
case GGML_OP_SOFT_MAX:
|
| 838 |
case GGML_OP_RMS_NORM:
|
| 839 |
case GGML_OP_GROUP_NORM:
|
| 840 |
+
return support_simdgroup_reduction;
|
| 841 |
case GGML_OP_NORM:
|
| 842 |
case GGML_OP_ROPE:
|
| 843 |
return true;
|
|
|
|
| 863 |
if (op->src[0]->ne[0] == 256) {
|
| 864 |
return false;
|
| 865 |
}
|
| 866 |
+
return support_simdgroup_mm; // TODO: over-restricted for vec-kernels
|
| 867 |
case GGML_OP_SSM_CONV:
|
| 868 |
case GGML_OP_SSM_SCAN:
|
| 869 |
return true;
|
| 870 |
case GGML_OP_MUL_MAT:
|
| 871 |
case GGML_OP_MUL_MAT_ID:
|
| 872 |
+
return support_simdgroup_reduction &&
|
| 873 |
(op->src[0]->type != GGML_TYPE_F32 || op->src[1]->type == GGML_TYPE_F32);
|
| 874 |
case GGML_OP_CPY:
|
| 875 |
case GGML_OP_DUP:
|
|
|
|
| 913 |
}
|
| 914 |
|
| 915 |
static void ggml_metal_encode_node(
|
| 916 |
+
ggml_backend_t backend,
|
| 917 |
int idx,
|
| 918 |
id<MTLComputeCommandEncoder> encoder) {
|
| 919 |
+
struct ggml_backend_metal_context * ctx = backend->context;
|
| 920 |
+
struct ggml_backend_metal_device_context * ctx_dev = backend->device->context;
|
| 921 |
+
|
| 922 |
struct ggml_cgraph * gf = ctx->gf;
|
| 923 |
|
| 924 |
struct ggml_tensor * node = ggml_graph_node(gf, idx);
|
|
|
|
| 948 |
} break;
|
| 949 |
}
|
| 950 |
|
| 951 |
+
if (!ggml_metal_supports_op(ctx_dev, dst)) {
|
| 952 |
GGML_LOG_ERROR("%s: error: unsupported op '%s'\n", __func__, ggml_op_desc(dst));
|
| 953 |
GGML_ABORT("unsupported op");
|
| 954 |
}
|
|
|
|
| 1021 |
// dst->name);
|
| 1022 |
//}
|
| 1023 |
|
| 1024 |
+
id<MTLDevice> device = ctx_dev->mtl_device;
|
| 1025 |
+
|
| 1026 |
switch (dst->op) {
|
| 1027 |
case GGML_OP_CONCAT:
|
| 1028 |
{
|
|
|
|
| 1731 |
// the numbers below are measured on M2 Ultra for 7B and 13B models
|
| 1732 |
// these numbers do not translate to other devices or model sizes
|
| 1733 |
// TODO: need to find a better approach
|
| 1734 |
+
if ([device.name isEqualToString:@"Apple M2 Ultra"]) {
|
| 1735 |
switch (src0t) {
|
| 1736 |
case GGML_TYPE_F16: ne11_mm_min = 2; break;
|
| 1737 |
case GGML_TYPE_Q8_0: ne11_mm_min = 7; break;
|
|
|
|
| 1751 |
|
| 1752 |
// for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
|
| 1753 |
// AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
|
| 1754 |
+
if ([device supportsFamily:MTLGPUFamilyApple7] &&
|
| 1755 |
!ggml_is_transposed(src0) &&
|
| 1756 |
!ggml_is_transposed(src1) &&
|
| 1757 |
src1t == GGML_TYPE_F32 &&
|
|
|
|
| 2046 |
// ne21 = n_rows
|
| 2047 |
const int dst_rows = ne20*ne21;
|
| 2048 |
const int dst_rows_min = n_as;
|
| 2049 |
+
const int dst_rows_max = (device.maxThreadgroupMemoryLength - 32 - 8192)/4;
|
| 2050 |
|
| 2051 |
// max size of the rowids array in the kernel shared buffer
|
| 2052 |
GGML_ASSERT(dst_rows <= dst_rows_max);
|
|
|
|
| 2057 |
// TODO: for now, always use mat-vec kernels until we figure out how to improve the
|
| 2058 |
// indirect matrix multiplication
|
| 2059 |
// !!!
|
| 2060 |
+
if ([device supportsFamily:MTLGPUFamilyApple7] &&
|
| 2061 |
ne00 % 32 == 0 && ne00 >= 64 &&
|
| 2062 |
dst_rows > dst_rows_min) {
|
| 2063 |
|
|
|
|
| 2896 |
|
| 2897 |
while (true) {
|
| 2898 |
const size_t smem = nqptg*(ne00 + 2*nsgmax*(ncpsg + nqptg))*(sizeof(float)/2);
|
| 2899 |
+
if (smem > device.maxThreadgroupMemoryLength) {
|
| 2900 |
break;
|
| 2901 |
}
|
| 2902 |
nsgmax *= 2;
|
|
|
|
| 2908 |
|
| 2909 |
const size_t smem = nqptg*(ne00 + 2*nsg*(ncpsg + nqptg))*(sizeof(float)/2);
|
| 2910 |
|
| 2911 |
+
//printf("smem: %zu, max: %zu\n", smem, device.maxThreadgroupMemoryLength);
|
| 2912 |
+
GGML_ASSERT(smem <= device.maxThreadgroupMemoryLength);
|
| 2913 |
|
| 2914 |
[encoder setThreadgroupMemoryLength:GGML_PAD(smem, 16) atIndex:0];
|
| 2915 |
|
|
|
|
| 2934 |
|
| 2935 |
const size_t smem = (nqptg*(ne00 + 2*nsg*(ncpsg + nqptg)) + nsg*ne00)*(sizeof(float)/2);
|
| 2936 |
|
| 2937 |
+
//printf("smem: %zu, max: %zu\n", smem, device.maxThreadgroupMemoryLength);
|
| 2938 |
+
GGML_ASSERT(smem <= device.maxThreadgroupMemoryLength);
|
| 2939 |
[encoder setThreadgroupMemoryLength:GGML_PAD(smem, 16) atIndex:0];
|
| 2940 |
|
| 2941 |
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)];
|
|
|
|
| 3010 |
}
|
| 3011 |
|
| 3012 |
static enum ggml_status ggml_metal_graph_compute(
|
| 3013 |
+
ggml_backend_t backend,
|
| 3014 |
+
struct ggml_cgraph * gf) {
|
| 3015 |
+
struct ggml_backend_metal_context * ctx = backend->context;
|
| 3016 |
+
struct ggml_backend_metal_device_context * ctx_dev = backend->device->context;
|
| 3017 |
+
|
| 3018 |
// number of nodes encoded by the main thread (empirically determined)
|
| 3019 |
const int n_main = 128;
|
| 3020 |
|
|
|
|
| 3042 |
|
| 3043 |
if (!ctx->capture_started) {
|
| 3044 |
// create capture scope
|
| 3045 |
+
ctx->capture_scope = [[MTLCaptureManager sharedCaptureManager] newCaptureScopeWithDevice:ctx_dev->mtl_device];
|
| 3046 |
|
| 3047 |
MTLCaptureDescriptor * descriptor = [MTLCaptureDescriptor new];
|
| 3048 |
descriptor.captureObject = ctx->capture_scope;
|
|
|
|
| 3146 |
|
| 3147 |
// backend interface
|
| 3148 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3149 |
static const char * ggml_backend_metal_buffer_get_name(ggml_backend_buffer_t buffer) {
|
| 3150 |
return "Metal";
|
| 3151 |
|
|
|
|
| 3158 |
for (int i = 0; i < ctx->n_buffers; i++) {
|
| 3159 |
[ctx->buffers[i].metal release];
|
| 3160 |
}
|
| 3161 |
+
ggml_backend_metal_device_rel(buffer->buft->device->context);
|
| 3162 |
|
| 3163 |
if (ctx->owned) {
|
| 3164 |
#if TARGET_OS_OSX
|
|
|
|
| 3261 |
size_aligned += (size_page - (size_aligned % size_page));
|
| 3262 |
}
|
| 3263 |
|
| 3264 |
+
id<MTLDevice> device = ggml_backend_metal_device_acq(buft->device->context);
|
| 3265 |
|
| 3266 |
ctx->all_data = ggml_metal_host_malloc(size_aligned);
|
| 3267 |
ctx->all_size = size_aligned;
|
|
|
|
| 3275 |
|
| 3276 |
if (size_aligned > 0) {
|
| 3277 |
ctx->buffers[0].metal = [device newBufferWithBytesNoCopy:ctx->all_data
|
| 3278 |
+
length:size_aligned
|
| 3279 |
+
options:MTLResourceStorageModeShared
|
| 3280 |
+
deallocator:nil];
|
| 3281 |
}
|
| 3282 |
}
|
| 3283 |
|
| 3284 |
if (size_aligned > 0 && (ctx->all_data == NULL || ctx->buffers[0].metal == nil)) {
|
| 3285 |
GGML_LOG_ERROR("%s: error: failed to allocate buffer, size = %8.2f MiB\n", __func__, size_aligned / 1024.0 / 1024.0);
|
| 3286 |
free(ctx);
|
| 3287 |
+
ggml_backend_metal_device_rel(buft->device->context);
|
| 3288 |
return NULL;
|
| 3289 |
}
|
| 3290 |
|
|
|
|
| 3299 |
}
|
| 3300 |
|
| 3301 |
static size_t ggml_backend_metal_buffer_type_get_max_size(ggml_backend_buffer_type_t buft) {
|
| 3302 |
+
id<MTLDevice> device = ggml_backend_metal_device_acq(buft->device->context);
|
| 3303 |
+
const size_t max_size = device.maxBufferLength;
|
| 3304 |
+
ggml_backend_metal_device_rel(buft->device->context);
|
| 3305 |
|
| 3306 |
return max_size;
|
| 3307 |
|
|
|
|
| 3324 |
/* .get_alloc_size = */ NULL, // defaults to ggml_nbytes
|
| 3325 |
/* .is_host = */ ggml_backend_metal_buffer_type_is_host,
|
| 3326 |
},
|
| 3327 |
+
/* .device = */ &g_ggml_backend_metal_device,
|
| 3328 |
/* .context = */ NULL,
|
| 3329 |
};
|
| 3330 |
|
| 3331 |
return &ggml_backend_buffer_type_metal;
|
| 3332 |
}
|
| 3333 |
|
| 3334 |
+
// TODO: obsoleted by ggml_backend_metal_device_buffer_from_ptr
|
|
|
|
| 3335 |
ggml_backend_buffer_t ggml_backend_metal_buffer_from_ptr(void * data, size_t size, size_t max_size) {
|
| 3336 |
struct ggml_backend_metal_buffer_context * ctx = calloc(1, sizeof(struct ggml_backend_metal_buffer_context));
|
| 3337 |
|
|
|
|
| 3354 |
size_aligned += (size_page - (size_aligned % size_page));
|
| 3355 |
}
|
| 3356 |
|
| 3357 |
+
id<MTLDevice> device = ggml_backend_metal_device_acq(&g_ggml_ctx_dev_main);
|
| 3358 |
|
| 3359 |
// the buffer fits into the max buffer size allowed by the device
|
| 3360 |
if (size_aligned <= device.maxBufferLength) {
|
|
|
|
| 3419 |
}
|
| 3420 |
|
| 3421 |
static void ggml_backend_metal_free(ggml_backend_t backend) {
|
| 3422 |
+
struct ggml_backend_metal_context * ctx = backend->context;
|
| 3423 |
+
struct ggml_backend_metal_device_context * ctx_dev = backend->device->context;
|
| 3424 |
+
|
| 3425 |
+
ggml_backend_metal_device_rel(ctx_dev);
|
| 3426 |
ggml_metal_free(ctx);
|
| 3427 |
+
|
| 3428 |
free(backend);
|
| 3429 |
}
|
| 3430 |
|
|
|
|
| 3435 |
}
|
| 3436 |
|
| 3437 |
static enum ggml_status ggml_backend_metal_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) {
|
| 3438 |
+
return ggml_metal_graph_compute(backend, cgraph);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3439 |
}
|
| 3440 |
|
| 3441 |
static void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) {
|
|
|
|
| 3482 |
[encoder pushDebugGroup:[NSString stringWithCString:ggml_op_desc(ggml_graph_node(ctx->gf, idx)) encoding:NSUTF8StringEncoding]];
|
| 3483 |
}
|
| 3484 |
|
| 3485 |
+
ggml_metal_encode_node(backend, idx, encoder);
|
| 3486 |
|
| 3487 |
if (should_capture) {
|
| 3488 |
[encoder popDebugGroup];
|
|
|
|
| 3510 |
/* .graph_plan_update = */ NULL,
|
| 3511 |
/* .graph_plan_compute = */ NULL,
|
| 3512 |
/* .graph_compute = */ ggml_backend_metal_graph_compute,
|
| 3513 |
+
/* .supports_op = */ NULL,
|
| 3514 |
+
/* .supports_buft = */ NULL,
|
| 3515 |
/* .offload_op = */ NULL,
|
| 3516 |
/* .event_record = */ NULL,
|
| 3517 |
/* .event_wait = */ NULL,
|
|
|
|
| 3522 |
return &guid;
|
| 3523 |
}
|
| 3524 |
|
| 3525 |
+
// TODO: remove in the future
|
| 3526 |
ggml_backend_t ggml_backend_metal_init(void) {
|
| 3527 |
+
ggml_backend_dev_t dev = ggml_backend_reg_dev_get(ggml_backend_metal_reg(), 0);
|
| 3528 |
+
|
| 3529 |
+
struct ggml_backend_metal_context * ctx = ggml_metal_init(dev);
|
| 3530 |
if (ctx == NULL) {
|
| 3531 |
GGML_LOG_ERROR("%s: error: failed to allocate context\n", __func__);
|
| 3532 |
return NULL;
|
|
|
|
| 3537 |
*backend = (struct ggml_backend) {
|
| 3538 |
/* .guid = */ ggml_backend_metal_guid(),
|
| 3539 |
/* .interface = */ ggml_backend_metal_i,
|
| 3540 |
+
/* .device = */ dev,
|
| 3541 |
/* .context = */ ctx,
|
| 3542 |
};
|
| 3543 |
|
|
|
|
| 3562 |
bool ggml_backend_metal_supports_family(ggml_backend_t backend, int family) {
|
| 3563 |
GGML_ASSERT(ggml_backend_is_metal(backend));
|
| 3564 |
|
| 3565 |
+
struct ggml_backend_metal_device_context * ctx_dev = backend->device->context;
|
| 3566 |
|
| 3567 |
+
return [ctx_dev->mtl_device supportsFamily:(MTLGPUFamilyApple1 + family - 1)];
|
| 3568 |
}
|
| 3569 |
|
| 3570 |
void ggml_backend_metal_capture_next_compute(ggml_backend_t backend) {
|
|
|
|
| 3574 |
ctx->capture_next_compute = true;
|
| 3575 |
}
|
| 3576 |
|
| 3577 |
+
// backend device
|
| 3578 |
+
|
| 3579 |
+
static const char * ggml_backend_metal_device_get_name(ggml_backend_dev_t dev) {
|
| 3580 |
+
return "Metal";
|
| 3581 |
|
| 3582 |
+
GGML_UNUSED(dev);
|
| 3583 |
+
}
|
| 3584 |
+
|
| 3585 |
+
static const char * ggml_backend_metal_device_get_description(ggml_backend_dev_t dev) {
|
| 3586 |
+
// acq/rel just to populate ctx->name in case it hasn't been done yet
|
| 3587 |
+
struct ggml_backend_metal_device_context * ctx_dev = (struct ggml_backend_metal_device_context *)dev->context;
|
| 3588 |
+
ggml_backend_metal_device_acq(ctx_dev);
|
| 3589 |
+
ggml_backend_metal_device_rel(ctx_dev);
|
| 3590 |
+
|
| 3591 |
+
return ctx_dev->name;
|
| 3592 |
+
}
|
| 3593 |
+
|
| 3594 |
+
static void ggml_backend_metal_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) {
|
| 3595 |
+
if (@available(macOS 10.12, iOS 16.0, *)) {
|
| 3596 |
+
struct ggml_backend_metal_device_context * ctx_dev = (struct ggml_backend_metal_device_context *)dev->context;
|
| 3597 |
+
id<MTLDevice> device = ggml_backend_metal_device_acq(ctx_dev);
|
| 3598 |
+
|
| 3599 |
+
*total = device.recommendedMaxWorkingSetSize;
|
| 3600 |
+
*free = *total - device.currentAllocatedSize;
|
| 3601 |
+
|
| 3602 |
+
ggml_backend_metal_device_rel(ctx_dev);
|
| 3603 |
+
} else {
|
| 3604 |
+
*free = 1;
|
| 3605 |
+
*total = 1;
|
| 3606 |
+
}
|
| 3607 |
+
}
|
| 3608 |
+
|
| 3609 |
+
static enum ggml_backend_dev_type ggml_backend_metal_device_get_type(ggml_backend_dev_t dev) {
|
| 3610 |
+
return GGML_BACKEND_DEVICE_TYPE_GPU_FULL;
|
| 3611 |
+
|
| 3612 |
+
GGML_UNUSED(dev);
|
| 3613 |
+
}
|
| 3614 |
+
|
| 3615 |
+
static void ggml_backend_metal_device_get_props(ggml_backend_dev_t dev, struct ggml_backend_dev_props * props) {
|
| 3616 |
+
props->name = ggml_backend_metal_device_get_name(dev);
|
| 3617 |
+
props->description = ggml_backend_metal_device_get_description(dev);
|
| 3618 |
+
props->type = ggml_backend_metal_device_get_type(dev);
|
| 3619 |
+
ggml_backend_metal_device_get_memory(dev, &props->memory_free, &props->memory_total);
|
| 3620 |
+
props->caps = (struct ggml_backend_dev_caps) {
|
| 3621 |
+
/* .async = */ false,
|
| 3622 |
+
/* .host_buffer = */ false,
|
| 3623 |
+
/* .buffer_from_host_ptr = */ true,
|
| 3624 |
+
/* .events = */ false,
|
| 3625 |
+
};
|
| 3626 |
+
}
|
| 3627 |
+
|
| 3628 |
+
static ggml_backend_t ggml_backend_metal_device_init(ggml_backend_dev_t dev, const char * params) {
|
| 3629 |
+
struct ggml_backend_metal_context * ctx = ggml_metal_init(dev);
|
| 3630 |
+
if (ctx == NULL) {
|
| 3631 |
+
GGML_LOG_ERROR("%s: error: failed to allocate context\n", __func__);
|
| 3632 |
+
return NULL;
|
| 3633 |
+
}
|
| 3634 |
+
|
| 3635 |
+
ggml_backend_t backend = malloc(sizeof(struct ggml_backend));
|
| 3636 |
+
|
| 3637 |
+
*backend = (struct ggml_backend) {
|
| 3638 |
+
/* .guid = */ ggml_backend_metal_guid(),
|
| 3639 |
+
/* .interface = */ ggml_backend_metal_i,
|
| 3640 |
+
/* .device = */ dev,
|
| 3641 |
+
/* .context = */ ctx,
|
| 3642 |
+
};
|
| 3643 |
+
|
| 3644 |
+
ggml_backend_metal_set_n_cb(backend, 1);
|
| 3645 |
+
|
| 3646 |
+
return backend;
|
| 3647 |
|
| 3648 |
GGML_UNUSED(params);
|
| 3649 |
+
}
|
| 3650 |
+
|
| 3651 |
+
static ggml_backend_buffer_type_t ggml_backend_metal_device_get_buffer_type(ggml_backend_dev_t dev) {
|
| 3652 |
+
return ggml_backend_metal_buffer_type();
|
| 3653 |
+
|
| 3654 |
+
GGML_UNUSED(dev);
|
| 3655 |
+
}
|
| 3656 |
+
|
| 3657 |
+
static ggml_backend_buffer_t ggml_backend_metal_device_buffer_from_ptr(ggml_backend_dev_t dev, void * ptr, size_t size, size_t max_tensor_size) {
|
| 3658 |
+
struct ggml_backend_metal_buffer_context * ctx = calloc(1, sizeof(struct ggml_backend_metal_buffer_context));
|
| 3659 |
+
|
| 3660 |
+
ctx->all_data = ptr;
|
| 3661 |
+
ctx->all_size = size;
|
| 3662 |
+
ctx->owned = false;
|
| 3663 |
+
ctx->n_buffers = 0;
|
| 3664 |
+
|
| 3665 |
+
const size_t size_page = sysconf(_SC_PAGESIZE);
|
| 3666 |
+
|
| 3667 |
+
// page-align the data ptr
|
| 3668 |
+
{
|
| 3669 |
+
const uintptr_t offs = (uintptr_t) ptr % size_page;
|
| 3670 |
+
ptr = (void *) ((char *) ptr - offs);
|
| 3671 |
+
size += offs;
|
| 3672 |
+
}
|
| 3673 |
+
|
| 3674 |
+
size_t size_aligned = size;
|
| 3675 |
+
if ((size_aligned % size_page) != 0) {
|
| 3676 |
+
size_aligned += (size_page - (size_aligned % size_page));
|
| 3677 |
+
}
|
| 3678 |
+
|
| 3679 |
+
struct ggml_backend_metal_device_context * ctx_dev = (struct ggml_backend_metal_device_context *)dev->context;
|
| 3680 |
+
id<MTLDevice> device = ggml_backend_metal_device_acq(ctx_dev);
|
| 3681 |
+
|
| 3682 |
+
// the buffer fits into the max buffer size allowed by the device
|
| 3683 |
+
if (size_aligned <= device.maxBufferLength) {
|
| 3684 |
+
ctx->buffers[ctx->n_buffers].data = ptr;
|
| 3685 |
+
ctx->buffers[ctx->n_buffers].size = size;
|
| 3686 |
+
ctx->buffers[ctx->n_buffers].metal = nil;
|
| 3687 |
+
|
| 3688 |
+
if (size_aligned > 0) {
|
| 3689 |
+
ctx->buffers[ctx->n_buffers].metal = [device newBufferWithBytesNoCopy:ptr length:size_aligned options:MTLResourceStorageModeShared deallocator:nil];
|
| 3690 |
+
|
| 3691 |
+
if (ctx->buffers[ctx->n_buffers].metal == nil) {
|
| 3692 |
+
GGML_LOG_ERROR("%s: error: failed to allocate buffer, size = %8.2f MiB\n", __func__, size_aligned / 1024.0 / 1024.0);
|
| 3693 |
+
return false;
|
| 3694 |
+
}
|
| 3695 |
+
}
|
| 3696 |
+
|
| 3697 |
+
ggml_backend_metal_log_allocated_size(device, size_aligned);
|
| 3698 |
+
|
| 3699 |
+
++ctx->n_buffers;
|
| 3700 |
+
} else {
|
| 3701 |
+
// this overlap between the views will guarantee that the tensor with the maximum size will fully fit into
|
| 3702 |
+
// one of the views
|
| 3703 |
+
const size_t size_ovlp = ((max_tensor_size + size_page - 1) / size_page + 1) * size_page; // round-up 2 pages just in case
|
| 3704 |
+
const size_t size_step = device.maxBufferLength - size_ovlp;
|
| 3705 |
+
const size_t size_view = device.maxBufferLength;
|
| 3706 |
+
|
| 3707 |
+
for (size_t i = 0; i < size; i += size_step) {
|
| 3708 |
+
const size_t size_step_aligned = (i + size_view <= size) ? size_view : (size_aligned - i);
|
| 3709 |
+
|
| 3710 |
+
ctx->buffers[ctx->n_buffers].data = (void *) ((uint8_t *) ptr + i);
|
| 3711 |
+
ctx->buffers[ctx->n_buffers].size = size_step_aligned;
|
| 3712 |
+
ctx->buffers[ctx->n_buffers].metal = nil;
|
| 3713 |
+
|
| 3714 |
+
if (size_step_aligned > 0) {
|
| 3715 |
+
ctx->buffers[ctx->n_buffers].metal = [device newBufferWithBytesNoCopy:(void *) ((uint8_t *) ptr + i) length:size_step_aligned options:MTLResourceStorageModeShared deallocator:nil];
|
| 3716 |
+
|
| 3717 |
+
if (ctx->buffers[ctx->n_buffers].metal == nil) {
|
| 3718 |
+
GGML_LOG_ERROR("%s: error: failed to allocate buffer, size = %8.2f MiB\n", __func__, size_step_aligned / 1024.0 / 1024.0);
|
| 3719 |
+
return false;
|
| 3720 |
+
}
|
| 3721 |
+
}
|
| 3722 |
+
|
| 3723 |
+
ggml_backend_metal_log_allocated_size(device, size_step_aligned);
|
| 3724 |
+
|
| 3725 |
+
if (i + size_step < size) {
|
| 3726 |
+
GGML_LOG_INFO("\n");
|
| 3727 |
+
}
|
| 3728 |
+
|
| 3729 |
+
++ctx->n_buffers;
|
| 3730 |
+
}
|
| 3731 |
+
}
|
| 3732 |
+
|
| 3733 |
+
return ggml_backend_buffer_init(ggml_backend_metal_buffer_type(), ggml_backend_metal_buffer_i, ctx, size);
|
| 3734 |
+
}
|
| 3735 |
+
|
| 3736 |
+
static bool ggml_backend_metal_device_supports_op(ggml_backend_dev_t dev, const struct ggml_tensor * op) {
|
| 3737 |
+
struct ggml_backend_metal_device_context * ctx_dev = dev->context;
|
| 3738 |
+
|
| 3739 |
+
return ggml_metal_supports_op(ctx_dev, op);
|
| 3740 |
+
}
|
| 3741 |
+
|
| 3742 |
+
static bool ggml_backend_metal_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) {
|
| 3743 |
+
return buft->iface.get_name == ggml_backend_metal_buffer_type_get_name;
|
| 3744 |
+
|
| 3745 |
+
UNUSED(dev);
|
| 3746 |
+
}
|
| 3747 |
+
|
| 3748 |
+
static bool ggml_backend_metal_device_offload_op(ggml_backend_dev_t dev, const struct ggml_tensor * op) {
|
| 3749 |
+
return false;
|
| 3750 |
+
|
| 3751 |
+
GGML_UNUSED(dev);
|
| 3752 |
+
GGML_UNUSED(op);
|
| 3753 |
+
}
|
| 3754 |
+
|
| 3755 |
+
static struct ggml_backend_device_i ggml_backend_metal_device_i = {
|
| 3756 |
+
/* .get_name = */ ggml_backend_metal_device_get_name,
|
| 3757 |
+
/* .get_description = */ ggml_backend_metal_device_get_description,
|
| 3758 |
+
/* .get_memory = */ ggml_backend_metal_device_get_memory,
|
| 3759 |
+
/* .get_type = */ ggml_backend_metal_device_get_type,
|
| 3760 |
+
/* .get_props = */ ggml_backend_metal_device_get_props,
|
| 3761 |
+
/* .init_backend = */ ggml_backend_metal_device_init,
|
| 3762 |
+
/* .get_buffer_type = */ ggml_backend_metal_device_get_buffer_type,
|
| 3763 |
+
/* .get_host_buffer_type = */ NULL,
|
| 3764 |
+
/* .buffer_from_host_ptr = */ ggml_backend_metal_device_buffer_from_ptr,
|
| 3765 |
+
/* .supports_op = */ ggml_backend_metal_device_supports_op,
|
| 3766 |
+
/* .supports_buft = */ ggml_backend_metal_device_supports_buft,
|
| 3767 |
+
/* .offload_op = */ ggml_backend_metal_device_offload_op,
|
| 3768 |
+
/* .event_new = */ NULL,
|
| 3769 |
+
/* .event_free = */ NULL,
|
| 3770 |
+
/* .event_synchronize = */ NULL,
|
| 3771 |
+
};
|
| 3772 |
+
|
| 3773 |
+
// backend registry
|
| 3774 |
+
|
| 3775 |
+
static const char * ggml_backend_metal_reg_get_name(ggml_backend_reg_t reg) {
|
| 3776 |
+
return "Metal";
|
| 3777 |
+
|
| 3778 |
+
GGML_UNUSED(reg);
|
| 3779 |
+
}
|
| 3780 |
+
|
| 3781 |
+
static size_t ggml_backend_metal_reg_device_count(ggml_backend_reg_t reg) {
|
| 3782 |
+
return 1;
|
| 3783 |
+
|
| 3784 |
+
GGML_UNUSED(reg);
|
| 3785 |
+
}
|
| 3786 |
+
|
| 3787 |
+
static ggml_backend_dev_t ggml_backend_metal_reg_device_get(ggml_backend_reg_t reg, size_t index) {
|
| 3788 |
+
GGML_ASSERT(index == 0);
|
| 3789 |
+
|
| 3790 |
+
return &g_ggml_backend_metal_device;
|
| 3791 |
+
|
| 3792 |
+
GGML_UNUSED(reg);
|
| 3793 |
+
GGML_UNUSED(index);
|
| 3794 |
+
}
|
| 3795 |
+
|
| 3796 |
+
static struct ggml_backend_reg_i ggml_backend_metal_reg_i = {
|
| 3797 |
+
/* .get_name = */ ggml_backend_metal_reg_get_name,
|
| 3798 |
+
/* .device_count = */ ggml_backend_metal_reg_device_count,
|
| 3799 |
+
/* .device_get = */ ggml_backend_metal_reg_device_get,
|
| 3800 |
+
/* .get_proc_address = */ NULL,
|
| 3801 |
+
};
|
| 3802 |
+
|
| 3803 |
+
ggml_backend_reg_t ggml_backend_metal_reg(void) {
|
| 3804 |
+
// TODO: make this thread-safe somehow?
|
| 3805 |
+
{
|
| 3806 |
+
g_ggml_backend_metal_reg = (struct ggml_backend_reg) {
|
| 3807 |
+
/* .iface = */ ggml_backend_metal_reg_i,
|
| 3808 |
+
/* .context = */ NULL,
|
| 3809 |
+
};
|
| 3810 |
+
|
| 3811 |
+
g_ggml_backend_metal_device = (struct ggml_backend_device) {
|
| 3812 |
+
/* .iface = */ ggml_backend_metal_device_i,
|
| 3813 |
+
/* .reg = */ &g_ggml_backend_metal_reg,
|
| 3814 |
+
/* .context = */ &g_ggml_ctx_dev_main,
|
| 3815 |
+
};
|
| 3816 |
+
}
|
| 3817 |
+
|
| 3818 |
+
return &g_ggml_backend_metal_reg;
|
| 3819 |
}
|