ggerganov commited on
Commit
efb86a3
·
1 Parent(s): a41f94c

metal : support permuted matrix multiplicaions (llama/10033)

Browse files

* metal : support permuted matrix multiplicaions

ggml-ci

* cont : use nb01 directly for row steps

ggml-ci

* cont : add comments [no ci]

* metal : minor refactor

* metal : minor

Files changed (2) hide show
  1. ggml/src/ggml-metal.m +42 -33
  2. ggml/src/ggml-metal.metal +380 -196
ggml/src/ggml-metal.m CHANGED
@@ -1015,19 +1015,21 @@ static void ggml_metal_encode_node(
1015
  id<MTLBuffer> id_src2 = src2 ? ggml_metal_get_buffer(src2, &offs_src2) : nil;
1016
  id<MTLBuffer> id_dst = dst ? ggml_metal_get_buffer(dst, &offs_dst) : nil;
1017
 
1018
- //GGML_LOG_INFO("%s: op - %s\n", __func__, ggml_op_name(dst->op));
1019
- //if (src0) {
1020
- // GGML_LOG_INFO("%s: src0 - %4s [%5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src0t), ne00, ne01, ne02,
1021
- // ggml_is_contiguous(src0), src0->name);
1022
- //}
1023
- //if (src1) {
1024
- // GGML_LOG_INFO("%s: src1 - %4s [%5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src1t), ne10, ne11, ne12,
1025
- // ggml_is_contiguous(src1), src1->name);
1026
- //}
1027
- //if (dst) {
1028
- // GGML_LOG_INFO("%s: dst - %4s [%5lld, %5lld, %5lld], 1, %s\n", __func__, ggml_type_name(dstt), ne0, ne1, ne2,
1029
- // dst->name);
1030
- //}
 
 
1031
 
1032
  id<MTLDevice> device = ctx_dev->mtl_device;
1033
 
@@ -1810,14 +1812,16 @@ static void ggml_metal_encode_node(
1810
  [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
1811
  [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:5];
1812
  [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:6];
1813
- [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:7];
1814
- [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:8];
1815
- [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:9];
1816
- [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:10];
1817
- [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:11];
1818
- [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:12];
1819
- [encoder setBytes:&r2 length:sizeof(r2) atIndex:13];
1820
- [encoder setBytes:&r3 length:sizeof(r3) atIndex:14];
 
 
1821
  [encoder setThreadgroupMemoryLength:8192 atIndex:0];
1822
  [encoder dispatchThreadgroups:MTLSizeMake( (ne11 + 31)/32, (ne01 + 63)/64, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
1823
  } else {
@@ -1986,20 +1990,22 @@ static void ggml_metal_encode_node(
1986
  [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
1987
  [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
1988
  [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
1989
- [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:9];
1990
- [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:10];
1991
- [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:11];
1992
- [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:12];
1993
- [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:13];
1994
- [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:14];
1995
- [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:15];
1996
- [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:16];
1997
- [encoder setBytes:&r2 length:sizeof(r2) atIndex:17];
1998
- [encoder setBytes:&r3 length:sizeof(r3) atIndex:18];
 
 
1999
 
2000
  if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 || src0t == GGML_TYPE_Q5_0 ||
2001
- src0t == GGML_TYPE_Q5_1 || src0t == GGML_TYPE_Q8_0 || src0t == GGML_TYPE_Q2_K ||
2002
- src0t == GGML_TYPE_IQ1_S || src0t == GGML_TYPE_IQ1_M || src0t == GGML_TYPE_IQ2_S) {
2003
  [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
2004
  }
2005
  else if (src0t == GGML_TYPE_IQ2_XXS || src0t == GGML_TYPE_IQ2_XS) {
@@ -2048,6 +2054,9 @@ static void ggml_metal_encode_node(
2048
 
2049
  GGML_ASSERT(src1t == GGML_TYPE_F32);
2050
 
 
 
 
2051
  // find the break-even point where the matrix-matrix kernel becomes more efficient compared
2052
  // to the matrix-vector kernel
2053
  // ne20 = n_used_experts
 
1015
  id<MTLBuffer> id_src2 = src2 ? ggml_metal_get_buffer(src2, &offs_src2) : nil;
1016
  id<MTLBuffer> id_dst = dst ? ggml_metal_get_buffer(dst, &offs_dst) : nil;
1017
 
1018
+ #if 0
1019
+ GGML_LOG_INFO("%s: op - %s\n", __func__, ggml_op_name(dst->op));
1020
+ if (src0) {
1021
+ GGML_LOG_INFO("%s: src0 - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src0t), ne00, ne01, ne02, ne03, nb00, nb01, nb02, nb03,
1022
+ ggml_is_contiguous(src0), src0->name);
1023
+ }
1024
+ if (src1) {
1025
+ GGML_LOG_INFO("%s: src1 - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src1t), ne10, ne11, ne12, ne13, nb10, nb11, nb12, nb13,
1026
+ ggml_is_contiguous(src1), src1->name);
1027
+ }
1028
+ if (dst) {
1029
+ GGML_LOG_INFO("%s: dst - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], 1, %s\n", __func__, ggml_type_name(dstt), ne0, ne1, ne2, ne3, nb0, nb1, nb2, nb3,
1030
+ dst->name);
1031
+ }
1032
+ #endif
1033
 
1034
  id<MTLDevice> device = ctx_dev->mtl_device;
1035
 
 
1812
  [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
1813
  [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:5];
1814
  [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:6];
1815
+ [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:7];
1816
+ [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:8];
1817
+ [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:9];
1818
+ [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:10];
1819
+ [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:11];
1820
+ [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:12];
1821
+ [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:13];
1822
+ [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:14];
1823
+ [encoder setBytes:&r2 length:sizeof(r2) atIndex:15];
1824
+ [encoder setBytes:&r3 length:sizeof(r3) atIndex:16];
1825
  [encoder setThreadgroupMemoryLength:8192 atIndex:0];
1826
  [encoder dispatchThreadgroups:MTLSizeMake( (ne11 + 31)/32, (ne01 + 63)/64, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
1827
  } else {
 
1990
  [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
1991
  [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
1992
  [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
1993
+ [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9];
1994
+ [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:10];
1995
+ [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:11];
1996
+ [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:12];
1997
+ [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:13];
1998
+ [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:14];
1999
+ [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:15];
2000
+ [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:16];
2001
+ [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:17];
2002
+ [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:18];
2003
+ [encoder setBytes:&r2 length:sizeof(r2) atIndex:19];
2004
+ [encoder setBytes:&r3 length:sizeof(r3) atIndex:20];
2005
 
2006
  if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 || src0t == GGML_TYPE_Q5_0 ||
2007
+ src0t == GGML_TYPE_Q5_1 || src0t == GGML_TYPE_Q8_0 || src0t == GGML_TYPE_Q2_K ||
2008
+ src0t == GGML_TYPE_IQ1_S || src0t == GGML_TYPE_IQ1_M || src0t == GGML_TYPE_IQ2_S) {
2009
  [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
2010
  }
2011
  else if (src0t == GGML_TYPE_IQ2_XXS || src0t == GGML_TYPE_IQ2_XS) {
 
2054
 
2055
  GGML_ASSERT(src1t == GGML_TYPE_F32);
2056
 
2057
+ GGML_ASSERT(ne03 == 1);
2058
+ GGML_ASSERT(ne13 == 1);
2059
+
2060
  // find the break-even point where the matrix-matrix kernel becomes more efficient compared
2061
  // to the matrix-vector kernel
2062
  // ne20 = n_used_experts
ggml/src/ggml-metal.metal CHANGED
@@ -777,10 +777,10 @@ kernel void kernel_ssm_conv_f32(
777
  const int64_t i3 = tgpig.z;
778
 
779
  const int64_t nc = ne10;
780
- const int64_t ncs = ne00;
781
- const int64_t nr = ne01;
782
- const int64_t n_t = ne1;
783
- const int64_t n_s = ne2;
784
 
785
  device const float * s = (device const float *) ((device const char *) src0 + ir*nb01 + i2*nb00 + i3*nb02);
786
  device const float * c = (device const float *) ((device const char *) src1 + ir*nb11);
@@ -834,9 +834,9 @@ kernel void kernel_ssm_scan_f32(
834
  const int64_t i3 = tgpig.y;
835
 
836
  const int64_t nc = d_state;
837
- const int64_t nr = d_inner;
838
  const int64_t n_t = n_seq_tokens;
839
- const int64_t n_s = n_seqs;
840
 
841
  for (int64_t i2 = 0; i2 < n_t; ++i2) {
842
  device const float * s0 = (device const float *) ((device const char *) src0 + ir*nb01 + i3*nb02);
@@ -1064,17 +1064,18 @@ kernel void kernel_group_norm(
1064
  inline float block_q_n_dot_y(device const block_q4_0 * qb_curr, float sumy, thread float * yl, int il) {
1065
  float d = qb_curr->d;
1066
 
1067
- float2 acc = 0.f;
1068
 
1069
- device const uint16_t * qs = ((device const uint16_t *)qb_curr + 1 + il/2);
1070
 
1071
- for (int i = 0; i < 8; i+=2) {
1072
- acc[0] += yl[i + 0] * (qs[i / 2] & 0x000F)
1073
- + yl[i + 1] * (qs[i / 2] & 0x0F00);
1074
- acc[1] += yl[i + 8] * (qs[i / 2] & 0x00F0)
1075
- + yl[i + 9] * (qs[i / 2] & 0xF000);
1076
  }
1077
- return d * (sumy * -8.f + acc[0] + acc[1]);
 
1078
  }
1079
 
1080
  // function for calculate inner product between half a q4_1 block and 16 floats (yl), sumy is SUM(yl[i])
@@ -1085,17 +1086,18 @@ inline float block_q_n_dot_y(device const block_q4_1 * qb_curr, float sumy, thre
1085
  float d = qb_curr->d;
1086
  float m = qb_curr->m;
1087
 
1088
- float2 acc = 0.f;
1089
 
1090
- device const uint16_t * qs = ((device const uint16_t *)qb_curr + 2 + il/2);
1091
 
1092
  for (int i = 0; i < 8; i+=2) {
1093
- acc[0] += yl[i + 0] * (qs[i / 2] & 0x000F)
1094
- + yl[i + 1] * (qs[i / 2] & 0x0F00);
1095
- acc[1] += yl[i + 8] * (qs[i / 2] & 0x00F0)
1096
- + yl[i + 9] * (qs[i / 2] & 0xF000);
1097
  }
1098
- return d * (acc[0] + acc[1]) + sumy * m;
 
1099
  }
1100
 
1101
  // function for calculate inner product between half a q5_0 block and 16 floats (yl), sumy is SUM(yl[i])
@@ -1105,18 +1107,19 @@ inline float block_q_n_dot_y(device const block_q4_1 * qb_curr, float sumy, thre
1105
  inline float block_q_n_dot_y(device const block_q5_0 * qb_curr, float sumy, thread float * yl, int il) {
1106
  float d = qb_curr->d;
1107
 
1108
- float2 acc = 0.f;
1109
 
1110
  device const uint16_t * qs = ((device const uint16_t *)qb_curr + 3 + il/2);
1111
  const uint32_t qh = *((device const uint32_t *)qb_curr->qh);
1112
 
1113
  for (int i = 0; i < 8; i+=2) {
1114
- acc[0] += yl[i + 0] * ((qs[i / 2] & 0x000F) | ((qh >> (i+0+il ) << 4 ) & 0x00010))
1115
- + yl[i + 1] * ((qs[i / 2] & 0x0F00) | ((qh >> (i+1+il ) << 12) & 0x01000));
1116
- acc[1] += yl[i + 8] * ((qs[i / 2] & 0x00F0) | ((qh >> (i+0+il+QK5_0/2) << 8 ) & 0x00100))
1117
- + yl[i + 9] * ((qs[i / 2] & 0xF000) | ((qh >> (i+1+il+QK5_0/2) << 16) & 0x10000));
1118
  }
1119
- return d * (sumy * -16.f + acc[0] + acc[1]);
 
1120
  }
1121
 
1122
  // function for calculate inner product between half a q5_1 block and 16 floats (yl), sumy is SUM(yl[i])
@@ -1127,18 +1130,19 @@ inline float block_q_n_dot_y(device const block_q5_1 * qb_curr, float sumy, thre
1127
  float d = qb_curr->d;
1128
  float m = qb_curr->m;
1129
 
1130
- float2 acc = 0.f;
1131
 
1132
  device const uint16_t * qs = ((device const uint16_t *)qb_curr + 4 + il/2);
1133
  const uint32_t qh = *((device const uint32_t *)qb_curr->qh);
1134
 
1135
  for (int i = 0; i < 8; i+=2) {
1136
- acc[0] += yl[i + 0] * ((qs[i / 2] & 0x000F) | ((qh >> (i+0+il ) << 4 ) & 0x00010))
1137
- + yl[i + 1] * ((qs[i / 2] & 0x0F00) | ((qh >> (i+1+il ) << 12) & 0x01000));
1138
- acc[1] += yl[i + 8] * ((qs[i / 2] & 0x00F0) | ((qh >> (i+0+il+QK5_0/2) << 8 ) & 0x00100))
1139
- + yl[i + 9] * ((qs[i / 2] & 0xF000) | ((qh >> (i+1+il+QK5_0/2) << 16) & 0x10000));
1140
  }
1141
- return d * (acc[0] + acc[1]) + sumy * m;
 
1142
  }
1143
 
1144
  // putting them in the kernel cause a significant performance penalty
@@ -1156,14 +1160,22 @@ void mul_vec_q_n_f32_impl(
1156
  int64_t ne00,
1157
  int64_t ne01,
1158
  int64_t ne02,
 
 
 
1159
  int64_t ne10,
1160
  int64_t ne12,
 
 
 
1161
  int64_t ne0,
1162
  int64_t ne1,
1163
  uint r2,
1164
  uint r3,
1165
  threadgroup int8_t * shared_values,
1166
- uint3 tgpig, uint tiisg, uint sgitg) {
 
 
1167
  const int nb = ne00/QK4_0;
1168
 
1169
  const int r0 = tgpig.x;
@@ -1175,10 +1187,19 @@ void mul_vec_q_n_f32_impl(
1175
  const uint i12 = im%ne12;
1176
  const uint i13 = im/ne12;
1177
 
1178
- const uint offset0 = first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
 
1179
 
1180
- device const block_q_type * x = (device const block_q_type *) src0 + offset0;
1181
- device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
 
 
 
 
 
 
 
 
1182
 
1183
  float yl[16]; // src1 vector cache
1184
  float sumf[nr] = {0.f};
@@ -1190,19 +1211,22 @@ void mul_vec_q_n_f32_impl(
1190
 
1191
  // each thread in a SIMD group deals with half a block.
1192
  for (int ib = ix; ib < nb; ib += nw/2) {
1193
- float sumy = 0;
 
 
1194
  for (int i = 0; i < 8; i += 2) {
1195
- sumy += yb[i] + yb[i+1];
1196
- yl[i+0] = yb[i+ 0];
1197
- yl[i+1] = yb[i+ 1]/256.f;
1198
 
1199
- sumy += yb[i+16] + yb[i+17];
1200
- yl[i+8] = yb[i+16]/16.f;
1201
- yl[i+9] = yb[i+17]/4096.f;
1202
  }
1203
 
 
1204
  for (int row = 0; row < nr; row++) {
1205
- sumf[row] += block_q_n_dot_y(x+ib+row*nb, sumy, yl, il);
1206
  }
1207
 
1208
  yb += QK4_0 * 16;
@@ -1226,12 +1250,14 @@ kernel void kernel_mul_mv_q4_0_f32(
1226
  constant uint64_t & nb00,
1227
  constant uint64_t & nb01,
1228
  constant uint64_t & nb02,
 
1229
  constant int64_t & ne10,
1230
  constant int64_t & ne11,
1231
  constant int64_t & ne12,
1232
  constant uint64_t & nb10,
1233
  constant uint64_t & nb11,
1234
  constant uint64_t & nb12,
 
1235
  constant int64_t & ne0,
1236
  constant int64_t & ne1,
1237
  constant uint & r2,
@@ -1239,7 +1265,7 @@ kernel void kernel_mul_mv_q4_0_f32(
1239
  uint3 tgpig[[threadgroup_position_in_grid]],
1240
  uint tiisg[[thread_index_in_simdgroup]],
1241
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
1242
- mul_vec_q_n_f32_impl<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,nullptr,tgpig,tiisg,sgitg);
1243
  }
1244
 
1245
  kernel void kernel_mul_mv_q4_1_f32(
@@ -1252,12 +1278,14 @@ kernel void kernel_mul_mv_q4_1_f32(
1252
  constant uint64_t & nb00,
1253
  constant uint64_t & nb01,
1254
  constant uint64_t & nb02,
 
1255
  constant int64_t & ne10,
1256
  constant int64_t & ne11,
1257
  constant int64_t & ne12,
1258
  constant uint64_t & nb10,
1259
  constant uint64_t & nb11,
1260
  constant uint64_t & nb12,
 
1261
  constant int64_t & ne0,
1262
  constant int64_t & ne1,
1263
  constant uint & r2,
@@ -1265,7 +1293,7 @@ kernel void kernel_mul_mv_q4_1_f32(
1265
  uint3 tgpig[[threadgroup_position_in_grid]],
1266
  uint tiisg[[thread_index_in_simdgroup]],
1267
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
1268
- mul_vec_q_n_f32_impl<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,nullptr,tgpig,tiisg,sgitg);
1269
  }
1270
 
1271
  kernel void kernel_mul_mv_q5_0_f32(
@@ -1278,12 +1306,14 @@ kernel void kernel_mul_mv_q5_0_f32(
1278
  constant uint64_t & nb00,
1279
  constant uint64_t & nb01,
1280
  constant uint64_t & nb02,
 
1281
  constant int64_t & ne10,
1282
  constant int64_t & ne11,
1283
  constant int64_t & ne12,
1284
  constant uint64_t & nb10,
1285
  constant uint64_t & nb11,
1286
  constant uint64_t & nb12,
 
1287
  constant int64_t & ne0,
1288
  constant int64_t & ne1,
1289
  constant uint & r2,
@@ -1291,7 +1321,7 @@ kernel void kernel_mul_mv_q5_0_f32(
1291
  uint3 tgpig[[threadgroup_position_in_grid]],
1292
  uint tiisg[[thread_index_in_simdgroup]],
1293
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
1294
- mul_vec_q_n_f32_impl<block_q5_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,nullptr,tgpig,tiisg,sgitg);
1295
  }
1296
 
1297
  kernel void kernel_mul_mv_q5_1_f32(
@@ -1304,12 +1334,14 @@ kernel void kernel_mul_mv_q5_1_f32(
1304
  constant uint64_t & nb00,
1305
  constant uint64_t & nb01,
1306
  constant uint64_t & nb02,
 
1307
  constant int64_t & ne10,
1308
  constant int64_t & ne11,
1309
  constant int64_t & ne12,
1310
  constant uint64_t & nb10,
1311
  constant uint64_t & nb11,
1312
  constant uint64_t & nb12,
 
1313
  constant int64_t & ne0,
1314
  constant int64_t & ne1,
1315
  constant uint & r2,
@@ -1317,7 +1349,7 @@ kernel void kernel_mul_mv_q5_1_f32(
1317
  uint3 tgpig[[threadgroup_position_in_grid]],
1318
  uint tiisg[[thread_index_in_simdgroup]],
1319
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
1320
- mul_vec_q_n_f32_impl<block_q5_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,nullptr,tgpig,tiisg,sgitg);
1321
  }
1322
 
1323
 
@@ -1330,8 +1362,14 @@ void kernel_mul_mv_q8_0_f32_impl(
1330
  int64_t ne00,
1331
  int64_t ne01,
1332
  int64_t ne02,
 
 
 
1333
  int64_t ne10,
1334
  int64_t ne12,
 
 
 
1335
  int64_t ne0,
1336
  int64_t ne1,
1337
  uint r2,
@@ -1354,10 +1392,19 @@ void kernel_mul_mv_q8_0_f32_impl(
1354
  const uint i12 = im%ne12;
1355
  const uint i13 = im/ne12;
1356
 
1357
- const uint offset0 = first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
 
1358
 
1359
- device const block_q8_0 * x = (device const block_q8_0 *) src0 + offset0;
1360
- device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
 
 
 
 
 
 
 
 
1361
 
1362
  float yl[NB_Q8_0];
1363
  float sumf[nr]={0.f};
@@ -1374,12 +1421,12 @@ void kernel_mul_mv_q8_0_f32_impl(
1374
  }
1375
 
1376
  for (int row = 0; row < nr; row++) {
1377
- device const int8_t * qs = x[ib+row*nb].qs + NB_Q8_0*il;
1378
  float sumq = 0.f;
1379
  for (int iq = 0; iq < NB_Q8_0; ++iq) {
1380
  sumq += qs[iq] * yl[iq];
1381
  }
1382
- sumf[row] += sumq*x[ib+row*nb].d;
1383
  }
1384
 
1385
  yb += NB_Q8_0 * nw;
@@ -1404,12 +1451,14 @@ kernel void kernel_mul_mv_q8_0_f32(
1404
  constant uint64_t & nb00,
1405
  constant uint64_t & nb01,
1406
  constant uint64_t & nb02,
 
1407
  constant int64_t & ne10,
1408
  constant int64_t & ne11,
1409
  constant int64_t & ne12,
1410
  constant uint64_t & nb10,
1411
  constant uint64_t & nb11,
1412
  constant uint64_t & nb12,
 
1413
  constant int64_t & ne0,
1414
  constant int64_t & ne1,
1415
  constant uint & r2,
@@ -1417,7 +1466,7 @@ kernel void kernel_mul_mv_q8_0_f32(
1417
  uint3 tgpig[[threadgroup_position_in_grid]],
1418
  uint tiisg[[thread_index_in_simdgroup]],
1419
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
1420
- kernel_mul_mv_q8_0_f32_impl(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,nullptr,tgpig,tiisg,sgitg);
1421
  }
1422
 
1423
  #define N_MV_T_T 4
@@ -1433,12 +1482,14 @@ void kernel_mul_mv_impl(
1433
  uint64_t nb00,
1434
  uint64_t nb01,
1435
  uint64_t nb02,
 
1436
  int64_t ne10,
1437
  int64_t ne11,
1438
  int64_t ne12,
1439
  uint64_t nb10,
1440
  uint64_t nb11,
1441
  uint64_t nb12,
 
1442
  int64_t ne0,
1443
  int64_t ne1,
1444
  uint r2,
@@ -1452,7 +1503,7 @@ void kernel_mul_mv_impl(
1452
  const uint i12 = im%ne12;
1453
  const uint i13 = im/ne12;
1454
 
1455
- const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02;
1456
 
1457
  device const T0 * x = (device const T0 *) (src0 + offset0);
1458
 
@@ -1463,7 +1514,9 @@ void kernel_mul_mv_impl(
1463
  break;
1464
  }
1465
 
1466
- device const T1 * y = (device const T1 *) (src1 + r1*nb11 + im*nb12);
 
 
1467
 
1468
  float sumf = 0;
1469
  for (int i = tiisg; i < ne00; i += 32) {
@@ -1483,7 +1536,9 @@ void kernel_mul_mv_impl(
1483
  break;
1484
  }
1485
 
1486
- device const T1 * y = (device const T1 *) (src1 + r1*nb11 + im*nb12);
 
 
1487
  device const T14 * y4 = (device const T14 *) y;
1488
 
1489
  float sumf = 0;
@@ -1511,12 +1566,14 @@ kernel void kernel_mul_mv(
1511
  constant uint64_t & nb00,
1512
  constant uint64_t & nb01,
1513
  constant uint64_t & nb02,
 
1514
  constant int64_t & ne10,
1515
  constant int64_t & ne11,
1516
  constant int64_t & ne12,
1517
  constant uint64_t & nb10,
1518
  constant uint64_t & nb11,
1519
  constant uint64_t & nb12,
 
1520
  constant int64_t & ne0,
1521
  constant int64_t & ne1,
1522
  constant uint & r2,
@@ -1533,12 +1590,14 @@ kernel void kernel_mul_mv(
1533
  nb00,
1534
  nb01,
1535
  nb02,
 
1536
  ne10,
1537
  ne11,
1538
  ne12,
1539
  nb10,
1540
  nb11,
1541
  nb12,
 
1542
  ne0,
1543
  ne1,
1544
  r2,
@@ -1564,12 +1623,14 @@ kernel void kernel_mul_mv_1row(
1564
  constant uint64_t & nb00,
1565
  constant uint64_t & nb01,
1566
  constant uint64_t & nb02,
 
1567
  constant int64_t & ne10,
1568
  constant int64_t & ne11,
1569
  constant int64_t & ne12,
1570
  constant uint64_t & nb10,
1571
  constant uint64_t & nb11,
1572
  constant uint64_t & nb12,
 
1573
  constant int64_t & ne0,
1574
  constant int64_t & ne1,
1575
  constant uint & r2,
@@ -1584,10 +1645,11 @@ kernel void kernel_mul_mv_1row(
1584
  const uint i12 = im%ne12;
1585
  const uint i13 = im/ne12;
1586
 
1587
- const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02;
 
1588
 
1589
  device const T * x = (device const T *) (src0 + offset0);
1590
- device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12);
1591
 
1592
  float sumf = 0;
1593
  if (ne00 < 128) {
@@ -1631,12 +1693,14 @@ kernel void kernel_mul_mv_l4(
1631
  constant uint64_t & nb00,
1632
  constant uint64_t & nb01,
1633
  constant uint64_t & nb02,
 
1634
  constant int64_t & ne10,
1635
  constant int64_t & ne11,
1636
  constant int64_t & ne12,
1637
  constant uint64_t & nb10,
1638
  constant uint64_t & nb11,
1639
  constant uint64_t & nb12,
 
1640
  constant int64_t & ne0,
1641
  constant int64_t & ne1,
1642
  constant uint & r2,
@@ -1651,12 +1715,14 @@ kernel void kernel_mul_mv_l4(
1651
  const uint i12 = im%ne12;
1652
  const uint i13 = im/ne12;
1653
 
1654
- const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02;
1655
 
1656
  device const T4 * x4 = (device const T4 *) (src0 + offset0);
1657
 
1658
  for (int r1 = 0; r1 < nrows; ++r1) {
1659
- device const float4 * y4 = (device const float4 *) (src1 + r1*nb11 + im*nb12);
 
 
1660
 
1661
  float sumf = 0;
1662
  for (int i = tiisg; i < ne00/4; i += 32) {
@@ -3416,8 +3482,14 @@ void kernel_mul_mv_q2_K_f32_impl(
3416
  int64_t ne00,
3417
  int64_t ne01,
3418
  int64_t ne02,
 
 
 
3419
  int64_t ne10,
3420
  int64_t ne12,
 
 
 
3421
  int64_t ne0,
3422
  int64_t ne1,
3423
  uint r2,
@@ -3433,21 +3505,19 @@ void kernel_mul_mv_q2_K_f32_impl(
3433
  const int im = tgpig.z;
3434
 
3435
  const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
3436
- const int ib_row = first_row * nb;
3437
 
3438
  const uint i12 = im%ne12;
3439
  const uint i13 = im/ne12;
3440
 
3441
- const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
 
3442
 
3443
- device const block_q2_K * x = (device const block_q2_K *) src0 + ib_row + offset0;
3444
- device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
3445
 
3446
  float yl[32];
3447
  float sumf[N_DST]={0.f}, all_sum;
3448
 
3449
- const int step = sizeof(block_q2_K) * nb;
3450
-
3451
  const int ix = tiisg/8; // 0...3
3452
  const int it = tiisg%8; // 0...7
3453
  const int iq = it/4; // 0 or 1
@@ -3492,9 +3562,9 @@ void kernel_mul_mv_q2_K_f32_impl(
3492
  (acc1[3] + 1.f/256.f * acc2[3]) * (sc[6] & 0xF) * 1.f/64.f) -
3493
  dmin * (sumy[0] * (sc[0] & 0xF0) + sumy[1] * (sc[2] & 0xF0) + sumy[2] * (sc[4] & 0xF0) + sumy[3] * (sc[6] & 0xF0));
3494
 
3495
- qs += step/2;
3496
- sc += step;
3497
- dh += step/2;
3498
  }
3499
 
3500
  y4 += 4 * QK_K;
@@ -3519,12 +3589,14 @@ kernel void kernel_mul_mv_q2_K_f32(
3519
  constant uint64_t & nb00,
3520
  constant uint64_t & nb01,
3521
  constant uint64_t & nb02,
 
3522
  constant int64_t & ne10,
3523
  constant int64_t & ne11,
3524
  constant int64_t & ne12,
3525
  constant uint64_t & nb10,
3526
  constant uint64_t & nb11,
3527
  constant uint64_t & nb12,
 
3528
  constant int64_t & ne0,
3529
  constant int64_t & ne1,
3530
  constant uint & r2,
@@ -3533,7 +3605,7 @@ kernel void kernel_mul_mv_q2_K_f32(
3533
  uint tiisg[[thread_index_in_simdgroup]],
3534
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
3535
 
3536
- kernel_mul_mv_q2_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg);
3537
  }
3538
 
3539
  void kernel_mul_mv_q3_K_f32_impl(
@@ -3543,8 +3615,14 @@ void kernel_mul_mv_q3_K_f32_impl(
3543
  int64_t ne00,
3544
  int64_t ne01,
3545
  int64_t ne02,
 
 
 
3546
  int64_t ne10,
3547
  int64_t ne12,
 
 
 
3548
  int64_t ne0,
3549
  int64_t ne1,
3550
  uint r2,
@@ -3565,10 +3643,11 @@ void kernel_mul_mv_q3_K_f32_impl(
3565
  const uint i12 = im%ne12;
3566
  const uint i13 = im/ne12;
3567
 
3568
- const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
 
3569
 
3570
- device const block_q3_K * x = (device const block_q3_K *) src0 + first_row*nb + offset0;
3571
- device const float * yy = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
3572
 
3573
  float yl[32];
3574
 
@@ -3608,8 +3687,6 @@ void kernel_mul_mv_q3_K_f32_impl(
3608
  const int q_offset = 32*ip + l0;
3609
  const int y_offset = 128*ip + 32*il + l0;
3610
 
3611
- const int step = sizeof(block_q3_K) * nb / 2;
3612
-
3613
  device const float * y1 = yy + ix*QK_K + y_offset;
3614
 
3615
  uint32_t scales32, aux32;
@@ -3619,7 +3696,6 @@ void kernel_mul_mv_q3_K_f32_impl(
3619
  float sumf1[2] = {0.f};
3620
  float sumf2[2] = {0.f};
3621
  for (int i = ix; i < nb; i += 4) {
3622
-
3623
  for (int l = 0; l < 8; ++l) {
3624
  yl[l+ 0] = y1[l+ 0];
3625
  yl[l+ 8] = y1[l+16];
@@ -3633,7 +3709,6 @@ void kernel_mul_mv_q3_K_f32_impl(
3633
  device const half * dh = &x[i].d;
3634
 
3635
  for (int row = 0; row < 2; ++row) {
3636
-
3637
  const float d_all = (float)dh[0];
3638
 
3639
  scales16[0] = a[4];
@@ -3673,15 +3748,13 @@ void kernel_mul_mv_q3_K_f32_impl(
3673
  sumf1[row] += d1 * (scales[1] - 32);
3674
  sumf2[row] += d2 * (scales[3] - 32);
3675
 
3676
- q += step;
3677
- h += step;
3678
- a += step;
3679
- dh += step;
3680
-
3681
  }
3682
 
3683
  y1 += 4 * QK_K;
3684
-
3685
  }
3686
 
3687
  for (int row = 0; row < 2; ++row) {
@@ -3706,12 +3779,14 @@ kernel void kernel_mul_mv_q3_K_f32(
3706
  constant uint64_t & nb00,
3707
  constant uint64_t & nb01,
3708
  constant uint64_t & nb02,
 
3709
  constant int64_t & ne10,
3710
  constant int64_t & ne11,
3711
  constant int64_t & ne12,
3712
  constant uint64_t & nb10,
3713
  constant uint64_t & nb11,
3714
  constant uint64_t & nb12,
 
3715
  constant int64_t & ne0,
3716
  constant int64_t & ne1,
3717
  constant uint & r2,
@@ -3720,7 +3795,7 @@ kernel void kernel_mul_mv_q3_K_f32(
3720
  uint tiisg[[thread_index_in_simdgroup]],
3721
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
3722
 
3723
- kernel_mul_mv_q3_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg);
3724
  }
3725
 
3726
  void kernel_mul_mv_q4_K_f32_impl(
@@ -3730,8 +3805,14 @@ void kernel_mul_mv_q4_K_f32_impl(
3730
  int64_t ne00,
3731
  int64_t ne01,
3732
  int64_t ne02,
 
 
 
3733
  int64_t ne10,
3734
  int64_t ne12,
 
 
 
3735
  int64_t ne0,
3736
  int64_t ne1,
3737
  uint r2,
@@ -3756,29 +3837,26 @@ void kernel_mul_mv_q4_K_f32_impl(
3756
  const int im = tgpig.z;
3757
  //const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
3758
  const int first_row = r0 * N_DST;
3759
- const int ib_row = first_row * nb;
3760
 
3761
  const uint i12 = im%ne12;
3762
  const uint i13 = im/ne12;
3763
 
3764
- const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
 
3765
 
3766
- device const block_q4_K * x = (device const block_q4_K *) src0 + ib_row + offset0;
3767
- device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
3768
 
3769
  float yl[16];
3770
  float yh[16];
3771
  float sumf[N_DST]={0.f}, all_sum;
3772
 
3773
- const int step = sizeof(block_q4_K) * nb / 2;
3774
-
3775
  device const float * y4 = y + ix * QK_K + 64 * iq + 8 * ir;
3776
 
3777
  uint16_t sc16[4];
3778
  thread const uint8_t * sc8 = (thread const uint8_t *)sc16;
3779
 
3780
  for (int ib = ix; ib < nb; ib += 4) {
3781
-
3782
  float4 sumy = {0.f, 0.f, 0.f, 0.f};
3783
  for (int i = 0; i < 8; ++i) {
3784
  yl[i+0] = y4[i+ 0]; sumy[0] += yl[i+0];
@@ -3792,7 +3870,6 @@ void kernel_mul_mv_q4_K_f32_impl(
3792
  device const half * dh = &x[ib].d;
3793
 
3794
  for (int row = 0; row < N_DST; row++) {
3795
-
3796
  sc16[0] = sc[0] & kmask1;
3797
  sc16[1] = sc[2] & kmask1;
3798
  sc16[2] = ((sc[4] >> 0) & kmask2) | ((sc[0] & kmask3) >> 2);
@@ -3821,9 +3898,9 @@ void kernel_mul_mv_q4_K_f32_impl(
3821
  (acc2[2] + 1.f/256.f * acc2[3]) * sc8[5] * 1.f/16.f) -
3822
  dmin * (sumy[0] * sc8[2] + sumy[1] * sc8[3] + sumy[2] * sc8[6] + sumy[3] * sc8[7]);
3823
 
3824
- q1 += step;
3825
- sc += step;
3826
- dh += step;
3827
  }
3828
 
3829
  y4 += 4 * QK_K;
@@ -3848,12 +3925,14 @@ kernel void kernel_mul_mv_q4_K_f32(
3848
  constant uint64_t & nb00,
3849
  constant uint64_t & nb01,
3850
  constant uint64_t & nb02,
 
3851
  constant int64_t & ne10,
3852
  constant int64_t & ne11,
3853
  constant int64_t & ne12,
3854
  constant uint64_t & nb10,
3855
  constant uint64_t & nb11,
3856
  constant uint64_t & nb12,
 
3857
  constant int64_t & ne0,
3858
  constant int64_t & ne1,
3859
  constant uint & r2,
@@ -3862,7 +3941,7 @@ kernel void kernel_mul_mv_q4_K_f32(
3862
  uint tiisg[[thread_index_in_simdgroup]],
3863
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
3864
 
3865
- kernel_mul_mv_q4_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg);
3866
  }
3867
 
3868
  void kernel_mul_mv_q5_K_f32_impl(
@@ -3872,8 +3951,14 @@ void kernel_mul_mv_q5_K_f32_impl(
3872
  int64_t ne00,
3873
  int64_t ne01,
3874
  int64_t ne02,
 
 
 
3875
  int64_t ne10,
3876
  int64_t ne12,
 
 
 
3877
  int64_t ne0,
3878
  int64_t ne1,
3879
  uint r2,
@@ -3894,15 +3979,14 @@ void kernel_mul_mv_q5_K_f32_impl(
3894
  const uint i12 = im%ne12;
3895
  const uint i13 = im/ne12;
3896
 
3897
- const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
 
3898
 
3899
- device const block_q5_K * x = (device const block_q5_K *) src0 + first_row*nb + offset0;
3900
- device const float * yy = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
3901
 
3902
  float sumf[2]={0.f};
3903
 
3904
- const int step = sizeof(block_q5_K) * nb;
3905
-
3906
  float yl[16], yh[16];
3907
 
3908
  const uint16_t kmask1 = 0x3f3f;
@@ -3930,7 +4014,6 @@ void kernel_mul_mv_q5_K_f32_impl(
3930
  device const float * y1 = yy + ix*QK_K + y_offset;
3931
 
3932
  for (int i = ix; i < nb; i += 4) {
3933
-
3934
  device const uint8_t * q1 = x[i].qs + q_offset;
3935
  device const uint8_t * qh = x[i].qh + l0;
3936
  device const half * dh = &x[i].d;
@@ -3946,7 +4029,6 @@ void kernel_mul_mv_q5_K_f32_impl(
3946
  }
3947
 
3948
  for (int row = 0; row < 2; ++row) {
3949
-
3950
  device const uint8_t * q2 = q1 + 64;
3951
 
3952
  sc16[0] = a[0] & kmask1;
@@ -3975,15 +4057,13 @@ void kernel_mul_mv_q5_K_f32_impl(
3975
  sc8[5] * (acc1[3]/16.f + 16.f*acc2[3])) -
3976
  dmin * (sumy[0] * sc8[2] + sumy[1] * sc8[3] + sumy[2] * sc8[6] + sumy[3] * sc8[7]);
3977
 
3978
- q1 += step;
3979
- qh += step;
3980
- dh += step/2;
3981
- a += step/2;
3982
-
3983
  }
3984
 
3985
  y1 += 4 * QK_K;
3986
-
3987
  }
3988
 
3989
  for (int row = 0; row < 2; ++row) {
@@ -4005,12 +4085,14 @@ kernel void kernel_mul_mv_q5_K_f32(
4005
  constant uint64_t & nb00,
4006
  constant uint64_t & nb01,
4007
  constant uint64_t & nb02,
 
4008
  constant int64_t & ne10,
4009
  constant int64_t & ne11,
4010
  constant int64_t & ne12,
4011
  constant uint64_t & nb10,
4012
  constant uint64_t & nb11,
4013
  constant uint64_t & nb12,
 
4014
  constant int64_t & ne0,
4015
  constant int64_t & ne1,
4016
  constant uint & r2,
@@ -4019,7 +4101,7 @@ kernel void kernel_mul_mv_q5_K_f32(
4019
  uint tiisg[[thread_index_in_simdgroup]],
4020
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
4021
 
4022
- kernel_mul_mv_q5_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg);
4023
  }
4024
 
4025
  void kernel_mul_mv_q6_K_f32_impl(
@@ -4029,8 +4111,14 @@ void kernel_mul_mv_q6_K_f32_impl(
4029
  int64_t ne00,
4030
  int64_t ne01,
4031
  int64_t ne02,
 
 
 
4032
  int64_t ne10,
4033
  int64_t ne12,
 
 
 
4034
  int64_t ne0,
4035
  int64_t ne1,
4036
  uint r2,
@@ -4056,10 +4144,11 @@ void kernel_mul_mv_q6_K_f32_impl(
4056
  const uint i12 = im%ne12;
4057
  const uint i13 = im/ne12;
4058
 
4059
- const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
 
4060
 
4061
- device const block_q6_K * x = (device const block_q6_K *) src0 + row * nb + offset0;
4062
- device const float * yy = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
4063
 
4064
  float sumf = 0;
4065
 
@@ -4115,12 +4204,14 @@ kernel void kernel_mul_mv_q6_K_f32(
4115
  constant uint64_t & nb00,
4116
  constant uint64_t & nb01,
4117
  constant uint64_t & nb02,
 
4118
  constant int64_t & ne10,
4119
  constant int64_t & ne11,
4120
  constant int64_t & ne12,
4121
  constant uint64_t & nb10,
4122
  constant uint64_t & nb11,
4123
  constant uint64_t & nb12,
 
4124
  constant int64_t & ne0,
4125
  constant int64_t & ne1,
4126
  constant uint & r2,
@@ -4129,7 +4220,7 @@ kernel void kernel_mul_mv_q6_K_f32(
4129
  uint tiisg[[thread_index_in_simdgroup]],
4130
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
4131
 
4132
- kernel_mul_mv_q6_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg);
4133
  }
4134
 
4135
  // ======================= "True" 2-bit
@@ -4141,8 +4232,14 @@ void kernel_mul_mv_iq2_xxs_f32_impl(
4141
  int64_t ne00,
4142
  int64_t ne01,
4143
  int64_t ne02,
 
 
 
4144
  int64_t ne10,
4145
  int64_t ne12,
 
 
 
4146
  int64_t ne0,
4147
  int64_t ne1,
4148
  uint r2,
@@ -4158,15 +4255,15 @@ void kernel_mul_mv_iq2_xxs_f32_impl(
4158
  const int im = tgpig.z;
4159
 
4160
  const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
4161
- const int ib_row = first_row * nb;
4162
 
4163
  const uint i12 = im%ne12;
4164
  const uint i13 = im/ne12;
4165
 
4166
- const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
 
4167
 
4168
- device const block_iq2_xxs * x = (device const block_iq2_xxs *) src0 + ib_row + offset0;
4169
- device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
4170
 
4171
  float yl[32];
4172
  float sumf[N_DST]={0.f}, all_sum;
@@ -4219,8 +4316,8 @@ void kernel_mul_mv_iq2_xxs_f32_impl(
4219
  }
4220
  sumf[row] += d * sum;
4221
 
4222
- dh += nb*sizeof(block_iq2_xxs)/2;
4223
- q2 += nb*sizeof(block_iq2_xxs)/2;
4224
  }
4225
 
4226
  y4 += 32 * 32;
@@ -4245,12 +4342,14 @@ kernel void kernel_mul_mv_iq2_xxs_f32(
4245
  constant uint64_t & nb00,
4246
  constant uint64_t & nb01,
4247
  constant uint64_t & nb02,
 
4248
  constant int64_t & ne10,
4249
  constant int64_t & ne11,
4250
  constant int64_t & ne12,
4251
  constant uint64_t & nb10,
4252
  constant uint64_t & nb11,
4253
  constant uint64_t & nb12,
 
4254
  constant int64_t & ne0,
4255
  constant int64_t & ne1,
4256
  constant uint & r2,
@@ -4260,7 +4359,7 @@ kernel void kernel_mul_mv_iq2_xxs_f32(
4260
  uint tiisg[[thread_index_in_simdgroup]],
4261
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
4262
 
4263
- kernel_mul_mv_iq2_xxs_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
4264
  }
4265
 
4266
  void kernel_mul_mv_iq2_xs_f32_impl(
@@ -4270,8 +4369,14 @@ void kernel_mul_mv_iq2_xs_f32_impl(
4270
  int64_t ne00,
4271
  int64_t ne01,
4272
  int64_t ne02,
 
 
 
4273
  int64_t ne10,
4274
  int64_t ne12,
 
 
 
4275
  int64_t ne0,
4276
  int64_t ne1,
4277
  uint r2,
@@ -4287,15 +4392,15 @@ void kernel_mul_mv_iq2_xs_f32_impl(
4287
  const int im = tgpig.z;
4288
 
4289
  const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
4290
- const int ib_row = first_row * nb;
4291
 
4292
  const uint i12 = im%ne12;
4293
  const uint i13 = im/ne12;
4294
 
4295
- const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
 
4296
 
4297
- device const block_iq2_xs * x = (device const block_iq2_xs *) src0 + ib_row + offset0;
4298
- device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
4299
 
4300
  float yl[32];
4301
  float sumf[N_DST]={0.f}, all_sum;
@@ -4357,9 +4462,9 @@ void kernel_mul_mv_iq2_xs_f32_impl(
4357
  }
4358
  sumf[row] += d1 * sum1 + d2 * sum2;
4359
 
4360
- dh += nb*sizeof(block_iq2_xs)/2;
4361
- q2 += nb*sizeof(block_iq2_xs)/2;
4362
- sc += nb*sizeof(block_iq2_xs);
4363
  }
4364
 
4365
  y4 += 32 * 32;
@@ -4384,12 +4489,14 @@ kernel void kernel_mul_mv_iq2_xs_f32(
4384
  constant uint64_t & nb00,
4385
  constant uint64_t & nb01,
4386
  constant uint64_t & nb02,
 
4387
  constant int64_t & ne10,
4388
  constant int64_t & ne11,
4389
  constant int64_t & ne12,
4390
  constant uint64_t & nb10,
4391
  constant uint64_t & nb11,
4392
  constant uint64_t & nb12,
 
4393
  constant int64_t & ne0,
4394
  constant int64_t & ne1,
4395
  constant uint & r2,
@@ -4399,7 +4506,7 @@ kernel void kernel_mul_mv_iq2_xs_f32(
4399
  uint tiisg[[thread_index_in_simdgroup]],
4400
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
4401
 
4402
- kernel_mul_mv_iq2_xs_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
4403
  }
4404
 
4405
  void kernel_mul_mv_iq3_xxs_f32_impl(
@@ -4409,8 +4516,14 @@ void kernel_mul_mv_iq3_xxs_f32_impl(
4409
  int64_t ne00,
4410
  int64_t ne01,
4411
  int64_t ne02,
 
 
 
4412
  int64_t ne10,
4413
  int64_t ne12,
 
 
 
4414
  int64_t ne0,
4415
  int64_t ne1,
4416
  uint r2,
@@ -4426,15 +4539,15 @@ void kernel_mul_mv_iq3_xxs_f32_impl(
4426
  const int im = tgpig.z;
4427
 
4428
  const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
4429
- const int ib_row = first_row * nb;
4430
 
4431
  const uint i12 = im%ne12;
4432
  const uint i13 = im/ne12;
4433
 
4434
- const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
 
4435
 
4436
- device const block_iq3_xxs * x = (device const block_iq3_xxs *) src0 + ib_row + offset0;
4437
- device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
4438
 
4439
  float yl[32];
4440
  float sumf[N_DST]={0.f}, all_sum;
@@ -4489,9 +4602,9 @@ void kernel_mul_mv_iq3_xxs_f32_impl(
4489
  }
4490
  sumf[row] += d * (sum[0] + sum[1]);
4491
 
4492
- dh += nb*sizeof(block_iq3_xxs)/2;
4493
- q3 += nb*sizeof(block_iq3_xxs);
4494
- gas += nb*sizeof(block_iq3_xxs)/2;
4495
  }
4496
 
4497
  y4 += 32 * 32;
@@ -4516,12 +4629,14 @@ kernel void kernel_mul_mv_iq3_xxs_f32(
4516
  constant uint64_t & nb00,
4517
  constant uint64_t & nb01,
4518
  constant uint64_t & nb02,
 
4519
  constant int64_t & ne10,
4520
  constant int64_t & ne11,
4521
  constant int64_t & ne12,
4522
  constant uint64_t & nb10,
4523
  constant uint64_t & nb11,
4524
  constant uint64_t & nb12,
 
4525
  constant int64_t & ne0,
4526
  constant int64_t & ne1,
4527
  constant uint & r2,
@@ -4531,7 +4646,7 @@ kernel void kernel_mul_mv_iq3_xxs_f32(
4531
  uint tiisg[[thread_index_in_simdgroup]],
4532
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
4533
 
4534
- kernel_mul_mv_iq3_xxs_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
4535
  }
4536
 
4537
  void kernel_mul_mv_iq3_s_f32_impl(
@@ -4541,8 +4656,14 @@ void kernel_mul_mv_iq3_s_f32_impl(
4541
  int64_t ne00,
4542
  int64_t ne01,
4543
  int64_t ne02,
 
 
 
4544
  int64_t ne10,
4545
  int64_t ne12,
 
 
 
4546
  int64_t ne0,
4547
  int64_t ne1,
4548
  uint r2,
@@ -4558,15 +4679,15 @@ void kernel_mul_mv_iq3_s_f32_impl(
4558
  const int im = tgpig.z;
4559
 
4560
  const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
4561
- const int ib_row = first_row * nb;
4562
 
4563
  const uint i12 = im%ne12;
4564
  const uint i13 = im/ne12;
4565
 
4566
- const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
 
4567
 
4568
- device const block_iq3_s * x = (device const block_iq3_s *) src0 + ib_row + offset0;
4569
- device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
4570
 
4571
  float yl[32];
4572
  float sumf[N_DST]={0.f}, all_sum;
@@ -4619,11 +4740,11 @@ void kernel_mul_mv_iq3_s_f32_impl(
4619
  }
4620
  sumf[row] += d * (sum[0] + sum[1]);
4621
 
4622
- dh += nb*sizeof(block_iq3_s)/2;
4623
- qs += nb*sizeof(block_iq3_s);
4624
- qh += nb*sizeof(block_iq3_s);
4625
- sc += nb*sizeof(block_iq3_s);
4626
- signs += nb*sizeof(block_iq3_s);
4627
  }
4628
 
4629
  y4 += 32 * 32;
@@ -4648,12 +4769,14 @@ kernel void kernel_mul_mv_iq3_s_f32(
4648
  constant uint64_t & nb00,
4649
  constant uint64_t & nb01,
4650
  constant uint64_t & nb02,
 
4651
  constant int64_t & ne10,
4652
  constant int64_t & ne11,
4653
  constant int64_t & ne12,
4654
  constant uint64_t & nb10,
4655
  constant uint64_t & nb11,
4656
  constant uint64_t & nb12,
 
4657
  constant int64_t & ne0,
4658
  constant int64_t & ne1,
4659
  constant uint & r2,
@@ -4663,7 +4786,7 @@ kernel void kernel_mul_mv_iq3_s_f32(
4663
  uint tiisg[[thread_index_in_simdgroup]],
4664
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
4665
 
4666
- kernel_mul_mv_iq3_s_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
4667
  }
4668
 
4669
  void kernel_mul_mv_iq2_s_f32_impl(
@@ -4673,8 +4796,14 @@ void kernel_mul_mv_iq2_s_f32_impl(
4673
  int64_t ne00,
4674
  int64_t ne01,
4675
  int64_t ne02,
 
 
 
4676
  int64_t ne10,
4677
  int64_t ne12,
 
 
 
4678
  int64_t ne0,
4679
  int64_t ne1,
4680
  uint r2,
@@ -4690,15 +4819,15 @@ void kernel_mul_mv_iq2_s_f32_impl(
4690
  const int im = tgpig.z;
4691
 
4692
  const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
4693
- const int ib_row = first_row * nb;
4694
 
4695
  const uint i12 = im%ne12;
4696
  const uint i13 = im/ne12;
4697
 
4698
- const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
 
4699
 
4700
- device const block_iq2_s * x = (device const block_iq2_s *) src0 + ib_row + offset0;
4701
- device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
4702
 
4703
  float yl[32];
4704
  float sumf[N_DST]={0.f}, all_sum;
@@ -4752,11 +4881,11 @@ void kernel_mul_mv_iq2_s_f32_impl(
4752
  }
4753
  sumf[row] += d1 * sum[0] + d2 * sum[1];
4754
 
4755
- dh += nb*sizeof(block_iq2_s)/2;
4756
- qs += nb*sizeof(block_iq2_s);
4757
- qh += nb*sizeof(block_iq2_s);
4758
- sc += nb*sizeof(block_iq2_s);
4759
- signs += nb*sizeof(block_iq2_s);
4760
  }
4761
 
4762
  y4 += 32 * 32;
@@ -4781,12 +4910,14 @@ kernel void kernel_mul_mv_iq2_s_f32(
4781
  constant uint64_t & nb00,
4782
  constant uint64_t & nb01,
4783
  constant uint64_t & nb02,
 
4784
  constant int64_t & ne10,
4785
  constant int64_t & ne11,
4786
  constant int64_t & ne12,
4787
  constant uint64_t & nb10,
4788
  constant uint64_t & nb11,
4789
  constant uint64_t & nb12,
 
4790
  constant int64_t & ne0,
4791
  constant int64_t & ne1,
4792
  constant uint & r2,
@@ -4796,7 +4927,7 @@ kernel void kernel_mul_mv_iq2_s_f32(
4796
  uint tiisg[[thread_index_in_simdgroup]],
4797
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
4798
 
4799
- kernel_mul_mv_iq2_s_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
4800
  }
4801
 
4802
  void kernel_mul_mv_iq1_s_f32_impl(
@@ -4806,8 +4937,14 @@ void kernel_mul_mv_iq1_s_f32_impl(
4806
  int64_t ne00,
4807
  int64_t ne01,
4808
  int64_t ne02,
 
 
 
4809
  int64_t ne10,
4810
  int64_t ne12,
 
 
 
4811
  int64_t ne0,
4812
  int64_t ne1,
4813
  uint r2,
@@ -4823,14 +4960,15 @@ void kernel_mul_mv_iq1_s_f32_impl(
4823
  const int im = tgpig.z;
4824
 
4825
  const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
4826
- const int ib_row = first_row * nb;
4827
 
4828
  const uint i12 = im%ne12;
4829
  const uint i13 = im/ne12;
4830
 
4831
- const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
4832
- device const block_iq1_s * x = (device const block_iq1_s *) src0 + ib_row + offset0;
4833
- device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
 
 
4834
 
4835
  float yl[32];
4836
  float sumf[N_DST]={0.f}, all_sum;
@@ -4873,9 +5011,9 @@ void kernel_mul_mv_iq1_s_f32_impl(
4873
  }
4874
  sumf[row] += (float)dh[0] * (sum + sumy * (qh[0] & 0x8000 ? -1 - IQ1S_DELTA : -1 + IQ1S_DELTA)) * (2*((qh[0] >> 12) & 7) + 1);
4875
 
4876
- dh += nb*sizeof(block_iq1_s)/2;
4877
- qs += nb*sizeof(block_iq1_s);
4878
- qh += nb*sizeof(block_iq1_s)/2;
4879
  }
4880
 
4881
  y4 += 32 * 32;
@@ -4896,8 +5034,14 @@ void kernel_mul_mv_iq1_m_f32_impl(
4896
  int64_t ne00,
4897
  int64_t ne01,
4898
  int64_t ne02,
 
 
 
4899
  int64_t ne10,
4900
  int64_t ne12,
 
 
 
4901
  int64_t ne0,
4902
  int64_t ne1,
4903
  uint r2,
@@ -4913,14 +5057,15 @@ void kernel_mul_mv_iq1_m_f32_impl(
4913
  const int im = tgpig.z;
4914
 
4915
  const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
4916
- const int ib_row = first_row * nb;
4917
 
4918
  const uint i12 = im%ne12;
4919
  const uint i13 = im/ne12;
4920
 
4921
- const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
4922
- device const block_iq1_m * x = (device const block_iq1_m *) src0 + ib_row + offset0;
4923
- device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
 
 
4924
 
4925
  float yl[32];
4926
  float sumf[N_DST]={0.f}, all_sum;
@@ -4972,9 +5117,9 @@ void kernel_mul_mv_iq1_m_f32_impl(
4972
  sumf[row] += (float)scale.f16 * ((sum[0] + delta1) * (2*((sc[ib/2] >> (6*(ib%2)+0)) & 7) + 1) +
4973
  (sum[1] + delta2) * (2*((sc[ib/2] >> (6*(ib%2)+3)) & 7) + 1));
4974
 
4975
- sc += nb*sizeof(block_iq1_m)/2;
4976
- qs += nb*sizeof(block_iq1_m);
4977
- qh += nb*sizeof(block_iq1_m);
4978
  }
4979
 
4980
  y4 += 32 * 32;
@@ -4995,8 +5140,14 @@ void kernel_mul_mv_iq4_nl_f32_impl(
4995
  int64_t ne00,
4996
  int64_t ne01,
4997
  int64_t ne02,
 
 
 
4998
  int64_t ne10,
4999
  int64_t ne12,
 
 
 
5000
  int64_t ne0,
5001
  int64_t ne1,
5002
  uint r2,
@@ -5012,14 +5163,15 @@ void kernel_mul_mv_iq4_nl_f32_impl(
5012
  const int r1 = tgpig.y;
5013
  const int im = tgpig.z;
5014
  const int first_row = (r0 * 2 + sgitg) * 2;
5015
- const int ib_row = first_row * nb;
5016
 
5017
  const uint i12 = im%ne12;
5018
  const uint i13 = im/ne12;
5019
 
5020
- const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
5021
- device const block_iq4_nl * x = (device const block_iq4_nl *) src0 + ib_row + offset0;
5022
- device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
 
 
5023
 
5024
  const int ix = tiisg/2; // 0...15
5025
  const int it = tiisg%2; // 0 or 1
@@ -5089,8 +5241,14 @@ void kernel_mul_mv_iq4_xs_f32_impl(
5089
  int64_t ne00,
5090
  int64_t ne01,
5091
  int64_t ne02,
 
 
 
5092
  int64_t ne10,
5093
  int64_t ne12,
 
 
 
5094
  int64_t ne0,
5095
  int64_t ne1,
5096
  uint r2,
@@ -5106,14 +5264,15 @@ void kernel_mul_mv_iq4_xs_f32_impl(
5106
  const int r1 = tgpig.y;
5107
  const int im = tgpig.z;
5108
  const int first_row = (r0 * 2 + sgitg) * 2;
5109
- const int ib_row = first_row * nb;
5110
 
5111
  const uint i12 = im%ne12;
5112
  const uint i13 = im/ne12;
5113
 
5114
- const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
5115
- device const block_iq4_xs * x = (device const block_iq4_xs *) src0 + ib_row + offset0;
5116
- device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
 
 
5117
 
5118
  const int ix = tiisg/16; // 0 or 1
5119
  const int it = tiisg%16; // 0...15
@@ -5188,12 +5347,14 @@ kernel void kernel_mul_mv_iq1_s_f32(
5188
  constant uint64_t & nb00,
5189
  constant uint64_t & nb01,
5190
  constant uint64_t & nb02,
 
5191
  constant int64_t & ne10,
5192
  constant int64_t & ne11,
5193
  constant int64_t & ne12,
5194
  constant uint64_t & nb10,
5195
  constant uint64_t & nb11,
5196
  constant uint64_t & nb12,
 
5197
  constant int64_t & ne0,
5198
  constant int64_t & ne1,
5199
  constant uint & r2,
@@ -5202,7 +5363,7 @@ kernel void kernel_mul_mv_iq1_s_f32(
5202
  uint tiisg[[thread_index_in_simdgroup]],
5203
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
5204
 
5205
- kernel_mul_mv_iq1_s_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg);
5206
  }
5207
 
5208
  [[host_name("kernel_mul_mv_iq1_m_f32")]]
@@ -5216,12 +5377,14 @@ kernel void kernel_mul_mv_iq1_m_f32(
5216
  constant uint64_t & nb00,
5217
  constant uint64_t & nb01,
5218
  constant uint64_t & nb02,
 
5219
  constant int64_t & ne10,
5220
  constant int64_t & ne11,
5221
  constant int64_t & ne12,
5222
  constant uint64_t & nb10,
5223
  constant uint64_t & nb11,
5224
  constant uint64_t & nb12,
 
5225
  constant int64_t & ne0,
5226
  constant int64_t & ne1,
5227
  constant uint & r2,
@@ -5230,7 +5393,7 @@ kernel void kernel_mul_mv_iq1_m_f32(
5230
  uint tiisg[[thread_index_in_simdgroup]],
5231
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
5232
 
5233
- kernel_mul_mv_iq1_m_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg);
5234
  }
5235
 
5236
  [[host_name("kernel_mul_mv_iq4_nl_f32")]]
@@ -5244,12 +5407,14 @@ kernel void kernel_mul_mv_iq4_nl_f32(
5244
  constant uint64_t & nb00,
5245
  constant uint64_t & nb01,
5246
  constant uint64_t & nb02,
 
5247
  constant int64_t & ne10,
5248
  constant int64_t & ne11,
5249
  constant int64_t & ne12,
5250
  constant uint64_t & nb10,
5251
  constant uint64_t & nb11,
5252
  constant uint64_t & nb12,
 
5253
  constant int64_t & ne0,
5254
  constant int64_t & ne1,
5255
  constant uint & r2,
@@ -5259,7 +5424,7 @@ kernel void kernel_mul_mv_iq4_nl_f32(
5259
  uint tiisg[[thread_index_in_simdgroup]],
5260
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
5261
 
5262
- kernel_mul_mv_iq4_nl_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
5263
  }
5264
 
5265
  [[host_name("kernel_mul_mv_iq4_xs_f32")]]
@@ -5273,12 +5438,14 @@ kernel void kernel_mul_mv_iq4_xs_f32(
5273
  constant uint64_t & nb00,
5274
  constant uint64_t & nb01,
5275
  constant uint64_t & nb02,
 
5276
  constant int64_t & ne10,
5277
  constant int64_t & ne11,
5278
  constant int64_t & ne12,
5279
  constant uint64_t & nb10,
5280
  constant uint64_t & nb11,
5281
  constant uint64_t & nb12,
 
5282
  constant int64_t & ne0,
5283
  constant int64_t & ne1,
5284
  constant uint & r2,
@@ -5288,7 +5455,7 @@ kernel void kernel_mul_mv_iq4_xs_f32(
5288
  uint tiisg[[thread_index_in_simdgroup]],
5289
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
5290
 
5291
- kernel_mul_mv_iq4_xs_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
5292
  }
5293
 
5294
  //============================= templates and their specializations =============================
@@ -5833,10 +6000,12 @@ kernel void kernel_mul_mm(device const uchar * src0,
5833
  constant int64_t & ne02,
5834
  constant uint64_t & nb01,
5835
  constant uint64_t & nb02,
 
5836
  constant int64_t & ne12,
5837
  constant uint64_t & nb10,
5838
  constant uint64_t & nb11,
5839
  constant uint64_t & nb12,
 
5840
  constant int64_t & ne0,
5841
  constant int64_t & ne1,
5842
  constant uint & r2,
@@ -5873,12 +6042,13 @@ kernel void kernel_mul_mm(device const uchar * src0,
5873
  const uint i12 = im%ne12;
5874
  const uint i13 = im/ne12;
5875
 
5876
- uint offset0 = (i12/r2)*nb02 + (i13/r3)*(nb02*ne02);
5877
  ushort offset1 = il/nl;
5878
 
5879
  device const block_q * x = (device const block_q *)(src0 + (r0 * BLOCK_SIZE_M + thread_row) * nb01 + offset0) + offset1;
5880
  device const float * y = (device const float *)(src1
5881
- + nb12 * im
 
5882
  + nb11 * (r1 * BLOCK_SIZE_N + thread_col)
5883
  + nb10 * (BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL)));
5884
 
@@ -6257,12 +6427,14 @@ typedef void (kernel_mul_mv_impl_t)(
6257
  uint64_t nb00,
6258
  uint64_t nb01,
6259
  uint64_t nb02,
 
6260
  int64_t ne10,
6261
  int64_t ne11,
6262
  int64_t ne12,
6263
  uint64_t nb10,
6264
  uint64_t nb11,
6265
  uint64_t nb12,
 
6266
  int64_t ne0,
6267
  int64_t ne1,
6268
  uint r2,
@@ -6277,8 +6449,14 @@ typedef void (kernel_mul_mv2_impl_t)(
6277
  int64_t ne00,
6278
  int64_t ne01,
6279
  int64_t ne02,
 
 
 
6280
  int64_t ne10,
6281
  int64_t ne12,
 
 
 
6282
  int64_t ne0,
6283
  int64_t ne1,
6284
  uint r2,
@@ -6299,6 +6477,7 @@ void mmv_fn(
6299
  uint64_t nb00,
6300
  uint64_t nb01,
6301
  uint64_t nb02,
 
6302
  int64_t ne10,
6303
  int64_t ne11,
6304
  int64_t ne12,
@@ -6306,6 +6485,7 @@ void mmv_fn(
6306
  uint64_t nb10,
6307
  uint64_t nb11,
6308
  uint64_t nb12,
 
6309
  int64_t ne0,
6310
  int64_t ne1,
6311
  uint64_t nb1,
@@ -6316,7 +6496,7 @@ void mmv_fn(
6316
  uint tiitg,
6317
  uint tiisg,
6318
  uint sgitg) {
6319
- impl_fn(src0,src1,dst,ne00,ne01,ne02,nb00,nb01,nb02,ne10,ne11,ne12,nb10,nb11,nb12,ne0,ne1,r2,r3,tgpig,tiisg);
6320
  }
6321
 
6322
  template<kernel_mul_mv2_impl_t impl_fn>
@@ -6330,6 +6510,7 @@ void mmv_fn(
6330
  uint64_t nb00,
6331
  uint64_t nb01,
6332
  uint64_t nb02,
 
6333
  int64_t ne10,
6334
  int64_t ne11,
6335
  int64_t ne12,
@@ -6337,6 +6518,7 @@ void mmv_fn(
6337
  uint64_t nb10,
6338
  uint64_t nb11,
6339
  uint64_t nb12,
 
6340
  int64_t ne0,
6341
  int64_t ne1,
6342
  uint64_t nb1,
@@ -6347,7 +6529,7 @@ void mmv_fn(
6347
  uint tiitg,
6348
  uint tiisg,
6349
  uint sgitg) {
6350
- impl_fn(src0,(const device float *)src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,shared_values,tgpig,tiisg,sgitg);
6351
  }
6352
 
6353
  typedef decltype(mmv_fn<kernel_mul_mv_impl<half, half4, half, half4>>) mul_mv_impl_fn_t;
@@ -6396,8 +6578,8 @@ kernel void kernel_mul_mv_id(
6396
  const int64_t i2 = i12;
6397
 
6398
  device const char * src0_cur = src0s + i02*nb02;
6399
- device const char * src1_cur = src1 + i11*nb11 + i12*nb12;
6400
- device float * dst_cur = dst + i1*ne0 + i2*ne1*ne0;
6401
 
6402
  impl_fn(
6403
  /* src0 */ src0_cur,
@@ -6405,19 +6587,21 @@ kernel void kernel_mul_mv_id(
6405
  /* dst */ dst_cur,
6406
  /* ne00 */ ne00,
6407
  /* ne01 */ ne01,
6408
- /* ne02 */ 1,//ne02,
6409
  /* nb00 */ nb00,
6410
  /* nb01 */ nb01,
6411
  /* nb02 */ nb02,
 
6412
  /* ne10 */ ne10,
6413
- /* ne11 */ 1,//ne11,
6414
- /* ne12 */ 1,//ne12,
6415
- /* ne13 */ 1,//ne13,
6416
  /* nb10 */ nb10,
6417
  /* nb11 */ nb11,
6418
  /* nb12 */ nb12,
 
6419
  /* ne0 */ ne0,
6420
- /* ne1 */ 1,//ne1,
6421
  /* nb1 */ nb1,
6422
  /* r2 */ 1,
6423
  /* r3 */ 1,
 
777
  const int64_t i3 = tgpig.z;
778
 
779
  const int64_t nc = ne10;
780
+ //const int64_t ncs = ne00;
781
+ //const int64_t nr = ne01;
782
+ //const int64_t n_t = ne1;
783
+ //const int64_t n_s = ne2;
784
 
785
  device const float * s = (device const float *) ((device const char *) src0 + ir*nb01 + i2*nb00 + i3*nb02);
786
  device const float * c = (device const float *) ((device const char *) src1 + ir*nb11);
 
834
  const int64_t i3 = tgpig.y;
835
 
836
  const int64_t nc = d_state;
837
+ //const int64_t nr = d_inner;
838
  const int64_t n_t = n_seq_tokens;
839
+ //const int64_t n_s = n_seqs;
840
 
841
  for (int64_t i2 = 0; i2 < n_t; ++i2) {
842
  device const float * s0 = (device const float *) ((device const char *) src0 + ir*nb01 + i3*nb02);
 
1064
  inline float block_q_n_dot_y(device const block_q4_0 * qb_curr, float sumy, thread float * yl, int il) {
1065
  float d = qb_curr->d;
1066
 
1067
+ float acc[4] = { 0.0f, 0.0f, 0.0f, 0.0f };
1068
 
1069
+ device const uint16_t * qs = ((device const uint16_t *) qb_curr + 1 + il/2);
1070
 
1071
+ for (int i = 0; i < 8; i += 2) {
1072
+ acc[0] += yl[i + 0] * (qs[i / 2] & 0x000F);
1073
+ acc[1] += yl[i + 1] * (qs[i / 2] & 0x0F00);
1074
+ acc[2] += yl[i + 8] * (qs[i / 2] & 0x00F0);
1075
+ acc[3] += yl[i + 9] * (qs[i / 2] & 0xF000);
1076
  }
1077
+
1078
+ return d * (sumy * -8.f + acc[0] + acc[1] + acc[2] + acc[3]);
1079
  }
1080
 
1081
  // function for calculate inner product between half a q4_1 block and 16 floats (yl), sumy is SUM(yl[i])
 
1086
  float d = qb_curr->d;
1087
  float m = qb_curr->m;
1088
 
1089
+ float acc[4] = { 0.0f, 0.0f, 0.0f, 0.0f };
1090
 
1091
+ device const uint16_t * qs = ((device const uint16_t *) qb_curr + 2 + il/2);
1092
 
1093
  for (int i = 0; i < 8; i+=2) {
1094
+ acc[0] += yl[i + 0] * (qs[i / 2] & 0x000F);
1095
+ acc[1] += yl[i + 1] * (qs[i / 2] & 0x0F00);
1096
+ acc[2] += yl[i + 8] * (qs[i / 2] & 0x00F0);
1097
+ acc[3] += yl[i + 9] * (qs[i / 2] & 0xF000);
1098
  }
1099
+
1100
+ return d * (acc[0] + acc[1] + acc[2] + acc[3]) + sumy * m;
1101
  }
1102
 
1103
  // function for calculate inner product between half a q5_0 block and 16 floats (yl), sumy is SUM(yl[i])
 
1107
  inline float block_q_n_dot_y(device const block_q5_0 * qb_curr, float sumy, thread float * yl, int il) {
1108
  float d = qb_curr->d;
1109
 
1110
+ float acc[4] = { 0.0f, 0.0f, 0.0f, 0.0f };
1111
 
1112
  device const uint16_t * qs = ((device const uint16_t *)qb_curr + 3 + il/2);
1113
  const uint32_t qh = *((device const uint32_t *)qb_curr->qh);
1114
 
1115
  for (int i = 0; i < 8; i+=2) {
1116
+ acc[0] += yl[i + 0] * ((qs[i / 2] & 0x000F) | ((qh >> (i+0+il ) << 4 ) & 0x00010));
1117
+ acc[1] += yl[i + 1] * ((qs[i / 2] & 0x0F00) | ((qh >> (i+1+il ) << 12) & 0x01000));
1118
+ acc[2] += yl[i + 8] * ((qs[i / 2] & 0x00F0) | ((qh >> (i+0+il+QK5_0/2) << 8 ) & 0x00100));
1119
+ acc[3] += yl[i + 9] * ((qs[i / 2] & 0xF000) | ((qh >> (i+1+il+QK5_0/2) << 16) & 0x10000));
1120
  }
1121
+
1122
+ return d * (sumy * -16.f + acc[0] + acc[1] + acc[2] + acc[3]);
1123
  }
1124
 
1125
  // function for calculate inner product between half a q5_1 block and 16 floats (yl), sumy is SUM(yl[i])
 
1130
  float d = qb_curr->d;
1131
  float m = qb_curr->m;
1132
 
1133
+ float acc[4] = { 0.0f, 0.0f, 0.0f, 0.0f };
1134
 
1135
  device const uint16_t * qs = ((device const uint16_t *)qb_curr + 4 + il/2);
1136
  const uint32_t qh = *((device const uint32_t *)qb_curr->qh);
1137
 
1138
  for (int i = 0; i < 8; i+=2) {
1139
+ acc[0] += yl[i + 0] * ((qs[i / 2] & 0x000F) | ((qh >> (i+0+il ) << 4 ) & 0x00010));
1140
+ acc[1] += yl[i + 1] * ((qs[i / 2] & 0x0F00) | ((qh >> (i+1+il ) << 12) & 0x01000));
1141
+ acc[2] += yl[i + 8] * ((qs[i / 2] & 0x00F0) | ((qh >> (i+0+il+QK5_0/2) << 8 ) & 0x00100));
1142
+ acc[3] += yl[i + 9] * ((qs[i / 2] & 0xF000) | ((qh >> (i+1+il+QK5_0/2) << 16) & 0x10000));
1143
  }
1144
+
1145
+ return d * (acc[0] + acc[1] + acc[2] + acc[3]) + sumy * m;
1146
  }
1147
 
1148
  // putting them in the kernel cause a significant performance penalty
 
1160
  int64_t ne00,
1161
  int64_t ne01,
1162
  int64_t ne02,
1163
+ uint64_t nb01,
1164
+ uint64_t nb02,
1165
+ uint64_t nb03,
1166
  int64_t ne10,
1167
  int64_t ne12,
1168
+ uint64_t nb11,
1169
+ uint64_t nb12,
1170
+ uint64_t nb13,
1171
  int64_t ne0,
1172
  int64_t ne1,
1173
  uint r2,
1174
  uint r3,
1175
  threadgroup int8_t * shared_values,
1176
+ uint3 tgpig,
1177
+ uint tiisg,
1178
+ uint sgitg) {
1179
  const int nb = ne00/QK4_0;
1180
 
1181
  const int r0 = tgpig.x;
 
1187
  const uint i12 = im%ne12;
1188
  const uint i13 = im/ne12;
1189
 
1190
+ //const uint offset0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03;
1191
+ const uint offset1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13;
1192
 
1193
+ //device const block_q_type * x = (device const block_q_type *) ((device char *) src0 + offset0);
1194
+ device const float * y = (device const float *) ((device char *) src1 + offset1);
1195
+
1196
+ // pointers to src0 rows
1197
+ device const block_q_type * ax[nr];
1198
+ for (int row = 0; row < nr; ++row) {
1199
+ const uint offset0 = (first_row + row)*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03;
1200
+
1201
+ ax[row] = (device const block_q_type *) ((device char *) src0 + offset0);
1202
+ }
1203
 
1204
  float yl[16]; // src1 vector cache
1205
  float sumf[nr] = {0.f};
 
1211
 
1212
  // each thread in a SIMD group deals with half a block.
1213
  for (int ib = ix; ib < nb; ib += nw/2) {
1214
+ float sumy[2] = { 0.f, 0.f };
1215
+
1216
+ #pragma unroll
1217
  for (int i = 0; i < 8; i += 2) {
1218
+ sumy[0] += yb[i + 0] + yb[i + 1];
1219
+ yl[i + 0] = yb[i + 0];
1220
+ yl[i + 1] = yb[i + 1]/256.f;
1221
 
1222
+ sumy[1] += yb[i + 16] + yb[i + 17];
1223
+ yl[i + 8] = yb[i + 16]/16.f;
1224
+ yl[i + 9] = yb[i + 17]/4096.f;
1225
  }
1226
 
1227
+ #pragma unroll
1228
  for (int row = 0; row < nr; row++) {
1229
+ sumf[row] += block_q_n_dot_y(ax[row] + ib, sumy[0] + sumy[1], yl, il);
1230
  }
1231
 
1232
  yb += QK4_0 * 16;
 
1250
  constant uint64_t & nb00,
1251
  constant uint64_t & nb01,
1252
  constant uint64_t & nb02,
1253
+ constant uint64_t & nb03,
1254
  constant int64_t & ne10,
1255
  constant int64_t & ne11,
1256
  constant int64_t & ne12,
1257
  constant uint64_t & nb10,
1258
  constant uint64_t & nb11,
1259
  constant uint64_t & nb12,
1260
+ constant uint64_t & nb13,
1261
  constant int64_t & ne0,
1262
  constant int64_t & ne1,
1263
  constant uint & r2,
 
1265
  uint3 tgpig[[threadgroup_position_in_grid]],
1266
  uint tiisg[[thread_index_in_simdgroup]],
1267
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
1268
+ mul_vec_q_n_f32_impl<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,nb01,nb02,nb03,ne10,ne12,nb11,nb12,nb13,ne0,ne1,r2,r3,nullptr,tgpig,tiisg,sgitg);
1269
  }
1270
 
1271
  kernel void kernel_mul_mv_q4_1_f32(
 
1278
  constant uint64_t & nb00,
1279
  constant uint64_t & nb01,
1280
  constant uint64_t & nb02,
1281
+ constant uint64_t & nb03,
1282
  constant int64_t & ne10,
1283
  constant int64_t & ne11,
1284
  constant int64_t & ne12,
1285
  constant uint64_t & nb10,
1286
  constant uint64_t & nb11,
1287
  constant uint64_t & nb12,
1288
+ constant uint64_t & nb13,
1289
  constant int64_t & ne0,
1290
  constant int64_t & ne1,
1291
  constant uint & r2,
 
1293
  uint3 tgpig[[threadgroup_position_in_grid]],
1294
  uint tiisg[[thread_index_in_simdgroup]],
1295
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
1296
+ mul_vec_q_n_f32_impl<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,nb01,nb02,nb03,ne10,ne12,nb11,nb12,nb13,ne0,ne1,r2,r3,nullptr,tgpig,tiisg,sgitg);
1297
  }
1298
 
1299
  kernel void kernel_mul_mv_q5_0_f32(
 
1306
  constant uint64_t & nb00,
1307
  constant uint64_t & nb01,
1308
  constant uint64_t & nb02,
1309
+ constant uint64_t & nb03,
1310
  constant int64_t & ne10,
1311
  constant int64_t & ne11,
1312
  constant int64_t & ne12,
1313
  constant uint64_t & nb10,
1314
  constant uint64_t & nb11,
1315
  constant uint64_t & nb12,
1316
+ constant uint64_t & nb13,
1317
  constant int64_t & ne0,
1318
  constant int64_t & ne1,
1319
  constant uint & r2,
 
1321
  uint3 tgpig[[threadgroup_position_in_grid]],
1322
  uint tiisg[[thread_index_in_simdgroup]],
1323
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
1324
+ mul_vec_q_n_f32_impl<block_q5_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,nb01,nb02,nb03,ne10,ne12,nb11,nb12,nb13,ne0,ne1,r2,r3,nullptr,tgpig,tiisg,sgitg);
1325
  }
1326
 
1327
  kernel void kernel_mul_mv_q5_1_f32(
 
1334
  constant uint64_t & nb00,
1335
  constant uint64_t & nb01,
1336
  constant uint64_t & nb02,
1337
+ constant uint64_t & nb03,
1338
  constant int64_t & ne10,
1339
  constant int64_t & ne11,
1340
  constant int64_t & ne12,
1341
  constant uint64_t & nb10,
1342
  constant uint64_t & nb11,
1343
  constant uint64_t & nb12,
1344
+ constant uint64_t & nb13,
1345
  constant int64_t & ne0,
1346
  constant int64_t & ne1,
1347
  constant uint & r2,
 
1349
  uint3 tgpig[[threadgroup_position_in_grid]],
1350
  uint tiisg[[thread_index_in_simdgroup]],
1351
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
1352
+ mul_vec_q_n_f32_impl<block_q5_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,nb01,nb02,nb03,ne10,ne12,nb11,nb12,nb13,ne0,ne1,r2,r3,nullptr,tgpig,tiisg,sgitg);
1353
  }
1354
 
1355
 
 
1362
  int64_t ne00,
1363
  int64_t ne01,
1364
  int64_t ne02,
1365
+ uint64_t nb01,
1366
+ uint64_t nb02,
1367
+ uint64_t nb03,
1368
  int64_t ne10,
1369
  int64_t ne12,
1370
+ uint64_t nb11,
1371
+ uint64_t nb12,
1372
+ uint64_t nb13,
1373
  int64_t ne0,
1374
  int64_t ne1,
1375
  uint r2,
 
1392
  const uint i12 = im%ne12;
1393
  const uint i13 = im/ne12;
1394
 
1395
+ //const uint offset0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03;
1396
+ const uint offset1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13;
1397
 
1398
+ //device const block_q8_0 * x = (device const block_q8_0 *) ((device char *) src0 + offset0);
1399
+ device const float * y = (device const float *) ((device char *) src1 + offset1);
1400
+
1401
+ // pointers to src0 rows
1402
+ device const block_q8_0 * ax[nr];
1403
+ for (int row = 0; row < nr; ++row) {
1404
+ const uint offset0 = (first_row + row)*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03;
1405
+
1406
+ ax[row] = (device const block_q8_0 *) ((device char *) src0 + offset0);
1407
+ }
1408
 
1409
  float yl[NB_Q8_0];
1410
  float sumf[nr]={0.f};
 
1421
  }
1422
 
1423
  for (int row = 0; row < nr; row++) {
1424
+ device const int8_t * qs = ax[row][ib].qs + NB_Q8_0*il;
1425
  float sumq = 0.f;
1426
  for (int iq = 0; iq < NB_Q8_0; ++iq) {
1427
  sumq += qs[iq] * yl[iq];
1428
  }
1429
+ sumf[row] += sumq*ax[row][ib].d;
1430
  }
1431
 
1432
  yb += NB_Q8_0 * nw;
 
1451
  constant uint64_t & nb00,
1452
  constant uint64_t & nb01,
1453
  constant uint64_t & nb02,
1454
+ constant uint64_t & nb03,
1455
  constant int64_t & ne10,
1456
  constant int64_t & ne11,
1457
  constant int64_t & ne12,
1458
  constant uint64_t & nb10,
1459
  constant uint64_t & nb11,
1460
  constant uint64_t & nb12,
1461
+ constant uint64_t & nb13,
1462
  constant int64_t & ne0,
1463
  constant int64_t & ne1,
1464
  constant uint & r2,
 
1466
  uint3 tgpig[[threadgroup_position_in_grid]],
1467
  uint tiisg[[thread_index_in_simdgroup]],
1468
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
1469
+ kernel_mul_mv_q8_0_f32_impl(src0,src1,dst,ne00,ne01,ne02,nb01,nb02,nb03,ne10,ne12,nb11,nb12,nb13,ne0,ne1,r2,r3,nullptr,tgpig,tiisg,sgitg);
1470
  }
1471
 
1472
  #define N_MV_T_T 4
 
1482
  uint64_t nb00,
1483
  uint64_t nb01,
1484
  uint64_t nb02,
1485
+ uint64_t nb03,
1486
  int64_t ne10,
1487
  int64_t ne11,
1488
  int64_t ne12,
1489
  uint64_t nb10,
1490
  uint64_t nb11,
1491
  uint64_t nb12,
1492
+ uint64_t nb13,
1493
  int64_t ne0,
1494
  int64_t ne1,
1495
  uint r2,
 
1503
  const uint i12 = im%ne12;
1504
  const uint i13 = im/ne12;
1505
 
1506
+ const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03;
1507
 
1508
  device const T0 * x = (device const T0 *) (src0 + offset0);
1509
 
 
1514
  break;
1515
  }
1516
 
1517
+ const uint offset1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13;
1518
+
1519
+ device const T1 * y = (device const T1 *) (src1 + offset1);
1520
 
1521
  float sumf = 0;
1522
  for (int i = tiisg; i < ne00; i += 32) {
 
1536
  break;
1537
  }
1538
 
1539
+ const uint offset1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13;
1540
+
1541
+ device const T1 * y = (device const T1 *) (src1 + offset1);
1542
  device const T14 * y4 = (device const T14 *) y;
1543
 
1544
  float sumf = 0;
 
1566
  constant uint64_t & nb00,
1567
  constant uint64_t & nb01,
1568
  constant uint64_t & nb02,
1569
+ constant uint64_t & nb03,
1570
  constant int64_t & ne10,
1571
  constant int64_t & ne11,
1572
  constant int64_t & ne12,
1573
  constant uint64_t & nb10,
1574
  constant uint64_t & nb11,
1575
  constant uint64_t & nb12,
1576
+ constant uint64_t & nb13,
1577
  constant int64_t & ne0,
1578
  constant int64_t & ne1,
1579
  constant uint & r2,
 
1590
  nb00,
1591
  nb01,
1592
  nb02,
1593
+ nb03,
1594
  ne10,
1595
  ne11,
1596
  ne12,
1597
  nb10,
1598
  nb11,
1599
  nb12,
1600
+ nb13,
1601
  ne0,
1602
  ne1,
1603
  r2,
 
1623
  constant uint64_t & nb00,
1624
  constant uint64_t & nb01,
1625
  constant uint64_t & nb02,
1626
+ constant uint64_t & nb03,
1627
  constant int64_t & ne10,
1628
  constant int64_t & ne11,
1629
  constant int64_t & ne12,
1630
  constant uint64_t & nb10,
1631
  constant uint64_t & nb11,
1632
  constant uint64_t & nb12,
1633
+ constant uint64_t & nb13,
1634
  constant int64_t & ne0,
1635
  constant int64_t & ne1,
1636
  constant uint & r2,
 
1645
  const uint i12 = im%ne12;
1646
  const uint i13 = im/ne12;
1647
 
1648
+ const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03;
1649
+ const uint offset1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13;
1650
 
1651
  device const T * x = (device const T *) (src0 + offset0);
1652
+ device const float * y = (device const float *) (src1 + offset1);
1653
 
1654
  float sumf = 0;
1655
  if (ne00 < 128) {
 
1693
  constant uint64_t & nb00,
1694
  constant uint64_t & nb01,
1695
  constant uint64_t & nb02,
1696
+ constant uint64_t & nb03,
1697
  constant int64_t & ne10,
1698
  constant int64_t & ne11,
1699
  constant int64_t & ne12,
1700
  constant uint64_t & nb10,
1701
  constant uint64_t & nb11,
1702
  constant uint64_t & nb12,
1703
+ constant uint64_t & nb13,
1704
  constant int64_t & ne0,
1705
  constant int64_t & ne1,
1706
  constant uint & r2,
 
1715
  const uint i12 = im%ne12;
1716
  const uint i13 = im/ne12;
1717
 
1718
+ const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03;
1719
 
1720
  device const T4 * x4 = (device const T4 *) (src0 + offset0);
1721
 
1722
  for (int r1 = 0; r1 < nrows; ++r1) {
1723
+ const uint offset1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13;
1724
+
1725
+ device const float4 * y4 = (device const float4 *) (src1 + offset1);
1726
 
1727
  float sumf = 0;
1728
  for (int i = tiisg; i < ne00/4; i += 32) {
 
3482
  int64_t ne00,
3483
  int64_t ne01,
3484
  int64_t ne02,
3485
+ uint64_t nb01,
3486
+ uint64_t nb02,
3487
+ uint64_t nb03,
3488
  int64_t ne10,
3489
  int64_t ne12,
3490
+ uint64_t nb11,
3491
+ uint64_t nb12,
3492
+ uint64_t nb13,
3493
  int64_t ne0,
3494
  int64_t ne1,
3495
  uint r2,
 
3505
  const int im = tgpig.z;
3506
 
3507
  const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
 
3508
 
3509
  const uint i12 = im%ne12;
3510
  const uint i13 = im/ne12;
3511
 
3512
+ const uint offset0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03;
3513
+ const uint offset1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13;
3514
 
3515
+ device const block_q2_K * x = (device const block_q2_K *) ((device char *) src0 + offset0);
3516
+ device const float * y = (device const float *) ((device char *) src1 + offset1);
3517
 
3518
  float yl[32];
3519
  float sumf[N_DST]={0.f}, all_sum;
3520
 
 
 
3521
  const int ix = tiisg/8; // 0...3
3522
  const int it = tiisg%8; // 0...7
3523
  const int iq = it/4; // 0 or 1
 
3562
  (acc1[3] + 1.f/256.f * acc2[3]) * (sc[6] & 0xF) * 1.f/64.f) -
3563
  dmin * (sumy[0] * (sc[0] & 0xF0) + sumy[1] * (sc[2] & 0xF0) + sumy[2] * (sc[4] & 0xF0) + sumy[3] * (sc[6] & 0xF0));
3564
 
3565
+ qs += nb01/2;
3566
+ sc += nb01;
3567
+ dh += nb01/2;
3568
  }
3569
 
3570
  y4 += 4 * QK_K;
 
3589
  constant uint64_t & nb00,
3590
  constant uint64_t & nb01,
3591
  constant uint64_t & nb02,
3592
+ constant uint64_t & nb03,
3593
  constant int64_t & ne10,
3594
  constant int64_t & ne11,
3595
  constant int64_t & ne12,
3596
  constant uint64_t & nb10,
3597
  constant uint64_t & nb11,
3598
  constant uint64_t & nb12,
3599
+ constant uint64_t & nb13,
3600
  constant int64_t & ne0,
3601
  constant int64_t & ne1,
3602
  constant uint & r2,
 
3605
  uint tiisg[[thread_index_in_simdgroup]],
3606
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
3607
 
3608
+ kernel_mul_mv_q2_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb01, nb02, nb03, ne10, ne12, nb11, nb12, nb13, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg);
3609
  }
3610
 
3611
  void kernel_mul_mv_q3_K_f32_impl(
 
3615
  int64_t ne00,
3616
  int64_t ne01,
3617
  int64_t ne02,
3618
+ uint64_t nb01,
3619
+ uint64_t nb02,
3620
+ uint64_t nb03,
3621
  int64_t ne10,
3622
  int64_t ne12,
3623
+ uint64_t nb11,
3624
+ uint64_t nb12,
3625
+ uint64_t nb13,
3626
  int64_t ne0,
3627
  int64_t ne1,
3628
  uint r2,
 
3643
  const uint i12 = im%ne12;
3644
  const uint i13 = im/ne12;
3645
 
3646
+ const uint offset0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03;
3647
+ const uint offset1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13;
3648
 
3649
+ device const block_q3_K * x = (device const block_q3_K *) ((device char *) src0 + offset0);
3650
+ device const float * yy = (device const float *) ((device char *) src1 + offset1);
3651
 
3652
  float yl[32];
3653
 
 
3687
  const int q_offset = 32*ip + l0;
3688
  const int y_offset = 128*ip + 32*il + l0;
3689
 
 
 
3690
  device const float * y1 = yy + ix*QK_K + y_offset;
3691
 
3692
  uint32_t scales32, aux32;
 
3696
  float sumf1[2] = {0.f};
3697
  float sumf2[2] = {0.f};
3698
  for (int i = ix; i < nb; i += 4) {
 
3699
  for (int l = 0; l < 8; ++l) {
3700
  yl[l+ 0] = y1[l+ 0];
3701
  yl[l+ 8] = y1[l+16];
 
3709
  device const half * dh = &x[i].d;
3710
 
3711
  for (int row = 0; row < 2; ++row) {
 
3712
  const float d_all = (float)dh[0];
3713
 
3714
  scales16[0] = a[4];
 
3748
  sumf1[row] += d1 * (scales[1] - 32);
3749
  sumf2[row] += d2 * (scales[3] - 32);
3750
 
3751
+ q += nb01/2;
3752
+ h += nb01/2;
3753
+ a += nb01/2;
3754
+ dh += nb01/2;
 
3755
  }
3756
 
3757
  y1 += 4 * QK_K;
 
3758
  }
3759
 
3760
  for (int row = 0; row < 2; ++row) {
 
3779
  constant uint64_t & nb00,
3780
  constant uint64_t & nb01,
3781
  constant uint64_t & nb02,
3782
+ constant uint64_t & nb03,
3783
  constant int64_t & ne10,
3784
  constant int64_t & ne11,
3785
  constant int64_t & ne12,
3786
  constant uint64_t & nb10,
3787
  constant uint64_t & nb11,
3788
  constant uint64_t & nb12,
3789
+ constant uint64_t & nb13,
3790
  constant int64_t & ne0,
3791
  constant int64_t & ne1,
3792
  constant uint & r2,
 
3795
  uint tiisg[[thread_index_in_simdgroup]],
3796
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
3797
 
3798
+ kernel_mul_mv_q3_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb01, nb02, nb03, ne10, ne12, nb11, nb12, nb13, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg);
3799
  }
3800
 
3801
  void kernel_mul_mv_q4_K_f32_impl(
 
3805
  int64_t ne00,
3806
  int64_t ne01,
3807
  int64_t ne02,
3808
+ uint64_t nb01,
3809
+ uint64_t nb02,
3810
+ uint64_t nb03,
3811
  int64_t ne10,
3812
  int64_t ne12,
3813
+ uint64_t nb11,
3814
+ uint64_t nb12,
3815
+ uint64_t nb13,
3816
  int64_t ne0,
3817
  int64_t ne1,
3818
  uint r2,
 
3837
  const int im = tgpig.z;
3838
  //const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
3839
  const int first_row = r0 * N_DST;
 
3840
 
3841
  const uint i12 = im%ne12;
3842
  const uint i13 = im/ne12;
3843
 
3844
+ const uint offset0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03;
3845
+ const uint offset1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13;
3846
 
3847
+ device const block_q4_K * x = (device const block_q4_K *) ((device char *) src0 + offset0);
3848
+ device const float * y = (device const float *) ((device char *) src1 + offset1);
3849
 
3850
  float yl[16];
3851
  float yh[16];
3852
  float sumf[N_DST]={0.f}, all_sum;
3853
 
 
 
3854
  device const float * y4 = y + ix * QK_K + 64 * iq + 8 * ir;
3855
 
3856
  uint16_t sc16[4];
3857
  thread const uint8_t * sc8 = (thread const uint8_t *)sc16;
3858
 
3859
  for (int ib = ix; ib < nb; ib += 4) {
 
3860
  float4 sumy = {0.f, 0.f, 0.f, 0.f};
3861
  for (int i = 0; i < 8; ++i) {
3862
  yl[i+0] = y4[i+ 0]; sumy[0] += yl[i+0];
 
3870
  device const half * dh = &x[ib].d;
3871
 
3872
  for (int row = 0; row < N_DST; row++) {
 
3873
  sc16[0] = sc[0] & kmask1;
3874
  sc16[1] = sc[2] & kmask1;
3875
  sc16[2] = ((sc[4] >> 0) & kmask2) | ((sc[0] & kmask3) >> 2);
 
3898
  (acc2[2] + 1.f/256.f * acc2[3]) * sc8[5] * 1.f/16.f) -
3899
  dmin * (sumy[0] * sc8[2] + sumy[1] * sc8[3] + sumy[2] * sc8[6] + sumy[3] * sc8[7]);
3900
 
3901
+ q1 += nb01/2;
3902
+ sc += nb01/2;
3903
+ dh += nb01/2;
3904
  }
3905
 
3906
  y4 += 4 * QK_K;
 
3925
  constant uint64_t & nb00,
3926
  constant uint64_t & nb01,
3927
  constant uint64_t & nb02,
3928
+ constant uint64_t & nb03,
3929
  constant int64_t & ne10,
3930
  constant int64_t & ne11,
3931
  constant int64_t & ne12,
3932
  constant uint64_t & nb10,
3933
  constant uint64_t & nb11,
3934
  constant uint64_t & nb12,
3935
+ constant uint64_t & nb13,
3936
  constant int64_t & ne0,
3937
  constant int64_t & ne1,
3938
  constant uint & r2,
 
3941
  uint tiisg[[thread_index_in_simdgroup]],
3942
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
3943
 
3944
+ kernel_mul_mv_q4_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb01, nb02, nb03, ne10, ne12, nb11, nb12, nb13, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg);
3945
  }
3946
 
3947
  void kernel_mul_mv_q5_K_f32_impl(
 
3951
  int64_t ne00,
3952
  int64_t ne01,
3953
  int64_t ne02,
3954
+ uint64_t nb01,
3955
+ uint64_t nb02,
3956
+ uint64_t nb03,
3957
  int64_t ne10,
3958
  int64_t ne12,
3959
+ uint64_t nb11,
3960
+ uint64_t nb12,
3961
+ uint64_t nb13,
3962
  int64_t ne0,
3963
  int64_t ne1,
3964
  uint r2,
 
3979
  const uint i12 = im%ne12;
3980
  const uint i13 = im/ne12;
3981
 
3982
+ const uint offset0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03;
3983
+ const uint offset1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13;
3984
 
3985
+ device const block_q5_K * x = (device const block_q5_K *) ((device char *) src0 + offset0);
3986
+ device const float * yy = (device const float *) ((device char *) src1 + offset1);
3987
 
3988
  float sumf[2]={0.f};
3989
 
 
 
3990
  float yl[16], yh[16];
3991
 
3992
  const uint16_t kmask1 = 0x3f3f;
 
4014
  device const float * y1 = yy + ix*QK_K + y_offset;
4015
 
4016
  for (int i = ix; i < nb; i += 4) {
 
4017
  device const uint8_t * q1 = x[i].qs + q_offset;
4018
  device const uint8_t * qh = x[i].qh + l0;
4019
  device const half * dh = &x[i].d;
 
4029
  }
4030
 
4031
  for (int row = 0; row < 2; ++row) {
 
4032
  device const uint8_t * q2 = q1 + 64;
4033
 
4034
  sc16[0] = a[0] & kmask1;
 
4057
  sc8[5] * (acc1[3]/16.f + 16.f*acc2[3])) -
4058
  dmin * (sumy[0] * sc8[2] + sumy[1] * sc8[3] + sumy[2] * sc8[6] + sumy[3] * sc8[7]);
4059
 
4060
+ q1 += nb01;
4061
+ qh += nb01;
4062
+ dh += nb01/2;
4063
+ a += nb01/2;
 
4064
  }
4065
 
4066
  y1 += 4 * QK_K;
 
4067
  }
4068
 
4069
  for (int row = 0; row < 2; ++row) {
 
4085
  constant uint64_t & nb00,
4086
  constant uint64_t & nb01,
4087
  constant uint64_t & nb02,
4088
+ constant uint64_t & nb03,
4089
  constant int64_t & ne10,
4090
  constant int64_t & ne11,
4091
  constant int64_t & ne12,
4092
  constant uint64_t & nb10,
4093
  constant uint64_t & nb11,
4094
  constant uint64_t & nb12,
4095
+ constant uint64_t & nb13,
4096
  constant int64_t & ne0,
4097
  constant int64_t & ne1,
4098
  constant uint & r2,
 
4101
  uint tiisg[[thread_index_in_simdgroup]],
4102
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
4103
 
4104
+ kernel_mul_mv_q5_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb01, nb02, nb03, ne10, ne12, nb11, nb12, nb13, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg);
4105
  }
4106
 
4107
  void kernel_mul_mv_q6_K_f32_impl(
 
4111
  int64_t ne00,
4112
  int64_t ne01,
4113
  int64_t ne02,
4114
+ uint64_t nb01,
4115
+ uint64_t nb02,
4116
+ uint64_t nb03,
4117
  int64_t ne10,
4118
  int64_t ne12,
4119
+ uint64_t nb11,
4120
+ uint64_t nb12,
4121
+ uint64_t nb13,
4122
  int64_t ne0,
4123
  int64_t ne1,
4124
  uint r2,
 
4144
  const uint i12 = im%ne12;
4145
  const uint i13 = im/ne12;
4146
 
4147
+ const uint offset0 = row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03;
4148
+ const uint offset1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13;
4149
 
4150
+ device const block_q6_K * x = (device const block_q6_K *) ((device char *) src0 + offset0);
4151
+ device const float * yy = (device const float *) ((device char *) src1 + offset1);
4152
 
4153
  float sumf = 0;
4154
 
 
4204
  constant uint64_t & nb00,
4205
  constant uint64_t & nb01,
4206
  constant uint64_t & nb02,
4207
+ constant uint64_t & nb03,
4208
  constant int64_t & ne10,
4209
  constant int64_t & ne11,
4210
  constant int64_t & ne12,
4211
  constant uint64_t & nb10,
4212
  constant uint64_t & nb11,
4213
  constant uint64_t & nb12,
4214
+ constant uint64_t & nb13,
4215
  constant int64_t & ne0,
4216
  constant int64_t & ne1,
4217
  constant uint & r2,
 
4220
  uint tiisg[[thread_index_in_simdgroup]],
4221
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
4222
 
4223
+ kernel_mul_mv_q6_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb01, nb02, nb03, ne10, ne12, nb11, nb12, nb13, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg);
4224
  }
4225
 
4226
  // ======================= "True" 2-bit
 
4232
  int64_t ne00,
4233
  int64_t ne01,
4234
  int64_t ne02,
4235
+ uint64_t nb01,
4236
+ uint64_t nb02,
4237
+ uint64_t nb03,
4238
  int64_t ne10,
4239
  int64_t ne12,
4240
+ uint64_t nb11,
4241
+ uint64_t nb12,
4242
+ uint64_t nb13,
4243
  int64_t ne0,
4244
  int64_t ne1,
4245
  uint r2,
 
4255
  const int im = tgpig.z;
4256
 
4257
  const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
 
4258
 
4259
  const uint i12 = im%ne12;
4260
  const uint i13 = im/ne12;
4261
 
4262
+ const uint offset0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03;
4263
+ const uint offset1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13;
4264
 
4265
+ device const block_iq2_xxs * x = (device const block_iq2_xxs *) ((device char *) src0 + offset0);
4266
+ device const float * y = (device const float *) ((device char *) src1 + offset1);
4267
 
4268
  float yl[32];
4269
  float sumf[N_DST]={0.f}, all_sum;
 
4316
  }
4317
  sumf[row] += d * sum;
4318
 
4319
+ dh += nb01/2;
4320
+ q2 += nb01/2;
4321
  }
4322
 
4323
  y4 += 32 * 32;
 
4342
  constant uint64_t & nb00,
4343
  constant uint64_t & nb01,
4344
  constant uint64_t & nb02,
4345
+ constant uint64_t & nb03,
4346
  constant int64_t & ne10,
4347
  constant int64_t & ne11,
4348
  constant int64_t & ne12,
4349
  constant uint64_t & nb10,
4350
  constant uint64_t & nb11,
4351
  constant uint64_t & nb12,
4352
+ constant uint64_t & nb13,
4353
  constant int64_t & ne0,
4354
  constant int64_t & ne1,
4355
  constant uint & r2,
 
4359
  uint tiisg[[thread_index_in_simdgroup]],
4360
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
4361
 
4362
+ kernel_mul_mv_iq2_xxs_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb01, nb02, nb03, ne10, ne12, nb11, nb12, nb13, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
4363
  }
4364
 
4365
  void kernel_mul_mv_iq2_xs_f32_impl(
 
4369
  int64_t ne00,
4370
  int64_t ne01,
4371
  int64_t ne02,
4372
+ uint64_t nb01,
4373
+ uint64_t nb02,
4374
+ uint64_t nb03,
4375
  int64_t ne10,
4376
  int64_t ne12,
4377
+ uint64_t nb11,
4378
+ uint64_t nb12,
4379
+ uint64_t nb13,
4380
  int64_t ne0,
4381
  int64_t ne1,
4382
  uint r2,
 
4392
  const int im = tgpig.z;
4393
 
4394
  const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
 
4395
 
4396
  const uint i12 = im%ne12;
4397
  const uint i13 = im/ne12;
4398
 
4399
+ const uint offset0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03;
4400
+ const uint offset1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13;
4401
 
4402
+ device const block_iq2_xs * x = (device const block_iq2_xs *) ((device char *) src0 + offset0);
4403
+ device const float * y = (device const float *) ((device char *) src1 + offset1);
4404
 
4405
  float yl[32];
4406
  float sumf[N_DST]={0.f}, all_sum;
 
4462
  }
4463
  sumf[row] += d1 * sum1 + d2 * sum2;
4464
 
4465
+ dh += nb01/2;
4466
+ q2 += nb01/2;
4467
+ sc += nb01;
4468
  }
4469
 
4470
  y4 += 32 * 32;
 
4489
  constant uint64_t & nb00,
4490
  constant uint64_t & nb01,
4491
  constant uint64_t & nb02,
4492
+ constant uint64_t & nb03,
4493
  constant int64_t & ne10,
4494
  constant int64_t & ne11,
4495
  constant int64_t & ne12,
4496
  constant uint64_t & nb10,
4497
  constant uint64_t & nb11,
4498
  constant uint64_t & nb12,
4499
+ constant uint64_t & nb13,
4500
  constant int64_t & ne0,
4501
  constant int64_t & ne1,
4502
  constant uint & r2,
 
4506
  uint tiisg[[thread_index_in_simdgroup]],
4507
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
4508
 
4509
+ kernel_mul_mv_iq2_xs_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb01, nb02, nb03, ne10, ne12, nb11, nb12, nb13, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
4510
  }
4511
 
4512
  void kernel_mul_mv_iq3_xxs_f32_impl(
 
4516
  int64_t ne00,
4517
  int64_t ne01,
4518
  int64_t ne02,
4519
+ uint64_t nb01,
4520
+ uint64_t nb02,
4521
+ uint64_t nb03,
4522
  int64_t ne10,
4523
  int64_t ne12,
4524
+ uint64_t nb11,
4525
+ uint64_t nb12,
4526
+ uint64_t nb13,
4527
  int64_t ne0,
4528
  int64_t ne1,
4529
  uint r2,
 
4539
  const int im = tgpig.z;
4540
 
4541
  const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
 
4542
 
4543
  const uint i12 = im%ne12;
4544
  const uint i13 = im/ne12;
4545
 
4546
+ const uint offset0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03;
4547
+ const uint offset1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13;
4548
 
4549
+ device const block_iq3_xxs * x = (device const block_iq3_xxs *) ((device char *) src0 + offset0);
4550
+ device const float * y = (device const float *) ((device char *) src1 + offset1);
4551
 
4552
  float yl[32];
4553
  float sumf[N_DST]={0.f}, all_sum;
 
4602
  }
4603
  sumf[row] += d * (sum[0] + sum[1]);
4604
 
4605
+ dh += nb01/2;
4606
+ q3 += nb01;
4607
+ gas += nb01/2;
4608
  }
4609
 
4610
  y4 += 32 * 32;
 
4629
  constant uint64_t & nb00,
4630
  constant uint64_t & nb01,
4631
  constant uint64_t & nb02,
4632
+ constant uint64_t & nb03,
4633
  constant int64_t & ne10,
4634
  constant int64_t & ne11,
4635
  constant int64_t & ne12,
4636
  constant uint64_t & nb10,
4637
  constant uint64_t & nb11,
4638
  constant uint64_t & nb12,
4639
+ constant uint64_t & nb13,
4640
  constant int64_t & ne0,
4641
  constant int64_t & ne1,
4642
  constant uint & r2,
 
4646
  uint tiisg[[thread_index_in_simdgroup]],
4647
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
4648
 
4649
+ kernel_mul_mv_iq3_xxs_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb01, nb02, nb03, ne10, ne12, nb11, nb12, nb13, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
4650
  }
4651
 
4652
  void kernel_mul_mv_iq3_s_f32_impl(
 
4656
  int64_t ne00,
4657
  int64_t ne01,
4658
  int64_t ne02,
4659
+ uint64_t nb01,
4660
+ uint64_t nb02,
4661
+ uint64_t nb03,
4662
  int64_t ne10,
4663
  int64_t ne12,
4664
+ uint64_t nb11,
4665
+ uint64_t nb12,
4666
+ uint64_t nb13,
4667
  int64_t ne0,
4668
  int64_t ne1,
4669
  uint r2,
 
4679
  const int im = tgpig.z;
4680
 
4681
  const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
 
4682
 
4683
  const uint i12 = im%ne12;
4684
  const uint i13 = im/ne12;
4685
 
4686
+ const uint offset0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03;
4687
+ const uint offset1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13;
4688
 
4689
+ device const block_iq3_s * x = (device const block_iq3_s *) ((device char *) src0 + offset0);
4690
+ device const float * y = (device const float *) ((device char *) src1 + offset1);
4691
 
4692
  float yl[32];
4693
  float sumf[N_DST]={0.f}, all_sum;
 
4740
  }
4741
  sumf[row] += d * (sum[0] + sum[1]);
4742
 
4743
+ dh += nb01/2;
4744
+ qs += nb01;
4745
+ qh += nb01;
4746
+ sc += nb01;
4747
+ signs += nb01;
4748
  }
4749
 
4750
  y4 += 32 * 32;
 
4769
  constant uint64_t & nb00,
4770
  constant uint64_t & nb01,
4771
  constant uint64_t & nb02,
4772
+ constant uint64_t & nb03,
4773
  constant int64_t & ne10,
4774
  constant int64_t & ne11,
4775
  constant int64_t & ne12,
4776
  constant uint64_t & nb10,
4777
  constant uint64_t & nb11,
4778
  constant uint64_t & nb12,
4779
+ constant uint64_t & nb13,
4780
  constant int64_t & ne0,
4781
  constant int64_t & ne1,
4782
  constant uint & r2,
 
4786
  uint tiisg[[thread_index_in_simdgroup]],
4787
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
4788
 
4789
+ kernel_mul_mv_iq3_s_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb01, nb02, nb03, ne10, ne12, nb11, nb12, nb13, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
4790
  }
4791
 
4792
  void kernel_mul_mv_iq2_s_f32_impl(
 
4796
  int64_t ne00,
4797
  int64_t ne01,
4798
  int64_t ne02,
4799
+ uint64_t nb01,
4800
+ uint64_t nb02,
4801
+ uint64_t nb03,
4802
  int64_t ne10,
4803
  int64_t ne12,
4804
+ uint64_t nb11,
4805
+ uint64_t nb12,
4806
+ uint64_t nb13,
4807
  int64_t ne0,
4808
  int64_t ne1,
4809
  uint r2,
 
4819
  const int im = tgpig.z;
4820
 
4821
  const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
 
4822
 
4823
  const uint i12 = im%ne12;
4824
  const uint i13 = im/ne12;
4825
 
4826
+ const uint offset0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03;
4827
+ const uint offset1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13;
4828
 
4829
+ device const block_iq2_s * x = (device const block_iq2_s *) ((device char *) src0 + offset0);
4830
+ device const float * y = (device const float *) ((device char *) src1 + offset1);
4831
 
4832
  float yl[32];
4833
  float sumf[N_DST]={0.f}, all_sum;
 
4881
  }
4882
  sumf[row] += d1 * sum[0] + d2 * sum[1];
4883
 
4884
+ dh += nb01/2;
4885
+ qs += nb01;
4886
+ qh += nb01;
4887
+ sc += nb01;
4888
+ signs += nb01;
4889
  }
4890
 
4891
  y4 += 32 * 32;
 
4910
  constant uint64_t & nb00,
4911
  constant uint64_t & nb01,
4912
  constant uint64_t & nb02,
4913
+ constant uint64_t & nb03,
4914
  constant int64_t & ne10,
4915
  constant int64_t & ne11,
4916
  constant int64_t & ne12,
4917
  constant uint64_t & nb10,
4918
  constant uint64_t & nb11,
4919
  constant uint64_t & nb12,
4920
+ constant uint64_t & nb13,
4921
  constant int64_t & ne0,
4922
  constant int64_t & ne1,
4923
  constant uint & r2,
 
4927
  uint tiisg[[thread_index_in_simdgroup]],
4928
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
4929
 
4930
+ kernel_mul_mv_iq2_s_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb01, nb02, nb03, ne10, ne12, nb11, nb12, nb13, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
4931
  }
4932
 
4933
  void kernel_mul_mv_iq1_s_f32_impl(
 
4937
  int64_t ne00,
4938
  int64_t ne01,
4939
  int64_t ne02,
4940
+ uint64_t nb01,
4941
+ uint64_t nb02,
4942
+ uint64_t nb03,
4943
  int64_t ne10,
4944
  int64_t ne12,
4945
+ uint64_t nb11,
4946
+ uint64_t nb12,
4947
+ uint64_t nb13,
4948
  int64_t ne0,
4949
  int64_t ne1,
4950
  uint r2,
 
4960
  const int im = tgpig.z;
4961
 
4962
  const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
 
4963
 
4964
  const uint i12 = im%ne12;
4965
  const uint i13 = im/ne12;
4966
 
4967
+ const uint offset0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03;
4968
+ const uint offset1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13;
4969
+
4970
+ device const block_iq1_s * x = (device const block_iq1_s *) ((device char *) src0 + offset0);
4971
+ device const float * y = (device const float *) ((device char *) src1 + offset1);
4972
 
4973
  float yl[32];
4974
  float sumf[N_DST]={0.f}, all_sum;
 
5011
  }
5012
  sumf[row] += (float)dh[0] * (sum + sumy * (qh[0] & 0x8000 ? -1 - IQ1S_DELTA : -1 + IQ1S_DELTA)) * (2*((qh[0] >> 12) & 7) + 1);
5013
 
5014
+ dh += nb01/2;
5015
+ qs += nb01;
5016
+ qh += nb01/2;
5017
  }
5018
 
5019
  y4 += 32 * 32;
 
5034
  int64_t ne00,
5035
  int64_t ne01,
5036
  int64_t ne02,
5037
+ uint64_t nb01,
5038
+ uint64_t nb02,
5039
+ uint64_t nb03,
5040
  int64_t ne10,
5041
  int64_t ne12,
5042
+ uint64_t nb11,
5043
+ uint64_t nb12,
5044
+ uint64_t nb13,
5045
  int64_t ne0,
5046
  int64_t ne1,
5047
  uint r2,
 
5057
  const int im = tgpig.z;
5058
 
5059
  const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
 
5060
 
5061
  const uint i12 = im%ne12;
5062
  const uint i13 = im/ne12;
5063
 
5064
+ const uint offset0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03;
5065
+ const uint offset1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13;
5066
+
5067
+ device const block_iq1_m * x = (device const block_iq1_m *) ((device char *) src0 + offset0);
5068
+ device const float * y = (device const float *) ((device char *) src1 + offset1);
5069
 
5070
  float yl[32];
5071
  float sumf[N_DST]={0.f}, all_sum;
 
5117
  sumf[row] += (float)scale.f16 * ((sum[0] + delta1) * (2*((sc[ib/2] >> (6*(ib%2)+0)) & 7) + 1) +
5118
  (sum[1] + delta2) * (2*((sc[ib/2] >> (6*(ib%2)+3)) & 7) + 1));
5119
 
5120
+ sc += nb01/2;
5121
+ qs += nb01;
5122
+ qh += nb01;
5123
  }
5124
 
5125
  y4 += 32 * 32;
 
5140
  int64_t ne00,
5141
  int64_t ne01,
5142
  int64_t ne02,
5143
+ uint64_t nb01,
5144
+ uint64_t nb02,
5145
+ uint64_t nb03,
5146
  int64_t ne10,
5147
  int64_t ne12,
5148
+ uint64_t nb11,
5149
+ uint64_t nb12,
5150
+ uint64_t nb13,
5151
  int64_t ne0,
5152
  int64_t ne1,
5153
  uint r2,
 
5163
  const int r1 = tgpig.y;
5164
  const int im = tgpig.z;
5165
  const int first_row = (r0 * 2 + sgitg) * 2;
 
5166
 
5167
  const uint i12 = im%ne12;
5168
  const uint i13 = im/ne12;
5169
 
5170
+ const uint offset0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03;
5171
+ const uint offset1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13;
5172
+
5173
+ device const block_iq4_nl * x = (device const block_iq4_nl *) ((device char *) src0 + offset0);
5174
+ device const float * y = (device const float *) ((device char *) src1 + offset1);
5175
 
5176
  const int ix = tiisg/2; // 0...15
5177
  const int it = tiisg%2; // 0 or 1
 
5241
  int64_t ne00,
5242
  int64_t ne01,
5243
  int64_t ne02,
5244
+ uint64_t nb01,
5245
+ uint64_t nb02,
5246
+ uint64_t nb03,
5247
  int64_t ne10,
5248
  int64_t ne12,
5249
+ uint64_t nb11,
5250
+ uint64_t nb12,
5251
+ uint64_t nb13,
5252
  int64_t ne0,
5253
  int64_t ne1,
5254
  uint r2,
 
5264
  const int r1 = tgpig.y;
5265
  const int im = tgpig.z;
5266
  const int first_row = (r0 * 2 + sgitg) * 2;
 
5267
 
5268
  const uint i12 = im%ne12;
5269
  const uint i13 = im/ne12;
5270
 
5271
+ const uint offset0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03;
5272
+ const uint offset1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13;
5273
+
5274
+ device const block_iq4_xs * x = (device const block_iq4_xs *) ((device char *) src0 + offset0);
5275
+ device const float * y = (device const float *) ((device char *) src1 + offset1);
5276
 
5277
  const int ix = tiisg/16; // 0 or 1
5278
  const int it = tiisg%16; // 0...15
 
5347
  constant uint64_t & nb00,
5348
  constant uint64_t & nb01,
5349
  constant uint64_t & nb02,
5350
+ constant uint64_t & nb03,
5351
  constant int64_t & ne10,
5352
  constant int64_t & ne11,
5353
  constant int64_t & ne12,
5354
  constant uint64_t & nb10,
5355
  constant uint64_t & nb11,
5356
  constant uint64_t & nb12,
5357
+ constant uint64_t & nb13,
5358
  constant int64_t & ne0,
5359
  constant int64_t & ne1,
5360
  constant uint & r2,
 
5363
  uint tiisg[[thread_index_in_simdgroup]],
5364
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
5365
 
5366
+ kernel_mul_mv_iq1_s_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb01, nb02, nb03, ne10, ne12, nb11, nb12, nb13, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg);
5367
  }
5368
 
5369
  [[host_name("kernel_mul_mv_iq1_m_f32")]]
 
5377
  constant uint64_t & nb00,
5378
  constant uint64_t & nb01,
5379
  constant uint64_t & nb02,
5380
+ constant uint64_t & nb03,
5381
  constant int64_t & ne10,
5382
  constant int64_t & ne11,
5383
  constant int64_t & ne12,
5384
  constant uint64_t & nb10,
5385
  constant uint64_t & nb11,
5386
  constant uint64_t & nb12,
5387
+ constant uint64_t & nb13,
5388
  constant int64_t & ne0,
5389
  constant int64_t & ne1,
5390
  constant uint & r2,
 
5393
  uint tiisg[[thread_index_in_simdgroup]],
5394
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
5395
 
5396
+ kernel_mul_mv_iq1_m_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb01, nb02, nb03, ne10, ne12, nb11, nb12, nb13, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg);
5397
  }
5398
 
5399
  [[host_name("kernel_mul_mv_iq4_nl_f32")]]
 
5407
  constant uint64_t & nb00,
5408
  constant uint64_t & nb01,
5409
  constant uint64_t & nb02,
5410
+ constant uint64_t & nb03,
5411
  constant int64_t & ne10,
5412
  constant int64_t & ne11,
5413
  constant int64_t & ne12,
5414
  constant uint64_t & nb10,
5415
  constant uint64_t & nb11,
5416
  constant uint64_t & nb12,
5417
+ constant uint64_t & nb13,
5418
  constant int64_t & ne0,
5419
  constant int64_t & ne1,
5420
  constant uint & r2,
 
5424
  uint tiisg[[thread_index_in_simdgroup]],
5425
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
5426
 
5427
+ kernel_mul_mv_iq4_nl_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb01, nb02, nb03, ne10, ne12, nb11, nb12, nb13, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
5428
  }
5429
 
5430
  [[host_name("kernel_mul_mv_iq4_xs_f32")]]
 
5438
  constant uint64_t & nb00,
5439
  constant uint64_t & nb01,
5440
  constant uint64_t & nb02,
5441
+ constant uint64_t & nb03,
5442
  constant int64_t & ne10,
5443
  constant int64_t & ne11,
5444
  constant int64_t & ne12,
5445
  constant uint64_t & nb10,
5446
  constant uint64_t & nb11,
5447
  constant uint64_t & nb12,
5448
+ constant uint64_t & nb13,
5449
  constant int64_t & ne0,
5450
  constant int64_t & ne1,
5451
  constant uint & r2,
 
5455
  uint tiisg[[thread_index_in_simdgroup]],
5456
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
5457
 
5458
+ kernel_mul_mv_iq4_xs_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb01, nb02, nb03, ne10, ne12, nb11, nb12, nb13, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
5459
  }
5460
 
5461
  //============================= templates and their specializations =============================
 
6000
  constant int64_t & ne02,
6001
  constant uint64_t & nb01,
6002
  constant uint64_t & nb02,
6003
+ constant uint64_t & nb03,
6004
  constant int64_t & ne12,
6005
  constant uint64_t & nb10,
6006
  constant uint64_t & nb11,
6007
  constant uint64_t & nb12,
6008
+ constant uint64_t & nb13,
6009
  constant int64_t & ne0,
6010
  constant int64_t & ne1,
6011
  constant uint & r2,
 
6042
  const uint i12 = im%ne12;
6043
  const uint i13 = im/ne12;
6044
 
6045
+ uint offset0 = (i12/r2)*nb02 + (i13/r3)*nb03;
6046
  ushort offset1 = il/nl;
6047
 
6048
  device const block_q * x = (device const block_q *)(src0 + (r0 * BLOCK_SIZE_M + thread_row) * nb01 + offset0) + offset1;
6049
  device const float * y = (device const float *)(src1
6050
+ + nb13 * i13
6051
+ + nb12 * i12
6052
  + nb11 * (r1 * BLOCK_SIZE_N + thread_col)
6053
  + nb10 * (BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL)));
6054
 
 
6427
  uint64_t nb00,
6428
  uint64_t nb01,
6429
  uint64_t nb02,
6430
+ uint64_t nb03,
6431
  int64_t ne10,
6432
  int64_t ne11,
6433
  int64_t ne12,
6434
  uint64_t nb10,
6435
  uint64_t nb11,
6436
  uint64_t nb12,
6437
+ uint64_t nb13,
6438
  int64_t ne0,
6439
  int64_t ne1,
6440
  uint r2,
 
6449
  int64_t ne00,
6450
  int64_t ne01,
6451
  int64_t ne02,
6452
+ uint64_t nb01,
6453
+ uint64_t nb02,
6454
+ uint64_t nb03,
6455
  int64_t ne10,
6456
  int64_t ne12,
6457
+ uint64_t nb11,
6458
+ uint64_t nb12,
6459
+ uint64_t nb13,
6460
  int64_t ne0,
6461
  int64_t ne1,
6462
  uint r2,
 
6477
  uint64_t nb00,
6478
  uint64_t nb01,
6479
  uint64_t nb02,
6480
+ uint64_t nb03,
6481
  int64_t ne10,
6482
  int64_t ne11,
6483
  int64_t ne12,
 
6485
  uint64_t nb10,
6486
  uint64_t nb11,
6487
  uint64_t nb12,
6488
+ uint64_t nb13,
6489
  int64_t ne0,
6490
  int64_t ne1,
6491
  uint64_t nb1,
 
6496
  uint tiitg,
6497
  uint tiisg,
6498
  uint sgitg) {
6499
+ impl_fn(src0,src1,dst,ne00,ne01,ne02,nb00,nb01,nb02,nb03,ne10,ne11,ne12,nb10,nb11,nb12,nb13,ne0,ne1,r2,r3,tgpig,tiisg);
6500
  }
6501
 
6502
  template<kernel_mul_mv2_impl_t impl_fn>
 
6510
  uint64_t nb00,
6511
  uint64_t nb01,
6512
  uint64_t nb02,
6513
+ uint64_t nb03,
6514
  int64_t ne10,
6515
  int64_t ne11,
6516
  int64_t ne12,
 
6518
  uint64_t nb10,
6519
  uint64_t nb11,
6520
  uint64_t nb12,
6521
+ uint64_t nb13,
6522
  int64_t ne0,
6523
  int64_t ne1,
6524
  uint64_t nb1,
 
6529
  uint tiitg,
6530
  uint tiisg,
6531
  uint sgitg) {
6532
+ impl_fn(src0,(const device float *)src1,dst,ne00,ne01,ne02,nb01,nb02,nb03,ne10,ne12,nb11,nb12,nb13,ne0,ne1,r2,r3,shared_values,tgpig,tiisg,sgitg);
6533
  }
6534
 
6535
  typedef decltype(mmv_fn<kernel_mul_mv_impl<half, half4, half, half4>>) mul_mv_impl_fn_t;
 
6578
  const int64_t i2 = i12;
6579
 
6580
  device const char * src0_cur = src0s + i02*nb02;
6581
+ device const char * src1_cur = src1 + i11*nb11 + i12*nb12;
6582
+ device float * dst_cur = dst + i1*ne0 + i2*ne1*ne0;
6583
 
6584
  impl_fn(
6585
  /* src0 */ src0_cur,
 
6587
  /* dst */ dst_cur,
6588
  /* ne00 */ ne00,
6589
  /* ne01 */ ne01,
6590
+ /* ne02 */ 1, // ne02,
6591
  /* nb00 */ nb00,
6592
  /* nb01 */ nb01,
6593
  /* nb02 */ nb02,
6594
+ /* nb03 */ nb02, // ne02 == 1
6595
  /* ne10 */ ne10,
6596
+ /* ne11 */ 1, // ne11,
6597
+ /* ne12 */ 1, // ne12,
6598
+ /* ne13 */ 1, // ne13,
6599
  /* nb10 */ nb10,
6600
  /* nb11 */ nb11,
6601
  /* nb12 */ nb12,
6602
+ /* ne13 */ nb12, // ne12 == 1
6603
  /* ne0 */ ne0,
6604
+ /* ne1 */ 1, // ne1,
6605
  /* nb1 */ nb1,
6606
  /* r2 */ 1,
6607
  /* r3 */ 1,