Spaces:
Running
Running
metal : single allocation of encode_async block (llama/9747)
Browse files* Single allocation of encode_async block with non-ARC capture in ggml-metal.m
* Moving Block_release to the deallocation code
* Release encode block when re-setting encoding buffer count if needed
* Update ggml/src/ggml-metal.m
---------
Co-authored-by: Georgi Gerganov <[email protected]>
- ggml/src/ggml-metal.m +47 -47
ggml/src/ggml-metal.m
CHANGED
|
@@ -239,8 +239,6 @@ struct ggml_backend_metal_context {
|
|
| 239 |
struct ggml_cgraph * gf;
|
| 240 |
|
| 241 |
// the callback given to the thread pool
|
| 242 |
-
// TODO: ideally, this should be created once, utilizing the command buffer state above
|
| 243 |
-
// for some reason, doing it like this leads to a crash
|
| 244 |
void (^encode_async)(size_t ith);
|
| 245 |
|
| 246 |
// n_cb command buffers + 1 used by the main thread
|
|
@@ -683,6 +681,8 @@ static void ggml_metal_free(struct ggml_backend_metal_context * ctx) {
|
|
| 683 |
[ctx->kernels[i].pipeline release];
|
| 684 |
}
|
| 685 |
|
|
|
|
|
|
|
| 686 |
[ctx->queue release];
|
| 687 |
[ctx->device release];
|
| 688 |
|
|
@@ -3000,46 +3000,6 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
| 3000 |
}
|
| 3001 |
}
|
| 3002 |
|
| 3003 |
-
// TODO: how to avoid this allocation? I tried initializing it in ggml_backend_metal_set_n_cb but it crashes.
|
| 3004 |
-
ctx->encode_async = ^(size_t iter) {
|
| 3005 |
-
const int cb_idx = iter;
|
| 3006 |
-
const int n_cb_l = ctx->n_cb;
|
| 3007 |
-
|
| 3008 |
-
const int n_nodes_0 = ctx->n_nodes_0;
|
| 3009 |
-
const int n_nodes_1 = ctx->n_nodes_1;
|
| 3010 |
-
|
| 3011 |
-
const int n_nodes_per_cb = ctx->n_nodes_per_cb;
|
| 3012 |
-
|
| 3013 |
-
id<MTLCommandBuffer> command_buffer = ctx->command_buffers[cb_idx];
|
| 3014 |
-
id<MTLComputeCommandEncoder> encoder = [command_buffer computeCommandEncoder];
|
| 3015 |
-
|
| 3016 |
-
int node_start = 0;
|
| 3017 |
-
int node_end = n_nodes_0;
|
| 3018 |
-
|
| 3019 |
-
if (cb_idx < n_cb_l) {
|
| 3020 |
-
node_start = n_nodes_0 + ( (cb_idx + 0) * n_nodes_per_cb);
|
| 3021 |
-
node_end = n_nodes_0 + (MIN((cb_idx == n_cb_l - 1) ? n_nodes_1 : (cb_idx + 1) * n_nodes_per_cb, n_nodes_1));
|
| 3022 |
-
}
|
| 3023 |
-
|
| 3024 |
-
for (int idx = node_start; idx < node_end; ++idx) {
|
| 3025 |
-
if (should_capture) {
|
| 3026 |
-
[encoder pushDebugGroup:[NSString stringWithCString:ggml_op_desc(ggml_graph_node(gf, idx)) encoding:NSUTF8StringEncoding]];
|
| 3027 |
-
}
|
| 3028 |
-
|
| 3029 |
-
ggml_metal_encode_node(ctx, idx, encoder);
|
| 3030 |
-
|
| 3031 |
-
if (should_capture) {
|
| 3032 |
-
[encoder popDebugGroup];
|
| 3033 |
-
}
|
| 3034 |
-
}
|
| 3035 |
-
|
| 3036 |
-
[encoder endEncoding];
|
| 3037 |
-
|
| 3038 |
-
if (cb_idx < 2 || ctx->abort_callback == NULL) {
|
| 3039 |
-
[command_buffer commit];
|
| 3040 |
-
}
|
| 3041 |
-
};
|
| 3042 |
-
|
| 3043 |
// the main thread commits the first few commands immediately
|
| 3044 |
// command_buffer[n_cb]
|
| 3045 |
{
|
|
@@ -3129,7 +3089,7 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
| 3129 |
|
| 3130 |
// default buffer
|
| 3131 |
static id<MTLDevice> g_backend_device = nil;
|
| 3132 |
-
static int g_backend_device_ref_count = 0;
|
| 3133 |
|
| 3134 |
static id<MTLDevice> ggml_backend_metal_get_device(void) {
|
| 3135 |
if (g_backend_device == nil) {
|
|
@@ -3468,10 +3428,50 @@ static void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) {
|
|
| 3468 |
}
|
| 3469 |
}
|
| 3470 |
|
| 3471 |
-
|
| 3472 |
-
|
| 3473 |
-
|
| 3474 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3475 |
}
|
| 3476 |
|
| 3477 |
static struct ggml_backend_i ggml_backend_metal_i = {
|
|
|
|
| 239 |
struct ggml_cgraph * gf;
|
| 240 |
|
| 241 |
// the callback given to the thread pool
|
|
|
|
|
|
|
| 242 |
void (^encode_async)(size_t ith);
|
| 243 |
|
| 244 |
// n_cb command buffers + 1 used by the main thread
|
|
|
|
| 681 |
[ctx->kernels[i].pipeline release];
|
| 682 |
}
|
| 683 |
|
| 684 |
+
Block_release(ctx->encode_async);
|
| 685 |
+
|
| 686 |
[ctx->queue release];
|
| 687 |
[ctx->device release];
|
| 688 |
|
|
|
|
| 3000 |
}
|
| 3001 |
}
|
| 3002 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3003 |
// the main thread commits the first few commands immediately
|
| 3004 |
// command_buffer[n_cb]
|
| 3005 |
{
|
|
|
|
| 3089 |
|
| 3090 |
// default buffer
|
| 3091 |
static id<MTLDevice> g_backend_device = nil;
|
| 3092 |
+
static int g_backend_device_ref_count = 0;
|
| 3093 |
|
| 3094 |
static id<MTLDevice> ggml_backend_metal_get_device(void) {
|
| 3095 |
if (g_backend_device == nil) {
|
|
|
|
| 3428 |
}
|
| 3429 |
}
|
| 3430 |
|
| 3431 |
+
if (ctx->encode_async) {
|
| 3432 |
+
Block_release(ctx->encode_async);
|
| 3433 |
+
}
|
| 3434 |
+
|
| 3435 |
+
ctx->encode_async = Block_copy(^(size_t iter) {
|
| 3436 |
+
const int cb_idx = iter;
|
| 3437 |
+
const int n_cb_l = ctx->n_cb;
|
| 3438 |
+
|
| 3439 |
+
const int n_nodes_0 = ctx->n_nodes_0;
|
| 3440 |
+
const int n_nodes_1 = ctx->n_nodes_1;
|
| 3441 |
+
|
| 3442 |
+
const int n_nodes_per_cb = ctx->n_nodes_per_cb;
|
| 3443 |
+
|
| 3444 |
+
id<MTLCommandBuffer> command_buffer = ctx->command_buffers[cb_idx];
|
| 3445 |
+
id<MTLComputeCommandEncoder> encoder = [command_buffer computeCommandEncoder];
|
| 3446 |
+
|
| 3447 |
+
int node_start = 0;
|
| 3448 |
+
int node_end = n_nodes_0;
|
| 3449 |
+
|
| 3450 |
+
if (cb_idx < n_cb_l) {
|
| 3451 |
+
node_start = n_nodes_0 + ( (cb_idx + 0) * n_nodes_per_cb);
|
| 3452 |
+
node_end = n_nodes_0 + (MIN((cb_idx == n_cb_l - 1) ? n_nodes_1 : (cb_idx + 1) * n_nodes_per_cb, n_nodes_1));
|
| 3453 |
+
}
|
| 3454 |
+
|
| 3455 |
+
const bool should_capture = ctx->capture_next_compute;
|
| 3456 |
+
|
| 3457 |
+
for (int idx = node_start; idx < node_end; ++idx) {
|
| 3458 |
+
if (should_capture) {
|
| 3459 |
+
[encoder pushDebugGroup:[NSString stringWithCString:ggml_op_desc(ggml_graph_node(ctx->gf, idx)) encoding:NSUTF8StringEncoding]];
|
| 3460 |
+
}
|
| 3461 |
+
|
| 3462 |
+
ggml_metal_encode_node(ctx, idx, encoder);
|
| 3463 |
+
|
| 3464 |
+
if (should_capture) {
|
| 3465 |
+
[encoder popDebugGroup];
|
| 3466 |
+
}
|
| 3467 |
+
}
|
| 3468 |
+
|
| 3469 |
+
[encoder endEncoding];
|
| 3470 |
+
|
| 3471 |
+
if (cb_idx < 2 || ctx->abort_callback == NULL) {
|
| 3472 |
+
[command_buffer commit];
|
| 3473 |
+
}
|
| 3474 |
+
});
|
| 3475 |
}
|
| 3476 |
|
| 3477 |
static struct ggml_backend_i ggml_backend_metal_i = {
|