Spaces:
Running
Running
metal : add debug capture backend function (ggml/694)
Browse filesCo-authored-by: Georgi Gerganov <[email protected]>
- ggml-metal.h +3 -0
- ggml-metal.m +34 -6
ggml-metal.h
CHANGED
|
@@ -58,6 +58,9 @@ GGML_API GGML_CALL ggml_backend_buffer_type_t ggml_backend_metal_buffer_type(voi
|
|
| 58 |
// ref: https://developer.apple.com/metal/Metal-Feature-Set-Tables.pdf
|
| 59 |
GGML_API bool ggml_backend_metal_supports_family(ggml_backend_t backend, int family);
|
| 60 |
|
|
|
|
|
|
|
|
|
|
| 61 |
#ifdef __cplusplus
|
| 62 |
}
|
| 63 |
#endif
|
|
|
|
| 58 |
// ref: https://developer.apple.com/metal/Metal-Feature-Set-Tables.pdf
|
| 59 |
GGML_API bool ggml_backend_metal_supports_family(ggml_backend_t backend, int family);
|
| 60 |
|
| 61 |
+
// capture all command buffers committed the next time `ggml_backend_graph_compute` is called
|
| 62 |
+
GGML_API void ggml_backend_metal_capture_next_compute(ggml_backend_t backend);
|
| 63 |
+
|
| 64 |
#ifdef __cplusplus
|
| 65 |
}
|
| 66 |
#endif
|
ggml-metal.m
CHANGED
|
@@ -167,6 +167,8 @@ struct ggml_metal_context {
|
|
| 167 |
|
| 168 |
bool support_simdgroup_reduction;
|
| 169 |
bool support_simdgroup_mm;
|
|
|
|
|
|
|
| 170 |
};
|
| 171 |
|
| 172 |
// MSL code
|
|
@@ -684,6 +686,20 @@ static bool ggml_metal_graph_compute(
|
|
| 684 |
const int n_cb = ctx->n_cb;
|
| 685 |
const int n_nodes_per_cb = (n_nodes + n_cb - 1) / n_cb;
|
| 686 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 687 |
id<MTLCommandBuffer> command_buffer_builder[n_cb];
|
| 688 |
for (int cb_idx = 0; cb_idx < n_cb; ++cb_idx) {
|
| 689 |
id<MTLCommandBuffer> command_buffer = [ctx->queue commandBufferWithUnretainedReferences];
|
|
@@ -692,6 +708,7 @@ static bool ggml_metal_graph_compute(
|
|
| 692 |
// enqueue the command buffers in order to specify their execution order
|
| 693 |
[command_buffer enqueue];
|
| 694 |
}
|
|
|
|
| 695 |
const id<MTLCommandBuffer> *command_buffers = command_buffer_builder;
|
| 696 |
|
| 697 |
dispatch_apply(n_cb, ctx->d_queue, ^(size_t iter) {
|
|
@@ -738,9 +755,9 @@ static bool ggml_metal_graph_compute(
|
|
| 738 |
GGML_ASSERT(!"unsupported op");
|
| 739 |
}
|
| 740 |
|
| 741 |
-
|
| 742 |
-
|
| 743 |
-
|
| 744 |
|
| 745 |
const int64_t ne00 = src0 ? src0->ne[0] : 0;
|
| 746 |
const int64_t ne01 = src0 ? src0->ne[1] : 0;
|
|
@@ -2190,9 +2207,9 @@ static bool ggml_metal_graph_compute(
|
|
| 2190 |
}
|
| 2191 |
}
|
| 2192 |
|
| 2193 |
-
|
| 2194 |
-
|
| 2195 |
-
|
| 2196 |
}
|
| 2197 |
|
| 2198 |
[encoder endEncoding];
|
|
@@ -2214,6 +2231,10 @@ static bool ggml_metal_graph_compute(
|
|
| 2214 |
}
|
| 2215 |
}
|
| 2216 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2217 |
return true;
|
| 2218 |
}
|
| 2219 |
|
|
@@ -2575,6 +2596,13 @@ bool ggml_backend_metal_supports_family(ggml_backend_t backend, int family) {
|
|
| 2575 |
return [ctx->device supportsFamily:(MTLGPUFamilyApple1 + family - 1)];
|
| 2576 |
}
|
| 2577 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2578 |
GGML_CALL ggml_backend_t ggml_backend_reg_metal_init(const char * params, void * user_data); // silence warning
|
| 2579 |
|
| 2580 |
GGML_CALL ggml_backend_t ggml_backend_reg_metal_init(const char * params, void * user_data) {
|
|
|
|
| 167 |
|
| 168 |
bool support_simdgroup_reduction;
|
| 169 |
bool support_simdgroup_mm;
|
| 170 |
+
|
| 171 |
+
bool should_capture_next_compute;
|
| 172 |
};
|
| 173 |
|
| 174 |
// MSL code
|
|
|
|
| 686 |
const int n_cb = ctx->n_cb;
|
| 687 |
const int n_nodes_per_cb = (n_nodes + n_cb - 1) / n_cb;
|
| 688 |
|
| 689 |
+
const bool should_capture = ctx->should_capture_next_compute;
|
| 690 |
+
if (should_capture) {
|
| 691 |
+
ctx->should_capture_next_compute = false;
|
| 692 |
+
|
| 693 |
+
MTLCaptureDescriptor * descriptor = [MTLCaptureDescriptor new];
|
| 694 |
+
descriptor.captureObject = ctx->queue;
|
| 695 |
+
|
| 696 |
+
NSError * error = nil;
|
| 697 |
+
if (![[MTLCaptureManager sharedCaptureManager] startCaptureWithDescriptor:descriptor error:&error]) {
|
| 698 |
+
GGML_METAL_LOG_ERROR("%s: error: unable to start capture '%s'\n", __func__, [[error localizedDescription] UTF8String]);
|
| 699 |
+
GGML_ASSERT(!"capture failed");
|
| 700 |
+
}
|
| 701 |
+
}
|
| 702 |
+
|
| 703 |
id<MTLCommandBuffer> command_buffer_builder[n_cb];
|
| 704 |
for (int cb_idx = 0; cb_idx < n_cb; ++cb_idx) {
|
| 705 |
id<MTLCommandBuffer> command_buffer = [ctx->queue commandBufferWithUnretainedReferences];
|
|
|
|
| 708 |
// enqueue the command buffers in order to specify their execution order
|
| 709 |
[command_buffer enqueue];
|
| 710 |
}
|
| 711 |
+
|
| 712 |
const id<MTLCommandBuffer> *command_buffers = command_buffer_builder;
|
| 713 |
|
| 714 |
dispatch_apply(n_cb, ctx->d_queue, ^(size_t iter) {
|
|
|
|
| 755 |
GGML_ASSERT(!"unsupported op");
|
| 756 |
}
|
| 757 |
|
| 758 |
+
if (should_capture) {
|
| 759 |
+
[encoder pushDebugGroup:[NSString stringWithCString:ggml_op_desc(dst) encoding:NSUTF8StringEncoding]];
|
| 760 |
+
}
|
| 761 |
|
| 762 |
const int64_t ne00 = src0 ? src0->ne[0] : 0;
|
| 763 |
const int64_t ne01 = src0 ? src0->ne[1] : 0;
|
|
|
|
| 2207 |
}
|
| 2208 |
}
|
| 2209 |
|
| 2210 |
+
if (should_capture) {
|
| 2211 |
+
[encoder popDebugGroup];
|
| 2212 |
+
}
|
| 2213 |
}
|
| 2214 |
|
| 2215 |
[encoder endEncoding];
|
|
|
|
| 2231 |
}
|
| 2232 |
}
|
| 2233 |
|
| 2234 |
+
if (should_capture) {
|
| 2235 |
+
[[MTLCaptureManager sharedCaptureManager] stopCapture];
|
| 2236 |
+
}
|
| 2237 |
+
|
| 2238 |
return true;
|
| 2239 |
}
|
| 2240 |
|
|
|
|
| 2596 |
return [ctx->device supportsFamily:(MTLGPUFamilyApple1 + family - 1)];
|
| 2597 |
}
|
| 2598 |
|
| 2599 |
+
void ggml_backend_metal_capture_next_compute(ggml_backend_t backend) {
|
| 2600 |
+
GGML_ASSERT(ggml_backend_is_metal(backend));
|
| 2601 |
+
|
| 2602 |
+
struct ggml_metal_context * ctx = (struct ggml_metal_context *)backend->context;
|
| 2603 |
+
ctx->should_capture_next_compute = true;
|
| 2604 |
+
}
|
| 2605 |
+
|
| 2606 |
GGML_CALL ggml_backend_t ggml_backend_reg_metal_init(const char * params, void * user_data); // silence warning
|
| 2607 |
|
| 2608 |
GGML_CALL ggml_backend_t ggml_backend_reg_metal_init(const char * params, void * user_data) {
|