ggerganov commited on
Commit
17d67da
·
1 Parent(s): bcbbf47

metal : fix fusion across different encoders (llama/14849)

Browse files

* metal : fix fusion across different encoders

ggml-ci

* cont : add assertion

ggml-ci

Files changed (1) hide show
  1. ggml/src/ggml-metal/ggml-metal.m +10 -3
ggml/src/ggml-metal/ggml-metal.m CHANGED
@@ -1955,6 +1955,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
1955
  static int ggml_metal_encode_node(
1956
  ggml_backend_t backend,
1957
  int idx,
 
1958
  id<MTLComputeCommandEncoder> encoder,
1959
  struct ggml_metal_mem_pool * mem_pool) {
1960
  struct ggml_backend_metal_context * ctx = backend->context;
@@ -2181,7 +2182,9 @@ static int ggml_metal_encode_node(
2181
  size_t offs_fuse;
2182
  id<MTLBuffer> id_fuse;
2183
 
2184
- for (n_fuse = 0; n_fuse <= 6; ++n_fuse) {
 
 
2185
  if (!ggml_can_fuse(gf, idx + n_fuse, ops + n_fuse, 2)) {
2186
  break;
2187
  }
@@ -4288,7 +4291,7 @@ static int ggml_metal_encode_node(
4288
  ops[1] = GGML_OP_MUL;
4289
  ops[2] = GGML_OP_ADD;
4290
 
4291
- for (n_fuse = 0; n_fuse <= 1; ++n_fuse) {
4292
  if (!ggml_can_fuse(gf, idx + n_fuse, ops + n_fuse, 2)) {
4293
  break;
4294
  }
@@ -6271,7 +6274,11 @@ static void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) {
6271
  [encoder pushDebugGroup:[NSString stringWithCString:ggml_op_desc(ggml_graph_node(ctx->gf, idx)) encoding:NSUTF8StringEncoding]];
6272
  }
6273
 
6274
- const int res = ggml_metal_encode_node(backend, idx, encoder, mem_pool);
 
 
 
 
6275
 
6276
  if (should_capture) {
6277
  [encoder popDebugGroup];
 
1955
  static int ggml_metal_encode_node(
1956
  ggml_backend_t backend,
1957
  int idx,
1958
+ int idx_end,
1959
  id<MTLComputeCommandEncoder> encoder,
1960
  struct ggml_metal_mem_pool * mem_pool) {
1961
  struct ggml_backend_metal_context * ctx = backend->context;
 
2182
  size_t offs_fuse;
2183
  id<MTLBuffer> id_fuse;
2184
 
2185
+ // note: in metal, we sometimes encode the graph in parallel so we have to avoid fusing nodes
2186
+ // across splits. idx_end indicates the last node in the current split
2187
+ for (n_fuse = 0; n_fuse <= 6 && idx + n_fuse + 1 < idx_end; ++n_fuse) {
2188
  if (!ggml_can_fuse(gf, idx + n_fuse, ops + n_fuse, 2)) {
2189
  break;
2190
  }
 
4291
  ops[1] = GGML_OP_MUL;
4292
  ops[2] = GGML_OP_ADD;
4293
 
4294
+ for (n_fuse = 0; n_fuse <= 1 && idx + n_fuse + 1 < idx_end; ++n_fuse) {
4295
  if (!ggml_can_fuse(gf, idx + n_fuse, ops + n_fuse, 2)) {
4296
  break;
4297
  }
 
6274
  [encoder pushDebugGroup:[NSString stringWithCString:ggml_op_desc(ggml_graph_node(ctx->gf, idx)) encoding:NSUTF8StringEncoding]];
6275
  }
6276
 
6277
+ const int res = ggml_metal_encode_node(backend, idx, node_end, encoder, mem_pool);
6278
+ if (idx + res > node_end) {
6279
+ GGML_ABORT("fusion error: nodes spanning multiple encoders have been fused. this indicates a bug in the fusion logic %s",
6280
+ "https://github.com/ggml-org/llama.cpp/pull/14849");
6281
+ }
6282
 
6283
  if (should_capture) {
6284
  [encoder popDebugGroup];