ggerganov commited on
Commit
3c3094f
·
1 Parent(s): e0c6dff

metal : template-ify some of the kernels (llama/8447)

Browse files
Files changed (2) hide show
  1. ggml/src/ggml-metal.m +14 -14
  2. ggml/src/ggml-metal.metal +179 -550
ggml/src/ggml-metal.m CHANGED
@@ -193,16 +193,16 @@ enum ggml_metal_kernel_type {
193
  //GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, // https://github.com/ggerganov/llama.cpp/issues/7261
194
  GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128,
195
  //GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, // https://github.com/ggerganov/llama.cpp/issues/7261
196
- GGML_METAL_KERNEL_TYPE_CPY_F32_F16,
197
  GGML_METAL_KERNEL_TYPE_CPY_F32_F32,
 
 
 
198
  GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0,
199
  GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0,
200
  GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1,
201
  GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0,
202
  GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1,
203
  GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL,
204
- GGML_METAL_KERNEL_TYPE_CPY_F16_F16,
205
- GGML_METAL_KERNEL_TYPE_CPY_F16_F32,
206
  GGML_METAL_KERNEL_TYPE_CONCAT,
207
  GGML_METAL_KERNEL_TYPE_SQR,
208
  GGML_METAL_KERNEL_TYPE_SUM_ROWS,
@@ -651,14 +651,14 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
651
  //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, flash_attn_ext_vec_f16_h256, ctx->support_simdgroup_reduction);
652
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F16, cpy_f32_f16, true);
653
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F32, cpy_f32_f32, true);
 
 
654
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0, cpy_f32_q8_0, true);
655
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0, cpy_f32_q4_0, true);
656
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1, cpy_f32_q4_1, true);
657
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0, cpy_f32_q5_0, true);
658
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1, cpy_f32_q5_1, true);
659
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL, cpy_f32_iq4_nl, true);
660
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F16, cpy_f16_f16, true);
661
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F32, cpy_f16_f32, true);
662
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CONCAT, concat, true);
663
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SQR, sqr, true);
664
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true);
@@ -810,8 +810,8 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const
810
  switch (op->src[0]->type) {
811
  case GGML_TYPE_F32:
812
  switch (op->type) {
813
- case GGML_TYPE_F16:
814
  case GGML_TYPE_F32:
 
815
  case GGML_TYPE_Q8_0:
816
  case GGML_TYPE_Q4_0:
817
  case GGML_TYPE_Q4_1:
@@ -824,8 +824,8 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const
824
  }
825
  case GGML_TYPE_F16:
826
  switch (op->type) {
827
- case GGML_TYPE_F16:
828
  case GGML_TYPE_F32:
 
829
  return true;
830
  default:
831
  return false;
@@ -837,7 +837,7 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const
837
  case GGML_OP_DIAG_MASK_INF:
838
  case GGML_OP_GET_ROWS:
839
  {
840
- return op->src[0]->type != GGML_TYPE_BF16 && op->ne[3] == 1;
841
  }
842
  default:
843
  return false;
@@ -1580,8 +1580,8 @@ static enum ggml_status ggml_metal_graph_compute(
1580
  // some Metal matrix data types require aligned pointers
1581
  // ref: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf (Table 2.5)
1582
  switch (src0->type) {
1583
- case GGML_TYPE_F32: GGML_ASSERT(nb01 % 16 == 0); break;
1584
- case GGML_TYPE_F16: GGML_ASSERT(nb01 % 8 == 0); break;
1585
  default: break;
1586
  }
1587
 
@@ -2775,8 +2775,8 @@ static enum ggml_status ggml_metal_graph_compute(
2775
  GGML_ASSERT(ne0 % ggml_blck_size(dst->type) == 0);
2776
 
2777
  switch (dstt) {
2778
- case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F16].pipeline; break;
2779
- case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F32].pipeline; break;
2780
  case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0].pipeline; break;
2781
  case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0].pipeline; break;
2782
  case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1].pipeline; break;
@@ -2789,8 +2789,8 @@ static enum ggml_status ggml_metal_graph_compute(
2789
  case GGML_TYPE_F16:
2790
  {
2791
  switch (dstt) {
2792
- case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F16_F16].pipeline; break;
2793
- case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F16_F32].pipeline; break;
2794
  default: GGML_ASSERT(false && "not implemented");
2795
  };
2796
  } break;
 
193
  //GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, // https://github.com/ggerganov/llama.cpp/issues/7261
194
  GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128,
195
  //GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, // https://github.com/ggerganov/llama.cpp/issues/7261
 
196
  GGML_METAL_KERNEL_TYPE_CPY_F32_F32,
197
+ GGML_METAL_KERNEL_TYPE_CPY_F32_F16,
198
+ GGML_METAL_KERNEL_TYPE_CPY_F16_F16,
199
+ GGML_METAL_KERNEL_TYPE_CPY_F16_F32,
200
  GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0,
201
  GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0,
202
  GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1,
203
  GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0,
204
  GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1,
205
  GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL,
 
 
206
  GGML_METAL_KERNEL_TYPE_CONCAT,
207
  GGML_METAL_KERNEL_TYPE_SQR,
208
  GGML_METAL_KERNEL_TYPE_SUM_ROWS,
 
651
  //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, flash_attn_ext_vec_f16_h256, ctx->support_simdgroup_reduction);
652
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F16, cpy_f32_f16, true);
653
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F32, cpy_f32_f32, true);
654
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F16, cpy_f16_f16, true);
655
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F32, cpy_f16_f32, true);
656
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0, cpy_f32_q8_0, true);
657
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0, cpy_f32_q4_0, true);
658
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1, cpy_f32_q4_1, true);
659
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0, cpy_f32_q5_0, true);
660
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1, cpy_f32_q5_1, true);
661
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL, cpy_f32_iq4_nl, true);
 
 
662
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CONCAT, concat, true);
663
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SQR, sqr, true);
664
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true);
 
810
  switch (op->src[0]->type) {
811
  case GGML_TYPE_F32:
812
  switch (op->type) {
 
813
  case GGML_TYPE_F32:
814
+ case GGML_TYPE_F16:
815
  case GGML_TYPE_Q8_0:
816
  case GGML_TYPE_Q4_0:
817
  case GGML_TYPE_Q4_1:
 
824
  }
825
  case GGML_TYPE_F16:
826
  switch (op->type) {
 
827
  case GGML_TYPE_F32:
828
+ case GGML_TYPE_F16:
829
  return true;
830
  default:
831
  return false;
 
837
  case GGML_OP_DIAG_MASK_INF:
838
  case GGML_OP_GET_ROWS:
839
  {
840
+ return op->ne[3] == 1;
841
  }
842
  default:
843
  return false;
 
1580
  // some Metal matrix data types require aligned pointers
1581
  // ref: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf (Table 2.5)
1582
  switch (src0->type) {
1583
+ case GGML_TYPE_F32: GGML_ASSERT(nb01 % 16 == 0); break;
1584
+ case GGML_TYPE_F16: GGML_ASSERT(nb01 % 8 == 0); break;
1585
  default: break;
1586
  }
1587
 
 
2775
  GGML_ASSERT(ne0 % ggml_blck_size(dst->type) == 0);
2776
 
2777
  switch (dstt) {
2778
+ case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F32].pipeline; break;
2779
+ case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F16].pipeline; break;
2780
  case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0].pipeline; break;
2781
  case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0].pipeline; break;
2782
  case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1].pipeline; break;
 
2789
  case GGML_TYPE_F16:
2790
  {
2791
  switch (dstt) {
2792
+ case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F16_F32].pipeline; break;
2793
+ case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F16_F16].pipeline; break;
2794
  default: GGML_ASSERT(false && "not implemented");
2795
  };
2796
  } break;
ggml/src/ggml-metal.metal CHANGED
@@ -1219,9 +1219,10 @@ kernel void kernel_mul_mv_q8_0_f32(
1219
  kernel_mul_mv_q8_0_f32_impl(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,nullptr,tgpig,tiisg,sgitg);
1220
  }
1221
 
1222
- #define N_F32_F32 4
1223
 
1224
- void kernel_mul_mv_f32_f32_impl(
 
1225
  device const char * src0,
1226
  device const char * src1,
1227
  device float * dst,
@@ -1239,13 +1240,12 @@ void kernel_mul_mv_f32_f32_impl(
1239
  uint64_t nb12,
1240
  int64_t ne0,
1241
  int64_t ne1,
1242
- uint r2,
1243
- uint r3,
1244
- uint3 tgpig,
1245
- uint tiisg) {
1246
-
1247
  const int64_t r0 = tgpig.x;
1248
- const int64_t rb = tgpig.y*N_F32_F32;
1249
  const int64_t im = tgpig.z;
1250
 
1251
  const uint i12 = im%ne12;
@@ -1253,20 +1253,20 @@ void kernel_mul_mv_f32_f32_impl(
1253
 
1254
  const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02;
1255
 
1256
- device const float * x = (device const float *) (src0 + offset0);
1257
 
1258
  if (ne00 < 128) {
1259
- for (int row = 0; row < N_F32_F32; ++row) {
1260
  int r1 = rb + row;
1261
  if (r1 >= ne11) {
1262
  break;
1263
  }
1264
 
1265
- device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12);
1266
 
1267
  float sumf = 0;
1268
  for (int i = tiisg; i < ne00; i += 32) {
1269
- sumf += (float) x[i] * (float) y[i];
1270
  }
1271
 
1272
  float all_sum = simd_sum(sumf);
@@ -1275,32 +1275,32 @@ void kernel_mul_mv_f32_f32_impl(
1275
  }
1276
  }
1277
  } else {
1278
- device const float4 * x4 = (device const float4 *)x;
1279
- for (int row = 0; row < N_F32_F32; ++row) {
1280
  int r1 = rb + row;
1281
  if (r1 >= ne11) {
1282
  break;
1283
  }
1284
 
1285
- device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12);
1286
- device const float4 * y4 = (device const float4 *) y;
1287
 
1288
  float sumf = 0;
1289
  for (int i = tiisg; i < ne00/4; i += 32) {
1290
- for (int k = 0; k < 4; ++k) sumf += (float) x4[i][k] * y4[i][k];
1291
  }
1292
 
1293
  float all_sum = simd_sum(sumf);
1294
  if (tiisg == 0) {
1295
- for (int i = 4*(ne00/4); i < ne00; ++i) all_sum += (float) x[i] * y[i];
1296
  dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
1297
  }
1298
  }
1299
  }
1300
  }
1301
 
1302
- [[host_name("kernel_mul_mv_f32_f32")]]
1303
- kernel void kernel_mul_mv_f32_f32(
1304
  device const char * src0,
1305
  device const char * src1,
1306
  device float * dst,
@@ -1322,90 +1322,38 @@ kernel void kernel_mul_mv_f32_f32(
1322
  constant uint & r3,
1323
  uint3 tgpig[[threadgroup_position_in_grid]],
1324
  uint tiisg[[thread_index_in_simdgroup]]) {
1325
- kernel_mul_mv_f32_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb00, nb01, nb02, ne10, ne11, ne12, nb10, nb11, nb12, ne0, ne1, r2, r3, tgpig, tiisg);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1326
  }
1327
 
1328
- #define N_F16_F16 4
1329
 
1330
- kernel void kernel_mul_mv_f16_f16(
1331
- device const char * src0,
1332
- device const char * src1,
1333
- device float * dst,
1334
- constant int64_t & ne00,
1335
- constant int64_t & ne01,
1336
- constant int64_t & ne02,
1337
- constant uint64_t & nb00,
1338
- constant uint64_t & nb01,
1339
- constant uint64_t & nb02,
1340
- constant int64_t & ne10,
1341
- constant int64_t & ne11,
1342
- constant int64_t & ne12,
1343
- constant uint64_t & nb10,
1344
- constant uint64_t & nb11,
1345
- constant uint64_t & nb12,
1346
- constant int64_t & ne0,
1347
- constant int64_t & ne1,
1348
- constant uint & r2,
1349
- constant uint & r3,
1350
- uint3 tgpig[[threadgroup_position_in_grid]],
1351
- uint tiisg[[thread_index_in_simdgroup]]) {
1352
 
1353
- const int64_t r0 = tgpig.x;
1354
- const int64_t rb = tgpig.y*N_F16_F16;
1355
- const int64_t im = tgpig.z;
1356
-
1357
- const uint i12 = im%ne12;
1358
- const uint i13 = im/ne12;
1359
-
1360
- const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02;
1361
-
1362
- device const half * x = (device const half *) (src0 + offset0);
1363
-
1364
- if (ne00 < 128) {
1365
- for (int row = 0; row < N_F16_F16; ++row) {
1366
- int r1 = rb + row;
1367
- if (r1 >= ne11) {
1368
- break;
1369
- }
1370
-
1371
- device const half * y = (device const half *) (src1 + r1*nb11 + im*nb12);
1372
-
1373
- float sumf = 0;
1374
- for (int i = tiisg; i < ne00; i += 32) {
1375
- sumf += (half) x[i] * (half) y[i];
1376
- }
1377
-
1378
- float all_sum = simd_sum(sumf);
1379
- if (tiisg == 0) {
1380
- dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
1381
- }
1382
- }
1383
- } else {
1384
- device const half4 * x4 = (device const half4 *)x;
1385
- for (int row = 0; row < N_F16_F16; ++row) {
1386
- int r1 = rb + row;
1387
- if (r1 >= ne11) {
1388
- break;
1389
- }
1390
-
1391
- device const half * y = (device const half *) (src1 + r1*nb11 + im*nb12);
1392
- device const half4 * y4 = (device const half4 *) y;
1393
-
1394
- float sumf = 0;
1395
- for (int i = tiisg; i < ne00/4; i += 32) {
1396
- for (int k = 0; k < 4; ++k) sumf += (half) x4[i][k] * y4[i][k];
1397
- }
1398
-
1399
- float all_sum = simd_sum(sumf);
1400
- if (tiisg == 0) {
1401
- for (int i = 4*(ne00/4); i < ne00; ++i) all_sum += (half) x[i] * y[i];
1402
- dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
1403
- }
1404
- }
1405
- }
1406
- }
1407
-
1408
- void kernel_mul_mv_f16_f32_1row_impl(
1409
  device const char * src0,
1410
  device const char * src1,
1411
  device float * dst,
@@ -1437,7 +1385,7 @@ void kernel_mul_mv_f16_f32_1row_impl(
1437
 
1438
  const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02;
1439
 
1440
- device const half * x = (device const half *) (src0 + offset0);
1441
  device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12);
1442
 
1443
  float sumf = 0;
@@ -1450,153 +1398,29 @@ void kernel_mul_mv_f16_f32_1row_impl(
1450
  dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
1451
  }
1452
  } else {
1453
- device const half4 * x4 = (device const half4 *) x;
1454
  device const float4 * y4 = (device const float4 *) y;
 
1455
  for (int i = tiisg; i < ne00/4; i += 32) {
1456
- for (int k = 0; k < 4; ++k) sumf += (float)x4[i][k] * y4[i][k];
1457
  }
 
1458
  float all_sum = simd_sum(sumf);
 
1459
  if (tiisg == 0) {
1460
- for (int i = 4*(ne00/4); i < ne00; ++i) all_sum += (float) x[i] * y[i];
1461
  dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
1462
  }
1463
  }
1464
  }
1465
 
1466
- [[host_name("kernel_mul_mv_f16_f32_1row")]]
1467
- kernel void kernel_mul_mv_f16_f32_1row(
1468
- device const char * src0,
1469
- device const char * src1,
1470
- device float * dst,
1471
- constant int64_t & ne00,
1472
- constant int64_t & ne01,
1473
- constant int64_t & ne02,
1474
- constant uint64_t & nb00,
1475
- constant uint64_t & nb01,
1476
- constant uint64_t & nb02,
1477
- constant int64_t & ne10,
1478
- constant int64_t & ne11,
1479
- constant int64_t & ne12,
1480
- constant uint64_t & nb10,
1481
- constant uint64_t & nb11,
1482
- constant uint64_t & nb12,
1483
- constant int64_t & ne0,
1484
- constant int64_t & ne1,
1485
- constant uint & r2,
1486
- constant uint & r3,
1487
- uint3 tgpig[[threadgroup_position_in_grid]],
1488
- uint tiisg[[thread_index_in_simdgroup]]) {
1489
- kernel_mul_mv_f16_f32_1row_impl(src0, src1, dst, ne00, ne01, ne02, nb00, nb01, nb02, ne10, ne11, ne12, nb10, nb11, nb12, ne0, ne1, r2, r3, tgpig, tiisg);
1490
- }
1491
-
1492
- #define N_F16_F32 4
1493
-
1494
- void kernel_mul_mv_f16_f32_impl(
1495
- device const char * src0,
1496
- device const char * src1,
1497
- device float * dst,
1498
- int64_t ne00,
1499
- int64_t ne01,
1500
- int64_t ne02,
1501
- uint64_t nb00,
1502
- uint64_t nb01,
1503
- uint64_t nb02,
1504
- int64_t ne10,
1505
- int64_t ne11,
1506
- int64_t ne12,
1507
- uint64_t nb10,
1508
- uint64_t nb11,
1509
- uint64_t nb12,
1510
- int64_t ne0,
1511
- int64_t ne1,
1512
- uint r2,
1513
- uint r3,
1514
- uint3 tgpig,
1515
- uint tiisg) {
1516
-
1517
- const int64_t r0 = tgpig.x;
1518
- const int64_t rb = tgpig.y*N_F16_F32;
1519
- const int64_t im = tgpig.z;
1520
-
1521
- const uint i12 = im%ne12;
1522
- const uint i13 = im/ne12;
1523
-
1524
- const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02;
1525
-
1526
- device const half * x = (device const half *) (src0 + offset0);
1527
-
1528
- if (ne00 < 128) {
1529
- for (int row = 0; row < N_F16_F32; ++row) {
1530
- int r1 = rb + row;
1531
- if (r1 >= ne11) {
1532
- break;
1533
- }
1534
-
1535
- device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12);
1536
-
1537
- float sumf = 0;
1538
- for (int i = tiisg; i < ne00; i += 32) {
1539
- sumf += (float) x[i] * (float) y[i];
1540
- }
1541
-
1542
- float all_sum = simd_sum(sumf);
1543
- if (tiisg == 0) {
1544
- dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
1545
- }
1546
- }
1547
- } else {
1548
- device const half4 * x4 = (device const half4 *)x;
1549
- for (int row = 0; row < N_F16_F32; ++row) {
1550
- int r1 = rb + row;
1551
- if (r1 >= ne11) {
1552
- break;
1553
- }
1554
-
1555
- device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12);
1556
- device const float4 * y4 = (device const float4 *) y;
1557
-
1558
- float sumf = 0;
1559
- for (int i = tiisg; i < ne00/4; i += 32) {
1560
- for (int k = 0; k < 4; ++k) sumf += (float) x4[i][k] * y4[i][k];
1561
- }
1562
 
1563
- float all_sum = simd_sum(sumf);
1564
- if (tiisg == 0) {
1565
- for (int i = 4*(ne00/4); i < ne00; ++i) all_sum += (float) x[i] * y[i];
1566
- dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
1567
- }
1568
- }
1569
- }
1570
- }
1571
-
1572
- [[host_name("kernel_mul_mv_f16_f32")]]
1573
- kernel void kernel_mul_mv_f16_f32(
1574
- device const char * src0,
1575
- device const char * src1,
1576
- device float * dst,
1577
- constant int64_t & ne00,
1578
- constant int64_t & ne01,
1579
- constant int64_t & ne02,
1580
- constant uint64_t & nb00,
1581
- constant uint64_t & nb01,
1582
- constant uint64_t & nb02,
1583
- constant int64_t & ne10,
1584
- constant int64_t & ne11,
1585
- constant int64_t & ne12,
1586
- constant uint64_t & nb10,
1587
- constant uint64_t & nb11,
1588
- constant uint64_t & nb12,
1589
- constant int64_t & ne0,
1590
- constant int64_t & ne1,
1591
- constant uint & r2,
1592
- constant uint & r3,
1593
- uint3 tgpig[[threadgroup_position_in_grid]],
1594
- uint tiisg[[thread_index_in_simdgroup]]) {
1595
- kernel_mul_mv_f16_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb00, nb01, nb02, ne10, ne11, ne12, nb10, nb11, nb12, ne0, ne1, r2, r3, tgpig, tiisg);
1596
- }
1597
 
1598
  // Assumes row size (ne00) is a multiple of 4
1599
- kernel void kernel_mul_mv_f16_f32_l4(
 
1600
  device const char * src0,
1601
  device const char * src1,
1602
  device float * dst,
@@ -1628,14 +1452,14 @@ kernel void kernel_mul_mv_f16_f32_l4(
1628
 
1629
  const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02;
1630
 
1631
- device const half4 * x4 = (device const half4 *) (src0 + offset0);
1632
 
1633
  for (int r1 = 0; r1 < nrows; ++r1) {
1634
  device const float4 * y4 = (device const float4 *) (src1 + r1*nb11 + im*nb12);
1635
 
1636
  float sumf = 0;
1637
  for (int i = tiisg; i < ne00/4; i += 32) {
1638
- for (int k = 0; k < 4; ++k) sumf += (float) x4[i][k] * y4[i][k];
1639
  }
1640
 
1641
  float all_sum = simd_sum(sumf);
@@ -1645,6 +1469,10 @@ kernel void kernel_mul_mv_f16_f32_l4(
1645
  }
1646
  }
1647
 
 
 
 
 
1648
  static float rope_yarn_ramp(const float low, const float high, const int i0) {
1649
  const float y = (i0 / 2 - low) / max(0.001f, high - low);
1650
  return 1.0f - min(1.0f, max(0.0f, y));
@@ -2765,91 +2593,10 @@ kernel void kernel_flash_attn_ext_vec_f16(
2765
  template [[host_name("kernel_flash_attn_ext_vec_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<128>;
2766
  //template [[host_name("kernel_flash_attn_ext_vec_f16_h256")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<256>;
2767
 
2768
- kernel void kernel_cpy_f16_f16(
2769
- device const half * src0,
2770
- device half * dst,
2771
- constant int64_t & ne00,
2772
- constant int64_t & ne01,
2773
- constant int64_t & ne02,
2774
- constant int64_t & ne03,
2775
- constant uint64_t & nb00,
2776
- constant uint64_t & nb01,
2777
- constant uint64_t & nb02,
2778
- constant uint64_t & nb03,
2779
- constant int64_t & ne0,
2780
- constant int64_t & ne1,
2781
- constant int64_t & ne2,
2782
- constant int64_t & ne3,
2783
- constant uint64_t & nb0,
2784
- constant uint64_t & nb1,
2785
- constant uint64_t & nb2,
2786
- constant uint64_t & nb3,
2787
- uint3 tgpig[[threadgroup_position_in_grid]],
2788
- uint3 tpitg[[thread_position_in_threadgroup]],
2789
- uint3 ntg[[threads_per_threadgroup]]) {
2790
- const int64_t i03 = tgpig[2];
2791
- const int64_t i02 = tgpig[1];
2792
- const int64_t i01 = tgpig[0];
2793
-
2794
- const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
2795
-
2796
- const int64_t i3 = n / (ne2*ne1*ne0);
2797
- const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
2798
- const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
2799
- const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
2800
-
2801
- device half * dst_data = (device half *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
2802
-
2803
- for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
2804
- device const half * src = (device half *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
2805
- dst_data[i00] = src[0];
2806
- }
2807
- }
2808
-
2809
- kernel void kernel_cpy_f16_f32(
2810
- device const half * src0,
2811
- device float * dst,
2812
- constant int64_t & ne00,
2813
- constant int64_t & ne01,
2814
- constant int64_t & ne02,
2815
- constant int64_t & ne03,
2816
- constant uint64_t & nb00,
2817
- constant uint64_t & nb01,
2818
- constant uint64_t & nb02,
2819
- constant uint64_t & nb03,
2820
- constant int64_t & ne0,
2821
- constant int64_t & ne1,
2822
- constant int64_t & ne2,
2823
- constant int64_t & ne3,
2824
- constant uint64_t & nb0,
2825
- constant uint64_t & nb1,
2826
- constant uint64_t & nb2,
2827
- constant uint64_t & nb3,
2828
- uint3 tgpig[[threadgroup_position_in_grid]],
2829
- uint3 tpitg[[thread_position_in_threadgroup]],
2830
- uint3 ntg[[threads_per_threadgroup]]) {
2831
- const int64_t i03 = tgpig[2];
2832
- const int64_t i02 = tgpig[1];
2833
- const int64_t i01 = tgpig[0];
2834
-
2835
- const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
2836
-
2837
- const int64_t i3 = n / (ne2*ne1*ne0);
2838
- const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
2839
- const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
2840
- const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
2841
-
2842
- device float * dst_data = (device float *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
2843
-
2844
- for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
2845
- device const half * src = (device half *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
2846
- dst_data[i00] = src[0];
2847
- }
2848
- }
2849
-
2850
- kernel void kernel_cpy_f32_f16(
2851
- device const float * src0,
2852
- device half * dst,
2853
  constant int64_t & ne00,
2854
  constant int64_t & ne01,
2855
  constant int64_t & ne02,
@@ -2880,56 +2627,20 @@ kernel void kernel_cpy_f32_f16(
2880
  const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
2881
  const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
2882
 
2883
- device half * dst_data = (device half *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
2884
 
2885
  for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
2886
- device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
2887
-
2888
- dst_data[i00] = src[0];
2889
  }
2890
  }
2891
 
2892
- kernel void kernel_cpy_f32_f32(
2893
- device const float * src0,
2894
- device float * dst,
2895
- constant int64_t & ne00,
2896
- constant int64_t & ne01,
2897
- constant int64_t & ne02,
2898
- constant int64_t & ne03,
2899
- constant uint64_t & nb00,
2900
- constant uint64_t & nb01,
2901
- constant uint64_t & nb02,
2902
- constant uint64_t & nb03,
2903
- constant int64_t & ne0,
2904
- constant int64_t & ne1,
2905
- constant int64_t & ne2,
2906
- constant int64_t & ne3,
2907
- constant uint64_t & nb0,
2908
- constant uint64_t & nb1,
2909
- constant uint64_t & nb2,
2910
- constant uint64_t & nb3,
2911
- uint3 tgpig[[threadgroup_position_in_grid]],
2912
- uint3 tpitg[[thread_position_in_threadgroup]],
2913
- uint3 ntg[[threads_per_threadgroup]]) {
2914
- const int64_t i03 = tgpig[2];
2915
- const int64_t i02 = tgpig[1];
2916
- const int64_t i01 = tgpig[0];
2917
-
2918
- const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
2919
 
2920
- const int64_t i3 = n / (ne2*ne1*ne0);
2921
- const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
2922
- const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
2923
- const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
2924
-
2925
- device float * dst_data = (device float *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
2926
-
2927
- for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
2928
- device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
2929
-
2930
- dst_data[i00] = src[0];
2931
- }
2932
- }
2933
 
2934
  kernel void kernel_cpy_f32_q8_0(
2935
  device const float * src0,
@@ -5730,9 +5441,9 @@ void dequantize_iq4_xs(device const block_iq4_xs * xb, short il, thread type4x4
5730
  }
5731
 
5732
  template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread float4x4 &)>
5733
- kernel void kernel_get_rows(
5734
  device const void * src0,
5735
- device const char * src1,
5736
  device float * dst,
5737
  constant int64_t & ne00,
5738
  constant uint64_t & nb01,
@@ -5745,55 +5456,24 @@ kernel void kernel_get_rows(
5745
  uint3 tgpig[[threadgroup_position_in_grid]],
5746
  uint tiitg[[thread_index_in_threadgroup]],
5747
  uint3 tptg [[threads_per_threadgroup]]) {
5748
- //const int64_t i = tgpig;
5749
- //const int64_t r = ((device int32_t *) src1)[i];
5750
-
5751
  const int64_t i10 = tgpig.x;
5752
  const int64_t i11 = tgpig.y;
5753
 
5754
- const int64_t r = ((device int32_t *) ((device char *) src1 + i11*nb11 + i10*nb10))[0];
5755
 
5756
  const int64_t i02 = i11;
5757
 
5758
  for (int64_t ind = tiitg; ind < ne00/16; ind += tptg.x) {
5759
  float4x4 temp;
5760
- dequantize_func(
5761
- ((device const block_q *) ((device char *) src0 + r*nb01 + i02*nb02)) + ind/nl, ind%nl, temp);
5762
  *(((device float4x4 *) ((device char *) dst + i11*nb2 + i10*nb1)) + ind) = temp;
5763
  }
5764
  }
5765
 
5766
- kernel void kernel_get_rows_f32(
5767
- device const void * src0,
5768
- device const char * src1,
5769
- device float * dst,
5770
- constant int64_t & ne00,
5771
- constant uint64_t & nb01,
5772
- constant uint64_t & nb02,
5773
- constant int64_t & ne10,
5774
- constant uint64_t & nb10,
5775
- constant uint64_t & nb11,
5776
- constant uint64_t & nb1,
5777
- constant uint64_t & nb2,
5778
- uint3 tgpig[[threadgroup_position_in_grid]],
5779
- uint tiitg[[thread_index_in_threadgroup]],
5780
- uint3 tptg [[threads_per_threadgroup]]) {
5781
- const int64_t i10 = tgpig.x;
5782
- const int64_t i11 = tgpig.y;
5783
-
5784
- const int64_t r = ((device int32_t *) ((device char *) src1 + i11*nb11 + i10*nb10))[0];
5785
-
5786
- const int64_t i02 = i11;
5787
-
5788
- for (int ind = tiitg; ind < ne00; ind += tptg.x) {
5789
- ((device float *) ((device char *) dst + i11*nb2 + i10*nb1))[ind] =
5790
- ((device float *) ((device char *) src0 + r*nb01 + i02*nb02))[ind];
5791
- }
5792
- }
5793
-
5794
- kernel void kernel_get_rows_f16(
5795
  device const void * src0,
5796
- device const char * src1,
5797
  device float * dst,
5798
  constant int64_t & ne00,
5799
  constant uint64_t & nb01,
@@ -5809,19 +5489,19 @@ kernel void kernel_get_rows_f16(
5809
  const int64_t i10 = tgpig.x;
5810
  const int64_t i11 = tgpig.y;
5811
 
5812
- const int64_t r = ((device int32_t *) ((device char *) src1 + i11*nb11 + i10*nb10))[0];
5813
 
5814
  const int64_t i02 = i11;
5815
 
5816
  for (int ind = tiitg; ind < ne00; ind += tptg.x) {
5817
- ((device float *) ((device char *) dst + i11*nb2 + i10*nb1))[ind] =
5818
- ((device half *) ((device char *) src0 + r*nb01 + i02*nb02))[ind];
5819
  }
5820
  }
5821
 
5822
  kernel void kernel_get_rows_i32(
5823
  device const void * src0,
5824
- device const char * src1,
5825
  device int32_t * dst,
5826
  constant int64_t & ne00,
5827
  constant uint64_t & nb01,
@@ -5837,13 +5517,13 @@ kernel void kernel_get_rows_i32(
5837
  const int64_t i10 = tgpig.x;
5838
  const int64_t i11 = tgpig.y;
5839
 
5840
- const int64_t r = ((device int32_t *) ((device char *) src1 + i11*nb11 + i10*nb10))[0];
5841
 
5842
  const int64_t i02 = i11;
5843
 
5844
  for (int ind = tiitg; ind < ne00; ind += tptg.x) {
5845
- ((device int32_t *) ((device char *) dst + i11*nb2 + i10*nb1))[ind] =
5846
- ((device int32_t *) ((device char *) src0 + r*nb01 + i02*nb02))[ind];
5847
  }
5848
  }
5849
 
@@ -5860,28 +5540,28 @@ kernel void kernel_get_rows_i32(
5860
  #define SG_MAT_ROW 8
5861
 
5862
  // each block_q contains 16*nl weights
5863
- template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
5864
- void kernel_mul_mm_impl(device const uchar * src0,
5865
- device const uchar * src1,
5866
- device float * dst,
5867
- constant int64_t & ne00,
5868
- constant int64_t & ne02,
5869
- constant uint64_t & nb01,
5870
- constant uint64_t & nb02,
5871
- constant int64_t & ne12,
5872
- constant uint64_t & nb10,
5873
- constant uint64_t & nb11,
5874
- constant uint64_t & nb12,
5875
- constant int64_t & ne0,
5876
- constant int64_t & ne1,
5877
- constant uint & r2,
5878
- constant uint & r3,
5879
- threadgroup uchar * shared_memory [[threadgroup(0)]],
5880
- uint3 tgpig[[threadgroup_position_in_grid]],
5881
- uint tiitg[[thread_index_in_threadgroup]],
5882
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
5883
 
5884
- threadgroup half * sa = (threadgroup half *)(shared_memory);
5885
  threadgroup float * sb = (threadgroup float *)(shared_memory + 4096);
5886
 
5887
  const uint r0 = tgpig.y;
@@ -5896,7 +5576,7 @@ void kernel_mul_mm_impl(device const uchar * src0,
5896
  short thread_row = ((short)tiitg/THREAD_PER_ROW) < n_rows ? ((short)tiitg/THREAD_PER_ROW) : n_rows - 1;
5897
  short thread_col = ((short)tiitg/THREAD_PER_COL) < n_cols ? ((short)tiitg/THREAD_PER_COL) : n_cols - 1;
5898
 
5899
- simdgroup_half8x8 ma[4];
5900
  simdgroup_float8x8 mb[2];
5901
  simdgroup_float8x8 c_res[8];
5902
  for (int i = 0; i < 8; i++){
@@ -5919,7 +5599,7 @@ void kernel_mul_mm_impl(device const uchar * src0,
5919
 
5920
  for (int loop_k = 0; loop_k < ne00; loop_k += BLOCK_SIZE_K) {
5921
  // load data and store to threadgroup memory
5922
- half4x4 temp_a;
5923
  dequantize_func(x, il, temp_a);
5924
  threadgroup_barrier(mem_flags::mem_threadgroup);
5925
 
@@ -5939,7 +5619,7 @@ void kernel_mul_mm_impl(device const uchar * src0,
5939
  threadgroup_barrier(mem_flags::mem_threadgroup);
5940
 
5941
  // load matrices from threadgroup memory and conduct outer products
5942
- threadgroup half * lsma = (sa + THREAD_MAT_M * SG_MAT_SIZE * (sgitg % 2));
5943
  threadgroup float * lsmb = (sb + THREAD_MAT_N * SG_MAT_SIZE * (sgitg / 2));
5944
 
5945
  #pragma unroll(4)
@@ -6115,48 +5795,6 @@ void kernel_mul_mm_id_impl(
6115
  }
6116
  }
6117
 
6118
- template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
6119
- kernel void kernel_mul_mm(device const uchar * src0,
6120
- device const uchar * src1,
6121
- device float * dst,
6122
- constant int64_t & ne00,
6123
- constant int64_t & ne02,
6124
- constant uint64_t & nb01,
6125
- constant uint64_t & nb02,
6126
- constant int64_t & ne12,
6127
- constant uint64_t & nb10,
6128
- constant uint64_t & nb11,
6129
- constant uint64_t & nb12,
6130
- constant int64_t & ne0,
6131
- constant int64_t & ne1,
6132
- constant uint & r2,
6133
- constant uint & r3,
6134
- threadgroup uchar * shared_memory [[threadgroup(0)]],
6135
- uint3 tgpig[[threadgroup_position_in_grid]],
6136
- uint tiitg[[thread_index_in_threadgroup]],
6137
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
6138
- kernel_mul_mm_impl<block_q, nl, dequantize_func>(
6139
- src0,
6140
- src1,
6141
- dst,
6142
- ne00,
6143
- ne02,
6144
- nb01,
6145
- nb02,
6146
- ne12,
6147
- nb10,
6148
- nb11,
6149
- nb12,
6150
- ne0,
6151
- ne1,
6152
- r2,
6153
- r3,
6154
- shared_memory,
6155
- tgpig,
6156
- tiitg,
6157
- sgitg);
6158
- }
6159
-
6160
  template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
6161
  kernel void kernel_mul_mm_id(
6162
  device const uchar * src0s,
@@ -6237,69 +5875,60 @@ kernel void kernel_mul_mm_id(
6237
  // get rows
6238
  //
6239
 
6240
- typedef void (get_rows_t)(
6241
- device const void * src0,
6242
- device const char * src1,
6243
- device float * dst,
6244
- constant int64_t & ne00,
6245
- constant uint64_t & nb01,
6246
- constant uint64_t & nb02,
6247
- constant int64_t & ne10,
6248
- constant uint64_t & nb10,
6249
- constant uint64_t & nb11,
6250
- constant uint64_t & nb1,
6251
- constant uint64_t & nb2,
6252
- uint3, uint, uint3);
6253
-
6254
- //template [[host_name("kernel_get_rows_f32")]] kernel get_rows_t kernel_get_rows<float4x4, 1, dequantize_f32>;
6255
- //template [[host_name("kernel_get_rows_f16")]] kernel get_rows_t kernel_get_rows<half4x4, 1, dequantize_f16>;
6256
- template [[host_name("kernel_get_rows_q4_0")]] kernel get_rows_t kernel_get_rows<block_q4_0, 2, dequantize_q4_0>;
6257
- template [[host_name("kernel_get_rows_q4_1")]] kernel get_rows_t kernel_get_rows<block_q4_1, 2, dequantize_q4_1>;
6258
- template [[host_name("kernel_get_rows_q5_0")]] kernel get_rows_t kernel_get_rows<block_q5_0, 2, dequantize_q5_0>;
6259
- template [[host_name("kernel_get_rows_q5_1")]] kernel get_rows_t kernel_get_rows<block_q5_1, 2, dequantize_q5_1>;
6260
- template [[host_name("kernel_get_rows_q8_0")]] kernel get_rows_t kernel_get_rows<block_q8_0, 2, dequantize_q8_0>;
6261
- template [[host_name("kernel_get_rows_q2_K")]] kernel get_rows_t kernel_get_rows<block_q2_K, QK_NL, dequantize_q2_K>;
6262
- template [[host_name("kernel_get_rows_q3_K")]] kernel get_rows_t kernel_get_rows<block_q3_K, QK_NL, dequantize_q3_K>;
6263
- template [[host_name("kernel_get_rows_q4_K")]] kernel get_rows_t kernel_get_rows<block_q4_K, QK_NL, dequantize_q4_K>;
6264
- template [[host_name("kernel_get_rows_q5_K")]] kernel get_rows_t kernel_get_rows<block_q5_K, QK_NL, dequantize_q5_K>;
6265
- template [[host_name("kernel_get_rows_q6_K")]] kernel get_rows_t kernel_get_rows<block_q6_K, QK_NL, dequantize_q6_K>;
6266
- template [[host_name("kernel_get_rows_iq2_xxs")]] kernel get_rows_t kernel_get_rows<block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
6267
- template [[host_name("kernel_get_rows_iq2_xs")]] kernel get_rows_t kernel_get_rows<block_iq2_xs, QK_NL, dequantize_iq2_xs>;
6268
- template [[host_name("kernel_get_rows_iq3_xxs")]] kernel get_rows_t kernel_get_rows<block_iq3_xxs, QK_NL, dequantize_iq3_xxs>;
6269
- template [[host_name("kernel_get_rows_iq3_s")]] kernel get_rows_t kernel_get_rows<block_iq3_s, QK_NL, dequantize_iq3_s>;
6270
- template [[host_name("kernel_get_rows_iq2_s")]] kernel get_rows_t kernel_get_rows<block_iq2_s, QK_NL, dequantize_iq2_s>;
6271
- template [[host_name("kernel_get_rows_iq1_s")]] kernel get_rows_t kernel_get_rows<block_iq1_s, QK_NL, dequantize_iq1_s>;
6272
- template [[host_name("kernel_get_rows_iq1_m")]] kernel get_rows_t kernel_get_rows<block_iq1_m, QK_NL, dequantize_iq1_m>;
6273
- template [[host_name("kernel_get_rows_iq4_nl")]] kernel get_rows_t kernel_get_rows<block_iq4_nl, 2, dequantize_iq4_nl>;
6274
- template [[host_name("kernel_get_rows_iq4_xs")]] kernel get_rows_t kernel_get_rows<block_iq4_xs, QK_NL, dequantize_iq4_xs>;
6275
 
6276
  //
6277
  // matrix-matrix multiplication
6278
  //
6279
 
6280
- typedef decltype(kernel_mul_mm<float4x4, 1, dequantize_f32>) mat_mm_t;
6281
-
6282
- template [[host_name("kernel_mul_mm_f32_f32")]] kernel mat_mm_t kernel_mul_mm<float4x4, 1, dequantize_f32>;
6283
- template [[host_name("kernel_mul_mm_f16_f32")]] kernel mat_mm_t kernel_mul_mm<half4x4, 1, dequantize_f16>;
6284
- template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_0, 2, dequantize_q4_0>;
6285
- template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_1, 2, dequantize_q4_1>;
6286
- template [[host_name("kernel_mul_mm_q5_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q5_0, 2, dequantize_q5_0>;
6287
- template [[host_name("kernel_mul_mm_q5_1_f32")]] kernel mat_mm_t kernel_mul_mm<block_q5_1, 2, dequantize_q5_1>;
6288
- template [[host_name("kernel_mul_mm_q8_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q8_0, 2, dequantize_q8_0>;
6289
- template [[host_name("kernel_mul_mm_q2_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q2_K, QK_NL, dequantize_q2_K>;
6290
- template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q3_K, QK_NL, dequantize_q3_K>;
6291
- template [[host_name("kernel_mul_mm_q4_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_K, QK_NL, dequantize_q4_K>;
6292
- template [[host_name("kernel_mul_mm_q5_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q5_K, QK_NL, dequantize_q5_K>;
6293
- template [[host_name("kernel_mul_mm_q6_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q6_K, QK_NL, dequantize_q6_K>;
6294
- template [[host_name("kernel_mul_mm_iq2_xxs_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
6295
- template [[host_name("kernel_mul_mm_iq2_xs_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq2_xs, QK_NL, dequantize_iq2_xs>;
6296
- template [[host_name("kernel_mul_mm_iq3_xxs_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq3_xxs, QK_NL, dequantize_iq3_xxs>;
6297
- template [[host_name("kernel_mul_mm_iq3_s_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq3_s, QK_NL, dequantize_iq3_s>;
6298
- template [[host_name("kernel_mul_mm_iq2_s_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq2_s, QK_NL, dequantize_iq2_s>;
6299
- template [[host_name("kernel_mul_mm_iq1_s_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq1_s, QK_NL, dequantize_iq1_s>;
6300
- template [[host_name("kernel_mul_mm_iq1_m_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq1_m, QK_NL, dequantize_iq1_m>;
6301
- template [[host_name("kernel_mul_mm_iq4_nl_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq4_nl, 2, dequantize_iq4_nl>;
6302
- template [[host_name("kernel_mul_mm_iq4_xs_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq4_xs, QK_NL, dequantize_iq4_xs>;
6303
 
6304
  //
6305
  // indirect matrix-matrix multiplication
@@ -6436,7 +6065,7 @@ void mmv_fn(
6436
  impl_fn(src0,(const device float *)src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,shared_values,tgpig,tiisg,sgitg);
6437
  }
6438
 
6439
- typedef decltype(mmv_fn<kernel_mul_mv_f32_f32_impl>) mul_mv_impl_fn_t;
6440
 
6441
  template<mul_mv_impl_fn_t impl_fn>
6442
  kernel void kernel_mul_mv_id(
@@ -6514,20 +6143,20 @@ kernel void kernel_mul_mv_id(
6514
  sgitg);
6515
  }
6516
 
6517
- typedef decltype(kernel_mul_mv_id<mmv_fn<kernel_mul_mv_f32_f32_impl>>) kernel_mul_mv_id_t;
6518
-
6519
- template [[host_name("kernel_mul_mv_id_f32_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_f32_f32_impl>>;
6520
- template [[host_name("kernel_mul_mv_id_f16_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_f16_f32_impl>>;
6521
- template [[host_name("kernel_mul_mv_id_q8_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q8_0_f32_impl>>;
6522
- template [[host_name("kernel_mul_mv_id_q4_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>>>;
6523
- template [[host_name("kernel_mul_mv_id_q4_1_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>>>;
6524
- template [[host_name("kernel_mul_mv_id_q5_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q5_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>>>;
6525
- template [[host_name("kernel_mul_mv_id_q5_1_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q5_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>>>;
6526
- template [[host_name("kernel_mul_mv_id_q2_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q2_K_f32_impl>>;
6527
- template [[host_name("kernel_mul_mv_id_q3_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q3_K_f32_impl>>;
6528
- template [[host_name("kernel_mul_mv_id_q4_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q4_K_f32_impl>>;
6529
- template [[host_name("kernel_mul_mv_id_q5_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q5_K_f32_impl>>;
6530
- template [[host_name("kernel_mul_mv_id_q6_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q6_K_f32_impl>>;
6531
  template [[host_name("kernel_mul_mv_id_iq1_s_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq1_s_f32_impl>>;
6532
  template [[host_name("kernel_mul_mv_id_iq1_m_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq1_m_f32_impl>>;
6533
  template [[host_name("kernel_mul_mv_id_iq2_xxs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq2_xxs_f32_impl>>;
 
1219
  kernel_mul_mv_q8_0_f32_impl(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,nullptr,tgpig,tiisg,sgitg);
1220
  }
1221
 
1222
+ #define N_MV_T_T 4
1223
 
1224
+ template<typename T0, typename T04, typename T1, typename T14>
1225
+ void kernel_mul_mv_impl(
1226
  device const char * src0,
1227
  device const char * src1,
1228
  device float * dst,
 
1240
  uint64_t nb12,
1241
  int64_t ne0,
1242
  int64_t ne1,
1243
+ uint r2,
1244
+ uint r3,
1245
+ uint3 tgpig,
1246
+ uint tiisg) {
 
1247
  const int64_t r0 = tgpig.x;
1248
+ const int64_t rb = tgpig.y*N_MV_T_T;
1249
  const int64_t im = tgpig.z;
1250
 
1251
  const uint i12 = im%ne12;
 
1253
 
1254
  const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02;
1255
 
1256
+ device const T0 * x = (device const T0 *) (src0 + offset0);
1257
 
1258
  if (ne00 < 128) {
1259
+ for (int row = 0; row < N_MV_T_T; ++row) {
1260
  int r1 = rb + row;
1261
  if (r1 >= ne11) {
1262
  break;
1263
  }
1264
 
1265
+ device const T1 * y = (device const T1 *) (src1 + r1*nb11 + im*nb12);
1266
 
1267
  float sumf = 0;
1268
  for (int i = tiisg; i < ne00; i += 32) {
1269
+ sumf += (T0) x[i] * (T1) y[i];
1270
  }
1271
 
1272
  float all_sum = simd_sum(sumf);
 
1275
  }
1276
  }
1277
  } else {
1278
+ device const T04 * x4 = (device const T04 *) x;
1279
+ for (int row = 0; row < N_MV_T_T; ++row) {
1280
  int r1 = rb + row;
1281
  if (r1 >= ne11) {
1282
  break;
1283
  }
1284
 
1285
+ device const T1 * y = (device const T1 *) (src1 + r1*nb11 + im*nb12);
1286
+ device const T14 * y4 = (device const T14 *) y;
1287
 
1288
  float sumf = 0;
1289
  for (int i = tiisg; i < ne00/4; i += 32) {
1290
+ for (int k = 0; k < 4; ++k) sumf += (float) (x4[i][k] * y4[i][k]);
1291
  }
1292
 
1293
  float all_sum = simd_sum(sumf);
1294
  if (tiisg == 0) {
1295
+ for (int i = 4*(ne00/4); i < ne00; ++i) all_sum += (float) (x[i] * y[i]);
1296
  dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
1297
  }
1298
  }
1299
  }
1300
  }
1301
 
1302
+ template<typename T0, typename T04, typename T1, typename T14>
1303
+ kernel void kernel_mul_mv(
1304
  device const char * src0,
1305
  device const char * src1,
1306
  device float * dst,
 
1322
  constant uint & r3,
1323
  uint3 tgpig[[threadgroup_position_in_grid]],
1324
  uint tiisg[[thread_index_in_simdgroup]]) {
1325
+ kernel_mul_mv_impl<T0, T04, T1, T14>(
1326
+ src0,
1327
+ src1,
1328
+ dst,
1329
+ ne00,
1330
+ ne01,
1331
+ ne02,
1332
+ nb00,
1333
+ nb01,
1334
+ nb02,
1335
+ ne10,
1336
+ ne11,
1337
+ ne12,
1338
+ nb10,
1339
+ nb11,
1340
+ nb12,
1341
+ ne0,
1342
+ ne1,
1343
+ r2,
1344
+ r3,
1345
+ tgpig,
1346
+ tiisg);
1347
  }
1348
 
1349
+ typedef decltype(kernel_mul_mv<half, half4, half, half4>) mul_mv_t;
1350
 
1351
+ template [[host_name("kernel_mul_mv_f32_f32")]] kernel mul_mv_t kernel_mul_mv<float, float4, float, float4>;
1352
+ template [[host_name("kernel_mul_mv_f16_f32")]] kernel mul_mv_t kernel_mul_mv<half, half4, float, float4>;
1353
+ template [[host_name("kernel_mul_mv_f16_f16")]] kernel mul_mv_t kernel_mul_mv<half, half4, half, half4>;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1354
 
1355
+ template<typename T, typename T4>
1356
+ kernel void kernel_mul_mv_1row(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1357
  device const char * src0,
1358
  device const char * src1,
1359
  device float * dst,
 
1385
 
1386
  const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02;
1387
 
1388
+ device const T * x = (device const T *) (src0 + offset0);
1389
  device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12);
1390
 
1391
  float sumf = 0;
 
1398
  dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
1399
  }
1400
  } else {
1401
+ device const T4 * x4 = (device const T4 *) x;
1402
  device const float4 * y4 = (device const float4 *) y;
1403
+
1404
  for (int i = tiisg; i < ne00/4; i += 32) {
1405
+ for (int k = 0; k < 4; ++k) sumf += (float) (x4[i][k] * y4[i][k]);
1406
  }
1407
+
1408
  float all_sum = simd_sum(sumf);
1409
+
1410
  if (tiisg == 0) {
1411
+ for (int i = 4*(ne00/4); i < ne00; ++i) all_sum += (float) (x[i] * y[i]);
1412
  dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
1413
  }
1414
  }
1415
  }
1416
 
1417
+ typedef decltype(kernel_mul_mv_1row<half, half4>) mul_mv_1row_t;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1418
 
1419
+ template [[host_name("kernel_mul_mv_f16_f32_1row")]] kernel mul_mv_1row_t kernel_mul_mv_1row<half, half4>;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1420
 
1421
  // Assumes row size (ne00) is a multiple of 4
1422
+ template<typename T, typename T4>
1423
+ kernel void kernel_mul_mv_l4(
1424
  device const char * src0,
1425
  device const char * src1,
1426
  device float * dst,
 
1452
 
1453
  const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02;
1454
 
1455
+ device const T4 * x4 = (device const T4 *) (src0 + offset0);
1456
 
1457
  for (int r1 = 0; r1 < nrows; ++r1) {
1458
  device const float4 * y4 = (device const float4 *) (src1 + r1*nb11 + im*nb12);
1459
 
1460
  float sumf = 0;
1461
  for (int i = tiisg; i < ne00/4; i += 32) {
1462
+ for (int k = 0; k < 4; ++k) sumf += (float) (x4[i][k] * y4[i][k]);
1463
  }
1464
 
1465
  float all_sum = simd_sum(sumf);
 
1469
  }
1470
  }
1471
 
1472
+ typedef decltype(kernel_mul_mv_l4<half, half4>) mul_mv_l4_t;
1473
+
1474
+ template [[host_name("kernel_mul_mv_f16_f32_l4")]] kernel mul_mv_l4_t kernel_mul_mv_l4<half, half4>;
1475
+
1476
  static float rope_yarn_ramp(const float low, const float high, const int i0) {
1477
  const float y = (i0 / 2 - low) / max(0.001f, high - low);
1478
  return 1.0f - min(1.0f, max(0.0f, y));
 
2593
  template [[host_name("kernel_flash_attn_ext_vec_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<128>;
2594
  //template [[host_name("kernel_flash_attn_ext_vec_f16_h256")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<256>;
2595
 
2596
+ template<typename T0, typename T1>
2597
+ kernel void kernel_cpy(
2598
+ device const void * src0,
2599
+ device void * dst,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2600
  constant int64_t & ne00,
2601
  constant int64_t & ne01,
2602
  constant int64_t & ne02,
 
2627
  const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
2628
  const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
2629
 
2630
+ device T1 * dst_data = (device T1 *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
2631
 
2632
  for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
2633
+ device const T0 * src = (device T0 *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
2634
+ dst_data[i00] = (T1) src[0];
 
2635
  }
2636
  }
2637
 
2638
+ typedef decltype(kernel_cpy<float, float>) kernel_cpy_t;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2639
 
2640
+ template [[host_name("kernel_cpy_f32_f32")]] kernel kernel_cpy_t kernel_cpy<float, float>;
2641
+ template [[host_name("kernel_cpy_f32_f16")]] kernel kernel_cpy_t kernel_cpy<float, half>;
2642
+ template [[host_name("kernel_cpy_f16_f16")]] kernel kernel_cpy_t kernel_cpy<half, half>;
2643
+ template [[host_name("kernel_cpy_f16_f32")]] kernel kernel_cpy_t kernel_cpy<half, float>;
 
 
 
 
 
 
 
 
 
2644
 
2645
  kernel void kernel_cpy_f32_q8_0(
2646
  device const float * src0,
 
5441
  }
5442
 
5443
  template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread float4x4 &)>
5444
+ kernel void kernel_get_rows_q(
5445
  device const void * src0,
5446
+ device const void * src1,
5447
  device float * dst,
5448
  constant int64_t & ne00,
5449
  constant uint64_t & nb01,
 
5456
  uint3 tgpig[[threadgroup_position_in_grid]],
5457
  uint tiitg[[thread_index_in_threadgroup]],
5458
  uint3 tptg [[threads_per_threadgroup]]) {
 
 
 
5459
  const int64_t i10 = tgpig.x;
5460
  const int64_t i11 = tgpig.y;
5461
 
5462
+ const int64_t r = ((const device int32_t *) ((const device char *) src1 + i11*nb11 + i10*nb10))[0];
5463
 
5464
  const int64_t i02 = i11;
5465
 
5466
  for (int64_t ind = tiitg; ind < ne00/16; ind += tptg.x) {
5467
  float4x4 temp;
5468
+ dequantize_func(((device const block_q *) ((const device char *) src0 + r*nb01 + i02*nb02)) + ind/nl, ind%nl, temp);
 
5469
  *(((device float4x4 *) ((device char *) dst + i11*nb2 + i10*nb1)) + ind) = temp;
5470
  }
5471
  }
5472
 
5473
+ template<typename T>
5474
+ kernel void kernel_get_rows_f(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5475
  device const void * src0,
5476
+ device const void * src1,
5477
  device float * dst,
5478
  constant int64_t & ne00,
5479
  constant uint64_t & nb01,
 
5489
  const int64_t i10 = tgpig.x;
5490
  const int64_t i11 = tgpig.y;
5491
 
5492
+ const int64_t r = ((const device int32_t *) ((const device char *) src1 + i11*nb11 + i10*nb10))[0];
5493
 
5494
  const int64_t i02 = i11;
5495
 
5496
  for (int ind = tiitg; ind < ne00; ind += tptg.x) {
5497
+ (( device float *) (( device char *) dst + i11*nb2 + i10*nb1))[ind] =
5498
+ ((const device T *) ((const device char *) src0 + i02*nb02 + r*nb01))[ind];
5499
  }
5500
  }
5501
 
5502
  kernel void kernel_get_rows_i32(
5503
  device const void * src0,
5504
+ device const void * src1,
5505
  device int32_t * dst,
5506
  constant int64_t & ne00,
5507
  constant uint64_t & nb01,
 
5517
  const int64_t i10 = tgpig.x;
5518
  const int64_t i11 = tgpig.y;
5519
 
5520
+ const int64_t r = ((const device int32_t *) ((const device char *) src1 + i11*nb11 + i10*nb10))[0];
5521
 
5522
  const int64_t i02 = i11;
5523
 
5524
  for (int ind = tiitg; ind < ne00; ind += tptg.x) {
5525
+ (( device int32_t *) (( device char *) dst + i11*nb2 + i10*nb1))[ind] =
5526
+ ((const device int32_t *) ((const device char *) src0 + i02*nb02 + r*nb01))[ind];
5527
  }
5528
  }
5529
 
 
5540
  #define SG_MAT_ROW 8
5541
 
5542
  // each block_q contains 16*nl weights
5543
+ template<typename T, typename T4x4, typename simdgroup_T8x8, typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread T4x4 &)>
5544
+ kernel void kernel_mul_mm(device const uchar * src0,
5545
+ device const uchar * src1,
5546
+ device float * dst,
5547
+ constant int64_t & ne00,
5548
+ constant int64_t & ne02,
5549
+ constant uint64_t & nb01,
5550
+ constant uint64_t & nb02,
5551
+ constant int64_t & ne12,
5552
+ constant uint64_t & nb10,
5553
+ constant uint64_t & nb11,
5554
+ constant uint64_t & nb12,
5555
+ constant int64_t & ne0,
5556
+ constant int64_t & ne1,
5557
+ constant uint & r2,
5558
+ constant uint & r3,
5559
+ threadgroup uchar * shared_memory [[threadgroup(0)]],
5560
+ uint3 tgpig[[threadgroup_position_in_grid]],
5561
+ uint tiitg[[thread_index_in_threadgroup]],
5562
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
5563
 
5564
+ threadgroup T * sa = (threadgroup T *)(shared_memory);
5565
  threadgroup float * sb = (threadgroup float *)(shared_memory + 4096);
5566
 
5567
  const uint r0 = tgpig.y;
 
5576
  short thread_row = ((short)tiitg/THREAD_PER_ROW) < n_rows ? ((short)tiitg/THREAD_PER_ROW) : n_rows - 1;
5577
  short thread_col = ((short)tiitg/THREAD_PER_COL) < n_cols ? ((short)tiitg/THREAD_PER_COL) : n_cols - 1;
5578
 
5579
+ simdgroup_T8x8 ma[4];
5580
  simdgroup_float8x8 mb[2];
5581
  simdgroup_float8x8 c_res[8];
5582
  for (int i = 0; i < 8; i++){
 
5599
 
5600
  for (int loop_k = 0; loop_k < ne00; loop_k += BLOCK_SIZE_K) {
5601
  // load data and store to threadgroup memory
5602
+ T4x4 temp_a;
5603
  dequantize_func(x, il, temp_a);
5604
  threadgroup_barrier(mem_flags::mem_threadgroup);
5605
 
 
5619
  threadgroup_barrier(mem_flags::mem_threadgroup);
5620
 
5621
  // load matrices from threadgroup memory and conduct outer products
5622
+ threadgroup T * lsma = (sa + THREAD_MAT_M * SG_MAT_SIZE * (sgitg % 2));
5623
  threadgroup float * lsmb = (sb + THREAD_MAT_N * SG_MAT_SIZE * (sgitg / 2));
5624
 
5625
  #pragma unroll(4)
 
5795
  }
5796
  }
5797
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5798
  template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
5799
  kernel void kernel_mul_mm_id(
5800
  device const uchar * src0s,
 
5875
  // get rows
5876
  //
5877
 
5878
+ typedef decltype(kernel_get_rows_f<float>) get_rows_f_t;
5879
+
5880
+ template [[host_name("kernel_get_rows_f32")]] kernel get_rows_f_t kernel_get_rows_f<float>;
5881
+ template [[host_name("kernel_get_rows_f16")]] kernel get_rows_f_t kernel_get_rows_f<half>;
5882
+
5883
+ typedef decltype(kernel_get_rows_q<block_q4_0, 2, dequantize_q4_0>) get_rows_q_t;
5884
+
5885
+ template [[host_name("kernel_get_rows_q4_0")]] kernel get_rows_q_t kernel_get_rows_q<block_q4_0, 2, dequantize_q4_0>;
5886
+ template [[host_name("kernel_get_rows_q4_1")]] kernel get_rows_q_t kernel_get_rows_q<block_q4_1, 2, dequantize_q4_1>;
5887
+ template [[host_name("kernel_get_rows_q5_0")]] kernel get_rows_q_t kernel_get_rows_q<block_q5_0, 2, dequantize_q5_0>;
5888
+ template [[host_name("kernel_get_rows_q5_1")]] kernel get_rows_q_t kernel_get_rows_q<block_q5_1, 2, dequantize_q5_1>;
5889
+ template [[host_name("kernel_get_rows_q8_0")]] kernel get_rows_q_t kernel_get_rows_q<block_q8_0, 2, dequantize_q8_0>;
5890
+ template [[host_name("kernel_get_rows_q2_K")]] kernel get_rows_q_t kernel_get_rows_q<block_q2_K, QK_NL, dequantize_q2_K>;
5891
+ template [[host_name("kernel_get_rows_q3_K")]] kernel get_rows_q_t kernel_get_rows_q<block_q3_K, QK_NL, dequantize_q3_K>;
5892
+ template [[host_name("kernel_get_rows_q4_K")]] kernel get_rows_q_t kernel_get_rows_q<block_q4_K, QK_NL, dequantize_q4_K>;
5893
+ template [[host_name("kernel_get_rows_q5_K")]] kernel get_rows_q_t kernel_get_rows_q<block_q5_K, QK_NL, dequantize_q5_K>;
5894
+ template [[host_name("kernel_get_rows_q6_K")]] kernel get_rows_q_t kernel_get_rows_q<block_q6_K, QK_NL, dequantize_q6_K>;
5895
+ template [[host_name("kernel_get_rows_iq2_xxs")]] kernel get_rows_q_t kernel_get_rows_q<block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
5896
+ template [[host_name("kernel_get_rows_iq2_xs")]] kernel get_rows_q_t kernel_get_rows_q<block_iq2_xs, QK_NL, dequantize_iq2_xs>;
5897
+ template [[host_name("kernel_get_rows_iq3_xxs")]] kernel get_rows_q_t kernel_get_rows_q<block_iq3_xxs, QK_NL, dequantize_iq3_xxs>;
5898
+ template [[host_name("kernel_get_rows_iq3_s")]] kernel get_rows_q_t kernel_get_rows_q<block_iq3_s, QK_NL, dequantize_iq3_s>;
5899
+ template [[host_name("kernel_get_rows_iq2_s")]] kernel get_rows_q_t kernel_get_rows_q<block_iq2_s, QK_NL, dequantize_iq2_s>;
5900
+ template [[host_name("kernel_get_rows_iq1_s")]] kernel get_rows_q_t kernel_get_rows_q<block_iq1_s, QK_NL, dequantize_iq1_s>;
5901
+ template [[host_name("kernel_get_rows_iq1_m")]] kernel get_rows_q_t kernel_get_rows_q<block_iq1_m, QK_NL, dequantize_iq1_m>;
5902
+ template [[host_name("kernel_get_rows_iq4_nl")]] kernel get_rows_q_t kernel_get_rows_q<block_iq4_nl, 2, dequantize_iq4_nl>;
5903
+ template [[host_name("kernel_get_rows_iq4_xs")]] kernel get_rows_q_t kernel_get_rows_q<block_iq4_xs, QK_NL, dequantize_iq4_xs>;
 
 
 
 
 
 
 
 
 
5904
 
5905
  //
5906
  // matrix-matrix multiplication
5907
  //
5908
 
5909
+ typedef decltype(kernel_mul_mm<half, half4x4, simdgroup_half8x8, float4x4, 1, dequantize_f32>) mat_mm_t;
5910
+
5911
+ template [[host_name("kernel_mul_mm_f32_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, float4x4, 1, dequantize_f32>;
5912
+ template [[host_name("kernel_mul_mm_f16_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half4x4, 1, dequantize_f16>;
5913
+ template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q4_0, 2, dequantize_q4_0>;
5914
+ template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q4_1, 2, dequantize_q4_1>;
5915
+ template [[host_name("kernel_mul_mm_q5_0_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q5_0, 2, dequantize_q5_0>;
5916
+ template [[host_name("kernel_mul_mm_q5_1_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q5_1, 2, dequantize_q5_1>;
5917
+ template [[host_name("kernel_mul_mm_q8_0_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q8_0, 2, dequantize_q8_0>;
5918
+ template [[host_name("kernel_mul_mm_q2_K_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q2_K, QK_NL, dequantize_q2_K>;
5919
+ template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q3_K, QK_NL, dequantize_q3_K>;
5920
+ template [[host_name("kernel_mul_mm_q4_K_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q4_K, QK_NL, dequantize_q4_K>;
5921
+ template [[host_name("kernel_mul_mm_q5_K_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q5_K, QK_NL, dequantize_q5_K>;
5922
+ template [[host_name("kernel_mul_mm_q6_K_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q6_K, QK_NL, dequantize_q6_K>;
5923
+ template [[host_name("kernel_mul_mm_iq2_xxs_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
5924
+ template [[host_name("kernel_mul_mm_iq2_xs_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq2_xs, QK_NL, dequantize_iq2_xs>;
5925
+ template [[host_name("kernel_mul_mm_iq3_xxs_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq3_xxs, QK_NL, dequantize_iq3_xxs>;
5926
+ template [[host_name("kernel_mul_mm_iq3_s_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq3_s, QK_NL, dequantize_iq3_s>;
5927
+ template [[host_name("kernel_mul_mm_iq2_s_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq2_s, QK_NL, dequantize_iq2_s>;
5928
+ template [[host_name("kernel_mul_mm_iq1_s_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq1_s, QK_NL, dequantize_iq1_s>;
5929
+ template [[host_name("kernel_mul_mm_iq1_m_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq1_m, QK_NL, dequantize_iq1_m>;
5930
+ template [[host_name("kernel_mul_mm_iq4_nl_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq4_nl, 2, dequantize_iq4_nl>;
5931
+ template [[host_name("kernel_mul_mm_iq4_xs_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq4_xs, QK_NL, dequantize_iq4_xs>;
5932
 
5933
  //
5934
  // indirect matrix-matrix multiplication
 
6065
  impl_fn(src0,(const device float *)src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,shared_values,tgpig,tiisg,sgitg);
6066
  }
6067
 
6068
+ typedef decltype(mmv_fn<kernel_mul_mv_impl<half, half4, half, half4>>) mul_mv_impl_fn_t;
6069
 
6070
  template<mul_mv_impl_fn_t impl_fn>
6071
  kernel void kernel_mul_mv_id(
 
6143
  sgitg);
6144
  }
6145
 
6146
+ typedef decltype(kernel_mul_mv_id<mmv_fn<kernel_mul_mv_impl<float, float4, float, float4>>>) kernel_mul_mv_id_t;
6147
+
6148
+ template [[host_name("kernel_mul_mv_id_f32_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_impl<float, float4, float, float4>>>;
6149
+ template [[host_name("kernel_mul_mv_id_f16_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_impl<half, half4, float, float4>>>;
6150
+ template [[host_name("kernel_mul_mv_id_q8_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q8_0_f32_impl>>;
6151
+ template [[host_name("kernel_mul_mv_id_q4_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>>>;
6152
+ template [[host_name("kernel_mul_mv_id_q4_1_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>>>;
6153
+ template [[host_name("kernel_mul_mv_id_q5_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q5_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>>>;
6154
+ template [[host_name("kernel_mul_mv_id_q5_1_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q5_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>>>;
6155
+ template [[host_name("kernel_mul_mv_id_q2_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q2_K_f32_impl>>;
6156
+ template [[host_name("kernel_mul_mv_id_q3_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q3_K_f32_impl>>;
6157
+ template [[host_name("kernel_mul_mv_id_q4_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q4_K_f32_impl>>;
6158
+ template [[host_name("kernel_mul_mv_id_q5_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q5_K_f32_impl>>;
6159
+ template [[host_name("kernel_mul_mv_id_q6_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q6_K_f32_impl>>;
6160
  template [[host_name("kernel_mul_mv_id_iq1_s_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq1_s_f32_impl>>;
6161
  template [[host_name("kernel_mul_mv_id_iq1_m_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq1_m_f32_impl>>;
6162
  template [[host_name("kernel_mul_mv_id_iq2_xxs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq2_xxs_f32_impl>>;