ggerganov commited on
Commit
af0525c
·
1 Parent(s): a0ecefd

metal : move dequantize templates to beginning of MSL source (llama/0)

Browse files
Files changed (1) hide show
  1. ggml/src/ggml-metal.metal +910 -912
ggml/src/ggml-metal.metal CHANGED
@@ -12,435 +12,454 @@ using namespace metal;
12
 
13
  #define N_SIMDWIDTH 32 // assuming SIMD group size is 32
14
 
15
- enum ggml_sort_order {
16
- GGML_SORT_ORDER_ASC,
17
- GGML_SORT_ORDER_DESC,
18
  };
19
 
20
- // general-purpose kernel for addition, subtraction, multiplication and division of two tensors
21
- // pros: works for non-contiguous tensors, supports broadcast across all dims
22
- // cons: not very efficient
23
- kernel void kernel_add(
24
- device const char * src0,
25
- device const char * src1,
26
- device char * dst,
27
- constant int64_t & ne00,
28
- constant int64_t & ne01,
29
- constant int64_t & ne02,
30
- constant int64_t & ne03,
31
- constant uint64_t & nb00,
32
- constant uint64_t & nb01,
33
- constant uint64_t & nb02,
34
- constant uint64_t & nb03,
35
- constant int64_t & ne10,
36
- constant int64_t & ne11,
37
- constant int64_t & ne12,
38
- constant int64_t & ne13,
39
- constant uint64_t & nb10,
40
- constant uint64_t & nb11,
41
- constant uint64_t & nb12,
42
- constant uint64_t & nb13,
43
- constant int64_t & ne0,
44
- constant int64_t & ne1,
45
- constant int64_t & ne2,
46
- constant int64_t & ne3,
47
- constant uint64_t & nb0,
48
- constant uint64_t & nb1,
49
- constant uint64_t & nb2,
50
- constant uint64_t & nb3,
51
- constant int64_t & offs,
52
- uint3 tgpig[[threadgroup_position_in_grid]],
53
- uint3 tpitg[[thread_position_in_threadgroup]],
54
- uint3 ntg[[threads_per_threadgroup]]) {
55
- const int64_t i03 = tgpig.z;
56
- const int64_t i02 = tgpig.y;
57
- const int64_t i01 = tgpig.x;
58
-
59
- const int64_t i13 = i03 % ne13;
60
- const int64_t i12 = i02 % ne12;
61
- const int64_t i11 = i01 % ne11;
62
-
63
- device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01 + offs;
64
- device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11;
65
- device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1 + offs;
66
 
67
- for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
68
- const int i10 = i0 % ne10;
69
- *((device float *)(dst_ptr + i0*nb0)) = *((device float *)(src0_ptr + i0*nb00)) + *((device float *)(src1_ptr + i10*nb10));
 
 
70
  }
71
  }
72
 
73
- kernel void kernel_sub(
74
- device const char * src0,
75
- device const char * src1,
76
- device char * dst,
77
- constant int64_t & ne00,
78
- constant int64_t & ne01,
79
- constant int64_t & ne02,
80
- constant int64_t & ne03,
81
- constant uint64_t & nb00,
82
- constant uint64_t & nb01,
83
- constant uint64_t & nb02,
84
- constant uint64_t & nb03,
85
- constant int64_t & ne10,
86
- constant int64_t & ne11,
87
- constant int64_t & ne12,
88
- constant int64_t & ne13,
89
- constant uint64_t & nb10,
90
- constant uint64_t & nb11,
91
- constant uint64_t & nb12,
92
- constant uint64_t & nb13,
93
- constant int64_t & ne0,
94
- constant int64_t & ne1,
95
- constant int64_t & ne2,
96
- constant int64_t & ne3,
97
- constant uint64_t & nb0,
98
- constant uint64_t & nb1,
99
- constant uint64_t & nb2,
100
- constant uint64_t & nb3,
101
- constant int64_t & offs,
102
- uint3 tgpig[[threadgroup_position_in_grid]],
103
- uint3 tpitg[[thread_position_in_threadgroup]],
104
- uint3 ntg[[threads_per_threadgroup]]) {
105
- const int64_t i03 = tgpig.z;
106
- const int64_t i02 = tgpig.y;
107
- const int64_t i01 = tgpig.x;
108
 
109
- const int64_t i13 = i03 % ne13;
110
- const int64_t i12 = i02 % ne12;
111
- const int64_t i11 = i01 % ne11;
 
 
112
 
113
- device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01 + offs;
114
- device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11;
115
- device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1 + offs;
 
 
 
 
 
116
 
117
- for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
118
- const int i10 = i0 % ne10;
119
- *((device float *)(dst_ptr + i0*nb0)) = *((device float *)(src0_ptr + i0*nb00)) - *((device float *)(src1_ptr + i10*nb10));
120
  }
121
  }
122
 
123
- kernel void kernel_mul(
124
- device const char * src0,
125
- device const char * src1,
126
- device char * dst,
127
- constant int64_t & ne00,
128
- constant int64_t & ne01,
129
- constant int64_t & ne02,
130
- constant int64_t & ne03,
131
- constant uint64_t & nb00,
132
- constant uint64_t & nb01,
133
- constant uint64_t & nb02,
134
- constant uint64_t & nb03,
135
- constant int64_t & ne10,
136
- constant int64_t & ne11,
137
- constant int64_t & ne12,
138
- constant int64_t & ne13,
139
- constant uint64_t & nb10,
140
- constant uint64_t & nb11,
141
- constant uint64_t & nb12,
142
- constant uint64_t & nb13,
143
- constant int64_t & ne0,
144
- constant int64_t & ne1,
145
- constant int64_t & ne2,
146
- constant int64_t & ne3,
147
- constant uint64_t & nb0,
148
- constant uint64_t & nb1,
149
- constant uint64_t & nb2,
150
- constant uint64_t & nb3,
151
- uint3 tgpig[[threadgroup_position_in_grid]],
152
- uint3 tpitg[[thread_position_in_threadgroup]],
153
- uint3 ntg[[threads_per_threadgroup]]) {
154
- const int64_t i03 = tgpig.z;
155
- const int64_t i02 = tgpig.y;
156
- const int64_t i01 = tgpig.x;
157
 
158
- const int64_t i13 = i03 % ne13;
159
- const int64_t i12 = i02 % ne12;
160
- const int64_t i11 = i01 % ne11;
161
 
162
- device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01;
163
- device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11;
164
- device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1;
165
 
166
- for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
167
- const int i10 = i0 % ne10;
168
- *((device float *)(dst_ptr + i0*nb0)) = *((device float *)(src0_ptr + i0*nb00)) * *((device float *)(src1_ptr + i10*nb10));
 
 
 
 
 
 
 
 
 
 
 
169
  }
170
  }
171
 
172
- kernel void kernel_div(
173
- device const char * src0,
174
- device const char * src1,
175
- device char * dst,
176
- constant int64_t & ne00,
177
- constant int64_t & ne01,
178
- constant int64_t & ne02,
179
- constant int64_t & ne03,
180
- constant uint64_t & nb00,
181
- constant uint64_t & nb01,
182
- constant uint64_t & nb02,
183
- constant uint64_t & nb03,
184
- constant int64_t & ne10,
185
- constant int64_t & ne11,
186
- constant int64_t & ne12,
187
- constant int64_t & ne13,
188
- constant uint64_t & nb10,
189
- constant uint64_t & nb11,
190
- constant uint64_t & nb12,
191
- constant uint64_t & nb13,
192
- constant int64_t & ne0,
193
- constant int64_t & ne1,
194
- constant int64_t & ne2,
195
- constant int64_t & ne3,
196
- constant uint64_t & nb0,
197
- constant uint64_t & nb1,
198
- constant uint64_t & nb2,
199
- constant uint64_t & nb3,
200
- uint3 tgpig[[threadgroup_position_in_grid]],
201
- uint3 tpitg[[thread_position_in_threadgroup]],
202
- uint3 ntg[[threads_per_threadgroup]]) {
203
- const int64_t i03 = tgpig.z;
204
- const int64_t i02 = tgpig.y;
205
- const int64_t i01 = tgpig.x;
206
 
207
- const int64_t i13 = i03 % ne13;
208
- const int64_t i12 = i02 % ne12;
209
- const int64_t i11 = i01 % ne11;
210
 
211
- device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01;
212
- device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11;
213
- device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1;
214
 
215
- for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
216
- const int i10 = i0 % ne10;
217
- *((device float *)(dst_ptr + i0*nb0)) = *((device float *)(src0_ptr + i0*nb00)) / *((device float *)(src1_ptr + i10*nb10));
218
- }
219
- }
220
 
221
- template<typename T>
222
- kernel void kernel_repeat(
223
- device const char * src0,
224
- device char * dst,
225
- constant int64_t & ne00,
226
- constant int64_t & ne01,
227
- constant int64_t & ne02,
228
- constant int64_t & ne03,
229
- constant uint64_t & nb00,
230
- constant uint64_t & nb01,
231
- constant uint64_t & nb02,
232
- constant uint64_t & nb03,
233
- constant int64_t & ne0,
234
- constant int64_t & ne1,
235
- constant int64_t & ne2,
236
- constant int64_t & ne3,
237
- constant uint64_t & nb0,
238
- constant uint64_t & nb1,
239
- constant uint64_t & nb2,
240
- constant uint64_t & nb3,
241
- uint3 tgpig[[threadgroup_position_in_grid]],
242
- uint3 tpitg[[thread_position_in_threadgroup]],
243
- uint3 ntg[[threads_per_threadgroup]]) {
244
- const int64_t i3 = tgpig.z;
245
- const int64_t i2 = tgpig.y;
246
- const int64_t i1 = tgpig.x;
247
-
248
- const int64_t i03 = i3 % ne03;
249
- const int64_t i02 = i2 % ne02;
250
- const int64_t i01 = i1 % ne01;
251
 
252
- device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01;
253
- device char * dst_ptr = dst + i3*nb3 + i2*nb2 + i1*nb1 ;
 
254
 
255
- for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
256
- const int i00 = i0 % ne00;
257
- *((device T *)(dst_ptr + i0*nb0)) = *((device T *)(src0_ptr + i00*nb00));
258
  }
259
  }
260
 
261
- typedef decltype(kernel_repeat<float>) kernel_repeat_t;
262
-
263
- template [[host_name("kernel_repeat_f32")]] kernel kernel_repeat_t kernel_repeat<float>;
264
- template [[host_name("kernel_repeat_f16")]] kernel kernel_repeat_t kernel_repeat<half>;
265
- template [[host_name("kernel_repeat_i32")]] kernel kernel_repeat_t kernel_repeat<int>;
266
- template [[host_name("kernel_repeat_i16")]] kernel kernel_repeat_t kernel_repeat<short>;
267
-
268
- // assumption: src1 is a row
269
- // broadcast src1 into src0
270
- kernel void kernel_add_row(
271
- device const float4 * src0,
272
- device const float4 * src1,
273
- device float4 * dst,
274
- constant uint64_t & nb [[buffer(28)]],
275
- uint tpig[[thread_position_in_grid]]) {
276
- dst[tpig] = src0[tpig] + src1[tpig % nb];
277
- }
278
 
279
- kernel void kernel_sub_row(
280
- device const float4 * src0,
281
- device const float4 * src1,
282
- device float4 * dst,
283
- constant uint64_t & nb [[buffer(28)]],
284
- uint tpig[[thread_position_in_grid]]) {
285
- dst[tpig] = src0[tpig] - src1[tpig % nb];
286
  }
287
 
288
- kernel void kernel_mul_row(
289
- device const float4 * src0,
290
- device const float4 * src1,
291
- device float4 * dst,
292
- constant uint64_t & nb [[buffer(28)]],
293
- uint tpig[[thread_position_in_grid]]) {
294
- dst[tpig] = src0[tpig] * src1[tpig % nb];
295
- }
296
 
297
- kernel void kernel_div_row(
298
- device const float4 * src0,
299
- device const float4 * src1,
300
- device float4 * dst,
301
- constant uint64_t & nb [[buffer(28)]],
302
- uint tpig[[thread_position_in_grid]]) {
303
- dst[tpig] = src0[tpig] / src1[tpig % nb];
304
- }
305
 
306
- kernel void kernel_scale(
307
- device const float * src0,
308
- device float * dst,
309
- constant float & scale,
310
- uint tpig[[thread_position_in_grid]]) {
311
- dst[tpig] = src0[tpig] * scale;
312
  }
313
 
314
- kernel void kernel_scale_4(
315
- device const float4 * src0,
316
- device float4 * dst,
317
- constant float & scale,
318
- uint tpig[[thread_position_in_grid]]) {
319
- dst[tpig] = src0[tpig] * scale;
320
- }
321
 
322
- kernel void kernel_clamp(
323
- device const float * src0,
324
- device float * dst,
325
- constant float & min,
326
- constant float & max,
327
- uint tpig[[thread_position_in_grid]]) {
328
- dst[tpig] = src0[tpig] < min ? min : (src0[tpig] > max ? max : src0[tpig]);
329
- }
 
 
 
330
 
331
- kernel void kernel_relu(
332
- device const float * src0,
333
- device float * dst,
334
- uint tpig[[thread_position_in_grid]]) {
335
- dst[tpig] = max(0.0f, src0[tpig]);
336
- }
337
 
338
- kernel void kernel_sigmoid(
339
- device const float * src0,
340
- device float * dst,
341
- uint tpig[[thread_position_in_grid]]) {
342
- dst[tpig] = 1.0f / (1.0f + exp(-src0[tpig]));
343
  }
344
 
345
- kernel void kernel_tanh(
346
- device const float * src0,
347
- device float * dst,
348
- uint tpig[[thread_position_in_grid]]) {
349
- device const float & x = src0[tpig];
350
- dst[tpig] = precise::tanh(x);
351
  }
352
 
353
- constant float GELU_COEF_A = 0.044715f;
354
- constant float GELU_QUICK_COEF = -1.702f;
355
- constant float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
356
 
357
- kernel void kernel_gelu(
358
- device const float * src0,
359
- device float * dst,
360
- uint tpig[[thread_position_in_grid]]) {
361
- device const float & x = src0[tpig];
 
 
 
362
 
363
- dst[tpig] = 0.5f*x*(1.0f + precise::tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x)));
 
 
 
364
  }
365
 
366
- kernel void kernel_gelu_4(
367
- device const float4 * src0,
368
- device float4 * dst,
369
- uint tpig[[thread_position_in_grid]]) {
370
- device const float4 & x = src0[tpig];
371
-
372
- // BEWARE !!!
373
- // Simply using "tanh" instead of "precise::tanh" will sometimes results in NaNs!
374
- // This was observed with Falcon 7B and 40B models
375
- //
376
- dst[tpig] = 0.5f*x*(1.0f + precise::tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x)));
377
- }
378
 
379
- kernel void kernel_gelu_quick(
380
- device const float * src0,
381
- device float * dst,
382
- uint tpig[[thread_position_in_grid]]) {
383
- device const float & x = src0[tpig];
 
 
 
 
 
384
 
385
- dst[tpig] = x*(1.0f/(1.0f+exp(GELU_QUICK_COEF*x)));
 
 
 
 
386
  }
387
 
388
- kernel void kernel_gelu_quick_4(
389
- device const float4 * src0,
390
- device float4 * dst,
391
- uint tpig[[thread_position_in_grid]]) {
392
- device const float4 & x = src0[tpig];
393
-
394
- dst[tpig] = x*(1.0f/(1.0f+exp(GELU_QUICK_COEF*x)));
395
- }
396
 
397
- kernel void kernel_silu(
398
- device const float * src0,
399
- device float * dst,
400
- uint tpig[[thread_position_in_grid]]) {
401
- device const float & x = src0[tpig];
402
- dst[tpig] = x / (1.0f + exp(-x));
403
- }
404
 
405
- kernel void kernel_silu_4(
406
- device const float4 * src0,
407
- device float4 * dst,
408
- uint tpig[[thread_position_in_grid]]) {
409
- device const float4 & x = src0[tpig];
410
- dst[tpig] = x / (1.0f + exp(-x));
 
 
 
 
411
  }
412
 
413
- kernel void kernel_sqr(
414
- device const float * src0,
415
- device float * dst,
416
- uint tpig[[thread_position_in_grid]]) {
417
- dst[tpig] = src0[tpig] * src0[tpig];
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
418
  }
419
 
420
- kernel void kernel_sqrt(
421
- device const float * src0,
422
- device float * dst,
423
- uint tpig[[thread_position_in_grid]]) {
424
- dst[tpig] = sqrt(src0[tpig]);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
425
  }
426
 
427
- kernel void kernel_sin(
428
- device const float * src0,
429
- device float * dst,
430
- uint tpig[[thread_position_in_grid]]) {
431
- dst[tpig] = sin(src0[tpig]);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
432
  }
433
 
434
- kernel void kernel_cos(
435
- device const float * src0,
436
- device float * dst,
437
- uint tpig[[thread_position_in_grid]]) {
438
- dst[tpig] = cos(src0[tpig]);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
439
  }
440
 
441
- kernel void kernel_sum_rows(
442
- device const float * src0,
443
- device float * dst,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
444
  constant int64_t & ne00,
445
  constant int64_t & ne01,
446
  constant int64_t & ne02,
@@ -465,132 +484,446 @@ kernel void kernel_sum_rows(
465
  constant uint64_t & nb1,
466
  constant uint64_t & nb2,
467
  constant uint64_t & nb3,
468
- uint3 tpig[[thread_position_in_grid]]) {
469
- int64_t i3 = tpig.z;
470
- int64_t i2 = tpig.y;
471
- int64_t i1 = tpig.x;
 
 
 
472
 
473
- if (i3 >= ne03 || i2 >= ne02 || i1 >= ne01) {
474
- return;
475
- }
476
-
477
- device const float * src_row = (device const float *) ((device const char *) src0 + i1*nb01 + i2*nb02 + i3*nb03);
478
- device float * dst_row = (device float *) ((device char *) dst + i1*nb1 + i2*nb2 + i3*nb3);
479
 
480
- float row_sum = 0;
 
 
481
 
482
- for (int64_t i0 = 0; i0 < ne00; i0++) {
483
- row_sum += src_row[i0];
 
484
  }
485
-
486
- dst_row[0] = row_sum;
487
  }
488
 
489
- template<typename T>
490
- kernel void kernel_soft_max(
491
- device const char * src0,
492
- device const char * src1,
493
- device char * dst,
494
- constant int64_t & ne00,
495
- constant int64_t & ne01,
496
- constant int64_t & ne02,
497
- constant float & scale,
498
- constant float & max_bias,
499
- constant float & m0,
500
- constant float & m1,
501
- constant uint32_t & n_head_log2,
502
- threadgroup float * buf [[threadgroup(0)]],
503
- uint tgpig[[threadgroup_position_in_grid]],
504
- uint tpitg[[thread_position_in_threadgroup]],
505
- uint sgitg[[simdgroup_index_in_threadgroup]],
506
- uint tiisg[[thread_index_in_simdgroup]],
507
- uint ntg[[threads_per_threadgroup]]) {
508
- const int64_t i03 = (tgpig) / (ne02*ne01);
509
- const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01;
510
- const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);
 
 
 
 
 
 
 
 
 
 
 
 
 
511
 
512
- device const float * psrc0 = (device const float *) src0 + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
513
- device const T * pmask = src1 != src0 ? (device const T *) src1 + i01*ne00 : nullptr;
514
- device float * pdst = (device float *) dst + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
515
 
516
- float slope = 1.0f;
 
 
517
 
518
- // ALiBi
519
- if (max_bias > 0.0f) {
520
- const int64_t h = i02;
 
 
521
 
522
- const float base = h < n_head_log2 ? m0 : m1;
523
- const int exp = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
524
 
525
- slope = pow(base, exp);
526
- }
 
527
 
528
- // parallel max
529
- float lmax = -INFINITY;
 
530
 
531
- for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
532
- lmax = MAX(lmax, psrc0[i00]*scale + (pmask ? slope*pmask[i00] : 0.0f));
 
533
  }
 
534
 
535
- // find the max value in the block
536
- float max_val = simd_max(lmax);
537
- if (ntg > N_SIMDWIDTH) {
538
- if (sgitg == 0) {
539
- buf[tiisg] = -INFINITY;
540
- }
541
-
542
- threadgroup_barrier(mem_flags::mem_threadgroup);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
543
 
544
- if (tiisg == 0) {
545
- buf[sgitg] = max_val;
546
- }
547
 
548
- threadgroup_barrier(mem_flags::mem_threadgroup);
 
 
549
 
550
- max_val = buf[tiisg];
551
- max_val = simd_max(max_val);
 
552
  }
 
553
 
554
- // parallel sum
555
- float lsum = 0.0f;
556
- for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
557
- const float exp_psrc0 = exp((psrc0[i00]*scale + (pmask ? slope*pmask[i00] : 0.0f)) - max_val);
558
- lsum += exp_psrc0;
559
- pdst[i00] = exp_psrc0;
560
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
561
 
562
- // This barrier fixes a failing test
563
- // ref: https://github.com/ggerganov/ggml/pull/621#discussion_r1425156335
564
- threadgroup_barrier(mem_flags::mem_none);
565
 
566
- float sum = simd_sum(lsum);
 
567
 
568
- if (ntg > N_SIMDWIDTH) {
569
- if (sgitg == 0) {
570
- buf[tiisg] = 0.0f;
571
- }
 
572
 
573
- threadgroup_barrier(mem_flags::mem_threadgroup);
574
 
575
- if (tiisg == 0) {
576
- buf[sgitg] = sum;
577
- }
 
578
 
579
- threadgroup_barrier(mem_flags::mem_threadgroup);
 
 
 
 
 
 
 
 
 
580
 
581
- sum = buf[tiisg];
582
- sum = simd_sum(sum);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
583
  }
584
 
585
- const float inv_sum = 1.0f/sum;
 
586
 
587
- for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
588
- pdst[i00] *= inv_sum;
 
 
589
  }
 
 
590
  }
591
 
592
  template<typename T>
593
- kernel void kernel_soft_max_4(
594
  device const char * src0,
595
  device const char * src1,
596
  device char * dst,
@@ -612,12 +945,13 @@ kernel void kernel_soft_max_4(
612
  const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01;
613
  const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);
614
 
615
- device const float4 * psrc4 = (device const float4 *) src0 + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00)/4;
616
- device const T * pmask = src1 != src0 ? (device const T *) src1 + i01*ne00/4 : nullptr;
617
- device float4 * pdst4 = (device float4 *) dst + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00)/4;
618
 
619
  float slope = 1.0f;
620
 
 
621
  if (max_bias > 0.0f) {
622
  const int64_t h = i02;
623
 
@@ -628,14 +962,13 @@ kernel void kernel_soft_max_4(
628
  }
629
 
630
  // parallel max
631
- float4 lmax4 = -INFINITY;
632
 
633
- for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
634
- lmax4 = fmax(lmax4, psrc4[i00]*scale + (float4)((pmask ? slope*pmask[i00] : 0.0f)));
635
  }
636
 
637
- const float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3]));
638
-
639
  float max_val = simd_max(lmax);
640
  if (ntg > N_SIMDWIDTH) {
641
  if (sgitg == 0) {
@@ -655,14 +988,117 @@ kernel void kernel_soft_max_4(
655
  }
656
 
657
  // parallel sum
658
- float4 lsum4 = 0.0f;
659
- for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
660
- const float4 exp_psrc4 = exp((psrc4[i00]*scale + (float4)((pmask ? slope*pmask[i00] : 0.0f))) - max_val);
661
- lsum4 += exp_psrc4;
662
- pdst4[i00] = exp_psrc4;
663
  }
664
 
665
- const float lsum = lsum4[0] + lsum4[1] + lsum4[2] + lsum4[3];
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
666
 
667
  // This barrier fixes a failing test
668
  // ref: https://github.com/ggerganov/ggml/pull/621#discussion_r1425156335
@@ -3339,10 +3775,6 @@ static inline int best_index_int8(int n, constant float * val, float x) {
3339
  return x - val[mu-1] < val[mu] - x ? mu-1 : mu;
3340
  }
3341
 
3342
- constexpr constant static float kvalues_iq4nl_f[16] = {
3343
- -127.f, -104.f, -83.f, -65.f, -49.f, -35.f, -22.f, -10.f, 1.f, 13.f, 25.f, 38.f, 53.f, 69.f, 89.f, 113.f
3344
- };
3345
-
3346
  kernel void kernel_cpy_f32_iq4_nl(
3347
  device const float * src0,
3348
  device void * dst,
@@ -5457,440 +5889,6 @@ kernel void kernel_mul_mv_iq4_xs_f32(
5457
  kernel_mul_mv_iq4_xs_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb01, nb02, nb03, ne10, ne12, nb11, nb12, nb13, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
5458
  }
5459
 
5460
- //============================= templates and their specializations =============================
5461
-
5462
- // NOTE: this is not dequantizing - we are simply fitting the template
5463
- template <typename type4x4>
5464
- void dequantize_f32(device const float4x4 * src, short il, thread type4x4 & reg) {
5465
- float4x4 temp = *(((device float4x4 *)src));
5466
- for (int i = 0; i < 16; i++){
5467
- reg[i/4][i%4] = temp[i/4][i%4];
5468
- }
5469
- }
5470
-
5471
- template <typename type4x4>
5472
- void dequantize_f16(device const half4x4 * src, short il, thread type4x4 & reg) {
5473
- half4x4 temp = *(((device half4x4 *)src));
5474
- for (int i = 0; i < 16; i++){
5475
- reg[i/4][i%4] = temp[i/4][i%4];
5476
- }
5477
- }
5478
-
5479
- template <typename type4x4>
5480
- void dequantize_q4_0(device const block_q4_0 *xb, short il, thread type4x4 & reg) {
5481
- device const uint16_t * qs = ((device const uint16_t *)xb + 1);
5482
- const float d1 = il ? (xb->d / 16.h) : xb->d;
5483
- const float d2 = d1 / 256.f;
5484
- const float md = -8.h * xb->d;
5485
- const ushort mask0 = il ? 0x00F0 : 0x000F;
5486
- const ushort mask1 = mask0 << 8;
5487
-
5488
- for (int i=0;i<8;i++) {
5489
- reg[i/2][2*(i%2)+0] = d1 * (qs[i] & mask0) + md;
5490
- reg[i/2][2*(i%2)+1] = d2 * (qs[i] & mask1) + md;
5491
- }
5492
- }
5493
-
5494
- template <typename type4x4>
5495
- void dequantize_q4_1(device const block_q4_1 *xb, short il, thread type4x4 & reg) {
5496
- device const uint16_t * qs = ((device const uint16_t *)xb + 2);
5497
- const float d1 = il ? (xb->d / 16.h) : xb->d;
5498
- const float d2 = d1 / 256.f;
5499
- const float m = xb->m;
5500
- const ushort mask0 = il ? 0x00F0 : 0x000F;
5501
- const ushort mask1 = mask0 << 8;
5502
-
5503
- for (int i=0;i<8;i++) {
5504
- reg[i/2][2*(i%2)+0] = ((qs[i] & mask0) * d1) + m;
5505
- reg[i/2][2*(i%2)+1] = ((qs[i] & mask1) * d2) + m;
5506
- }
5507
- }
5508
-
5509
- template <typename type4x4>
5510
- void dequantize_q5_0(device const block_q5_0 *xb, short il, thread type4x4 & reg) {
5511
- device const uint16_t * qs = ((device const uint16_t *)xb + 3);
5512
- const float d = xb->d;
5513
- const float md = -16.h * xb->d;
5514
- const ushort mask = il ? 0x00F0 : 0x000F;
5515
-
5516
- const uint32_t qh = *((device const uint32_t *)xb->qh);
5517
-
5518
- const int x_mv = il ? 4 : 0;
5519
-
5520
- const int gh_mv = il ? 12 : 0;
5521
- const int gh_bk = il ? 0 : 4;
5522
-
5523
- for (int i = 0; i < 8; i++) {
5524
- // extract the 5-th bits for x0 and x1
5525
- const uint8_t xh_0 = ((qh >> (gh_mv + 2*i )) << gh_bk) & 0x10;
5526
- const uint8_t xh_1 = ((qh >> (gh_mv + 2*i+1)) << gh_bk) & 0x10;
5527
-
5528
- // combine the 4-bits from qs with the 5th bit
5529
- const int32_t x0 = ((((qs[i] ) & mask) >> x_mv) | xh_0);
5530
- const int32_t x1 = ((((qs[i] >> 8) & mask) >> x_mv) | xh_1);
5531
-
5532
- reg[i/2][2*(i%2)+0] = d * x0 + md;
5533
- reg[i/2][2*(i%2)+1] = d * x1 + md;
5534
- }
5535
- }
5536
-
5537
- template <typename type4x4>
5538
- void dequantize_q5_1(device const block_q5_1 *xb, short il, thread type4x4 & reg) {
5539
- device const uint16_t * qs = ((device const uint16_t *)xb + 4);
5540
- const float d = xb->d;
5541
- const float m = xb->m;
5542
- const ushort mask = il ? 0x00F0 : 0x000F;
5543
-
5544
- const uint32_t qh = *((device const uint32_t *)xb->qh);
5545
-
5546
- const int x_mv = il ? 4 : 0;
5547
-
5548
- const int gh_mv = il ? 12 : 0;
5549
- const int gh_bk = il ? 0 : 4;
5550
-
5551
- for (int i = 0; i < 8; i++) {
5552
- // extract the 5-th bits for x0 and x1
5553
- const uint8_t xh_0 = ((qh >> (gh_mv + 2*i )) << gh_bk) & 0x10;
5554
- const uint8_t xh_1 = ((qh >> (gh_mv + 2*i+1)) << gh_bk) & 0x10;
5555
-
5556
- // combine the 4-bits from qs with the 5th bit
5557
- const int32_t x0 = ((((qs[i] ) & mask) >> x_mv) | xh_0);
5558
- const int32_t x1 = ((((qs[i] >> 8) & mask) >> x_mv) | xh_1);
5559
-
5560
- reg[i/2][2*(i%2)+0] = d * x0 + m;
5561
- reg[i/2][2*(i%2)+1] = d * x1 + m;
5562
- }
5563
- }
5564
-
5565
- template <typename type4x4>
5566
- void dequantize_q8_0(device const block_q8_0 *xb, short il, thread type4x4 & reg) {
5567
- device const int8_t * qs = ((device const int8_t *)xb->qs);
5568
- const half d = xb->d;
5569
-
5570
- for (int i = 0; i < 16; i++) {
5571
- reg[i/4][i%4] = (qs[i + 16*il] * d);
5572
- }
5573
- }
5574
-
5575
- template <typename type4x4>
5576
- void dequantize_q2_K(device const block_q2_K *xb, short il, thread type4x4 & reg) {
5577
- const float d = xb->d;
5578
- const float min = xb->dmin;
5579
- device const uint8_t * q = (device const uint8_t *)xb->qs;
5580
- float dl, ml;
5581
- uint8_t sc = xb->scales[il];
5582
-
5583
- q = q + 32*(il/8) + 16*(il&1);
5584
- il = (il/2)%4;
5585
-
5586
- half coef = il>1 ? (il>2 ? 1/64.h : 1/16.h) : (il>0 ? 1/4.h : 1.h);
5587
- uchar mask = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3);
5588
- dl = d * (sc & 0xF) * coef, ml = min * (sc >> 4);
5589
- for (int i = 0; i < 16; ++i) {
5590
- reg[i/4][i%4] = dl * (q[i] & mask) - ml;
5591
- }
5592
- }
5593
-
5594
- template <typename type4x4>
5595
- void dequantize_q3_K(device const block_q3_K *xb, short il, thread type4x4 & reg) {
5596
- const half d_all = xb->d;
5597
- device const uint8_t * q = (device const uint8_t *)xb->qs;
5598
- device const uint8_t * h = (device const uint8_t *)xb->hmask;
5599
- device const int8_t * scales = (device const int8_t *)xb->scales;
5600
-
5601
- q = q + 32 * (il/8) + 16 * (il&1);
5602
- h = h + 16 * (il&1);
5603
- uint8_t m = 1 << (il/2);
5604
- uint16_t kmask1 = (il/4)>1 ? ((il/4)>2 ? 192 : 48) : \
5605
- ((il/4)>0 ? 12 : 3);
5606
- uint16_t kmask2 = il/8 ? 0xF0 : 0x0F;
5607
- uint16_t scale_2 = scales[il%8], scale_1 = scales[8 + il%4];
5608
- int16_t dl_int = (il/4)&1 ? (scale_2&kmask2) | ((scale_1&kmask1) << 2)
5609
- : (scale_2&kmask2) | ((scale_1&kmask1) << 4);
5610
- float dl = il<8 ? d_all * (dl_int - 32.f) : d_all * (dl_int / 16.f - 32.f);
5611
- const float ml = 4.f * dl;
5612
-
5613
- il = (il/2) & 3;
5614
- const half coef = il>1 ? (il>2 ? 1/64.h : 1/16.h) : (il>0 ? 1/4.h : 1.h);
5615
- const uint8_t mask = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3);
5616
- dl *= coef;
5617
-
5618
- for (int i = 0; i < 16; ++i) {
5619
- reg[i/4][i%4] = dl * (q[i] & mask) - (h[i] & m ? 0 : ml);
5620
- }
5621
- }
5622
-
5623
- static inline uchar2 get_scale_min_k4_just2(int j, int k, device const uchar * q) {
5624
- return j < 4 ? uchar2{uchar(q[j+0+k] & 63), uchar(q[j+4+k] & 63)}
5625
- : uchar2{uchar((q[j+4+k] & 0xF) | ((q[j-4+k] & 0xc0) >> 2)), uchar((q[j+4+k] >> 4) | ((q[j-0+k] & 0xc0) >> 2))};
5626
- }
5627
-
5628
- template <typename type4x4>
5629
- void dequantize_q4_K(device const block_q4_K *xb, short il, thread type4x4 & reg) {
5630
- device const uchar * q = xb->qs;
5631
-
5632
- short is = (il/4) * 2;
5633
- q = q + (il/4) * 32 + 16 * (il&1);
5634
- il = il & 3;
5635
- const uchar2 sc = get_scale_min_k4_just2(is, il/2, xb->scales);
5636
- const float d = il < 2 ? xb->d : xb->d / 16.h;
5637
- const float min = xb->dmin;
5638
- const float dl = d * sc[0];
5639
- const float ml = min * sc[1];
5640
-
5641
- const ushort mask = il<2 ? 0x0F : 0xF0;
5642
- for (int i = 0; i < 16; ++i) {
5643
- reg[i/4][i%4] = dl * (q[i] & mask) - ml;
5644
- }
5645
- }
5646
-
5647
- template <typename type4x4>
5648
- void dequantize_q5_K(device const block_q5_K *xb, short il, thread type4x4 & reg) {
5649
- device const uint8_t * q = xb->qs;
5650
- device const uint8_t * qh = xb->qh;
5651
-
5652
- short is = (il/4) * 2;
5653
- q = q + 32 * (il/4) + 16 * (il&1);
5654
- qh = qh + 16 * (il&1);
5655
- uint8_t ul = 1 << (il/2);
5656
- il = il & 3;
5657
- const uchar2 sc = get_scale_min_k4_just2(is, il/2, xb->scales);
5658
- const float d = il < 2 ? xb->d : xb->d / 16.f;
5659
- const float min = xb->dmin;
5660
- const float dl = d * sc[0];
5661
- const float ml = min * sc[1];
5662
-
5663
- const ushort mask = il<2 ? 0x0F : 0xF0;
5664
- const float qh_val = il<2 ? 16.f : 256.f;
5665
- for (int i = 0; i < 16; ++i) {
5666
- reg[i/4][i%4] = dl * ((q[i] & mask) + (qh[i] & ul ? qh_val : 0)) - ml;
5667
- }
5668
- }
5669
-
5670
- template <typename type4x4>
5671
- void dequantize_q6_K(device const block_q6_K *xb, short il, thread type4x4 & reg) {
5672
- const half d_all = xb->d;
5673
- device const uint8_t * ql = (device const uint8_t *)xb->ql;
5674
- device const uint8_t * qh = (device const uint8_t *)xb->qh;
5675
- device const int8_t * scales = (device const int8_t *)xb->scales;
5676
-
5677
- ql = ql + 64*(il/8) + 32*((il/2)&1) + 16*(il&1);
5678
- qh = qh + 32*(il/8) + 16*(il&1);
5679
- float sc = scales[(il%2) + 2 * ((il/2))];
5680
- il = (il/2) & 3;
5681
-
5682
- const uint16_t kmask1 = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3);
5683
- const uint16_t kmask2 = il>1 ? 0xF0 : 0x0F;
5684
- const float coef = il>1 ? 1.f/16.f : 1.f;
5685
- const float ml = d_all * sc * 32.f;
5686
- const float dl = d_all * sc * coef;
5687
- for (int i = 0; i < 16; ++i) {
5688
- const half q = il&1 ? ((ql[i] & kmask2) | ((qh[i] & kmask1) << 2))
5689
- : ((ql[i] & kmask2) | ((qh[i] & kmask1) << 4));
5690
- reg[i/4][i%4] = dl * q - ml;
5691
- }
5692
- }
5693
-
5694
- template <typename type4x4>
5695
- void dequantize_iq2_xxs(device const block_iq2_xxs * xb, short il, thread type4x4 & reg) {
5696
- // il is 0...15 for QK_K = 256 => index of block of 32 is il/2
5697
- const float d = xb->d;
5698
- const int ib32 = il/2;
5699
- il = il%2;
5700
- // il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16
5701
- // each block of 32 needs 2 uint32_t's for the quants & scale, so 4 uint16_t's.
5702
- device const uint16_t * q2 = xb->qs + 4*ib32;
5703
- const uint32_t aux32_g = q2[0] | (q2[1] << 16);
5704
- const uint32_t aux32_s = q2[2] | (q2[3] << 16);
5705
- thread const uint8_t * aux8 = (thread const uint8_t *)&aux32_g;
5706
- const float dl = d * (0.5f + (aux32_s >> 28)) * 0.25f;
5707
- constant uint8_t * grid = (constant uint8_t *)(iq2xxs_grid + aux8[2*il+0]);
5708
- uint8_t signs = ksigns_iq2xs[(aux32_s >> 14*il) & 127];
5709
- for (int i = 0; i < 8; ++i) {
5710
- reg[i/4][i%4] = dl * grid[i] * (signs & kmask_iq2xs[i] ? -1.f : 1.f);
5711
- }
5712
- grid = (constant uint8_t *)(iq2xxs_grid + aux8[2*il+1]);
5713
- signs = ksigns_iq2xs[(aux32_s >> (14*il+7)) & 127];
5714
- for (int i = 0; i < 8; ++i) {
5715
- reg[2+i/4][i%4] = dl * grid[i] * (signs & kmask_iq2xs[i] ? -1.f : 1.f);
5716
- }
5717
- }
5718
-
5719
- template <typename type4x4>
5720
- void dequantize_iq2_xs(device const block_iq2_xs * xb, short il, thread type4x4 & reg) {
5721
- // il is 0...15 for QK_K = 256 => index of block of 32 is il/2
5722
- const float d = xb->d;
5723
- const int ib32 = il/2;
5724
- il = il%2;
5725
- // il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16
5726
- device const uint16_t * q2 = xb->qs + 4*ib32;
5727
- const float dl = d * (0.5f + ((xb->scales[ib32] >> 4*il) & 0xf)) * 0.25f;
5728
- constant uint8_t * grid = (constant uint8_t *)(iq2xs_grid + (q2[2*il+0] & 511));
5729
- uint8_t signs = ksigns_iq2xs[q2[2*il+0] >> 9];
5730
- for (int i = 0; i < 8; ++i) {
5731
- reg[i/4][i%4] = dl * grid[i] * (signs & kmask_iq2xs[i] ? -1.f : 1.f);
5732
- }
5733
- grid = (constant uint8_t *)(iq2xs_grid + (q2[2*il+1] & 511));
5734
- signs = ksigns_iq2xs[q2[2*il+1] >> 9];
5735
- for (int i = 0; i < 8; ++i) {
5736
- reg[2+i/4][i%4] = dl * grid[i] * (signs & kmask_iq2xs[i] ? -1.f : 1.f);
5737
- }
5738
- }
5739
-
5740
- template <typename type4x4>
5741
- void dequantize_iq3_xxs(device const block_iq3_xxs * xb, short il, thread type4x4 & reg) {
5742
- // il is 0...15 for QK_K = 256 => index of block of 32 is il/2
5743
- const float d = xb->d;
5744
- const int ib32 = il/2;
5745
- il = il%2;
5746
- // il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16
5747
- device const uint8_t * q3 = xb->qs + 8*ib32;
5748
- device const uint16_t * gas = (device const uint16_t *)(xb->qs + QK_K/4) + 2*ib32;
5749
- const uint32_t aux32 = gas[0] | (gas[1] << 16);
5750
- const float dl = d * (0.5f + (aux32 >> 28)) * 0.5f;
5751
- constant uint8_t * grid1 = (constant uint8_t *)(iq3xxs_grid + q3[4*il+0]);
5752
- constant uint8_t * grid2 = (constant uint8_t *)(iq3xxs_grid + q3[4*il+1]);
5753
- uint8_t signs = ksigns_iq2xs[(aux32 >> 14*il) & 127];
5754
- for (int i = 0; i < 4; ++i) {
5755
- reg[0][i] = dl * grid1[i] * (signs & kmask_iq2xs[i+0] ? -1.f : 1.f);
5756
- reg[1][i] = dl * grid2[i] * (signs & kmask_iq2xs[i+4] ? -1.f : 1.f);
5757
- }
5758
- grid1 = (constant uint8_t *)(iq3xxs_grid + q3[4*il+2]);
5759
- grid2 = (constant uint8_t *)(iq3xxs_grid + q3[4*il+3]);
5760
- signs = ksigns_iq2xs[(aux32 >> (14*il+7)) & 127];
5761
- for (int i = 0; i < 4; ++i) {
5762
- reg[2][i] = dl * grid1[i] * (signs & kmask_iq2xs[i+0] ? -1.f : 1.f);
5763
- reg[3][i] = dl * grid2[i] * (signs & kmask_iq2xs[i+4] ? -1.f : 1.f);
5764
- }
5765
- }
5766
-
5767
- template <typename type4x4>
5768
- void dequantize_iq3_s(device const block_iq3_s * xb, short il, thread type4x4 & reg) {
5769
- // il is 0...15 for QK_K = 256 => index of block of 32 is il/2
5770
- const float d = xb->d;
5771
- const int ib32 = il/2;
5772
- il = il%2;
5773
- // il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16
5774
- device const uint8_t * qs = xb->qs + 8*ib32;
5775
- device const uint8_t * signs = xb->signs + 4*ib32 + 2*il;
5776
- const uint8_t qh = xb->qh[ib32] >> 4*il;
5777
- const float dl = d * (1 + 2*((xb->scales[ib32/2] >> 4*(ib32%2)) & 0xf));
5778
- constant uint8_t * grid1 = (constant uint8_t *)(iq3s_grid + (qs[4*il+0] | ((qh << 8) & 256)));
5779
- constant uint8_t * grid2 = (constant uint8_t *)(iq3s_grid + (qs[4*il+1] | ((qh << 7) & 256)));
5780
- for (int i = 0; i < 4; ++i) {
5781
- reg[0][i] = dl * grid1[i] * select(1, -1, signs[0] & kmask_iq2xs[i+0]);
5782
- reg[1][i] = dl * grid2[i] * select(1, -1, signs[0] & kmask_iq2xs[i+4]);
5783
- }
5784
- grid1 = (constant uint8_t *)(iq3s_grid + (qs[4*il+2] | ((qh << 6) & 256)));
5785
- grid2 = (constant uint8_t *)(iq3s_grid + (qs[4*il+3] | ((qh << 5) & 256)));
5786
- for (int i = 0; i < 4; ++i) {
5787
- reg[2][i] = dl * grid1[i] * select(1, -1, signs[1] & kmask_iq2xs[i+0]);
5788
- reg[3][i] = dl * grid2[i] * select(1, -1, signs[1] & kmask_iq2xs[i+4]);
5789
- }
5790
- }
5791
-
5792
- template <typename type4x4>
5793
- void dequantize_iq2_s(device const block_iq2_s * xb, short il, thread type4x4 & reg) {
5794
- // il is 0...15 for QK_K = 256 => index of block of 32 is il/2
5795
- const float d = xb->d;
5796
- const int ib32 = il/2;
5797
- il = il%2;
5798
- // il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16
5799
- device const uint8_t * qs = xb->qs + 4*ib32 + 2*il;
5800
- device const uint8_t * signs = qs + QK_K/8;
5801
- const uint8_t qh = xb->qh[ib32] >> 4*il;
5802
- const float dl = d * (0.5f + ((xb->scales[ib32] >> 4*il) & 0xf)) * 0.25f;
5803
- constant uint8_t * grid1 = (constant uint8_t *)(iq2s_grid + (qs[0] | ((qh << 8) & 0x300)));
5804
- constant uint8_t * grid2 = (constant uint8_t *)(iq2s_grid + (qs[1] | ((qh << 6) & 0x300)));
5805
- for (int i = 0; i < 8; ++i) {
5806
- reg[i/4+0][i%4] = dl * grid1[i] * select(1, -1, signs[0] & kmask_iq2xs[i]);
5807
- reg[i/4+2][i%4] = dl * grid2[i] * select(1, -1, signs[1] & kmask_iq2xs[i]);
5808
- }
5809
- }
5810
-
5811
- template <typename type4x4>
5812
- void dequantize_iq1_s(device const block_iq1_s * xb, short il, thread type4x4 & reg) {
5813
- // il is 0...15 for QK_K = 256 => index of block of 32 is il/2
5814
- const int ib32 = il/2;
5815
- il = il%2;
5816
- const float d = xb->d;
5817
- device const uint8_t * qs = xb->qs + 4*ib32 + 2*il;
5818
- device const uint16_t * qh = xb->qh;
5819
- const float dl = d * (2*((qh[ib32] >> 12) & 7) + 1);
5820
- const float ml = dl * (qh[ib32] & 0x8000 ? -1 - IQ1S_DELTA : -1 + IQ1S_DELTA);
5821
- const uint16_t h = qh[ib32] >> 6*il;
5822
- constant uint8_t * grid1 = (constant uint8_t *)(iq1s_grid_gpu + (qs[0] | ((h << 8) & 0x700)));
5823
- constant uint8_t * grid2 = (constant uint8_t *)(iq1s_grid_gpu + (qs[1] | ((h << 5) & 0x700)));
5824
- for (int i = 0; i < 4; ++i) {
5825
- reg[0][i] = dl * (grid1[i] & 0xf) + ml;
5826
- reg[1][i] = dl * (grid1[i] >> 4) + ml;
5827
- reg[2][i] = dl * (grid2[i] & 0xf) + ml;
5828
- reg[3][i] = dl * (grid2[i] >> 4) + ml;
5829
- }
5830
- }
5831
-
5832
- template <typename type4x4>
5833
- void dequantize_iq1_m(device const block_iq1_m * xb, short il, thread type4x4 & reg) {
5834
- // il is 0...15 for QK_K = 256 => index of block of 32 is il/2
5835
- const int ib32 = il/2;
5836
- il = il%2;
5837
- device const uint16_t * sc = (device const uint16_t *)xb->scales;
5838
-
5839
- iq1m_scale_t scale;
5840
- scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000);
5841
- const float d = scale.f16;
5842
-
5843
- device const uint8_t * qs = xb->qs + 4*ib32 + 2*il;
5844
- device const uint8_t * qh = xb->qh + 2*ib32 + il;
5845
-
5846
- const float dl = d * (2*((sc[ib32/2] >> (6*(ib32%2)+3*il)) & 7) + 1);
5847
- const float ml1 = dl * (qh[0] & 0x08 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA);
5848
- const float ml2 = dl * (qh[0] & 0x80 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA);
5849
- constant uint8_t * grid1 = (constant uint8_t *)(iq1s_grid_gpu + (qs[0] | ((qh[0] << 8) & 0x700)));
5850
- constant uint8_t * grid2 = (constant uint8_t *)(iq1s_grid_gpu + (qs[1] | ((qh[0] << 4) & 0x700)));
5851
- for (int i = 0; i < 4; ++i) {
5852
- reg[0][i] = dl * (grid1[i] & 0xf) + ml1;
5853
- reg[1][i] = dl * (grid1[i] >> 4) + ml1;
5854
- reg[2][i] = dl * (grid2[i] & 0xf) + ml2;
5855
- reg[3][i] = dl * (grid2[i] >> 4) + ml2;
5856
- }
5857
- }
5858
-
5859
- template <typename type4x4>
5860
- void dequantize_iq4_nl(device const block_iq4_nl * xb, short il, thread type4x4 & reg) {
5861
- device const uint16_t * q4 = (device const uint16_t *)xb->qs;
5862
- const float d = xb->d;
5863
- uint32_t aux32;
5864
- thread const uint8_t * q8 = (thread const uint8_t *)&aux32;
5865
- for (int i = 0; i < 4; ++i) {
5866
- aux32 = ((q4[2*i] | (q4[2*i+1] << 16)) >> 4*il) & 0x0f0f0f0f;
5867
- reg[i][0] = d * kvalues_iq4nl_f[q8[0]];
5868
- reg[i][1] = d * kvalues_iq4nl_f[q8[1]];
5869
- reg[i][2] = d * kvalues_iq4nl_f[q8[2]];
5870
- reg[i][3] = d * kvalues_iq4nl_f[q8[3]];
5871
- }
5872
- }
5873
-
5874
- template <typename type4x4>
5875
- void dequantize_iq4_xs(device const block_iq4_xs * xb, short il, thread type4x4 & reg) {
5876
- // il is 0...15 for QK_K = 256 => index of block of 32 is il/2
5877
- const int ib32 = il/2;
5878
- il = il%2;
5879
- // il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16
5880
- device const uint32_t * q4 = (device const uint32_t *)xb->qs + 4*ib32;
5881
- const int ls = ((xb->scales_l[ib32/2] >> 4*(ib32%2)) & 0xf) | (((xb->scales_h >> 2*ib32) & 3) << 4);
5882
- const float d = (float)xb->d * (ls - 32);
5883
- uint32_t aux32;
5884
- thread const uint8_t * q8 = (thread const uint8_t *)&aux32;
5885
- for (int i = 0; i < 4; ++i) {
5886
- aux32 = (q4[i] >> 4*il) & 0x0f0f0f0f;
5887
- reg[i][0] = d * kvalues_iq4nl_f[q8[0]];
5888
- reg[i][1] = d * kvalues_iq4nl_f[q8[1]];
5889
- reg[i][2] = d * kvalues_iq4nl_f[q8[2]];
5890
- reg[i][3] = d * kvalues_iq4nl_f[q8[3]];
5891
- }
5892
- }
5893
-
5894
  template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread float4x4 &)>
5895
  kernel void kernel_get_rows_q(
5896
  device const void * src0,
 
12
 
13
  #define N_SIMDWIDTH 32 // assuming SIMD group size is 32
14
 
15
+ constexpr constant static float kvalues_iq4nl_f[16] = {
16
+ -127.f, -104.f, -83.f, -65.f, -49.f, -35.f, -22.f, -10.f, 1.f, 13.f, 25.f, 38.f, 53.f, 69.f, 89.f, 113.f
 
17
  };
18
 
19
+ // NOTE: this is not dequantizing - we are simply fitting the template
20
+ template <typename type4x4>
21
+ void dequantize_f32(device const float4x4 * src, short il, thread type4x4 & reg) {
22
+ float4x4 temp = *(((device float4x4 *)src));
23
+ for (int i = 0; i < 16; i++){
24
+ reg[i/4][i%4] = temp[i/4][i%4];
25
+ }
26
+ }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
 
28
+ template <typename type4x4>
29
+ void dequantize_f16(device const half4x4 * src, short il, thread type4x4 & reg) {
30
+ half4x4 temp = *(((device half4x4 *)src));
31
+ for (int i = 0; i < 16; i++){
32
+ reg[i/4][i%4] = temp[i/4][i%4];
33
  }
34
  }
35
 
36
+ template <typename type4x4>
37
+ void dequantize_q4_0(device const block_q4_0 *xb, short il, thread type4x4 & reg) {
38
+ device const uint16_t * qs = ((device const uint16_t *)xb + 1);
39
+ const float d1 = il ? (xb->d / 16.h) : xb->d;
40
+ const float d2 = d1 / 256.f;
41
+ const float md = -8.h * xb->d;
42
+ const ushort mask0 = il ? 0x00F0 : 0x000F;
43
+ const ushort mask1 = mask0 << 8;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
 
45
+ for (int i=0;i<8;i++) {
46
+ reg[i/2][2*(i%2)+0] = d1 * (qs[i] & mask0) + md;
47
+ reg[i/2][2*(i%2)+1] = d2 * (qs[i] & mask1) + md;
48
+ }
49
+ }
50
 
51
+ template <typename type4x4>
52
+ void dequantize_q4_1(device const block_q4_1 *xb, short il, thread type4x4 & reg) {
53
+ device const uint16_t * qs = ((device const uint16_t *)xb + 2);
54
+ const float d1 = il ? (xb->d / 16.h) : xb->d;
55
+ const float d2 = d1 / 256.f;
56
+ const float m = xb->m;
57
+ const ushort mask0 = il ? 0x00F0 : 0x000F;
58
+ const ushort mask1 = mask0 << 8;
59
 
60
+ for (int i=0;i<8;i++) {
61
+ reg[i/2][2*(i%2)+0] = ((qs[i] & mask0) * d1) + m;
62
+ reg[i/2][2*(i%2)+1] = ((qs[i] & mask1) * d2) + m;
63
  }
64
  }
65
 
66
+ template <typename type4x4>
67
+ void dequantize_q5_0(device const block_q5_0 *xb, short il, thread type4x4 & reg) {
68
+ device const uint16_t * qs = ((device const uint16_t *)xb + 3);
69
+ const float d = xb->d;
70
+ const float md = -16.h * xb->d;
71
+ const ushort mask = il ? 0x00F0 : 0x000F;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
 
73
+ const uint32_t qh = *((device const uint32_t *)xb->qh);
 
 
74
 
75
+ const int x_mv = il ? 4 : 0;
 
 
76
 
77
+ const int gh_mv = il ? 12 : 0;
78
+ const int gh_bk = il ? 0 : 4;
79
+
80
+ for (int i = 0; i < 8; i++) {
81
+ // extract the 5-th bits for x0 and x1
82
+ const uint8_t xh_0 = ((qh >> (gh_mv + 2*i )) << gh_bk) & 0x10;
83
+ const uint8_t xh_1 = ((qh >> (gh_mv + 2*i+1)) << gh_bk) & 0x10;
84
+
85
+ // combine the 4-bits from qs with the 5th bit
86
+ const int32_t x0 = ((((qs[i] ) & mask) >> x_mv) | xh_0);
87
+ const int32_t x1 = ((((qs[i] >> 8) & mask) >> x_mv) | xh_1);
88
+
89
+ reg[i/2][2*(i%2)+0] = d * x0 + md;
90
+ reg[i/2][2*(i%2)+1] = d * x1 + md;
91
  }
92
  }
93
 
94
+ template <typename type4x4>
95
+ void dequantize_q5_1(device const block_q5_1 *xb, short il, thread type4x4 & reg) {
96
+ device const uint16_t * qs = ((device const uint16_t *)xb + 4);
97
+ const float d = xb->d;
98
+ const float m = xb->m;
99
+ const ushort mask = il ? 0x00F0 : 0x000F;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100
 
101
+ const uint32_t qh = *((device const uint32_t *)xb->qh);
 
 
102
 
103
+ const int x_mv = il ? 4 : 0;
 
 
104
 
105
+ const int gh_mv = il ? 12 : 0;
106
+ const int gh_bk = il ? 0 : 4;
 
 
 
107
 
108
+ for (int i = 0; i < 8; i++) {
109
+ // extract the 5-th bits for x0 and x1
110
+ const uint8_t xh_0 = ((qh >> (gh_mv + 2*i )) << gh_bk) & 0x10;
111
+ const uint8_t xh_1 = ((qh >> (gh_mv + 2*i+1)) << gh_bk) & 0x10;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
112
 
113
+ // combine the 4-bits from qs with the 5th bit
114
+ const int32_t x0 = ((((qs[i] ) & mask) >> x_mv) | xh_0);
115
+ const int32_t x1 = ((((qs[i] >> 8) & mask) >> x_mv) | xh_1);
116
 
117
+ reg[i/2][2*(i%2)+0] = d * x0 + m;
118
+ reg[i/2][2*(i%2)+1] = d * x1 + m;
 
119
  }
120
  }
121
 
122
+ template <typename type4x4>
123
+ void dequantize_q8_0(device const block_q8_0 *xb, short il, thread type4x4 & reg) {
124
+ device const int8_t * qs = ((device const int8_t *)xb->qs);
125
+ const half d = xb->d;
 
 
 
 
 
 
 
 
 
 
 
 
 
126
 
127
+ for (int i = 0; i < 16; i++) {
128
+ reg[i/4][i%4] = (qs[i + 16*il] * d);
129
+ }
 
 
 
 
130
  }
131
 
132
+ template <typename type4x4>
133
+ void dequantize_q2_K(device const block_q2_K *xb, short il, thread type4x4 & reg) {
134
+ const float d = xb->d;
135
+ const float min = xb->dmin;
136
+ device const uint8_t * q = (device const uint8_t *)xb->qs;
137
+ float dl, ml;
138
+ uint8_t sc = xb->scales[il];
 
139
 
140
+ q = q + 32*(il/8) + 16*(il&1);
141
+ il = (il/2)%4;
 
 
 
 
 
 
142
 
143
+ half coef = il>1 ? (il>2 ? 1/64.h : 1/16.h) : (il>0 ? 1/4.h : 1.h);
144
+ uchar mask = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3);
145
+ dl = d * (sc & 0xF) * coef, ml = min * (sc >> 4);
146
+ for (int i = 0; i < 16; ++i) {
147
+ reg[i/4][i%4] = dl * (q[i] & mask) - ml;
148
+ }
149
  }
150
 
151
+ template <typename type4x4>
152
+ void dequantize_q3_K(device const block_q3_K *xb, short il, thread type4x4 & reg) {
153
+ const half d_all = xb->d;
154
+ device const uint8_t * q = (device const uint8_t *)xb->qs;
155
+ device const uint8_t * h = (device const uint8_t *)xb->hmask;
156
+ device const int8_t * scales = (device const int8_t *)xb->scales;
 
157
 
158
+ q = q + 32 * (il/8) + 16 * (il&1);
159
+ h = h + 16 * (il&1);
160
+ uint8_t m = 1 << (il/2);
161
+ uint16_t kmask1 = (il/4)>1 ? ((il/4)>2 ? 192 : 48) : \
162
+ ((il/4)>0 ? 12 : 3);
163
+ uint16_t kmask2 = il/8 ? 0xF0 : 0x0F;
164
+ uint16_t scale_2 = scales[il%8], scale_1 = scales[8 + il%4];
165
+ int16_t dl_int = (il/4)&1 ? (scale_2&kmask2) | ((scale_1&kmask1) << 2)
166
+ : (scale_2&kmask2) | ((scale_1&kmask1) << 4);
167
+ float dl = il<8 ? d_all * (dl_int - 32.f) : d_all * (dl_int / 16.f - 32.f);
168
+ const float ml = 4.f * dl;
169
 
170
+ il = (il/2) & 3;
171
+ const half coef = il>1 ? (il>2 ? 1/64.h : 1/16.h) : (il>0 ? 1/4.h : 1.h);
172
+ const uint8_t mask = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3);
173
+ dl *= coef;
 
 
174
 
175
+ for (int i = 0; i < 16; ++i) {
176
+ reg[i/4][i%4] = dl * (q[i] & mask) - (h[i] & m ? 0 : ml);
177
+ }
 
 
178
  }
179
 
180
+ static inline uchar2 get_scale_min_k4_just2(int j, int k, device const uchar * q) {
181
+ return j < 4 ? uchar2{uchar(q[j+0+k] & 63), uchar(q[j+4+k] & 63)}
182
+ : uchar2{uchar((q[j+4+k] & 0xF) | ((q[j-4+k] & 0xc0) >> 2)), uchar((q[j+4+k] >> 4) | ((q[j-0+k] & 0xc0) >> 2))};
 
 
 
183
  }
184
 
185
+ template <typename type4x4>
186
+ void dequantize_q4_K(device const block_q4_K *xb, short il, thread type4x4 & reg) {
187
+ device const uchar * q = xb->qs;
188
 
189
+ short is = (il/4) * 2;
190
+ q = q + (il/4) * 32 + 16 * (il&1);
191
+ il = il & 3;
192
+ const uchar2 sc = get_scale_min_k4_just2(is, il/2, xb->scales);
193
+ const float d = il < 2 ? xb->d : xb->d / 16.h;
194
+ const float min = xb->dmin;
195
+ const float dl = d * sc[0];
196
+ const float ml = min * sc[1];
197
 
198
+ const ushort mask = il<2 ? 0x0F : 0xF0;
199
+ for (int i = 0; i < 16; ++i) {
200
+ reg[i/4][i%4] = dl * (q[i] & mask) - ml;
201
+ }
202
  }
203
 
204
+ template <typename type4x4>
205
+ void dequantize_q5_K(device const block_q5_K *xb, short il, thread type4x4 & reg) {
206
+ device const uint8_t * q = xb->qs;
207
+ device const uint8_t * qh = xb->qh;
 
 
 
 
 
 
 
 
208
 
209
+ short is = (il/4) * 2;
210
+ q = q + 32 * (il/4) + 16 * (il&1);
211
+ qh = qh + 16 * (il&1);
212
+ uint8_t ul = 1 << (il/2);
213
+ il = il & 3;
214
+ const uchar2 sc = get_scale_min_k4_just2(is, il/2, xb->scales);
215
+ const float d = il < 2 ? xb->d : xb->d / 16.f;
216
+ const float min = xb->dmin;
217
+ const float dl = d * sc[0];
218
+ const float ml = min * sc[1];
219
 
220
+ const ushort mask = il<2 ? 0x0F : 0xF0;
221
+ const float qh_val = il<2 ? 16.f : 256.f;
222
+ for (int i = 0; i < 16; ++i) {
223
+ reg[i/4][i%4] = dl * ((q[i] & mask) + (qh[i] & ul ? qh_val : 0)) - ml;
224
+ }
225
  }
226
 
227
+ template <typename type4x4>
228
+ void dequantize_q6_K(device const block_q6_K *xb, short il, thread type4x4 & reg) {
229
+ const half d_all = xb->d;
230
+ device const uint8_t * ql = (device const uint8_t *)xb->ql;
231
+ device const uint8_t * qh = (device const uint8_t *)xb->qh;
232
+ device const int8_t * scales = (device const int8_t *)xb->scales;
 
 
233
 
234
+ ql = ql + 64*(il/8) + 32*((il/2)&1) + 16*(il&1);
235
+ qh = qh + 32*(il/8) + 16*(il&1);
236
+ float sc = scales[(il%2) + 2 * ((il/2))];
237
+ il = (il/2) & 3;
 
 
 
238
 
239
+ const uint16_t kmask1 = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3);
240
+ const uint16_t kmask2 = il>1 ? 0xF0 : 0x0F;
241
+ const float coef = il>1 ? 1.f/16.f : 1.f;
242
+ const float ml = d_all * sc * 32.f;
243
+ const float dl = d_all * sc * coef;
244
+ for (int i = 0; i < 16; ++i) {
245
+ const half q = il&1 ? ((ql[i] & kmask2) | ((qh[i] & kmask1) << 2))
246
+ : ((ql[i] & kmask2) | ((qh[i] & kmask1) << 4));
247
+ reg[i/4][i%4] = dl * q - ml;
248
+ }
249
  }
250
 
251
+ template <typename type4x4>
252
+ void dequantize_iq2_xxs(device const block_iq2_xxs * xb, short il, thread type4x4 & reg) {
253
+ // il is 0...15 for QK_K = 256 => index of block of 32 is il/2
254
+ const float d = xb->d;
255
+ const int ib32 = il/2;
256
+ il = il%2;
257
+ // il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16
258
+ // each block of 32 needs 2 uint32_t's for the quants & scale, so 4 uint16_t's.
259
+ device const uint16_t * q2 = xb->qs + 4*ib32;
260
+ const uint32_t aux32_g = q2[0] | (q2[1] << 16);
261
+ const uint32_t aux32_s = q2[2] | (q2[3] << 16);
262
+ thread const uint8_t * aux8 = (thread const uint8_t *)&aux32_g;
263
+ const float dl = d * (0.5f + (aux32_s >> 28)) * 0.25f;
264
+ constant uint8_t * grid = (constant uint8_t *)(iq2xxs_grid + aux8[2*il+0]);
265
+ uint8_t signs = ksigns_iq2xs[(aux32_s >> 14*il) & 127];
266
+ for (int i = 0; i < 8; ++i) {
267
+ reg[i/4][i%4] = dl * grid[i] * (signs & kmask_iq2xs[i] ? -1.f : 1.f);
268
+ }
269
+ grid = (constant uint8_t *)(iq2xxs_grid + aux8[2*il+1]);
270
+ signs = ksigns_iq2xs[(aux32_s >> (14*il+7)) & 127];
271
+ for (int i = 0; i < 8; ++i) {
272
+ reg[2+i/4][i%4] = dl * grid[i] * (signs & kmask_iq2xs[i] ? -1.f : 1.f);
273
+ }
274
  }
275
 
276
+ template <typename type4x4>
277
+ void dequantize_iq2_xs(device const block_iq2_xs * xb, short il, thread type4x4 & reg) {
278
+ // il is 0...15 for QK_K = 256 => index of block of 32 is il/2
279
+ const float d = xb->d;
280
+ const int ib32 = il/2;
281
+ il = il%2;
282
+ // il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16
283
+ device const uint16_t * q2 = xb->qs + 4*ib32;
284
+ const float dl = d * (0.5f + ((xb->scales[ib32] >> 4*il) & 0xf)) * 0.25f;
285
+ constant uint8_t * grid = (constant uint8_t *)(iq2xs_grid + (q2[2*il+0] & 511));
286
+ uint8_t signs = ksigns_iq2xs[q2[2*il+0] >> 9];
287
+ for (int i = 0; i < 8; ++i) {
288
+ reg[i/4][i%4] = dl * grid[i] * (signs & kmask_iq2xs[i] ? -1.f : 1.f);
289
+ }
290
+ grid = (constant uint8_t *)(iq2xs_grid + (q2[2*il+1] & 511));
291
+ signs = ksigns_iq2xs[q2[2*il+1] >> 9];
292
+ for (int i = 0; i < 8; ++i) {
293
+ reg[2+i/4][i%4] = dl * grid[i] * (signs & kmask_iq2xs[i] ? -1.f : 1.f);
294
+ }
295
  }
296
 
297
+ template <typename type4x4>
298
+ void dequantize_iq3_xxs(device const block_iq3_xxs * xb, short il, thread type4x4 & reg) {
299
+ // il is 0...15 for QK_K = 256 => index of block of 32 is il/2
300
+ const float d = xb->d;
301
+ const int ib32 = il/2;
302
+ il = il%2;
303
+ // il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16
304
+ device const uint8_t * q3 = xb->qs + 8*ib32;
305
+ device const uint16_t * gas = (device const uint16_t *)(xb->qs + QK_K/4) + 2*ib32;
306
+ const uint32_t aux32 = gas[0] | (gas[1] << 16);
307
+ const float dl = d * (0.5f + (aux32 >> 28)) * 0.5f;
308
+ constant uint8_t * grid1 = (constant uint8_t *)(iq3xxs_grid + q3[4*il+0]);
309
+ constant uint8_t * grid2 = (constant uint8_t *)(iq3xxs_grid + q3[4*il+1]);
310
+ uint8_t signs = ksigns_iq2xs[(aux32 >> 14*il) & 127];
311
+ for (int i = 0; i < 4; ++i) {
312
+ reg[0][i] = dl * grid1[i] * (signs & kmask_iq2xs[i+0] ? -1.f : 1.f);
313
+ reg[1][i] = dl * grid2[i] * (signs & kmask_iq2xs[i+4] ? -1.f : 1.f);
314
+ }
315
+ grid1 = (constant uint8_t *)(iq3xxs_grid + q3[4*il+2]);
316
+ grid2 = (constant uint8_t *)(iq3xxs_grid + q3[4*il+3]);
317
+ signs = ksigns_iq2xs[(aux32 >> (14*il+7)) & 127];
318
+ for (int i = 0; i < 4; ++i) {
319
+ reg[2][i] = dl * grid1[i] * (signs & kmask_iq2xs[i+0] ? -1.f : 1.f);
320
+ reg[3][i] = dl * grid2[i] * (signs & kmask_iq2xs[i+4] ? -1.f : 1.f);
321
+ }
322
  }
323
 
324
+ template <typename type4x4>
325
+ void dequantize_iq3_s(device const block_iq3_s * xb, short il, thread type4x4 & reg) {
326
+ // il is 0...15 for QK_K = 256 => index of block of 32 is il/2
327
+ const float d = xb->d;
328
+ const int ib32 = il/2;
329
+ il = il%2;
330
+ // il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16
331
+ device const uint8_t * qs = xb->qs + 8*ib32;
332
+ device const uint8_t * signs = xb->signs + 4*ib32 + 2*il;
333
+ const uint8_t qh = xb->qh[ib32] >> 4*il;
334
+ const float dl = d * (1 + 2*((xb->scales[ib32/2] >> 4*(ib32%2)) & 0xf));
335
+ constant uint8_t * grid1 = (constant uint8_t *)(iq3s_grid + (qs[4*il+0] | ((qh << 8) & 256)));
336
+ constant uint8_t * grid2 = (constant uint8_t *)(iq3s_grid + (qs[4*il+1] | ((qh << 7) & 256)));
337
+ for (int i = 0; i < 4; ++i) {
338
+ reg[0][i] = dl * grid1[i] * select(1, -1, signs[0] & kmask_iq2xs[i+0]);
339
+ reg[1][i] = dl * grid2[i] * select(1, -1, signs[0] & kmask_iq2xs[i+4]);
340
+ }
341
+ grid1 = (constant uint8_t *)(iq3s_grid + (qs[4*il+2] | ((qh << 6) & 256)));
342
+ grid2 = (constant uint8_t *)(iq3s_grid + (qs[4*il+3] | ((qh << 5) & 256)));
343
+ for (int i = 0; i < 4; ++i) {
344
+ reg[2][i] = dl * grid1[i] * select(1, -1, signs[1] & kmask_iq2xs[i+0]);
345
+ reg[3][i] = dl * grid2[i] * select(1, -1, signs[1] & kmask_iq2xs[i+4]);
346
+ }
347
  }
348
 
349
+ template <typename type4x4>
350
+ void dequantize_iq2_s(device const block_iq2_s * xb, short il, thread type4x4 & reg) {
351
+ // il is 0...15 for QK_K = 256 => index of block of 32 is il/2
352
+ const float d = xb->d;
353
+ const int ib32 = il/2;
354
+ il = il%2;
355
+ // il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16
356
+ device const uint8_t * qs = xb->qs + 4*ib32 + 2*il;
357
+ device const uint8_t * signs = qs + QK_K/8;
358
+ const uint8_t qh = xb->qh[ib32] >> 4*il;
359
+ const float dl = d * (0.5f + ((xb->scales[ib32] >> 4*il) & 0xf)) * 0.25f;
360
+ constant uint8_t * grid1 = (constant uint8_t *)(iq2s_grid + (qs[0] | ((qh << 8) & 0x300)));
361
+ constant uint8_t * grid2 = (constant uint8_t *)(iq2s_grid + (qs[1] | ((qh << 6) & 0x300)));
362
+ for (int i = 0; i < 8; ++i) {
363
+ reg[i/4+0][i%4] = dl * grid1[i] * select(1, -1, signs[0] & kmask_iq2xs[i]);
364
+ reg[i/4+2][i%4] = dl * grid2[i] * select(1, -1, signs[1] & kmask_iq2xs[i]);
365
+ }
366
+ }
367
+
368
+ template <typename type4x4>
369
+ void dequantize_iq1_s(device const block_iq1_s * xb, short il, thread type4x4 & reg) {
370
+ // il is 0...15 for QK_K = 256 => index of block of 32 is il/2
371
+ const int ib32 = il/2;
372
+ il = il%2;
373
+ const float d = xb->d;
374
+ device const uint8_t * qs = xb->qs + 4*ib32 + 2*il;
375
+ device const uint16_t * qh = xb->qh;
376
+ const float dl = d * (2*((qh[ib32] >> 12) & 7) + 1);
377
+ const float ml = dl * (qh[ib32] & 0x8000 ? -1 - IQ1S_DELTA : -1 + IQ1S_DELTA);
378
+ const uint16_t h = qh[ib32] >> 6*il;
379
+ constant uint8_t * grid1 = (constant uint8_t *)(iq1s_grid_gpu + (qs[0] | ((h << 8) & 0x700)));
380
+ constant uint8_t * grid2 = (constant uint8_t *)(iq1s_grid_gpu + (qs[1] | ((h << 5) & 0x700)));
381
+ for (int i = 0; i < 4; ++i) {
382
+ reg[0][i] = dl * (grid1[i] & 0xf) + ml;
383
+ reg[1][i] = dl * (grid1[i] >> 4) + ml;
384
+ reg[2][i] = dl * (grid2[i] & 0xf) + ml;
385
+ reg[3][i] = dl * (grid2[i] >> 4) + ml;
386
+ }
387
+ }
388
+
389
+ template <typename type4x4>
390
+ void dequantize_iq1_m(device const block_iq1_m * xb, short il, thread type4x4 & reg) {
391
+ // il is 0...15 for QK_K = 256 => index of block of 32 is il/2
392
+ const int ib32 = il/2;
393
+ il = il%2;
394
+ device const uint16_t * sc = (device const uint16_t *)xb->scales;
395
+
396
+ iq1m_scale_t scale;
397
+ scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000);
398
+ const float d = scale.f16;
399
+
400
+ device const uint8_t * qs = xb->qs + 4*ib32 + 2*il;
401
+ device const uint8_t * qh = xb->qh + 2*ib32 + il;
402
+
403
+ const float dl = d * (2*((sc[ib32/2] >> (6*(ib32%2)+3*il)) & 7) + 1);
404
+ const float ml1 = dl * (qh[0] & 0x08 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA);
405
+ const float ml2 = dl * (qh[0] & 0x80 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA);
406
+ constant uint8_t * grid1 = (constant uint8_t *)(iq1s_grid_gpu + (qs[0] | ((qh[0] << 8) & 0x700)));
407
+ constant uint8_t * grid2 = (constant uint8_t *)(iq1s_grid_gpu + (qs[1] | ((qh[0] << 4) & 0x700)));
408
+ for (int i = 0; i < 4; ++i) {
409
+ reg[0][i] = dl * (grid1[i] & 0xf) + ml1;
410
+ reg[1][i] = dl * (grid1[i] >> 4) + ml1;
411
+ reg[2][i] = dl * (grid2[i] & 0xf) + ml2;
412
+ reg[3][i] = dl * (grid2[i] >> 4) + ml2;
413
+ }
414
+ }
415
+
416
+ template <typename type4x4>
417
+ void dequantize_iq4_nl(device const block_iq4_nl * xb, short il, thread type4x4 & reg) {
418
+ device const uint16_t * q4 = (device const uint16_t *)xb->qs;
419
+ const float d = xb->d;
420
+ uint32_t aux32;
421
+ thread const uint8_t * q8 = (thread const uint8_t *)&aux32;
422
+ for (int i = 0; i < 4; ++i) {
423
+ aux32 = ((q4[2*i] | (q4[2*i+1] << 16)) >> 4*il) & 0x0f0f0f0f;
424
+ reg[i][0] = d * kvalues_iq4nl_f[q8[0]];
425
+ reg[i][1] = d * kvalues_iq4nl_f[q8[1]];
426
+ reg[i][2] = d * kvalues_iq4nl_f[q8[2]];
427
+ reg[i][3] = d * kvalues_iq4nl_f[q8[3]];
428
+ }
429
+ }
430
+
431
+ template <typename type4x4>
432
+ void dequantize_iq4_xs(device const block_iq4_xs * xb, short il, thread type4x4 & reg) {
433
+ // il is 0...15 for QK_K = 256 => index of block of 32 is il/2
434
+ const int ib32 = il/2;
435
+ il = il%2;
436
+ // il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16
437
+ device const uint32_t * q4 = (device const uint32_t *)xb->qs + 4*ib32;
438
+ const int ls = ((xb->scales_l[ib32/2] >> 4*(ib32%2)) & 0xf) | (((xb->scales_h >> 2*ib32) & 3) << 4);
439
+ const float d = (float)xb->d * (ls - 32);
440
+ uint32_t aux32;
441
+ thread const uint8_t * q8 = (thread const uint8_t *)&aux32;
442
+ for (int i = 0; i < 4; ++i) {
443
+ aux32 = (q4[i] >> 4*il) & 0x0f0f0f0f;
444
+ reg[i][0] = d * kvalues_iq4nl_f[q8[0]];
445
+ reg[i][1] = d * kvalues_iq4nl_f[q8[1]];
446
+ reg[i][2] = d * kvalues_iq4nl_f[q8[2]];
447
+ reg[i][3] = d * kvalues_iq4nl_f[q8[3]];
448
+ }
449
+ }
450
+
451
+ enum ggml_sort_order {
452
+ GGML_SORT_ORDER_ASC,
453
+ GGML_SORT_ORDER_DESC,
454
+ };
455
+
456
+ // general-purpose kernel for addition, subtraction, multiplication and division of two tensors
457
+ // pros: works for non-contiguous tensors, supports broadcast across all dims
458
+ // cons: not very efficient
459
+ kernel void kernel_add(
460
+ device const char * src0,
461
+ device const char * src1,
462
+ device char * dst,
463
  constant int64_t & ne00,
464
  constant int64_t & ne01,
465
  constant int64_t & ne02,
 
484
  constant uint64_t & nb1,
485
  constant uint64_t & nb2,
486
  constant uint64_t & nb3,
487
+ constant int64_t & offs,
488
+ uint3 tgpig[[threadgroup_position_in_grid]],
489
+ uint3 tpitg[[thread_position_in_threadgroup]],
490
+ uint3 ntg[[threads_per_threadgroup]]) {
491
+ const int64_t i03 = tgpig.z;
492
+ const int64_t i02 = tgpig.y;
493
+ const int64_t i01 = tgpig.x;
494
 
495
+ const int64_t i13 = i03 % ne13;
496
+ const int64_t i12 = i02 % ne12;
497
+ const int64_t i11 = i01 % ne11;
 
 
 
498
 
499
+ device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01 + offs;
500
+ device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11;
501
+ device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1 + offs;
502
 
503
+ for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
504
+ const int i10 = i0 % ne10;
505
+ *((device float *)(dst_ptr + i0*nb0)) = *((device float *)(src0_ptr + i0*nb00)) + *((device float *)(src1_ptr + i10*nb10));
506
  }
 
 
507
  }
508
 
509
+ kernel void kernel_sub(
510
+ device const char * src0,
511
+ device const char * src1,
512
+ device char * dst,
513
+ constant int64_t & ne00,
514
+ constant int64_t & ne01,
515
+ constant int64_t & ne02,
516
+ constant int64_t & ne03,
517
+ constant uint64_t & nb00,
518
+ constant uint64_t & nb01,
519
+ constant uint64_t & nb02,
520
+ constant uint64_t & nb03,
521
+ constant int64_t & ne10,
522
+ constant int64_t & ne11,
523
+ constant int64_t & ne12,
524
+ constant int64_t & ne13,
525
+ constant uint64_t & nb10,
526
+ constant uint64_t & nb11,
527
+ constant uint64_t & nb12,
528
+ constant uint64_t & nb13,
529
+ constant int64_t & ne0,
530
+ constant int64_t & ne1,
531
+ constant int64_t & ne2,
532
+ constant int64_t & ne3,
533
+ constant uint64_t & nb0,
534
+ constant uint64_t & nb1,
535
+ constant uint64_t & nb2,
536
+ constant uint64_t & nb3,
537
+ constant int64_t & offs,
538
+ uint3 tgpig[[threadgroup_position_in_grid]],
539
+ uint3 tpitg[[thread_position_in_threadgroup]],
540
+ uint3 ntg[[threads_per_threadgroup]]) {
541
+ const int64_t i03 = tgpig.z;
542
+ const int64_t i02 = tgpig.y;
543
+ const int64_t i01 = tgpig.x;
544
 
545
+ const int64_t i13 = i03 % ne13;
546
+ const int64_t i12 = i02 % ne12;
547
+ const int64_t i11 = i01 % ne11;
548
 
549
+ device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01 + offs;
550
+ device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11;
551
+ device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1 + offs;
552
 
553
+ for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
554
+ const int i10 = i0 % ne10;
555
+ *((device float *)(dst_ptr + i0*nb0)) = *((device float *)(src0_ptr + i0*nb00)) - *((device float *)(src1_ptr + i10*nb10));
556
+ }
557
+ }
558
 
559
+ kernel void kernel_mul(
560
+ device const char * src0,
561
+ device const char * src1,
562
+ device char * dst,
563
+ constant int64_t & ne00,
564
+ constant int64_t & ne01,
565
+ constant int64_t & ne02,
566
+ constant int64_t & ne03,
567
+ constant uint64_t & nb00,
568
+ constant uint64_t & nb01,
569
+ constant uint64_t & nb02,
570
+ constant uint64_t & nb03,
571
+ constant int64_t & ne10,
572
+ constant int64_t & ne11,
573
+ constant int64_t & ne12,
574
+ constant int64_t & ne13,
575
+ constant uint64_t & nb10,
576
+ constant uint64_t & nb11,
577
+ constant uint64_t & nb12,
578
+ constant uint64_t & nb13,
579
+ constant int64_t & ne0,
580
+ constant int64_t & ne1,
581
+ constant int64_t & ne2,
582
+ constant int64_t & ne3,
583
+ constant uint64_t & nb0,
584
+ constant uint64_t & nb1,
585
+ constant uint64_t & nb2,
586
+ constant uint64_t & nb3,
587
+ uint3 tgpig[[threadgroup_position_in_grid]],
588
+ uint3 tpitg[[thread_position_in_threadgroup]],
589
+ uint3 ntg[[threads_per_threadgroup]]) {
590
+ const int64_t i03 = tgpig.z;
591
+ const int64_t i02 = tgpig.y;
592
+ const int64_t i01 = tgpig.x;
593
 
594
+ const int64_t i13 = i03 % ne13;
595
+ const int64_t i12 = i02 % ne12;
596
+ const int64_t i11 = i01 % ne11;
597
 
598
+ device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01;
599
+ device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11;
600
+ device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1;
601
 
602
+ for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
603
+ const int i10 = i0 % ne10;
604
+ *((device float *)(dst_ptr + i0*nb0)) = *((device float *)(src0_ptr + i0*nb00)) * *((device float *)(src1_ptr + i10*nb10));
605
  }
606
+ }
607
 
608
+ kernel void kernel_div(
609
+ device const char * src0,
610
+ device const char * src1,
611
+ device char * dst,
612
+ constant int64_t & ne00,
613
+ constant int64_t & ne01,
614
+ constant int64_t & ne02,
615
+ constant int64_t & ne03,
616
+ constant uint64_t & nb00,
617
+ constant uint64_t & nb01,
618
+ constant uint64_t & nb02,
619
+ constant uint64_t & nb03,
620
+ constant int64_t & ne10,
621
+ constant int64_t & ne11,
622
+ constant int64_t & ne12,
623
+ constant int64_t & ne13,
624
+ constant uint64_t & nb10,
625
+ constant uint64_t & nb11,
626
+ constant uint64_t & nb12,
627
+ constant uint64_t & nb13,
628
+ constant int64_t & ne0,
629
+ constant int64_t & ne1,
630
+ constant int64_t & ne2,
631
+ constant int64_t & ne3,
632
+ constant uint64_t & nb0,
633
+ constant uint64_t & nb1,
634
+ constant uint64_t & nb2,
635
+ constant uint64_t & nb3,
636
+ uint3 tgpig[[threadgroup_position_in_grid]],
637
+ uint3 tpitg[[thread_position_in_threadgroup]],
638
+ uint3 ntg[[threads_per_threadgroup]]) {
639
+ const int64_t i03 = tgpig.z;
640
+ const int64_t i02 = tgpig.y;
641
+ const int64_t i01 = tgpig.x;
642
 
643
+ const int64_t i13 = i03 % ne13;
644
+ const int64_t i12 = i02 % ne12;
645
+ const int64_t i11 = i01 % ne11;
646
 
647
+ device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01;
648
+ device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11;
649
+ device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1;
650
 
651
+ for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
652
+ const int i10 = i0 % ne10;
653
+ *((device float *)(dst_ptr + i0*nb0)) = *((device float *)(src0_ptr + i0*nb00)) / *((device float *)(src1_ptr + i10*nb10));
654
  }
655
+ }
656
 
657
+ template<typename T>
658
+ kernel void kernel_repeat(
659
+ device const char * src0,
660
+ device char * dst,
661
+ constant int64_t & ne00,
662
+ constant int64_t & ne01,
663
+ constant int64_t & ne02,
664
+ constant int64_t & ne03,
665
+ constant uint64_t & nb00,
666
+ constant uint64_t & nb01,
667
+ constant uint64_t & nb02,
668
+ constant uint64_t & nb03,
669
+ constant int64_t & ne0,
670
+ constant int64_t & ne1,
671
+ constant int64_t & ne2,
672
+ constant int64_t & ne3,
673
+ constant uint64_t & nb0,
674
+ constant uint64_t & nb1,
675
+ constant uint64_t & nb2,
676
+ constant uint64_t & nb3,
677
+ uint3 tgpig[[threadgroup_position_in_grid]],
678
+ uint3 tpitg[[thread_position_in_threadgroup]],
679
+ uint3 ntg[[threads_per_threadgroup]]) {
680
+ const int64_t i3 = tgpig.z;
681
+ const int64_t i2 = tgpig.y;
682
+ const int64_t i1 = tgpig.x;
683
 
684
+ const int64_t i03 = i3 % ne03;
685
+ const int64_t i02 = i2 % ne02;
686
+ const int64_t i01 = i1 % ne01;
687
 
688
+ device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01;
689
+ device char * dst_ptr = dst + i3*nb3 + i2*nb2 + i1*nb1 ;
690
 
691
+ for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
692
+ const int i00 = i0 % ne00;
693
+ *((device T *)(dst_ptr + i0*nb0)) = *((device T *)(src0_ptr + i00*nb00));
694
+ }
695
+ }
696
 
697
+ typedef decltype(kernel_repeat<float>) kernel_repeat_t;
698
 
699
+ template [[host_name("kernel_repeat_f32")]] kernel kernel_repeat_t kernel_repeat<float>;
700
+ template [[host_name("kernel_repeat_f16")]] kernel kernel_repeat_t kernel_repeat<half>;
701
+ template [[host_name("kernel_repeat_i32")]] kernel kernel_repeat_t kernel_repeat<int>;
702
+ template [[host_name("kernel_repeat_i16")]] kernel kernel_repeat_t kernel_repeat<short>;
703
 
704
+ // assumption: src1 is a row
705
+ // broadcast src1 into src0
706
+ kernel void kernel_add_row(
707
+ device const float4 * src0,
708
+ device const float4 * src1,
709
+ device float4 * dst,
710
+ constant uint64_t & nb [[buffer(28)]],
711
+ uint tpig[[thread_position_in_grid]]) {
712
+ dst[tpig] = src0[tpig] + src1[tpig % nb];
713
+ }
714
 
715
+ kernel void kernel_sub_row(
716
+ device const float4 * src0,
717
+ device const float4 * src1,
718
+ device float4 * dst,
719
+ constant uint64_t & nb [[buffer(28)]],
720
+ uint tpig[[thread_position_in_grid]]) {
721
+ dst[tpig] = src0[tpig] - src1[tpig % nb];
722
+ }
723
+
724
+ kernel void kernel_mul_row(
725
+ device const float4 * src0,
726
+ device const float4 * src1,
727
+ device float4 * dst,
728
+ constant uint64_t & nb [[buffer(28)]],
729
+ uint tpig[[thread_position_in_grid]]) {
730
+ dst[tpig] = src0[tpig] * src1[tpig % nb];
731
+ }
732
+
733
+ kernel void kernel_div_row(
734
+ device const float4 * src0,
735
+ device const float4 * src1,
736
+ device float4 * dst,
737
+ constant uint64_t & nb [[buffer(28)]],
738
+ uint tpig[[thread_position_in_grid]]) {
739
+ dst[tpig] = src0[tpig] / src1[tpig % nb];
740
+ }
741
+
742
+ kernel void kernel_scale(
743
+ device const float * src0,
744
+ device float * dst,
745
+ constant float & scale,
746
+ uint tpig[[thread_position_in_grid]]) {
747
+ dst[tpig] = src0[tpig] * scale;
748
+ }
749
+
750
+ kernel void kernel_scale_4(
751
+ device const float4 * src0,
752
+ device float4 * dst,
753
+ constant float & scale,
754
+ uint tpig[[thread_position_in_grid]]) {
755
+ dst[tpig] = src0[tpig] * scale;
756
+ }
757
+
758
+ kernel void kernel_clamp(
759
+ device const float * src0,
760
+ device float * dst,
761
+ constant float & min,
762
+ constant float & max,
763
+ uint tpig[[thread_position_in_grid]]) {
764
+ dst[tpig] = src0[tpig] < min ? min : (src0[tpig] > max ? max : src0[tpig]);
765
+ }
766
+
767
+ kernel void kernel_relu(
768
+ device const float * src0,
769
+ device float * dst,
770
+ uint tpig[[thread_position_in_grid]]) {
771
+ dst[tpig] = max(0.0f, src0[tpig]);
772
+ }
773
+
774
+ kernel void kernel_sigmoid(
775
+ device const float * src0,
776
+ device float * dst,
777
+ uint tpig[[thread_position_in_grid]]) {
778
+ dst[tpig] = 1.0f / (1.0f + exp(-src0[tpig]));
779
+ }
780
+
781
+ kernel void kernel_tanh(
782
+ device const float * src0,
783
+ device float * dst,
784
+ uint tpig[[thread_position_in_grid]]) {
785
+ device const float & x = src0[tpig];
786
+ dst[tpig] = precise::tanh(x);
787
+ }
788
+
789
+ constant float GELU_COEF_A = 0.044715f;
790
+ constant float GELU_QUICK_COEF = -1.702f;
791
+ constant float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
792
+
793
+ kernel void kernel_gelu(
794
+ device const float * src0,
795
+ device float * dst,
796
+ uint tpig[[thread_position_in_grid]]) {
797
+ device const float & x = src0[tpig];
798
+
799
+ dst[tpig] = 0.5f*x*(1.0f + precise::tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x)));
800
+ }
801
+
802
+ kernel void kernel_gelu_4(
803
+ device const float4 * src0,
804
+ device float4 * dst,
805
+ uint tpig[[thread_position_in_grid]]) {
806
+ device const float4 & x = src0[tpig];
807
+
808
+ // BEWARE !!!
809
+ // Simply using "tanh" instead of "precise::tanh" will sometimes results in NaNs!
810
+ // This was observed with Falcon 7B and 40B models
811
+ //
812
+ dst[tpig] = 0.5f*x*(1.0f + precise::tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x)));
813
+ }
814
+
815
+ kernel void kernel_gelu_quick(
816
+ device const float * src0,
817
+ device float * dst,
818
+ uint tpig[[thread_position_in_grid]]) {
819
+ device const float & x = src0[tpig];
820
+
821
+ dst[tpig] = x*(1.0f/(1.0f+exp(GELU_QUICK_COEF*x)));
822
+ }
823
+
824
+ kernel void kernel_gelu_quick_4(
825
+ device const float4 * src0,
826
+ device float4 * dst,
827
+ uint tpig[[thread_position_in_grid]]) {
828
+ device const float4 & x = src0[tpig];
829
+
830
+ dst[tpig] = x*(1.0f/(1.0f+exp(GELU_QUICK_COEF*x)));
831
+ }
832
+
833
+ kernel void kernel_silu(
834
+ device const float * src0,
835
+ device float * dst,
836
+ uint tpig[[thread_position_in_grid]]) {
837
+ device const float & x = src0[tpig];
838
+ dst[tpig] = x / (1.0f + exp(-x));
839
+ }
840
+
841
+ kernel void kernel_silu_4(
842
+ device const float4 * src0,
843
+ device float4 * dst,
844
+ uint tpig[[thread_position_in_grid]]) {
845
+ device const float4 & x = src0[tpig];
846
+ dst[tpig] = x / (1.0f + exp(-x));
847
+ }
848
+
849
+ kernel void kernel_sqr(
850
+ device const float * src0,
851
+ device float * dst,
852
+ uint tpig[[thread_position_in_grid]]) {
853
+ dst[tpig] = src0[tpig] * src0[tpig];
854
+ }
855
+
856
+ kernel void kernel_sqrt(
857
+ device const float * src0,
858
+ device float * dst,
859
+ uint tpig[[thread_position_in_grid]]) {
860
+ dst[tpig] = sqrt(src0[tpig]);
861
+ }
862
+
863
+ kernel void kernel_sin(
864
+ device const float * src0,
865
+ device float * dst,
866
+ uint tpig[[thread_position_in_grid]]) {
867
+ dst[tpig] = sin(src0[tpig]);
868
+ }
869
+
870
+ kernel void kernel_cos(
871
+ device const float * src0,
872
+ device float * dst,
873
+ uint tpig[[thread_position_in_grid]]) {
874
+ dst[tpig] = cos(src0[tpig]);
875
+ }
876
+
877
+ kernel void kernel_sum_rows(
878
+ device const float * src0,
879
+ device float * dst,
880
+ constant int64_t & ne00,
881
+ constant int64_t & ne01,
882
+ constant int64_t & ne02,
883
+ constant int64_t & ne03,
884
+ constant uint64_t & nb00,
885
+ constant uint64_t & nb01,
886
+ constant uint64_t & nb02,
887
+ constant uint64_t & nb03,
888
+ constant int64_t & ne10,
889
+ constant int64_t & ne11,
890
+ constant int64_t & ne12,
891
+ constant int64_t & ne13,
892
+ constant uint64_t & nb10,
893
+ constant uint64_t & nb11,
894
+ constant uint64_t & nb12,
895
+ constant uint64_t & nb13,
896
+ constant int64_t & ne0,
897
+ constant int64_t & ne1,
898
+ constant int64_t & ne2,
899
+ constant int64_t & ne3,
900
+ constant uint64_t & nb0,
901
+ constant uint64_t & nb1,
902
+ constant uint64_t & nb2,
903
+ constant uint64_t & nb3,
904
+ uint3 tpig[[thread_position_in_grid]]) {
905
+ int64_t i3 = tpig.z;
906
+ int64_t i2 = tpig.y;
907
+ int64_t i1 = tpig.x;
908
+
909
+ if (i3 >= ne03 || i2 >= ne02 || i1 >= ne01) {
910
+ return;
911
  }
912
 
913
+ device const float * src_row = (device const float *) ((device const char *) src0 + i1*nb01 + i2*nb02 + i3*nb03);
914
+ device float * dst_row = (device float *) ((device char *) dst + i1*nb1 + i2*nb2 + i3*nb3);
915
 
916
+ float row_sum = 0;
917
+
918
+ for (int64_t i0 = 0; i0 < ne00; i0++) {
919
+ row_sum += src_row[i0];
920
  }
921
+
922
+ dst_row[0] = row_sum;
923
  }
924
 
925
  template<typename T>
926
+ kernel void kernel_soft_max(
927
  device const char * src0,
928
  device const char * src1,
929
  device char * dst,
 
945
  const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01;
946
  const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);
947
 
948
+ device const float * psrc0 = (device const float *) src0 + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
949
+ device const T * pmask = src1 != src0 ? (device const T *) src1 + i01*ne00 : nullptr;
950
+ device float * pdst = (device float *) dst + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
951
 
952
  float slope = 1.0f;
953
 
954
+ // ALiBi
955
  if (max_bias > 0.0f) {
956
  const int64_t h = i02;
957
 
 
962
  }
963
 
964
  // parallel max
965
+ float lmax = -INFINITY;
966
 
967
+ for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
968
+ lmax = MAX(lmax, psrc0[i00]*scale + (pmask ? slope*pmask[i00] : 0.0f));
969
  }
970
 
971
+ // find the max value in the block
 
972
  float max_val = simd_max(lmax);
973
  if (ntg > N_SIMDWIDTH) {
974
  if (sgitg == 0) {
 
988
  }
989
 
990
  // parallel sum
991
+ float lsum = 0.0f;
992
+ for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
993
+ const float exp_psrc0 = exp((psrc0[i00]*scale + (pmask ? slope*pmask[i00] : 0.0f)) - max_val);
994
+ lsum += exp_psrc0;
995
+ pdst[i00] = exp_psrc0;
996
  }
997
 
998
+ // This barrier fixes a failing test
999
+ // ref: https://github.com/ggerganov/ggml/pull/621#discussion_r1425156335
1000
+ threadgroup_barrier(mem_flags::mem_none);
1001
+
1002
+ float sum = simd_sum(lsum);
1003
+
1004
+ if (ntg > N_SIMDWIDTH) {
1005
+ if (sgitg == 0) {
1006
+ buf[tiisg] = 0.0f;
1007
+ }
1008
+
1009
+ threadgroup_barrier(mem_flags::mem_threadgroup);
1010
+
1011
+ if (tiisg == 0) {
1012
+ buf[sgitg] = sum;
1013
+ }
1014
+
1015
+ threadgroup_barrier(mem_flags::mem_threadgroup);
1016
+
1017
+ sum = buf[tiisg];
1018
+ sum = simd_sum(sum);
1019
+ }
1020
+
1021
+ const float inv_sum = 1.0f/sum;
1022
+
1023
+ for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
1024
+ pdst[i00] *= inv_sum;
1025
+ }
1026
+ }
1027
+
1028
+ template<typename T>
1029
+ kernel void kernel_soft_max_4(
1030
+ device const char * src0,
1031
+ device const char * src1,
1032
+ device char * dst,
1033
+ constant int64_t & ne00,
1034
+ constant int64_t & ne01,
1035
+ constant int64_t & ne02,
1036
+ constant float & scale,
1037
+ constant float & max_bias,
1038
+ constant float & m0,
1039
+ constant float & m1,
1040
+ constant uint32_t & n_head_log2,
1041
+ threadgroup float * buf [[threadgroup(0)]],
1042
+ uint tgpig[[threadgroup_position_in_grid]],
1043
+ uint tpitg[[thread_position_in_threadgroup]],
1044
+ uint sgitg[[simdgroup_index_in_threadgroup]],
1045
+ uint tiisg[[thread_index_in_simdgroup]],
1046
+ uint ntg[[threads_per_threadgroup]]) {
1047
+ const int64_t i03 = (tgpig) / (ne02*ne01);
1048
+ const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01;
1049
+ const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);
1050
+
1051
+ device const float4 * psrc4 = (device const float4 *) src0 + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00)/4;
1052
+ device const T * pmask = src1 != src0 ? (device const T *) src1 + i01*ne00/4 : nullptr;
1053
+ device float4 * pdst4 = (device float4 *) dst + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00)/4;
1054
+
1055
+ float slope = 1.0f;
1056
+
1057
+ if (max_bias > 0.0f) {
1058
+ const int64_t h = i02;
1059
+
1060
+ const float base = h < n_head_log2 ? m0 : m1;
1061
+ const int exp = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
1062
+
1063
+ slope = pow(base, exp);
1064
+ }
1065
+
1066
+ // parallel max
1067
+ float4 lmax4 = -INFINITY;
1068
+
1069
+ for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
1070
+ lmax4 = fmax(lmax4, psrc4[i00]*scale + (float4)((pmask ? slope*pmask[i00] : 0.0f)));
1071
+ }
1072
+
1073
+ const float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3]));
1074
+
1075
+ float max_val = simd_max(lmax);
1076
+ if (ntg > N_SIMDWIDTH) {
1077
+ if (sgitg == 0) {
1078
+ buf[tiisg] = -INFINITY;
1079
+ }
1080
+
1081
+ threadgroup_barrier(mem_flags::mem_threadgroup);
1082
+
1083
+ if (tiisg == 0) {
1084
+ buf[sgitg] = max_val;
1085
+ }
1086
+
1087
+ threadgroup_barrier(mem_flags::mem_threadgroup);
1088
+
1089
+ max_val = buf[tiisg];
1090
+ max_val = simd_max(max_val);
1091
+ }
1092
+
1093
+ // parallel sum
1094
+ float4 lsum4 = 0.0f;
1095
+ for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
1096
+ const float4 exp_psrc4 = exp((psrc4[i00]*scale + (float4)((pmask ? slope*pmask[i00] : 0.0f))) - max_val);
1097
+ lsum4 += exp_psrc4;
1098
+ pdst4[i00] = exp_psrc4;
1099
+ }
1100
+
1101
+ const float lsum = lsum4[0] + lsum4[1] + lsum4[2] + lsum4[3];
1102
 
1103
  // This barrier fixes a failing test
1104
  // ref: https://github.com/ggerganov/ggml/pull/621#discussion_r1425156335
 
3775
  return x - val[mu-1] < val[mu] - x ? mu-1 : mu;
3776
  }
3777
 
 
 
 
 
3778
  kernel void kernel_cpy_f32_iq4_nl(
3779
  device const float * src0,
3780
  device void * dst,
 
5889
  kernel_mul_mv_iq4_xs_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb01, nb02, nb03, ne10, ne12, nb11, nb12, nb13, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
5890
  }
5891
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5892
  template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread float4x4 &)>
5893
  kernel void kernel_get_rows_q(
5894
  device const void * src0,