Spaces:
Running
Running
metal : support permuted matrix multiplicaions (llama/10033)
Browse files* metal : support permuted matrix multiplicaions
ggml-ci
* cont : use nb01 directly for row steps
ggml-ci
* cont : add comments [no ci]
* metal : minor refactor
* metal : minor
- ggml/src/ggml-metal.m +42 -33
- ggml/src/ggml-metal.metal +380 -196
ggml/src/ggml-metal.m
CHANGED
|
@@ -1015,19 +1015,21 @@ static void ggml_metal_encode_node(
|
|
| 1015 |
id<MTLBuffer> id_src2 = src2 ? ggml_metal_get_buffer(src2, &offs_src2) : nil;
|
| 1016 |
id<MTLBuffer> id_dst = dst ? ggml_metal_get_buffer(dst, &offs_dst) : nil;
|
| 1017 |
|
| 1018 |
-
|
| 1019 |
-
|
| 1020 |
-
|
| 1021 |
-
|
| 1022 |
-
|
| 1023 |
-
|
| 1024 |
-
|
| 1025 |
-
|
| 1026 |
-
|
| 1027 |
-
|
| 1028 |
-
|
| 1029 |
-
|
| 1030 |
-
|
|
|
|
|
|
|
| 1031 |
|
| 1032 |
id<MTLDevice> device = ctx_dev->mtl_device;
|
| 1033 |
|
|
@@ -1810,14 +1812,16 @@ static void ggml_metal_encode_node(
|
|
| 1810 |
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
|
| 1811 |
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:5];
|
| 1812 |
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:6];
|
| 1813 |
-
[encoder setBytes:&
|
| 1814 |
-
[encoder setBytes:&
|
| 1815 |
-
[encoder setBytes:&
|
| 1816 |
-
[encoder setBytes:&
|
| 1817 |
-
[encoder setBytes:&
|
| 1818 |
-
[encoder setBytes:&
|
| 1819 |
-
[encoder setBytes:&
|
| 1820 |
-
[encoder setBytes:&
|
|
|
|
|
|
|
| 1821 |
[encoder setThreadgroupMemoryLength:8192 atIndex:0];
|
| 1822 |
[encoder dispatchThreadgroups:MTLSizeMake( (ne11 + 31)/32, (ne01 + 63)/64, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
|
| 1823 |
} else {
|
|
@@ -1986,20 +1990,22 @@ static void ggml_metal_encode_node(
|
|
| 1986 |
[encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
|
| 1987 |
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
|
| 1988 |
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
|
| 1989 |
-
[encoder setBytes:&
|
| 1990 |
-
[encoder setBytes:&
|
| 1991 |
-
[encoder setBytes:&
|
| 1992 |
-
[encoder setBytes:&
|
| 1993 |
-
[encoder setBytes:&
|
| 1994 |
-
[encoder setBytes:&
|
| 1995 |
-
[encoder setBytes:&
|
| 1996 |
-
[encoder setBytes:&
|
| 1997 |
-
[encoder setBytes:&
|
| 1998 |
-
[encoder setBytes:&
|
|
|
|
|
|
|
| 1999 |
|
| 2000 |
if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 || src0t == GGML_TYPE_Q5_0 ||
|
| 2001 |
-
|
| 2002 |
-
|
| 2003 |
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
| 2004 |
}
|
| 2005 |
else if (src0t == GGML_TYPE_IQ2_XXS || src0t == GGML_TYPE_IQ2_XS) {
|
|
@@ -2048,6 +2054,9 @@ static void ggml_metal_encode_node(
|
|
| 2048 |
|
| 2049 |
GGML_ASSERT(src1t == GGML_TYPE_F32);
|
| 2050 |
|
|
|
|
|
|
|
|
|
|
| 2051 |
// find the break-even point where the matrix-matrix kernel becomes more efficient compared
|
| 2052 |
// to the matrix-vector kernel
|
| 2053 |
// ne20 = n_used_experts
|
|
|
|
| 1015 |
id<MTLBuffer> id_src2 = src2 ? ggml_metal_get_buffer(src2, &offs_src2) : nil;
|
| 1016 |
id<MTLBuffer> id_dst = dst ? ggml_metal_get_buffer(dst, &offs_dst) : nil;
|
| 1017 |
|
| 1018 |
+
#if 0
|
| 1019 |
+
GGML_LOG_INFO("%s: op - %s\n", __func__, ggml_op_name(dst->op));
|
| 1020 |
+
if (src0) {
|
| 1021 |
+
GGML_LOG_INFO("%s: src0 - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src0t), ne00, ne01, ne02, ne03, nb00, nb01, nb02, nb03,
|
| 1022 |
+
ggml_is_contiguous(src0), src0->name);
|
| 1023 |
+
}
|
| 1024 |
+
if (src1) {
|
| 1025 |
+
GGML_LOG_INFO("%s: src1 - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src1t), ne10, ne11, ne12, ne13, nb10, nb11, nb12, nb13,
|
| 1026 |
+
ggml_is_contiguous(src1), src1->name);
|
| 1027 |
+
}
|
| 1028 |
+
if (dst) {
|
| 1029 |
+
GGML_LOG_INFO("%s: dst - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], 1, %s\n", __func__, ggml_type_name(dstt), ne0, ne1, ne2, ne3, nb0, nb1, nb2, nb3,
|
| 1030 |
+
dst->name);
|
| 1031 |
+
}
|
| 1032 |
+
#endif
|
| 1033 |
|
| 1034 |
id<MTLDevice> device = ctx_dev->mtl_device;
|
| 1035 |
|
|
|
|
| 1812 |
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
|
| 1813 |
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:5];
|
| 1814 |
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:6];
|
| 1815 |
+
[encoder setBytes:&nb03 length:sizeof(nb03) atIndex:7];
|
| 1816 |
+
[encoder setBytes:&ne12 length:sizeof(ne12) atIndex:8];
|
| 1817 |
+
[encoder setBytes:&nb10 length:sizeof(nb10) atIndex:9];
|
| 1818 |
+
[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:10];
|
| 1819 |
+
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:11];
|
| 1820 |
+
[encoder setBytes:&nb13 length:sizeof(nb13) atIndex:12];
|
| 1821 |
+
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:13];
|
| 1822 |
+
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:14];
|
| 1823 |
+
[encoder setBytes:&r2 length:sizeof(r2) atIndex:15];
|
| 1824 |
+
[encoder setBytes:&r3 length:sizeof(r3) atIndex:16];
|
| 1825 |
[encoder setThreadgroupMemoryLength:8192 atIndex:0];
|
| 1826 |
[encoder dispatchThreadgroups:MTLSizeMake( (ne11 + 31)/32, (ne01 + 63)/64, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
|
| 1827 |
} else {
|
|
|
|
| 1990 |
[encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
|
| 1991 |
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
|
| 1992 |
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
|
| 1993 |
+
[encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9];
|
| 1994 |
+
[encoder setBytes:&ne10 length:sizeof(ne10) atIndex:10];
|
| 1995 |
+
[encoder setBytes:&ne11 length:sizeof(ne11) atIndex:11];
|
| 1996 |
+
[encoder setBytes:&ne12 length:sizeof(ne12) atIndex:12];
|
| 1997 |
+
[encoder setBytes:&nb10 length:sizeof(nb10) atIndex:13];
|
| 1998 |
+
[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:14];
|
| 1999 |
+
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:15];
|
| 2000 |
+
[encoder setBytes:&nb13 length:sizeof(nb13) atIndex:16];
|
| 2001 |
+
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:17];
|
| 2002 |
+
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:18];
|
| 2003 |
+
[encoder setBytes:&r2 length:sizeof(r2) atIndex:19];
|
| 2004 |
+
[encoder setBytes:&r3 length:sizeof(r3) atIndex:20];
|
| 2005 |
|
| 2006 |
if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 || src0t == GGML_TYPE_Q5_0 ||
|
| 2007 |
+
src0t == GGML_TYPE_Q5_1 || src0t == GGML_TYPE_Q8_0 || src0t == GGML_TYPE_Q2_K ||
|
| 2008 |
+
src0t == GGML_TYPE_IQ1_S || src0t == GGML_TYPE_IQ1_M || src0t == GGML_TYPE_IQ2_S) {
|
| 2009 |
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
| 2010 |
}
|
| 2011 |
else if (src0t == GGML_TYPE_IQ2_XXS || src0t == GGML_TYPE_IQ2_XS) {
|
|
|
|
| 2054 |
|
| 2055 |
GGML_ASSERT(src1t == GGML_TYPE_F32);
|
| 2056 |
|
| 2057 |
+
GGML_ASSERT(ne03 == 1);
|
| 2058 |
+
GGML_ASSERT(ne13 == 1);
|
| 2059 |
+
|
| 2060 |
// find the break-even point where the matrix-matrix kernel becomes more efficient compared
|
| 2061 |
// to the matrix-vector kernel
|
| 2062 |
// ne20 = n_used_experts
|
ggml/src/ggml-metal.metal
CHANGED
|
@@ -777,10 +777,10 @@ kernel void kernel_ssm_conv_f32(
|
|
| 777 |
const int64_t i3 = tgpig.z;
|
| 778 |
|
| 779 |
const int64_t nc = ne10;
|
| 780 |
-
|
| 781 |
-
|
| 782 |
-
|
| 783 |
-
|
| 784 |
|
| 785 |
device const float * s = (device const float *) ((device const char *) src0 + ir*nb01 + i2*nb00 + i3*nb02);
|
| 786 |
device const float * c = (device const float *) ((device const char *) src1 + ir*nb11);
|
|
@@ -834,9 +834,9 @@ kernel void kernel_ssm_scan_f32(
|
|
| 834 |
const int64_t i3 = tgpig.y;
|
| 835 |
|
| 836 |
const int64_t nc = d_state;
|
| 837 |
-
|
| 838 |
const int64_t n_t = n_seq_tokens;
|
| 839 |
-
|
| 840 |
|
| 841 |
for (int64_t i2 = 0; i2 < n_t; ++i2) {
|
| 842 |
device const float * s0 = (device const float *) ((device const char *) src0 + ir*nb01 + i3*nb02);
|
|
@@ -1064,17 +1064,18 @@ kernel void kernel_group_norm(
|
|
| 1064 |
inline float block_q_n_dot_y(device const block_q4_0 * qb_curr, float sumy, thread float * yl, int il) {
|
| 1065 |
float d = qb_curr->d;
|
| 1066 |
|
| 1067 |
-
|
| 1068 |
|
| 1069 |
-
device const uint16_t * qs = ((device const uint16_t *)qb_curr + 1 + il/2);
|
| 1070 |
|
| 1071 |
-
for (int i = 0; i < 8; i+=2) {
|
| 1072 |
-
acc[0] += yl[i + 0] * (qs[i / 2] & 0x000F)
|
| 1073 |
-
|
| 1074 |
-
acc[
|
| 1075 |
-
|
| 1076 |
}
|
| 1077 |
-
|
|
|
|
| 1078 |
}
|
| 1079 |
|
| 1080 |
// function for calculate inner product between half a q4_1 block and 16 floats (yl), sumy is SUM(yl[i])
|
|
@@ -1085,17 +1086,18 @@ inline float block_q_n_dot_y(device const block_q4_1 * qb_curr, float sumy, thre
|
|
| 1085 |
float d = qb_curr->d;
|
| 1086 |
float m = qb_curr->m;
|
| 1087 |
|
| 1088 |
-
|
| 1089 |
|
| 1090 |
-
device const uint16_t * qs = ((device const uint16_t *)qb_curr + 2 + il/2);
|
| 1091 |
|
| 1092 |
for (int i = 0; i < 8; i+=2) {
|
| 1093 |
-
acc[0] += yl[i + 0] * (qs[i / 2] & 0x000F)
|
| 1094 |
-
|
| 1095 |
-
acc[
|
| 1096 |
-
|
| 1097 |
}
|
| 1098 |
-
|
|
|
|
| 1099 |
}
|
| 1100 |
|
| 1101 |
// function for calculate inner product between half a q5_0 block and 16 floats (yl), sumy is SUM(yl[i])
|
|
@@ -1105,18 +1107,19 @@ inline float block_q_n_dot_y(device const block_q4_1 * qb_curr, float sumy, thre
|
|
| 1105 |
inline float block_q_n_dot_y(device const block_q5_0 * qb_curr, float sumy, thread float * yl, int il) {
|
| 1106 |
float d = qb_curr->d;
|
| 1107 |
|
| 1108 |
-
|
| 1109 |
|
| 1110 |
device const uint16_t * qs = ((device const uint16_t *)qb_curr + 3 + il/2);
|
| 1111 |
const uint32_t qh = *((device const uint32_t *)qb_curr->qh);
|
| 1112 |
|
| 1113 |
for (int i = 0; i < 8; i+=2) {
|
| 1114 |
-
acc[0] += yl[i + 0] * ((qs[i / 2] & 0x000F) | ((qh >> (i+0+il ) << 4 ) & 0x00010))
|
| 1115 |
-
|
| 1116 |
-
acc[
|
| 1117 |
-
|
| 1118 |
}
|
| 1119 |
-
|
|
|
|
| 1120 |
}
|
| 1121 |
|
| 1122 |
// function for calculate inner product between half a q5_1 block and 16 floats (yl), sumy is SUM(yl[i])
|
|
@@ -1127,18 +1130,19 @@ inline float block_q_n_dot_y(device const block_q5_1 * qb_curr, float sumy, thre
|
|
| 1127 |
float d = qb_curr->d;
|
| 1128 |
float m = qb_curr->m;
|
| 1129 |
|
| 1130 |
-
|
| 1131 |
|
| 1132 |
device const uint16_t * qs = ((device const uint16_t *)qb_curr + 4 + il/2);
|
| 1133 |
const uint32_t qh = *((device const uint32_t *)qb_curr->qh);
|
| 1134 |
|
| 1135 |
for (int i = 0; i < 8; i+=2) {
|
| 1136 |
-
acc[0] += yl[i + 0] * ((qs[i / 2] & 0x000F) | ((qh >> (i+0+il ) << 4 ) & 0x00010))
|
| 1137 |
-
|
| 1138 |
-
acc[
|
| 1139 |
-
|
| 1140 |
}
|
| 1141 |
-
|
|
|
|
| 1142 |
}
|
| 1143 |
|
| 1144 |
// putting them in the kernel cause a significant performance penalty
|
|
@@ -1156,14 +1160,22 @@ void mul_vec_q_n_f32_impl(
|
|
| 1156 |
int64_t ne00,
|
| 1157 |
int64_t ne01,
|
| 1158 |
int64_t ne02,
|
|
|
|
|
|
|
|
|
|
| 1159 |
int64_t ne10,
|
| 1160 |
int64_t ne12,
|
|
|
|
|
|
|
|
|
|
| 1161 |
int64_t ne0,
|
| 1162 |
int64_t ne1,
|
| 1163 |
uint r2,
|
| 1164 |
uint r3,
|
| 1165 |
threadgroup int8_t * shared_values,
|
| 1166 |
-
|
|
|
|
|
|
|
| 1167 |
const int nb = ne00/QK4_0;
|
| 1168 |
|
| 1169 |
const int r0 = tgpig.x;
|
|
@@ -1175,10 +1187,19 @@ void mul_vec_q_n_f32_impl(
|
|
| 1175 |
const uint i12 = im%ne12;
|
| 1176 |
const uint i13 = im/ne12;
|
| 1177 |
|
| 1178 |
-
|
|
|
|
| 1179 |
|
| 1180 |
-
|
| 1181 |
-
device const float * y = (device const float *)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1182 |
|
| 1183 |
float yl[16]; // src1 vector cache
|
| 1184 |
float sumf[nr] = {0.f};
|
|
@@ -1190,19 +1211,22 @@ void mul_vec_q_n_f32_impl(
|
|
| 1190 |
|
| 1191 |
// each thread in a SIMD group deals with half a block.
|
| 1192 |
for (int ib = ix; ib < nb; ib += nw/2) {
|
| 1193 |
-
float sumy = 0;
|
|
|
|
|
|
|
| 1194 |
for (int i = 0; i < 8; i += 2) {
|
| 1195 |
-
sumy
|
| 1196 |
-
yl[i+0] = yb[i+
|
| 1197 |
-
yl[i+1] = yb[i+
|
| 1198 |
|
| 1199 |
-
sumy
|
| 1200 |
-
yl[i+8] = yb[i+16]/16.f;
|
| 1201 |
-
yl[i+9] = yb[i+17]/4096.f;
|
| 1202 |
}
|
| 1203 |
|
|
|
|
| 1204 |
for (int row = 0; row < nr; row++) {
|
| 1205 |
-
sumf[row] += block_q_n_dot_y(
|
| 1206 |
}
|
| 1207 |
|
| 1208 |
yb += QK4_0 * 16;
|
|
@@ -1226,12 +1250,14 @@ kernel void kernel_mul_mv_q4_0_f32(
|
|
| 1226 |
constant uint64_t & nb00,
|
| 1227 |
constant uint64_t & nb01,
|
| 1228 |
constant uint64_t & nb02,
|
|
|
|
| 1229 |
constant int64_t & ne10,
|
| 1230 |
constant int64_t & ne11,
|
| 1231 |
constant int64_t & ne12,
|
| 1232 |
constant uint64_t & nb10,
|
| 1233 |
constant uint64_t & nb11,
|
| 1234 |
constant uint64_t & nb12,
|
|
|
|
| 1235 |
constant int64_t & ne0,
|
| 1236 |
constant int64_t & ne1,
|
| 1237 |
constant uint & r2,
|
|
@@ -1239,7 +1265,7 @@ kernel void kernel_mul_mv_q4_0_f32(
|
|
| 1239 |
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 1240 |
uint tiisg[[thread_index_in_simdgroup]],
|
| 1241 |
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
| 1242 |
-
mul_vec_q_n_f32_impl<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,nullptr,tgpig,tiisg,sgitg);
|
| 1243 |
}
|
| 1244 |
|
| 1245 |
kernel void kernel_mul_mv_q4_1_f32(
|
|
@@ -1252,12 +1278,14 @@ kernel void kernel_mul_mv_q4_1_f32(
|
|
| 1252 |
constant uint64_t & nb00,
|
| 1253 |
constant uint64_t & nb01,
|
| 1254 |
constant uint64_t & nb02,
|
|
|
|
| 1255 |
constant int64_t & ne10,
|
| 1256 |
constant int64_t & ne11,
|
| 1257 |
constant int64_t & ne12,
|
| 1258 |
constant uint64_t & nb10,
|
| 1259 |
constant uint64_t & nb11,
|
| 1260 |
constant uint64_t & nb12,
|
|
|
|
| 1261 |
constant int64_t & ne0,
|
| 1262 |
constant int64_t & ne1,
|
| 1263 |
constant uint & r2,
|
|
@@ -1265,7 +1293,7 @@ kernel void kernel_mul_mv_q4_1_f32(
|
|
| 1265 |
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 1266 |
uint tiisg[[thread_index_in_simdgroup]],
|
| 1267 |
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
| 1268 |
-
mul_vec_q_n_f32_impl<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,nullptr,tgpig,tiisg,sgitg);
|
| 1269 |
}
|
| 1270 |
|
| 1271 |
kernel void kernel_mul_mv_q5_0_f32(
|
|
@@ -1278,12 +1306,14 @@ kernel void kernel_mul_mv_q5_0_f32(
|
|
| 1278 |
constant uint64_t & nb00,
|
| 1279 |
constant uint64_t & nb01,
|
| 1280 |
constant uint64_t & nb02,
|
|
|
|
| 1281 |
constant int64_t & ne10,
|
| 1282 |
constant int64_t & ne11,
|
| 1283 |
constant int64_t & ne12,
|
| 1284 |
constant uint64_t & nb10,
|
| 1285 |
constant uint64_t & nb11,
|
| 1286 |
constant uint64_t & nb12,
|
|
|
|
| 1287 |
constant int64_t & ne0,
|
| 1288 |
constant int64_t & ne1,
|
| 1289 |
constant uint & r2,
|
|
@@ -1291,7 +1321,7 @@ kernel void kernel_mul_mv_q5_0_f32(
|
|
| 1291 |
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 1292 |
uint tiisg[[thread_index_in_simdgroup]],
|
| 1293 |
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
| 1294 |
-
mul_vec_q_n_f32_impl<block_q5_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,nullptr,tgpig,tiisg,sgitg);
|
| 1295 |
}
|
| 1296 |
|
| 1297 |
kernel void kernel_mul_mv_q5_1_f32(
|
|
@@ -1304,12 +1334,14 @@ kernel void kernel_mul_mv_q5_1_f32(
|
|
| 1304 |
constant uint64_t & nb00,
|
| 1305 |
constant uint64_t & nb01,
|
| 1306 |
constant uint64_t & nb02,
|
|
|
|
| 1307 |
constant int64_t & ne10,
|
| 1308 |
constant int64_t & ne11,
|
| 1309 |
constant int64_t & ne12,
|
| 1310 |
constant uint64_t & nb10,
|
| 1311 |
constant uint64_t & nb11,
|
| 1312 |
constant uint64_t & nb12,
|
|
|
|
| 1313 |
constant int64_t & ne0,
|
| 1314 |
constant int64_t & ne1,
|
| 1315 |
constant uint & r2,
|
|
@@ -1317,7 +1349,7 @@ kernel void kernel_mul_mv_q5_1_f32(
|
|
| 1317 |
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 1318 |
uint tiisg[[thread_index_in_simdgroup]],
|
| 1319 |
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
| 1320 |
-
mul_vec_q_n_f32_impl<block_q5_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,nullptr,tgpig,tiisg,sgitg);
|
| 1321 |
}
|
| 1322 |
|
| 1323 |
|
|
@@ -1330,8 +1362,14 @@ void kernel_mul_mv_q8_0_f32_impl(
|
|
| 1330 |
int64_t ne00,
|
| 1331 |
int64_t ne01,
|
| 1332 |
int64_t ne02,
|
|
|
|
|
|
|
|
|
|
| 1333 |
int64_t ne10,
|
| 1334 |
int64_t ne12,
|
|
|
|
|
|
|
|
|
|
| 1335 |
int64_t ne0,
|
| 1336 |
int64_t ne1,
|
| 1337 |
uint r2,
|
|
@@ -1354,10 +1392,19 @@ void kernel_mul_mv_q8_0_f32_impl(
|
|
| 1354 |
const uint i12 = im%ne12;
|
| 1355 |
const uint i13 = im/ne12;
|
| 1356 |
|
| 1357 |
-
|
|
|
|
| 1358 |
|
| 1359 |
-
|
| 1360 |
-
device const float * y = (device const float *)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1361 |
|
| 1362 |
float yl[NB_Q8_0];
|
| 1363 |
float sumf[nr]={0.f};
|
|
@@ -1374,12 +1421,12 @@ void kernel_mul_mv_q8_0_f32_impl(
|
|
| 1374 |
}
|
| 1375 |
|
| 1376 |
for (int row = 0; row < nr; row++) {
|
| 1377 |
-
device const int8_t * qs =
|
| 1378 |
float sumq = 0.f;
|
| 1379 |
for (int iq = 0; iq < NB_Q8_0; ++iq) {
|
| 1380 |
sumq += qs[iq] * yl[iq];
|
| 1381 |
}
|
| 1382 |
-
sumf[row] += sumq*
|
| 1383 |
}
|
| 1384 |
|
| 1385 |
yb += NB_Q8_0 * nw;
|
|
@@ -1404,12 +1451,14 @@ kernel void kernel_mul_mv_q8_0_f32(
|
|
| 1404 |
constant uint64_t & nb00,
|
| 1405 |
constant uint64_t & nb01,
|
| 1406 |
constant uint64_t & nb02,
|
|
|
|
| 1407 |
constant int64_t & ne10,
|
| 1408 |
constant int64_t & ne11,
|
| 1409 |
constant int64_t & ne12,
|
| 1410 |
constant uint64_t & nb10,
|
| 1411 |
constant uint64_t & nb11,
|
| 1412 |
constant uint64_t & nb12,
|
|
|
|
| 1413 |
constant int64_t & ne0,
|
| 1414 |
constant int64_t & ne1,
|
| 1415 |
constant uint & r2,
|
|
@@ -1417,7 +1466,7 @@ kernel void kernel_mul_mv_q8_0_f32(
|
|
| 1417 |
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 1418 |
uint tiisg[[thread_index_in_simdgroup]],
|
| 1419 |
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
| 1420 |
-
kernel_mul_mv_q8_0_f32_impl(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,nullptr,tgpig,tiisg,sgitg);
|
| 1421 |
}
|
| 1422 |
|
| 1423 |
#define N_MV_T_T 4
|
|
@@ -1433,12 +1482,14 @@ void kernel_mul_mv_impl(
|
|
| 1433 |
uint64_t nb00,
|
| 1434 |
uint64_t nb01,
|
| 1435 |
uint64_t nb02,
|
|
|
|
| 1436 |
int64_t ne10,
|
| 1437 |
int64_t ne11,
|
| 1438 |
int64_t ne12,
|
| 1439 |
uint64_t nb10,
|
| 1440 |
uint64_t nb11,
|
| 1441 |
uint64_t nb12,
|
|
|
|
| 1442 |
int64_t ne0,
|
| 1443 |
int64_t ne1,
|
| 1444 |
uint r2,
|
|
@@ -1452,7 +1503,7 @@ void kernel_mul_mv_impl(
|
|
| 1452 |
const uint i12 = im%ne12;
|
| 1453 |
const uint i13 = im/ne12;
|
| 1454 |
|
| 1455 |
-
const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*
|
| 1456 |
|
| 1457 |
device const T0 * x = (device const T0 *) (src0 + offset0);
|
| 1458 |
|
|
@@ -1463,7 +1514,9 @@ void kernel_mul_mv_impl(
|
|
| 1463 |
break;
|
| 1464 |
}
|
| 1465 |
|
| 1466 |
-
|
|
|
|
|
|
|
| 1467 |
|
| 1468 |
float sumf = 0;
|
| 1469 |
for (int i = tiisg; i < ne00; i += 32) {
|
|
@@ -1483,7 +1536,9 @@ void kernel_mul_mv_impl(
|
|
| 1483 |
break;
|
| 1484 |
}
|
| 1485 |
|
| 1486 |
-
|
|
|
|
|
|
|
| 1487 |
device const T14 * y4 = (device const T14 *) y;
|
| 1488 |
|
| 1489 |
float sumf = 0;
|
|
@@ -1511,12 +1566,14 @@ kernel void kernel_mul_mv(
|
|
| 1511 |
constant uint64_t & nb00,
|
| 1512 |
constant uint64_t & nb01,
|
| 1513 |
constant uint64_t & nb02,
|
|
|
|
| 1514 |
constant int64_t & ne10,
|
| 1515 |
constant int64_t & ne11,
|
| 1516 |
constant int64_t & ne12,
|
| 1517 |
constant uint64_t & nb10,
|
| 1518 |
constant uint64_t & nb11,
|
| 1519 |
constant uint64_t & nb12,
|
|
|
|
| 1520 |
constant int64_t & ne0,
|
| 1521 |
constant int64_t & ne1,
|
| 1522 |
constant uint & r2,
|
|
@@ -1533,12 +1590,14 @@ kernel void kernel_mul_mv(
|
|
| 1533 |
nb00,
|
| 1534 |
nb01,
|
| 1535 |
nb02,
|
|
|
|
| 1536 |
ne10,
|
| 1537 |
ne11,
|
| 1538 |
ne12,
|
| 1539 |
nb10,
|
| 1540 |
nb11,
|
| 1541 |
nb12,
|
|
|
|
| 1542 |
ne0,
|
| 1543 |
ne1,
|
| 1544 |
r2,
|
|
@@ -1564,12 +1623,14 @@ kernel void kernel_mul_mv_1row(
|
|
| 1564 |
constant uint64_t & nb00,
|
| 1565 |
constant uint64_t & nb01,
|
| 1566 |
constant uint64_t & nb02,
|
|
|
|
| 1567 |
constant int64_t & ne10,
|
| 1568 |
constant int64_t & ne11,
|
| 1569 |
constant int64_t & ne12,
|
| 1570 |
constant uint64_t & nb10,
|
| 1571 |
constant uint64_t & nb11,
|
| 1572 |
constant uint64_t & nb12,
|
|
|
|
| 1573 |
constant int64_t & ne0,
|
| 1574 |
constant int64_t & ne1,
|
| 1575 |
constant uint & r2,
|
|
@@ -1584,10 +1645,11 @@ kernel void kernel_mul_mv_1row(
|
|
| 1584 |
const uint i12 = im%ne12;
|
| 1585 |
const uint i13 = im/ne12;
|
| 1586 |
|
| 1587 |
-
const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*
|
|
|
|
| 1588 |
|
| 1589 |
device const T * x = (device const T *) (src0 + offset0);
|
| 1590 |
-
device const float * y = (device const float *) (src1 +
|
| 1591 |
|
| 1592 |
float sumf = 0;
|
| 1593 |
if (ne00 < 128) {
|
|
@@ -1631,12 +1693,14 @@ kernel void kernel_mul_mv_l4(
|
|
| 1631 |
constant uint64_t & nb00,
|
| 1632 |
constant uint64_t & nb01,
|
| 1633 |
constant uint64_t & nb02,
|
|
|
|
| 1634 |
constant int64_t & ne10,
|
| 1635 |
constant int64_t & ne11,
|
| 1636 |
constant int64_t & ne12,
|
| 1637 |
constant uint64_t & nb10,
|
| 1638 |
constant uint64_t & nb11,
|
| 1639 |
constant uint64_t & nb12,
|
|
|
|
| 1640 |
constant int64_t & ne0,
|
| 1641 |
constant int64_t & ne1,
|
| 1642 |
constant uint & r2,
|
|
@@ -1651,12 +1715,14 @@ kernel void kernel_mul_mv_l4(
|
|
| 1651 |
const uint i12 = im%ne12;
|
| 1652 |
const uint i13 = im/ne12;
|
| 1653 |
|
| 1654 |
-
const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*
|
| 1655 |
|
| 1656 |
device const T4 * x4 = (device const T4 *) (src0 + offset0);
|
| 1657 |
|
| 1658 |
for (int r1 = 0; r1 < nrows; ++r1) {
|
| 1659 |
-
|
|
|
|
|
|
|
| 1660 |
|
| 1661 |
float sumf = 0;
|
| 1662 |
for (int i = tiisg; i < ne00/4; i += 32) {
|
|
@@ -3416,8 +3482,14 @@ void kernel_mul_mv_q2_K_f32_impl(
|
|
| 3416 |
int64_t ne00,
|
| 3417 |
int64_t ne01,
|
| 3418 |
int64_t ne02,
|
|
|
|
|
|
|
|
|
|
| 3419 |
int64_t ne10,
|
| 3420 |
int64_t ne12,
|
|
|
|
|
|
|
|
|
|
| 3421 |
int64_t ne0,
|
| 3422 |
int64_t ne1,
|
| 3423 |
uint r2,
|
|
@@ -3433,21 +3505,19 @@ void kernel_mul_mv_q2_K_f32_impl(
|
|
| 3433 |
const int im = tgpig.z;
|
| 3434 |
|
| 3435 |
const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
|
| 3436 |
-
const int ib_row = first_row * nb;
|
| 3437 |
|
| 3438 |
const uint i12 = im%ne12;
|
| 3439 |
const uint i13 = im/ne12;
|
| 3440 |
|
| 3441 |
-
const uint offset0 = (i12/r2)*
|
|
|
|
| 3442 |
|
| 3443 |
-
device const block_q2_K * x = (device const block_q2_K *)
|
| 3444 |
-
device const float * y = (device const float *)
|
| 3445 |
|
| 3446 |
float yl[32];
|
| 3447 |
float sumf[N_DST]={0.f}, all_sum;
|
| 3448 |
|
| 3449 |
-
const int step = sizeof(block_q2_K) * nb;
|
| 3450 |
-
|
| 3451 |
const int ix = tiisg/8; // 0...3
|
| 3452 |
const int it = tiisg%8; // 0...7
|
| 3453 |
const int iq = it/4; // 0 or 1
|
|
@@ -3492,9 +3562,9 @@ void kernel_mul_mv_q2_K_f32_impl(
|
|
| 3492 |
(acc1[3] + 1.f/256.f * acc2[3]) * (sc[6] & 0xF) * 1.f/64.f) -
|
| 3493 |
dmin * (sumy[0] * (sc[0] & 0xF0) + sumy[1] * (sc[2] & 0xF0) + sumy[2] * (sc[4] & 0xF0) + sumy[3] * (sc[6] & 0xF0));
|
| 3494 |
|
| 3495 |
-
qs +=
|
| 3496 |
-
sc +=
|
| 3497 |
-
dh +=
|
| 3498 |
}
|
| 3499 |
|
| 3500 |
y4 += 4 * QK_K;
|
|
@@ -3519,12 +3589,14 @@ kernel void kernel_mul_mv_q2_K_f32(
|
|
| 3519 |
constant uint64_t & nb00,
|
| 3520 |
constant uint64_t & nb01,
|
| 3521 |
constant uint64_t & nb02,
|
|
|
|
| 3522 |
constant int64_t & ne10,
|
| 3523 |
constant int64_t & ne11,
|
| 3524 |
constant int64_t & ne12,
|
| 3525 |
constant uint64_t & nb10,
|
| 3526 |
constant uint64_t & nb11,
|
| 3527 |
constant uint64_t & nb12,
|
|
|
|
| 3528 |
constant int64_t & ne0,
|
| 3529 |
constant int64_t & ne1,
|
| 3530 |
constant uint & r2,
|
|
@@ -3533,7 +3605,7 @@ kernel void kernel_mul_mv_q2_K_f32(
|
|
| 3533 |
uint tiisg[[thread_index_in_simdgroup]],
|
| 3534 |
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
| 3535 |
|
| 3536 |
-
kernel_mul_mv_q2_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg);
|
| 3537 |
}
|
| 3538 |
|
| 3539 |
void kernel_mul_mv_q3_K_f32_impl(
|
|
@@ -3543,8 +3615,14 @@ void kernel_mul_mv_q3_K_f32_impl(
|
|
| 3543 |
int64_t ne00,
|
| 3544 |
int64_t ne01,
|
| 3545 |
int64_t ne02,
|
|
|
|
|
|
|
|
|
|
| 3546 |
int64_t ne10,
|
| 3547 |
int64_t ne12,
|
|
|
|
|
|
|
|
|
|
| 3548 |
int64_t ne0,
|
| 3549 |
int64_t ne1,
|
| 3550 |
uint r2,
|
|
@@ -3565,10 +3643,11 @@ void kernel_mul_mv_q3_K_f32_impl(
|
|
| 3565 |
const uint i12 = im%ne12;
|
| 3566 |
const uint i13 = im/ne12;
|
| 3567 |
|
| 3568 |
-
const uint offset0 = (i12/r2)*
|
|
|
|
| 3569 |
|
| 3570 |
-
device const block_q3_K * x = (device const block_q3_K *)
|
| 3571 |
-
device const float * yy = (device const float *)
|
| 3572 |
|
| 3573 |
float yl[32];
|
| 3574 |
|
|
@@ -3608,8 +3687,6 @@ void kernel_mul_mv_q3_K_f32_impl(
|
|
| 3608 |
const int q_offset = 32*ip + l0;
|
| 3609 |
const int y_offset = 128*ip + 32*il + l0;
|
| 3610 |
|
| 3611 |
-
const int step = sizeof(block_q3_K) * nb / 2;
|
| 3612 |
-
|
| 3613 |
device const float * y1 = yy + ix*QK_K + y_offset;
|
| 3614 |
|
| 3615 |
uint32_t scales32, aux32;
|
|
@@ -3619,7 +3696,6 @@ void kernel_mul_mv_q3_K_f32_impl(
|
|
| 3619 |
float sumf1[2] = {0.f};
|
| 3620 |
float sumf2[2] = {0.f};
|
| 3621 |
for (int i = ix; i < nb; i += 4) {
|
| 3622 |
-
|
| 3623 |
for (int l = 0; l < 8; ++l) {
|
| 3624 |
yl[l+ 0] = y1[l+ 0];
|
| 3625 |
yl[l+ 8] = y1[l+16];
|
|
@@ -3633,7 +3709,6 @@ void kernel_mul_mv_q3_K_f32_impl(
|
|
| 3633 |
device const half * dh = &x[i].d;
|
| 3634 |
|
| 3635 |
for (int row = 0; row < 2; ++row) {
|
| 3636 |
-
|
| 3637 |
const float d_all = (float)dh[0];
|
| 3638 |
|
| 3639 |
scales16[0] = a[4];
|
|
@@ -3673,15 +3748,13 @@ void kernel_mul_mv_q3_K_f32_impl(
|
|
| 3673 |
sumf1[row] += d1 * (scales[1] - 32);
|
| 3674 |
sumf2[row] += d2 * (scales[3] - 32);
|
| 3675 |
|
| 3676 |
-
q +=
|
| 3677 |
-
h +=
|
| 3678 |
-
a +=
|
| 3679 |
-
dh +=
|
| 3680 |
-
|
| 3681 |
}
|
| 3682 |
|
| 3683 |
y1 += 4 * QK_K;
|
| 3684 |
-
|
| 3685 |
}
|
| 3686 |
|
| 3687 |
for (int row = 0; row < 2; ++row) {
|
|
@@ -3706,12 +3779,14 @@ kernel void kernel_mul_mv_q3_K_f32(
|
|
| 3706 |
constant uint64_t & nb00,
|
| 3707 |
constant uint64_t & nb01,
|
| 3708 |
constant uint64_t & nb02,
|
|
|
|
| 3709 |
constant int64_t & ne10,
|
| 3710 |
constant int64_t & ne11,
|
| 3711 |
constant int64_t & ne12,
|
| 3712 |
constant uint64_t & nb10,
|
| 3713 |
constant uint64_t & nb11,
|
| 3714 |
constant uint64_t & nb12,
|
|
|
|
| 3715 |
constant int64_t & ne0,
|
| 3716 |
constant int64_t & ne1,
|
| 3717 |
constant uint & r2,
|
|
@@ -3720,7 +3795,7 @@ kernel void kernel_mul_mv_q3_K_f32(
|
|
| 3720 |
uint tiisg[[thread_index_in_simdgroup]],
|
| 3721 |
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
| 3722 |
|
| 3723 |
-
kernel_mul_mv_q3_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg);
|
| 3724 |
}
|
| 3725 |
|
| 3726 |
void kernel_mul_mv_q4_K_f32_impl(
|
|
@@ -3730,8 +3805,14 @@ void kernel_mul_mv_q4_K_f32_impl(
|
|
| 3730 |
int64_t ne00,
|
| 3731 |
int64_t ne01,
|
| 3732 |
int64_t ne02,
|
|
|
|
|
|
|
|
|
|
| 3733 |
int64_t ne10,
|
| 3734 |
int64_t ne12,
|
|
|
|
|
|
|
|
|
|
| 3735 |
int64_t ne0,
|
| 3736 |
int64_t ne1,
|
| 3737 |
uint r2,
|
|
@@ -3756,29 +3837,26 @@ void kernel_mul_mv_q4_K_f32_impl(
|
|
| 3756 |
const int im = tgpig.z;
|
| 3757 |
//const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
|
| 3758 |
const int first_row = r0 * N_DST;
|
| 3759 |
-
const int ib_row = first_row * nb;
|
| 3760 |
|
| 3761 |
const uint i12 = im%ne12;
|
| 3762 |
const uint i13 = im/ne12;
|
| 3763 |
|
| 3764 |
-
const uint offset0 = (i12/r2)*
|
|
|
|
| 3765 |
|
| 3766 |
-
device const block_q4_K * x = (device const block_q4_K *)
|
| 3767 |
-
device const float * y = (device const float *)
|
| 3768 |
|
| 3769 |
float yl[16];
|
| 3770 |
float yh[16];
|
| 3771 |
float sumf[N_DST]={0.f}, all_sum;
|
| 3772 |
|
| 3773 |
-
const int step = sizeof(block_q4_K) * nb / 2;
|
| 3774 |
-
|
| 3775 |
device const float * y4 = y + ix * QK_K + 64 * iq + 8 * ir;
|
| 3776 |
|
| 3777 |
uint16_t sc16[4];
|
| 3778 |
thread const uint8_t * sc8 = (thread const uint8_t *)sc16;
|
| 3779 |
|
| 3780 |
for (int ib = ix; ib < nb; ib += 4) {
|
| 3781 |
-
|
| 3782 |
float4 sumy = {0.f, 0.f, 0.f, 0.f};
|
| 3783 |
for (int i = 0; i < 8; ++i) {
|
| 3784 |
yl[i+0] = y4[i+ 0]; sumy[0] += yl[i+0];
|
|
@@ -3792,7 +3870,6 @@ void kernel_mul_mv_q4_K_f32_impl(
|
|
| 3792 |
device const half * dh = &x[ib].d;
|
| 3793 |
|
| 3794 |
for (int row = 0; row < N_DST; row++) {
|
| 3795 |
-
|
| 3796 |
sc16[0] = sc[0] & kmask1;
|
| 3797 |
sc16[1] = sc[2] & kmask1;
|
| 3798 |
sc16[2] = ((sc[4] >> 0) & kmask2) | ((sc[0] & kmask3) >> 2);
|
|
@@ -3821,9 +3898,9 @@ void kernel_mul_mv_q4_K_f32_impl(
|
|
| 3821 |
(acc2[2] + 1.f/256.f * acc2[3]) * sc8[5] * 1.f/16.f) -
|
| 3822 |
dmin * (sumy[0] * sc8[2] + sumy[1] * sc8[3] + sumy[2] * sc8[6] + sumy[3] * sc8[7]);
|
| 3823 |
|
| 3824 |
-
q1 +=
|
| 3825 |
-
sc +=
|
| 3826 |
-
dh +=
|
| 3827 |
}
|
| 3828 |
|
| 3829 |
y4 += 4 * QK_K;
|
|
@@ -3848,12 +3925,14 @@ kernel void kernel_mul_mv_q4_K_f32(
|
|
| 3848 |
constant uint64_t & nb00,
|
| 3849 |
constant uint64_t & nb01,
|
| 3850 |
constant uint64_t & nb02,
|
|
|
|
| 3851 |
constant int64_t & ne10,
|
| 3852 |
constant int64_t & ne11,
|
| 3853 |
constant int64_t & ne12,
|
| 3854 |
constant uint64_t & nb10,
|
| 3855 |
constant uint64_t & nb11,
|
| 3856 |
constant uint64_t & nb12,
|
|
|
|
| 3857 |
constant int64_t & ne0,
|
| 3858 |
constant int64_t & ne1,
|
| 3859 |
constant uint & r2,
|
|
@@ -3862,7 +3941,7 @@ kernel void kernel_mul_mv_q4_K_f32(
|
|
| 3862 |
uint tiisg[[thread_index_in_simdgroup]],
|
| 3863 |
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
| 3864 |
|
| 3865 |
-
kernel_mul_mv_q4_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg);
|
| 3866 |
}
|
| 3867 |
|
| 3868 |
void kernel_mul_mv_q5_K_f32_impl(
|
|
@@ -3872,8 +3951,14 @@ void kernel_mul_mv_q5_K_f32_impl(
|
|
| 3872 |
int64_t ne00,
|
| 3873 |
int64_t ne01,
|
| 3874 |
int64_t ne02,
|
|
|
|
|
|
|
|
|
|
| 3875 |
int64_t ne10,
|
| 3876 |
int64_t ne12,
|
|
|
|
|
|
|
|
|
|
| 3877 |
int64_t ne0,
|
| 3878 |
int64_t ne1,
|
| 3879 |
uint r2,
|
|
@@ -3894,15 +3979,14 @@ void kernel_mul_mv_q5_K_f32_impl(
|
|
| 3894 |
const uint i12 = im%ne12;
|
| 3895 |
const uint i13 = im/ne12;
|
| 3896 |
|
| 3897 |
-
const uint offset0 = (i12/r2)*
|
|
|
|
| 3898 |
|
| 3899 |
-
device const block_q5_K * x = (device const block_q5_K *)
|
| 3900 |
-
device const float * yy = (device const float *)
|
| 3901 |
|
| 3902 |
float sumf[2]={0.f};
|
| 3903 |
|
| 3904 |
-
const int step = sizeof(block_q5_K) * nb;
|
| 3905 |
-
|
| 3906 |
float yl[16], yh[16];
|
| 3907 |
|
| 3908 |
const uint16_t kmask1 = 0x3f3f;
|
|
@@ -3930,7 +4014,6 @@ void kernel_mul_mv_q5_K_f32_impl(
|
|
| 3930 |
device const float * y1 = yy + ix*QK_K + y_offset;
|
| 3931 |
|
| 3932 |
for (int i = ix; i < nb; i += 4) {
|
| 3933 |
-
|
| 3934 |
device const uint8_t * q1 = x[i].qs + q_offset;
|
| 3935 |
device const uint8_t * qh = x[i].qh + l0;
|
| 3936 |
device const half * dh = &x[i].d;
|
|
@@ -3946,7 +4029,6 @@ void kernel_mul_mv_q5_K_f32_impl(
|
|
| 3946 |
}
|
| 3947 |
|
| 3948 |
for (int row = 0; row < 2; ++row) {
|
| 3949 |
-
|
| 3950 |
device const uint8_t * q2 = q1 + 64;
|
| 3951 |
|
| 3952 |
sc16[0] = a[0] & kmask1;
|
|
@@ -3975,15 +4057,13 @@ void kernel_mul_mv_q5_K_f32_impl(
|
|
| 3975 |
sc8[5] * (acc1[3]/16.f + 16.f*acc2[3])) -
|
| 3976 |
dmin * (sumy[0] * sc8[2] + sumy[1] * sc8[3] + sumy[2] * sc8[6] + sumy[3] * sc8[7]);
|
| 3977 |
|
| 3978 |
-
q1 +=
|
| 3979 |
-
qh +=
|
| 3980 |
-
dh +=
|
| 3981 |
-
a +=
|
| 3982 |
-
|
| 3983 |
}
|
| 3984 |
|
| 3985 |
y1 += 4 * QK_K;
|
| 3986 |
-
|
| 3987 |
}
|
| 3988 |
|
| 3989 |
for (int row = 0; row < 2; ++row) {
|
|
@@ -4005,12 +4085,14 @@ kernel void kernel_mul_mv_q5_K_f32(
|
|
| 4005 |
constant uint64_t & nb00,
|
| 4006 |
constant uint64_t & nb01,
|
| 4007 |
constant uint64_t & nb02,
|
|
|
|
| 4008 |
constant int64_t & ne10,
|
| 4009 |
constant int64_t & ne11,
|
| 4010 |
constant int64_t & ne12,
|
| 4011 |
constant uint64_t & nb10,
|
| 4012 |
constant uint64_t & nb11,
|
| 4013 |
constant uint64_t & nb12,
|
|
|
|
| 4014 |
constant int64_t & ne0,
|
| 4015 |
constant int64_t & ne1,
|
| 4016 |
constant uint & r2,
|
|
@@ -4019,7 +4101,7 @@ kernel void kernel_mul_mv_q5_K_f32(
|
|
| 4019 |
uint tiisg[[thread_index_in_simdgroup]],
|
| 4020 |
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
| 4021 |
|
| 4022 |
-
kernel_mul_mv_q5_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg);
|
| 4023 |
}
|
| 4024 |
|
| 4025 |
void kernel_mul_mv_q6_K_f32_impl(
|
|
@@ -4029,8 +4111,14 @@ void kernel_mul_mv_q6_K_f32_impl(
|
|
| 4029 |
int64_t ne00,
|
| 4030 |
int64_t ne01,
|
| 4031 |
int64_t ne02,
|
|
|
|
|
|
|
|
|
|
| 4032 |
int64_t ne10,
|
| 4033 |
int64_t ne12,
|
|
|
|
|
|
|
|
|
|
| 4034 |
int64_t ne0,
|
| 4035 |
int64_t ne1,
|
| 4036 |
uint r2,
|
|
@@ -4056,10 +4144,11 @@ void kernel_mul_mv_q6_K_f32_impl(
|
|
| 4056 |
const uint i12 = im%ne12;
|
| 4057 |
const uint i13 = im/ne12;
|
| 4058 |
|
| 4059 |
-
const uint offset0 = (i12/r2)*
|
|
|
|
| 4060 |
|
| 4061 |
-
device const block_q6_K * x = (device const block_q6_K *)
|
| 4062 |
-
device const float * yy = (device const float *)
|
| 4063 |
|
| 4064 |
float sumf = 0;
|
| 4065 |
|
|
@@ -4115,12 +4204,14 @@ kernel void kernel_mul_mv_q6_K_f32(
|
|
| 4115 |
constant uint64_t & nb00,
|
| 4116 |
constant uint64_t & nb01,
|
| 4117 |
constant uint64_t & nb02,
|
|
|
|
| 4118 |
constant int64_t & ne10,
|
| 4119 |
constant int64_t & ne11,
|
| 4120 |
constant int64_t & ne12,
|
| 4121 |
constant uint64_t & nb10,
|
| 4122 |
constant uint64_t & nb11,
|
| 4123 |
constant uint64_t & nb12,
|
|
|
|
| 4124 |
constant int64_t & ne0,
|
| 4125 |
constant int64_t & ne1,
|
| 4126 |
constant uint & r2,
|
|
@@ -4129,7 +4220,7 @@ kernel void kernel_mul_mv_q6_K_f32(
|
|
| 4129 |
uint tiisg[[thread_index_in_simdgroup]],
|
| 4130 |
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
| 4131 |
|
| 4132 |
-
kernel_mul_mv_q6_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg);
|
| 4133 |
}
|
| 4134 |
|
| 4135 |
// ======================= "True" 2-bit
|
|
@@ -4141,8 +4232,14 @@ void kernel_mul_mv_iq2_xxs_f32_impl(
|
|
| 4141 |
int64_t ne00,
|
| 4142 |
int64_t ne01,
|
| 4143 |
int64_t ne02,
|
|
|
|
|
|
|
|
|
|
| 4144 |
int64_t ne10,
|
| 4145 |
int64_t ne12,
|
|
|
|
|
|
|
|
|
|
| 4146 |
int64_t ne0,
|
| 4147 |
int64_t ne1,
|
| 4148 |
uint r2,
|
|
@@ -4158,15 +4255,15 @@ void kernel_mul_mv_iq2_xxs_f32_impl(
|
|
| 4158 |
const int im = tgpig.z;
|
| 4159 |
|
| 4160 |
const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
|
| 4161 |
-
const int ib_row = first_row * nb;
|
| 4162 |
|
| 4163 |
const uint i12 = im%ne12;
|
| 4164 |
const uint i13 = im/ne12;
|
| 4165 |
|
| 4166 |
-
const uint offset0 = (i12/r2)*
|
|
|
|
| 4167 |
|
| 4168 |
-
device const block_iq2_xxs * x = (device const block_iq2_xxs *)
|
| 4169 |
-
device const float * y = (device const float *)
|
| 4170 |
|
| 4171 |
float yl[32];
|
| 4172 |
float sumf[N_DST]={0.f}, all_sum;
|
|
@@ -4219,8 +4316,8 @@ void kernel_mul_mv_iq2_xxs_f32_impl(
|
|
| 4219 |
}
|
| 4220 |
sumf[row] += d * sum;
|
| 4221 |
|
| 4222 |
-
dh +=
|
| 4223 |
-
q2 +=
|
| 4224 |
}
|
| 4225 |
|
| 4226 |
y4 += 32 * 32;
|
|
@@ -4245,12 +4342,14 @@ kernel void kernel_mul_mv_iq2_xxs_f32(
|
|
| 4245 |
constant uint64_t & nb00,
|
| 4246 |
constant uint64_t & nb01,
|
| 4247 |
constant uint64_t & nb02,
|
|
|
|
| 4248 |
constant int64_t & ne10,
|
| 4249 |
constant int64_t & ne11,
|
| 4250 |
constant int64_t & ne12,
|
| 4251 |
constant uint64_t & nb10,
|
| 4252 |
constant uint64_t & nb11,
|
| 4253 |
constant uint64_t & nb12,
|
|
|
|
| 4254 |
constant int64_t & ne0,
|
| 4255 |
constant int64_t & ne1,
|
| 4256 |
constant uint & r2,
|
|
@@ -4260,7 +4359,7 @@ kernel void kernel_mul_mv_iq2_xxs_f32(
|
|
| 4260 |
uint tiisg[[thread_index_in_simdgroup]],
|
| 4261 |
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
| 4262 |
|
| 4263 |
-
kernel_mul_mv_iq2_xxs_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
|
| 4264 |
}
|
| 4265 |
|
| 4266 |
void kernel_mul_mv_iq2_xs_f32_impl(
|
|
@@ -4270,8 +4369,14 @@ void kernel_mul_mv_iq2_xs_f32_impl(
|
|
| 4270 |
int64_t ne00,
|
| 4271 |
int64_t ne01,
|
| 4272 |
int64_t ne02,
|
|
|
|
|
|
|
|
|
|
| 4273 |
int64_t ne10,
|
| 4274 |
int64_t ne12,
|
|
|
|
|
|
|
|
|
|
| 4275 |
int64_t ne0,
|
| 4276 |
int64_t ne1,
|
| 4277 |
uint r2,
|
|
@@ -4287,15 +4392,15 @@ void kernel_mul_mv_iq2_xs_f32_impl(
|
|
| 4287 |
const int im = tgpig.z;
|
| 4288 |
|
| 4289 |
const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
|
| 4290 |
-
const int ib_row = first_row * nb;
|
| 4291 |
|
| 4292 |
const uint i12 = im%ne12;
|
| 4293 |
const uint i13 = im/ne12;
|
| 4294 |
|
| 4295 |
-
const uint offset0 = (i12/r2)*
|
|
|
|
| 4296 |
|
| 4297 |
-
device const block_iq2_xs * x = (device const block_iq2_xs *)
|
| 4298 |
-
device const float * y = (device const float *)
|
| 4299 |
|
| 4300 |
float yl[32];
|
| 4301 |
float sumf[N_DST]={0.f}, all_sum;
|
|
@@ -4357,9 +4462,9 @@ void kernel_mul_mv_iq2_xs_f32_impl(
|
|
| 4357 |
}
|
| 4358 |
sumf[row] += d1 * sum1 + d2 * sum2;
|
| 4359 |
|
| 4360 |
-
dh +=
|
| 4361 |
-
q2 +=
|
| 4362 |
-
sc +=
|
| 4363 |
}
|
| 4364 |
|
| 4365 |
y4 += 32 * 32;
|
|
@@ -4384,12 +4489,14 @@ kernel void kernel_mul_mv_iq2_xs_f32(
|
|
| 4384 |
constant uint64_t & nb00,
|
| 4385 |
constant uint64_t & nb01,
|
| 4386 |
constant uint64_t & nb02,
|
|
|
|
| 4387 |
constant int64_t & ne10,
|
| 4388 |
constant int64_t & ne11,
|
| 4389 |
constant int64_t & ne12,
|
| 4390 |
constant uint64_t & nb10,
|
| 4391 |
constant uint64_t & nb11,
|
| 4392 |
constant uint64_t & nb12,
|
|
|
|
| 4393 |
constant int64_t & ne0,
|
| 4394 |
constant int64_t & ne1,
|
| 4395 |
constant uint & r2,
|
|
@@ -4399,7 +4506,7 @@ kernel void kernel_mul_mv_iq2_xs_f32(
|
|
| 4399 |
uint tiisg[[thread_index_in_simdgroup]],
|
| 4400 |
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
| 4401 |
|
| 4402 |
-
kernel_mul_mv_iq2_xs_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
|
| 4403 |
}
|
| 4404 |
|
| 4405 |
void kernel_mul_mv_iq3_xxs_f32_impl(
|
|
@@ -4409,8 +4516,14 @@ void kernel_mul_mv_iq3_xxs_f32_impl(
|
|
| 4409 |
int64_t ne00,
|
| 4410 |
int64_t ne01,
|
| 4411 |
int64_t ne02,
|
|
|
|
|
|
|
|
|
|
| 4412 |
int64_t ne10,
|
| 4413 |
int64_t ne12,
|
|
|
|
|
|
|
|
|
|
| 4414 |
int64_t ne0,
|
| 4415 |
int64_t ne1,
|
| 4416 |
uint r2,
|
|
@@ -4426,15 +4539,15 @@ void kernel_mul_mv_iq3_xxs_f32_impl(
|
|
| 4426 |
const int im = tgpig.z;
|
| 4427 |
|
| 4428 |
const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
|
| 4429 |
-
const int ib_row = first_row * nb;
|
| 4430 |
|
| 4431 |
const uint i12 = im%ne12;
|
| 4432 |
const uint i13 = im/ne12;
|
| 4433 |
|
| 4434 |
-
const uint offset0 = (i12/r2)*
|
|
|
|
| 4435 |
|
| 4436 |
-
device const block_iq3_xxs * x = (device const block_iq3_xxs *)
|
| 4437 |
-
device const float * y = (device const float *)
|
| 4438 |
|
| 4439 |
float yl[32];
|
| 4440 |
float sumf[N_DST]={0.f}, all_sum;
|
|
@@ -4489,9 +4602,9 @@ void kernel_mul_mv_iq3_xxs_f32_impl(
|
|
| 4489 |
}
|
| 4490 |
sumf[row] += d * (sum[0] + sum[1]);
|
| 4491 |
|
| 4492 |
-
dh +=
|
| 4493 |
-
q3 +=
|
| 4494 |
-
gas +=
|
| 4495 |
}
|
| 4496 |
|
| 4497 |
y4 += 32 * 32;
|
|
@@ -4516,12 +4629,14 @@ kernel void kernel_mul_mv_iq3_xxs_f32(
|
|
| 4516 |
constant uint64_t & nb00,
|
| 4517 |
constant uint64_t & nb01,
|
| 4518 |
constant uint64_t & nb02,
|
|
|
|
| 4519 |
constant int64_t & ne10,
|
| 4520 |
constant int64_t & ne11,
|
| 4521 |
constant int64_t & ne12,
|
| 4522 |
constant uint64_t & nb10,
|
| 4523 |
constant uint64_t & nb11,
|
| 4524 |
constant uint64_t & nb12,
|
|
|
|
| 4525 |
constant int64_t & ne0,
|
| 4526 |
constant int64_t & ne1,
|
| 4527 |
constant uint & r2,
|
|
@@ -4531,7 +4646,7 @@ kernel void kernel_mul_mv_iq3_xxs_f32(
|
|
| 4531 |
uint tiisg[[thread_index_in_simdgroup]],
|
| 4532 |
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
| 4533 |
|
| 4534 |
-
kernel_mul_mv_iq3_xxs_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
|
| 4535 |
}
|
| 4536 |
|
| 4537 |
void kernel_mul_mv_iq3_s_f32_impl(
|
|
@@ -4541,8 +4656,14 @@ void kernel_mul_mv_iq3_s_f32_impl(
|
|
| 4541 |
int64_t ne00,
|
| 4542 |
int64_t ne01,
|
| 4543 |
int64_t ne02,
|
|
|
|
|
|
|
|
|
|
| 4544 |
int64_t ne10,
|
| 4545 |
int64_t ne12,
|
|
|
|
|
|
|
|
|
|
| 4546 |
int64_t ne0,
|
| 4547 |
int64_t ne1,
|
| 4548 |
uint r2,
|
|
@@ -4558,15 +4679,15 @@ void kernel_mul_mv_iq3_s_f32_impl(
|
|
| 4558 |
const int im = tgpig.z;
|
| 4559 |
|
| 4560 |
const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
|
| 4561 |
-
const int ib_row = first_row * nb;
|
| 4562 |
|
| 4563 |
const uint i12 = im%ne12;
|
| 4564 |
const uint i13 = im/ne12;
|
| 4565 |
|
| 4566 |
-
const uint offset0 = (i12/r2)*
|
|
|
|
| 4567 |
|
| 4568 |
-
device const block_iq3_s * x = (device const block_iq3_s *)
|
| 4569 |
-
device const float * y = (device const float *)
|
| 4570 |
|
| 4571 |
float yl[32];
|
| 4572 |
float sumf[N_DST]={0.f}, all_sum;
|
|
@@ -4619,11 +4740,11 @@ void kernel_mul_mv_iq3_s_f32_impl(
|
|
| 4619 |
}
|
| 4620 |
sumf[row] += d * (sum[0] + sum[1]);
|
| 4621 |
|
| 4622 |
-
dh
|
| 4623 |
-
qs
|
| 4624 |
-
qh
|
| 4625 |
-
sc
|
| 4626 |
-
signs +=
|
| 4627 |
}
|
| 4628 |
|
| 4629 |
y4 += 32 * 32;
|
|
@@ -4648,12 +4769,14 @@ kernel void kernel_mul_mv_iq3_s_f32(
|
|
| 4648 |
constant uint64_t & nb00,
|
| 4649 |
constant uint64_t & nb01,
|
| 4650 |
constant uint64_t & nb02,
|
|
|
|
| 4651 |
constant int64_t & ne10,
|
| 4652 |
constant int64_t & ne11,
|
| 4653 |
constant int64_t & ne12,
|
| 4654 |
constant uint64_t & nb10,
|
| 4655 |
constant uint64_t & nb11,
|
| 4656 |
constant uint64_t & nb12,
|
|
|
|
| 4657 |
constant int64_t & ne0,
|
| 4658 |
constant int64_t & ne1,
|
| 4659 |
constant uint & r2,
|
|
@@ -4663,7 +4786,7 @@ kernel void kernel_mul_mv_iq3_s_f32(
|
|
| 4663 |
uint tiisg[[thread_index_in_simdgroup]],
|
| 4664 |
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
| 4665 |
|
| 4666 |
-
kernel_mul_mv_iq3_s_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
|
| 4667 |
}
|
| 4668 |
|
| 4669 |
void kernel_mul_mv_iq2_s_f32_impl(
|
|
@@ -4673,8 +4796,14 @@ void kernel_mul_mv_iq2_s_f32_impl(
|
|
| 4673 |
int64_t ne00,
|
| 4674 |
int64_t ne01,
|
| 4675 |
int64_t ne02,
|
|
|
|
|
|
|
|
|
|
| 4676 |
int64_t ne10,
|
| 4677 |
int64_t ne12,
|
|
|
|
|
|
|
|
|
|
| 4678 |
int64_t ne0,
|
| 4679 |
int64_t ne1,
|
| 4680 |
uint r2,
|
|
@@ -4690,15 +4819,15 @@ void kernel_mul_mv_iq2_s_f32_impl(
|
|
| 4690 |
const int im = tgpig.z;
|
| 4691 |
|
| 4692 |
const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
|
| 4693 |
-
const int ib_row = first_row * nb;
|
| 4694 |
|
| 4695 |
const uint i12 = im%ne12;
|
| 4696 |
const uint i13 = im/ne12;
|
| 4697 |
|
| 4698 |
-
const uint offset0 = (i12/r2)*
|
|
|
|
| 4699 |
|
| 4700 |
-
device const block_iq2_s * x = (device const block_iq2_s *)
|
| 4701 |
-
device const float * y = (device const float *)
|
| 4702 |
|
| 4703 |
float yl[32];
|
| 4704 |
float sumf[N_DST]={0.f}, all_sum;
|
|
@@ -4752,11 +4881,11 @@ void kernel_mul_mv_iq2_s_f32_impl(
|
|
| 4752 |
}
|
| 4753 |
sumf[row] += d1 * sum[0] + d2 * sum[1];
|
| 4754 |
|
| 4755 |
-
dh
|
| 4756 |
-
qs
|
| 4757 |
-
qh
|
| 4758 |
-
sc
|
| 4759 |
-
signs +=
|
| 4760 |
}
|
| 4761 |
|
| 4762 |
y4 += 32 * 32;
|
|
@@ -4781,12 +4910,14 @@ kernel void kernel_mul_mv_iq2_s_f32(
|
|
| 4781 |
constant uint64_t & nb00,
|
| 4782 |
constant uint64_t & nb01,
|
| 4783 |
constant uint64_t & nb02,
|
|
|
|
| 4784 |
constant int64_t & ne10,
|
| 4785 |
constant int64_t & ne11,
|
| 4786 |
constant int64_t & ne12,
|
| 4787 |
constant uint64_t & nb10,
|
| 4788 |
constant uint64_t & nb11,
|
| 4789 |
constant uint64_t & nb12,
|
|
|
|
| 4790 |
constant int64_t & ne0,
|
| 4791 |
constant int64_t & ne1,
|
| 4792 |
constant uint & r2,
|
|
@@ -4796,7 +4927,7 @@ kernel void kernel_mul_mv_iq2_s_f32(
|
|
| 4796 |
uint tiisg[[thread_index_in_simdgroup]],
|
| 4797 |
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
| 4798 |
|
| 4799 |
-
kernel_mul_mv_iq2_s_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
|
| 4800 |
}
|
| 4801 |
|
| 4802 |
void kernel_mul_mv_iq1_s_f32_impl(
|
|
@@ -4806,8 +4937,14 @@ void kernel_mul_mv_iq1_s_f32_impl(
|
|
| 4806 |
int64_t ne00,
|
| 4807 |
int64_t ne01,
|
| 4808 |
int64_t ne02,
|
|
|
|
|
|
|
|
|
|
| 4809 |
int64_t ne10,
|
| 4810 |
int64_t ne12,
|
|
|
|
|
|
|
|
|
|
| 4811 |
int64_t ne0,
|
| 4812 |
int64_t ne1,
|
| 4813 |
uint r2,
|
|
@@ -4823,14 +4960,15 @@ void kernel_mul_mv_iq1_s_f32_impl(
|
|
| 4823 |
const int im = tgpig.z;
|
| 4824 |
|
| 4825 |
const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
|
| 4826 |
-
const int ib_row = first_row * nb;
|
| 4827 |
|
| 4828 |
const uint i12 = im%ne12;
|
| 4829 |
const uint i13 = im/ne12;
|
| 4830 |
|
| 4831 |
-
const uint offset0 = (i12/r2)*
|
| 4832 |
-
|
| 4833 |
-
|
|
|
|
|
|
|
| 4834 |
|
| 4835 |
float yl[32];
|
| 4836 |
float sumf[N_DST]={0.f}, all_sum;
|
|
@@ -4873,9 +5011,9 @@ void kernel_mul_mv_iq1_s_f32_impl(
|
|
| 4873 |
}
|
| 4874 |
sumf[row] += (float)dh[0] * (sum + sumy * (qh[0] & 0x8000 ? -1 - IQ1S_DELTA : -1 + IQ1S_DELTA)) * (2*((qh[0] >> 12) & 7) + 1);
|
| 4875 |
|
| 4876 |
-
dh +=
|
| 4877 |
-
qs +=
|
| 4878 |
-
qh +=
|
| 4879 |
}
|
| 4880 |
|
| 4881 |
y4 += 32 * 32;
|
|
@@ -4896,8 +5034,14 @@ void kernel_mul_mv_iq1_m_f32_impl(
|
|
| 4896 |
int64_t ne00,
|
| 4897 |
int64_t ne01,
|
| 4898 |
int64_t ne02,
|
|
|
|
|
|
|
|
|
|
| 4899 |
int64_t ne10,
|
| 4900 |
int64_t ne12,
|
|
|
|
|
|
|
|
|
|
| 4901 |
int64_t ne0,
|
| 4902 |
int64_t ne1,
|
| 4903 |
uint r2,
|
|
@@ -4913,14 +5057,15 @@ void kernel_mul_mv_iq1_m_f32_impl(
|
|
| 4913 |
const int im = tgpig.z;
|
| 4914 |
|
| 4915 |
const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
|
| 4916 |
-
const int ib_row = first_row * nb;
|
| 4917 |
|
| 4918 |
const uint i12 = im%ne12;
|
| 4919 |
const uint i13 = im/ne12;
|
| 4920 |
|
| 4921 |
-
const uint offset0 = (i12/r2)*
|
| 4922 |
-
|
| 4923 |
-
|
|
|
|
|
|
|
| 4924 |
|
| 4925 |
float yl[32];
|
| 4926 |
float sumf[N_DST]={0.f}, all_sum;
|
|
@@ -4972,9 +5117,9 @@ void kernel_mul_mv_iq1_m_f32_impl(
|
|
| 4972 |
sumf[row] += (float)scale.f16 * ((sum[0] + delta1) * (2*((sc[ib/2] >> (6*(ib%2)+0)) & 7) + 1) +
|
| 4973 |
(sum[1] + delta2) * (2*((sc[ib/2] >> (6*(ib%2)+3)) & 7) + 1));
|
| 4974 |
|
| 4975 |
-
sc +=
|
| 4976 |
-
qs +=
|
| 4977 |
-
qh +=
|
| 4978 |
}
|
| 4979 |
|
| 4980 |
y4 += 32 * 32;
|
|
@@ -4995,8 +5140,14 @@ void kernel_mul_mv_iq4_nl_f32_impl(
|
|
| 4995 |
int64_t ne00,
|
| 4996 |
int64_t ne01,
|
| 4997 |
int64_t ne02,
|
|
|
|
|
|
|
|
|
|
| 4998 |
int64_t ne10,
|
| 4999 |
int64_t ne12,
|
|
|
|
|
|
|
|
|
|
| 5000 |
int64_t ne0,
|
| 5001 |
int64_t ne1,
|
| 5002 |
uint r2,
|
|
@@ -5012,14 +5163,15 @@ void kernel_mul_mv_iq4_nl_f32_impl(
|
|
| 5012 |
const int r1 = tgpig.y;
|
| 5013 |
const int im = tgpig.z;
|
| 5014 |
const int first_row = (r0 * 2 + sgitg) * 2;
|
| 5015 |
-
const int ib_row = first_row * nb;
|
| 5016 |
|
| 5017 |
const uint i12 = im%ne12;
|
| 5018 |
const uint i13 = im/ne12;
|
| 5019 |
|
| 5020 |
-
const uint offset0 = (i12/r2)*
|
| 5021 |
-
|
| 5022 |
-
|
|
|
|
|
|
|
| 5023 |
|
| 5024 |
const int ix = tiisg/2; // 0...15
|
| 5025 |
const int it = tiisg%2; // 0 or 1
|
|
@@ -5089,8 +5241,14 @@ void kernel_mul_mv_iq4_xs_f32_impl(
|
|
| 5089 |
int64_t ne00,
|
| 5090 |
int64_t ne01,
|
| 5091 |
int64_t ne02,
|
|
|
|
|
|
|
|
|
|
| 5092 |
int64_t ne10,
|
| 5093 |
int64_t ne12,
|
|
|
|
|
|
|
|
|
|
| 5094 |
int64_t ne0,
|
| 5095 |
int64_t ne1,
|
| 5096 |
uint r2,
|
|
@@ -5106,14 +5264,15 @@ void kernel_mul_mv_iq4_xs_f32_impl(
|
|
| 5106 |
const int r1 = tgpig.y;
|
| 5107 |
const int im = tgpig.z;
|
| 5108 |
const int first_row = (r0 * 2 + sgitg) * 2;
|
| 5109 |
-
const int ib_row = first_row * nb;
|
| 5110 |
|
| 5111 |
const uint i12 = im%ne12;
|
| 5112 |
const uint i13 = im/ne12;
|
| 5113 |
|
| 5114 |
-
const uint offset0 = (i12/r2)*
|
| 5115 |
-
|
| 5116 |
-
|
|
|
|
|
|
|
| 5117 |
|
| 5118 |
const int ix = tiisg/16; // 0 or 1
|
| 5119 |
const int it = tiisg%16; // 0...15
|
|
@@ -5188,12 +5347,14 @@ kernel void kernel_mul_mv_iq1_s_f32(
|
|
| 5188 |
constant uint64_t & nb00,
|
| 5189 |
constant uint64_t & nb01,
|
| 5190 |
constant uint64_t & nb02,
|
|
|
|
| 5191 |
constant int64_t & ne10,
|
| 5192 |
constant int64_t & ne11,
|
| 5193 |
constant int64_t & ne12,
|
| 5194 |
constant uint64_t & nb10,
|
| 5195 |
constant uint64_t & nb11,
|
| 5196 |
constant uint64_t & nb12,
|
|
|
|
| 5197 |
constant int64_t & ne0,
|
| 5198 |
constant int64_t & ne1,
|
| 5199 |
constant uint & r2,
|
|
@@ -5202,7 +5363,7 @@ kernel void kernel_mul_mv_iq1_s_f32(
|
|
| 5202 |
uint tiisg[[thread_index_in_simdgroup]],
|
| 5203 |
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
| 5204 |
|
| 5205 |
-
kernel_mul_mv_iq1_s_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg);
|
| 5206 |
}
|
| 5207 |
|
| 5208 |
[[host_name("kernel_mul_mv_iq1_m_f32")]]
|
|
@@ -5216,12 +5377,14 @@ kernel void kernel_mul_mv_iq1_m_f32(
|
|
| 5216 |
constant uint64_t & nb00,
|
| 5217 |
constant uint64_t & nb01,
|
| 5218 |
constant uint64_t & nb02,
|
|
|
|
| 5219 |
constant int64_t & ne10,
|
| 5220 |
constant int64_t & ne11,
|
| 5221 |
constant int64_t & ne12,
|
| 5222 |
constant uint64_t & nb10,
|
| 5223 |
constant uint64_t & nb11,
|
| 5224 |
constant uint64_t & nb12,
|
|
|
|
| 5225 |
constant int64_t & ne0,
|
| 5226 |
constant int64_t & ne1,
|
| 5227 |
constant uint & r2,
|
|
@@ -5230,7 +5393,7 @@ kernel void kernel_mul_mv_iq1_m_f32(
|
|
| 5230 |
uint tiisg[[thread_index_in_simdgroup]],
|
| 5231 |
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
| 5232 |
|
| 5233 |
-
kernel_mul_mv_iq1_m_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg);
|
| 5234 |
}
|
| 5235 |
|
| 5236 |
[[host_name("kernel_mul_mv_iq4_nl_f32")]]
|
|
@@ -5244,12 +5407,14 @@ kernel void kernel_mul_mv_iq4_nl_f32(
|
|
| 5244 |
constant uint64_t & nb00,
|
| 5245 |
constant uint64_t & nb01,
|
| 5246 |
constant uint64_t & nb02,
|
|
|
|
| 5247 |
constant int64_t & ne10,
|
| 5248 |
constant int64_t & ne11,
|
| 5249 |
constant int64_t & ne12,
|
| 5250 |
constant uint64_t & nb10,
|
| 5251 |
constant uint64_t & nb11,
|
| 5252 |
constant uint64_t & nb12,
|
|
|
|
| 5253 |
constant int64_t & ne0,
|
| 5254 |
constant int64_t & ne1,
|
| 5255 |
constant uint & r2,
|
|
@@ -5259,7 +5424,7 @@ kernel void kernel_mul_mv_iq4_nl_f32(
|
|
| 5259 |
uint tiisg[[thread_index_in_simdgroup]],
|
| 5260 |
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
| 5261 |
|
| 5262 |
-
kernel_mul_mv_iq4_nl_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
|
| 5263 |
}
|
| 5264 |
|
| 5265 |
[[host_name("kernel_mul_mv_iq4_xs_f32")]]
|
|
@@ -5273,12 +5438,14 @@ kernel void kernel_mul_mv_iq4_xs_f32(
|
|
| 5273 |
constant uint64_t & nb00,
|
| 5274 |
constant uint64_t & nb01,
|
| 5275 |
constant uint64_t & nb02,
|
|
|
|
| 5276 |
constant int64_t & ne10,
|
| 5277 |
constant int64_t & ne11,
|
| 5278 |
constant int64_t & ne12,
|
| 5279 |
constant uint64_t & nb10,
|
| 5280 |
constant uint64_t & nb11,
|
| 5281 |
constant uint64_t & nb12,
|
|
|
|
| 5282 |
constant int64_t & ne0,
|
| 5283 |
constant int64_t & ne1,
|
| 5284 |
constant uint & r2,
|
|
@@ -5288,7 +5455,7 @@ kernel void kernel_mul_mv_iq4_xs_f32(
|
|
| 5288 |
uint tiisg[[thread_index_in_simdgroup]],
|
| 5289 |
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
| 5290 |
|
| 5291 |
-
kernel_mul_mv_iq4_xs_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
|
| 5292 |
}
|
| 5293 |
|
| 5294 |
//============================= templates and their specializations =============================
|
|
@@ -5833,10 +6000,12 @@ kernel void kernel_mul_mm(device const uchar * src0,
|
|
| 5833 |
constant int64_t & ne02,
|
| 5834 |
constant uint64_t & nb01,
|
| 5835 |
constant uint64_t & nb02,
|
|
|
|
| 5836 |
constant int64_t & ne12,
|
| 5837 |
constant uint64_t & nb10,
|
| 5838 |
constant uint64_t & nb11,
|
| 5839 |
constant uint64_t & nb12,
|
|
|
|
| 5840 |
constant int64_t & ne0,
|
| 5841 |
constant int64_t & ne1,
|
| 5842 |
constant uint & r2,
|
|
@@ -5873,12 +6042,13 @@ kernel void kernel_mul_mm(device const uchar * src0,
|
|
| 5873 |
const uint i12 = im%ne12;
|
| 5874 |
const uint i13 = im/ne12;
|
| 5875 |
|
| 5876 |
-
uint offset0 = (i12/r2)*nb02 + (i13/r3)*
|
| 5877 |
ushort offset1 = il/nl;
|
| 5878 |
|
| 5879 |
device const block_q * x = (device const block_q *)(src0 + (r0 * BLOCK_SIZE_M + thread_row) * nb01 + offset0) + offset1;
|
| 5880 |
device const float * y = (device const float *)(src1
|
| 5881 |
-
+
|
|
|
|
| 5882 |
+ nb11 * (r1 * BLOCK_SIZE_N + thread_col)
|
| 5883 |
+ nb10 * (BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL)));
|
| 5884 |
|
|
@@ -6257,12 +6427,14 @@ typedef void (kernel_mul_mv_impl_t)(
|
|
| 6257 |
uint64_t nb00,
|
| 6258 |
uint64_t nb01,
|
| 6259 |
uint64_t nb02,
|
|
|
|
| 6260 |
int64_t ne10,
|
| 6261 |
int64_t ne11,
|
| 6262 |
int64_t ne12,
|
| 6263 |
uint64_t nb10,
|
| 6264 |
uint64_t nb11,
|
| 6265 |
uint64_t nb12,
|
|
|
|
| 6266 |
int64_t ne0,
|
| 6267 |
int64_t ne1,
|
| 6268 |
uint r2,
|
|
@@ -6277,8 +6449,14 @@ typedef void (kernel_mul_mv2_impl_t)(
|
|
| 6277 |
int64_t ne00,
|
| 6278 |
int64_t ne01,
|
| 6279 |
int64_t ne02,
|
|
|
|
|
|
|
|
|
|
| 6280 |
int64_t ne10,
|
| 6281 |
int64_t ne12,
|
|
|
|
|
|
|
|
|
|
| 6282 |
int64_t ne0,
|
| 6283 |
int64_t ne1,
|
| 6284 |
uint r2,
|
|
@@ -6299,6 +6477,7 @@ void mmv_fn(
|
|
| 6299 |
uint64_t nb00,
|
| 6300 |
uint64_t nb01,
|
| 6301 |
uint64_t nb02,
|
|
|
|
| 6302 |
int64_t ne10,
|
| 6303 |
int64_t ne11,
|
| 6304 |
int64_t ne12,
|
|
@@ -6306,6 +6485,7 @@ void mmv_fn(
|
|
| 6306 |
uint64_t nb10,
|
| 6307 |
uint64_t nb11,
|
| 6308 |
uint64_t nb12,
|
|
|
|
| 6309 |
int64_t ne0,
|
| 6310 |
int64_t ne1,
|
| 6311 |
uint64_t nb1,
|
|
@@ -6316,7 +6496,7 @@ void mmv_fn(
|
|
| 6316 |
uint tiitg,
|
| 6317 |
uint tiisg,
|
| 6318 |
uint sgitg) {
|
| 6319 |
-
impl_fn(src0,src1,dst,ne00,ne01,ne02,nb00,nb01,nb02,ne10,ne11,ne12,nb10,nb11,nb12,ne0,ne1,r2,r3,tgpig,tiisg);
|
| 6320 |
}
|
| 6321 |
|
| 6322 |
template<kernel_mul_mv2_impl_t impl_fn>
|
|
@@ -6330,6 +6510,7 @@ void mmv_fn(
|
|
| 6330 |
uint64_t nb00,
|
| 6331 |
uint64_t nb01,
|
| 6332 |
uint64_t nb02,
|
|
|
|
| 6333 |
int64_t ne10,
|
| 6334 |
int64_t ne11,
|
| 6335 |
int64_t ne12,
|
|
@@ -6337,6 +6518,7 @@ void mmv_fn(
|
|
| 6337 |
uint64_t nb10,
|
| 6338 |
uint64_t nb11,
|
| 6339 |
uint64_t nb12,
|
|
|
|
| 6340 |
int64_t ne0,
|
| 6341 |
int64_t ne1,
|
| 6342 |
uint64_t nb1,
|
|
@@ -6347,7 +6529,7 @@ void mmv_fn(
|
|
| 6347 |
uint tiitg,
|
| 6348 |
uint tiisg,
|
| 6349 |
uint sgitg) {
|
| 6350 |
-
impl_fn(src0,(const device float *)src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,shared_values,tgpig,tiisg,sgitg);
|
| 6351 |
}
|
| 6352 |
|
| 6353 |
typedef decltype(mmv_fn<kernel_mul_mv_impl<half, half4, half, half4>>) mul_mv_impl_fn_t;
|
|
@@ -6396,8 +6578,8 @@ kernel void kernel_mul_mv_id(
|
|
| 6396 |
const int64_t i2 = i12;
|
| 6397 |
|
| 6398 |
device const char * src0_cur = src0s + i02*nb02;
|
| 6399 |
-
device const char * src1_cur = src1
|
| 6400 |
-
device float *
|
| 6401 |
|
| 6402 |
impl_fn(
|
| 6403 |
/* src0 */ src0_cur,
|
|
@@ -6405,19 +6587,21 @@ kernel void kernel_mul_mv_id(
|
|
| 6405 |
/* dst */ dst_cur,
|
| 6406 |
/* ne00 */ ne00,
|
| 6407 |
/* ne01 */ ne01,
|
| 6408 |
-
/* ne02 */ 1
|
| 6409 |
/* nb00 */ nb00,
|
| 6410 |
/* nb01 */ nb01,
|
| 6411 |
/* nb02 */ nb02,
|
|
|
|
| 6412 |
/* ne10 */ ne10,
|
| 6413 |
-
/* ne11 */ 1
|
| 6414 |
-
/* ne12 */ 1
|
| 6415 |
-
/* ne13 */ 1
|
| 6416 |
/* nb10 */ nb10,
|
| 6417 |
/* nb11 */ nb11,
|
| 6418 |
/* nb12 */ nb12,
|
|
|
|
| 6419 |
/* ne0 */ ne0,
|
| 6420 |
-
/* ne1 */ 1
|
| 6421 |
/* nb1 */ nb1,
|
| 6422 |
/* r2 */ 1,
|
| 6423 |
/* r3 */ 1,
|
|
|
|
| 777 |
const int64_t i3 = tgpig.z;
|
| 778 |
|
| 779 |
const int64_t nc = ne10;
|
| 780 |
+
//const int64_t ncs = ne00;
|
| 781 |
+
//const int64_t nr = ne01;
|
| 782 |
+
//const int64_t n_t = ne1;
|
| 783 |
+
//const int64_t n_s = ne2;
|
| 784 |
|
| 785 |
device const float * s = (device const float *) ((device const char *) src0 + ir*nb01 + i2*nb00 + i3*nb02);
|
| 786 |
device const float * c = (device const float *) ((device const char *) src1 + ir*nb11);
|
|
|
|
| 834 |
const int64_t i3 = tgpig.y;
|
| 835 |
|
| 836 |
const int64_t nc = d_state;
|
| 837 |
+
//const int64_t nr = d_inner;
|
| 838 |
const int64_t n_t = n_seq_tokens;
|
| 839 |
+
//const int64_t n_s = n_seqs;
|
| 840 |
|
| 841 |
for (int64_t i2 = 0; i2 < n_t; ++i2) {
|
| 842 |
device const float * s0 = (device const float *) ((device const char *) src0 + ir*nb01 + i3*nb02);
|
|
|
|
| 1064 |
inline float block_q_n_dot_y(device const block_q4_0 * qb_curr, float sumy, thread float * yl, int il) {
|
| 1065 |
float d = qb_curr->d;
|
| 1066 |
|
| 1067 |
+
float acc[4] = { 0.0f, 0.0f, 0.0f, 0.0f };
|
| 1068 |
|
| 1069 |
+
device const uint16_t * qs = ((device const uint16_t *) qb_curr + 1 + il/2);
|
| 1070 |
|
| 1071 |
+
for (int i = 0; i < 8; i += 2) {
|
| 1072 |
+
acc[0] += yl[i + 0] * (qs[i / 2] & 0x000F);
|
| 1073 |
+
acc[1] += yl[i + 1] * (qs[i / 2] & 0x0F00);
|
| 1074 |
+
acc[2] += yl[i + 8] * (qs[i / 2] & 0x00F0);
|
| 1075 |
+
acc[3] += yl[i + 9] * (qs[i / 2] & 0xF000);
|
| 1076 |
}
|
| 1077 |
+
|
| 1078 |
+
return d * (sumy * -8.f + acc[0] + acc[1] + acc[2] + acc[3]);
|
| 1079 |
}
|
| 1080 |
|
| 1081 |
// function for calculate inner product between half a q4_1 block and 16 floats (yl), sumy is SUM(yl[i])
|
|
|
|
| 1086 |
float d = qb_curr->d;
|
| 1087 |
float m = qb_curr->m;
|
| 1088 |
|
| 1089 |
+
float acc[4] = { 0.0f, 0.0f, 0.0f, 0.0f };
|
| 1090 |
|
| 1091 |
+
device const uint16_t * qs = ((device const uint16_t *) qb_curr + 2 + il/2);
|
| 1092 |
|
| 1093 |
for (int i = 0; i < 8; i+=2) {
|
| 1094 |
+
acc[0] += yl[i + 0] * (qs[i / 2] & 0x000F);
|
| 1095 |
+
acc[1] += yl[i + 1] * (qs[i / 2] & 0x0F00);
|
| 1096 |
+
acc[2] += yl[i + 8] * (qs[i / 2] & 0x00F0);
|
| 1097 |
+
acc[3] += yl[i + 9] * (qs[i / 2] & 0xF000);
|
| 1098 |
}
|
| 1099 |
+
|
| 1100 |
+
return d * (acc[0] + acc[1] + acc[2] + acc[3]) + sumy * m;
|
| 1101 |
}
|
| 1102 |
|
| 1103 |
// function for calculate inner product between half a q5_0 block and 16 floats (yl), sumy is SUM(yl[i])
|
|
|
|
| 1107 |
inline float block_q_n_dot_y(device const block_q5_0 * qb_curr, float sumy, thread float * yl, int il) {
|
| 1108 |
float d = qb_curr->d;
|
| 1109 |
|
| 1110 |
+
float acc[4] = { 0.0f, 0.0f, 0.0f, 0.0f };
|
| 1111 |
|
| 1112 |
device const uint16_t * qs = ((device const uint16_t *)qb_curr + 3 + il/2);
|
| 1113 |
const uint32_t qh = *((device const uint32_t *)qb_curr->qh);
|
| 1114 |
|
| 1115 |
for (int i = 0; i < 8; i+=2) {
|
| 1116 |
+
acc[0] += yl[i + 0] * ((qs[i / 2] & 0x000F) | ((qh >> (i+0+il ) << 4 ) & 0x00010));
|
| 1117 |
+
acc[1] += yl[i + 1] * ((qs[i / 2] & 0x0F00) | ((qh >> (i+1+il ) << 12) & 0x01000));
|
| 1118 |
+
acc[2] += yl[i + 8] * ((qs[i / 2] & 0x00F0) | ((qh >> (i+0+il+QK5_0/2) << 8 ) & 0x00100));
|
| 1119 |
+
acc[3] += yl[i + 9] * ((qs[i / 2] & 0xF000) | ((qh >> (i+1+il+QK5_0/2) << 16) & 0x10000));
|
| 1120 |
}
|
| 1121 |
+
|
| 1122 |
+
return d * (sumy * -16.f + acc[0] + acc[1] + acc[2] + acc[3]);
|
| 1123 |
}
|
| 1124 |
|
| 1125 |
// function for calculate inner product between half a q5_1 block and 16 floats (yl), sumy is SUM(yl[i])
|
|
|
|
| 1130 |
float d = qb_curr->d;
|
| 1131 |
float m = qb_curr->m;
|
| 1132 |
|
| 1133 |
+
float acc[4] = { 0.0f, 0.0f, 0.0f, 0.0f };
|
| 1134 |
|
| 1135 |
device const uint16_t * qs = ((device const uint16_t *)qb_curr + 4 + il/2);
|
| 1136 |
const uint32_t qh = *((device const uint32_t *)qb_curr->qh);
|
| 1137 |
|
| 1138 |
for (int i = 0; i < 8; i+=2) {
|
| 1139 |
+
acc[0] += yl[i + 0] * ((qs[i / 2] & 0x000F) | ((qh >> (i+0+il ) << 4 ) & 0x00010));
|
| 1140 |
+
acc[1] += yl[i + 1] * ((qs[i / 2] & 0x0F00) | ((qh >> (i+1+il ) << 12) & 0x01000));
|
| 1141 |
+
acc[2] += yl[i + 8] * ((qs[i / 2] & 0x00F0) | ((qh >> (i+0+il+QK5_0/2) << 8 ) & 0x00100));
|
| 1142 |
+
acc[3] += yl[i + 9] * ((qs[i / 2] & 0xF000) | ((qh >> (i+1+il+QK5_0/2) << 16) & 0x10000));
|
| 1143 |
}
|
| 1144 |
+
|
| 1145 |
+
return d * (acc[0] + acc[1] + acc[2] + acc[3]) + sumy * m;
|
| 1146 |
}
|
| 1147 |
|
| 1148 |
// putting them in the kernel cause a significant performance penalty
|
|
|
|
| 1160 |
int64_t ne00,
|
| 1161 |
int64_t ne01,
|
| 1162 |
int64_t ne02,
|
| 1163 |
+
uint64_t nb01,
|
| 1164 |
+
uint64_t nb02,
|
| 1165 |
+
uint64_t nb03,
|
| 1166 |
int64_t ne10,
|
| 1167 |
int64_t ne12,
|
| 1168 |
+
uint64_t nb11,
|
| 1169 |
+
uint64_t nb12,
|
| 1170 |
+
uint64_t nb13,
|
| 1171 |
int64_t ne0,
|
| 1172 |
int64_t ne1,
|
| 1173 |
uint r2,
|
| 1174 |
uint r3,
|
| 1175 |
threadgroup int8_t * shared_values,
|
| 1176 |
+
uint3 tgpig,
|
| 1177 |
+
uint tiisg,
|
| 1178 |
+
uint sgitg) {
|
| 1179 |
const int nb = ne00/QK4_0;
|
| 1180 |
|
| 1181 |
const int r0 = tgpig.x;
|
|
|
|
| 1187 |
const uint i12 = im%ne12;
|
| 1188 |
const uint i13 = im/ne12;
|
| 1189 |
|
| 1190 |
+
//const uint offset0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03;
|
| 1191 |
+
const uint offset1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13;
|
| 1192 |
|
| 1193 |
+
//device const block_q_type * x = (device const block_q_type *) ((device char *) src0 + offset0);
|
| 1194 |
+
device const float * y = (device const float *) ((device char *) src1 + offset1);
|
| 1195 |
+
|
| 1196 |
+
// pointers to src0 rows
|
| 1197 |
+
device const block_q_type * ax[nr];
|
| 1198 |
+
for (int row = 0; row < nr; ++row) {
|
| 1199 |
+
const uint offset0 = (first_row + row)*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03;
|
| 1200 |
+
|
| 1201 |
+
ax[row] = (device const block_q_type *) ((device char *) src0 + offset0);
|
| 1202 |
+
}
|
| 1203 |
|
| 1204 |
float yl[16]; // src1 vector cache
|
| 1205 |
float sumf[nr] = {0.f};
|
|
|
|
| 1211 |
|
| 1212 |
// each thread in a SIMD group deals with half a block.
|
| 1213 |
for (int ib = ix; ib < nb; ib += nw/2) {
|
| 1214 |
+
float sumy[2] = { 0.f, 0.f };
|
| 1215 |
+
|
| 1216 |
+
#pragma unroll
|
| 1217 |
for (int i = 0; i < 8; i += 2) {
|
| 1218 |
+
sumy[0] += yb[i + 0] + yb[i + 1];
|
| 1219 |
+
yl[i + 0] = yb[i + 0];
|
| 1220 |
+
yl[i + 1] = yb[i + 1]/256.f;
|
| 1221 |
|
| 1222 |
+
sumy[1] += yb[i + 16] + yb[i + 17];
|
| 1223 |
+
yl[i + 8] = yb[i + 16]/16.f;
|
| 1224 |
+
yl[i + 9] = yb[i + 17]/4096.f;
|
| 1225 |
}
|
| 1226 |
|
| 1227 |
+
#pragma unroll
|
| 1228 |
for (int row = 0; row < nr; row++) {
|
| 1229 |
+
sumf[row] += block_q_n_dot_y(ax[row] + ib, sumy[0] + sumy[1], yl, il);
|
| 1230 |
}
|
| 1231 |
|
| 1232 |
yb += QK4_0 * 16;
|
|
|
|
| 1250 |
constant uint64_t & nb00,
|
| 1251 |
constant uint64_t & nb01,
|
| 1252 |
constant uint64_t & nb02,
|
| 1253 |
+
constant uint64_t & nb03,
|
| 1254 |
constant int64_t & ne10,
|
| 1255 |
constant int64_t & ne11,
|
| 1256 |
constant int64_t & ne12,
|
| 1257 |
constant uint64_t & nb10,
|
| 1258 |
constant uint64_t & nb11,
|
| 1259 |
constant uint64_t & nb12,
|
| 1260 |
+
constant uint64_t & nb13,
|
| 1261 |
constant int64_t & ne0,
|
| 1262 |
constant int64_t & ne1,
|
| 1263 |
constant uint & r2,
|
|
|
|
| 1265 |
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 1266 |
uint tiisg[[thread_index_in_simdgroup]],
|
| 1267 |
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
| 1268 |
+
mul_vec_q_n_f32_impl<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,nb01,nb02,nb03,ne10,ne12,nb11,nb12,nb13,ne0,ne1,r2,r3,nullptr,tgpig,tiisg,sgitg);
|
| 1269 |
}
|
| 1270 |
|
| 1271 |
kernel void kernel_mul_mv_q4_1_f32(
|
|
|
|
| 1278 |
constant uint64_t & nb00,
|
| 1279 |
constant uint64_t & nb01,
|
| 1280 |
constant uint64_t & nb02,
|
| 1281 |
+
constant uint64_t & nb03,
|
| 1282 |
constant int64_t & ne10,
|
| 1283 |
constant int64_t & ne11,
|
| 1284 |
constant int64_t & ne12,
|
| 1285 |
constant uint64_t & nb10,
|
| 1286 |
constant uint64_t & nb11,
|
| 1287 |
constant uint64_t & nb12,
|
| 1288 |
+
constant uint64_t & nb13,
|
| 1289 |
constant int64_t & ne0,
|
| 1290 |
constant int64_t & ne1,
|
| 1291 |
constant uint & r2,
|
|
|
|
| 1293 |
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 1294 |
uint tiisg[[thread_index_in_simdgroup]],
|
| 1295 |
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
| 1296 |
+
mul_vec_q_n_f32_impl<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,nb01,nb02,nb03,ne10,ne12,nb11,nb12,nb13,ne0,ne1,r2,r3,nullptr,tgpig,tiisg,sgitg);
|
| 1297 |
}
|
| 1298 |
|
| 1299 |
kernel void kernel_mul_mv_q5_0_f32(
|
|
|
|
| 1306 |
constant uint64_t & nb00,
|
| 1307 |
constant uint64_t & nb01,
|
| 1308 |
constant uint64_t & nb02,
|
| 1309 |
+
constant uint64_t & nb03,
|
| 1310 |
constant int64_t & ne10,
|
| 1311 |
constant int64_t & ne11,
|
| 1312 |
constant int64_t & ne12,
|
| 1313 |
constant uint64_t & nb10,
|
| 1314 |
constant uint64_t & nb11,
|
| 1315 |
constant uint64_t & nb12,
|
| 1316 |
+
constant uint64_t & nb13,
|
| 1317 |
constant int64_t & ne0,
|
| 1318 |
constant int64_t & ne1,
|
| 1319 |
constant uint & r2,
|
|
|
|
| 1321 |
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 1322 |
uint tiisg[[thread_index_in_simdgroup]],
|
| 1323 |
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
| 1324 |
+
mul_vec_q_n_f32_impl<block_q5_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,nb01,nb02,nb03,ne10,ne12,nb11,nb12,nb13,ne0,ne1,r2,r3,nullptr,tgpig,tiisg,sgitg);
|
| 1325 |
}
|
| 1326 |
|
| 1327 |
kernel void kernel_mul_mv_q5_1_f32(
|
|
|
|
| 1334 |
constant uint64_t & nb00,
|
| 1335 |
constant uint64_t & nb01,
|
| 1336 |
constant uint64_t & nb02,
|
| 1337 |
+
constant uint64_t & nb03,
|
| 1338 |
constant int64_t & ne10,
|
| 1339 |
constant int64_t & ne11,
|
| 1340 |
constant int64_t & ne12,
|
| 1341 |
constant uint64_t & nb10,
|
| 1342 |
constant uint64_t & nb11,
|
| 1343 |
constant uint64_t & nb12,
|
| 1344 |
+
constant uint64_t & nb13,
|
| 1345 |
constant int64_t & ne0,
|
| 1346 |
constant int64_t & ne1,
|
| 1347 |
constant uint & r2,
|
|
|
|
| 1349 |
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 1350 |
uint tiisg[[thread_index_in_simdgroup]],
|
| 1351 |
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
| 1352 |
+
mul_vec_q_n_f32_impl<block_q5_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,nb01,nb02,nb03,ne10,ne12,nb11,nb12,nb13,ne0,ne1,r2,r3,nullptr,tgpig,tiisg,sgitg);
|
| 1353 |
}
|
| 1354 |
|
| 1355 |
|
|
|
|
| 1362 |
int64_t ne00,
|
| 1363 |
int64_t ne01,
|
| 1364 |
int64_t ne02,
|
| 1365 |
+
uint64_t nb01,
|
| 1366 |
+
uint64_t nb02,
|
| 1367 |
+
uint64_t nb03,
|
| 1368 |
int64_t ne10,
|
| 1369 |
int64_t ne12,
|
| 1370 |
+
uint64_t nb11,
|
| 1371 |
+
uint64_t nb12,
|
| 1372 |
+
uint64_t nb13,
|
| 1373 |
int64_t ne0,
|
| 1374 |
int64_t ne1,
|
| 1375 |
uint r2,
|
|
|
|
| 1392 |
const uint i12 = im%ne12;
|
| 1393 |
const uint i13 = im/ne12;
|
| 1394 |
|
| 1395 |
+
//const uint offset0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03;
|
| 1396 |
+
const uint offset1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13;
|
| 1397 |
|
| 1398 |
+
//device const block_q8_0 * x = (device const block_q8_0 *) ((device char *) src0 + offset0);
|
| 1399 |
+
device const float * y = (device const float *) ((device char *) src1 + offset1);
|
| 1400 |
+
|
| 1401 |
+
// pointers to src0 rows
|
| 1402 |
+
device const block_q8_0 * ax[nr];
|
| 1403 |
+
for (int row = 0; row < nr; ++row) {
|
| 1404 |
+
const uint offset0 = (first_row + row)*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03;
|
| 1405 |
+
|
| 1406 |
+
ax[row] = (device const block_q8_0 *) ((device char *) src0 + offset0);
|
| 1407 |
+
}
|
| 1408 |
|
| 1409 |
float yl[NB_Q8_0];
|
| 1410 |
float sumf[nr]={0.f};
|
|
|
|
| 1421 |
}
|
| 1422 |
|
| 1423 |
for (int row = 0; row < nr; row++) {
|
| 1424 |
+
device const int8_t * qs = ax[row][ib].qs + NB_Q8_0*il;
|
| 1425 |
float sumq = 0.f;
|
| 1426 |
for (int iq = 0; iq < NB_Q8_0; ++iq) {
|
| 1427 |
sumq += qs[iq] * yl[iq];
|
| 1428 |
}
|
| 1429 |
+
sumf[row] += sumq*ax[row][ib].d;
|
| 1430 |
}
|
| 1431 |
|
| 1432 |
yb += NB_Q8_0 * nw;
|
|
|
|
| 1451 |
constant uint64_t & nb00,
|
| 1452 |
constant uint64_t & nb01,
|
| 1453 |
constant uint64_t & nb02,
|
| 1454 |
+
constant uint64_t & nb03,
|
| 1455 |
constant int64_t & ne10,
|
| 1456 |
constant int64_t & ne11,
|
| 1457 |
constant int64_t & ne12,
|
| 1458 |
constant uint64_t & nb10,
|
| 1459 |
constant uint64_t & nb11,
|
| 1460 |
constant uint64_t & nb12,
|
| 1461 |
+
constant uint64_t & nb13,
|
| 1462 |
constant int64_t & ne0,
|
| 1463 |
constant int64_t & ne1,
|
| 1464 |
constant uint & r2,
|
|
|
|
| 1466 |
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 1467 |
uint tiisg[[thread_index_in_simdgroup]],
|
| 1468 |
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
| 1469 |
+
kernel_mul_mv_q8_0_f32_impl(src0,src1,dst,ne00,ne01,ne02,nb01,nb02,nb03,ne10,ne12,nb11,nb12,nb13,ne0,ne1,r2,r3,nullptr,tgpig,tiisg,sgitg);
|
| 1470 |
}
|
| 1471 |
|
| 1472 |
#define N_MV_T_T 4
|
|
|
|
| 1482 |
uint64_t nb00,
|
| 1483 |
uint64_t nb01,
|
| 1484 |
uint64_t nb02,
|
| 1485 |
+
uint64_t nb03,
|
| 1486 |
int64_t ne10,
|
| 1487 |
int64_t ne11,
|
| 1488 |
int64_t ne12,
|
| 1489 |
uint64_t nb10,
|
| 1490 |
uint64_t nb11,
|
| 1491 |
uint64_t nb12,
|
| 1492 |
+
uint64_t nb13,
|
| 1493 |
int64_t ne0,
|
| 1494 |
int64_t ne1,
|
| 1495 |
uint r2,
|
|
|
|
| 1503 |
const uint i12 = im%ne12;
|
| 1504 |
const uint i13 = im/ne12;
|
| 1505 |
|
| 1506 |
+
const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03;
|
| 1507 |
|
| 1508 |
device const T0 * x = (device const T0 *) (src0 + offset0);
|
| 1509 |
|
|
|
|
| 1514 |
break;
|
| 1515 |
}
|
| 1516 |
|
| 1517 |
+
const uint offset1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13;
|
| 1518 |
+
|
| 1519 |
+
device const T1 * y = (device const T1 *) (src1 + offset1);
|
| 1520 |
|
| 1521 |
float sumf = 0;
|
| 1522 |
for (int i = tiisg; i < ne00; i += 32) {
|
|
|
|
| 1536 |
break;
|
| 1537 |
}
|
| 1538 |
|
| 1539 |
+
const uint offset1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13;
|
| 1540 |
+
|
| 1541 |
+
device const T1 * y = (device const T1 *) (src1 + offset1);
|
| 1542 |
device const T14 * y4 = (device const T14 *) y;
|
| 1543 |
|
| 1544 |
float sumf = 0;
|
|
|
|
| 1566 |
constant uint64_t & nb00,
|
| 1567 |
constant uint64_t & nb01,
|
| 1568 |
constant uint64_t & nb02,
|
| 1569 |
+
constant uint64_t & nb03,
|
| 1570 |
constant int64_t & ne10,
|
| 1571 |
constant int64_t & ne11,
|
| 1572 |
constant int64_t & ne12,
|
| 1573 |
constant uint64_t & nb10,
|
| 1574 |
constant uint64_t & nb11,
|
| 1575 |
constant uint64_t & nb12,
|
| 1576 |
+
constant uint64_t & nb13,
|
| 1577 |
constant int64_t & ne0,
|
| 1578 |
constant int64_t & ne1,
|
| 1579 |
constant uint & r2,
|
|
|
|
| 1590 |
nb00,
|
| 1591 |
nb01,
|
| 1592 |
nb02,
|
| 1593 |
+
nb03,
|
| 1594 |
ne10,
|
| 1595 |
ne11,
|
| 1596 |
ne12,
|
| 1597 |
nb10,
|
| 1598 |
nb11,
|
| 1599 |
nb12,
|
| 1600 |
+
nb13,
|
| 1601 |
ne0,
|
| 1602 |
ne1,
|
| 1603 |
r2,
|
|
|
|
| 1623 |
constant uint64_t & nb00,
|
| 1624 |
constant uint64_t & nb01,
|
| 1625 |
constant uint64_t & nb02,
|
| 1626 |
+
constant uint64_t & nb03,
|
| 1627 |
constant int64_t & ne10,
|
| 1628 |
constant int64_t & ne11,
|
| 1629 |
constant int64_t & ne12,
|
| 1630 |
constant uint64_t & nb10,
|
| 1631 |
constant uint64_t & nb11,
|
| 1632 |
constant uint64_t & nb12,
|
| 1633 |
+
constant uint64_t & nb13,
|
| 1634 |
constant int64_t & ne0,
|
| 1635 |
constant int64_t & ne1,
|
| 1636 |
constant uint & r2,
|
|
|
|
| 1645 |
const uint i12 = im%ne12;
|
| 1646 |
const uint i13 = im/ne12;
|
| 1647 |
|
| 1648 |
+
const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03;
|
| 1649 |
+
const uint offset1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13;
|
| 1650 |
|
| 1651 |
device const T * x = (device const T *) (src0 + offset0);
|
| 1652 |
+
device const float * y = (device const float *) (src1 + offset1);
|
| 1653 |
|
| 1654 |
float sumf = 0;
|
| 1655 |
if (ne00 < 128) {
|
|
|
|
| 1693 |
constant uint64_t & nb00,
|
| 1694 |
constant uint64_t & nb01,
|
| 1695 |
constant uint64_t & nb02,
|
| 1696 |
+
constant uint64_t & nb03,
|
| 1697 |
constant int64_t & ne10,
|
| 1698 |
constant int64_t & ne11,
|
| 1699 |
constant int64_t & ne12,
|
| 1700 |
constant uint64_t & nb10,
|
| 1701 |
constant uint64_t & nb11,
|
| 1702 |
constant uint64_t & nb12,
|
| 1703 |
+
constant uint64_t & nb13,
|
| 1704 |
constant int64_t & ne0,
|
| 1705 |
constant int64_t & ne1,
|
| 1706 |
constant uint & r2,
|
|
|
|
| 1715 |
const uint i12 = im%ne12;
|
| 1716 |
const uint i13 = im/ne12;
|
| 1717 |
|
| 1718 |
+
const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03;
|
| 1719 |
|
| 1720 |
device const T4 * x4 = (device const T4 *) (src0 + offset0);
|
| 1721 |
|
| 1722 |
for (int r1 = 0; r1 < nrows; ++r1) {
|
| 1723 |
+
const uint offset1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13;
|
| 1724 |
+
|
| 1725 |
+
device const float4 * y4 = (device const float4 *) (src1 + offset1);
|
| 1726 |
|
| 1727 |
float sumf = 0;
|
| 1728 |
for (int i = tiisg; i < ne00/4; i += 32) {
|
|
|
|
| 3482 |
int64_t ne00,
|
| 3483 |
int64_t ne01,
|
| 3484 |
int64_t ne02,
|
| 3485 |
+
uint64_t nb01,
|
| 3486 |
+
uint64_t nb02,
|
| 3487 |
+
uint64_t nb03,
|
| 3488 |
int64_t ne10,
|
| 3489 |
int64_t ne12,
|
| 3490 |
+
uint64_t nb11,
|
| 3491 |
+
uint64_t nb12,
|
| 3492 |
+
uint64_t nb13,
|
| 3493 |
int64_t ne0,
|
| 3494 |
int64_t ne1,
|
| 3495 |
uint r2,
|
|
|
|
| 3505 |
const int im = tgpig.z;
|
| 3506 |
|
| 3507 |
const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
|
|
|
|
| 3508 |
|
| 3509 |
const uint i12 = im%ne12;
|
| 3510 |
const uint i13 = im/ne12;
|
| 3511 |
|
| 3512 |
+
const uint offset0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03;
|
| 3513 |
+
const uint offset1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13;
|
| 3514 |
|
| 3515 |
+
device const block_q2_K * x = (device const block_q2_K *) ((device char *) src0 + offset0);
|
| 3516 |
+
device const float * y = (device const float *) ((device char *) src1 + offset1);
|
| 3517 |
|
| 3518 |
float yl[32];
|
| 3519 |
float sumf[N_DST]={0.f}, all_sum;
|
| 3520 |
|
|
|
|
|
|
|
| 3521 |
const int ix = tiisg/8; // 0...3
|
| 3522 |
const int it = tiisg%8; // 0...7
|
| 3523 |
const int iq = it/4; // 0 or 1
|
|
|
|
| 3562 |
(acc1[3] + 1.f/256.f * acc2[3]) * (sc[6] & 0xF) * 1.f/64.f) -
|
| 3563 |
dmin * (sumy[0] * (sc[0] & 0xF0) + sumy[1] * (sc[2] & 0xF0) + sumy[2] * (sc[4] & 0xF0) + sumy[3] * (sc[6] & 0xF0));
|
| 3564 |
|
| 3565 |
+
qs += nb01/2;
|
| 3566 |
+
sc += nb01;
|
| 3567 |
+
dh += nb01/2;
|
| 3568 |
}
|
| 3569 |
|
| 3570 |
y4 += 4 * QK_K;
|
|
|
|
| 3589 |
constant uint64_t & nb00,
|
| 3590 |
constant uint64_t & nb01,
|
| 3591 |
constant uint64_t & nb02,
|
| 3592 |
+
constant uint64_t & nb03,
|
| 3593 |
constant int64_t & ne10,
|
| 3594 |
constant int64_t & ne11,
|
| 3595 |
constant int64_t & ne12,
|
| 3596 |
constant uint64_t & nb10,
|
| 3597 |
constant uint64_t & nb11,
|
| 3598 |
constant uint64_t & nb12,
|
| 3599 |
+
constant uint64_t & nb13,
|
| 3600 |
constant int64_t & ne0,
|
| 3601 |
constant int64_t & ne1,
|
| 3602 |
constant uint & r2,
|
|
|
|
| 3605 |
uint tiisg[[thread_index_in_simdgroup]],
|
| 3606 |
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
| 3607 |
|
| 3608 |
+
kernel_mul_mv_q2_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb01, nb02, nb03, ne10, ne12, nb11, nb12, nb13, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg);
|
| 3609 |
}
|
| 3610 |
|
| 3611 |
void kernel_mul_mv_q3_K_f32_impl(
|
|
|
|
| 3615 |
int64_t ne00,
|
| 3616 |
int64_t ne01,
|
| 3617 |
int64_t ne02,
|
| 3618 |
+
uint64_t nb01,
|
| 3619 |
+
uint64_t nb02,
|
| 3620 |
+
uint64_t nb03,
|
| 3621 |
int64_t ne10,
|
| 3622 |
int64_t ne12,
|
| 3623 |
+
uint64_t nb11,
|
| 3624 |
+
uint64_t nb12,
|
| 3625 |
+
uint64_t nb13,
|
| 3626 |
int64_t ne0,
|
| 3627 |
int64_t ne1,
|
| 3628 |
uint r2,
|
|
|
|
| 3643 |
const uint i12 = im%ne12;
|
| 3644 |
const uint i13 = im/ne12;
|
| 3645 |
|
| 3646 |
+
const uint offset0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03;
|
| 3647 |
+
const uint offset1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13;
|
| 3648 |
|
| 3649 |
+
device const block_q3_K * x = (device const block_q3_K *) ((device char *) src0 + offset0);
|
| 3650 |
+
device const float * yy = (device const float *) ((device char *) src1 + offset1);
|
| 3651 |
|
| 3652 |
float yl[32];
|
| 3653 |
|
|
|
|
| 3687 |
const int q_offset = 32*ip + l0;
|
| 3688 |
const int y_offset = 128*ip + 32*il + l0;
|
| 3689 |
|
|
|
|
|
|
|
| 3690 |
device const float * y1 = yy + ix*QK_K + y_offset;
|
| 3691 |
|
| 3692 |
uint32_t scales32, aux32;
|
|
|
|
| 3696 |
float sumf1[2] = {0.f};
|
| 3697 |
float sumf2[2] = {0.f};
|
| 3698 |
for (int i = ix; i < nb; i += 4) {
|
|
|
|
| 3699 |
for (int l = 0; l < 8; ++l) {
|
| 3700 |
yl[l+ 0] = y1[l+ 0];
|
| 3701 |
yl[l+ 8] = y1[l+16];
|
|
|
|
| 3709 |
device const half * dh = &x[i].d;
|
| 3710 |
|
| 3711 |
for (int row = 0; row < 2; ++row) {
|
|
|
|
| 3712 |
const float d_all = (float)dh[0];
|
| 3713 |
|
| 3714 |
scales16[0] = a[4];
|
|
|
|
| 3748 |
sumf1[row] += d1 * (scales[1] - 32);
|
| 3749 |
sumf2[row] += d2 * (scales[3] - 32);
|
| 3750 |
|
| 3751 |
+
q += nb01/2;
|
| 3752 |
+
h += nb01/2;
|
| 3753 |
+
a += nb01/2;
|
| 3754 |
+
dh += nb01/2;
|
|
|
|
| 3755 |
}
|
| 3756 |
|
| 3757 |
y1 += 4 * QK_K;
|
|
|
|
| 3758 |
}
|
| 3759 |
|
| 3760 |
for (int row = 0; row < 2; ++row) {
|
|
|
|
| 3779 |
constant uint64_t & nb00,
|
| 3780 |
constant uint64_t & nb01,
|
| 3781 |
constant uint64_t & nb02,
|
| 3782 |
+
constant uint64_t & nb03,
|
| 3783 |
constant int64_t & ne10,
|
| 3784 |
constant int64_t & ne11,
|
| 3785 |
constant int64_t & ne12,
|
| 3786 |
constant uint64_t & nb10,
|
| 3787 |
constant uint64_t & nb11,
|
| 3788 |
constant uint64_t & nb12,
|
| 3789 |
+
constant uint64_t & nb13,
|
| 3790 |
constant int64_t & ne0,
|
| 3791 |
constant int64_t & ne1,
|
| 3792 |
constant uint & r2,
|
|
|
|
| 3795 |
uint tiisg[[thread_index_in_simdgroup]],
|
| 3796 |
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
| 3797 |
|
| 3798 |
+
kernel_mul_mv_q3_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb01, nb02, nb03, ne10, ne12, nb11, nb12, nb13, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg);
|
| 3799 |
}
|
| 3800 |
|
| 3801 |
void kernel_mul_mv_q4_K_f32_impl(
|
|
|
|
| 3805 |
int64_t ne00,
|
| 3806 |
int64_t ne01,
|
| 3807 |
int64_t ne02,
|
| 3808 |
+
uint64_t nb01,
|
| 3809 |
+
uint64_t nb02,
|
| 3810 |
+
uint64_t nb03,
|
| 3811 |
int64_t ne10,
|
| 3812 |
int64_t ne12,
|
| 3813 |
+
uint64_t nb11,
|
| 3814 |
+
uint64_t nb12,
|
| 3815 |
+
uint64_t nb13,
|
| 3816 |
int64_t ne0,
|
| 3817 |
int64_t ne1,
|
| 3818 |
uint r2,
|
|
|
|
| 3837 |
const int im = tgpig.z;
|
| 3838 |
//const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
|
| 3839 |
const int first_row = r0 * N_DST;
|
|
|
|
| 3840 |
|
| 3841 |
const uint i12 = im%ne12;
|
| 3842 |
const uint i13 = im/ne12;
|
| 3843 |
|
| 3844 |
+
const uint offset0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03;
|
| 3845 |
+
const uint offset1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13;
|
| 3846 |
|
| 3847 |
+
device const block_q4_K * x = (device const block_q4_K *) ((device char *) src0 + offset0);
|
| 3848 |
+
device const float * y = (device const float *) ((device char *) src1 + offset1);
|
| 3849 |
|
| 3850 |
float yl[16];
|
| 3851 |
float yh[16];
|
| 3852 |
float sumf[N_DST]={0.f}, all_sum;
|
| 3853 |
|
|
|
|
|
|
|
| 3854 |
device const float * y4 = y + ix * QK_K + 64 * iq + 8 * ir;
|
| 3855 |
|
| 3856 |
uint16_t sc16[4];
|
| 3857 |
thread const uint8_t * sc8 = (thread const uint8_t *)sc16;
|
| 3858 |
|
| 3859 |
for (int ib = ix; ib < nb; ib += 4) {
|
|
|
|
| 3860 |
float4 sumy = {0.f, 0.f, 0.f, 0.f};
|
| 3861 |
for (int i = 0; i < 8; ++i) {
|
| 3862 |
yl[i+0] = y4[i+ 0]; sumy[0] += yl[i+0];
|
|
|
|
| 3870 |
device const half * dh = &x[ib].d;
|
| 3871 |
|
| 3872 |
for (int row = 0; row < N_DST; row++) {
|
|
|
|
| 3873 |
sc16[0] = sc[0] & kmask1;
|
| 3874 |
sc16[1] = sc[2] & kmask1;
|
| 3875 |
sc16[2] = ((sc[4] >> 0) & kmask2) | ((sc[0] & kmask3) >> 2);
|
|
|
|
| 3898 |
(acc2[2] + 1.f/256.f * acc2[3]) * sc8[5] * 1.f/16.f) -
|
| 3899 |
dmin * (sumy[0] * sc8[2] + sumy[1] * sc8[3] + sumy[2] * sc8[6] + sumy[3] * sc8[7]);
|
| 3900 |
|
| 3901 |
+
q1 += nb01/2;
|
| 3902 |
+
sc += nb01/2;
|
| 3903 |
+
dh += nb01/2;
|
| 3904 |
}
|
| 3905 |
|
| 3906 |
y4 += 4 * QK_K;
|
|
|
|
| 3925 |
constant uint64_t & nb00,
|
| 3926 |
constant uint64_t & nb01,
|
| 3927 |
constant uint64_t & nb02,
|
| 3928 |
+
constant uint64_t & nb03,
|
| 3929 |
constant int64_t & ne10,
|
| 3930 |
constant int64_t & ne11,
|
| 3931 |
constant int64_t & ne12,
|
| 3932 |
constant uint64_t & nb10,
|
| 3933 |
constant uint64_t & nb11,
|
| 3934 |
constant uint64_t & nb12,
|
| 3935 |
+
constant uint64_t & nb13,
|
| 3936 |
constant int64_t & ne0,
|
| 3937 |
constant int64_t & ne1,
|
| 3938 |
constant uint & r2,
|
|
|
|
| 3941 |
uint tiisg[[thread_index_in_simdgroup]],
|
| 3942 |
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
| 3943 |
|
| 3944 |
+
kernel_mul_mv_q4_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb01, nb02, nb03, ne10, ne12, nb11, nb12, nb13, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg);
|
| 3945 |
}
|
| 3946 |
|
| 3947 |
void kernel_mul_mv_q5_K_f32_impl(
|
|
|
|
| 3951 |
int64_t ne00,
|
| 3952 |
int64_t ne01,
|
| 3953 |
int64_t ne02,
|
| 3954 |
+
uint64_t nb01,
|
| 3955 |
+
uint64_t nb02,
|
| 3956 |
+
uint64_t nb03,
|
| 3957 |
int64_t ne10,
|
| 3958 |
int64_t ne12,
|
| 3959 |
+
uint64_t nb11,
|
| 3960 |
+
uint64_t nb12,
|
| 3961 |
+
uint64_t nb13,
|
| 3962 |
int64_t ne0,
|
| 3963 |
int64_t ne1,
|
| 3964 |
uint r2,
|
|
|
|
| 3979 |
const uint i12 = im%ne12;
|
| 3980 |
const uint i13 = im/ne12;
|
| 3981 |
|
| 3982 |
+
const uint offset0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03;
|
| 3983 |
+
const uint offset1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13;
|
| 3984 |
|
| 3985 |
+
device const block_q5_K * x = (device const block_q5_K *) ((device char *) src0 + offset0);
|
| 3986 |
+
device const float * yy = (device const float *) ((device char *) src1 + offset1);
|
| 3987 |
|
| 3988 |
float sumf[2]={0.f};
|
| 3989 |
|
|
|
|
|
|
|
| 3990 |
float yl[16], yh[16];
|
| 3991 |
|
| 3992 |
const uint16_t kmask1 = 0x3f3f;
|
|
|
|
| 4014 |
device const float * y1 = yy + ix*QK_K + y_offset;
|
| 4015 |
|
| 4016 |
for (int i = ix; i < nb; i += 4) {
|
|
|
|
| 4017 |
device const uint8_t * q1 = x[i].qs + q_offset;
|
| 4018 |
device const uint8_t * qh = x[i].qh + l0;
|
| 4019 |
device const half * dh = &x[i].d;
|
|
|
|
| 4029 |
}
|
| 4030 |
|
| 4031 |
for (int row = 0; row < 2; ++row) {
|
|
|
|
| 4032 |
device const uint8_t * q2 = q1 + 64;
|
| 4033 |
|
| 4034 |
sc16[0] = a[0] & kmask1;
|
|
|
|
| 4057 |
sc8[5] * (acc1[3]/16.f + 16.f*acc2[3])) -
|
| 4058 |
dmin * (sumy[0] * sc8[2] + sumy[1] * sc8[3] + sumy[2] * sc8[6] + sumy[3] * sc8[7]);
|
| 4059 |
|
| 4060 |
+
q1 += nb01;
|
| 4061 |
+
qh += nb01;
|
| 4062 |
+
dh += nb01/2;
|
| 4063 |
+
a += nb01/2;
|
|
|
|
| 4064 |
}
|
| 4065 |
|
| 4066 |
y1 += 4 * QK_K;
|
|
|
|
| 4067 |
}
|
| 4068 |
|
| 4069 |
for (int row = 0; row < 2; ++row) {
|
|
|
|
| 4085 |
constant uint64_t & nb00,
|
| 4086 |
constant uint64_t & nb01,
|
| 4087 |
constant uint64_t & nb02,
|
| 4088 |
+
constant uint64_t & nb03,
|
| 4089 |
constant int64_t & ne10,
|
| 4090 |
constant int64_t & ne11,
|
| 4091 |
constant int64_t & ne12,
|
| 4092 |
constant uint64_t & nb10,
|
| 4093 |
constant uint64_t & nb11,
|
| 4094 |
constant uint64_t & nb12,
|
| 4095 |
+
constant uint64_t & nb13,
|
| 4096 |
constant int64_t & ne0,
|
| 4097 |
constant int64_t & ne1,
|
| 4098 |
constant uint & r2,
|
|
|
|
| 4101 |
uint tiisg[[thread_index_in_simdgroup]],
|
| 4102 |
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
| 4103 |
|
| 4104 |
+
kernel_mul_mv_q5_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb01, nb02, nb03, ne10, ne12, nb11, nb12, nb13, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg);
|
| 4105 |
}
|
| 4106 |
|
| 4107 |
void kernel_mul_mv_q6_K_f32_impl(
|
|
|
|
| 4111 |
int64_t ne00,
|
| 4112 |
int64_t ne01,
|
| 4113 |
int64_t ne02,
|
| 4114 |
+
uint64_t nb01,
|
| 4115 |
+
uint64_t nb02,
|
| 4116 |
+
uint64_t nb03,
|
| 4117 |
int64_t ne10,
|
| 4118 |
int64_t ne12,
|
| 4119 |
+
uint64_t nb11,
|
| 4120 |
+
uint64_t nb12,
|
| 4121 |
+
uint64_t nb13,
|
| 4122 |
int64_t ne0,
|
| 4123 |
int64_t ne1,
|
| 4124 |
uint r2,
|
|
|
|
| 4144 |
const uint i12 = im%ne12;
|
| 4145 |
const uint i13 = im/ne12;
|
| 4146 |
|
| 4147 |
+
const uint offset0 = row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03;
|
| 4148 |
+
const uint offset1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13;
|
| 4149 |
|
| 4150 |
+
device const block_q6_K * x = (device const block_q6_K *) ((device char *) src0 + offset0);
|
| 4151 |
+
device const float * yy = (device const float *) ((device char *) src1 + offset1);
|
| 4152 |
|
| 4153 |
float sumf = 0;
|
| 4154 |
|
|
|
|
| 4204 |
constant uint64_t & nb00,
|
| 4205 |
constant uint64_t & nb01,
|
| 4206 |
constant uint64_t & nb02,
|
| 4207 |
+
constant uint64_t & nb03,
|
| 4208 |
constant int64_t & ne10,
|
| 4209 |
constant int64_t & ne11,
|
| 4210 |
constant int64_t & ne12,
|
| 4211 |
constant uint64_t & nb10,
|
| 4212 |
constant uint64_t & nb11,
|
| 4213 |
constant uint64_t & nb12,
|
| 4214 |
+
constant uint64_t & nb13,
|
| 4215 |
constant int64_t & ne0,
|
| 4216 |
constant int64_t & ne1,
|
| 4217 |
constant uint & r2,
|
|
|
|
| 4220 |
uint tiisg[[thread_index_in_simdgroup]],
|
| 4221 |
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
| 4222 |
|
| 4223 |
+
kernel_mul_mv_q6_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb01, nb02, nb03, ne10, ne12, nb11, nb12, nb13, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg);
|
| 4224 |
}
|
| 4225 |
|
| 4226 |
// ======================= "True" 2-bit
|
|
|
|
| 4232 |
int64_t ne00,
|
| 4233 |
int64_t ne01,
|
| 4234 |
int64_t ne02,
|
| 4235 |
+
uint64_t nb01,
|
| 4236 |
+
uint64_t nb02,
|
| 4237 |
+
uint64_t nb03,
|
| 4238 |
int64_t ne10,
|
| 4239 |
int64_t ne12,
|
| 4240 |
+
uint64_t nb11,
|
| 4241 |
+
uint64_t nb12,
|
| 4242 |
+
uint64_t nb13,
|
| 4243 |
int64_t ne0,
|
| 4244 |
int64_t ne1,
|
| 4245 |
uint r2,
|
|
|
|
| 4255 |
const int im = tgpig.z;
|
| 4256 |
|
| 4257 |
const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
|
|
|
|
| 4258 |
|
| 4259 |
const uint i12 = im%ne12;
|
| 4260 |
const uint i13 = im/ne12;
|
| 4261 |
|
| 4262 |
+
const uint offset0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03;
|
| 4263 |
+
const uint offset1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13;
|
| 4264 |
|
| 4265 |
+
device const block_iq2_xxs * x = (device const block_iq2_xxs *) ((device char *) src0 + offset0);
|
| 4266 |
+
device const float * y = (device const float *) ((device char *) src1 + offset1);
|
| 4267 |
|
| 4268 |
float yl[32];
|
| 4269 |
float sumf[N_DST]={0.f}, all_sum;
|
|
|
|
| 4316 |
}
|
| 4317 |
sumf[row] += d * sum;
|
| 4318 |
|
| 4319 |
+
dh += nb01/2;
|
| 4320 |
+
q2 += nb01/2;
|
| 4321 |
}
|
| 4322 |
|
| 4323 |
y4 += 32 * 32;
|
|
|
|
| 4342 |
constant uint64_t & nb00,
|
| 4343 |
constant uint64_t & nb01,
|
| 4344 |
constant uint64_t & nb02,
|
| 4345 |
+
constant uint64_t & nb03,
|
| 4346 |
constant int64_t & ne10,
|
| 4347 |
constant int64_t & ne11,
|
| 4348 |
constant int64_t & ne12,
|
| 4349 |
constant uint64_t & nb10,
|
| 4350 |
constant uint64_t & nb11,
|
| 4351 |
constant uint64_t & nb12,
|
| 4352 |
+
constant uint64_t & nb13,
|
| 4353 |
constant int64_t & ne0,
|
| 4354 |
constant int64_t & ne1,
|
| 4355 |
constant uint & r2,
|
|
|
|
| 4359 |
uint tiisg[[thread_index_in_simdgroup]],
|
| 4360 |
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
| 4361 |
|
| 4362 |
+
kernel_mul_mv_iq2_xxs_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb01, nb02, nb03, ne10, ne12, nb11, nb12, nb13, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
|
| 4363 |
}
|
| 4364 |
|
| 4365 |
void kernel_mul_mv_iq2_xs_f32_impl(
|
|
|
|
| 4369 |
int64_t ne00,
|
| 4370 |
int64_t ne01,
|
| 4371 |
int64_t ne02,
|
| 4372 |
+
uint64_t nb01,
|
| 4373 |
+
uint64_t nb02,
|
| 4374 |
+
uint64_t nb03,
|
| 4375 |
int64_t ne10,
|
| 4376 |
int64_t ne12,
|
| 4377 |
+
uint64_t nb11,
|
| 4378 |
+
uint64_t nb12,
|
| 4379 |
+
uint64_t nb13,
|
| 4380 |
int64_t ne0,
|
| 4381 |
int64_t ne1,
|
| 4382 |
uint r2,
|
|
|
|
| 4392 |
const int im = tgpig.z;
|
| 4393 |
|
| 4394 |
const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
|
|
|
|
| 4395 |
|
| 4396 |
const uint i12 = im%ne12;
|
| 4397 |
const uint i13 = im/ne12;
|
| 4398 |
|
| 4399 |
+
const uint offset0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03;
|
| 4400 |
+
const uint offset1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13;
|
| 4401 |
|
| 4402 |
+
device const block_iq2_xs * x = (device const block_iq2_xs *) ((device char *) src0 + offset0);
|
| 4403 |
+
device const float * y = (device const float *) ((device char *) src1 + offset1);
|
| 4404 |
|
| 4405 |
float yl[32];
|
| 4406 |
float sumf[N_DST]={0.f}, all_sum;
|
|
|
|
| 4462 |
}
|
| 4463 |
sumf[row] += d1 * sum1 + d2 * sum2;
|
| 4464 |
|
| 4465 |
+
dh += nb01/2;
|
| 4466 |
+
q2 += nb01/2;
|
| 4467 |
+
sc += nb01;
|
| 4468 |
}
|
| 4469 |
|
| 4470 |
y4 += 32 * 32;
|
|
|
|
| 4489 |
constant uint64_t & nb00,
|
| 4490 |
constant uint64_t & nb01,
|
| 4491 |
constant uint64_t & nb02,
|
| 4492 |
+
constant uint64_t & nb03,
|
| 4493 |
constant int64_t & ne10,
|
| 4494 |
constant int64_t & ne11,
|
| 4495 |
constant int64_t & ne12,
|
| 4496 |
constant uint64_t & nb10,
|
| 4497 |
constant uint64_t & nb11,
|
| 4498 |
constant uint64_t & nb12,
|
| 4499 |
+
constant uint64_t & nb13,
|
| 4500 |
constant int64_t & ne0,
|
| 4501 |
constant int64_t & ne1,
|
| 4502 |
constant uint & r2,
|
|
|
|
| 4506 |
uint tiisg[[thread_index_in_simdgroup]],
|
| 4507 |
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
| 4508 |
|
| 4509 |
+
kernel_mul_mv_iq2_xs_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb01, nb02, nb03, ne10, ne12, nb11, nb12, nb13, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
|
| 4510 |
}
|
| 4511 |
|
| 4512 |
void kernel_mul_mv_iq3_xxs_f32_impl(
|
|
|
|
| 4516 |
int64_t ne00,
|
| 4517 |
int64_t ne01,
|
| 4518 |
int64_t ne02,
|
| 4519 |
+
uint64_t nb01,
|
| 4520 |
+
uint64_t nb02,
|
| 4521 |
+
uint64_t nb03,
|
| 4522 |
int64_t ne10,
|
| 4523 |
int64_t ne12,
|
| 4524 |
+
uint64_t nb11,
|
| 4525 |
+
uint64_t nb12,
|
| 4526 |
+
uint64_t nb13,
|
| 4527 |
int64_t ne0,
|
| 4528 |
int64_t ne1,
|
| 4529 |
uint r2,
|
|
|
|
| 4539 |
const int im = tgpig.z;
|
| 4540 |
|
| 4541 |
const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
|
|
|
|
| 4542 |
|
| 4543 |
const uint i12 = im%ne12;
|
| 4544 |
const uint i13 = im/ne12;
|
| 4545 |
|
| 4546 |
+
const uint offset0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03;
|
| 4547 |
+
const uint offset1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13;
|
| 4548 |
|
| 4549 |
+
device const block_iq3_xxs * x = (device const block_iq3_xxs *) ((device char *) src0 + offset0);
|
| 4550 |
+
device const float * y = (device const float *) ((device char *) src1 + offset1);
|
| 4551 |
|
| 4552 |
float yl[32];
|
| 4553 |
float sumf[N_DST]={0.f}, all_sum;
|
|
|
|
| 4602 |
}
|
| 4603 |
sumf[row] += d * (sum[0] + sum[1]);
|
| 4604 |
|
| 4605 |
+
dh += nb01/2;
|
| 4606 |
+
q3 += nb01;
|
| 4607 |
+
gas += nb01/2;
|
| 4608 |
}
|
| 4609 |
|
| 4610 |
y4 += 32 * 32;
|
|
|
|
| 4629 |
constant uint64_t & nb00,
|
| 4630 |
constant uint64_t & nb01,
|
| 4631 |
constant uint64_t & nb02,
|
| 4632 |
+
constant uint64_t & nb03,
|
| 4633 |
constant int64_t & ne10,
|
| 4634 |
constant int64_t & ne11,
|
| 4635 |
constant int64_t & ne12,
|
| 4636 |
constant uint64_t & nb10,
|
| 4637 |
constant uint64_t & nb11,
|
| 4638 |
constant uint64_t & nb12,
|
| 4639 |
+
constant uint64_t & nb13,
|
| 4640 |
constant int64_t & ne0,
|
| 4641 |
constant int64_t & ne1,
|
| 4642 |
constant uint & r2,
|
|
|
|
| 4646 |
uint tiisg[[thread_index_in_simdgroup]],
|
| 4647 |
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
| 4648 |
|
| 4649 |
+
kernel_mul_mv_iq3_xxs_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb01, nb02, nb03, ne10, ne12, nb11, nb12, nb13, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
|
| 4650 |
}
|
| 4651 |
|
| 4652 |
void kernel_mul_mv_iq3_s_f32_impl(
|
|
|
|
| 4656 |
int64_t ne00,
|
| 4657 |
int64_t ne01,
|
| 4658 |
int64_t ne02,
|
| 4659 |
+
uint64_t nb01,
|
| 4660 |
+
uint64_t nb02,
|
| 4661 |
+
uint64_t nb03,
|
| 4662 |
int64_t ne10,
|
| 4663 |
int64_t ne12,
|
| 4664 |
+
uint64_t nb11,
|
| 4665 |
+
uint64_t nb12,
|
| 4666 |
+
uint64_t nb13,
|
| 4667 |
int64_t ne0,
|
| 4668 |
int64_t ne1,
|
| 4669 |
uint r2,
|
|
|
|
| 4679 |
const int im = tgpig.z;
|
| 4680 |
|
| 4681 |
const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
|
|
|
|
| 4682 |
|
| 4683 |
const uint i12 = im%ne12;
|
| 4684 |
const uint i13 = im/ne12;
|
| 4685 |
|
| 4686 |
+
const uint offset0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03;
|
| 4687 |
+
const uint offset1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13;
|
| 4688 |
|
| 4689 |
+
device const block_iq3_s * x = (device const block_iq3_s *) ((device char *) src0 + offset0);
|
| 4690 |
+
device const float * y = (device const float *) ((device char *) src1 + offset1);
|
| 4691 |
|
| 4692 |
float yl[32];
|
| 4693 |
float sumf[N_DST]={0.f}, all_sum;
|
|
|
|
| 4740 |
}
|
| 4741 |
sumf[row] += d * (sum[0] + sum[1]);
|
| 4742 |
|
| 4743 |
+
dh += nb01/2;
|
| 4744 |
+
qs += nb01;
|
| 4745 |
+
qh += nb01;
|
| 4746 |
+
sc += nb01;
|
| 4747 |
+
signs += nb01;
|
| 4748 |
}
|
| 4749 |
|
| 4750 |
y4 += 32 * 32;
|
|
|
|
| 4769 |
constant uint64_t & nb00,
|
| 4770 |
constant uint64_t & nb01,
|
| 4771 |
constant uint64_t & nb02,
|
| 4772 |
+
constant uint64_t & nb03,
|
| 4773 |
constant int64_t & ne10,
|
| 4774 |
constant int64_t & ne11,
|
| 4775 |
constant int64_t & ne12,
|
| 4776 |
constant uint64_t & nb10,
|
| 4777 |
constant uint64_t & nb11,
|
| 4778 |
constant uint64_t & nb12,
|
| 4779 |
+
constant uint64_t & nb13,
|
| 4780 |
constant int64_t & ne0,
|
| 4781 |
constant int64_t & ne1,
|
| 4782 |
constant uint & r2,
|
|
|
|
| 4786 |
uint tiisg[[thread_index_in_simdgroup]],
|
| 4787 |
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
| 4788 |
|
| 4789 |
+
kernel_mul_mv_iq3_s_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb01, nb02, nb03, ne10, ne12, nb11, nb12, nb13, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
|
| 4790 |
}
|
| 4791 |
|
| 4792 |
void kernel_mul_mv_iq2_s_f32_impl(
|
|
|
|
| 4796 |
int64_t ne00,
|
| 4797 |
int64_t ne01,
|
| 4798 |
int64_t ne02,
|
| 4799 |
+
uint64_t nb01,
|
| 4800 |
+
uint64_t nb02,
|
| 4801 |
+
uint64_t nb03,
|
| 4802 |
int64_t ne10,
|
| 4803 |
int64_t ne12,
|
| 4804 |
+
uint64_t nb11,
|
| 4805 |
+
uint64_t nb12,
|
| 4806 |
+
uint64_t nb13,
|
| 4807 |
int64_t ne0,
|
| 4808 |
int64_t ne1,
|
| 4809 |
uint r2,
|
|
|
|
| 4819 |
const int im = tgpig.z;
|
| 4820 |
|
| 4821 |
const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
|
|
|
|
| 4822 |
|
| 4823 |
const uint i12 = im%ne12;
|
| 4824 |
const uint i13 = im/ne12;
|
| 4825 |
|
| 4826 |
+
const uint offset0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03;
|
| 4827 |
+
const uint offset1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13;
|
| 4828 |
|
| 4829 |
+
device const block_iq2_s * x = (device const block_iq2_s *) ((device char *) src0 + offset0);
|
| 4830 |
+
device const float * y = (device const float *) ((device char *) src1 + offset1);
|
| 4831 |
|
| 4832 |
float yl[32];
|
| 4833 |
float sumf[N_DST]={0.f}, all_sum;
|
|
|
|
| 4881 |
}
|
| 4882 |
sumf[row] += d1 * sum[0] + d2 * sum[1];
|
| 4883 |
|
| 4884 |
+
dh += nb01/2;
|
| 4885 |
+
qs += nb01;
|
| 4886 |
+
qh += nb01;
|
| 4887 |
+
sc += nb01;
|
| 4888 |
+
signs += nb01;
|
| 4889 |
}
|
| 4890 |
|
| 4891 |
y4 += 32 * 32;
|
|
|
|
| 4910 |
constant uint64_t & nb00,
|
| 4911 |
constant uint64_t & nb01,
|
| 4912 |
constant uint64_t & nb02,
|
| 4913 |
+
constant uint64_t & nb03,
|
| 4914 |
constant int64_t & ne10,
|
| 4915 |
constant int64_t & ne11,
|
| 4916 |
constant int64_t & ne12,
|
| 4917 |
constant uint64_t & nb10,
|
| 4918 |
constant uint64_t & nb11,
|
| 4919 |
constant uint64_t & nb12,
|
| 4920 |
+
constant uint64_t & nb13,
|
| 4921 |
constant int64_t & ne0,
|
| 4922 |
constant int64_t & ne1,
|
| 4923 |
constant uint & r2,
|
|
|
|
| 4927 |
uint tiisg[[thread_index_in_simdgroup]],
|
| 4928 |
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
| 4929 |
|
| 4930 |
+
kernel_mul_mv_iq2_s_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb01, nb02, nb03, ne10, ne12, nb11, nb12, nb13, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
|
| 4931 |
}
|
| 4932 |
|
| 4933 |
void kernel_mul_mv_iq1_s_f32_impl(
|
|
|
|
| 4937 |
int64_t ne00,
|
| 4938 |
int64_t ne01,
|
| 4939 |
int64_t ne02,
|
| 4940 |
+
uint64_t nb01,
|
| 4941 |
+
uint64_t nb02,
|
| 4942 |
+
uint64_t nb03,
|
| 4943 |
int64_t ne10,
|
| 4944 |
int64_t ne12,
|
| 4945 |
+
uint64_t nb11,
|
| 4946 |
+
uint64_t nb12,
|
| 4947 |
+
uint64_t nb13,
|
| 4948 |
int64_t ne0,
|
| 4949 |
int64_t ne1,
|
| 4950 |
uint r2,
|
|
|
|
| 4960 |
const int im = tgpig.z;
|
| 4961 |
|
| 4962 |
const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
|
|
|
|
| 4963 |
|
| 4964 |
const uint i12 = im%ne12;
|
| 4965 |
const uint i13 = im/ne12;
|
| 4966 |
|
| 4967 |
+
const uint offset0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03;
|
| 4968 |
+
const uint offset1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13;
|
| 4969 |
+
|
| 4970 |
+
device const block_iq1_s * x = (device const block_iq1_s *) ((device char *) src0 + offset0);
|
| 4971 |
+
device const float * y = (device const float *) ((device char *) src1 + offset1);
|
| 4972 |
|
| 4973 |
float yl[32];
|
| 4974 |
float sumf[N_DST]={0.f}, all_sum;
|
|
|
|
| 5011 |
}
|
| 5012 |
sumf[row] += (float)dh[0] * (sum + sumy * (qh[0] & 0x8000 ? -1 - IQ1S_DELTA : -1 + IQ1S_DELTA)) * (2*((qh[0] >> 12) & 7) + 1);
|
| 5013 |
|
| 5014 |
+
dh += nb01/2;
|
| 5015 |
+
qs += nb01;
|
| 5016 |
+
qh += nb01/2;
|
| 5017 |
}
|
| 5018 |
|
| 5019 |
y4 += 32 * 32;
|
|
|
|
| 5034 |
int64_t ne00,
|
| 5035 |
int64_t ne01,
|
| 5036 |
int64_t ne02,
|
| 5037 |
+
uint64_t nb01,
|
| 5038 |
+
uint64_t nb02,
|
| 5039 |
+
uint64_t nb03,
|
| 5040 |
int64_t ne10,
|
| 5041 |
int64_t ne12,
|
| 5042 |
+
uint64_t nb11,
|
| 5043 |
+
uint64_t nb12,
|
| 5044 |
+
uint64_t nb13,
|
| 5045 |
int64_t ne0,
|
| 5046 |
int64_t ne1,
|
| 5047 |
uint r2,
|
|
|
|
| 5057 |
const int im = tgpig.z;
|
| 5058 |
|
| 5059 |
const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
|
|
|
|
| 5060 |
|
| 5061 |
const uint i12 = im%ne12;
|
| 5062 |
const uint i13 = im/ne12;
|
| 5063 |
|
| 5064 |
+
const uint offset0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03;
|
| 5065 |
+
const uint offset1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13;
|
| 5066 |
+
|
| 5067 |
+
device const block_iq1_m * x = (device const block_iq1_m *) ((device char *) src0 + offset0);
|
| 5068 |
+
device const float * y = (device const float *) ((device char *) src1 + offset1);
|
| 5069 |
|
| 5070 |
float yl[32];
|
| 5071 |
float sumf[N_DST]={0.f}, all_sum;
|
|
|
|
| 5117 |
sumf[row] += (float)scale.f16 * ((sum[0] + delta1) * (2*((sc[ib/2] >> (6*(ib%2)+0)) & 7) + 1) +
|
| 5118 |
(sum[1] + delta2) * (2*((sc[ib/2] >> (6*(ib%2)+3)) & 7) + 1));
|
| 5119 |
|
| 5120 |
+
sc += nb01/2;
|
| 5121 |
+
qs += nb01;
|
| 5122 |
+
qh += nb01;
|
| 5123 |
}
|
| 5124 |
|
| 5125 |
y4 += 32 * 32;
|
|
|
|
| 5140 |
int64_t ne00,
|
| 5141 |
int64_t ne01,
|
| 5142 |
int64_t ne02,
|
| 5143 |
+
uint64_t nb01,
|
| 5144 |
+
uint64_t nb02,
|
| 5145 |
+
uint64_t nb03,
|
| 5146 |
int64_t ne10,
|
| 5147 |
int64_t ne12,
|
| 5148 |
+
uint64_t nb11,
|
| 5149 |
+
uint64_t nb12,
|
| 5150 |
+
uint64_t nb13,
|
| 5151 |
int64_t ne0,
|
| 5152 |
int64_t ne1,
|
| 5153 |
uint r2,
|
|
|
|
| 5163 |
const int r1 = tgpig.y;
|
| 5164 |
const int im = tgpig.z;
|
| 5165 |
const int first_row = (r0 * 2 + sgitg) * 2;
|
|
|
|
| 5166 |
|
| 5167 |
const uint i12 = im%ne12;
|
| 5168 |
const uint i13 = im/ne12;
|
| 5169 |
|
| 5170 |
+
const uint offset0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03;
|
| 5171 |
+
const uint offset1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13;
|
| 5172 |
+
|
| 5173 |
+
device const block_iq4_nl * x = (device const block_iq4_nl *) ((device char *) src0 + offset0);
|
| 5174 |
+
device const float * y = (device const float *) ((device char *) src1 + offset1);
|
| 5175 |
|
| 5176 |
const int ix = tiisg/2; // 0...15
|
| 5177 |
const int it = tiisg%2; // 0 or 1
|
|
|
|
| 5241 |
int64_t ne00,
|
| 5242 |
int64_t ne01,
|
| 5243 |
int64_t ne02,
|
| 5244 |
+
uint64_t nb01,
|
| 5245 |
+
uint64_t nb02,
|
| 5246 |
+
uint64_t nb03,
|
| 5247 |
int64_t ne10,
|
| 5248 |
int64_t ne12,
|
| 5249 |
+
uint64_t nb11,
|
| 5250 |
+
uint64_t nb12,
|
| 5251 |
+
uint64_t nb13,
|
| 5252 |
int64_t ne0,
|
| 5253 |
int64_t ne1,
|
| 5254 |
uint r2,
|
|
|
|
| 5264 |
const int r1 = tgpig.y;
|
| 5265 |
const int im = tgpig.z;
|
| 5266 |
const int first_row = (r0 * 2 + sgitg) * 2;
|
|
|
|
| 5267 |
|
| 5268 |
const uint i12 = im%ne12;
|
| 5269 |
const uint i13 = im/ne12;
|
| 5270 |
|
| 5271 |
+
const uint offset0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03;
|
| 5272 |
+
const uint offset1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13;
|
| 5273 |
+
|
| 5274 |
+
device const block_iq4_xs * x = (device const block_iq4_xs *) ((device char *) src0 + offset0);
|
| 5275 |
+
device const float * y = (device const float *) ((device char *) src1 + offset1);
|
| 5276 |
|
| 5277 |
const int ix = tiisg/16; // 0 or 1
|
| 5278 |
const int it = tiisg%16; // 0...15
|
|
|
|
| 5347 |
constant uint64_t & nb00,
|
| 5348 |
constant uint64_t & nb01,
|
| 5349 |
constant uint64_t & nb02,
|
| 5350 |
+
constant uint64_t & nb03,
|
| 5351 |
constant int64_t & ne10,
|
| 5352 |
constant int64_t & ne11,
|
| 5353 |
constant int64_t & ne12,
|
| 5354 |
constant uint64_t & nb10,
|
| 5355 |
constant uint64_t & nb11,
|
| 5356 |
constant uint64_t & nb12,
|
| 5357 |
+
constant uint64_t & nb13,
|
| 5358 |
constant int64_t & ne0,
|
| 5359 |
constant int64_t & ne1,
|
| 5360 |
constant uint & r2,
|
|
|
|
| 5363 |
uint tiisg[[thread_index_in_simdgroup]],
|
| 5364 |
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
| 5365 |
|
| 5366 |
+
kernel_mul_mv_iq1_s_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb01, nb02, nb03, ne10, ne12, nb11, nb12, nb13, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg);
|
| 5367 |
}
|
| 5368 |
|
| 5369 |
[[host_name("kernel_mul_mv_iq1_m_f32")]]
|
|
|
|
| 5377 |
constant uint64_t & nb00,
|
| 5378 |
constant uint64_t & nb01,
|
| 5379 |
constant uint64_t & nb02,
|
| 5380 |
+
constant uint64_t & nb03,
|
| 5381 |
constant int64_t & ne10,
|
| 5382 |
constant int64_t & ne11,
|
| 5383 |
constant int64_t & ne12,
|
| 5384 |
constant uint64_t & nb10,
|
| 5385 |
constant uint64_t & nb11,
|
| 5386 |
constant uint64_t & nb12,
|
| 5387 |
+
constant uint64_t & nb13,
|
| 5388 |
constant int64_t & ne0,
|
| 5389 |
constant int64_t & ne1,
|
| 5390 |
constant uint & r2,
|
|
|
|
| 5393 |
uint tiisg[[thread_index_in_simdgroup]],
|
| 5394 |
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
| 5395 |
|
| 5396 |
+
kernel_mul_mv_iq1_m_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb01, nb02, nb03, ne10, ne12, nb11, nb12, nb13, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg);
|
| 5397 |
}
|
| 5398 |
|
| 5399 |
[[host_name("kernel_mul_mv_iq4_nl_f32")]]
|
|
|
|
| 5407 |
constant uint64_t & nb00,
|
| 5408 |
constant uint64_t & nb01,
|
| 5409 |
constant uint64_t & nb02,
|
| 5410 |
+
constant uint64_t & nb03,
|
| 5411 |
constant int64_t & ne10,
|
| 5412 |
constant int64_t & ne11,
|
| 5413 |
constant int64_t & ne12,
|
| 5414 |
constant uint64_t & nb10,
|
| 5415 |
constant uint64_t & nb11,
|
| 5416 |
constant uint64_t & nb12,
|
| 5417 |
+
constant uint64_t & nb13,
|
| 5418 |
constant int64_t & ne0,
|
| 5419 |
constant int64_t & ne1,
|
| 5420 |
constant uint & r2,
|
|
|
|
| 5424 |
uint tiisg[[thread_index_in_simdgroup]],
|
| 5425 |
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
| 5426 |
|
| 5427 |
+
kernel_mul_mv_iq4_nl_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb01, nb02, nb03, ne10, ne12, nb11, nb12, nb13, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
|
| 5428 |
}
|
| 5429 |
|
| 5430 |
[[host_name("kernel_mul_mv_iq4_xs_f32")]]
|
|
|
|
| 5438 |
constant uint64_t & nb00,
|
| 5439 |
constant uint64_t & nb01,
|
| 5440 |
constant uint64_t & nb02,
|
| 5441 |
+
constant uint64_t & nb03,
|
| 5442 |
constant int64_t & ne10,
|
| 5443 |
constant int64_t & ne11,
|
| 5444 |
constant int64_t & ne12,
|
| 5445 |
constant uint64_t & nb10,
|
| 5446 |
constant uint64_t & nb11,
|
| 5447 |
constant uint64_t & nb12,
|
| 5448 |
+
constant uint64_t & nb13,
|
| 5449 |
constant int64_t & ne0,
|
| 5450 |
constant int64_t & ne1,
|
| 5451 |
constant uint & r2,
|
|
|
|
| 5455 |
uint tiisg[[thread_index_in_simdgroup]],
|
| 5456 |
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
| 5457 |
|
| 5458 |
+
kernel_mul_mv_iq4_xs_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb01, nb02, nb03, ne10, ne12, nb11, nb12, nb13, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
|
| 5459 |
}
|
| 5460 |
|
| 5461 |
//============================= templates and their specializations =============================
|
|
|
|
| 6000 |
constant int64_t & ne02,
|
| 6001 |
constant uint64_t & nb01,
|
| 6002 |
constant uint64_t & nb02,
|
| 6003 |
+
constant uint64_t & nb03,
|
| 6004 |
constant int64_t & ne12,
|
| 6005 |
constant uint64_t & nb10,
|
| 6006 |
constant uint64_t & nb11,
|
| 6007 |
constant uint64_t & nb12,
|
| 6008 |
+
constant uint64_t & nb13,
|
| 6009 |
constant int64_t & ne0,
|
| 6010 |
constant int64_t & ne1,
|
| 6011 |
constant uint & r2,
|
|
|
|
| 6042 |
const uint i12 = im%ne12;
|
| 6043 |
const uint i13 = im/ne12;
|
| 6044 |
|
| 6045 |
+
uint offset0 = (i12/r2)*nb02 + (i13/r3)*nb03;
|
| 6046 |
ushort offset1 = il/nl;
|
| 6047 |
|
| 6048 |
device const block_q * x = (device const block_q *)(src0 + (r0 * BLOCK_SIZE_M + thread_row) * nb01 + offset0) + offset1;
|
| 6049 |
device const float * y = (device const float *)(src1
|
| 6050 |
+
+ nb13 * i13
|
| 6051 |
+
+ nb12 * i12
|
| 6052 |
+ nb11 * (r1 * BLOCK_SIZE_N + thread_col)
|
| 6053 |
+ nb10 * (BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL)));
|
| 6054 |
|
|
|
|
| 6427 |
uint64_t nb00,
|
| 6428 |
uint64_t nb01,
|
| 6429 |
uint64_t nb02,
|
| 6430 |
+
uint64_t nb03,
|
| 6431 |
int64_t ne10,
|
| 6432 |
int64_t ne11,
|
| 6433 |
int64_t ne12,
|
| 6434 |
uint64_t nb10,
|
| 6435 |
uint64_t nb11,
|
| 6436 |
uint64_t nb12,
|
| 6437 |
+
uint64_t nb13,
|
| 6438 |
int64_t ne0,
|
| 6439 |
int64_t ne1,
|
| 6440 |
uint r2,
|
|
|
|
| 6449 |
int64_t ne00,
|
| 6450 |
int64_t ne01,
|
| 6451 |
int64_t ne02,
|
| 6452 |
+
uint64_t nb01,
|
| 6453 |
+
uint64_t nb02,
|
| 6454 |
+
uint64_t nb03,
|
| 6455 |
int64_t ne10,
|
| 6456 |
int64_t ne12,
|
| 6457 |
+
uint64_t nb11,
|
| 6458 |
+
uint64_t nb12,
|
| 6459 |
+
uint64_t nb13,
|
| 6460 |
int64_t ne0,
|
| 6461 |
int64_t ne1,
|
| 6462 |
uint r2,
|
|
|
|
| 6477 |
uint64_t nb00,
|
| 6478 |
uint64_t nb01,
|
| 6479 |
uint64_t nb02,
|
| 6480 |
+
uint64_t nb03,
|
| 6481 |
int64_t ne10,
|
| 6482 |
int64_t ne11,
|
| 6483 |
int64_t ne12,
|
|
|
|
| 6485 |
uint64_t nb10,
|
| 6486 |
uint64_t nb11,
|
| 6487 |
uint64_t nb12,
|
| 6488 |
+
uint64_t nb13,
|
| 6489 |
int64_t ne0,
|
| 6490 |
int64_t ne1,
|
| 6491 |
uint64_t nb1,
|
|
|
|
| 6496 |
uint tiitg,
|
| 6497 |
uint tiisg,
|
| 6498 |
uint sgitg) {
|
| 6499 |
+
impl_fn(src0,src1,dst,ne00,ne01,ne02,nb00,nb01,nb02,nb03,ne10,ne11,ne12,nb10,nb11,nb12,nb13,ne0,ne1,r2,r3,tgpig,tiisg);
|
| 6500 |
}
|
| 6501 |
|
| 6502 |
template<kernel_mul_mv2_impl_t impl_fn>
|
|
|
|
| 6510 |
uint64_t nb00,
|
| 6511 |
uint64_t nb01,
|
| 6512 |
uint64_t nb02,
|
| 6513 |
+
uint64_t nb03,
|
| 6514 |
int64_t ne10,
|
| 6515 |
int64_t ne11,
|
| 6516 |
int64_t ne12,
|
|
|
|
| 6518 |
uint64_t nb10,
|
| 6519 |
uint64_t nb11,
|
| 6520 |
uint64_t nb12,
|
| 6521 |
+
uint64_t nb13,
|
| 6522 |
int64_t ne0,
|
| 6523 |
int64_t ne1,
|
| 6524 |
uint64_t nb1,
|
|
|
|
| 6529 |
uint tiitg,
|
| 6530 |
uint tiisg,
|
| 6531 |
uint sgitg) {
|
| 6532 |
+
impl_fn(src0,(const device float *)src1,dst,ne00,ne01,ne02,nb01,nb02,nb03,ne10,ne12,nb11,nb12,nb13,ne0,ne1,r2,r3,shared_values,tgpig,tiisg,sgitg);
|
| 6533 |
}
|
| 6534 |
|
| 6535 |
typedef decltype(mmv_fn<kernel_mul_mv_impl<half, half4, half, half4>>) mul_mv_impl_fn_t;
|
|
|
|
| 6578 |
const int64_t i2 = i12;
|
| 6579 |
|
| 6580 |
device const char * src0_cur = src0s + i02*nb02;
|
| 6581 |
+
device const char * src1_cur = src1 + i11*nb11 + i12*nb12;
|
| 6582 |
+
device float * dst_cur = dst + i1*ne0 + i2*ne1*ne0;
|
| 6583 |
|
| 6584 |
impl_fn(
|
| 6585 |
/* src0 */ src0_cur,
|
|
|
|
| 6587 |
/* dst */ dst_cur,
|
| 6588 |
/* ne00 */ ne00,
|
| 6589 |
/* ne01 */ ne01,
|
| 6590 |
+
/* ne02 */ 1, // ne02,
|
| 6591 |
/* nb00 */ nb00,
|
| 6592 |
/* nb01 */ nb01,
|
| 6593 |
/* nb02 */ nb02,
|
| 6594 |
+
/* nb03 */ nb02, // ne02 == 1
|
| 6595 |
/* ne10 */ ne10,
|
| 6596 |
+
/* ne11 */ 1, // ne11,
|
| 6597 |
+
/* ne12 */ 1, // ne12,
|
| 6598 |
+
/* ne13 */ 1, // ne13,
|
| 6599 |
/* nb10 */ nb10,
|
| 6600 |
/* nb11 */ nb11,
|
| 6601 |
/* nb12 */ nb12,
|
| 6602 |
+
/* ne13 */ nb12, // ne12 == 1
|
| 6603 |
/* ne0 */ ne0,
|
| 6604 |
+
/* ne1 */ 1, // ne1,
|
| 6605 |
/* nb1 */ nb1,
|
| 6606 |
/* r2 */ 1,
|
| 6607 |
/* r3 */ 1,
|