Spaces:
Running
Running
metal : fix fusion across different encoders (llama/14849)
Browse files* metal : fix fusion across different encoders
ggml-ci
* cont : add assertion
ggml-ci
ggml/src/ggml-metal/ggml-metal.m
CHANGED
|
@@ -1955,6 +1955,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
|
|
| 1955 |
static int ggml_metal_encode_node(
|
| 1956 |
ggml_backend_t backend,
|
| 1957 |
int idx,
|
|
|
|
| 1958 |
id<MTLComputeCommandEncoder> encoder,
|
| 1959 |
struct ggml_metal_mem_pool * mem_pool) {
|
| 1960 |
struct ggml_backend_metal_context * ctx = backend->context;
|
|
@@ -2181,7 +2182,9 @@ static int ggml_metal_encode_node(
|
|
| 2181 |
size_t offs_fuse;
|
| 2182 |
id<MTLBuffer> id_fuse;
|
| 2183 |
|
| 2184 |
-
|
|
|
|
|
|
|
| 2185 |
if (!ggml_can_fuse(gf, idx + n_fuse, ops + n_fuse, 2)) {
|
| 2186 |
break;
|
| 2187 |
}
|
|
@@ -4288,7 +4291,7 @@ static int ggml_metal_encode_node(
|
|
| 4288 |
ops[1] = GGML_OP_MUL;
|
| 4289 |
ops[2] = GGML_OP_ADD;
|
| 4290 |
|
| 4291 |
-
for (n_fuse = 0; n_fuse <= 1; ++n_fuse) {
|
| 4292 |
if (!ggml_can_fuse(gf, idx + n_fuse, ops + n_fuse, 2)) {
|
| 4293 |
break;
|
| 4294 |
}
|
|
@@ -6271,7 +6274,11 @@ static void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) {
|
|
| 6271 |
[encoder pushDebugGroup:[NSString stringWithCString:ggml_op_desc(ggml_graph_node(ctx->gf, idx)) encoding:NSUTF8StringEncoding]];
|
| 6272 |
}
|
| 6273 |
|
| 6274 |
-
const int res = ggml_metal_encode_node(backend, idx, encoder, mem_pool);
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6275 |
|
| 6276 |
if (should_capture) {
|
| 6277 |
[encoder popDebugGroup];
|
|
|
|
| 1955 |
static int ggml_metal_encode_node(
|
| 1956 |
ggml_backend_t backend,
|
| 1957 |
int idx,
|
| 1958 |
+
int idx_end,
|
| 1959 |
id<MTLComputeCommandEncoder> encoder,
|
| 1960 |
struct ggml_metal_mem_pool * mem_pool) {
|
| 1961 |
struct ggml_backend_metal_context * ctx = backend->context;
|
|
|
|
| 2182 |
size_t offs_fuse;
|
| 2183 |
id<MTLBuffer> id_fuse;
|
| 2184 |
|
| 2185 |
+
// note: in metal, we sometimes encode the graph in parallel so we have to avoid fusing nodes
|
| 2186 |
+
// across splits. idx_end indicates the last node in the current split
|
| 2187 |
+
for (n_fuse = 0; n_fuse <= 6 && idx + n_fuse + 1 < idx_end; ++n_fuse) {
|
| 2188 |
if (!ggml_can_fuse(gf, idx + n_fuse, ops + n_fuse, 2)) {
|
| 2189 |
break;
|
| 2190 |
}
|
|
|
|
| 4291 |
ops[1] = GGML_OP_MUL;
|
| 4292 |
ops[2] = GGML_OP_ADD;
|
| 4293 |
|
| 4294 |
+
for (n_fuse = 0; n_fuse <= 1 && idx + n_fuse + 1 < idx_end; ++n_fuse) {
|
| 4295 |
if (!ggml_can_fuse(gf, idx + n_fuse, ops + n_fuse, 2)) {
|
| 4296 |
break;
|
| 4297 |
}
|
|
|
|
| 6274 |
[encoder pushDebugGroup:[NSString stringWithCString:ggml_op_desc(ggml_graph_node(ctx->gf, idx)) encoding:NSUTF8StringEncoding]];
|
| 6275 |
}
|
| 6276 |
|
| 6277 |
+
const int res = ggml_metal_encode_node(backend, idx, node_end, encoder, mem_pool);
|
| 6278 |
+
if (idx + res > node_end) {
|
| 6279 |
+
GGML_ABORT("fusion error: nodes spanning multiple encoders have been fused. this indicates a bug in the fusion logic %s",
|
| 6280 |
+
"https://github.com/ggml-org/llama.cpp/pull/14849");
|
| 6281 |
+
}
|
| 6282 |
|
| 6283 |
if (should_capture) {
|
| 6284 |
[encoder popDebugGroup];
|