ggerganov commited on
Commit
66ae493
·
1 Parent(s): bb523fb

metal : fuse add, mul + add tests (llama/14596)

Browse files
ggml/src/ggml-alloc.c CHANGED
@@ -22,21 +22,6 @@ static bool ggml_is_view(const struct ggml_tensor * t) {
22
  return t->view_src != NULL;
23
  }
24
 
25
- static bool ggml_are_same_layout(const struct ggml_tensor * a, const struct ggml_tensor * b) {
26
- if (a->type != b->type) {
27
- return false;
28
- }
29
- for (int i = 0; i < GGML_MAX_DIMS; i++) {
30
- if (a->ne[i] != b->ne[i]) {
31
- return false;
32
- }
33
- if (a->nb[i] != b->nb[i]) {
34
- return false;
35
- }
36
- }
37
- return true;
38
- }
39
-
40
  // ops that return true for this function must not use restrict pointers for their backend implementations
41
  static bool ggml_op_can_inplace(enum ggml_op op) {
42
  switch (op) {
 
22
  return t->view_src != NULL;
23
  }
24
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  // ops that return true for this function must not use restrict pointers for their backend implementations
26
  static bool ggml_op_can_inplace(enum ggml_op op) {
27
  switch (op) {
ggml/src/ggml-backend.cpp CHANGED
@@ -352,21 +352,6 @@ ggml_backend_dev_t ggml_backend_get_device(ggml_backend_t backend) {
352
 
353
  // backend copy
354
 
355
- static bool ggml_are_same_layout(const struct ggml_tensor * a, const struct ggml_tensor * b) {
356
- if (a->type != b->type) {
357
- return false;
358
- }
359
- for (int i = 0; i < GGML_MAX_DIMS; i++) {
360
- if (a->ne[i] != b->ne[i]) {
361
- return false;
362
- }
363
- if (a->nb[i] != b->nb[i]) {
364
- return false;
365
- }
366
- }
367
- return true;
368
- }
369
-
370
  void ggml_backend_tensor_copy(struct ggml_tensor * src, struct ggml_tensor * dst) {
371
  GGML_ASSERT(ggml_are_same_layout(src, dst) && "cannot copy tensors with different layouts");
372
 
 
352
 
353
  // backend copy
354
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
355
  void ggml_backend_tensor_copy(struct ggml_tensor * src, struct ggml_tensor * dst) {
356
  GGML_ASSERT(ggml_are_same_layout(src, dst) && "cannot copy tensors with different layouts");
357
 
ggml/src/ggml-impl.h CHANGED
@@ -73,6 +73,22 @@ static inline int ggml_up(int n, int m) {
73
  return (n + m - 1) & ~(m - 1);
74
  }
75
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
  //
77
  // logging
78
  //
 
73
  return (n + m - 1) & ~(m - 1);
74
  }
75
 
76
+ // TODO: move to ggml.h?
77
+ static bool ggml_are_same_layout(const struct ggml_tensor * a, const struct ggml_tensor * b) {
78
+ if (a->type != b->type) {
79
+ return false;
80
+ }
81
+ for (int i = 0; i < GGML_MAX_DIMS; i++) {
82
+ if (a->ne[i] != b->ne[i]) {
83
+ return false;
84
+ }
85
+ if (a->nb[i] != b->nb[i]) {
86
+ return false;
87
+ }
88
+ }
89
+ return true;
90
+ }
91
+
92
  //
93
  // logging
94
  //
ggml/src/ggml-metal/ggml-metal-impl.h CHANGED
@@ -126,6 +126,7 @@ typedef struct {
126
  uint64_t nb2;
127
  uint64_t nb3;
128
  uint64_t offs;
 
129
  } ggml_metal_kargs_bin;
130
 
131
  typedef struct {
@@ -240,7 +241,7 @@ typedef struct {
240
  float max_bias;
241
  float m0;
242
  float m1;
243
- uint16_t n_head_log2;
244
  float logit_softcap;
245
  } ggml_metal_kargs_flash_attn_ext;
246
 
@@ -377,8 +378,16 @@ typedef struct {
377
  typedef struct {
378
  int32_t ne00;
379
  int32_t ne00_4;
380
- uint64_t nb01;
 
 
381
  float eps;
 
 
 
 
 
 
382
  } ggml_metal_kargs_rms_norm;
383
 
384
  typedef struct {
@@ -484,7 +493,7 @@ typedef struct {
484
  float max_bias;
485
  float m0;
486
  float m1;
487
- uint32_t n_head_log2;
488
  } ggml_metal_kargs_soft_max;
489
 
490
  typedef struct {
 
126
  uint64_t nb2;
127
  uint64_t nb3;
128
  uint64_t offs;
129
+ uint64_t o1[8];
130
  } ggml_metal_kargs_bin;
131
 
132
  typedef struct {
 
241
  float max_bias;
242
  float m0;
243
  float m1;
244
+ int32_t n_head_log2;
245
  float logit_softcap;
246
  } ggml_metal_kargs_flash_attn_ext;
247
 
 
378
  typedef struct {
379
  int32_t ne00;
380
  int32_t ne00_4;
381
+ uint64_t nb1;
382
+ uint64_t nb2;
383
+ uint64_t nb3;
384
  float eps;
385
+ int32_t nef1[3];
386
+ int32_t nef2[3];
387
+ int32_t nef3[3];
388
+ uint64_t nbf1[3];
389
+ uint64_t nbf2[3];
390
+ uint64_t nbf3[3];
391
  } ggml_metal_kargs_rms_norm;
392
 
393
  typedef struct {
 
493
  float max_bias;
494
  float m0;
495
  float m1;
496
+ int32_t n_head_log2;
497
  } ggml_metal_kargs_soft_max;
498
 
499
  typedef struct {
ggml/src/ggml-metal/ggml-metal.m CHANGED
@@ -55,6 +55,12 @@ static struct ggml_backend_metal_device_context {
55
  bool has_residency_sets;
56
  bool has_bfloat;
57
  bool use_bfloat;
 
 
 
 
 
 
58
 
59
  size_t max_size;
60
 
@@ -69,6 +75,9 @@ static struct ggml_backend_metal_device_context {
69
  /*.has_residency_sets =*/ false,
70
  /*.has_bfloat =*/ false,
71
  /*.use_bfloat =*/ false,
 
 
 
72
  /*.max_size =*/ 0,
73
  /*.name =*/ "",
74
  };
@@ -83,16 +92,14 @@ static id<MTLDevice> ggml_backend_metal_device_acq(struct ggml_backend_metal_dev
83
 
84
  if (ctx->mtl_device == nil) {
85
  ctx->mtl_device = MTLCreateSystemDefaultDevice();
86
- }
87
 
88
- if (ctx->mtl_device) {
89
  ctx->has_simdgroup_reduction = [ctx->mtl_device supportsFamily:MTLGPUFamilyApple7];
90
  ctx->has_simdgroup_reduction |= [ctx->mtl_device supportsFamily:MTLGPUFamilyMetal3_GGML];
91
 
92
  ctx->has_simdgroup_mm = [ctx->mtl_device supportsFamily:MTLGPUFamilyApple7];
93
 
94
  #if defined(GGML_METAL_HAS_RESIDENCY_SETS)
95
- ctx->has_residency_sets = getenv("GGML_METAL_NO_RESIDENCY") == NULL;
96
  #endif
97
 
98
  ctx->has_bfloat = [ctx->mtl_device supportsFamily:MTLGPUFamilyMetal3_GGML];
@@ -103,6 +110,14 @@ static id<MTLDevice> ggml_backend_metal_device_acq(struct ggml_backend_metal_dev
103
  #else
104
  ctx->use_bfloat = false;
105
  #endif
 
 
 
 
 
 
 
 
106
 
107
  ctx->max_size = ctx->mtl_device.maxBufferLength;
108
 
@@ -122,6 +137,18 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
122
  ctx->mtl_device_ref_count--;
123
 
124
  if (ctx->mtl_device_ref_count == 0) {
 
 
 
 
 
 
 
 
 
 
 
 
125
  if (ctx->mtl_lock) {
126
  [ctx->mtl_lock release];
127
  ctx->mtl_lock = nil;
@@ -147,13 +174,27 @@ struct ggml_metal_kernel {
147
 
148
  enum ggml_metal_kernel_type {
149
  GGML_METAL_KERNEL_TYPE_ADD,
150
- GGML_METAL_KERNEL_TYPE_ADD_ROW,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
151
  GGML_METAL_KERNEL_TYPE_SUB,
152
- GGML_METAL_KERNEL_TYPE_SUB_ROW,
153
  GGML_METAL_KERNEL_TYPE_MUL,
154
- GGML_METAL_KERNEL_TYPE_MUL_ROW,
155
  GGML_METAL_KERNEL_TYPE_DIV,
156
- GGML_METAL_KERNEL_TYPE_DIV_ROW,
157
  GGML_METAL_KERNEL_TYPE_REPEAT_F32,
158
  GGML_METAL_KERNEL_TYPE_REPEAT_F16,
159
  GGML_METAL_KERNEL_TYPE_REPEAT_I32,
@@ -218,6 +259,8 @@ enum ggml_metal_kernel_type {
218
  GGML_METAL_KERNEL_TYPE_SET_ROWS_Q5_1,
219
  GGML_METAL_KERNEL_TYPE_SET_ROWS_IQ4_NL,
220
  GGML_METAL_KERNEL_TYPE_RMS_NORM,
 
 
221
  GGML_METAL_KERNEL_TYPE_L2_NORM,
222
  GGML_METAL_KERNEL_TYPE_GROUP_NORM,
223
  GGML_METAL_KERNEL_TYPE_NORM,
@@ -1135,13 +1178,27 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
1135
  // simd_sum and simd_max requires MTLGPUFamilyApple7
1136
 
1137
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD, add, true);
1138
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW, add_row, true);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1139
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUB, sub, true);
1140
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUB_ROW, sub_row, true);
1141
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL, mul, true);
1142
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_ROW, mul_row, true);
1143
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV, div, true);
1144
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV_ROW, div_row, true);
1145
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_F32, repeat_f32, true);
1146
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_F16, repeat_f16, true);
1147
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_I32, repeat_i32, true);
@@ -1206,6 +1263,8 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
1206
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_ROWS_Q5_1, set_rows_q5_1, true);
1207
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_ROWS_IQ4_NL, set_rows_iq4_nl, true);
1208
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RMS_NORM, rms_norm, has_simdgroup_reduction);
 
 
1209
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_L2_NORM, l2_norm, has_simdgroup_reduction);
1210
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GROUP_NORM, group_norm, has_simdgroup_reduction);
1211
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NORM, norm, true);
@@ -1893,7 +1952,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
1893
  }
1894
  }
1895
 
1896
- static bool ggml_metal_encode_node(
1897
  ggml_backend_t backend,
1898
  int idx,
1899
  id<MTLComputeCommandEncoder> encoder,
@@ -1903,7 +1962,10 @@ static bool ggml_metal_encode_node(
1903
 
1904
  struct ggml_cgraph * gf = ctx->gf;
1905
 
1906
- struct ggml_tensor * node = ggml_graph_node(gf, idx);
 
 
 
1907
 
1908
  //GGML_LOG_INFO("%s: encoding node %3d, op = %8s\n", __func__, idx, ggml_op_name(node->op));
1909
 
@@ -1913,7 +1975,7 @@ static bool ggml_metal_encode_node(
1913
  struct ggml_tensor * dst = node;
1914
 
1915
  if (ggml_is_empty(dst)) {
1916
- return true;
1917
  }
1918
 
1919
  switch (dst->op) {
@@ -1924,7 +1986,7 @@ static bool ggml_metal_encode_node(
1924
  case GGML_OP_PERMUTE:
1925
  {
1926
  // noop -> next node
1927
- } return true;
1928
  default:
1929
  {
1930
  } break;
@@ -1991,6 +2053,8 @@ static bool ggml_metal_encode_node(
1991
  id<MTLBuffer> id_src2 = src2 ? ggml_metal_get_buffer(src2, &offs_src2) : nil;
1992
  id<MTLBuffer> id_dst = dst ? ggml_metal_get_buffer(dst, &offs_dst) : nil;
1993
 
 
 
1994
  #if 0
1995
  GGML_LOG_INFO("%s: op - %s\n", __func__, ggml_op_name(dst->op));
1996
  if (src0) {
@@ -2062,37 +2126,15 @@ static bool ggml_metal_encode_node(
2062
  GGML_ASSERT(src0t == GGML_TYPE_F32);
2063
  GGML_ASSERT(src1t == GGML_TYPE_F32);
2064
 
 
 
 
2065
  const size_t offs = 0;
2066
 
2067
  bool bcast_row = false;
2068
 
2069
  id<MTLComputePipelineState> pipeline = nil;
2070
 
2071
- if (ggml_nelements(src1) == ne10 && ggml_is_contiguous(src1) && ne00 % 4 == 0 && ne10 % 4 == 0) {
2072
- GGML_ASSERT(ggml_is_contiguous(src0));
2073
-
2074
- // src1 is a row
2075
- GGML_ASSERT(ne11 == 1);
2076
-
2077
- switch (dst->op) {
2078
- case GGML_OP_ADD: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW].pipeline; break;
2079
- case GGML_OP_SUB: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SUB_ROW].pipeline; break;
2080
- case GGML_OP_MUL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_ROW].pipeline; break;
2081
- case GGML_OP_DIV: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIV_ROW].pipeline; break;
2082
- default: GGML_ABORT("fatal error");
2083
- }
2084
-
2085
- bcast_row = true;
2086
- } else {
2087
- switch (dst->op) {
2088
- case GGML_OP_ADD: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD].pipeline; break;
2089
- case GGML_OP_SUB: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SUB].pipeline; break;
2090
- case GGML_OP_MUL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL].pipeline; break;
2091
- case GGML_OP_DIV: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIV].pipeline; break;
2092
- default: GGML_ABORT("fatal error");
2093
- }
2094
- }
2095
-
2096
  ggml_metal_kargs_bin args = {
2097
  /*.ne00 =*/ ne00,
2098
  /*.ne01 =*/ ne01,
@@ -2119,12 +2161,117 @@ static bool ggml_metal_encode_node(
2119
  /*.nb2 =*/ nb2,
2120
  /*.nb3 =*/ nb3,
2121
  /*.offs =*/ offs,
 
2122
  };
2123
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2124
  [encoder setComputePipelineState:pipeline];
2125
  [encoder setBytes:&args length:sizeof(args) atIndex:0];
2126
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
2127
- [encoder setBuffer:id_src1 offset:offs_src1 atIndex:2];
2128
  [encoder setBuffer:id_dst offset:offs_dst atIndex:3];
2129
 
2130
  if (bcast_row) {
@@ -2132,7 +2279,11 @@ static bool ggml_metal_encode_node(
2132
 
2133
  [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
2134
  } else {
2135
- const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne0);
 
 
 
 
2136
 
2137
  [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
2138
  }
@@ -2257,12 +2408,13 @@ static bool ggml_metal_encode_node(
2257
  /*.nb2 =*/ pnb2,
2258
  /*.nb3 =*/ pnb3,
2259
  /*.offs =*/ offs,
 
2260
  };
2261
 
2262
  [encoder setComputePipelineState:pipeline];
2263
  [encoder setBytes:&args length:sizeof(args) atIndex:0];
2264
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
2265
- [encoder setBuffer:id_src1 offset:offs_src1 atIndex:2];
2266
  [encoder setBuffer:id_dst offset:offs_dst atIndex:3];
2267
 
2268
  const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne00);
@@ -2764,7 +2916,7 @@ static bool ggml_metal_encode_node(
2764
  id<MTLBuffer> h_src0 = h_src0 = ggml_metal_mem_pool_alloc(mem_pool, ggml_nbytes(src0));
2765
  if (!h_src0) {
2766
  GGML_LOG_ERROR("%s: failed to allocate buffer from memory pool, size = %zu\n", __func__, ggml_nbytes(src0));
2767
- return false;
2768
  }
2769
 
2770
  offs_src0 = 0;
@@ -3640,7 +3792,7 @@ static bool ggml_metal_encode_node(
3640
  id<MTLBuffer> h_src1 = ggml_metal_mem_pool_alloc(mem_pool, s_src1);
3641
  if (!h_src1) {
3642
  GGML_LOG_ERROR("%s: failed to allocate buffer from memory pool, size = %zu\n", __func__, s_src1);
3643
- return false;
3644
  }
3645
 
3646
  const int64_t neh0 = ne0;
@@ -3656,7 +3808,7 @@ static bool ggml_metal_encode_node(
3656
  id<MTLBuffer> h_dst = ggml_metal_mem_pool_alloc(mem_pool, s_dst);
3657
  if (!h_dst) {
3658
  GGML_LOG_ERROR("%s: failed to allocate buffer from memory pool, size = %zu\n", __func__, s_dst);
3659
- return false;
3660
  }
3661
 
3662
  // tokens per expert
@@ -3664,7 +3816,7 @@ static bool ggml_metal_encode_node(
3664
  id<MTLBuffer> h_tpe = ggml_metal_mem_pool_alloc(mem_pool, s_tpe);
3665
  if (!h_tpe) {
3666
  GGML_LOG_ERROR("%s: failed to allocate buffer from memory pool, size = %zu\n", __func__, s_tpe);
3667
- return false;
3668
  }
3669
 
3670
  // id map
@@ -3673,7 +3825,7 @@ static bool ggml_metal_encode_node(
3673
  id<MTLBuffer> h_ids = ggml_metal_mem_pool_alloc(mem_pool, s_ids);
3674
  if (!h_ids) {
3675
  GGML_LOG_ERROR("%s: failed to allocate buffer from memory pool, size = %zu\n", __func__, s_ids);
3676
- return false;
3677
  }
3678
 
3679
  {
@@ -4105,12 +4257,95 @@ static bool ggml_metal_encode_node(
4105
  case GGML_OP_RMS_NORM:
4106
  {
4107
  GGML_ASSERT(ne00 % 4 == 0);
4108
- GGML_ASSERT(ggml_is_contiguous_1(src0));
4109
 
4110
  float eps;
4111
  memcpy(&eps, dst->op_params, sizeof(float));
4112
 
4113
- id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_RMS_NORM].pipeline;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4114
 
4115
  int nth = 32; // SIMD width
4116
 
@@ -4121,23 +4356,16 @@ static bool ggml_metal_encode_node(
4121
  nth = MIN(nth, (int) pipeline.maxTotalThreadsPerThreadgroup);
4122
  nth = MIN(nth, ne00/4);
4123
 
4124
- ggml_metal_kargs_rms_norm args = {
4125
- /*.ne00 =*/ ne00,
4126
- /*.ne00_4 =*/ ne00/4,
4127
- /*.nb01 =*/ nb01,
4128
- /*.eps =*/ eps,
4129
- };
4130
-
4131
  [encoder setComputePipelineState:pipeline];
4132
- [encoder setBytes:&args length:sizeof(args) atIndex:0];
4133
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
4134
- [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
 
 
4135
 
4136
  [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
4137
 
4138
- const int64_t nrows = ggml_nrows(src0);
4139
-
4140
- [encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
4141
  } break;
4142
  case GGML_OP_L2_NORM:
4143
  {
@@ -5532,7 +5760,7 @@ static bool ggml_metal_encode_node(
5532
  }
5533
  }
5534
 
5535
- return true;
5536
  }
5537
 
5538
  static enum ggml_status ggml_metal_graph_compute(
@@ -6038,20 +6266,22 @@ static void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) {
6038
  struct ggml_metal_mem_pool * mem_pool = ctx->cmd_bufs[cb_idx].mem_pool;
6039
  ggml_metal_mem_pool_reset(mem_pool);
6040
 
6041
- for (int idx = node_start; idx < node_end; ++idx) {
6042
  if (should_capture) {
6043
  [encoder pushDebugGroup:[NSString stringWithCString:ggml_op_desc(ggml_graph_node(ctx->gf, idx)) encoding:NSUTF8StringEncoding]];
6044
  }
6045
 
6046
- const bool res = ggml_metal_encode_node(backend, idx, encoder, mem_pool);
6047
 
6048
  if (should_capture) {
6049
  [encoder popDebugGroup];
6050
  }
6051
 
6052
- if (!res) {
6053
  break;
6054
  }
 
 
6055
  }
6056
 
6057
  [encoder endEncoding];
 
55
  bool has_residency_sets;
56
  bool has_bfloat;
57
  bool use_bfloat;
58
+ bool use_fusion;
59
+
60
+ int debug_fusion;
61
+
62
+ // how many times a given op was fused
63
+ uint64_t fuse_cnt[GGML_OP_COUNT];
64
 
65
  size_t max_size;
66
 
 
75
  /*.has_residency_sets =*/ false,
76
  /*.has_bfloat =*/ false,
77
  /*.use_bfloat =*/ false,
78
+ /*.use_fusion =*/ true,
79
+ /*.debug_fusion =*/ 0,
80
+ /*.fuse_cnt =*/ { 0 },
81
  /*.max_size =*/ 0,
82
  /*.name =*/ "",
83
  };
 
92
 
93
  if (ctx->mtl_device == nil) {
94
  ctx->mtl_device = MTLCreateSystemDefaultDevice();
 
95
 
 
96
  ctx->has_simdgroup_reduction = [ctx->mtl_device supportsFamily:MTLGPUFamilyApple7];
97
  ctx->has_simdgroup_reduction |= [ctx->mtl_device supportsFamily:MTLGPUFamilyMetal3_GGML];
98
 
99
  ctx->has_simdgroup_mm = [ctx->mtl_device supportsFamily:MTLGPUFamilyApple7];
100
 
101
  #if defined(GGML_METAL_HAS_RESIDENCY_SETS)
102
+ ctx->has_residency_sets = getenv("GGML_METAL_NO_RESIDENCY") == nil;
103
  #endif
104
 
105
  ctx->has_bfloat = [ctx->mtl_device supportsFamily:MTLGPUFamilyMetal3_GGML];
 
110
  #else
111
  ctx->use_bfloat = false;
112
  #endif
113
+ ctx->use_fusion = getenv("GGML_METAL_FUSION_DISABLE") == nil;
114
+
115
+ {
116
+ const char * val = getenv("GGML_METAL_FUSION_DEBUG");
117
+ ctx->debug_fusion = val ? atoi(val) : 0;
118
+ }
119
+
120
+ memset(ctx->fuse_cnt, 0, sizeof(ctx->fuse_cnt));
121
 
122
  ctx->max_size = ctx->mtl_device.maxBufferLength;
123
 
 
137
  ctx->mtl_device_ref_count--;
138
 
139
  if (ctx->mtl_device_ref_count == 0) {
140
+ if (ctx->debug_fusion > 0) {
141
+ fprintf(stderr, "%s: fusion stats:\n", __func__);
142
+ for (int i = 0; i < GGML_OP_COUNT; i++) {
143
+ if (ctx->fuse_cnt[i] == 0) {
144
+ continue;
145
+ }
146
+
147
+ // note: cannot use ggml_log here
148
+ fprintf(stderr, "%s: - %s: %" PRIu64 "\n", __func__, ggml_op_name((enum ggml_op) i), ctx->fuse_cnt[i]);
149
+ }
150
+ }
151
+
152
  if (ctx->mtl_lock) {
153
  [ctx->mtl_lock release];
154
  ctx->mtl_lock = nil;
 
174
 
175
  enum ggml_metal_kernel_type {
176
  GGML_METAL_KERNEL_TYPE_ADD,
177
+ GGML_METAL_KERNEL_TYPE_ADD_FUSE_2,
178
+ GGML_METAL_KERNEL_TYPE_ADD_FUSE_3,
179
+ GGML_METAL_KERNEL_TYPE_ADD_FUSE_4,
180
+ GGML_METAL_KERNEL_TYPE_ADD_FUSE_5,
181
+ GGML_METAL_KERNEL_TYPE_ADD_FUSE_6,
182
+ GGML_METAL_KERNEL_TYPE_ADD_FUSE_7,
183
+ GGML_METAL_KERNEL_TYPE_ADD_FUSE_8,
184
+ GGML_METAL_KERNEL_TYPE_ADD_ROW_C4,
185
+ GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_2,
186
+ GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_3,
187
+ GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_4,
188
+ GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_5,
189
+ GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_6,
190
+ GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_7,
191
+ GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_8,
192
  GGML_METAL_KERNEL_TYPE_SUB,
193
+ GGML_METAL_KERNEL_TYPE_SUB_ROW_C4,
194
  GGML_METAL_KERNEL_TYPE_MUL,
195
+ GGML_METAL_KERNEL_TYPE_MUL_ROW_C4,
196
  GGML_METAL_KERNEL_TYPE_DIV,
197
+ GGML_METAL_KERNEL_TYPE_DIV_ROW_C4,
198
  GGML_METAL_KERNEL_TYPE_REPEAT_F32,
199
  GGML_METAL_KERNEL_TYPE_REPEAT_F16,
200
  GGML_METAL_KERNEL_TYPE_REPEAT_I32,
 
259
  GGML_METAL_KERNEL_TYPE_SET_ROWS_Q5_1,
260
  GGML_METAL_KERNEL_TYPE_SET_ROWS_IQ4_NL,
261
  GGML_METAL_KERNEL_TYPE_RMS_NORM,
262
+ GGML_METAL_KERNEL_TYPE_RMS_NORM_MUL,
263
+ GGML_METAL_KERNEL_TYPE_RMS_NORM_MUL_ADD,
264
  GGML_METAL_KERNEL_TYPE_L2_NORM,
265
  GGML_METAL_KERNEL_TYPE_GROUP_NORM,
266
  GGML_METAL_KERNEL_TYPE_NORM,
 
1178
  // simd_sum and simd_max requires MTLGPUFamilyApple7
1179
 
1180
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD, add, true);
1181
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_FUSE_2, add_fuse_2, true);
1182
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_FUSE_3, add_fuse_3, true);
1183
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_FUSE_4, add_fuse_4, true);
1184
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_FUSE_5, add_fuse_5, true);
1185
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_FUSE_6, add_fuse_6, true);
1186
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_FUSE_7, add_fuse_7, true);
1187
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_FUSE_8, add_fuse_8, true);
1188
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW_C4, add_row_c4, true);
1189
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_2, add_row_c4_fuse_2, true);
1190
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_3, add_row_c4_fuse_3, true);
1191
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_4, add_row_c4_fuse_4, true);
1192
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_5, add_row_c4_fuse_5, true);
1193
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_6, add_row_c4_fuse_6, true);
1194
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_7, add_row_c4_fuse_7, true);
1195
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_8, add_row_c4_fuse_8, true);
1196
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUB, sub, true);
1197
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUB_ROW_C4, sub_row_c4, true);
1198
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL, mul, true);
1199
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_ROW_C4, mul_row_c4, true);
1200
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV, div, true);
1201
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV_ROW_C4, div_row_c4, true);
1202
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_F32, repeat_f32, true);
1203
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_F16, repeat_f16, true);
1204
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_I32, repeat_i32, true);
 
1263
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_ROWS_Q5_1, set_rows_q5_1, true);
1264
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_ROWS_IQ4_NL, set_rows_iq4_nl, true);
1265
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RMS_NORM, rms_norm, has_simdgroup_reduction);
1266
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RMS_NORM_MUL, rms_norm_mul, has_simdgroup_reduction);
1267
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RMS_NORM_MUL_ADD, rms_norm_mul_add, has_simdgroup_reduction);
1268
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_L2_NORM, l2_norm, has_simdgroup_reduction);
1269
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GROUP_NORM, group_norm, has_simdgroup_reduction);
1270
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NORM, norm, true);
 
1952
  }
1953
  }
1954
 
1955
+ static int ggml_metal_encode_node(
1956
  ggml_backend_t backend,
1957
  int idx,
1958
  id<MTLComputeCommandEncoder> encoder,
 
1962
 
1963
  struct ggml_cgraph * gf = ctx->gf;
1964
 
1965
+ enum ggml_op ops[8];
1966
+
1967
+ struct ggml_tensor ** nodes = ggml_graph_nodes(gf) + idx;
1968
+ struct ggml_tensor * node = nodes[0];
1969
 
1970
  //GGML_LOG_INFO("%s: encoding node %3d, op = %8s\n", __func__, idx, ggml_op_name(node->op));
1971
 
 
1975
  struct ggml_tensor * dst = node;
1976
 
1977
  if (ggml_is_empty(dst)) {
1978
+ return 1;
1979
  }
1980
 
1981
  switch (dst->op) {
 
1986
  case GGML_OP_PERMUTE:
1987
  {
1988
  // noop -> next node
1989
+ } return 1;
1990
  default:
1991
  {
1992
  } break;
 
2053
  id<MTLBuffer> id_src2 = src2 ? ggml_metal_get_buffer(src2, &offs_src2) : nil;
2054
  id<MTLBuffer> id_dst = dst ? ggml_metal_get_buffer(dst, &offs_dst) : nil;
2055
 
2056
+ int n_fuse = 1;
2057
+
2058
  #if 0
2059
  GGML_LOG_INFO("%s: op - %s\n", __func__, ggml_op_name(dst->op));
2060
  if (src0) {
 
2126
  GGML_ASSERT(src0t == GGML_TYPE_F32);
2127
  GGML_ASSERT(src1t == GGML_TYPE_F32);
2128
 
2129
+ GGML_ASSERT(ggml_is_contiguous_rows(src0));
2130
+ GGML_ASSERT(ggml_is_contiguous_rows(src1));
2131
+
2132
  const size_t offs = 0;
2133
 
2134
  bool bcast_row = false;
2135
 
2136
  id<MTLComputePipelineState> pipeline = nil;
2137
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2138
  ggml_metal_kargs_bin args = {
2139
  /*.ne00 =*/ ne00,
2140
  /*.ne01 =*/ ne01,
 
2161
  /*.nb2 =*/ nb2,
2162
  /*.nb3 =*/ nb3,
2163
  /*.offs =*/ offs,
2164
+ /*.o1 =*/ { offs_src1 },
2165
  };
2166
 
2167
+ // c[0] = add(a, b[0])
2168
+ // c[1] = add(c[0], b[1])
2169
+ // c[2] = add(c[1], b[2])
2170
+ // ...
2171
+ if (ctx_dev->use_fusion) {
2172
+ ops[0] = GGML_OP_ADD;
2173
+ ops[1] = GGML_OP_ADD;
2174
+ ops[2] = GGML_OP_ADD;
2175
+ ops[3] = GGML_OP_ADD;
2176
+ ops[4] = GGML_OP_ADD;
2177
+ ops[5] = GGML_OP_ADD;
2178
+ ops[6] = GGML_OP_ADD;
2179
+ ops[7] = GGML_OP_ADD;
2180
+
2181
+ size_t offs_fuse;
2182
+ id<MTLBuffer> id_fuse;
2183
+
2184
+ for (n_fuse = 0; n_fuse <= 6; ++n_fuse) {
2185
+ if (!ggml_can_fuse(gf, idx + n_fuse, ops + n_fuse, 2)) {
2186
+ break;
2187
+ }
2188
+
2189
+ if (nodes[n_fuse] != nodes[n_fuse + 1]->src[0]) {
2190
+ break;
2191
+ }
2192
+
2193
+ // b[0] === b[1] === ...
2194
+ if (!ggml_are_same_layout(nodes[n_fuse]->src[1], nodes[n_fuse + 1]->src[1])) {
2195
+ break;
2196
+ }
2197
+
2198
+ // only fuse nodes if src1 is in the same Metal buffer
2199
+ id_fuse = ggml_metal_get_buffer(nodes[n_fuse + 1]->src[1], &offs_fuse);
2200
+ if (id_fuse != id_src1) {
2201
+ break;
2202
+ }
2203
+
2204
+ ctx_dev->fuse_cnt[nodes[n_fuse + 1]->op]++;
2205
+
2206
+ args.o1[n_fuse + 1] = offs_fuse;
2207
+ }
2208
+
2209
+ ++n_fuse;
2210
+
2211
+ if (ctx_dev->debug_fusion > 1 && n_fuse > 1) {
2212
+ GGML_LOG_DEBUG("%s: fuse: ADD x %d\n", __func__, n_fuse);
2213
+ }
2214
+ }
2215
+
2216
+ if (ggml_nelements(src1) == ne10 && ggml_is_contiguous(src1) && ne00 % 4 == 0 && ne10 % 4 == 0) {
2217
+ GGML_ASSERT(ggml_is_contiguous(src0));
2218
+
2219
+ // src1 is a row
2220
+ GGML_ASSERT(ne11 == 1);
2221
+
2222
+ switch (dst->op) {
2223
+ case GGML_OP_ADD:
2224
+ {
2225
+ switch (n_fuse) {
2226
+ case 1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW_C4 ].pipeline; break;
2227
+ case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_2].pipeline; break;
2228
+ case 3: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_3].pipeline; break;
2229
+ case 4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_4].pipeline; break;
2230
+ case 5: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_5].pipeline; break;
2231
+ case 6: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_6].pipeline; break;
2232
+ case 7: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_7].pipeline; break;
2233
+ case 8: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_8].pipeline; break;
2234
+ default: GGML_ABORT("fatal error");
2235
+ }
2236
+ } break;
2237
+ case GGML_OP_SUB: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SUB_ROW_C4].pipeline; break;
2238
+ case GGML_OP_MUL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_ROW_C4].pipeline; break;
2239
+ case GGML_OP_DIV: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIV_ROW_C4].pipeline; break;
2240
+ default: GGML_ABORT("fatal error");
2241
+ }
2242
+
2243
+ bcast_row = true;
2244
+ } else {
2245
+ switch (dst->op) {
2246
+ case GGML_OP_ADD:
2247
+ {
2248
+ switch (n_fuse) {
2249
+ case 1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD ].pipeline; break;
2250
+ case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_FUSE_2].pipeline; break;
2251
+ case 3: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_FUSE_3].pipeline; break;
2252
+ case 4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_FUSE_4].pipeline; break;
2253
+ case 5: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_FUSE_5].pipeline; break;
2254
+ case 6: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_FUSE_6].pipeline; break;
2255
+ case 7: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_FUSE_7].pipeline; break;
2256
+ case 8: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_FUSE_8].pipeline; break;
2257
+ default: GGML_ABORT("fatal error");
2258
+ }
2259
+ } break;
2260
+ case GGML_OP_SUB: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SUB].pipeline; break;
2261
+ case GGML_OP_MUL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL].pipeline; break;
2262
+ case GGML_OP_DIV: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIV].pipeline; break;
2263
+ default: GGML_ABORT("fatal error");
2264
+ }
2265
+ }
2266
+
2267
+ if (n_fuse > 1) {
2268
+ id_dst = ggml_metal_get_buffer(nodes[n_fuse - 1], &offs_dst);
2269
+ }
2270
+
2271
  [encoder setComputePipelineState:pipeline];
2272
  [encoder setBytes:&args length:sizeof(args) atIndex:0];
2273
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
2274
+ [encoder setBuffer:id_src1 offset:0 atIndex:2];
2275
  [encoder setBuffer:id_dst offset:offs_dst atIndex:3];
2276
 
2277
  if (bcast_row) {
 
2279
 
2280
  [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
2281
  } else {
2282
+ int nth = 32;
2283
+
2284
+ while (16*nth < ne0 && nth < (int) pipeline.maxTotalThreadsPerThreadgroup) {
2285
+ nth *= 2;
2286
+ }
2287
 
2288
  [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
2289
  }
 
2408
  /*.nb2 =*/ pnb2,
2409
  /*.nb3 =*/ pnb3,
2410
  /*.offs =*/ offs,
2411
+ /*.o1 =*/ { offs_src1},
2412
  };
2413
 
2414
  [encoder setComputePipelineState:pipeline];
2415
  [encoder setBytes:&args length:sizeof(args) atIndex:0];
2416
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
2417
+ [encoder setBuffer:id_src1 offset:0 atIndex:2];
2418
  [encoder setBuffer:id_dst offset:offs_dst atIndex:3];
2419
 
2420
  const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne00);
 
2916
  id<MTLBuffer> h_src0 = h_src0 = ggml_metal_mem_pool_alloc(mem_pool, ggml_nbytes(src0));
2917
  if (!h_src0) {
2918
  GGML_LOG_ERROR("%s: failed to allocate buffer from memory pool, size = %zu\n", __func__, ggml_nbytes(src0));
2919
+ return 0;
2920
  }
2921
 
2922
  offs_src0 = 0;
 
3792
  id<MTLBuffer> h_src1 = ggml_metal_mem_pool_alloc(mem_pool, s_src1);
3793
  if (!h_src1) {
3794
  GGML_LOG_ERROR("%s: failed to allocate buffer from memory pool, size = %zu\n", __func__, s_src1);
3795
+ return 0;
3796
  }
3797
 
3798
  const int64_t neh0 = ne0;
 
3808
  id<MTLBuffer> h_dst = ggml_metal_mem_pool_alloc(mem_pool, s_dst);
3809
  if (!h_dst) {
3810
  GGML_LOG_ERROR("%s: failed to allocate buffer from memory pool, size = %zu\n", __func__, s_dst);
3811
+ return 0;
3812
  }
3813
 
3814
  // tokens per expert
 
3816
  id<MTLBuffer> h_tpe = ggml_metal_mem_pool_alloc(mem_pool, s_tpe);
3817
  if (!h_tpe) {
3818
  GGML_LOG_ERROR("%s: failed to allocate buffer from memory pool, size = %zu\n", __func__, s_tpe);
3819
+ return 0;
3820
  }
3821
 
3822
  // id map
 
3825
  id<MTLBuffer> h_ids = ggml_metal_mem_pool_alloc(mem_pool, s_ids);
3826
  if (!h_ids) {
3827
  GGML_LOG_ERROR("%s: failed to allocate buffer from memory pool, size = %zu\n", __func__, s_ids);
3828
+ return 0;
3829
  }
3830
 
3831
  {
 
4257
  case GGML_OP_RMS_NORM:
4258
  {
4259
  GGML_ASSERT(ne00 % 4 == 0);
4260
+ GGML_ASSERT(ggml_is_contiguous_rows(src0));
4261
 
4262
  float eps;
4263
  memcpy(&eps, dst->op_params, sizeof(float));
4264
 
4265
+ ggml_metal_kargs_rms_norm args = {
4266
+ /*.ne00 =*/ ne00,
4267
+ /*.ne00_4 =*/ ne00/4,
4268
+ /*.nb1 =*/ nb1,
4269
+ /*.nb2 =*/ nb2,
4270
+ /*.nb3 =*/ nb3,
4271
+ /*.eps =*/ eps,
4272
+ /*.nef1 =*/ { ne01 },
4273
+ /*.nef2 =*/ { ne02 },
4274
+ /*.nef3 =*/ { ne03 },
4275
+ /*.nbf1 =*/ { nb01 },
4276
+ /*.nbf2 =*/ { nb02 },
4277
+ /*.nbf3 =*/ { nb03 },
4278
+ };
4279
+
4280
+ size_t offs_fuse[2] = { 0, 0 };
4281
+ id<MTLBuffer> id_fuse[2] = { id_src0, id_src0 };
4282
+
4283
+ // d[0] = rms_norm(a)
4284
+ // d[1] = mul(d[0], b)
4285
+ // d[2] = add(d[1], c)
4286
+ if (ctx_dev->use_fusion) {
4287
+ ops[0] = GGML_OP_RMS_NORM;
4288
+ ops[1] = GGML_OP_MUL;
4289
+ ops[2] = GGML_OP_ADD;
4290
+
4291
+ for (n_fuse = 0; n_fuse <= 1; ++n_fuse) {
4292
+ if (!ggml_can_fuse(gf, idx + n_fuse, ops + n_fuse, 2)) {
4293
+ break;
4294
+ }
4295
+
4296
+ if (nodes[n_fuse] != nodes[n_fuse + 1]->src[0]) {
4297
+ break;
4298
+ }
4299
+
4300
+ if (nodes[n_fuse + 1]->src[1]->ne[0] != node->ne[0]) {
4301
+ break;
4302
+ }
4303
+
4304
+ if (!ggml_is_contiguous_rows(nodes[n_fuse + 1]->src[1])) {
4305
+ break;
4306
+ }
4307
+
4308
+ if (nodes[n_fuse + 1]->type != GGML_TYPE_F32) {
4309
+ break;
4310
+ }
4311
+
4312
+ ctx_dev->fuse_cnt[nodes[n_fuse + 1]->op]++;
4313
+
4314
+ id_fuse[n_fuse] = ggml_metal_get_buffer(nodes[n_fuse + 1]->src[1], &offs_fuse[n_fuse]);
4315
+
4316
+ args.nef1[n_fuse + 1] = nodes[n_fuse + 1]->src[1]->ne[1];
4317
+ args.nef2[n_fuse + 1] = nodes[n_fuse + 1]->src[1]->ne[2];
4318
+ args.nef3[n_fuse + 1] = nodes[n_fuse + 1]->src[1]->ne[3];
4319
+
4320
+ args.nbf1[n_fuse + 1] = nodes[n_fuse + 1]->src[1]->nb[1];
4321
+ args.nbf2[n_fuse + 1] = nodes[n_fuse + 1]->src[1]->nb[2];
4322
+ args.nbf3[n_fuse + 1] = nodes[n_fuse + 1]->src[1]->nb[3];
4323
+ }
4324
+
4325
+ ++n_fuse;
4326
+
4327
+ if (ctx_dev->debug_fusion > 1 && n_fuse > 1) {
4328
+ if (n_fuse == 2) {
4329
+ GGML_LOG_DEBUG("%s: fuse: RMS_NORM + MUL\n", __func__);
4330
+ }
4331
+ if (n_fuse == 3) {
4332
+ GGML_LOG_DEBUG("%s: fuse: RMS_NORM + MUL + ADD\n", __func__);
4333
+ }
4334
+ }
4335
+ }
4336
+
4337
+ if (n_fuse > 1) {
4338
+ id_dst = ggml_metal_get_buffer(nodes[n_fuse - 1], &offs_dst);
4339
+ }
4340
+
4341
+ id<MTLComputePipelineState> pipeline;
4342
+
4343
+ switch (n_fuse) {
4344
+ case 1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_RMS_NORM ].pipeline; break;
4345
+ case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_RMS_NORM_MUL ].pipeline; break;
4346
+ case 3: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_RMS_NORM_MUL_ADD].pipeline; break;
4347
+ default: GGML_ABORT("unsupported n_fuse = %d\n", n_fuse);
4348
+ }
4349
 
4350
  int nth = 32; // SIMD width
4351
 
 
4356
  nth = MIN(nth, (int) pipeline.maxTotalThreadsPerThreadgroup);
4357
  nth = MIN(nth, ne00/4);
4358
 
 
 
 
 
 
 
 
4359
  [encoder setComputePipelineState:pipeline];
4360
+ [encoder setBytes:&args length:sizeof(args) atIndex:0];
4361
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
4362
+ [encoder setBuffer:id_fuse[0] offset:offs_fuse[0] atIndex:2];
4363
+ [encoder setBuffer:id_fuse[1] offset:offs_fuse[1] atIndex:3];
4364
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:4];
4365
 
4366
  [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
4367
 
4368
+ [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
 
 
4369
  } break;
4370
  case GGML_OP_L2_NORM:
4371
  {
 
5760
  }
5761
  }
5762
 
5763
+ return n_fuse;
5764
  }
5765
 
5766
  static enum ggml_status ggml_metal_graph_compute(
 
6266
  struct ggml_metal_mem_pool * mem_pool = ctx->cmd_bufs[cb_idx].mem_pool;
6267
  ggml_metal_mem_pool_reset(mem_pool);
6268
 
6269
+ for (int idx = node_start; idx < node_end;) {
6270
  if (should_capture) {
6271
  [encoder pushDebugGroup:[NSString stringWithCString:ggml_op_desc(ggml_graph_node(ctx->gf, idx)) encoding:NSUTF8StringEncoding]];
6272
  }
6273
 
6274
+ const int res = ggml_metal_encode_node(backend, idx, encoder, mem_pool);
6275
 
6276
  if (should_capture) {
6277
  [encoder popDebugGroup];
6278
  }
6279
 
6280
+ if (res == 0) {
6281
  break;
6282
  }
6283
+
6284
+ idx += res;
6285
  }
6286
 
6287
  [encoder endEncoding];
ggml/src/ggml-metal/ggml-metal.metal CHANGED
@@ -832,7 +832,8 @@ enum ggml_sort_order {
832
  // general-purpose kernel for addition, subtraction, multiplication and division of two tensors
833
  // pros: works for non-contiguous tensors, supports broadcast across all dims
834
  // cons: not very efficient
835
- kernel void kernel_add(
 
836
  constant ggml_metal_kargs_bin & args,
837
  device const char * src0,
838
  device const char * src1,
@@ -848,16 +849,39 @@ kernel void kernel_add(
848
  const int i12 = i02%args.ne12;
849
  const int i11 = i01%args.ne11;
850
 
851
- device const char * src0_ptr = src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + args.offs;
852
- device const char * src1_ptr = src1 + i13*args.nb13 + i12*args.nb12 + i11*args.nb11;
853
- device char * dst_ptr = dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1 + args.offs;
 
 
 
 
854
 
855
  for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
856
  const int i10 = i0%args.ne10;
857
- *((device float *)(dst_ptr + i0*args.nb0)) = *((device float *)(src0_ptr + i0*args.nb00)) + *((device float *)(src1_ptr + i10*args.nb10));
 
 
 
 
 
 
 
 
858
  }
859
  }
860
 
 
 
 
 
 
 
 
 
 
 
 
861
  kernel void kernel_sub(
862
  constant ggml_metal_kargs_bin & args,
863
  device const char * src0,
@@ -875,7 +899,7 @@ kernel void kernel_sub(
875
  const int i11 = i01%args.ne11;
876
 
877
  device const char * src0_ptr = src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + args.offs;
878
- device const char * src1_ptr = src1 + i13*args.nb13 + i12*args.nb12 + i11*args.nb11;
879
  device char * dst_ptr = dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1 + args.offs;
880
 
881
  for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
@@ -900,9 +924,9 @@ kernel void kernel_mul(
900
  const int i12 = i02%args.ne12;
901
  const int i11 = i01%args.ne11;
902
 
903
- device const char * src0_ptr = src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01;
904
- device const char * src1_ptr = src1 + i13*args.nb13 + i12*args.nb12 + i11*args.nb11;
905
- device char * dst_ptr = dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1;
906
 
907
  for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
908
  const int i10 = i0%args.ne10;
@@ -926,9 +950,9 @@ kernel void kernel_div(
926
  const int i12 = i02%args.ne12;
927
  const int i11 = i01%args.ne11;
928
 
929
- device const char * src0_ptr = src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01;
930
- device const char * src1_ptr = src1 + i13*args.nb13 + i12*args.nb12 + i11*args.nb11;
931
- device char * dst_ptr = dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1;
932
 
933
  for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
934
  const int i10 = i0%args.ne10;
@@ -970,46 +994,145 @@ template [[host_name("kernel_repeat_i16")]] kernel kernel_repeat_t kernel_repeat
970
 
971
  // assumption: src1 is a row
972
  // broadcast src1 into src0
973
- kernel void kernel_add_row(
 
974
  constant ggml_metal_kargs_bin & args,
975
- device const float4 * src0,
976
- device const float4 * src1,
977
- device float4 * dst,
978
  uint tpig[[thread_position_in_grid]]) {
 
979
  const uint nb = args.ne00/4;
980
- dst[tpig] = src0[tpig] + src1[tpig % nb];
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
981
  }
982
 
983
- kernel void kernel_sub_row(
 
 
 
 
 
 
 
 
 
 
 
 
984
  constant ggml_metal_kargs_bin & args,
985
- device const float4 * src0,
986
- device const float4 * src1,
987
- device float4 * dst,
988
  uint tpig[[thread_position_in_grid]]) {
 
989
  const uint nb = args.ne00/4;
990
- dst[tpig] = src0[tpig] - src1[tpig % nb];
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
991
  }
992
 
993
- kernel void kernel_mul_row(
 
 
 
 
 
994
  constant ggml_metal_kargs_bin & args,
995
- device const float4 * src0,
996
- device const float4 * src1,
997
- device float4 * dst,
998
  uint tpig[[thread_position_in_grid]]) {
 
999
  const uint nb = args.ne00/4;
1000
- dst[tpig] = src0[tpig] * src1[tpig % nb];
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1001
  }
1002
 
1003
- kernel void kernel_div_row(
 
 
 
 
 
1004
  constant ggml_metal_kargs_bin & args,
1005
- device const float4 * src0,
1006
- device const float4 * src1,
1007
- device float4 * dst,
1008
  uint tpig[[thread_position_in_grid]]) {
 
1009
  const uint nb = args.ne00/4;
1010
- dst[tpig] = src0[tpig] / src1[tpig % nb];
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1011
  }
1012
 
 
 
 
 
1013
  kernel void kernel_scale(
1014
  device const float * src0,
1015
  device float * dst,
@@ -2116,26 +2239,39 @@ kernel void kernel_norm(
2116
  }
2117
  }
2118
 
2119
- kernel void kernel_rms_norm(
 
 
 
 
2120
  constant ggml_metal_kargs_rms_norm & args,
2121
  device const char * src0,
 
 
2122
  device char * dst,
2123
  threadgroup float * shmem_f32 [[threadgroup(0)]],
2124
- uint tgpig[[threadgroup_position_in_grid]],
2125
- ushort tpitg[[thread_position_in_threadgroup]],
2126
- ushort sgitg[[simdgroup_index_in_threadgroup]],
2127
- ushort tiisg[[thread_index_in_simdgroup]],
2128
- ushort ntg[[threads_per_threadgroup]]) {
2129
  if (sgitg == 0) {
2130
  shmem_f32[tiisg] = 0.0f;
2131
  }
2132
 
2133
- device const float4 * x = (device const float4 *) (src0 + tgpig*args.nb01);
 
 
 
 
 
 
 
2134
 
2135
  float sumf = 0.0f;
2136
 
2137
  // parallel sum
2138
- for (int i00 = tpitg; i00 < args.ne00_4; i00 += ntg) {
2139
  sumf += dot(x[i00], x[i00]);
2140
  }
2141
  sumf = simd_sum(sumf);
@@ -2154,12 +2290,26 @@ kernel void kernel_rms_norm(
2154
  const float mean = sumf/args.ne00;
2155
  const float scale = 1.0f/sqrt(mean + args.eps);
2156
 
2157
- device float4 * y = (device float4 *) dst + tgpig*args.ne00_4;
2158
- for (int i00 = tpitg; i00 < args.ne00_4; i00 += ntg) {
2159
- y[i00] = x[i00] * scale;
 
 
 
 
 
 
 
 
2160
  }
2161
  }
2162
 
 
 
 
 
 
 
2163
  kernel void kernel_l2_norm(
2164
  constant ggml_metal_kargs_l2_norm & args,
2165
  device const char * src0,
 
832
  // general-purpose kernel for addition, subtraction, multiplication and division of two tensors
833
  // pros: works for non-contiguous tensors, supports broadcast across all dims
834
  // cons: not very efficient
835
+ template <int F>
836
+ kernel void kernel_add_fuse_impl(
837
  constant ggml_metal_kargs_bin & args,
838
  device const char * src0,
839
  device const char * src1,
 
849
  const int i12 = i02%args.ne12;
850
  const int i11 = i01%args.ne11;
851
 
852
+ device const float * src0_ptr = (device const float *) (src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + args.offs);
853
+ device float * dst_ptr = (device float *) (dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1 + args.offs);
854
+
855
+ device const float * src1_ptr[F];
856
+ for (short j = 0; j < F; ++j) {
857
+ src1_ptr[j] = (device const float *) (src1 + args.o1[j] + i13*args.nb13 + i12*args.nb12 + i11*args.nb11);
858
+ }
859
 
860
  for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
861
  const int i10 = i0%args.ne10;
862
+
863
+ float res = src0_ptr[i0];
864
+
865
+ #pragma unroll
866
+ for (short j = 0; j < F; ++j) {
867
+ res += src1_ptr[j][i10];
868
+ }
869
+
870
+ dst_ptr[i0] = res;
871
  }
872
  }
873
 
874
+ typedef decltype(kernel_add_fuse_impl<2>) kernel_add_fuse_t;
875
+
876
+ template [[host_name("kernel_add")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<1>;
877
+ template [[host_name("kernel_add_fuse_2")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<2>;
878
+ template [[host_name("kernel_add_fuse_3")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<3>;
879
+ template [[host_name("kernel_add_fuse_4")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<4>;
880
+ template [[host_name("kernel_add_fuse_5")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<5>;
881
+ template [[host_name("kernel_add_fuse_6")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<6>;
882
+ template [[host_name("kernel_add_fuse_7")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<7>;
883
+ template [[host_name("kernel_add_fuse_8")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<8>;
884
+
885
  kernel void kernel_sub(
886
  constant ggml_metal_kargs_bin & args,
887
  device const char * src0,
 
899
  const int i11 = i01%args.ne11;
900
 
901
  device const char * src0_ptr = src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + args.offs;
902
+ device const char * src1_ptr = src1 + i13*args.nb13 + i12*args.nb12 + i11*args.nb11 + args.o1[0];
903
  device char * dst_ptr = dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1 + args.offs;
904
 
905
  for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
 
924
  const int i12 = i02%args.ne12;
925
  const int i11 = i01%args.ne11;
926
 
927
+ device const char * src0_ptr = src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + args.offs;
928
+ device const char * src1_ptr = src1 + i13*args.nb13 + i12*args.nb12 + i11*args.nb11 + args.o1[0];
929
+ device char * dst_ptr = dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1 + args.offs;
930
 
931
  for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
932
  const int i10 = i0%args.ne10;
 
950
  const int i12 = i02%args.ne12;
951
  const int i11 = i01%args.ne11;
952
 
953
+ device const char * src0_ptr = src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + args.offs;
954
+ device const char * src1_ptr = src1 + i13*args.nb13 + i12*args.nb12 + i11*args.nb11 + args.o1[0];
955
+ device char * dst_ptr = dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1 + args.offs;
956
 
957
  for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
958
  const int i10 = i0%args.ne10;
 
994
 
995
  // assumption: src1 is a row
996
  // broadcast src1 into src0
997
+ template <short F>
998
+ kernel void kernel_add_row_c4_fuse_impl(
999
  constant ggml_metal_kargs_bin & args,
1000
+ device const char * src0,
1001
+ device const char * src1,
1002
+ device char * dst,
1003
  uint tpig[[thread_position_in_grid]]) {
1004
+
1005
  const uint nb = args.ne00/4;
1006
+ const uint i = tpig % nb;
1007
+
1008
+ device const float4 * src0_row = (device const float4 *) (src0);
1009
+ device float4 * dst_row = (device float4 *) (dst);
1010
+
1011
+ device const float4 * src1_row[F];
1012
+ for (short j = 0; j < F; ++j) {
1013
+ src1_row[j] = (device const float4 *) (src1 + args.o1[j]);
1014
+ }
1015
+
1016
+ float4 res = src0_row[tpig];
1017
+
1018
+ #pragma unroll(F)
1019
+ for (short j = 0; j < F; ++j) {
1020
+ res += src1_row[j][i];
1021
+ }
1022
+
1023
+ dst_row[tpig] = res;
1024
  }
1025
 
1026
+ typedef decltype(kernel_add_row_c4_fuse_impl<1>) kernel_add_row_c4_fuse_t;
1027
+
1028
+ template [[host_name("kernel_add_row_c4")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<1>;
1029
+ template [[host_name("kernel_add_row_c4_fuse_2")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<2>;
1030
+ template [[host_name("kernel_add_row_c4_fuse_3")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<3>;
1031
+ template [[host_name("kernel_add_row_c4_fuse_4")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<4>;
1032
+ template [[host_name("kernel_add_row_c4_fuse_5")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<5>;
1033
+ template [[host_name("kernel_add_row_c4_fuse_6")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<6>;
1034
+ template [[host_name("kernel_add_row_c4_fuse_7")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<7>;
1035
+ template [[host_name("kernel_add_row_c4_fuse_8")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<8>;
1036
+
1037
+ template <short F>
1038
+ kernel void kernel_sub_row_c4_fuse_impl(
1039
  constant ggml_metal_kargs_bin & args,
1040
+ device const char * src0,
1041
+ device const char * src1,
1042
+ device char * dst,
1043
  uint tpig[[thread_position_in_grid]]) {
1044
+
1045
  const uint nb = args.ne00/4;
1046
+ const uint i = tpig % nb;
1047
+
1048
+ device const float4 * src0_row = (device const float4 *) (src0);
1049
+ device float4 * dst_row = (device float4 *) (dst);
1050
+
1051
+ device const float4 * src1_row[F];
1052
+ for (short j = 0; j < F; ++j) {
1053
+ src1_row[j] = (device const float4 *) (src1 + args.o1[j]);
1054
+ }
1055
+
1056
+ float4 res = src0_row[tpig];
1057
+
1058
+ #pragma unroll(F)
1059
+ for (short j = 0; j < F; ++j) {
1060
+ res -= src1_row[j][i];
1061
+ }
1062
+
1063
+ dst_row[tpig] = res;
1064
  }
1065
 
1066
+ typedef decltype(kernel_sub_row_c4_fuse_impl<1>) kernel_sub_row_c4_fuse_t;
1067
+
1068
+ template [[host_name("kernel_sub_row_c4")]] kernel kernel_sub_row_c4_fuse_t kernel_sub_row_c4_fuse_impl<1>;
1069
+
1070
+ template <short F>
1071
+ kernel void kernel_mul_row_c4_fuse_impl(
1072
  constant ggml_metal_kargs_bin & args,
1073
+ device const char * src0,
1074
+ device const char * src1,
1075
+ device char * dst,
1076
  uint tpig[[thread_position_in_grid]]) {
1077
+
1078
  const uint nb = args.ne00/4;
1079
+ const uint i = tpig % nb;
1080
+
1081
+ device const float4 * src0_row = (device const float4 *) (src0);
1082
+ device float4 * dst_row = (device float4 *) (dst);
1083
+
1084
+ device const float4 * src1_row[F];
1085
+ for (short j = 0; j < F; ++j) {
1086
+ src1_row[j] = (device const float4 *) (src1 + args.o1[j]);
1087
+ }
1088
+
1089
+ float4 res = src0_row[tpig];
1090
+
1091
+ #pragma unroll(F)
1092
+ for (short j = 0; j < F; ++j) {
1093
+ res *= src1_row[j][i];
1094
+ }
1095
+
1096
+ dst_row[tpig] = res;
1097
  }
1098
 
1099
+ typedef decltype(kernel_mul_row_c4_fuse_impl<1>) kernel_mul_row_c4_fuse_t;
1100
+
1101
+ template [[host_name("kernel_mul_row_c4")]] kernel kernel_mul_row_c4_fuse_t kernel_mul_row_c4_fuse_impl<1>;
1102
+
1103
+ template <short F>
1104
+ kernel void kernel_div_row_c4_fuse_impl(
1105
  constant ggml_metal_kargs_bin & args,
1106
+ device const char * src0,
1107
+ device const char * src1,
1108
+ device char * dst,
1109
  uint tpig[[thread_position_in_grid]]) {
1110
+
1111
  const uint nb = args.ne00/4;
1112
+ const uint i = tpig % nb;
1113
+
1114
+ device const float4 * src0_row = (device const float4 *) (src0);
1115
+ device float4 * dst_row = (device float4 *) (dst);
1116
+
1117
+ device const float4 * src1_row[F];
1118
+ for (short j = 0; j < F; ++j) {
1119
+ src1_row[j] = (device const float4 *) (src1 + args.o1[j]);
1120
+ }
1121
+
1122
+ float4 res = src0_row[tpig];
1123
+
1124
+ #pragma unroll(F)
1125
+ for (short j = 0; j < F; ++j) {
1126
+ res /= src1_row[j][i];
1127
+ }
1128
+
1129
+ dst_row[tpig] = res;
1130
  }
1131
 
1132
+ typedef decltype(kernel_div_row_c4_fuse_impl<1>) kernel_div_row_c4_fuse_t;
1133
+
1134
+ template [[host_name("kernel_div_row_c4")]] kernel kernel_div_row_c4_fuse_t kernel_div_row_c4_fuse_impl<1>;
1135
+
1136
  kernel void kernel_scale(
1137
  device const float * src0,
1138
  device float * dst,
 
2239
  }
2240
  }
2241
 
2242
+ // F == 1 : rms_norm (no fuse)
2243
+ // F == 2 : rms_norm + mul
2244
+ // F == 3 : rms_norm + mul + add
2245
+ template <short F>
2246
+ kernel void kernel_rms_norm_fuse_impl(
2247
  constant ggml_metal_kargs_rms_norm & args,
2248
  device const char * src0,
2249
+ device const char * src1_0,
2250
+ device const char * src1_1,
2251
  device char * dst,
2252
  threadgroup float * shmem_f32 [[threadgroup(0)]],
2253
+ uint3 tgpig[[threadgroup_position_in_grid]],
2254
+ ushort3 tpitg[[thread_position_in_threadgroup]],
2255
+ ushort sgitg[[simdgroup_index_in_threadgroup]],
2256
+ ushort tiisg[[thread_index_in_simdgroup]],
2257
+ ushort3 ntg[[threads_per_threadgroup]]) {
2258
  if (sgitg == 0) {
2259
  shmem_f32[tiisg] = 0.0f;
2260
  }
2261
 
2262
+ const int i01 = tgpig.x;
2263
+ const int i02 = tgpig.y;
2264
+ const int i03 = tgpig.z;
2265
+
2266
+ device const float4 * x = (device const float4 *) (src0 + i03*args.nbf3[0] + i02*args.nbf2[0] + i01*args.nbf1[0]);
2267
+
2268
+ device const float4 * f0 = (device const float4 *) (src1_0 + (i03%args.nef3[1])*args.nbf3[1] + (i02%args.nef2[1])*args.nbf2[1] + (i01%args.nef1[1])*args.nbf1[1]);
2269
+ device const float4 * f1 = (device const float4 *) (src1_1 + (i03%args.nef3[2])*args.nbf3[2] + (i02%args.nef2[2])*args.nbf2[2] + (i01%args.nef1[2])*args.nbf1[2]);
2270
 
2271
  float sumf = 0.0f;
2272
 
2273
  // parallel sum
2274
+ for (int i00 = tpitg.x; i00 < args.ne00_4; i00 += ntg.x) {
2275
  sumf += dot(x[i00], x[i00]);
2276
  }
2277
  sumf = simd_sum(sumf);
 
2290
  const float mean = sumf/args.ne00;
2291
  const float scale = 1.0f/sqrt(mean + args.eps);
2292
 
2293
+ device float4 * y = (device float4 *) (dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1);
2294
+ for (int i00 = tpitg.x; i00 < args.ne00_4; i00 += ntg.x) {
2295
+ if (F == 1) {
2296
+ y[i00] = (x[i00]*scale);
2297
+ }
2298
+ if (F == 2) {
2299
+ y[i00] = (x[i00]*scale)*f0[i00];
2300
+ }
2301
+ if (F == 3) {
2302
+ y[i00] = (x[i00]*scale)*f0[i00] + f1[i00];
2303
+ }
2304
  }
2305
  }
2306
 
2307
+ typedef decltype(kernel_rms_norm_fuse_impl<1>) kernel_rms_norm_fuse_t;
2308
+
2309
+ template [[host_name("kernel_rms_norm")]] kernel kernel_rms_norm_fuse_t kernel_rms_norm_fuse_impl<1>;
2310
+ template [[host_name("kernel_rms_norm_mul")]] kernel kernel_rms_norm_fuse_t kernel_rms_norm_fuse_impl<2>;
2311
+ template [[host_name("kernel_rms_norm_mul_add")]] kernel kernel_rms_norm_fuse_t kernel_rms_norm_fuse_impl<3>;
2312
+
2313
  kernel void kernel_l2_norm(
2314
  constant ggml_metal_kargs_l2_norm & args,
2315
  device const char * src0,