Spaces:
Running
Running
metal : fuse add, mul + add tests (llama/14596)
Browse files- ggml/src/ggml-alloc.c +0 -15
- ggml/src/ggml-backend.cpp +0 -15
- ggml/src/ggml-impl.h +16 -0
- ggml/src/ggml-metal/ggml-metal-impl.h +12 -3
- ggml/src/ggml-metal/ggml-metal.m +297 -67
- ggml/src/ggml-metal/ggml-metal.metal +193 -43
ggml/src/ggml-alloc.c
CHANGED
|
@@ -22,21 +22,6 @@ static bool ggml_is_view(const struct ggml_tensor * t) {
|
|
| 22 |
return t->view_src != NULL;
|
| 23 |
}
|
| 24 |
|
| 25 |
-
static bool ggml_are_same_layout(const struct ggml_tensor * a, const struct ggml_tensor * b) {
|
| 26 |
-
if (a->type != b->type) {
|
| 27 |
-
return false;
|
| 28 |
-
}
|
| 29 |
-
for (int i = 0; i < GGML_MAX_DIMS; i++) {
|
| 30 |
-
if (a->ne[i] != b->ne[i]) {
|
| 31 |
-
return false;
|
| 32 |
-
}
|
| 33 |
-
if (a->nb[i] != b->nb[i]) {
|
| 34 |
-
return false;
|
| 35 |
-
}
|
| 36 |
-
}
|
| 37 |
-
return true;
|
| 38 |
-
}
|
| 39 |
-
|
| 40 |
// ops that return true for this function must not use restrict pointers for their backend implementations
|
| 41 |
static bool ggml_op_can_inplace(enum ggml_op op) {
|
| 42 |
switch (op) {
|
|
|
|
| 22 |
return t->view_src != NULL;
|
| 23 |
}
|
| 24 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
// ops that return true for this function must not use restrict pointers for their backend implementations
|
| 26 |
static bool ggml_op_can_inplace(enum ggml_op op) {
|
| 27 |
switch (op) {
|
ggml/src/ggml-backend.cpp
CHANGED
|
@@ -352,21 +352,6 @@ ggml_backend_dev_t ggml_backend_get_device(ggml_backend_t backend) {
|
|
| 352 |
|
| 353 |
// backend copy
|
| 354 |
|
| 355 |
-
static bool ggml_are_same_layout(const struct ggml_tensor * a, const struct ggml_tensor * b) {
|
| 356 |
-
if (a->type != b->type) {
|
| 357 |
-
return false;
|
| 358 |
-
}
|
| 359 |
-
for (int i = 0; i < GGML_MAX_DIMS; i++) {
|
| 360 |
-
if (a->ne[i] != b->ne[i]) {
|
| 361 |
-
return false;
|
| 362 |
-
}
|
| 363 |
-
if (a->nb[i] != b->nb[i]) {
|
| 364 |
-
return false;
|
| 365 |
-
}
|
| 366 |
-
}
|
| 367 |
-
return true;
|
| 368 |
-
}
|
| 369 |
-
|
| 370 |
void ggml_backend_tensor_copy(struct ggml_tensor * src, struct ggml_tensor * dst) {
|
| 371 |
GGML_ASSERT(ggml_are_same_layout(src, dst) && "cannot copy tensors with different layouts");
|
| 372 |
|
|
|
|
| 352 |
|
| 353 |
// backend copy
|
| 354 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 355 |
void ggml_backend_tensor_copy(struct ggml_tensor * src, struct ggml_tensor * dst) {
|
| 356 |
GGML_ASSERT(ggml_are_same_layout(src, dst) && "cannot copy tensors with different layouts");
|
| 357 |
|
ggml/src/ggml-impl.h
CHANGED
|
@@ -73,6 +73,22 @@ static inline int ggml_up(int n, int m) {
|
|
| 73 |
return (n + m - 1) & ~(m - 1);
|
| 74 |
}
|
| 75 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 76 |
//
|
| 77 |
// logging
|
| 78 |
//
|
|
|
|
| 73 |
return (n + m - 1) & ~(m - 1);
|
| 74 |
}
|
| 75 |
|
| 76 |
+
// TODO: move to ggml.h?
|
| 77 |
+
static bool ggml_are_same_layout(const struct ggml_tensor * a, const struct ggml_tensor * b) {
|
| 78 |
+
if (a->type != b->type) {
|
| 79 |
+
return false;
|
| 80 |
+
}
|
| 81 |
+
for (int i = 0; i < GGML_MAX_DIMS; i++) {
|
| 82 |
+
if (a->ne[i] != b->ne[i]) {
|
| 83 |
+
return false;
|
| 84 |
+
}
|
| 85 |
+
if (a->nb[i] != b->nb[i]) {
|
| 86 |
+
return false;
|
| 87 |
+
}
|
| 88 |
+
}
|
| 89 |
+
return true;
|
| 90 |
+
}
|
| 91 |
+
|
| 92 |
//
|
| 93 |
// logging
|
| 94 |
//
|
ggml/src/ggml-metal/ggml-metal-impl.h
CHANGED
|
@@ -126,6 +126,7 @@ typedef struct {
|
|
| 126 |
uint64_t nb2;
|
| 127 |
uint64_t nb3;
|
| 128 |
uint64_t offs;
|
|
|
|
| 129 |
} ggml_metal_kargs_bin;
|
| 130 |
|
| 131 |
typedef struct {
|
|
@@ -240,7 +241,7 @@ typedef struct {
|
|
| 240 |
float max_bias;
|
| 241 |
float m0;
|
| 242 |
float m1;
|
| 243 |
-
|
| 244 |
float logit_softcap;
|
| 245 |
} ggml_metal_kargs_flash_attn_ext;
|
| 246 |
|
|
@@ -377,8 +378,16 @@ typedef struct {
|
|
| 377 |
typedef struct {
|
| 378 |
int32_t ne00;
|
| 379 |
int32_t ne00_4;
|
| 380 |
-
uint64_t
|
|
|
|
|
|
|
| 381 |
float eps;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 382 |
} ggml_metal_kargs_rms_norm;
|
| 383 |
|
| 384 |
typedef struct {
|
|
@@ -484,7 +493,7 @@ typedef struct {
|
|
| 484 |
float max_bias;
|
| 485 |
float m0;
|
| 486 |
float m1;
|
| 487 |
-
|
| 488 |
} ggml_metal_kargs_soft_max;
|
| 489 |
|
| 490 |
typedef struct {
|
|
|
|
| 126 |
uint64_t nb2;
|
| 127 |
uint64_t nb3;
|
| 128 |
uint64_t offs;
|
| 129 |
+
uint64_t o1[8];
|
| 130 |
} ggml_metal_kargs_bin;
|
| 131 |
|
| 132 |
typedef struct {
|
|
|
|
| 241 |
float max_bias;
|
| 242 |
float m0;
|
| 243 |
float m1;
|
| 244 |
+
int32_t n_head_log2;
|
| 245 |
float logit_softcap;
|
| 246 |
} ggml_metal_kargs_flash_attn_ext;
|
| 247 |
|
|
|
|
| 378 |
typedef struct {
|
| 379 |
int32_t ne00;
|
| 380 |
int32_t ne00_4;
|
| 381 |
+
uint64_t nb1;
|
| 382 |
+
uint64_t nb2;
|
| 383 |
+
uint64_t nb3;
|
| 384 |
float eps;
|
| 385 |
+
int32_t nef1[3];
|
| 386 |
+
int32_t nef2[3];
|
| 387 |
+
int32_t nef3[3];
|
| 388 |
+
uint64_t nbf1[3];
|
| 389 |
+
uint64_t nbf2[3];
|
| 390 |
+
uint64_t nbf3[3];
|
| 391 |
} ggml_metal_kargs_rms_norm;
|
| 392 |
|
| 393 |
typedef struct {
|
|
|
|
| 493 |
float max_bias;
|
| 494 |
float m0;
|
| 495 |
float m1;
|
| 496 |
+
int32_t n_head_log2;
|
| 497 |
} ggml_metal_kargs_soft_max;
|
| 498 |
|
| 499 |
typedef struct {
|
ggml/src/ggml-metal/ggml-metal.m
CHANGED
|
@@ -55,6 +55,12 @@ static struct ggml_backend_metal_device_context {
|
|
| 55 |
bool has_residency_sets;
|
| 56 |
bool has_bfloat;
|
| 57 |
bool use_bfloat;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 58 |
|
| 59 |
size_t max_size;
|
| 60 |
|
|
@@ -69,6 +75,9 @@ static struct ggml_backend_metal_device_context {
|
|
| 69 |
/*.has_residency_sets =*/ false,
|
| 70 |
/*.has_bfloat =*/ false,
|
| 71 |
/*.use_bfloat =*/ false,
|
|
|
|
|
|
|
|
|
|
| 72 |
/*.max_size =*/ 0,
|
| 73 |
/*.name =*/ "",
|
| 74 |
};
|
|
@@ -83,16 +92,14 @@ static id<MTLDevice> ggml_backend_metal_device_acq(struct ggml_backend_metal_dev
|
|
| 83 |
|
| 84 |
if (ctx->mtl_device == nil) {
|
| 85 |
ctx->mtl_device = MTLCreateSystemDefaultDevice();
|
| 86 |
-
}
|
| 87 |
|
| 88 |
-
if (ctx->mtl_device) {
|
| 89 |
ctx->has_simdgroup_reduction = [ctx->mtl_device supportsFamily:MTLGPUFamilyApple7];
|
| 90 |
ctx->has_simdgroup_reduction |= [ctx->mtl_device supportsFamily:MTLGPUFamilyMetal3_GGML];
|
| 91 |
|
| 92 |
ctx->has_simdgroup_mm = [ctx->mtl_device supportsFamily:MTLGPUFamilyApple7];
|
| 93 |
|
| 94 |
#if defined(GGML_METAL_HAS_RESIDENCY_SETS)
|
| 95 |
-
ctx->has_residency_sets = getenv("GGML_METAL_NO_RESIDENCY") ==
|
| 96 |
#endif
|
| 97 |
|
| 98 |
ctx->has_bfloat = [ctx->mtl_device supportsFamily:MTLGPUFamilyMetal3_GGML];
|
|
@@ -103,6 +110,14 @@ static id<MTLDevice> ggml_backend_metal_device_acq(struct ggml_backend_metal_dev
|
|
| 103 |
#else
|
| 104 |
ctx->use_bfloat = false;
|
| 105 |
#endif
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 106 |
|
| 107 |
ctx->max_size = ctx->mtl_device.maxBufferLength;
|
| 108 |
|
|
@@ -122,6 +137,18 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
|
|
| 122 |
ctx->mtl_device_ref_count--;
|
| 123 |
|
| 124 |
if (ctx->mtl_device_ref_count == 0) {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 125 |
if (ctx->mtl_lock) {
|
| 126 |
[ctx->mtl_lock release];
|
| 127 |
ctx->mtl_lock = nil;
|
|
@@ -147,13 +174,27 @@ struct ggml_metal_kernel {
|
|
| 147 |
|
| 148 |
enum ggml_metal_kernel_type {
|
| 149 |
GGML_METAL_KERNEL_TYPE_ADD,
|
| 150 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 151 |
GGML_METAL_KERNEL_TYPE_SUB,
|
| 152 |
-
|
| 153 |
GGML_METAL_KERNEL_TYPE_MUL,
|
| 154 |
-
|
| 155 |
GGML_METAL_KERNEL_TYPE_DIV,
|
| 156 |
-
|
| 157 |
GGML_METAL_KERNEL_TYPE_REPEAT_F32,
|
| 158 |
GGML_METAL_KERNEL_TYPE_REPEAT_F16,
|
| 159 |
GGML_METAL_KERNEL_TYPE_REPEAT_I32,
|
|
@@ -218,6 +259,8 @@ enum ggml_metal_kernel_type {
|
|
| 218 |
GGML_METAL_KERNEL_TYPE_SET_ROWS_Q5_1,
|
| 219 |
GGML_METAL_KERNEL_TYPE_SET_ROWS_IQ4_NL,
|
| 220 |
GGML_METAL_KERNEL_TYPE_RMS_NORM,
|
|
|
|
|
|
|
| 221 |
GGML_METAL_KERNEL_TYPE_L2_NORM,
|
| 222 |
GGML_METAL_KERNEL_TYPE_GROUP_NORM,
|
| 223 |
GGML_METAL_KERNEL_TYPE_NORM,
|
|
@@ -1135,13 +1178,27 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
|
|
| 1135 |
// simd_sum and simd_max requires MTLGPUFamilyApple7
|
| 1136 |
|
| 1137 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD, add, true);
|
| 1138 |
-
GGML_METAL_ADD_KERNEL(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1139 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUB, sub, true);
|
| 1140 |
-
GGML_METAL_ADD_KERNEL(
|
| 1141 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL, mul, true);
|
| 1142 |
-
GGML_METAL_ADD_KERNEL(
|
| 1143 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV, div, true);
|
| 1144 |
-
GGML_METAL_ADD_KERNEL(
|
| 1145 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_F32, repeat_f32, true);
|
| 1146 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_F16, repeat_f16, true);
|
| 1147 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_I32, repeat_i32, true);
|
|
@@ -1206,6 +1263,8 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
|
|
| 1206 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_ROWS_Q5_1, set_rows_q5_1, true);
|
| 1207 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_ROWS_IQ4_NL, set_rows_iq4_nl, true);
|
| 1208 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RMS_NORM, rms_norm, has_simdgroup_reduction);
|
|
|
|
|
|
|
| 1209 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_L2_NORM, l2_norm, has_simdgroup_reduction);
|
| 1210 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GROUP_NORM, group_norm, has_simdgroup_reduction);
|
| 1211 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NORM, norm, true);
|
|
@@ -1893,7 +1952,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
|
|
| 1893 |
}
|
| 1894 |
}
|
| 1895 |
|
| 1896 |
-
static
|
| 1897 |
ggml_backend_t backend,
|
| 1898 |
int idx,
|
| 1899 |
id<MTLComputeCommandEncoder> encoder,
|
|
@@ -1903,7 +1962,10 @@ static bool ggml_metal_encode_node(
|
|
| 1903 |
|
| 1904 |
struct ggml_cgraph * gf = ctx->gf;
|
| 1905 |
|
| 1906 |
-
|
|
|
|
|
|
|
|
|
|
| 1907 |
|
| 1908 |
//GGML_LOG_INFO("%s: encoding node %3d, op = %8s\n", __func__, idx, ggml_op_name(node->op));
|
| 1909 |
|
|
@@ -1913,7 +1975,7 @@ static bool ggml_metal_encode_node(
|
|
| 1913 |
struct ggml_tensor * dst = node;
|
| 1914 |
|
| 1915 |
if (ggml_is_empty(dst)) {
|
| 1916 |
-
return
|
| 1917 |
}
|
| 1918 |
|
| 1919 |
switch (dst->op) {
|
|
@@ -1924,7 +1986,7 @@ static bool ggml_metal_encode_node(
|
|
| 1924 |
case GGML_OP_PERMUTE:
|
| 1925 |
{
|
| 1926 |
// noop -> next node
|
| 1927 |
-
} return
|
| 1928 |
default:
|
| 1929 |
{
|
| 1930 |
} break;
|
|
@@ -1991,6 +2053,8 @@ static bool ggml_metal_encode_node(
|
|
| 1991 |
id<MTLBuffer> id_src2 = src2 ? ggml_metal_get_buffer(src2, &offs_src2) : nil;
|
| 1992 |
id<MTLBuffer> id_dst = dst ? ggml_metal_get_buffer(dst, &offs_dst) : nil;
|
| 1993 |
|
|
|
|
|
|
|
| 1994 |
#if 0
|
| 1995 |
GGML_LOG_INFO("%s: op - %s\n", __func__, ggml_op_name(dst->op));
|
| 1996 |
if (src0) {
|
|
@@ -2062,37 +2126,15 @@ static bool ggml_metal_encode_node(
|
|
| 2062 |
GGML_ASSERT(src0t == GGML_TYPE_F32);
|
| 2063 |
GGML_ASSERT(src1t == GGML_TYPE_F32);
|
| 2064 |
|
|
|
|
|
|
|
|
|
|
| 2065 |
const size_t offs = 0;
|
| 2066 |
|
| 2067 |
bool bcast_row = false;
|
| 2068 |
|
| 2069 |
id<MTLComputePipelineState> pipeline = nil;
|
| 2070 |
|
| 2071 |
-
if (ggml_nelements(src1) == ne10 && ggml_is_contiguous(src1) && ne00 % 4 == 0 && ne10 % 4 == 0) {
|
| 2072 |
-
GGML_ASSERT(ggml_is_contiguous(src0));
|
| 2073 |
-
|
| 2074 |
-
// src1 is a row
|
| 2075 |
-
GGML_ASSERT(ne11 == 1);
|
| 2076 |
-
|
| 2077 |
-
switch (dst->op) {
|
| 2078 |
-
case GGML_OP_ADD: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW].pipeline; break;
|
| 2079 |
-
case GGML_OP_SUB: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SUB_ROW].pipeline; break;
|
| 2080 |
-
case GGML_OP_MUL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_ROW].pipeline; break;
|
| 2081 |
-
case GGML_OP_DIV: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIV_ROW].pipeline; break;
|
| 2082 |
-
default: GGML_ABORT("fatal error");
|
| 2083 |
-
}
|
| 2084 |
-
|
| 2085 |
-
bcast_row = true;
|
| 2086 |
-
} else {
|
| 2087 |
-
switch (dst->op) {
|
| 2088 |
-
case GGML_OP_ADD: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD].pipeline; break;
|
| 2089 |
-
case GGML_OP_SUB: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SUB].pipeline; break;
|
| 2090 |
-
case GGML_OP_MUL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL].pipeline; break;
|
| 2091 |
-
case GGML_OP_DIV: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIV].pipeline; break;
|
| 2092 |
-
default: GGML_ABORT("fatal error");
|
| 2093 |
-
}
|
| 2094 |
-
}
|
| 2095 |
-
|
| 2096 |
ggml_metal_kargs_bin args = {
|
| 2097 |
/*.ne00 =*/ ne00,
|
| 2098 |
/*.ne01 =*/ ne01,
|
|
@@ -2119,12 +2161,117 @@ static bool ggml_metal_encode_node(
|
|
| 2119 |
/*.nb2 =*/ nb2,
|
| 2120 |
/*.nb3 =*/ nb3,
|
| 2121 |
/*.offs =*/ offs,
|
|
|
|
| 2122 |
};
|
| 2123 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2124 |
[encoder setComputePipelineState:pipeline];
|
| 2125 |
[encoder setBytes:&args length:sizeof(args) atIndex:0];
|
| 2126 |
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
|
| 2127 |
-
[encoder setBuffer:id_src1 offset:
|
| 2128 |
[encoder setBuffer:id_dst offset:offs_dst atIndex:3];
|
| 2129 |
|
| 2130 |
if (bcast_row) {
|
|
@@ -2132,7 +2279,11 @@ static bool ggml_metal_encode_node(
|
|
| 2132 |
|
| 2133 |
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
| 2134 |
} else {
|
| 2135 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2136 |
|
| 2137 |
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
| 2138 |
}
|
|
@@ -2257,12 +2408,13 @@ static bool ggml_metal_encode_node(
|
|
| 2257 |
/*.nb2 =*/ pnb2,
|
| 2258 |
/*.nb3 =*/ pnb3,
|
| 2259 |
/*.offs =*/ offs,
|
|
|
|
| 2260 |
};
|
| 2261 |
|
| 2262 |
[encoder setComputePipelineState:pipeline];
|
| 2263 |
[encoder setBytes:&args length:sizeof(args) atIndex:0];
|
| 2264 |
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
|
| 2265 |
-
[encoder setBuffer:id_src1 offset:
|
| 2266 |
[encoder setBuffer:id_dst offset:offs_dst atIndex:3];
|
| 2267 |
|
| 2268 |
const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne00);
|
|
@@ -2764,7 +2916,7 @@ static bool ggml_metal_encode_node(
|
|
| 2764 |
id<MTLBuffer> h_src0 = h_src0 = ggml_metal_mem_pool_alloc(mem_pool, ggml_nbytes(src0));
|
| 2765 |
if (!h_src0) {
|
| 2766 |
GGML_LOG_ERROR("%s: failed to allocate buffer from memory pool, size = %zu\n", __func__, ggml_nbytes(src0));
|
| 2767 |
-
return
|
| 2768 |
}
|
| 2769 |
|
| 2770 |
offs_src0 = 0;
|
|
@@ -3640,7 +3792,7 @@ static bool ggml_metal_encode_node(
|
|
| 3640 |
id<MTLBuffer> h_src1 = ggml_metal_mem_pool_alloc(mem_pool, s_src1);
|
| 3641 |
if (!h_src1) {
|
| 3642 |
GGML_LOG_ERROR("%s: failed to allocate buffer from memory pool, size = %zu\n", __func__, s_src1);
|
| 3643 |
-
return
|
| 3644 |
}
|
| 3645 |
|
| 3646 |
const int64_t neh0 = ne0;
|
|
@@ -3656,7 +3808,7 @@ static bool ggml_metal_encode_node(
|
|
| 3656 |
id<MTLBuffer> h_dst = ggml_metal_mem_pool_alloc(mem_pool, s_dst);
|
| 3657 |
if (!h_dst) {
|
| 3658 |
GGML_LOG_ERROR("%s: failed to allocate buffer from memory pool, size = %zu\n", __func__, s_dst);
|
| 3659 |
-
return
|
| 3660 |
}
|
| 3661 |
|
| 3662 |
// tokens per expert
|
|
@@ -3664,7 +3816,7 @@ static bool ggml_metal_encode_node(
|
|
| 3664 |
id<MTLBuffer> h_tpe = ggml_metal_mem_pool_alloc(mem_pool, s_tpe);
|
| 3665 |
if (!h_tpe) {
|
| 3666 |
GGML_LOG_ERROR("%s: failed to allocate buffer from memory pool, size = %zu\n", __func__, s_tpe);
|
| 3667 |
-
return
|
| 3668 |
}
|
| 3669 |
|
| 3670 |
// id map
|
|
@@ -3673,7 +3825,7 @@ static bool ggml_metal_encode_node(
|
|
| 3673 |
id<MTLBuffer> h_ids = ggml_metal_mem_pool_alloc(mem_pool, s_ids);
|
| 3674 |
if (!h_ids) {
|
| 3675 |
GGML_LOG_ERROR("%s: failed to allocate buffer from memory pool, size = %zu\n", __func__, s_ids);
|
| 3676 |
-
return
|
| 3677 |
}
|
| 3678 |
|
| 3679 |
{
|
|
@@ -4105,12 +4257,95 @@ static bool ggml_metal_encode_node(
|
|
| 4105 |
case GGML_OP_RMS_NORM:
|
| 4106 |
{
|
| 4107 |
GGML_ASSERT(ne00 % 4 == 0);
|
| 4108 |
-
GGML_ASSERT(
|
| 4109 |
|
| 4110 |
float eps;
|
| 4111 |
memcpy(&eps, dst->op_params, sizeof(float));
|
| 4112 |
|
| 4113 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4114 |
|
| 4115 |
int nth = 32; // SIMD width
|
| 4116 |
|
|
@@ -4121,23 +4356,16 @@ static bool ggml_metal_encode_node(
|
|
| 4121 |
nth = MIN(nth, (int) pipeline.maxTotalThreadsPerThreadgroup);
|
| 4122 |
nth = MIN(nth, ne00/4);
|
| 4123 |
|
| 4124 |
-
ggml_metal_kargs_rms_norm args = {
|
| 4125 |
-
/*.ne00 =*/ ne00,
|
| 4126 |
-
/*.ne00_4 =*/ ne00/4,
|
| 4127 |
-
/*.nb01 =*/ nb01,
|
| 4128 |
-
/*.eps =*/ eps,
|
| 4129 |
-
};
|
| 4130 |
-
|
| 4131 |
[encoder setComputePipelineState:pipeline];
|
| 4132 |
-
[encoder setBytes:&args length:sizeof(args)
|
| 4133 |
-
[encoder setBuffer:id_src0
|
| 4134 |
-
[encoder setBuffer:
|
|
|
|
|
|
|
| 4135 |
|
| 4136 |
[encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
|
| 4137 |
|
| 4138 |
-
|
| 4139 |
-
|
| 4140 |
-
[encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
| 4141 |
} break;
|
| 4142 |
case GGML_OP_L2_NORM:
|
| 4143 |
{
|
|
@@ -5532,7 +5760,7 @@ static bool ggml_metal_encode_node(
|
|
| 5532 |
}
|
| 5533 |
}
|
| 5534 |
|
| 5535 |
-
return
|
| 5536 |
}
|
| 5537 |
|
| 5538 |
static enum ggml_status ggml_metal_graph_compute(
|
|
@@ -6038,20 +6266,22 @@ static void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) {
|
|
| 6038 |
struct ggml_metal_mem_pool * mem_pool = ctx->cmd_bufs[cb_idx].mem_pool;
|
| 6039 |
ggml_metal_mem_pool_reset(mem_pool);
|
| 6040 |
|
| 6041 |
-
for (int idx = node_start; idx < node_end;
|
| 6042 |
if (should_capture) {
|
| 6043 |
[encoder pushDebugGroup:[NSString stringWithCString:ggml_op_desc(ggml_graph_node(ctx->gf, idx)) encoding:NSUTF8StringEncoding]];
|
| 6044 |
}
|
| 6045 |
|
| 6046 |
-
const
|
| 6047 |
|
| 6048 |
if (should_capture) {
|
| 6049 |
[encoder popDebugGroup];
|
| 6050 |
}
|
| 6051 |
|
| 6052 |
-
if (
|
| 6053 |
break;
|
| 6054 |
}
|
|
|
|
|
|
|
| 6055 |
}
|
| 6056 |
|
| 6057 |
[encoder endEncoding];
|
|
|
|
| 55 |
bool has_residency_sets;
|
| 56 |
bool has_bfloat;
|
| 57 |
bool use_bfloat;
|
| 58 |
+
bool use_fusion;
|
| 59 |
+
|
| 60 |
+
int debug_fusion;
|
| 61 |
+
|
| 62 |
+
// how many times a given op was fused
|
| 63 |
+
uint64_t fuse_cnt[GGML_OP_COUNT];
|
| 64 |
|
| 65 |
size_t max_size;
|
| 66 |
|
|
|
|
| 75 |
/*.has_residency_sets =*/ false,
|
| 76 |
/*.has_bfloat =*/ false,
|
| 77 |
/*.use_bfloat =*/ false,
|
| 78 |
+
/*.use_fusion =*/ true,
|
| 79 |
+
/*.debug_fusion =*/ 0,
|
| 80 |
+
/*.fuse_cnt =*/ { 0 },
|
| 81 |
/*.max_size =*/ 0,
|
| 82 |
/*.name =*/ "",
|
| 83 |
};
|
|
|
|
| 92 |
|
| 93 |
if (ctx->mtl_device == nil) {
|
| 94 |
ctx->mtl_device = MTLCreateSystemDefaultDevice();
|
|
|
|
| 95 |
|
|
|
|
| 96 |
ctx->has_simdgroup_reduction = [ctx->mtl_device supportsFamily:MTLGPUFamilyApple7];
|
| 97 |
ctx->has_simdgroup_reduction |= [ctx->mtl_device supportsFamily:MTLGPUFamilyMetal3_GGML];
|
| 98 |
|
| 99 |
ctx->has_simdgroup_mm = [ctx->mtl_device supportsFamily:MTLGPUFamilyApple7];
|
| 100 |
|
| 101 |
#if defined(GGML_METAL_HAS_RESIDENCY_SETS)
|
| 102 |
+
ctx->has_residency_sets = getenv("GGML_METAL_NO_RESIDENCY") == nil;
|
| 103 |
#endif
|
| 104 |
|
| 105 |
ctx->has_bfloat = [ctx->mtl_device supportsFamily:MTLGPUFamilyMetal3_GGML];
|
|
|
|
| 110 |
#else
|
| 111 |
ctx->use_bfloat = false;
|
| 112 |
#endif
|
| 113 |
+
ctx->use_fusion = getenv("GGML_METAL_FUSION_DISABLE") == nil;
|
| 114 |
+
|
| 115 |
+
{
|
| 116 |
+
const char * val = getenv("GGML_METAL_FUSION_DEBUG");
|
| 117 |
+
ctx->debug_fusion = val ? atoi(val) : 0;
|
| 118 |
+
}
|
| 119 |
+
|
| 120 |
+
memset(ctx->fuse_cnt, 0, sizeof(ctx->fuse_cnt));
|
| 121 |
|
| 122 |
ctx->max_size = ctx->mtl_device.maxBufferLength;
|
| 123 |
|
|
|
|
| 137 |
ctx->mtl_device_ref_count--;
|
| 138 |
|
| 139 |
if (ctx->mtl_device_ref_count == 0) {
|
| 140 |
+
if (ctx->debug_fusion > 0) {
|
| 141 |
+
fprintf(stderr, "%s: fusion stats:\n", __func__);
|
| 142 |
+
for (int i = 0; i < GGML_OP_COUNT; i++) {
|
| 143 |
+
if (ctx->fuse_cnt[i] == 0) {
|
| 144 |
+
continue;
|
| 145 |
+
}
|
| 146 |
+
|
| 147 |
+
// note: cannot use ggml_log here
|
| 148 |
+
fprintf(stderr, "%s: - %s: %" PRIu64 "\n", __func__, ggml_op_name((enum ggml_op) i), ctx->fuse_cnt[i]);
|
| 149 |
+
}
|
| 150 |
+
}
|
| 151 |
+
|
| 152 |
if (ctx->mtl_lock) {
|
| 153 |
[ctx->mtl_lock release];
|
| 154 |
ctx->mtl_lock = nil;
|
|
|
|
| 174 |
|
| 175 |
enum ggml_metal_kernel_type {
|
| 176 |
GGML_METAL_KERNEL_TYPE_ADD,
|
| 177 |
+
GGML_METAL_KERNEL_TYPE_ADD_FUSE_2,
|
| 178 |
+
GGML_METAL_KERNEL_TYPE_ADD_FUSE_3,
|
| 179 |
+
GGML_METAL_KERNEL_TYPE_ADD_FUSE_4,
|
| 180 |
+
GGML_METAL_KERNEL_TYPE_ADD_FUSE_5,
|
| 181 |
+
GGML_METAL_KERNEL_TYPE_ADD_FUSE_6,
|
| 182 |
+
GGML_METAL_KERNEL_TYPE_ADD_FUSE_7,
|
| 183 |
+
GGML_METAL_KERNEL_TYPE_ADD_FUSE_8,
|
| 184 |
+
GGML_METAL_KERNEL_TYPE_ADD_ROW_C4,
|
| 185 |
+
GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_2,
|
| 186 |
+
GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_3,
|
| 187 |
+
GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_4,
|
| 188 |
+
GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_5,
|
| 189 |
+
GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_6,
|
| 190 |
+
GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_7,
|
| 191 |
+
GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_8,
|
| 192 |
GGML_METAL_KERNEL_TYPE_SUB,
|
| 193 |
+
GGML_METAL_KERNEL_TYPE_SUB_ROW_C4,
|
| 194 |
GGML_METAL_KERNEL_TYPE_MUL,
|
| 195 |
+
GGML_METAL_KERNEL_TYPE_MUL_ROW_C4,
|
| 196 |
GGML_METAL_KERNEL_TYPE_DIV,
|
| 197 |
+
GGML_METAL_KERNEL_TYPE_DIV_ROW_C4,
|
| 198 |
GGML_METAL_KERNEL_TYPE_REPEAT_F32,
|
| 199 |
GGML_METAL_KERNEL_TYPE_REPEAT_F16,
|
| 200 |
GGML_METAL_KERNEL_TYPE_REPEAT_I32,
|
|
|
|
| 259 |
GGML_METAL_KERNEL_TYPE_SET_ROWS_Q5_1,
|
| 260 |
GGML_METAL_KERNEL_TYPE_SET_ROWS_IQ4_NL,
|
| 261 |
GGML_METAL_KERNEL_TYPE_RMS_NORM,
|
| 262 |
+
GGML_METAL_KERNEL_TYPE_RMS_NORM_MUL,
|
| 263 |
+
GGML_METAL_KERNEL_TYPE_RMS_NORM_MUL_ADD,
|
| 264 |
GGML_METAL_KERNEL_TYPE_L2_NORM,
|
| 265 |
GGML_METAL_KERNEL_TYPE_GROUP_NORM,
|
| 266 |
GGML_METAL_KERNEL_TYPE_NORM,
|
|
|
|
| 1178 |
// simd_sum and simd_max requires MTLGPUFamilyApple7
|
| 1179 |
|
| 1180 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD, add, true);
|
| 1181 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_FUSE_2, add_fuse_2, true);
|
| 1182 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_FUSE_3, add_fuse_3, true);
|
| 1183 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_FUSE_4, add_fuse_4, true);
|
| 1184 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_FUSE_5, add_fuse_5, true);
|
| 1185 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_FUSE_6, add_fuse_6, true);
|
| 1186 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_FUSE_7, add_fuse_7, true);
|
| 1187 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_FUSE_8, add_fuse_8, true);
|
| 1188 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW_C4, add_row_c4, true);
|
| 1189 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_2, add_row_c4_fuse_2, true);
|
| 1190 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_3, add_row_c4_fuse_3, true);
|
| 1191 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_4, add_row_c4_fuse_4, true);
|
| 1192 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_5, add_row_c4_fuse_5, true);
|
| 1193 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_6, add_row_c4_fuse_6, true);
|
| 1194 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_7, add_row_c4_fuse_7, true);
|
| 1195 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_8, add_row_c4_fuse_8, true);
|
| 1196 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUB, sub, true);
|
| 1197 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUB_ROW_C4, sub_row_c4, true);
|
| 1198 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL, mul, true);
|
| 1199 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_ROW_C4, mul_row_c4, true);
|
| 1200 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV, div, true);
|
| 1201 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV_ROW_C4, div_row_c4, true);
|
| 1202 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_F32, repeat_f32, true);
|
| 1203 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_F16, repeat_f16, true);
|
| 1204 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_I32, repeat_i32, true);
|
|
|
|
| 1263 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_ROWS_Q5_1, set_rows_q5_1, true);
|
| 1264 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_ROWS_IQ4_NL, set_rows_iq4_nl, true);
|
| 1265 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RMS_NORM, rms_norm, has_simdgroup_reduction);
|
| 1266 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RMS_NORM_MUL, rms_norm_mul, has_simdgroup_reduction);
|
| 1267 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RMS_NORM_MUL_ADD, rms_norm_mul_add, has_simdgroup_reduction);
|
| 1268 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_L2_NORM, l2_norm, has_simdgroup_reduction);
|
| 1269 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GROUP_NORM, group_norm, has_simdgroup_reduction);
|
| 1270 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NORM, norm, true);
|
|
|
|
| 1952 |
}
|
| 1953 |
}
|
| 1954 |
|
| 1955 |
+
static int ggml_metal_encode_node(
|
| 1956 |
ggml_backend_t backend,
|
| 1957 |
int idx,
|
| 1958 |
id<MTLComputeCommandEncoder> encoder,
|
|
|
|
| 1962 |
|
| 1963 |
struct ggml_cgraph * gf = ctx->gf;
|
| 1964 |
|
| 1965 |
+
enum ggml_op ops[8];
|
| 1966 |
+
|
| 1967 |
+
struct ggml_tensor ** nodes = ggml_graph_nodes(gf) + idx;
|
| 1968 |
+
struct ggml_tensor * node = nodes[0];
|
| 1969 |
|
| 1970 |
//GGML_LOG_INFO("%s: encoding node %3d, op = %8s\n", __func__, idx, ggml_op_name(node->op));
|
| 1971 |
|
|
|
|
| 1975 |
struct ggml_tensor * dst = node;
|
| 1976 |
|
| 1977 |
if (ggml_is_empty(dst)) {
|
| 1978 |
+
return 1;
|
| 1979 |
}
|
| 1980 |
|
| 1981 |
switch (dst->op) {
|
|
|
|
| 1986 |
case GGML_OP_PERMUTE:
|
| 1987 |
{
|
| 1988 |
// noop -> next node
|
| 1989 |
+
} return 1;
|
| 1990 |
default:
|
| 1991 |
{
|
| 1992 |
} break;
|
|
|
|
| 2053 |
id<MTLBuffer> id_src2 = src2 ? ggml_metal_get_buffer(src2, &offs_src2) : nil;
|
| 2054 |
id<MTLBuffer> id_dst = dst ? ggml_metal_get_buffer(dst, &offs_dst) : nil;
|
| 2055 |
|
| 2056 |
+
int n_fuse = 1;
|
| 2057 |
+
|
| 2058 |
#if 0
|
| 2059 |
GGML_LOG_INFO("%s: op - %s\n", __func__, ggml_op_name(dst->op));
|
| 2060 |
if (src0) {
|
|
|
|
| 2126 |
GGML_ASSERT(src0t == GGML_TYPE_F32);
|
| 2127 |
GGML_ASSERT(src1t == GGML_TYPE_F32);
|
| 2128 |
|
| 2129 |
+
GGML_ASSERT(ggml_is_contiguous_rows(src0));
|
| 2130 |
+
GGML_ASSERT(ggml_is_contiguous_rows(src1));
|
| 2131 |
+
|
| 2132 |
const size_t offs = 0;
|
| 2133 |
|
| 2134 |
bool bcast_row = false;
|
| 2135 |
|
| 2136 |
id<MTLComputePipelineState> pipeline = nil;
|
| 2137 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2138 |
ggml_metal_kargs_bin args = {
|
| 2139 |
/*.ne00 =*/ ne00,
|
| 2140 |
/*.ne01 =*/ ne01,
|
|
|
|
| 2161 |
/*.nb2 =*/ nb2,
|
| 2162 |
/*.nb3 =*/ nb3,
|
| 2163 |
/*.offs =*/ offs,
|
| 2164 |
+
/*.o1 =*/ { offs_src1 },
|
| 2165 |
};
|
| 2166 |
|
| 2167 |
+
// c[0] = add(a, b[0])
|
| 2168 |
+
// c[1] = add(c[0], b[1])
|
| 2169 |
+
// c[2] = add(c[1], b[2])
|
| 2170 |
+
// ...
|
| 2171 |
+
if (ctx_dev->use_fusion) {
|
| 2172 |
+
ops[0] = GGML_OP_ADD;
|
| 2173 |
+
ops[1] = GGML_OP_ADD;
|
| 2174 |
+
ops[2] = GGML_OP_ADD;
|
| 2175 |
+
ops[3] = GGML_OP_ADD;
|
| 2176 |
+
ops[4] = GGML_OP_ADD;
|
| 2177 |
+
ops[5] = GGML_OP_ADD;
|
| 2178 |
+
ops[6] = GGML_OP_ADD;
|
| 2179 |
+
ops[7] = GGML_OP_ADD;
|
| 2180 |
+
|
| 2181 |
+
size_t offs_fuse;
|
| 2182 |
+
id<MTLBuffer> id_fuse;
|
| 2183 |
+
|
| 2184 |
+
for (n_fuse = 0; n_fuse <= 6; ++n_fuse) {
|
| 2185 |
+
if (!ggml_can_fuse(gf, idx + n_fuse, ops + n_fuse, 2)) {
|
| 2186 |
+
break;
|
| 2187 |
+
}
|
| 2188 |
+
|
| 2189 |
+
if (nodes[n_fuse] != nodes[n_fuse + 1]->src[0]) {
|
| 2190 |
+
break;
|
| 2191 |
+
}
|
| 2192 |
+
|
| 2193 |
+
// b[0] === b[1] === ...
|
| 2194 |
+
if (!ggml_are_same_layout(nodes[n_fuse]->src[1], nodes[n_fuse + 1]->src[1])) {
|
| 2195 |
+
break;
|
| 2196 |
+
}
|
| 2197 |
+
|
| 2198 |
+
// only fuse nodes if src1 is in the same Metal buffer
|
| 2199 |
+
id_fuse = ggml_metal_get_buffer(nodes[n_fuse + 1]->src[1], &offs_fuse);
|
| 2200 |
+
if (id_fuse != id_src1) {
|
| 2201 |
+
break;
|
| 2202 |
+
}
|
| 2203 |
+
|
| 2204 |
+
ctx_dev->fuse_cnt[nodes[n_fuse + 1]->op]++;
|
| 2205 |
+
|
| 2206 |
+
args.o1[n_fuse + 1] = offs_fuse;
|
| 2207 |
+
}
|
| 2208 |
+
|
| 2209 |
+
++n_fuse;
|
| 2210 |
+
|
| 2211 |
+
if (ctx_dev->debug_fusion > 1 && n_fuse > 1) {
|
| 2212 |
+
GGML_LOG_DEBUG("%s: fuse: ADD x %d\n", __func__, n_fuse);
|
| 2213 |
+
}
|
| 2214 |
+
}
|
| 2215 |
+
|
| 2216 |
+
if (ggml_nelements(src1) == ne10 && ggml_is_contiguous(src1) && ne00 % 4 == 0 && ne10 % 4 == 0) {
|
| 2217 |
+
GGML_ASSERT(ggml_is_contiguous(src0));
|
| 2218 |
+
|
| 2219 |
+
// src1 is a row
|
| 2220 |
+
GGML_ASSERT(ne11 == 1);
|
| 2221 |
+
|
| 2222 |
+
switch (dst->op) {
|
| 2223 |
+
case GGML_OP_ADD:
|
| 2224 |
+
{
|
| 2225 |
+
switch (n_fuse) {
|
| 2226 |
+
case 1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW_C4 ].pipeline; break;
|
| 2227 |
+
case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_2].pipeline; break;
|
| 2228 |
+
case 3: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_3].pipeline; break;
|
| 2229 |
+
case 4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_4].pipeline; break;
|
| 2230 |
+
case 5: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_5].pipeline; break;
|
| 2231 |
+
case 6: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_6].pipeline; break;
|
| 2232 |
+
case 7: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_7].pipeline; break;
|
| 2233 |
+
case 8: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_8].pipeline; break;
|
| 2234 |
+
default: GGML_ABORT("fatal error");
|
| 2235 |
+
}
|
| 2236 |
+
} break;
|
| 2237 |
+
case GGML_OP_SUB: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SUB_ROW_C4].pipeline; break;
|
| 2238 |
+
case GGML_OP_MUL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_ROW_C4].pipeline; break;
|
| 2239 |
+
case GGML_OP_DIV: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIV_ROW_C4].pipeline; break;
|
| 2240 |
+
default: GGML_ABORT("fatal error");
|
| 2241 |
+
}
|
| 2242 |
+
|
| 2243 |
+
bcast_row = true;
|
| 2244 |
+
} else {
|
| 2245 |
+
switch (dst->op) {
|
| 2246 |
+
case GGML_OP_ADD:
|
| 2247 |
+
{
|
| 2248 |
+
switch (n_fuse) {
|
| 2249 |
+
case 1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD ].pipeline; break;
|
| 2250 |
+
case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_FUSE_2].pipeline; break;
|
| 2251 |
+
case 3: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_FUSE_3].pipeline; break;
|
| 2252 |
+
case 4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_FUSE_4].pipeline; break;
|
| 2253 |
+
case 5: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_FUSE_5].pipeline; break;
|
| 2254 |
+
case 6: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_FUSE_6].pipeline; break;
|
| 2255 |
+
case 7: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_FUSE_7].pipeline; break;
|
| 2256 |
+
case 8: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_FUSE_8].pipeline; break;
|
| 2257 |
+
default: GGML_ABORT("fatal error");
|
| 2258 |
+
}
|
| 2259 |
+
} break;
|
| 2260 |
+
case GGML_OP_SUB: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SUB].pipeline; break;
|
| 2261 |
+
case GGML_OP_MUL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL].pipeline; break;
|
| 2262 |
+
case GGML_OP_DIV: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIV].pipeline; break;
|
| 2263 |
+
default: GGML_ABORT("fatal error");
|
| 2264 |
+
}
|
| 2265 |
+
}
|
| 2266 |
+
|
| 2267 |
+
if (n_fuse > 1) {
|
| 2268 |
+
id_dst = ggml_metal_get_buffer(nodes[n_fuse - 1], &offs_dst);
|
| 2269 |
+
}
|
| 2270 |
+
|
| 2271 |
[encoder setComputePipelineState:pipeline];
|
| 2272 |
[encoder setBytes:&args length:sizeof(args) atIndex:0];
|
| 2273 |
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
|
| 2274 |
+
[encoder setBuffer:id_src1 offset:0 atIndex:2];
|
| 2275 |
[encoder setBuffer:id_dst offset:offs_dst atIndex:3];
|
| 2276 |
|
| 2277 |
if (bcast_row) {
|
|
|
|
| 2279 |
|
| 2280 |
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
| 2281 |
} else {
|
| 2282 |
+
int nth = 32;
|
| 2283 |
+
|
| 2284 |
+
while (16*nth < ne0 && nth < (int) pipeline.maxTotalThreadsPerThreadgroup) {
|
| 2285 |
+
nth *= 2;
|
| 2286 |
+
}
|
| 2287 |
|
| 2288 |
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
| 2289 |
}
|
|
|
|
| 2408 |
/*.nb2 =*/ pnb2,
|
| 2409 |
/*.nb3 =*/ pnb3,
|
| 2410 |
/*.offs =*/ offs,
|
| 2411 |
+
/*.o1 =*/ { offs_src1},
|
| 2412 |
};
|
| 2413 |
|
| 2414 |
[encoder setComputePipelineState:pipeline];
|
| 2415 |
[encoder setBytes:&args length:sizeof(args) atIndex:0];
|
| 2416 |
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
|
| 2417 |
+
[encoder setBuffer:id_src1 offset:0 atIndex:2];
|
| 2418 |
[encoder setBuffer:id_dst offset:offs_dst atIndex:3];
|
| 2419 |
|
| 2420 |
const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne00);
|
|
|
|
| 2916 |
id<MTLBuffer> h_src0 = h_src0 = ggml_metal_mem_pool_alloc(mem_pool, ggml_nbytes(src0));
|
| 2917 |
if (!h_src0) {
|
| 2918 |
GGML_LOG_ERROR("%s: failed to allocate buffer from memory pool, size = %zu\n", __func__, ggml_nbytes(src0));
|
| 2919 |
+
return 0;
|
| 2920 |
}
|
| 2921 |
|
| 2922 |
offs_src0 = 0;
|
|
|
|
| 3792 |
id<MTLBuffer> h_src1 = ggml_metal_mem_pool_alloc(mem_pool, s_src1);
|
| 3793 |
if (!h_src1) {
|
| 3794 |
GGML_LOG_ERROR("%s: failed to allocate buffer from memory pool, size = %zu\n", __func__, s_src1);
|
| 3795 |
+
return 0;
|
| 3796 |
}
|
| 3797 |
|
| 3798 |
const int64_t neh0 = ne0;
|
|
|
|
| 3808 |
id<MTLBuffer> h_dst = ggml_metal_mem_pool_alloc(mem_pool, s_dst);
|
| 3809 |
if (!h_dst) {
|
| 3810 |
GGML_LOG_ERROR("%s: failed to allocate buffer from memory pool, size = %zu\n", __func__, s_dst);
|
| 3811 |
+
return 0;
|
| 3812 |
}
|
| 3813 |
|
| 3814 |
// tokens per expert
|
|
|
|
| 3816 |
id<MTLBuffer> h_tpe = ggml_metal_mem_pool_alloc(mem_pool, s_tpe);
|
| 3817 |
if (!h_tpe) {
|
| 3818 |
GGML_LOG_ERROR("%s: failed to allocate buffer from memory pool, size = %zu\n", __func__, s_tpe);
|
| 3819 |
+
return 0;
|
| 3820 |
}
|
| 3821 |
|
| 3822 |
// id map
|
|
|
|
| 3825 |
id<MTLBuffer> h_ids = ggml_metal_mem_pool_alloc(mem_pool, s_ids);
|
| 3826 |
if (!h_ids) {
|
| 3827 |
GGML_LOG_ERROR("%s: failed to allocate buffer from memory pool, size = %zu\n", __func__, s_ids);
|
| 3828 |
+
return 0;
|
| 3829 |
}
|
| 3830 |
|
| 3831 |
{
|
|
|
|
| 4257 |
case GGML_OP_RMS_NORM:
|
| 4258 |
{
|
| 4259 |
GGML_ASSERT(ne00 % 4 == 0);
|
| 4260 |
+
GGML_ASSERT(ggml_is_contiguous_rows(src0));
|
| 4261 |
|
| 4262 |
float eps;
|
| 4263 |
memcpy(&eps, dst->op_params, sizeof(float));
|
| 4264 |
|
| 4265 |
+
ggml_metal_kargs_rms_norm args = {
|
| 4266 |
+
/*.ne00 =*/ ne00,
|
| 4267 |
+
/*.ne00_4 =*/ ne00/4,
|
| 4268 |
+
/*.nb1 =*/ nb1,
|
| 4269 |
+
/*.nb2 =*/ nb2,
|
| 4270 |
+
/*.nb3 =*/ nb3,
|
| 4271 |
+
/*.eps =*/ eps,
|
| 4272 |
+
/*.nef1 =*/ { ne01 },
|
| 4273 |
+
/*.nef2 =*/ { ne02 },
|
| 4274 |
+
/*.nef3 =*/ { ne03 },
|
| 4275 |
+
/*.nbf1 =*/ { nb01 },
|
| 4276 |
+
/*.nbf2 =*/ { nb02 },
|
| 4277 |
+
/*.nbf3 =*/ { nb03 },
|
| 4278 |
+
};
|
| 4279 |
+
|
| 4280 |
+
size_t offs_fuse[2] = { 0, 0 };
|
| 4281 |
+
id<MTLBuffer> id_fuse[2] = { id_src0, id_src0 };
|
| 4282 |
+
|
| 4283 |
+
// d[0] = rms_norm(a)
|
| 4284 |
+
// d[1] = mul(d[0], b)
|
| 4285 |
+
// d[2] = add(d[1], c)
|
| 4286 |
+
if (ctx_dev->use_fusion) {
|
| 4287 |
+
ops[0] = GGML_OP_RMS_NORM;
|
| 4288 |
+
ops[1] = GGML_OP_MUL;
|
| 4289 |
+
ops[2] = GGML_OP_ADD;
|
| 4290 |
+
|
| 4291 |
+
for (n_fuse = 0; n_fuse <= 1; ++n_fuse) {
|
| 4292 |
+
if (!ggml_can_fuse(gf, idx + n_fuse, ops + n_fuse, 2)) {
|
| 4293 |
+
break;
|
| 4294 |
+
}
|
| 4295 |
+
|
| 4296 |
+
if (nodes[n_fuse] != nodes[n_fuse + 1]->src[0]) {
|
| 4297 |
+
break;
|
| 4298 |
+
}
|
| 4299 |
+
|
| 4300 |
+
if (nodes[n_fuse + 1]->src[1]->ne[0] != node->ne[0]) {
|
| 4301 |
+
break;
|
| 4302 |
+
}
|
| 4303 |
+
|
| 4304 |
+
if (!ggml_is_contiguous_rows(nodes[n_fuse + 1]->src[1])) {
|
| 4305 |
+
break;
|
| 4306 |
+
}
|
| 4307 |
+
|
| 4308 |
+
if (nodes[n_fuse + 1]->type != GGML_TYPE_F32) {
|
| 4309 |
+
break;
|
| 4310 |
+
}
|
| 4311 |
+
|
| 4312 |
+
ctx_dev->fuse_cnt[nodes[n_fuse + 1]->op]++;
|
| 4313 |
+
|
| 4314 |
+
id_fuse[n_fuse] = ggml_metal_get_buffer(nodes[n_fuse + 1]->src[1], &offs_fuse[n_fuse]);
|
| 4315 |
+
|
| 4316 |
+
args.nef1[n_fuse + 1] = nodes[n_fuse + 1]->src[1]->ne[1];
|
| 4317 |
+
args.nef2[n_fuse + 1] = nodes[n_fuse + 1]->src[1]->ne[2];
|
| 4318 |
+
args.nef3[n_fuse + 1] = nodes[n_fuse + 1]->src[1]->ne[3];
|
| 4319 |
+
|
| 4320 |
+
args.nbf1[n_fuse + 1] = nodes[n_fuse + 1]->src[1]->nb[1];
|
| 4321 |
+
args.nbf2[n_fuse + 1] = nodes[n_fuse + 1]->src[1]->nb[2];
|
| 4322 |
+
args.nbf3[n_fuse + 1] = nodes[n_fuse + 1]->src[1]->nb[3];
|
| 4323 |
+
}
|
| 4324 |
+
|
| 4325 |
+
++n_fuse;
|
| 4326 |
+
|
| 4327 |
+
if (ctx_dev->debug_fusion > 1 && n_fuse > 1) {
|
| 4328 |
+
if (n_fuse == 2) {
|
| 4329 |
+
GGML_LOG_DEBUG("%s: fuse: RMS_NORM + MUL\n", __func__);
|
| 4330 |
+
}
|
| 4331 |
+
if (n_fuse == 3) {
|
| 4332 |
+
GGML_LOG_DEBUG("%s: fuse: RMS_NORM + MUL + ADD\n", __func__);
|
| 4333 |
+
}
|
| 4334 |
+
}
|
| 4335 |
+
}
|
| 4336 |
+
|
| 4337 |
+
if (n_fuse > 1) {
|
| 4338 |
+
id_dst = ggml_metal_get_buffer(nodes[n_fuse - 1], &offs_dst);
|
| 4339 |
+
}
|
| 4340 |
+
|
| 4341 |
+
id<MTLComputePipelineState> pipeline;
|
| 4342 |
+
|
| 4343 |
+
switch (n_fuse) {
|
| 4344 |
+
case 1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_RMS_NORM ].pipeline; break;
|
| 4345 |
+
case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_RMS_NORM_MUL ].pipeline; break;
|
| 4346 |
+
case 3: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_RMS_NORM_MUL_ADD].pipeline; break;
|
| 4347 |
+
default: GGML_ABORT("unsupported n_fuse = %d\n", n_fuse);
|
| 4348 |
+
}
|
| 4349 |
|
| 4350 |
int nth = 32; // SIMD width
|
| 4351 |
|
|
|
|
| 4356 |
nth = MIN(nth, (int) pipeline.maxTotalThreadsPerThreadgroup);
|
| 4357 |
nth = MIN(nth, ne00/4);
|
| 4358 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4359 |
[encoder setComputePipelineState:pipeline];
|
| 4360 |
+
[encoder setBytes:&args length:sizeof(args) atIndex:0];
|
| 4361 |
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
|
| 4362 |
+
[encoder setBuffer:id_fuse[0] offset:offs_fuse[0] atIndex:2];
|
| 4363 |
+
[encoder setBuffer:id_fuse[1] offset:offs_fuse[1] atIndex:3];
|
| 4364 |
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:4];
|
| 4365 |
|
| 4366 |
[encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
|
| 4367 |
|
| 4368 |
+
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
|
|
|
|
|
|
| 4369 |
} break;
|
| 4370 |
case GGML_OP_L2_NORM:
|
| 4371 |
{
|
|
|
|
| 5760 |
}
|
| 5761 |
}
|
| 5762 |
|
| 5763 |
+
return n_fuse;
|
| 5764 |
}
|
| 5765 |
|
| 5766 |
static enum ggml_status ggml_metal_graph_compute(
|
|
|
|
| 6266 |
struct ggml_metal_mem_pool * mem_pool = ctx->cmd_bufs[cb_idx].mem_pool;
|
| 6267 |
ggml_metal_mem_pool_reset(mem_pool);
|
| 6268 |
|
| 6269 |
+
for (int idx = node_start; idx < node_end;) {
|
| 6270 |
if (should_capture) {
|
| 6271 |
[encoder pushDebugGroup:[NSString stringWithCString:ggml_op_desc(ggml_graph_node(ctx->gf, idx)) encoding:NSUTF8StringEncoding]];
|
| 6272 |
}
|
| 6273 |
|
| 6274 |
+
const int res = ggml_metal_encode_node(backend, idx, encoder, mem_pool);
|
| 6275 |
|
| 6276 |
if (should_capture) {
|
| 6277 |
[encoder popDebugGroup];
|
| 6278 |
}
|
| 6279 |
|
| 6280 |
+
if (res == 0) {
|
| 6281 |
break;
|
| 6282 |
}
|
| 6283 |
+
|
| 6284 |
+
idx += res;
|
| 6285 |
}
|
| 6286 |
|
| 6287 |
[encoder endEncoding];
|
ggml/src/ggml-metal/ggml-metal.metal
CHANGED
|
@@ -832,7 +832,8 @@ enum ggml_sort_order {
|
|
| 832 |
// general-purpose kernel for addition, subtraction, multiplication and division of two tensors
|
| 833 |
// pros: works for non-contiguous tensors, supports broadcast across all dims
|
| 834 |
// cons: not very efficient
|
| 835 |
-
|
|
|
|
| 836 |
constant ggml_metal_kargs_bin & args,
|
| 837 |
device const char * src0,
|
| 838 |
device const char * src1,
|
|
@@ -848,16 +849,39 @@ kernel void kernel_add(
|
|
| 848 |
const int i12 = i02%args.ne12;
|
| 849 |
const int i11 = i01%args.ne11;
|
| 850 |
|
| 851 |
-
device const
|
| 852 |
-
device
|
| 853 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 854 |
|
| 855 |
for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
|
| 856 |
const int i10 = i0%args.ne10;
|
| 857 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 858 |
}
|
| 859 |
}
|
| 860 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 861 |
kernel void kernel_sub(
|
| 862 |
constant ggml_metal_kargs_bin & args,
|
| 863 |
device const char * src0,
|
|
@@ -875,7 +899,7 @@ kernel void kernel_sub(
|
|
| 875 |
const int i11 = i01%args.ne11;
|
| 876 |
|
| 877 |
device const char * src0_ptr = src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + args.offs;
|
| 878 |
-
device const char * src1_ptr = src1 + i13*args.nb13 + i12*args.nb12 + i11*args.nb11;
|
| 879 |
device char * dst_ptr = dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1 + args.offs;
|
| 880 |
|
| 881 |
for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
|
|
@@ -900,9 +924,9 @@ kernel void kernel_mul(
|
|
| 900 |
const int i12 = i02%args.ne12;
|
| 901 |
const int i11 = i01%args.ne11;
|
| 902 |
|
| 903 |
-
device const char * src0_ptr = src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01;
|
| 904 |
-
device const char * src1_ptr = src1 + i13*args.nb13 + i12*args.nb12 + i11*args.nb11;
|
| 905 |
-
device char * dst_ptr = dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1;
|
| 906 |
|
| 907 |
for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
|
| 908 |
const int i10 = i0%args.ne10;
|
|
@@ -926,9 +950,9 @@ kernel void kernel_div(
|
|
| 926 |
const int i12 = i02%args.ne12;
|
| 927 |
const int i11 = i01%args.ne11;
|
| 928 |
|
| 929 |
-
device const char * src0_ptr = src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01;
|
| 930 |
-
device const char * src1_ptr = src1 + i13*args.nb13 + i12*args.nb12 + i11*args.nb11;
|
| 931 |
-
device char * dst_ptr = dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1;
|
| 932 |
|
| 933 |
for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
|
| 934 |
const int i10 = i0%args.ne10;
|
|
@@ -970,46 +994,145 @@ template [[host_name("kernel_repeat_i16")]] kernel kernel_repeat_t kernel_repeat
|
|
| 970 |
|
| 971 |
// assumption: src1 is a row
|
| 972 |
// broadcast src1 into src0
|
| 973 |
-
|
|
|
|
| 974 |
constant ggml_metal_kargs_bin & args,
|
| 975 |
-
device const
|
| 976 |
-
device const
|
| 977 |
-
device
|
| 978 |
uint tpig[[thread_position_in_grid]]) {
|
|
|
|
| 979 |
const uint nb = args.ne00/4;
|
| 980 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 981 |
}
|
| 982 |
|
| 983 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 984 |
constant ggml_metal_kargs_bin & args,
|
| 985 |
-
device const
|
| 986 |
-
device const
|
| 987 |
-
device
|
| 988 |
uint tpig[[thread_position_in_grid]]) {
|
|
|
|
| 989 |
const uint nb = args.ne00/4;
|
| 990 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 991 |
}
|
| 992 |
|
| 993 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 994 |
constant ggml_metal_kargs_bin & args,
|
| 995 |
-
device const
|
| 996 |
-
device const
|
| 997 |
-
device
|
| 998 |
uint tpig[[thread_position_in_grid]]) {
|
|
|
|
| 999 |
const uint nb = args.ne00/4;
|
| 1000 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1001 |
}
|
| 1002 |
|
| 1003 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1004 |
constant ggml_metal_kargs_bin & args,
|
| 1005 |
-
device const
|
| 1006 |
-
device const
|
| 1007 |
-
device
|
| 1008 |
uint tpig[[thread_position_in_grid]]) {
|
|
|
|
| 1009 |
const uint nb = args.ne00/4;
|
| 1010 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1011 |
}
|
| 1012 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1013 |
kernel void kernel_scale(
|
| 1014 |
device const float * src0,
|
| 1015 |
device float * dst,
|
|
@@ -2116,26 +2239,39 @@ kernel void kernel_norm(
|
|
| 2116 |
}
|
| 2117 |
}
|
| 2118 |
|
| 2119 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2120 |
constant ggml_metal_kargs_rms_norm & args,
|
| 2121 |
device const char * src0,
|
|
|
|
|
|
|
| 2122 |
device char * dst,
|
| 2123 |
threadgroup float * shmem_f32 [[threadgroup(0)]],
|
| 2124 |
-
|
| 2125 |
-
|
| 2126 |
-
ushort
|
| 2127 |
-
ushort
|
| 2128 |
-
|
| 2129 |
if (sgitg == 0) {
|
| 2130 |
shmem_f32[tiisg] = 0.0f;
|
| 2131 |
}
|
| 2132 |
|
| 2133 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2134 |
|
| 2135 |
float sumf = 0.0f;
|
| 2136 |
|
| 2137 |
// parallel sum
|
| 2138 |
-
for (int i00 = tpitg; i00 < args.ne00_4; i00 += ntg) {
|
| 2139 |
sumf += dot(x[i00], x[i00]);
|
| 2140 |
}
|
| 2141 |
sumf = simd_sum(sumf);
|
|
@@ -2154,12 +2290,26 @@ kernel void kernel_rms_norm(
|
|
| 2154 |
const float mean = sumf/args.ne00;
|
| 2155 |
const float scale = 1.0f/sqrt(mean + args.eps);
|
| 2156 |
|
| 2157 |
-
device float4 * y = (device float4 *) dst +
|
| 2158 |
-
for (int i00 = tpitg; i00 < args.ne00_4; i00 += ntg) {
|
| 2159 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2160 |
}
|
| 2161 |
}
|
| 2162 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2163 |
kernel void kernel_l2_norm(
|
| 2164 |
constant ggml_metal_kargs_l2_norm & args,
|
| 2165 |
device const char * src0,
|
|
|
|
| 832 |
// general-purpose kernel for addition, subtraction, multiplication and division of two tensors
|
| 833 |
// pros: works for non-contiguous tensors, supports broadcast across all dims
|
| 834 |
// cons: not very efficient
|
| 835 |
+
template <int F>
|
| 836 |
+
kernel void kernel_add_fuse_impl(
|
| 837 |
constant ggml_metal_kargs_bin & args,
|
| 838 |
device const char * src0,
|
| 839 |
device const char * src1,
|
|
|
|
| 849 |
const int i12 = i02%args.ne12;
|
| 850 |
const int i11 = i01%args.ne11;
|
| 851 |
|
| 852 |
+
device const float * src0_ptr = (device const float *) (src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + args.offs);
|
| 853 |
+
device float * dst_ptr = (device float *) (dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1 + args.offs);
|
| 854 |
+
|
| 855 |
+
device const float * src1_ptr[F];
|
| 856 |
+
for (short j = 0; j < F; ++j) {
|
| 857 |
+
src1_ptr[j] = (device const float *) (src1 + args.o1[j] + i13*args.nb13 + i12*args.nb12 + i11*args.nb11);
|
| 858 |
+
}
|
| 859 |
|
| 860 |
for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
|
| 861 |
const int i10 = i0%args.ne10;
|
| 862 |
+
|
| 863 |
+
float res = src0_ptr[i0];
|
| 864 |
+
|
| 865 |
+
#pragma unroll
|
| 866 |
+
for (short j = 0; j < F; ++j) {
|
| 867 |
+
res += src1_ptr[j][i10];
|
| 868 |
+
}
|
| 869 |
+
|
| 870 |
+
dst_ptr[i0] = res;
|
| 871 |
}
|
| 872 |
}
|
| 873 |
|
| 874 |
+
typedef decltype(kernel_add_fuse_impl<2>) kernel_add_fuse_t;
|
| 875 |
+
|
| 876 |
+
template [[host_name("kernel_add")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<1>;
|
| 877 |
+
template [[host_name("kernel_add_fuse_2")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<2>;
|
| 878 |
+
template [[host_name("kernel_add_fuse_3")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<3>;
|
| 879 |
+
template [[host_name("kernel_add_fuse_4")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<4>;
|
| 880 |
+
template [[host_name("kernel_add_fuse_5")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<5>;
|
| 881 |
+
template [[host_name("kernel_add_fuse_6")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<6>;
|
| 882 |
+
template [[host_name("kernel_add_fuse_7")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<7>;
|
| 883 |
+
template [[host_name("kernel_add_fuse_8")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<8>;
|
| 884 |
+
|
| 885 |
kernel void kernel_sub(
|
| 886 |
constant ggml_metal_kargs_bin & args,
|
| 887 |
device const char * src0,
|
|
|
|
| 899 |
const int i11 = i01%args.ne11;
|
| 900 |
|
| 901 |
device const char * src0_ptr = src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + args.offs;
|
| 902 |
+
device const char * src1_ptr = src1 + i13*args.nb13 + i12*args.nb12 + i11*args.nb11 + args.o1[0];
|
| 903 |
device char * dst_ptr = dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1 + args.offs;
|
| 904 |
|
| 905 |
for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
|
|
|
|
| 924 |
const int i12 = i02%args.ne12;
|
| 925 |
const int i11 = i01%args.ne11;
|
| 926 |
|
| 927 |
+
device const char * src0_ptr = src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + args.offs;
|
| 928 |
+
device const char * src1_ptr = src1 + i13*args.nb13 + i12*args.nb12 + i11*args.nb11 + args.o1[0];
|
| 929 |
+
device char * dst_ptr = dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1 + args.offs;
|
| 930 |
|
| 931 |
for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
|
| 932 |
const int i10 = i0%args.ne10;
|
|
|
|
| 950 |
const int i12 = i02%args.ne12;
|
| 951 |
const int i11 = i01%args.ne11;
|
| 952 |
|
| 953 |
+
device const char * src0_ptr = src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + args.offs;
|
| 954 |
+
device const char * src1_ptr = src1 + i13*args.nb13 + i12*args.nb12 + i11*args.nb11 + args.o1[0];
|
| 955 |
+
device char * dst_ptr = dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1 + args.offs;
|
| 956 |
|
| 957 |
for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
|
| 958 |
const int i10 = i0%args.ne10;
|
|
|
|
| 994 |
|
| 995 |
// assumption: src1 is a row
|
| 996 |
// broadcast src1 into src0
|
| 997 |
+
template <short F>
|
| 998 |
+
kernel void kernel_add_row_c4_fuse_impl(
|
| 999 |
constant ggml_metal_kargs_bin & args,
|
| 1000 |
+
device const char * src0,
|
| 1001 |
+
device const char * src1,
|
| 1002 |
+
device char * dst,
|
| 1003 |
uint tpig[[thread_position_in_grid]]) {
|
| 1004 |
+
|
| 1005 |
const uint nb = args.ne00/4;
|
| 1006 |
+
const uint i = tpig % nb;
|
| 1007 |
+
|
| 1008 |
+
device const float4 * src0_row = (device const float4 *) (src0);
|
| 1009 |
+
device float4 * dst_row = (device float4 *) (dst);
|
| 1010 |
+
|
| 1011 |
+
device const float4 * src1_row[F];
|
| 1012 |
+
for (short j = 0; j < F; ++j) {
|
| 1013 |
+
src1_row[j] = (device const float4 *) (src1 + args.o1[j]);
|
| 1014 |
+
}
|
| 1015 |
+
|
| 1016 |
+
float4 res = src0_row[tpig];
|
| 1017 |
+
|
| 1018 |
+
#pragma unroll(F)
|
| 1019 |
+
for (short j = 0; j < F; ++j) {
|
| 1020 |
+
res += src1_row[j][i];
|
| 1021 |
+
}
|
| 1022 |
+
|
| 1023 |
+
dst_row[tpig] = res;
|
| 1024 |
}
|
| 1025 |
|
| 1026 |
+
typedef decltype(kernel_add_row_c4_fuse_impl<1>) kernel_add_row_c4_fuse_t;
|
| 1027 |
+
|
| 1028 |
+
template [[host_name("kernel_add_row_c4")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<1>;
|
| 1029 |
+
template [[host_name("kernel_add_row_c4_fuse_2")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<2>;
|
| 1030 |
+
template [[host_name("kernel_add_row_c4_fuse_3")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<3>;
|
| 1031 |
+
template [[host_name("kernel_add_row_c4_fuse_4")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<4>;
|
| 1032 |
+
template [[host_name("kernel_add_row_c4_fuse_5")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<5>;
|
| 1033 |
+
template [[host_name("kernel_add_row_c4_fuse_6")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<6>;
|
| 1034 |
+
template [[host_name("kernel_add_row_c4_fuse_7")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<7>;
|
| 1035 |
+
template [[host_name("kernel_add_row_c4_fuse_8")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<8>;
|
| 1036 |
+
|
| 1037 |
+
template <short F>
|
| 1038 |
+
kernel void kernel_sub_row_c4_fuse_impl(
|
| 1039 |
constant ggml_metal_kargs_bin & args,
|
| 1040 |
+
device const char * src0,
|
| 1041 |
+
device const char * src1,
|
| 1042 |
+
device char * dst,
|
| 1043 |
uint tpig[[thread_position_in_grid]]) {
|
| 1044 |
+
|
| 1045 |
const uint nb = args.ne00/4;
|
| 1046 |
+
const uint i = tpig % nb;
|
| 1047 |
+
|
| 1048 |
+
device const float4 * src0_row = (device const float4 *) (src0);
|
| 1049 |
+
device float4 * dst_row = (device float4 *) (dst);
|
| 1050 |
+
|
| 1051 |
+
device const float4 * src1_row[F];
|
| 1052 |
+
for (short j = 0; j < F; ++j) {
|
| 1053 |
+
src1_row[j] = (device const float4 *) (src1 + args.o1[j]);
|
| 1054 |
+
}
|
| 1055 |
+
|
| 1056 |
+
float4 res = src0_row[tpig];
|
| 1057 |
+
|
| 1058 |
+
#pragma unroll(F)
|
| 1059 |
+
for (short j = 0; j < F; ++j) {
|
| 1060 |
+
res -= src1_row[j][i];
|
| 1061 |
+
}
|
| 1062 |
+
|
| 1063 |
+
dst_row[tpig] = res;
|
| 1064 |
}
|
| 1065 |
|
| 1066 |
+
typedef decltype(kernel_sub_row_c4_fuse_impl<1>) kernel_sub_row_c4_fuse_t;
|
| 1067 |
+
|
| 1068 |
+
template [[host_name("kernel_sub_row_c4")]] kernel kernel_sub_row_c4_fuse_t kernel_sub_row_c4_fuse_impl<1>;
|
| 1069 |
+
|
| 1070 |
+
template <short F>
|
| 1071 |
+
kernel void kernel_mul_row_c4_fuse_impl(
|
| 1072 |
constant ggml_metal_kargs_bin & args,
|
| 1073 |
+
device const char * src0,
|
| 1074 |
+
device const char * src1,
|
| 1075 |
+
device char * dst,
|
| 1076 |
uint tpig[[thread_position_in_grid]]) {
|
| 1077 |
+
|
| 1078 |
const uint nb = args.ne00/4;
|
| 1079 |
+
const uint i = tpig % nb;
|
| 1080 |
+
|
| 1081 |
+
device const float4 * src0_row = (device const float4 *) (src0);
|
| 1082 |
+
device float4 * dst_row = (device float4 *) (dst);
|
| 1083 |
+
|
| 1084 |
+
device const float4 * src1_row[F];
|
| 1085 |
+
for (short j = 0; j < F; ++j) {
|
| 1086 |
+
src1_row[j] = (device const float4 *) (src1 + args.o1[j]);
|
| 1087 |
+
}
|
| 1088 |
+
|
| 1089 |
+
float4 res = src0_row[tpig];
|
| 1090 |
+
|
| 1091 |
+
#pragma unroll(F)
|
| 1092 |
+
for (short j = 0; j < F; ++j) {
|
| 1093 |
+
res *= src1_row[j][i];
|
| 1094 |
+
}
|
| 1095 |
+
|
| 1096 |
+
dst_row[tpig] = res;
|
| 1097 |
}
|
| 1098 |
|
| 1099 |
+
typedef decltype(kernel_mul_row_c4_fuse_impl<1>) kernel_mul_row_c4_fuse_t;
|
| 1100 |
+
|
| 1101 |
+
template [[host_name("kernel_mul_row_c4")]] kernel kernel_mul_row_c4_fuse_t kernel_mul_row_c4_fuse_impl<1>;
|
| 1102 |
+
|
| 1103 |
+
template <short F>
|
| 1104 |
+
kernel void kernel_div_row_c4_fuse_impl(
|
| 1105 |
constant ggml_metal_kargs_bin & args,
|
| 1106 |
+
device const char * src0,
|
| 1107 |
+
device const char * src1,
|
| 1108 |
+
device char * dst,
|
| 1109 |
uint tpig[[thread_position_in_grid]]) {
|
| 1110 |
+
|
| 1111 |
const uint nb = args.ne00/4;
|
| 1112 |
+
const uint i = tpig % nb;
|
| 1113 |
+
|
| 1114 |
+
device const float4 * src0_row = (device const float4 *) (src0);
|
| 1115 |
+
device float4 * dst_row = (device float4 *) (dst);
|
| 1116 |
+
|
| 1117 |
+
device const float4 * src1_row[F];
|
| 1118 |
+
for (short j = 0; j < F; ++j) {
|
| 1119 |
+
src1_row[j] = (device const float4 *) (src1 + args.o1[j]);
|
| 1120 |
+
}
|
| 1121 |
+
|
| 1122 |
+
float4 res = src0_row[tpig];
|
| 1123 |
+
|
| 1124 |
+
#pragma unroll(F)
|
| 1125 |
+
for (short j = 0; j < F; ++j) {
|
| 1126 |
+
res /= src1_row[j][i];
|
| 1127 |
+
}
|
| 1128 |
+
|
| 1129 |
+
dst_row[tpig] = res;
|
| 1130 |
}
|
| 1131 |
|
| 1132 |
+
typedef decltype(kernel_div_row_c4_fuse_impl<1>) kernel_div_row_c4_fuse_t;
|
| 1133 |
+
|
| 1134 |
+
template [[host_name("kernel_div_row_c4")]] kernel kernel_div_row_c4_fuse_t kernel_div_row_c4_fuse_impl<1>;
|
| 1135 |
+
|
| 1136 |
kernel void kernel_scale(
|
| 1137 |
device const float * src0,
|
| 1138 |
device float * dst,
|
|
|
|
| 2239 |
}
|
| 2240 |
}
|
| 2241 |
|
| 2242 |
+
// F == 1 : rms_norm (no fuse)
|
| 2243 |
+
// F == 2 : rms_norm + mul
|
| 2244 |
+
// F == 3 : rms_norm + mul + add
|
| 2245 |
+
template <short F>
|
| 2246 |
+
kernel void kernel_rms_norm_fuse_impl(
|
| 2247 |
constant ggml_metal_kargs_rms_norm & args,
|
| 2248 |
device const char * src0,
|
| 2249 |
+
device const char * src1_0,
|
| 2250 |
+
device const char * src1_1,
|
| 2251 |
device char * dst,
|
| 2252 |
threadgroup float * shmem_f32 [[threadgroup(0)]],
|
| 2253 |
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 2254 |
+
ushort3 tpitg[[thread_position_in_threadgroup]],
|
| 2255 |
+
ushort sgitg[[simdgroup_index_in_threadgroup]],
|
| 2256 |
+
ushort tiisg[[thread_index_in_simdgroup]],
|
| 2257 |
+
ushort3 ntg[[threads_per_threadgroup]]) {
|
| 2258 |
if (sgitg == 0) {
|
| 2259 |
shmem_f32[tiisg] = 0.0f;
|
| 2260 |
}
|
| 2261 |
|
| 2262 |
+
const int i01 = tgpig.x;
|
| 2263 |
+
const int i02 = tgpig.y;
|
| 2264 |
+
const int i03 = tgpig.z;
|
| 2265 |
+
|
| 2266 |
+
device const float4 * x = (device const float4 *) (src0 + i03*args.nbf3[0] + i02*args.nbf2[0] + i01*args.nbf1[0]);
|
| 2267 |
+
|
| 2268 |
+
device const float4 * f0 = (device const float4 *) (src1_0 + (i03%args.nef3[1])*args.nbf3[1] + (i02%args.nef2[1])*args.nbf2[1] + (i01%args.nef1[1])*args.nbf1[1]);
|
| 2269 |
+
device const float4 * f1 = (device const float4 *) (src1_1 + (i03%args.nef3[2])*args.nbf3[2] + (i02%args.nef2[2])*args.nbf2[2] + (i01%args.nef1[2])*args.nbf1[2]);
|
| 2270 |
|
| 2271 |
float sumf = 0.0f;
|
| 2272 |
|
| 2273 |
// parallel sum
|
| 2274 |
+
for (int i00 = tpitg.x; i00 < args.ne00_4; i00 += ntg.x) {
|
| 2275 |
sumf += dot(x[i00], x[i00]);
|
| 2276 |
}
|
| 2277 |
sumf = simd_sum(sumf);
|
|
|
|
| 2290 |
const float mean = sumf/args.ne00;
|
| 2291 |
const float scale = 1.0f/sqrt(mean + args.eps);
|
| 2292 |
|
| 2293 |
+
device float4 * y = (device float4 *) (dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1);
|
| 2294 |
+
for (int i00 = tpitg.x; i00 < args.ne00_4; i00 += ntg.x) {
|
| 2295 |
+
if (F == 1) {
|
| 2296 |
+
y[i00] = (x[i00]*scale);
|
| 2297 |
+
}
|
| 2298 |
+
if (F == 2) {
|
| 2299 |
+
y[i00] = (x[i00]*scale)*f0[i00];
|
| 2300 |
+
}
|
| 2301 |
+
if (F == 3) {
|
| 2302 |
+
y[i00] = (x[i00]*scale)*f0[i00] + f1[i00];
|
| 2303 |
+
}
|
| 2304 |
}
|
| 2305 |
}
|
| 2306 |
|
| 2307 |
+
typedef decltype(kernel_rms_norm_fuse_impl<1>) kernel_rms_norm_fuse_t;
|
| 2308 |
+
|
| 2309 |
+
template [[host_name("kernel_rms_norm")]] kernel kernel_rms_norm_fuse_t kernel_rms_norm_fuse_impl<1>;
|
| 2310 |
+
template [[host_name("kernel_rms_norm_mul")]] kernel kernel_rms_norm_fuse_t kernel_rms_norm_fuse_impl<2>;
|
| 2311 |
+
template [[host_name("kernel_rms_norm_mul_add")]] kernel kernel_rms_norm_fuse_t kernel_rms_norm_fuse_impl<3>;
|
| 2312 |
+
|
| 2313 |
kernel void kernel_l2_norm(
|
| 2314 |
constant ggml_metal_kargs_l2_norm & args,
|
| 2315 |
device const char * src0,
|