Spaces:
Running
Running
metal : template-ify some of the kernels (llama/8447)
Browse files- ggml/src/ggml-metal.m +14 -14
- ggml/src/ggml-metal.metal +179 -550
ggml/src/ggml-metal.m
CHANGED
|
@@ -193,16 +193,16 @@ enum ggml_metal_kernel_type {
|
|
| 193 |
//GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, // https://github.com/ggerganov/llama.cpp/issues/7261
|
| 194 |
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128,
|
| 195 |
//GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, // https://github.com/ggerganov/llama.cpp/issues/7261
|
| 196 |
-
GGML_METAL_KERNEL_TYPE_CPY_F32_F16,
|
| 197 |
GGML_METAL_KERNEL_TYPE_CPY_F32_F32,
|
|
|
|
|
|
|
|
|
|
| 198 |
GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0,
|
| 199 |
GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0,
|
| 200 |
GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1,
|
| 201 |
GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0,
|
| 202 |
GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1,
|
| 203 |
GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL,
|
| 204 |
-
GGML_METAL_KERNEL_TYPE_CPY_F16_F16,
|
| 205 |
-
GGML_METAL_KERNEL_TYPE_CPY_F16_F32,
|
| 206 |
GGML_METAL_KERNEL_TYPE_CONCAT,
|
| 207 |
GGML_METAL_KERNEL_TYPE_SQR,
|
| 208 |
GGML_METAL_KERNEL_TYPE_SUM_ROWS,
|
|
@@ -651,14 +651,14 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
| 651 |
//GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, flash_attn_ext_vec_f16_h256, ctx->support_simdgroup_reduction);
|
| 652 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F16, cpy_f32_f16, true);
|
| 653 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F32, cpy_f32_f32, true);
|
|
|
|
|
|
|
| 654 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0, cpy_f32_q8_0, true);
|
| 655 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0, cpy_f32_q4_0, true);
|
| 656 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1, cpy_f32_q4_1, true);
|
| 657 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0, cpy_f32_q5_0, true);
|
| 658 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1, cpy_f32_q5_1, true);
|
| 659 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL, cpy_f32_iq4_nl, true);
|
| 660 |
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F16, cpy_f16_f16, true);
|
| 661 |
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F32, cpy_f16_f32, true);
|
| 662 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CONCAT, concat, true);
|
| 663 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SQR, sqr, true);
|
| 664 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true);
|
|
@@ -810,8 +810,8 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const
|
|
| 810 |
switch (op->src[0]->type) {
|
| 811 |
case GGML_TYPE_F32:
|
| 812 |
switch (op->type) {
|
| 813 |
-
case GGML_TYPE_F16:
|
| 814 |
case GGML_TYPE_F32:
|
|
|
|
| 815 |
case GGML_TYPE_Q8_0:
|
| 816 |
case GGML_TYPE_Q4_0:
|
| 817 |
case GGML_TYPE_Q4_1:
|
|
@@ -824,8 +824,8 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const
|
|
| 824 |
}
|
| 825 |
case GGML_TYPE_F16:
|
| 826 |
switch (op->type) {
|
| 827 |
-
case GGML_TYPE_F16:
|
| 828 |
case GGML_TYPE_F32:
|
|
|
|
| 829 |
return true;
|
| 830 |
default:
|
| 831 |
return false;
|
|
@@ -837,7 +837,7 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const
|
|
| 837 |
case GGML_OP_DIAG_MASK_INF:
|
| 838 |
case GGML_OP_GET_ROWS:
|
| 839 |
{
|
| 840 |
-
return op->
|
| 841 |
}
|
| 842 |
default:
|
| 843 |
return false;
|
|
@@ -1580,8 +1580,8 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
| 1580 |
// some Metal matrix data types require aligned pointers
|
| 1581 |
// ref: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf (Table 2.5)
|
| 1582 |
switch (src0->type) {
|
| 1583 |
-
case GGML_TYPE_F32:
|
| 1584 |
-
case GGML_TYPE_F16:
|
| 1585 |
default: break;
|
| 1586 |
}
|
| 1587 |
|
|
@@ -2775,8 +2775,8 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
| 2775 |
GGML_ASSERT(ne0 % ggml_blck_size(dst->type) == 0);
|
| 2776 |
|
| 2777 |
switch (dstt) {
|
| 2778 |
-
case
|
| 2779 |
-
case
|
| 2780 |
case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0].pipeline; break;
|
| 2781 |
case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0].pipeline; break;
|
| 2782 |
case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1].pipeline; break;
|
|
@@ -2789,8 +2789,8 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
| 2789 |
case GGML_TYPE_F16:
|
| 2790 |
{
|
| 2791 |
switch (dstt) {
|
| 2792 |
-
case
|
| 2793 |
-
case
|
| 2794 |
default: GGML_ASSERT(false && "not implemented");
|
| 2795 |
};
|
| 2796 |
} break;
|
|
|
|
| 193 |
//GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, // https://github.com/ggerganov/llama.cpp/issues/7261
|
| 194 |
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128,
|
| 195 |
//GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, // https://github.com/ggerganov/llama.cpp/issues/7261
|
|
|
|
| 196 |
GGML_METAL_KERNEL_TYPE_CPY_F32_F32,
|
| 197 |
+
GGML_METAL_KERNEL_TYPE_CPY_F32_F16,
|
| 198 |
+
GGML_METAL_KERNEL_TYPE_CPY_F16_F16,
|
| 199 |
+
GGML_METAL_KERNEL_TYPE_CPY_F16_F32,
|
| 200 |
GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0,
|
| 201 |
GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0,
|
| 202 |
GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1,
|
| 203 |
GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0,
|
| 204 |
GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1,
|
| 205 |
GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL,
|
|
|
|
|
|
|
| 206 |
GGML_METAL_KERNEL_TYPE_CONCAT,
|
| 207 |
GGML_METAL_KERNEL_TYPE_SQR,
|
| 208 |
GGML_METAL_KERNEL_TYPE_SUM_ROWS,
|
|
|
|
| 651 |
//GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, flash_attn_ext_vec_f16_h256, ctx->support_simdgroup_reduction);
|
| 652 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F16, cpy_f32_f16, true);
|
| 653 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F32, cpy_f32_f32, true);
|
| 654 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F16, cpy_f16_f16, true);
|
| 655 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F32, cpy_f16_f32, true);
|
| 656 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0, cpy_f32_q8_0, true);
|
| 657 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0, cpy_f32_q4_0, true);
|
| 658 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1, cpy_f32_q4_1, true);
|
| 659 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0, cpy_f32_q5_0, true);
|
| 660 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1, cpy_f32_q5_1, true);
|
| 661 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL, cpy_f32_iq4_nl, true);
|
|
|
|
|
|
|
| 662 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CONCAT, concat, true);
|
| 663 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SQR, sqr, true);
|
| 664 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true);
|
|
|
|
| 810 |
switch (op->src[0]->type) {
|
| 811 |
case GGML_TYPE_F32:
|
| 812 |
switch (op->type) {
|
|
|
|
| 813 |
case GGML_TYPE_F32:
|
| 814 |
+
case GGML_TYPE_F16:
|
| 815 |
case GGML_TYPE_Q8_0:
|
| 816 |
case GGML_TYPE_Q4_0:
|
| 817 |
case GGML_TYPE_Q4_1:
|
|
|
|
| 824 |
}
|
| 825 |
case GGML_TYPE_F16:
|
| 826 |
switch (op->type) {
|
|
|
|
| 827 |
case GGML_TYPE_F32:
|
| 828 |
+
case GGML_TYPE_F16:
|
| 829 |
return true;
|
| 830 |
default:
|
| 831 |
return false;
|
|
|
|
| 837 |
case GGML_OP_DIAG_MASK_INF:
|
| 838 |
case GGML_OP_GET_ROWS:
|
| 839 |
{
|
| 840 |
+
return op->ne[3] == 1;
|
| 841 |
}
|
| 842 |
default:
|
| 843 |
return false;
|
|
|
|
| 1580 |
// some Metal matrix data types require aligned pointers
|
| 1581 |
// ref: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf (Table 2.5)
|
| 1582 |
switch (src0->type) {
|
| 1583 |
+
case GGML_TYPE_F32: GGML_ASSERT(nb01 % 16 == 0); break;
|
| 1584 |
+
case GGML_TYPE_F16: GGML_ASSERT(nb01 % 8 == 0); break;
|
| 1585 |
default: break;
|
| 1586 |
}
|
| 1587 |
|
|
|
|
| 2775 |
GGML_ASSERT(ne0 % ggml_blck_size(dst->type) == 0);
|
| 2776 |
|
| 2777 |
switch (dstt) {
|
| 2778 |
+
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F32].pipeline; break;
|
| 2779 |
+
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F16].pipeline; break;
|
| 2780 |
case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0].pipeline; break;
|
| 2781 |
case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0].pipeline; break;
|
| 2782 |
case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1].pipeline; break;
|
|
|
|
| 2789 |
case GGML_TYPE_F16:
|
| 2790 |
{
|
| 2791 |
switch (dstt) {
|
| 2792 |
+
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F16_F32].pipeline; break;
|
| 2793 |
+
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F16_F16].pipeline; break;
|
| 2794 |
default: GGML_ASSERT(false && "not implemented");
|
| 2795 |
};
|
| 2796 |
} break;
|
ggml/src/ggml-metal.metal
CHANGED
|
@@ -1219,9 +1219,10 @@ kernel void kernel_mul_mv_q8_0_f32(
|
|
| 1219 |
kernel_mul_mv_q8_0_f32_impl(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,nullptr,tgpig,tiisg,sgitg);
|
| 1220 |
}
|
| 1221 |
|
| 1222 |
-
#define
|
| 1223 |
|
| 1224 |
-
|
|
|
|
| 1225 |
device const char * src0,
|
| 1226 |
device const char * src1,
|
| 1227 |
device float * dst,
|
|
@@ -1239,13 +1240,12 @@ void kernel_mul_mv_f32_f32_impl(
|
|
| 1239 |
uint64_t nb12,
|
| 1240 |
int64_t ne0,
|
| 1241 |
int64_t ne1,
|
| 1242 |
-
|
| 1243 |
-
|
| 1244 |
-
|
| 1245 |
-
|
| 1246 |
-
|
| 1247 |
const int64_t r0 = tgpig.x;
|
| 1248 |
-
const int64_t rb = tgpig.y*
|
| 1249 |
const int64_t im = tgpig.z;
|
| 1250 |
|
| 1251 |
const uint i12 = im%ne12;
|
|
@@ -1253,20 +1253,20 @@ void kernel_mul_mv_f32_f32_impl(
|
|
| 1253 |
|
| 1254 |
const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02;
|
| 1255 |
|
| 1256 |
-
device const
|
| 1257 |
|
| 1258 |
if (ne00 < 128) {
|
| 1259 |
-
for (int row = 0; row <
|
| 1260 |
int r1 = rb + row;
|
| 1261 |
if (r1 >= ne11) {
|
| 1262 |
break;
|
| 1263 |
}
|
| 1264 |
|
| 1265 |
-
device const
|
| 1266 |
|
| 1267 |
float sumf = 0;
|
| 1268 |
for (int i = tiisg; i < ne00; i += 32) {
|
| 1269 |
-
sumf += (
|
| 1270 |
}
|
| 1271 |
|
| 1272 |
float all_sum = simd_sum(sumf);
|
|
@@ -1275,32 +1275,32 @@ void kernel_mul_mv_f32_f32_impl(
|
|
| 1275 |
}
|
| 1276 |
}
|
| 1277 |
} else {
|
| 1278 |
-
device const
|
| 1279 |
-
for (int row = 0; row <
|
| 1280 |
int r1 = rb + row;
|
| 1281 |
if (r1 >= ne11) {
|
| 1282 |
break;
|
| 1283 |
}
|
| 1284 |
|
| 1285 |
-
device const
|
| 1286 |
-
device const
|
| 1287 |
|
| 1288 |
float sumf = 0;
|
| 1289 |
for (int i = tiisg; i < ne00/4; i += 32) {
|
| 1290 |
-
for (int k = 0; k < 4; ++k) sumf += (float) x4[i][k] * y4[i][k];
|
| 1291 |
}
|
| 1292 |
|
| 1293 |
float all_sum = simd_sum(sumf);
|
| 1294 |
if (tiisg == 0) {
|
| 1295 |
-
for (int i = 4*(ne00/4); i < ne00; ++i) all_sum += (float) x[i] * y[i];
|
| 1296 |
dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
|
| 1297 |
}
|
| 1298 |
}
|
| 1299 |
}
|
| 1300 |
}
|
| 1301 |
|
| 1302 |
-
|
| 1303 |
-
kernel void
|
| 1304 |
device const char * src0,
|
| 1305 |
device const char * src1,
|
| 1306 |
device float * dst,
|
|
@@ -1322,90 +1322,38 @@ kernel void kernel_mul_mv_f32_f32(
|
|
| 1322 |
constant uint & r3,
|
| 1323 |
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 1324 |
uint tiisg[[thread_index_in_simdgroup]]) {
|
| 1325 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1326 |
}
|
| 1327 |
|
| 1328 |
-
|
| 1329 |
|
| 1330 |
-
kernel
|
| 1331 |
-
|
| 1332 |
-
|
| 1333 |
-
device float * dst,
|
| 1334 |
-
constant int64_t & ne00,
|
| 1335 |
-
constant int64_t & ne01,
|
| 1336 |
-
constant int64_t & ne02,
|
| 1337 |
-
constant uint64_t & nb00,
|
| 1338 |
-
constant uint64_t & nb01,
|
| 1339 |
-
constant uint64_t & nb02,
|
| 1340 |
-
constant int64_t & ne10,
|
| 1341 |
-
constant int64_t & ne11,
|
| 1342 |
-
constant int64_t & ne12,
|
| 1343 |
-
constant uint64_t & nb10,
|
| 1344 |
-
constant uint64_t & nb11,
|
| 1345 |
-
constant uint64_t & nb12,
|
| 1346 |
-
constant int64_t & ne0,
|
| 1347 |
-
constant int64_t & ne1,
|
| 1348 |
-
constant uint & r2,
|
| 1349 |
-
constant uint & r3,
|
| 1350 |
-
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 1351 |
-
uint tiisg[[thread_index_in_simdgroup]]) {
|
| 1352 |
|
| 1353 |
-
|
| 1354 |
-
|
| 1355 |
-
const int64_t im = tgpig.z;
|
| 1356 |
-
|
| 1357 |
-
const uint i12 = im%ne12;
|
| 1358 |
-
const uint i13 = im/ne12;
|
| 1359 |
-
|
| 1360 |
-
const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02;
|
| 1361 |
-
|
| 1362 |
-
device const half * x = (device const half *) (src0 + offset0);
|
| 1363 |
-
|
| 1364 |
-
if (ne00 < 128) {
|
| 1365 |
-
for (int row = 0; row < N_F16_F16; ++row) {
|
| 1366 |
-
int r1 = rb + row;
|
| 1367 |
-
if (r1 >= ne11) {
|
| 1368 |
-
break;
|
| 1369 |
-
}
|
| 1370 |
-
|
| 1371 |
-
device const half * y = (device const half *) (src1 + r1*nb11 + im*nb12);
|
| 1372 |
-
|
| 1373 |
-
float sumf = 0;
|
| 1374 |
-
for (int i = tiisg; i < ne00; i += 32) {
|
| 1375 |
-
sumf += (half) x[i] * (half) y[i];
|
| 1376 |
-
}
|
| 1377 |
-
|
| 1378 |
-
float all_sum = simd_sum(sumf);
|
| 1379 |
-
if (tiisg == 0) {
|
| 1380 |
-
dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
|
| 1381 |
-
}
|
| 1382 |
-
}
|
| 1383 |
-
} else {
|
| 1384 |
-
device const half4 * x4 = (device const half4 *)x;
|
| 1385 |
-
for (int row = 0; row < N_F16_F16; ++row) {
|
| 1386 |
-
int r1 = rb + row;
|
| 1387 |
-
if (r1 >= ne11) {
|
| 1388 |
-
break;
|
| 1389 |
-
}
|
| 1390 |
-
|
| 1391 |
-
device const half * y = (device const half *) (src1 + r1*nb11 + im*nb12);
|
| 1392 |
-
device const half4 * y4 = (device const half4 *) y;
|
| 1393 |
-
|
| 1394 |
-
float sumf = 0;
|
| 1395 |
-
for (int i = tiisg; i < ne00/4; i += 32) {
|
| 1396 |
-
for (int k = 0; k < 4; ++k) sumf += (half) x4[i][k] * y4[i][k];
|
| 1397 |
-
}
|
| 1398 |
-
|
| 1399 |
-
float all_sum = simd_sum(sumf);
|
| 1400 |
-
if (tiisg == 0) {
|
| 1401 |
-
for (int i = 4*(ne00/4); i < ne00; ++i) all_sum += (half) x[i] * y[i];
|
| 1402 |
-
dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
|
| 1403 |
-
}
|
| 1404 |
-
}
|
| 1405 |
-
}
|
| 1406 |
-
}
|
| 1407 |
-
|
| 1408 |
-
void kernel_mul_mv_f16_f32_1row_impl(
|
| 1409 |
device const char * src0,
|
| 1410 |
device const char * src1,
|
| 1411 |
device float * dst,
|
|
@@ -1437,7 +1385,7 @@ void kernel_mul_mv_f16_f32_1row_impl(
|
|
| 1437 |
|
| 1438 |
const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02;
|
| 1439 |
|
| 1440 |
-
device const
|
| 1441 |
device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12);
|
| 1442 |
|
| 1443 |
float sumf = 0;
|
|
@@ -1450,153 +1398,29 @@ void kernel_mul_mv_f16_f32_1row_impl(
|
|
| 1450 |
dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
|
| 1451 |
}
|
| 1452 |
} else {
|
| 1453 |
-
device const
|
| 1454 |
device const float4 * y4 = (device const float4 *) y;
|
|
|
|
| 1455 |
for (int i = tiisg; i < ne00/4; i += 32) {
|
| 1456 |
-
for (int k = 0; k < 4; ++k) sumf += (float)x4[i][k] * y4[i][k];
|
| 1457 |
}
|
|
|
|
| 1458 |
float all_sum = simd_sum(sumf);
|
|
|
|
| 1459 |
if (tiisg == 0) {
|
| 1460 |
-
for (int i = 4*(ne00/4); i < ne00; ++i) all_sum += (float) x[i] * y[i];
|
| 1461 |
dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
|
| 1462 |
}
|
| 1463 |
}
|
| 1464 |
}
|
| 1465 |
|
| 1466 |
-
|
| 1467 |
-
kernel void kernel_mul_mv_f16_f32_1row(
|
| 1468 |
-
device const char * src0,
|
| 1469 |
-
device const char * src1,
|
| 1470 |
-
device float * dst,
|
| 1471 |
-
constant int64_t & ne00,
|
| 1472 |
-
constant int64_t & ne01,
|
| 1473 |
-
constant int64_t & ne02,
|
| 1474 |
-
constant uint64_t & nb00,
|
| 1475 |
-
constant uint64_t & nb01,
|
| 1476 |
-
constant uint64_t & nb02,
|
| 1477 |
-
constant int64_t & ne10,
|
| 1478 |
-
constant int64_t & ne11,
|
| 1479 |
-
constant int64_t & ne12,
|
| 1480 |
-
constant uint64_t & nb10,
|
| 1481 |
-
constant uint64_t & nb11,
|
| 1482 |
-
constant uint64_t & nb12,
|
| 1483 |
-
constant int64_t & ne0,
|
| 1484 |
-
constant int64_t & ne1,
|
| 1485 |
-
constant uint & r2,
|
| 1486 |
-
constant uint & r3,
|
| 1487 |
-
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 1488 |
-
uint tiisg[[thread_index_in_simdgroup]]) {
|
| 1489 |
-
kernel_mul_mv_f16_f32_1row_impl(src0, src1, dst, ne00, ne01, ne02, nb00, nb01, nb02, ne10, ne11, ne12, nb10, nb11, nb12, ne0, ne1, r2, r3, tgpig, tiisg);
|
| 1490 |
-
}
|
| 1491 |
-
|
| 1492 |
-
#define N_F16_F32 4
|
| 1493 |
-
|
| 1494 |
-
void kernel_mul_mv_f16_f32_impl(
|
| 1495 |
-
device const char * src0,
|
| 1496 |
-
device const char * src1,
|
| 1497 |
-
device float * dst,
|
| 1498 |
-
int64_t ne00,
|
| 1499 |
-
int64_t ne01,
|
| 1500 |
-
int64_t ne02,
|
| 1501 |
-
uint64_t nb00,
|
| 1502 |
-
uint64_t nb01,
|
| 1503 |
-
uint64_t nb02,
|
| 1504 |
-
int64_t ne10,
|
| 1505 |
-
int64_t ne11,
|
| 1506 |
-
int64_t ne12,
|
| 1507 |
-
uint64_t nb10,
|
| 1508 |
-
uint64_t nb11,
|
| 1509 |
-
uint64_t nb12,
|
| 1510 |
-
int64_t ne0,
|
| 1511 |
-
int64_t ne1,
|
| 1512 |
-
uint r2,
|
| 1513 |
-
uint r3,
|
| 1514 |
-
uint3 tgpig,
|
| 1515 |
-
uint tiisg) {
|
| 1516 |
-
|
| 1517 |
-
const int64_t r0 = tgpig.x;
|
| 1518 |
-
const int64_t rb = tgpig.y*N_F16_F32;
|
| 1519 |
-
const int64_t im = tgpig.z;
|
| 1520 |
-
|
| 1521 |
-
const uint i12 = im%ne12;
|
| 1522 |
-
const uint i13 = im/ne12;
|
| 1523 |
-
|
| 1524 |
-
const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02;
|
| 1525 |
-
|
| 1526 |
-
device const half * x = (device const half *) (src0 + offset0);
|
| 1527 |
-
|
| 1528 |
-
if (ne00 < 128) {
|
| 1529 |
-
for (int row = 0; row < N_F16_F32; ++row) {
|
| 1530 |
-
int r1 = rb + row;
|
| 1531 |
-
if (r1 >= ne11) {
|
| 1532 |
-
break;
|
| 1533 |
-
}
|
| 1534 |
-
|
| 1535 |
-
device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12);
|
| 1536 |
-
|
| 1537 |
-
float sumf = 0;
|
| 1538 |
-
for (int i = tiisg; i < ne00; i += 32) {
|
| 1539 |
-
sumf += (float) x[i] * (float) y[i];
|
| 1540 |
-
}
|
| 1541 |
-
|
| 1542 |
-
float all_sum = simd_sum(sumf);
|
| 1543 |
-
if (tiisg == 0) {
|
| 1544 |
-
dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
|
| 1545 |
-
}
|
| 1546 |
-
}
|
| 1547 |
-
} else {
|
| 1548 |
-
device const half4 * x4 = (device const half4 *)x;
|
| 1549 |
-
for (int row = 0; row < N_F16_F32; ++row) {
|
| 1550 |
-
int r1 = rb + row;
|
| 1551 |
-
if (r1 >= ne11) {
|
| 1552 |
-
break;
|
| 1553 |
-
}
|
| 1554 |
-
|
| 1555 |
-
device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12);
|
| 1556 |
-
device const float4 * y4 = (device const float4 *) y;
|
| 1557 |
-
|
| 1558 |
-
float sumf = 0;
|
| 1559 |
-
for (int i = tiisg; i < ne00/4; i += 32) {
|
| 1560 |
-
for (int k = 0; k < 4; ++k) sumf += (float) x4[i][k] * y4[i][k];
|
| 1561 |
-
}
|
| 1562 |
|
| 1563 |
-
|
| 1564 |
-
if (tiisg == 0) {
|
| 1565 |
-
for (int i = 4*(ne00/4); i < ne00; ++i) all_sum += (float) x[i] * y[i];
|
| 1566 |
-
dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
|
| 1567 |
-
}
|
| 1568 |
-
}
|
| 1569 |
-
}
|
| 1570 |
-
}
|
| 1571 |
-
|
| 1572 |
-
[[host_name("kernel_mul_mv_f16_f32")]]
|
| 1573 |
-
kernel void kernel_mul_mv_f16_f32(
|
| 1574 |
-
device const char * src0,
|
| 1575 |
-
device const char * src1,
|
| 1576 |
-
device float * dst,
|
| 1577 |
-
constant int64_t & ne00,
|
| 1578 |
-
constant int64_t & ne01,
|
| 1579 |
-
constant int64_t & ne02,
|
| 1580 |
-
constant uint64_t & nb00,
|
| 1581 |
-
constant uint64_t & nb01,
|
| 1582 |
-
constant uint64_t & nb02,
|
| 1583 |
-
constant int64_t & ne10,
|
| 1584 |
-
constant int64_t & ne11,
|
| 1585 |
-
constant int64_t & ne12,
|
| 1586 |
-
constant uint64_t & nb10,
|
| 1587 |
-
constant uint64_t & nb11,
|
| 1588 |
-
constant uint64_t & nb12,
|
| 1589 |
-
constant int64_t & ne0,
|
| 1590 |
-
constant int64_t & ne1,
|
| 1591 |
-
constant uint & r2,
|
| 1592 |
-
constant uint & r3,
|
| 1593 |
-
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 1594 |
-
uint tiisg[[thread_index_in_simdgroup]]) {
|
| 1595 |
-
kernel_mul_mv_f16_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb00, nb01, nb02, ne10, ne11, ne12, nb10, nb11, nb12, ne0, ne1, r2, r3, tgpig, tiisg);
|
| 1596 |
-
}
|
| 1597 |
|
| 1598 |
// Assumes row size (ne00) is a multiple of 4
|
| 1599 |
-
|
|
|
|
| 1600 |
device const char * src0,
|
| 1601 |
device const char * src1,
|
| 1602 |
device float * dst,
|
|
@@ -1628,14 +1452,14 @@ kernel void kernel_mul_mv_f16_f32_l4(
|
|
| 1628 |
|
| 1629 |
const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02;
|
| 1630 |
|
| 1631 |
-
device const
|
| 1632 |
|
| 1633 |
for (int r1 = 0; r1 < nrows; ++r1) {
|
| 1634 |
device const float4 * y4 = (device const float4 *) (src1 + r1*nb11 + im*nb12);
|
| 1635 |
|
| 1636 |
float sumf = 0;
|
| 1637 |
for (int i = tiisg; i < ne00/4; i += 32) {
|
| 1638 |
-
for (int k = 0; k < 4; ++k) sumf += (float) x4[i][k] * y4[i][k];
|
| 1639 |
}
|
| 1640 |
|
| 1641 |
float all_sum = simd_sum(sumf);
|
|
@@ -1645,6 +1469,10 @@ kernel void kernel_mul_mv_f16_f32_l4(
|
|
| 1645 |
}
|
| 1646 |
}
|
| 1647 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1648 |
static float rope_yarn_ramp(const float low, const float high, const int i0) {
|
| 1649 |
const float y = (i0 / 2 - low) / max(0.001f, high - low);
|
| 1650 |
return 1.0f - min(1.0f, max(0.0f, y));
|
|
@@ -2765,91 +2593,10 @@ kernel void kernel_flash_attn_ext_vec_f16(
|
|
| 2765 |
template [[host_name("kernel_flash_attn_ext_vec_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<128>;
|
| 2766 |
//template [[host_name("kernel_flash_attn_ext_vec_f16_h256")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<256>;
|
| 2767 |
|
| 2768 |
-
|
| 2769 |
-
|
| 2770 |
-
device
|
| 2771 |
-
|
| 2772 |
-
constant int64_t & ne01,
|
| 2773 |
-
constant int64_t & ne02,
|
| 2774 |
-
constant int64_t & ne03,
|
| 2775 |
-
constant uint64_t & nb00,
|
| 2776 |
-
constant uint64_t & nb01,
|
| 2777 |
-
constant uint64_t & nb02,
|
| 2778 |
-
constant uint64_t & nb03,
|
| 2779 |
-
constant int64_t & ne0,
|
| 2780 |
-
constant int64_t & ne1,
|
| 2781 |
-
constant int64_t & ne2,
|
| 2782 |
-
constant int64_t & ne3,
|
| 2783 |
-
constant uint64_t & nb0,
|
| 2784 |
-
constant uint64_t & nb1,
|
| 2785 |
-
constant uint64_t & nb2,
|
| 2786 |
-
constant uint64_t & nb3,
|
| 2787 |
-
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 2788 |
-
uint3 tpitg[[thread_position_in_threadgroup]],
|
| 2789 |
-
uint3 ntg[[threads_per_threadgroup]]) {
|
| 2790 |
-
const int64_t i03 = tgpig[2];
|
| 2791 |
-
const int64_t i02 = tgpig[1];
|
| 2792 |
-
const int64_t i01 = tgpig[0];
|
| 2793 |
-
|
| 2794 |
-
const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
|
| 2795 |
-
|
| 2796 |
-
const int64_t i3 = n / (ne2*ne1*ne0);
|
| 2797 |
-
const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
|
| 2798 |
-
const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
|
| 2799 |
-
const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
|
| 2800 |
-
|
| 2801 |
-
device half * dst_data = (device half *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
| 2802 |
-
|
| 2803 |
-
for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
|
| 2804 |
-
device const half * src = (device half *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
|
| 2805 |
-
dst_data[i00] = src[0];
|
| 2806 |
-
}
|
| 2807 |
-
}
|
| 2808 |
-
|
| 2809 |
-
kernel void kernel_cpy_f16_f32(
|
| 2810 |
-
device const half * src0,
|
| 2811 |
-
device float * dst,
|
| 2812 |
-
constant int64_t & ne00,
|
| 2813 |
-
constant int64_t & ne01,
|
| 2814 |
-
constant int64_t & ne02,
|
| 2815 |
-
constant int64_t & ne03,
|
| 2816 |
-
constant uint64_t & nb00,
|
| 2817 |
-
constant uint64_t & nb01,
|
| 2818 |
-
constant uint64_t & nb02,
|
| 2819 |
-
constant uint64_t & nb03,
|
| 2820 |
-
constant int64_t & ne0,
|
| 2821 |
-
constant int64_t & ne1,
|
| 2822 |
-
constant int64_t & ne2,
|
| 2823 |
-
constant int64_t & ne3,
|
| 2824 |
-
constant uint64_t & nb0,
|
| 2825 |
-
constant uint64_t & nb1,
|
| 2826 |
-
constant uint64_t & nb2,
|
| 2827 |
-
constant uint64_t & nb3,
|
| 2828 |
-
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 2829 |
-
uint3 tpitg[[thread_position_in_threadgroup]],
|
| 2830 |
-
uint3 ntg[[threads_per_threadgroup]]) {
|
| 2831 |
-
const int64_t i03 = tgpig[2];
|
| 2832 |
-
const int64_t i02 = tgpig[1];
|
| 2833 |
-
const int64_t i01 = tgpig[0];
|
| 2834 |
-
|
| 2835 |
-
const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
|
| 2836 |
-
|
| 2837 |
-
const int64_t i3 = n / (ne2*ne1*ne0);
|
| 2838 |
-
const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
|
| 2839 |
-
const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
|
| 2840 |
-
const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
|
| 2841 |
-
|
| 2842 |
-
device float * dst_data = (device float *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
| 2843 |
-
|
| 2844 |
-
for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
|
| 2845 |
-
device const half * src = (device half *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
|
| 2846 |
-
dst_data[i00] = src[0];
|
| 2847 |
-
}
|
| 2848 |
-
}
|
| 2849 |
-
|
| 2850 |
-
kernel void kernel_cpy_f32_f16(
|
| 2851 |
-
device const float * src0,
|
| 2852 |
-
device half * dst,
|
| 2853 |
constant int64_t & ne00,
|
| 2854 |
constant int64_t & ne01,
|
| 2855 |
constant int64_t & ne02,
|
|
@@ -2880,56 +2627,20 @@ kernel void kernel_cpy_f32_f16(
|
|
| 2880 |
const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
|
| 2881 |
const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
|
| 2882 |
|
| 2883 |
-
device
|
| 2884 |
|
| 2885 |
for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
|
| 2886 |
-
device const
|
| 2887 |
-
|
| 2888 |
-
dst_data[i00] = src[0];
|
| 2889 |
}
|
| 2890 |
}
|
| 2891 |
|
| 2892 |
-
|
| 2893 |
-
device const float * src0,
|
| 2894 |
-
device float * dst,
|
| 2895 |
-
constant int64_t & ne00,
|
| 2896 |
-
constant int64_t & ne01,
|
| 2897 |
-
constant int64_t & ne02,
|
| 2898 |
-
constant int64_t & ne03,
|
| 2899 |
-
constant uint64_t & nb00,
|
| 2900 |
-
constant uint64_t & nb01,
|
| 2901 |
-
constant uint64_t & nb02,
|
| 2902 |
-
constant uint64_t & nb03,
|
| 2903 |
-
constant int64_t & ne0,
|
| 2904 |
-
constant int64_t & ne1,
|
| 2905 |
-
constant int64_t & ne2,
|
| 2906 |
-
constant int64_t & ne3,
|
| 2907 |
-
constant uint64_t & nb0,
|
| 2908 |
-
constant uint64_t & nb1,
|
| 2909 |
-
constant uint64_t & nb2,
|
| 2910 |
-
constant uint64_t & nb3,
|
| 2911 |
-
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 2912 |
-
uint3 tpitg[[thread_position_in_threadgroup]],
|
| 2913 |
-
uint3 ntg[[threads_per_threadgroup]]) {
|
| 2914 |
-
const int64_t i03 = tgpig[2];
|
| 2915 |
-
const int64_t i02 = tgpig[1];
|
| 2916 |
-
const int64_t i01 = tgpig[0];
|
| 2917 |
-
|
| 2918 |
-
const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
|
| 2919 |
|
| 2920 |
-
|
| 2921 |
-
|
| 2922 |
-
|
| 2923 |
-
|
| 2924 |
-
|
| 2925 |
-
device float * dst_data = (device float *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
| 2926 |
-
|
| 2927 |
-
for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
|
| 2928 |
-
device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
|
| 2929 |
-
|
| 2930 |
-
dst_data[i00] = src[0];
|
| 2931 |
-
}
|
| 2932 |
-
}
|
| 2933 |
|
| 2934 |
kernel void kernel_cpy_f32_q8_0(
|
| 2935 |
device const float * src0,
|
|
@@ -5730,9 +5441,9 @@ void dequantize_iq4_xs(device const block_iq4_xs * xb, short il, thread type4x4
|
|
| 5730 |
}
|
| 5731 |
|
| 5732 |
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread float4x4 &)>
|
| 5733 |
-
kernel void
|
| 5734 |
device const void * src0,
|
| 5735 |
-
device const
|
| 5736 |
device float * dst,
|
| 5737 |
constant int64_t & ne00,
|
| 5738 |
constant uint64_t & nb01,
|
|
@@ -5745,55 +5456,24 @@ kernel void kernel_get_rows(
|
|
| 5745 |
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 5746 |
uint tiitg[[thread_index_in_threadgroup]],
|
| 5747 |
uint3 tptg [[threads_per_threadgroup]]) {
|
| 5748 |
-
//const int64_t i = tgpig;
|
| 5749 |
-
//const int64_t r = ((device int32_t *) src1)[i];
|
| 5750 |
-
|
| 5751 |
const int64_t i10 = tgpig.x;
|
| 5752 |
const int64_t i11 = tgpig.y;
|
| 5753 |
|
| 5754 |
-
const int64_t r = ((device int32_t *) ((device char *) src1 + i11*nb11 + i10*nb10))[0];
|
| 5755 |
|
| 5756 |
const int64_t i02 = i11;
|
| 5757 |
|
| 5758 |
for (int64_t ind = tiitg; ind < ne00/16; ind += tptg.x) {
|
| 5759 |
float4x4 temp;
|
| 5760 |
-
dequantize_func(
|
| 5761 |
-
((device const block_q *) ((device char *) src0 + r*nb01 + i02*nb02)) + ind/nl, ind%nl, temp);
|
| 5762 |
*(((device float4x4 *) ((device char *) dst + i11*nb2 + i10*nb1)) + ind) = temp;
|
| 5763 |
}
|
| 5764 |
}
|
| 5765 |
|
| 5766 |
-
|
| 5767 |
-
|
| 5768 |
-
device const char * src1,
|
| 5769 |
-
device float * dst,
|
| 5770 |
-
constant int64_t & ne00,
|
| 5771 |
-
constant uint64_t & nb01,
|
| 5772 |
-
constant uint64_t & nb02,
|
| 5773 |
-
constant int64_t & ne10,
|
| 5774 |
-
constant uint64_t & nb10,
|
| 5775 |
-
constant uint64_t & nb11,
|
| 5776 |
-
constant uint64_t & nb1,
|
| 5777 |
-
constant uint64_t & nb2,
|
| 5778 |
-
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 5779 |
-
uint tiitg[[thread_index_in_threadgroup]],
|
| 5780 |
-
uint3 tptg [[threads_per_threadgroup]]) {
|
| 5781 |
-
const int64_t i10 = tgpig.x;
|
| 5782 |
-
const int64_t i11 = tgpig.y;
|
| 5783 |
-
|
| 5784 |
-
const int64_t r = ((device int32_t *) ((device char *) src1 + i11*nb11 + i10*nb10))[0];
|
| 5785 |
-
|
| 5786 |
-
const int64_t i02 = i11;
|
| 5787 |
-
|
| 5788 |
-
for (int ind = tiitg; ind < ne00; ind += tptg.x) {
|
| 5789 |
-
((device float *) ((device char *) dst + i11*nb2 + i10*nb1))[ind] =
|
| 5790 |
-
((device float *) ((device char *) src0 + r*nb01 + i02*nb02))[ind];
|
| 5791 |
-
}
|
| 5792 |
-
}
|
| 5793 |
-
|
| 5794 |
-
kernel void kernel_get_rows_f16(
|
| 5795 |
device const void * src0,
|
| 5796 |
-
device const
|
| 5797 |
device float * dst,
|
| 5798 |
constant int64_t & ne00,
|
| 5799 |
constant uint64_t & nb01,
|
|
@@ -5809,19 +5489,19 @@ kernel void kernel_get_rows_f16(
|
|
| 5809 |
const int64_t i10 = tgpig.x;
|
| 5810 |
const int64_t i11 = tgpig.y;
|
| 5811 |
|
| 5812 |
-
const int64_t r = ((device int32_t *) ((device char *) src1 + i11*nb11 + i10*nb10))[0];
|
| 5813 |
|
| 5814 |
const int64_t i02 = i11;
|
| 5815 |
|
| 5816 |
for (int ind = tiitg; ind < ne00; ind += tptg.x) {
|
| 5817 |
-
((device float *) ((device char *)
|
| 5818 |
-
|
| 5819 |
}
|
| 5820 |
}
|
| 5821 |
|
| 5822 |
kernel void kernel_get_rows_i32(
|
| 5823 |
device const void * src0,
|
| 5824 |
-
device const
|
| 5825 |
device int32_t * dst,
|
| 5826 |
constant int64_t & ne00,
|
| 5827 |
constant uint64_t & nb01,
|
|
@@ -5837,13 +5517,13 @@ kernel void kernel_get_rows_i32(
|
|
| 5837 |
const int64_t i10 = tgpig.x;
|
| 5838 |
const int64_t i11 = tgpig.y;
|
| 5839 |
|
| 5840 |
-
const int64_t r = ((device int32_t *) ((device char *) src1 + i11*nb11 + i10*nb10))[0];
|
| 5841 |
|
| 5842 |
const int64_t i02 = i11;
|
| 5843 |
|
| 5844 |
for (int ind = tiitg; ind < ne00; ind += tptg.x) {
|
| 5845 |
-
((device int32_t *) ((device char *) dst
|
| 5846 |
-
|
| 5847 |
}
|
| 5848 |
}
|
| 5849 |
|
|
@@ -5860,28 +5540,28 @@ kernel void kernel_get_rows_i32(
|
|
| 5860 |
#define SG_MAT_ROW 8
|
| 5861 |
|
| 5862 |
// each block_q contains 16*nl weights
|
| 5863 |
-
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread
|
| 5864 |
-
void
|
| 5865 |
-
|
| 5866 |
-
|
| 5867 |
-
|
| 5868 |
-
|
| 5869 |
-
|
| 5870 |
-
|
| 5871 |
-
|
| 5872 |
-
|
| 5873 |
-
|
| 5874 |
-
|
| 5875 |
-
|
| 5876 |
-
|
| 5877 |
-
|
| 5878 |
-
|
| 5879 |
-
|
| 5880 |
-
|
| 5881 |
-
|
| 5882 |
-
|
| 5883 |
|
| 5884 |
-
threadgroup
|
| 5885 |
threadgroup float * sb = (threadgroup float *)(shared_memory + 4096);
|
| 5886 |
|
| 5887 |
const uint r0 = tgpig.y;
|
|
@@ -5896,7 +5576,7 @@ void kernel_mul_mm_impl(device const uchar * src0,
|
|
| 5896 |
short thread_row = ((short)tiitg/THREAD_PER_ROW) < n_rows ? ((short)tiitg/THREAD_PER_ROW) : n_rows - 1;
|
| 5897 |
short thread_col = ((short)tiitg/THREAD_PER_COL) < n_cols ? ((short)tiitg/THREAD_PER_COL) : n_cols - 1;
|
| 5898 |
|
| 5899 |
-
|
| 5900 |
simdgroup_float8x8 mb[2];
|
| 5901 |
simdgroup_float8x8 c_res[8];
|
| 5902 |
for (int i = 0; i < 8; i++){
|
|
@@ -5919,7 +5599,7 @@ void kernel_mul_mm_impl(device const uchar * src0,
|
|
| 5919 |
|
| 5920 |
for (int loop_k = 0; loop_k < ne00; loop_k += BLOCK_SIZE_K) {
|
| 5921 |
// load data and store to threadgroup memory
|
| 5922 |
-
|
| 5923 |
dequantize_func(x, il, temp_a);
|
| 5924 |
threadgroup_barrier(mem_flags::mem_threadgroup);
|
| 5925 |
|
|
@@ -5939,7 +5619,7 @@ void kernel_mul_mm_impl(device const uchar * src0,
|
|
| 5939 |
threadgroup_barrier(mem_flags::mem_threadgroup);
|
| 5940 |
|
| 5941 |
// load matrices from threadgroup memory and conduct outer products
|
| 5942 |
-
threadgroup
|
| 5943 |
threadgroup float * lsmb = (sb + THREAD_MAT_N * SG_MAT_SIZE * (sgitg / 2));
|
| 5944 |
|
| 5945 |
#pragma unroll(4)
|
|
@@ -6115,48 +5795,6 @@ void kernel_mul_mm_id_impl(
|
|
| 6115 |
}
|
| 6116 |
}
|
| 6117 |
|
| 6118 |
-
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
|
| 6119 |
-
kernel void kernel_mul_mm(device const uchar * src0,
|
| 6120 |
-
device const uchar * src1,
|
| 6121 |
-
device float * dst,
|
| 6122 |
-
constant int64_t & ne00,
|
| 6123 |
-
constant int64_t & ne02,
|
| 6124 |
-
constant uint64_t & nb01,
|
| 6125 |
-
constant uint64_t & nb02,
|
| 6126 |
-
constant int64_t & ne12,
|
| 6127 |
-
constant uint64_t & nb10,
|
| 6128 |
-
constant uint64_t & nb11,
|
| 6129 |
-
constant uint64_t & nb12,
|
| 6130 |
-
constant int64_t & ne0,
|
| 6131 |
-
constant int64_t & ne1,
|
| 6132 |
-
constant uint & r2,
|
| 6133 |
-
constant uint & r3,
|
| 6134 |
-
threadgroup uchar * shared_memory [[threadgroup(0)]],
|
| 6135 |
-
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 6136 |
-
uint tiitg[[thread_index_in_threadgroup]],
|
| 6137 |
-
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
| 6138 |
-
kernel_mul_mm_impl<block_q, nl, dequantize_func>(
|
| 6139 |
-
src0,
|
| 6140 |
-
src1,
|
| 6141 |
-
dst,
|
| 6142 |
-
ne00,
|
| 6143 |
-
ne02,
|
| 6144 |
-
nb01,
|
| 6145 |
-
nb02,
|
| 6146 |
-
ne12,
|
| 6147 |
-
nb10,
|
| 6148 |
-
nb11,
|
| 6149 |
-
nb12,
|
| 6150 |
-
ne0,
|
| 6151 |
-
ne1,
|
| 6152 |
-
r2,
|
| 6153 |
-
r3,
|
| 6154 |
-
shared_memory,
|
| 6155 |
-
tgpig,
|
| 6156 |
-
tiitg,
|
| 6157 |
-
sgitg);
|
| 6158 |
-
}
|
| 6159 |
-
|
| 6160 |
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
|
| 6161 |
kernel void kernel_mul_mm_id(
|
| 6162 |
device const uchar * src0s,
|
|
@@ -6237,69 +5875,60 @@ kernel void kernel_mul_mm_id(
|
|
| 6237 |
// get rows
|
| 6238 |
//
|
| 6239 |
|
| 6240 |
-
typedef
|
| 6241 |
-
|
| 6242 |
-
|
| 6243 |
-
|
| 6244 |
-
|
| 6245 |
-
|
| 6246 |
-
|
| 6247 |
-
|
| 6248 |
-
|
| 6249 |
-
|
| 6250 |
-
|
| 6251 |
-
|
| 6252 |
-
|
| 6253 |
-
|
| 6254 |
-
|
| 6255 |
-
|
| 6256 |
-
template [[host_name("
|
| 6257 |
-
template [[host_name("
|
| 6258 |
-
template [[host_name("
|
| 6259 |
-
template [[host_name("
|
| 6260 |
-
template [[host_name("
|
| 6261 |
-
template [[host_name("
|
| 6262 |
-
template [[host_name("
|
| 6263 |
-
template [[host_name("
|
| 6264 |
-
template [[host_name("
|
| 6265 |
-
template [[host_name("
|
| 6266 |
-
template [[host_name("kernel_get_rows_iq2_xxs")]] kernel get_rows_t kernel_get_rows<block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
|
| 6267 |
-
template [[host_name("kernel_get_rows_iq2_xs")]] kernel get_rows_t kernel_get_rows<block_iq2_xs, QK_NL, dequantize_iq2_xs>;
|
| 6268 |
-
template [[host_name("kernel_get_rows_iq3_xxs")]] kernel get_rows_t kernel_get_rows<block_iq3_xxs, QK_NL, dequantize_iq3_xxs>;
|
| 6269 |
-
template [[host_name("kernel_get_rows_iq3_s")]] kernel get_rows_t kernel_get_rows<block_iq3_s, QK_NL, dequantize_iq3_s>;
|
| 6270 |
-
template [[host_name("kernel_get_rows_iq2_s")]] kernel get_rows_t kernel_get_rows<block_iq2_s, QK_NL, dequantize_iq2_s>;
|
| 6271 |
-
template [[host_name("kernel_get_rows_iq1_s")]] kernel get_rows_t kernel_get_rows<block_iq1_s, QK_NL, dequantize_iq1_s>;
|
| 6272 |
-
template [[host_name("kernel_get_rows_iq1_m")]] kernel get_rows_t kernel_get_rows<block_iq1_m, QK_NL, dequantize_iq1_m>;
|
| 6273 |
-
template [[host_name("kernel_get_rows_iq4_nl")]] kernel get_rows_t kernel_get_rows<block_iq4_nl, 2, dequantize_iq4_nl>;
|
| 6274 |
-
template [[host_name("kernel_get_rows_iq4_xs")]] kernel get_rows_t kernel_get_rows<block_iq4_xs, QK_NL, dequantize_iq4_xs>;
|
| 6275 |
|
| 6276 |
//
|
| 6277 |
// matrix-matrix multiplication
|
| 6278 |
//
|
| 6279 |
|
| 6280 |
-
typedef decltype(kernel_mul_mm<float4x4, 1, dequantize_f32>) mat_mm_t;
|
| 6281 |
-
|
| 6282 |
-
template [[host_name("kernel_mul_mm_f32_f32")]] kernel mat_mm_t kernel_mul_mm<float4x4, 1, dequantize_f32>;
|
| 6283 |
-
template [[host_name("kernel_mul_mm_f16_f32")]] kernel mat_mm_t kernel_mul_mm<half4x4, 1, dequantize_f16>;
|
| 6284 |
-
template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_0, 2, dequantize_q4_0>;
|
| 6285 |
-
template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_1, 2, dequantize_q4_1>;
|
| 6286 |
-
template [[host_name("kernel_mul_mm_q5_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q5_0, 2, dequantize_q5_0>;
|
| 6287 |
-
template [[host_name("kernel_mul_mm_q5_1_f32")]] kernel mat_mm_t kernel_mul_mm<block_q5_1, 2, dequantize_q5_1>;
|
| 6288 |
-
template [[host_name("kernel_mul_mm_q8_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q8_0, 2, dequantize_q8_0>;
|
| 6289 |
-
template [[host_name("kernel_mul_mm_q2_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q2_K, QK_NL, dequantize_q2_K>;
|
| 6290 |
-
template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q3_K, QK_NL, dequantize_q3_K>;
|
| 6291 |
-
template [[host_name("kernel_mul_mm_q4_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_K, QK_NL, dequantize_q4_K>;
|
| 6292 |
-
template [[host_name("kernel_mul_mm_q5_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q5_K, QK_NL, dequantize_q5_K>;
|
| 6293 |
-
template [[host_name("kernel_mul_mm_q6_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q6_K, QK_NL, dequantize_q6_K>;
|
| 6294 |
-
template [[host_name("kernel_mul_mm_iq2_xxs_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
|
| 6295 |
-
template [[host_name("kernel_mul_mm_iq2_xs_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq2_xs, QK_NL, dequantize_iq2_xs>;
|
| 6296 |
-
template [[host_name("kernel_mul_mm_iq3_xxs_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq3_xxs, QK_NL, dequantize_iq3_xxs>;
|
| 6297 |
-
template [[host_name("kernel_mul_mm_iq3_s_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq3_s, QK_NL, dequantize_iq3_s>;
|
| 6298 |
-
template [[host_name("kernel_mul_mm_iq2_s_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq2_s, QK_NL, dequantize_iq2_s>;
|
| 6299 |
-
template [[host_name("kernel_mul_mm_iq1_s_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq1_s, QK_NL, dequantize_iq1_s>;
|
| 6300 |
-
template [[host_name("kernel_mul_mm_iq1_m_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq1_m, QK_NL, dequantize_iq1_m>;
|
| 6301 |
-
template [[host_name("kernel_mul_mm_iq4_nl_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq4_nl, 2, dequantize_iq4_nl>;
|
| 6302 |
-
template [[host_name("kernel_mul_mm_iq4_xs_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq4_xs, QK_NL, dequantize_iq4_xs>;
|
| 6303 |
|
| 6304 |
//
|
| 6305 |
// indirect matrix-matrix multiplication
|
|
@@ -6436,7 +6065,7 @@ void mmv_fn(
|
|
| 6436 |
impl_fn(src0,(const device float *)src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,shared_values,tgpig,tiisg,sgitg);
|
| 6437 |
}
|
| 6438 |
|
| 6439 |
-
typedef decltype(mmv_fn<
|
| 6440 |
|
| 6441 |
template<mul_mv_impl_fn_t impl_fn>
|
| 6442 |
kernel void kernel_mul_mv_id(
|
|
@@ -6514,20 +6143,20 @@ kernel void kernel_mul_mv_id(
|
|
| 6514 |
sgitg);
|
| 6515 |
}
|
| 6516 |
|
| 6517 |
-
typedef decltype(kernel_mul_mv_id<mmv_fn<
|
| 6518 |
-
|
| 6519 |
-
template [[host_name("kernel_mul_mv_id_f32_f32")]]
|
| 6520 |
-
template [[host_name("kernel_mul_mv_id_f16_f32")]]
|
| 6521 |
-
template [[host_name("kernel_mul_mv_id_q8_0_f32")]]
|
| 6522 |
-
template [[host_name("kernel_mul_mv_id_q4_0_f32")]]
|
| 6523 |
-
template [[host_name("kernel_mul_mv_id_q4_1_f32")]]
|
| 6524 |
-
template [[host_name("kernel_mul_mv_id_q5_0_f32")]]
|
| 6525 |
-
template [[host_name("kernel_mul_mv_id_q5_1_f32")]]
|
| 6526 |
-
template [[host_name("kernel_mul_mv_id_q2_K_f32")]]
|
| 6527 |
-
template [[host_name("kernel_mul_mv_id_q3_K_f32")]]
|
| 6528 |
-
template [[host_name("kernel_mul_mv_id_q4_K_f32")]]
|
| 6529 |
-
template [[host_name("kernel_mul_mv_id_q5_K_f32")]]
|
| 6530 |
-
template [[host_name("kernel_mul_mv_id_q6_K_f32")]]
|
| 6531 |
template [[host_name("kernel_mul_mv_id_iq1_s_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq1_s_f32_impl>>;
|
| 6532 |
template [[host_name("kernel_mul_mv_id_iq1_m_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq1_m_f32_impl>>;
|
| 6533 |
template [[host_name("kernel_mul_mv_id_iq2_xxs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq2_xxs_f32_impl>>;
|
|
|
|
| 1219 |
kernel_mul_mv_q8_0_f32_impl(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,nullptr,tgpig,tiisg,sgitg);
|
| 1220 |
}
|
| 1221 |
|
| 1222 |
+
#define N_MV_T_T 4
|
| 1223 |
|
| 1224 |
+
template<typename T0, typename T04, typename T1, typename T14>
|
| 1225 |
+
void kernel_mul_mv_impl(
|
| 1226 |
device const char * src0,
|
| 1227 |
device const char * src1,
|
| 1228 |
device float * dst,
|
|
|
|
| 1240 |
uint64_t nb12,
|
| 1241 |
int64_t ne0,
|
| 1242 |
int64_t ne1,
|
| 1243 |
+
uint r2,
|
| 1244 |
+
uint r3,
|
| 1245 |
+
uint3 tgpig,
|
| 1246 |
+
uint tiisg) {
|
|
|
|
| 1247 |
const int64_t r0 = tgpig.x;
|
| 1248 |
+
const int64_t rb = tgpig.y*N_MV_T_T;
|
| 1249 |
const int64_t im = tgpig.z;
|
| 1250 |
|
| 1251 |
const uint i12 = im%ne12;
|
|
|
|
| 1253 |
|
| 1254 |
const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02;
|
| 1255 |
|
| 1256 |
+
device const T0 * x = (device const T0 *) (src0 + offset0);
|
| 1257 |
|
| 1258 |
if (ne00 < 128) {
|
| 1259 |
+
for (int row = 0; row < N_MV_T_T; ++row) {
|
| 1260 |
int r1 = rb + row;
|
| 1261 |
if (r1 >= ne11) {
|
| 1262 |
break;
|
| 1263 |
}
|
| 1264 |
|
| 1265 |
+
device const T1 * y = (device const T1 *) (src1 + r1*nb11 + im*nb12);
|
| 1266 |
|
| 1267 |
float sumf = 0;
|
| 1268 |
for (int i = tiisg; i < ne00; i += 32) {
|
| 1269 |
+
sumf += (T0) x[i] * (T1) y[i];
|
| 1270 |
}
|
| 1271 |
|
| 1272 |
float all_sum = simd_sum(sumf);
|
|
|
|
| 1275 |
}
|
| 1276 |
}
|
| 1277 |
} else {
|
| 1278 |
+
device const T04 * x4 = (device const T04 *) x;
|
| 1279 |
+
for (int row = 0; row < N_MV_T_T; ++row) {
|
| 1280 |
int r1 = rb + row;
|
| 1281 |
if (r1 >= ne11) {
|
| 1282 |
break;
|
| 1283 |
}
|
| 1284 |
|
| 1285 |
+
device const T1 * y = (device const T1 *) (src1 + r1*nb11 + im*nb12);
|
| 1286 |
+
device const T14 * y4 = (device const T14 *) y;
|
| 1287 |
|
| 1288 |
float sumf = 0;
|
| 1289 |
for (int i = tiisg; i < ne00/4; i += 32) {
|
| 1290 |
+
for (int k = 0; k < 4; ++k) sumf += (float) (x4[i][k] * y4[i][k]);
|
| 1291 |
}
|
| 1292 |
|
| 1293 |
float all_sum = simd_sum(sumf);
|
| 1294 |
if (tiisg == 0) {
|
| 1295 |
+
for (int i = 4*(ne00/4); i < ne00; ++i) all_sum += (float) (x[i] * y[i]);
|
| 1296 |
dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
|
| 1297 |
}
|
| 1298 |
}
|
| 1299 |
}
|
| 1300 |
}
|
| 1301 |
|
| 1302 |
+
template<typename T0, typename T04, typename T1, typename T14>
|
| 1303 |
+
kernel void kernel_mul_mv(
|
| 1304 |
device const char * src0,
|
| 1305 |
device const char * src1,
|
| 1306 |
device float * dst,
|
|
|
|
| 1322 |
constant uint & r3,
|
| 1323 |
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 1324 |
uint tiisg[[thread_index_in_simdgroup]]) {
|
| 1325 |
+
kernel_mul_mv_impl<T0, T04, T1, T14>(
|
| 1326 |
+
src0,
|
| 1327 |
+
src1,
|
| 1328 |
+
dst,
|
| 1329 |
+
ne00,
|
| 1330 |
+
ne01,
|
| 1331 |
+
ne02,
|
| 1332 |
+
nb00,
|
| 1333 |
+
nb01,
|
| 1334 |
+
nb02,
|
| 1335 |
+
ne10,
|
| 1336 |
+
ne11,
|
| 1337 |
+
ne12,
|
| 1338 |
+
nb10,
|
| 1339 |
+
nb11,
|
| 1340 |
+
nb12,
|
| 1341 |
+
ne0,
|
| 1342 |
+
ne1,
|
| 1343 |
+
r2,
|
| 1344 |
+
r3,
|
| 1345 |
+
tgpig,
|
| 1346 |
+
tiisg);
|
| 1347 |
}
|
| 1348 |
|
| 1349 |
+
typedef decltype(kernel_mul_mv<half, half4, half, half4>) mul_mv_t;
|
| 1350 |
|
| 1351 |
+
template [[host_name("kernel_mul_mv_f32_f32")]] kernel mul_mv_t kernel_mul_mv<float, float4, float, float4>;
|
| 1352 |
+
template [[host_name("kernel_mul_mv_f16_f32")]] kernel mul_mv_t kernel_mul_mv<half, half4, float, float4>;
|
| 1353 |
+
template [[host_name("kernel_mul_mv_f16_f16")]] kernel mul_mv_t kernel_mul_mv<half, half4, half, half4>;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1354 |
|
| 1355 |
+
template<typename T, typename T4>
|
| 1356 |
+
kernel void kernel_mul_mv_1row(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1357 |
device const char * src0,
|
| 1358 |
device const char * src1,
|
| 1359 |
device float * dst,
|
|
|
|
| 1385 |
|
| 1386 |
const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02;
|
| 1387 |
|
| 1388 |
+
device const T * x = (device const T *) (src0 + offset0);
|
| 1389 |
device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12);
|
| 1390 |
|
| 1391 |
float sumf = 0;
|
|
|
|
| 1398 |
dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
|
| 1399 |
}
|
| 1400 |
} else {
|
| 1401 |
+
device const T4 * x4 = (device const T4 *) x;
|
| 1402 |
device const float4 * y4 = (device const float4 *) y;
|
| 1403 |
+
|
| 1404 |
for (int i = tiisg; i < ne00/4; i += 32) {
|
| 1405 |
+
for (int k = 0; k < 4; ++k) sumf += (float) (x4[i][k] * y4[i][k]);
|
| 1406 |
}
|
| 1407 |
+
|
| 1408 |
float all_sum = simd_sum(sumf);
|
| 1409 |
+
|
| 1410 |
if (tiisg == 0) {
|
| 1411 |
+
for (int i = 4*(ne00/4); i < ne00; ++i) all_sum += (float) (x[i] * y[i]);
|
| 1412 |
dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
|
| 1413 |
}
|
| 1414 |
}
|
| 1415 |
}
|
| 1416 |
|
| 1417 |
+
typedef decltype(kernel_mul_mv_1row<half, half4>) mul_mv_1row_t;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1418 |
|
| 1419 |
+
template [[host_name("kernel_mul_mv_f16_f32_1row")]] kernel mul_mv_1row_t kernel_mul_mv_1row<half, half4>;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1420 |
|
| 1421 |
// Assumes row size (ne00) is a multiple of 4
|
| 1422 |
+
template<typename T, typename T4>
|
| 1423 |
+
kernel void kernel_mul_mv_l4(
|
| 1424 |
device const char * src0,
|
| 1425 |
device const char * src1,
|
| 1426 |
device float * dst,
|
|
|
|
| 1452 |
|
| 1453 |
const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02;
|
| 1454 |
|
| 1455 |
+
device const T4 * x4 = (device const T4 *) (src0 + offset0);
|
| 1456 |
|
| 1457 |
for (int r1 = 0; r1 < nrows; ++r1) {
|
| 1458 |
device const float4 * y4 = (device const float4 *) (src1 + r1*nb11 + im*nb12);
|
| 1459 |
|
| 1460 |
float sumf = 0;
|
| 1461 |
for (int i = tiisg; i < ne00/4; i += 32) {
|
| 1462 |
+
for (int k = 0; k < 4; ++k) sumf += (float) (x4[i][k] * y4[i][k]);
|
| 1463 |
}
|
| 1464 |
|
| 1465 |
float all_sum = simd_sum(sumf);
|
|
|
|
| 1469 |
}
|
| 1470 |
}
|
| 1471 |
|
| 1472 |
+
typedef decltype(kernel_mul_mv_l4<half, half4>) mul_mv_l4_t;
|
| 1473 |
+
|
| 1474 |
+
template [[host_name("kernel_mul_mv_f16_f32_l4")]] kernel mul_mv_l4_t kernel_mul_mv_l4<half, half4>;
|
| 1475 |
+
|
| 1476 |
static float rope_yarn_ramp(const float low, const float high, const int i0) {
|
| 1477 |
const float y = (i0 / 2 - low) / max(0.001f, high - low);
|
| 1478 |
return 1.0f - min(1.0f, max(0.0f, y));
|
|
|
|
| 2593 |
template [[host_name("kernel_flash_attn_ext_vec_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<128>;
|
| 2594 |
//template [[host_name("kernel_flash_attn_ext_vec_f16_h256")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<256>;
|
| 2595 |
|
| 2596 |
+
template<typename T0, typename T1>
|
| 2597 |
+
kernel void kernel_cpy(
|
| 2598 |
+
device const void * src0,
|
| 2599 |
+
device void * dst,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2600 |
constant int64_t & ne00,
|
| 2601 |
constant int64_t & ne01,
|
| 2602 |
constant int64_t & ne02,
|
|
|
|
| 2627 |
const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
|
| 2628 |
const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
|
| 2629 |
|
| 2630 |
+
device T1 * dst_data = (device T1 *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
| 2631 |
|
| 2632 |
for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
|
| 2633 |
+
device const T0 * src = (device T0 *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
|
| 2634 |
+
dst_data[i00] = (T1) src[0];
|
|
|
|
| 2635 |
}
|
| 2636 |
}
|
| 2637 |
|
| 2638 |
+
typedef decltype(kernel_cpy<float, float>) kernel_cpy_t;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2639 |
|
| 2640 |
+
template [[host_name("kernel_cpy_f32_f32")]] kernel kernel_cpy_t kernel_cpy<float, float>;
|
| 2641 |
+
template [[host_name("kernel_cpy_f32_f16")]] kernel kernel_cpy_t kernel_cpy<float, half>;
|
| 2642 |
+
template [[host_name("kernel_cpy_f16_f16")]] kernel kernel_cpy_t kernel_cpy<half, half>;
|
| 2643 |
+
template [[host_name("kernel_cpy_f16_f32")]] kernel kernel_cpy_t kernel_cpy<half, float>;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2644 |
|
| 2645 |
kernel void kernel_cpy_f32_q8_0(
|
| 2646 |
device const float * src0,
|
|
|
|
| 5441 |
}
|
| 5442 |
|
| 5443 |
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread float4x4 &)>
|
| 5444 |
+
kernel void kernel_get_rows_q(
|
| 5445 |
device const void * src0,
|
| 5446 |
+
device const void * src1,
|
| 5447 |
device float * dst,
|
| 5448 |
constant int64_t & ne00,
|
| 5449 |
constant uint64_t & nb01,
|
|
|
|
| 5456 |
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 5457 |
uint tiitg[[thread_index_in_threadgroup]],
|
| 5458 |
uint3 tptg [[threads_per_threadgroup]]) {
|
|
|
|
|
|
|
|
|
|
| 5459 |
const int64_t i10 = tgpig.x;
|
| 5460 |
const int64_t i11 = tgpig.y;
|
| 5461 |
|
| 5462 |
+
const int64_t r = ((const device int32_t *) ((const device char *) src1 + i11*nb11 + i10*nb10))[0];
|
| 5463 |
|
| 5464 |
const int64_t i02 = i11;
|
| 5465 |
|
| 5466 |
for (int64_t ind = tiitg; ind < ne00/16; ind += tptg.x) {
|
| 5467 |
float4x4 temp;
|
| 5468 |
+
dequantize_func(((device const block_q *) ((const device char *) src0 + r*nb01 + i02*nb02)) + ind/nl, ind%nl, temp);
|
|
|
|
| 5469 |
*(((device float4x4 *) ((device char *) dst + i11*nb2 + i10*nb1)) + ind) = temp;
|
| 5470 |
}
|
| 5471 |
}
|
| 5472 |
|
| 5473 |
+
template<typename T>
|
| 5474 |
+
kernel void kernel_get_rows_f(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5475 |
device const void * src0,
|
| 5476 |
+
device const void * src1,
|
| 5477 |
device float * dst,
|
| 5478 |
constant int64_t & ne00,
|
| 5479 |
constant uint64_t & nb01,
|
|
|
|
| 5489 |
const int64_t i10 = tgpig.x;
|
| 5490 |
const int64_t i11 = tgpig.y;
|
| 5491 |
|
| 5492 |
+
const int64_t r = ((const device int32_t *) ((const device char *) src1 + i11*nb11 + i10*nb10))[0];
|
| 5493 |
|
| 5494 |
const int64_t i02 = i11;
|
| 5495 |
|
| 5496 |
for (int ind = tiitg; ind < ne00; ind += tptg.x) {
|
| 5497 |
+
(( device float *) (( device char *) dst + i11*nb2 + i10*nb1))[ind] =
|
| 5498 |
+
((const device T *) ((const device char *) src0 + i02*nb02 + r*nb01))[ind];
|
| 5499 |
}
|
| 5500 |
}
|
| 5501 |
|
| 5502 |
kernel void kernel_get_rows_i32(
|
| 5503 |
device const void * src0,
|
| 5504 |
+
device const void * src1,
|
| 5505 |
device int32_t * dst,
|
| 5506 |
constant int64_t & ne00,
|
| 5507 |
constant uint64_t & nb01,
|
|
|
|
| 5517 |
const int64_t i10 = tgpig.x;
|
| 5518 |
const int64_t i11 = tgpig.y;
|
| 5519 |
|
| 5520 |
+
const int64_t r = ((const device int32_t *) ((const device char *) src1 + i11*nb11 + i10*nb10))[0];
|
| 5521 |
|
| 5522 |
const int64_t i02 = i11;
|
| 5523 |
|
| 5524 |
for (int ind = tiitg; ind < ne00; ind += tptg.x) {
|
| 5525 |
+
(( device int32_t *) (( device char *) dst + i11*nb2 + i10*nb1))[ind] =
|
| 5526 |
+
((const device int32_t *) ((const device char *) src0 + i02*nb02 + r*nb01))[ind];
|
| 5527 |
}
|
| 5528 |
}
|
| 5529 |
|
|
|
|
| 5540 |
#define SG_MAT_ROW 8
|
| 5541 |
|
| 5542 |
// each block_q contains 16*nl weights
|
| 5543 |
+
template<typename T, typename T4x4, typename simdgroup_T8x8, typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread T4x4 &)>
|
| 5544 |
+
kernel void kernel_mul_mm(device const uchar * src0,
|
| 5545 |
+
device const uchar * src1,
|
| 5546 |
+
device float * dst,
|
| 5547 |
+
constant int64_t & ne00,
|
| 5548 |
+
constant int64_t & ne02,
|
| 5549 |
+
constant uint64_t & nb01,
|
| 5550 |
+
constant uint64_t & nb02,
|
| 5551 |
+
constant int64_t & ne12,
|
| 5552 |
+
constant uint64_t & nb10,
|
| 5553 |
+
constant uint64_t & nb11,
|
| 5554 |
+
constant uint64_t & nb12,
|
| 5555 |
+
constant int64_t & ne0,
|
| 5556 |
+
constant int64_t & ne1,
|
| 5557 |
+
constant uint & r2,
|
| 5558 |
+
constant uint & r3,
|
| 5559 |
+
threadgroup uchar * shared_memory [[threadgroup(0)]],
|
| 5560 |
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 5561 |
+
uint tiitg[[thread_index_in_threadgroup]],
|
| 5562 |
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
| 5563 |
|
| 5564 |
+
threadgroup T * sa = (threadgroup T *)(shared_memory);
|
| 5565 |
threadgroup float * sb = (threadgroup float *)(shared_memory + 4096);
|
| 5566 |
|
| 5567 |
const uint r0 = tgpig.y;
|
|
|
|
| 5576 |
short thread_row = ((short)tiitg/THREAD_PER_ROW) < n_rows ? ((short)tiitg/THREAD_PER_ROW) : n_rows - 1;
|
| 5577 |
short thread_col = ((short)tiitg/THREAD_PER_COL) < n_cols ? ((short)tiitg/THREAD_PER_COL) : n_cols - 1;
|
| 5578 |
|
| 5579 |
+
simdgroup_T8x8 ma[4];
|
| 5580 |
simdgroup_float8x8 mb[2];
|
| 5581 |
simdgroup_float8x8 c_res[8];
|
| 5582 |
for (int i = 0; i < 8; i++){
|
|
|
|
| 5599 |
|
| 5600 |
for (int loop_k = 0; loop_k < ne00; loop_k += BLOCK_SIZE_K) {
|
| 5601 |
// load data and store to threadgroup memory
|
| 5602 |
+
T4x4 temp_a;
|
| 5603 |
dequantize_func(x, il, temp_a);
|
| 5604 |
threadgroup_barrier(mem_flags::mem_threadgroup);
|
| 5605 |
|
|
|
|
| 5619 |
threadgroup_barrier(mem_flags::mem_threadgroup);
|
| 5620 |
|
| 5621 |
// load matrices from threadgroup memory and conduct outer products
|
| 5622 |
+
threadgroup T * lsma = (sa + THREAD_MAT_M * SG_MAT_SIZE * (sgitg % 2));
|
| 5623 |
threadgroup float * lsmb = (sb + THREAD_MAT_N * SG_MAT_SIZE * (sgitg / 2));
|
| 5624 |
|
| 5625 |
#pragma unroll(4)
|
|
|
|
| 5795 |
}
|
| 5796 |
}
|
| 5797 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5798 |
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
|
| 5799 |
kernel void kernel_mul_mm_id(
|
| 5800 |
device const uchar * src0s,
|
|
|
|
| 5875 |
// get rows
|
| 5876 |
//
|
| 5877 |
|
| 5878 |
+
typedef decltype(kernel_get_rows_f<float>) get_rows_f_t;
|
| 5879 |
+
|
| 5880 |
+
template [[host_name("kernel_get_rows_f32")]] kernel get_rows_f_t kernel_get_rows_f<float>;
|
| 5881 |
+
template [[host_name("kernel_get_rows_f16")]] kernel get_rows_f_t kernel_get_rows_f<half>;
|
| 5882 |
+
|
| 5883 |
+
typedef decltype(kernel_get_rows_q<block_q4_0, 2, dequantize_q4_0>) get_rows_q_t;
|
| 5884 |
+
|
| 5885 |
+
template [[host_name("kernel_get_rows_q4_0")]] kernel get_rows_q_t kernel_get_rows_q<block_q4_0, 2, dequantize_q4_0>;
|
| 5886 |
+
template [[host_name("kernel_get_rows_q4_1")]] kernel get_rows_q_t kernel_get_rows_q<block_q4_1, 2, dequantize_q4_1>;
|
| 5887 |
+
template [[host_name("kernel_get_rows_q5_0")]] kernel get_rows_q_t kernel_get_rows_q<block_q5_0, 2, dequantize_q5_0>;
|
| 5888 |
+
template [[host_name("kernel_get_rows_q5_1")]] kernel get_rows_q_t kernel_get_rows_q<block_q5_1, 2, dequantize_q5_1>;
|
| 5889 |
+
template [[host_name("kernel_get_rows_q8_0")]] kernel get_rows_q_t kernel_get_rows_q<block_q8_0, 2, dequantize_q8_0>;
|
| 5890 |
+
template [[host_name("kernel_get_rows_q2_K")]] kernel get_rows_q_t kernel_get_rows_q<block_q2_K, QK_NL, dequantize_q2_K>;
|
| 5891 |
+
template [[host_name("kernel_get_rows_q3_K")]] kernel get_rows_q_t kernel_get_rows_q<block_q3_K, QK_NL, dequantize_q3_K>;
|
| 5892 |
+
template [[host_name("kernel_get_rows_q4_K")]] kernel get_rows_q_t kernel_get_rows_q<block_q4_K, QK_NL, dequantize_q4_K>;
|
| 5893 |
+
template [[host_name("kernel_get_rows_q5_K")]] kernel get_rows_q_t kernel_get_rows_q<block_q5_K, QK_NL, dequantize_q5_K>;
|
| 5894 |
+
template [[host_name("kernel_get_rows_q6_K")]] kernel get_rows_q_t kernel_get_rows_q<block_q6_K, QK_NL, dequantize_q6_K>;
|
| 5895 |
+
template [[host_name("kernel_get_rows_iq2_xxs")]] kernel get_rows_q_t kernel_get_rows_q<block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
|
| 5896 |
+
template [[host_name("kernel_get_rows_iq2_xs")]] kernel get_rows_q_t kernel_get_rows_q<block_iq2_xs, QK_NL, dequantize_iq2_xs>;
|
| 5897 |
+
template [[host_name("kernel_get_rows_iq3_xxs")]] kernel get_rows_q_t kernel_get_rows_q<block_iq3_xxs, QK_NL, dequantize_iq3_xxs>;
|
| 5898 |
+
template [[host_name("kernel_get_rows_iq3_s")]] kernel get_rows_q_t kernel_get_rows_q<block_iq3_s, QK_NL, dequantize_iq3_s>;
|
| 5899 |
+
template [[host_name("kernel_get_rows_iq2_s")]] kernel get_rows_q_t kernel_get_rows_q<block_iq2_s, QK_NL, dequantize_iq2_s>;
|
| 5900 |
+
template [[host_name("kernel_get_rows_iq1_s")]] kernel get_rows_q_t kernel_get_rows_q<block_iq1_s, QK_NL, dequantize_iq1_s>;
|
| 5901 |
+
template [[host_name("kernel_get_rows_iq1_m")]] kernel get_rows_q_t kernel_get_rows_q<block_iq1_m, QK_NL, dequantize_iq1_m>;
|
| 5902 |
+
template [[host_name("kernel_get_rows_iq4_nl")]] kernel get_rows_q_t kernel_get_rows_q<block_iq4_nl, 2, dequantize_iq4_nl>;
|
| 5903 |
+
template [[host_name("kernel_get_rows_iq4_xs")]] kernel get_rows_q_t kernel_get_rows_q<block_iq4_xs, QK_NL, dequantize_iq4_xs>;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5904 |
|
| 5905 |
//
|
| 5906 |
// matrix-matrix multiplication
|
| 5907 |
//
|
| 5908 |
|
| 5909 |
+
typedef decltype(kernel_mul_mm<half, half4x4, simdgroup_half8x8, float4x4, 1, dequantize_f32>) mat_mm_t;
|
| 5910 |
+
|
| 5911 |
+
template [[host_name("kernel_mul_mm_f32_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, float4x4, 1, dequantize_f32>;
|
| 5912 |
+
template [[host_name("kernel_mul_mm_f16_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half4x4, 1, dequantize_f16>;
|
| 5913 |
+
template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q4_0, 2, dequantize_q4_0>;
|
| 5914 |
+
template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q4_1, 2, dequantize_q4_1>;
|
| 5915 |
+
template [[host_name("kernel_mul_mm_q5_0_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q5_0, 2, dequantize_q5_0>;
|
| 5916 |
+
template [[host_name("kernel_mul_mm_q5_1_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q5_1, 2, dequantize_q5_1>;
|
| 5917 |
+
template [[host_name("kernel_mul_mm_q8_0_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q8_0, 2, dequantize_q8_0>;
|
| 5918 |
+
template [[host_name("kernel_mul_mm_q2_K_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q2_K, QK_NL, dequantize_q2_K>;
|
| 5919 |
+
template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q3_K, QK_NL, dequantize_q3_K>;
|
| 5920 |
+
template [[host_name("kernel_mul_mm_q4_K_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q4_K, QK_NL, dequantize_q4_K>;
|
| 5921 |
+
template [[host_name("kernel_mul_mm_q5_K_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q5_K, QK_NL, dequantize_q5_K>;
|
| 5922 |
+
template [[host_name("kernel_mul_mm_q6_K_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q6_K, QK_NL, dequantize_q6_K>;
|
| 5923 |
+
template [[host_name("kernel_mul_mm_iq2_xxs_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
|
| 5924 |
+
template [[host_name("kernel_mul_mm_iq2_xs_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq2_xs, QK_NL, dequantize_iq2_xs>;
|
| 5925 |
+
template [[host_name("kernel_mul_mm_iq3_xxs_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq3_xxs, QK_NL, dequantize_iq3_xxs>;
|
| 5926 |
+
template [[host_name("kernel_mul_mm_iq3_s_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq3_s, QK_NL, dequantize_iq3_s>;
|
| 5927 |
+
template [[host_name("kernel_mul_mm_iq2_s_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq2_s, QK_NL, dequantize_iq2_s>;
|
| 5928 |
+
template [[host_name("kernel_mul_mm_iq1_s_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq1_s, QK_NL, dequantize_iq1_s>;
|
| 5929 |
+
template [[host_name("kernel_mul_mm_iq1_m_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq1_m, QK_NL, dequantize_iq1_m>;
|
| 5930 |
+
template [[host_name("kernel_mul_mm_iq4_nl_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq4_nl, 2, dequantize_iq4_nl>;
|
| 5931 |
+
template [[host_name("kernel_mul_mm_iq4_xs_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq4_xs, QK_NL, dequantize_iq4_xs>;
|
| 5932 |
|
| 5933 |
//
|
| 5934 |
// indirect matrix-matrix multiplication
|
|
|
|
| 6065 |
impl_fn(src0,(const device float *)src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,shared_values,tgpig,tiisg,sgitg);
|
| 6066 |
}
|
| 6067 |
|
| 6068 |
+
typedef decltype(mmv_fn<kernel_mul_mv_impl<half, half4, half, half4>>) mul_mv_impl_fn_t;
|
| 6069 |
|
| 6070 |
template<mul_mv_impl_fn_t impl_fn>
|
| 6071 |
kernel void kernel_mul_mv_id(
|
|
|
|
| 6143 |
sgitg);
|
| 6144 |
}
|
| 6145 |
|
| 6146 |
+
typedef decltype(kernel_mul_mv_id<mmv_fn<kernel_mul_mv_impl<float, float4, float, float4>>>) kernel_mul_mv_id_t;
|
| 6147 |
+
|
| 6148 |
+
template [[host_name("kernel_mul_mv_id_f32_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_impl<float, float4, float, float4>>>;
|
| 6149 |
+
template [[host_name("kernel_mul_mv_id_f16_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_impl<half, half4, float, float4>>>;
|
| 6150 |
+
template [[host_name("kernel_mul_mv_id_q8_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q8_0_f32_impl>>;
|
| 6151 |
+
template [[host_name("kernel_mul_mv_id_q4_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>>>;
|
| 6152 |
+
template [[host_name("kernel_mul_mv_id_q4_1_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>>>;
|
| 6153 |
+
template [[host_name("kernel_mul_mv_id_q5_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q5_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>>>;
|
| 6154 |
+
template [[host_name("kernel_mul_mv_id_q5_1_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q5_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>>>;
|
| 6155 |
+
template [[host_name("kernel_mul_mv_id_q2_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q2_K_f32_impl>>;
|
| 6156 |
+
template [[host_name("kernel_mul_mv_id_q3_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q3_K_f32_impl>>;
|
| 6157 |
+
template [[host_name("kernel_mul_mv_id_q4_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q4_K_f32_impl>>;
|
| 6158 |
+
template [[host_name("kernel_mul_mv_id_q5_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q5_K_f32_impl>>;
|
| 6159 |
+
template [[host_name("kernel_mul_mv_id_q6_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q6_K_f32_impl>>;
|
| 6160 |
template [[host_name("kernel_mul_mv_id_iq1_s_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq1_s_f32_impl>>;
|
| 6161 |
template [[host_name("kernel_mul_mv_id_iq1_m_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq1_m_f32_impl>>;
|
| 6162 |
template [[host_name("kernel_mul_mv_id_iq2_xxs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq2_xxs_f32_impl>>;
|