Eve commited on
Commit
03ab36f
·
1 Parent(s): ad8f031

vulkan: scale caching for k quants + misc fixes (llama/11081)

Browse files

* q6_k scale caching

* 16 bit unpack

* q4_k test (slow)

* revert it

* q3_k

* q2_k

* little stuff

* try precalculating products of a and q2_k scales

* Revert "try precalculating products of a and q2_k scales"

This reverts commit 65110b81f23f66331a50c6e889a7c1ab9470a86b.

* unpack should be u16, add vim swap to gitignore (about time)

* better q4_k scales

* q5_k

* better q6_k with separate paths for all threads and partial threads in use, plus some more optimizations

* q2_k better dequant

* q3_k optimizations

* q3_k use hmask simd from cpu avx version

* make the caches happy

* q3_k separate out calculation

* q2_k separate out

* little stuff

* use calc_superblock everywhere

* q2_k optimize scale calculation

* more barriers

ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp CHANGED
@@ -5,6 +5,80 @@
5
 
6
  layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
9
  uint a_offset, b_offset, d_offset;
10
  get_offsets(a_offset, b_offset, d_offset);
@@ -14,88 +88,28 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
14
  // 16 threads are used to process each block
15
  const uint it_size = gl_WorkGroupSize.x/16;
16
  const uint tid = gl_LocalInvocationID.x;
17
- const uint itid = tid%16; // 0...16
18
- const uint ix = tid/16;
19
-
20
- const uint step = 8;
21
 
22
- const uint v_im = itid/step; // 0 or 1. 0 computes 0..., 1 computes 128...
23
- const uint v_in = itid - step*v_im; // 0...15 or 0...7
24
 
25
  const uint l0 = 2*v_in; // 0...15
26
  const uint q_offset = 32*v_im + l0;
27
- const uint s_offset = 8*v_im;
28
  const uint y_offset = 128*v_im + l0;
29
 
30
- FLOAT_TYPE temp[NUM_COLS][NUM_ROWS];
31
-
32
  [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
33
  [[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) {
34
  temp[j][i] = FLOAT_TYPE(0);
35
  }
36
  }
37
 
38
- [[unroll]] for (uint i = ix; i < num_blocks_per_row; i += it_size) {
39
- const uint y_idx = i * QUANT_K + y_offset;
40
-
41
- [[unroll]] for (uint n = 0; n < num_rows; ++n) {
42
- const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row;
43
- vec2 d = vec2(data_a[ib0 + i].d);
44
- const FLOAT_TYPE dall = FLOAT_TYPE(d.x);
45
- const FLOAT_TYPE dmin = FLOAT_TYPE(d.y);
46
-
47
- uint32_t s0_u32 = data_a_packed32[ib0 + i].scales[s_offset / 4 + 0];
48
- uint32_t s4_u32 = data_a_packed32[ib0 + i].scales[s_offset / 4 + 1];
49
-
50
- uint32_t s0_lo4_u32 = s0_u32 & 0x0F0F0F0F;
51
- uint32_t s0_hi4_u32 = (s0_u32 >> 4) & 0x0F0F0F0F;
52
- uint32_t s4_lo4_u32 = s4_u32 & 0x0F0F0F0F;
53
- uint32_t s4_hi4_u32 = (s4_u32 >> 4) & 0x0F0F0F0F;
54
-
55
- uvec4 s0_lo4 = uvec4(unpack8(s0_lo4_u32));
56
- uvec4 s4_lo4 = uvec4(unpack8(s4_lo4_u32));
57
- uvec4 s0_hi4 = uvec4(unpack8(s0_hi4_u32));
58
- uvec4 s4_hi4 = uvec4(unpack8(s4_hi4_u32));
59
-
60
- uint16_t qs0_u16 = data_a_packed16[ib0 + i].qs[q_offset / 2 + 0];
61
- uint16_t qs16_u16 = data_a_packed16[ib0 + i].qs[q_offset / 2 + 8];
62
- uvec2 qs0 = uvec2(unpack8(qs0_u16));
63
- uvec2 qs16 = uvec2(unpack8(qs16_u16));
64
-
65
- [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
66
- vec2 b0 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 0]);
67
- vec2 b16 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 8]);
68
- vec2 b32 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 16]);
69
- vec2 b48 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 24]);
70
- vec2 b64 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 32]);
71
- vec2 b80 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 40]);
72
- vec2 b96 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 48]);
73
- vec2 b112 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 56]);
74
-
75
- FLOAT_TYPE sum1 = FLOAT_TYPE(0.0);
76
- FLOAT_TYPE sum2 = FLOAT_TYPE(0.0);
77
- [[unroll]] for (int l = 0; l < 2; ++l) {
78
- sum1 = fma(FLOAT_TYPE(b0[l]), FLOAT_TYPE(s0_lo4[0]) * FLOAT_TYPE((qs0[l] >> 0) & 3),
79
- fma(FLOAT_TYPE(b16[l]), FLOAT_TYPE(s0_lo4[1]) * FLOAT_TYPE((qs16[l] >> 0) & 3),
80
- fma(FLOAT_TYPE(b32[l]), FLOAT_TYPE(s0_lo4[2]) * FLOAT_TYPE((qs0[l] >> 2) & 3),
81
- fma(FLOAT_TYPE(b48[l]), FLOAT_TYPE(s0_lo4[3]) * FLOAT_TYPE((qs16[l] >> 2) & 3),
82
- fma(FLOAT_TYPE(b64[l]), FLOAT_TYPE(s4_lo4[0]) * FLOAT_TYPE((qs0[l] >> 4) & 3),
83
- fma(FLOAT_TYPE(b80[l]), FLOAT_TYPE(s4_lo4[1]) * FLOAT_TYPE((qs16[l] >> 4) & 3),
84
- fma(FLOAT_TYPE(b96[l]), FLOAT_TYPE(s4_lo4[2]) * FLOAT_TYPE((qs0[l] >> 6) & 3),
85
- fma(FLOAT_TYPE(b112[l]), FLOAT_TYPE(s4_lo4[3]) * FLOAT_TYPE((qs16[l] >> 6) & 3), sum1))))))));
86
- sum2 = fma(FLOAT_TYPE(b0[l]), FLOAT_TYPE(s0_hi4[0]),
87
- fma(FLOAT_TYPE(b16[l]), FLOAT_TYPE(s0_hi4[1]),
88
- fma(FLOAT_TYPE(b32[l]), FLOAT_TYPE(s0_hi4[2]),
89
- fma(FLOAT_TYPE(b48[l]), FLOAT_TYPE(s0_hi4[3]),
90
- fma(FLOAT_TYPE(b64[l]), FLOAT_TYPE(s4_hi4[0]),
91
- fma(FLOAT_TYPE(b80[l]), FLOAT_TYPE(s4_hi4[1]),
92
- fma(FLOAT_TYPE(b96[l]), FLOAT_TYPE(s4_hi4[2]),
93
- fma(FLOAT_TYPE(b112[l]), FLOAT_TYPE(s4_hi4[3]), sum2))))))));
94
- }
95
- temp[j][n] = fma(dall, sum1, fma(-dmin, sum2, temp[j][n]));
96
- }
97
- }
98
- }
99
 
100
  reduce_result(temp, d_offset, first_row, num_rows, tid);
101
  }
 
5
 
6
  layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
7
 
8
+ shared FLOAT_TYPE sccache1[BLOCK_SIZE/16][16];
9
+ shared FLOAT_TYPE sccache2[BLOCK_SIZE/16][16];
10
+
11
+ FLOAT_TYPE temp[NUM_COLS][NUM_ROWS];
12
+
13
+ void calc_superblock(const uint a_offset, const uint b_offset, const uint itid, const uint v_im, const uint ix, const uint q_offset, const uint y_offset, const uint i, const uint num_blocks_per_row, const uint first_row, const uint num_rows, const bool all_threads) {
14
+ const uint y_idx = i * QUANT_K + y_offset;
15
+
16
+ [[unroll]] for (uint n = 0; n < num_rows; ++n) {
17
+ const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row;
18
+
19
+ barrier();
20
+ if (!all_threads) { // when we don't have enough blocks to use all threads
21
+ if (i < num_blocks_per_row) {
22
+ const uint32_t scale = uint32_t(data_a[ib0 + i].scales[itid]);
23
+ sccache1[ix][itid] = FLOAT_TYPE(scale & 0xF);
24
+ sccache2[ix][itid] = FLOAT_TYPE((scale >> 4) & 0xF);
25
+ }
26
+ barrier();
27
+
28
+ if (i >= num_blocks_per_row)
29
+ continue;
30
+ } else {
31
+ const uint32_t scale = uint32_t(data_a[ib0 + i].scales[itid]);
32
+ sccache1[ix][itid] = FLOAT_TYPE(scale & 0xF);
33
+ sccache2[ix][itid] = FLOAT_TYPE((scale >> 4) & 0xF);
34
+ barrier();
35
+ }
36
+
37
+ const uint32_t qs_u32 = uint32_t(data_a_packed16[ib0 + i].qs[q_offset / 2]) | (uint32_t(data_a_packed16[ib0 + i].qs[q_offset / 2 + 8]) << 16);
38
+ const vec4 qs_u32_0 = vec4(unpack8(qs_u32 & 0x03030303));
39
+ const vec4 qs_u32_2 = vec4(unpack8((qs_u32 >> 2) & 0x03030303));
40
+ const vec4 qs_u32_4 = vec4(unpack8((qs_u32 >> 4) & 0x03030303));
41
+ const vec4 qs_u32_6 = vec4(unpack8((qs_u32 >> 6) & 0x03030303));
42
+
43
+ vec2 d = vec2(data_a[ib0 + i].d);
44
+ const FLOAT_TYPE dall = FLOAT_TYPE(d.x);
45
+ const FLOAT_TYPE dmin = FLOAT_TYPE(d.y);
46
+
47
+ [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
48
+ vec2 b0 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 0]);
49
+ vec2 b16 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 8]);
50
+ vec2 b32 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 16]);
51
+ vec2 b48 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 24]);
52
+ vec2 b64 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 32]);
53
+ vec2 b80 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 40]);
54
+ vec2 b96 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 48]);
55
+ vec2 b112 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 56]);
56
+
57
+ FLOAT_TYPE sum1 = FLOAT_TYPE(0.0);
58
+ FLOAT_TYPE sum2 = FLOAT_TYPE(0.0);
59
+ [[unroll]] for (int l = 0; l < 2; ++l) {
60
+ sum1 = fma(FLOAT_TYPE(b0[l]), sccache1[ix][ 8*v_im] * qs_u32_0[l ],
61
+ fma(FLOAT_TYPE(b16[l]), sccache1[ix][1 + 8*v_im] * qs_u32_0[l+2],
62
+ fma(FLOAT_TYPE(b32[l]), sccache1[ix][2 + 8*v_im] * qs_u32_2[l ],
63
+ fma(FLOAT_TYPE(b48[l]), sccache1[ix][3 + 8*v_im] * qs_u32_2[l+2],
64
+ fma(FLOAT_TYPE(b64[l]), sccache1[ix][4 + 8*v_im] * qs_u32_4[l ],
65
+ fma(FLOAT_TYPE(b80[l]), sccache1[ix][5 + 8*v_im] * qs_u32_4[l+2],
66
+ fma(FLOAT_TYPE(b96[l]), sccache1[ix][6 + 8*v_im] * qs_u32_6[l ],
67
+ fma(FLOAT_TYPE(b112[l]), sccache1[ix][7 + 8*v_im] * qs_u32_6[l+2], sum1))))))));
68
+ sum2 = fma(FLOAT_TYPE(b0[l]), sccache2[ix][ 8*v_im],
69
+ fma(FLOAT_TYPE(b16[l]), sccache2[ix][1 + 8*v_im],
70
+ fma(FLOAT_TYPE(b32[l]), sccache2[ix][2 + 8*v_im],
71
+ fma(FLOAT_TYPE(b48[l]), sccache2[ix][3 + 8*v_im],
72
+ fma(FLOAT_TYPE(b64[l]), sccache2[ix][4 + 8*v_im],
73
+ fma(FLOAT_TYPE(b80[l]), sccache2[ix][5 + 8*v_im],
74
+ fma(FLOAT_TYPE(b96[l]), sccache2[ix][6 + 8*v_im],
75
+ fma(FLOAT_TYPE(b112[l]), sccache2[ix][7 + 8*v_im], sum2))))))));
76
+ }
77
+ temp[j][n] = fma(dall, sum1, fma(-dmin, sum2, temp[j][n]));
78
+ }
79
+ }
80
+ }
81
+
82
  void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
83
  uint a_offset, b_offset, d_offset;
84
  get_offsets(a_offset, b_offset, d_offset);
 
88
  // 16 threads are used to process each block
89
  const uint it_size = gl_WorkGroupSize.x/16;
90
  const uint tid = gl_LocalInvocationID.x;
91
+ const uint itid = tid%16; // 0...15
92
+ const uint ix = tid/16;
 
 
93
 
94
+ const uint v_im = itid/8; // 0 or 1. 0 computes 0..., 1 computes 128...
95
+ const uint v_in = itid - 8*v_im; // 0...7
96
 
97
  const uint l0 = 2*v_in; // 0...15
98
  const uint q_offset = 32*v_im + l0;
 
99
  const uint y_offset = 128*v_im + l0;
100
 
 
 
101
  [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
102
  [[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) {
103
  temp[j][i] = FLOAT_TYPE(0);
104
  }
105
  }
106
 
107
+ const uint nbr_par_th = num_blocks_per_row%it_size;
108
+ const uint nbr_all_th = num_blocks_per_row - nbr_par_th;
109
+ uint i0 = 0;
110
+ [[unroll]] for (; i0 < nbr_all_th; i0 += it_size)
111
+ calc_superblock(a_offset, b_offset, itid, v_im, ix, q_offset, y_offset, i0 + ix, num_blocks_per_row, first_row, num_rows, true);
112
+ calc_superblock(a_offset, b_offset, itid, v_im, ix, q_offset, y_offset, i0 + ix, num_blocks_per_row, first_row, num_rows, false);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
113
 
114
  reduce_result(temp, d_offset, first_row, num_rows, tid);
115
  }
ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp CHANGED
@@ -5,6 +5,74 @@
5
 
6
  layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
9
  uint a_offset, b_offset, d_offset;
10
  get_offsets(a_offset, b_offset, d_offset);
@@ -14,76 +82,37 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
14
  // 16 threads are used to process each block
15
  const uint it_size = gl_WorkGroupSize.x/16;
16
  const uint tid = gl_LocalInvocationID.x;
17
- const uint itid = tid%16; // 0...16
18
- const uint ix = tid/16;
19
-
20
- const uint step = 8;
21
 
22
- const uint v_im = itid/step; // 0 or 1. 0 computes 0..., 1 computes 128...
23
- const uint v_in = itid - step*v_im; // 0...15 or 0...7
 
24
 
25
- const uint8_t m = uint8_t(1 << (4 * v_im));
 
 
 
26
 
27
  const uint l0 = 2*v_in; // 0...15
28
  const uint q_offset = 32*v_im + l0;
29
  const uint y_offset = 128*v_im + l0;
30
 
31
- FLOAT_TYPE temp[NUM_COLS][NUM_ROWS];
32
-
33
  [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
34
  [[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) {
35
  temp[j][i] = FLOAT_TYPE(0);
36
  }
37
  }
38
 
39
- const uint s_shift = 4 * v_im;
40
-
41
- [[unroll]] for (uint i = ix; i < num_blocks_per_row; i += it_size) {
42
- const uint y_idx = i * QUANT_K + y_offset;
43
-
44
- [[unroll]] for (uint n = 0; n < num_rows; ++n) {
45
- const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row;
46
- const FLOAT_TYPE d = FLOAT_TYPE(data_a[ib0 + i].d);
47
-
48
- uint16_t s0_16 = data_a_packed16[ib0 + i].scales[0];
49
- uint16_t s2_16 = data_a_packed16[ib0 + i].scales[1];
50
- uint16_t s4_16 = data_a_packed16[ib0 + i].scales[2];
51
- uint16_t s6_16 = data_a_packed16[ib0 + i].scales[3];
52
- uint16_t s8_16 = data_a_packed16[ib0 + i].scales[4];
53
- uint16_t s10_16 = data_a_packed16[ib0 + i].scales[5];
54
- u8vec2 s0 = unpack8(s0_16);
55
- u8vec2 s2 = unpack8(s2_16);
56
- u8vec2 s4 = unpack8(s4_16);
57
- u8vec2 s6 = unpack8(s6_16);
58
- u8vec2 s8 = unpack8(s8_16);
59
- u8vec2 s10 = unpack8(s10_16);
60
-
61
- [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
62
-
63
- vec2 b0 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 0]);
64
- vec2 b16 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 8]);
65
- vec2 b32 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 16]);
66
- vec2 b48 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 24]);
67
- vec2 b64 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 32]);
68
- vec2 b80 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 40]);
69
- vec2 b96 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 48]);
70
- vec2 b112 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 56]);
71
-
72
- FLOAT_TYPE sum = FLOAT_TYPE(0.0);
73
- [[unroll]] for (int l = 0; l < 2; ++l) {
74
- sum = fma(FLOAT_TYPE(b0[l]) * FLOAT_TYPE(int8_t(((s0[0] >> s_shift) & 0xF) | ((s8[0] >> (s_shift + 0) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l ] ) & 3) - (((data_a[ib0 + i].hmask[l0 + l ] & (m << 0)) != 0) ? 0 : 4)),
75
- fma(FLOAT_TYPE(b32[l]) * FLOAT_TYPE(int8_t(((s2[0] >> s_shift) & 0xF) | ((s10[0] >> (s_shift + 0) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l ] >> 2) & 3) - (((data_a[ib0 + i].hmask[l0 + l ] & (m << 1)) != 0) ? 0 : 4)),
76
- fma(FLOAT_TYPE(b64[l]) * FLOAT_TYPE(int8_t(((s4[0] >> s_shift) & 0xF) | ((s8[0] >> (s_shift + 2) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l ] >> 4) & 3) - (((data_a[ib0 + i].hmask[l0 + l ] & (m << 2)) != 0) ? 0 : 4)),
77
- fma(FLOAT_TYPE(b96[l]) * FLOAT_TYPE(int8_t(((s6[0] >> s_shift) & 0xF) | ((s10[0] >> (s_shift + 2) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l ] >> 6) & 3) - (((data_a[ib0 + i].hmask[l0 + l ] & (m << 3)) != 0) ? 0 : 4)),
78
- fma(FLOAT_TYPE(b16[l]) * FLOAT_TYPE(int8_t(((s0[1] >> s_shift) & 0xF) | ((s8[1] >> (s_shift + 0) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l+16] ) & 3) - (((data_a[ib0 + i].hmask[l0 + l+16] & (m << 0)) != 0) ? 0 : 4)),
79
- fma(FLOAT_TYPE(b48[l]) * FLOAT_TYPE(int8_t(((s2[1] >> s_shift) & 0xF) | ((s10[1] >> (s_shift + 0) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l+16] >> 2) & 3) - (((data_a[ib0 + i].hmask[l0 + l+16] & (m << 1)) != 0) ? 0 : 4)),
80
- fma(FLOAT_TYPE(b80[l]) * FLOAT_TYPE(int8_t(((s4[1] >> s_shift) & 0xF) | ((s8[1] >> (s_shift + 2) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l+16] >> 4) & 3) - (((data_a[ib0 + i].hmask[l0 + l+16] & (m << 2)) != 0) ? 0 : 4)),
81
- fma(FLOAT_TYPE(b112[l]) * FLOAT_TYPE(int8_t(((s6[1] >> s_shift) & 0xF) | ((s10[1] >> (s_shift + 2) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l+16] >> 6) & 3) - (((data_a[ib0 + i].hmask[l0 + l+16] & (m << 3)) != 0) ? 0 : 4)), sum))))))));
82
- }
83
- temp[j][n] = fma(d, sum, temp[j][n]);
84
- }
85
- }
86
- }
87
 
88
  reduce_result(temp, d_offset, first_row, num_rows, tid);
89
  }
 
5
 
6
  layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
7
 
8
+ shared FLOAT_TYPE sccache[BLOCK_SIZE/16][2][8];
9
+
10
+ FLOAT_TYPE temp[NUM_COLS][NUM_ROWS];
11
+
12
+ void calc_superblock(const uint a_offset, const uint b_offset, const uint ix, const uint itid8, const uint v_im, const uint v_im4, const uint v_in, const uint32_t hm_m[4], const uint q_offset, const uint y_offset, const uint s_shift, const uint i, const uint num_blocks_per_row, const uint first_row, const uint num_rows, const bool all_threads) {
13
+ const uint y_idx = i * QUANT_K + y_offset;
14
+
15
+ [[unroll]] for (uint n = 0; n < num_rows; ++n) {
16
+ const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row;
17
+
18
+ if (!all_threads) { // when we don't have enough blocks to use all threads
19
+ barrier();
20
+ if (i < num_blocks_per_row)
21
+ sccache[ix][v_im][itid8] = FLOAT_TYPE(int8_t(((data_a[ib0+i].scales[itid8] >> v_im4) & 0xF) | (((data_a[ib0+i].scales[itid8%4+8] >> s_shift) & 3) << 4)) - 32);
22
+ barrier();
23
+
24
+ if (i >= num_blocks_per_row)
25
+ continue;
26
+ }
27
+
28
+ const uint32_t hmk = ~(uint32_t(data_a_packed16[ib0 + i].hmask[v_in]) | (uint32_t(data_a_packed16[ib0 + i].hmask[v_in + 8]) << 16));
29
+ const vec4 hmk_0 = vec4(unpack8(((hmk & hm_m[0]) >> ( v_im4)) << 2));
30
+ const vec4 hmk_1 = vec4(unpack8(((hmk & hm_m[1]) >> (1 + v_im4)) << 2));
31
+ const vec4 hmk_2 = vec4(unpack8(((hmk & hm_m[2]) >> (2 + v_im4)) << 2));
32
+ const vec4 hmk_3 = vec4(unpack8(((hmk & hm_m[3]) >> (3 + v_im4)) << 2));
33
+
34
+ // 0, 1, 16, 17
35
+ uint32_t qs_u32 = uint32_t(data_a[ib0 + i].qs[q_offset]) | (uint32_t(data_a[ib0 + i].qs[q_offset + 1]) << 8);
36
+ qs_u32 |= (uint32_t(data_a[ib0 + i].qs[q_offset + 16]) | (uint32_t(data_a[ib0 + i].qs[q_offset + 17]) << 8)) << 16;
37
+ const vec4 qs_u32_0 = vec4(unpack8(qs_u32 & 0x03030303));
38
+ const vec4 qs_u32_2 = vec4(unpack8((qs_u32 >> 2) & 0x03030303));
39
+ const vec4 qs_u32_4 = vec4(unpack8((qs_u32 >> 4) & 0x03030303));
40
+ const vec4 qs_u32_6 = vec4(unpack8((qs_u32 >> 6) & 0x03030303));
41
+
42
+ if (all_threads) {
43
+ barrier();
44
+ sccache[ix][v_im][itid8] = FLOAT_TYPE(int8_t(((data_a[ib0+i].scales[itid8] >> v_im4) & 0xF) | (((data_a[ib0+i].scales[itid8%4+8] >> s_shift) & 3) << 4)) - 32);
45
+ barrier();
46
+ }
47
+
48
+ const FLOAT_TYPE d = FLOAT_TYPE(data_a[ib0 + i].d);
49
+
50
+ [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
51
+ vec2 b0 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 0]);
52
+ vec2 b16 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 8]);
53
+ vec2 b32 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 16]);
54
+ vec2 b48 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 24]);
55
+ vec2 b64 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 32]);
56
+ vec2 b80 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 40]);
57
+ vec2 b96 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 48]);
58
+ vec2 b112 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 56]);
59
+
60
+ FLOAT_TYPE sum = FLOAT_TYPE(0.0);
61
+ [[unroll]] for (int l = 0; l < 2; ++l) {
62
+ sum = fma(FLOAT_TYPE( b0[l]) * sccache[ix][v_im][0], qs_u32_0[l ] - hmk_0[l ],
63
+ fma(FLOAT_TYPE( b16[l]) * sccache[ix][v_im][1], qs_u32_0[l+2] - hmk_0[l+2],
64
+ fma(FLOAT_TYPE( b32[l]) * sccache[ix][v_im][2], qs_u32_2[l ] - hmk_1[l ],
65
+ fma(FLOAT_TYPE( b48[l]) * sccache[ix][v_im][3], qs_u32_2[l+2] - hmk_1[l+2],
66
+ fma(FLOAT_TYPE( b64[l]) * sccache[ix][v_im][4], qs_u32_4[l ] - hmk_2[l ],
67
+ fma(FLOAT_TYPE( b80[l]) * sccache[ix][v_im][5], qs_u32_4[l+2] - hmk_2[l+2],
68
+ fma(FLOAT_TYPE( b96[l]) * sccache[ix][v_im][6], qs_u32_6[l ] - hmk_3[l ],
69
+ fma(FLOAT_TYPE(b112[l]) * sccache[ix][v_im][7], qs_u32_6[l+2] - hmk_3[l+2], sum))))))));
70
+ }
71
+ temp[j][n] = fma(d, sum, temp[j][n]);
72
+ }
73
+ }
74
+ }
75
+
76
  void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
77
  uint a_offset, b_offset, d_offset;
78
  get_offsets(a_offset, b_offset, d_offset);
 
82
  // 16 threads are used to process each block
83
  const uint it_size = gl_WorkGroupSize.x/16;
84
  const uint tid = gl_LocalInvocationID.x;
85
+ const uint itid = tid%16; // 0...15
86
+ const uint ix = tid/16;
87
+ const uint itid8 = itid%8;
 
88
 
89
+ const uint v_im = itid/8; // 0 or 1. 0 computes 0..., 1 computes 128...
90
+ const uint v_im4 = v_im*4;
91
+ const uint v_in = itid - 8*v_im; // 0...7
92
 
93
+ const uint32_t m = 0x01010101 << (4 * v_im);
94
+ uint32_t hm_m[4];
95
+ [[unroll]] for (uint j = 0; j < 4; ++j)
96
+ hm_m[j] = m << j;
97
 
98
  const uint l0 = 2*v_in; // 0...15
99
  const uint q_offset = 32*v_im + l0;
100
  const uint y_offset = 128*v_im + l0;
101
 
 
 
102
  [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
103
  [[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) {
104
  temp[j][i] = FLOAT_TYPE(0);
105
  }
106
  }
107
 
108
+ const uint s_shift = v_im4 + 2*(itid8/4);
109
+
110
+ const uint nbr_par_th = num_blocks_per_row%it_size;
111
+ const uint nbr_all_th = num_blocks_per_row - nbr_par_th;
112
+ uint i0 = 0;
113
+ [[unroll]] for (; i0 < nbr_all_th; i0 += it_size)
114
+ calc_superblock(a_offset, b_offset, ix, itid8, v_im, v_im4, v_in, hm_m, q_offset, y_offset, s_shift, i0 + ix, num_blocks_per_row, first_row, num_rows, true);
115
+ calc_superblock(a_offset, b_offset, ix, itid8, v_im, v_im4, v_in, hm_m, q_offset, y_offset, s_shift, i0 + ix, num_blocks_per_row, first_row, num_rows, false);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
116
 
117
  reduce_result(temp, d_offset, first_row, num_rows, tid);
118
  }
ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp CHANGED
@@ -6,6 +6,86 @@
6
 
7
  layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
10
  uint a_offset, b_offset, d_offset;
11
  get_offsets(a_offset, b_offset, d_offset);
@@ -15,13 +95,11 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
15
  // 16 threads are used to process each block
16
  const uint it_size = gl_WorkGroupSize.x/16;
17
  const uint tid = gl_LocalInvocationID.x;
18
- const uint itid = tid%16; // 0...16
19
- const uint ix = tid/16;
20
 
21
- const uint step = 4;
22
-
23
- const uint il = itid/step; // 0...3
24
- const uint ir = itid - step*il; // 0...7 or 0...3
25
  const uint n = 4;
26
 
27
  const uint v_im = il / 2; // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224
@@ -31,89 +109,14 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
31
  const uint q_offset = 32*v_im + l0;
32
  const uint y_offset = 64*v_im + l0;
33
 
34
- FLOAT_TYPE temp[NUM_COLS][NUM_ROWS];
35
-
36
  [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
37
  [[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) {
38
  temp[j][i] = FLOAT_TYPE(0);
39
  }
40
  }
41
 
42
- [[unroll]] for (uint i = ix; i < num_blocks_per_row; i += it_size) {
43
- const uint y1_idx = i * QUANT_K + y_offset;
44
- const uint y2_idx = y1_idx + 128;
45
-
46
- [[unroll]] for (uint n = 0; n < num_rows; ++n) {
47
- const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row;
48
- vec2 d = vec2(data_a[ib0 + i].d);
49
- const FLOAT_TYPE dall = FLOAT_TYPE(d.x);
50
- const FLOAT_TYPE dmin = FLOAT_TYPE(d.y);
51
-
52
- uint32_t scale0_u32 = data_a_packed16[ib0 + i].scales[v_im ];
53
- uint32_t scale4_u32 = data_a_packed16[ib0 + i].scales[v_im + 2];
54
- uint32_t scale8_u32 = data_a_packed16[ib0 + i].scales[v_im + 4];
55
- uvec4 scale0 = uvec4(unpack8(scale0_u32));
56
- uvec4 scale4 = uvec4(unpack8(scale4_u32));
57
- uvec4 scale8 = uvec4(unpack8(scale8_u32));
58
-
59
- const uint32_t sc0 = ( scale0.x & 0x3f);
60
- const uint32_t sc1 = ( scale0.y & 0x3f);
61
- const uint32_t sc2 = ( scale4.x & 0x3f);
62
- const uint32_t sc3 = ( scale4.y & 0x3f);
63
- const uint32_t sc4 = (( scale8.x & 0x0f) | ((scale0.x & 0xc0) >> 2));
64
- const uint32_t sc5 = (( scale8.y & 0x0f) | ((scale0.y & 0xc0) >> 2));
65
- const uint32_t sc6 = (((scale8.x >> 4) & 0x0f) | ((scale4.x & 0xc0) >> 2));
66
- const uint32_t sc7 = (((scale8.y >> 4) & 0x0f) | ((scale4.y & 0xc0) >> 2));
67
-
68
- uint32_t qs0_u32 = data_a_packed32[ib0 + i].qs[q_offset / 4];
69
- uint32_t qs64_u32 = data_a_packed32[ib0 + i].qs[q_offset / 4 + 16];
70
-
71
- uint32_t qs0_u32_lo4 = qs0_u32 & 0x0F0F0F0F;
72
- uint32_t qs0_u32_hi4 = (qs0_u32 >> 4) & 0x0F0F0F0F;
73
- uint32_t qs64_u32_lo4 = qs64_u32 & 0x0F0F0F0F;
74
- uint32_t qs64_u32_hi4 = (qs64_u32 >> 4) & 0x0F0F0F0F;
75
-
76
- uvec4 qs0_lo4 = uvec4(unpack8(qs0_u32_lo4));
77
- uvec4 qs64_lo4 = uvec4(unpack8(qs64_u32_lo4));
78
- uvec4 qs0_hi4 = uvec4(unpack8(qs0_u32_hi4));
79
- uvec4 qs64_hi4 = uvec4(unpack8(qs64_u32_hi4));
80
-
81
- const uint32_t q4_0 = qs0_lo4.x;
82
- const uint32_t q4_1 = qs0_lo4.y;
83
- const uint32_t q4_2 = qs0_lo4.z;
84
- const uint32_t q4_3 = qs0_lo4.w;
85
- const uint32_t q4_4 = qs0_hi4.x;
86
- const uint32_t q4_5 = qs0_hi4.y;
87
- const uint32_t q4_6 = qs0_hi4.z;
88
- const uint32_t q4_7 = qs0_hi4.w;
89
- const uint32_t q4_8 = qs64_lo4.x;
90
- const uint32_t q4_9 = qs64_lo4.y;
91
- const uint32_t q4_10 = qs64_lo4.z;
92
- const uint32_t q4_11 = qs64_lo4.w;
93
- const uint32_t q4_12 = qs64_hi4.x;
94
- const uint32_t q4_13 = qs64_hi4.y;
95
- const uint32_t q4_14 = qs64_hi4.z;
96
- const uint32_t q4_15 = qs64_hi4.w;
97
-
98
- [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
99
- vec4 by10 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y1_idx) / 4 ]);
100
- vec4 by132 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y1_idx) / 4 + 8]);
101
- vec4 by20 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y2_idx) / 4 ]);
102
- vec4 by232 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y2_idx) / 4 + 8]);
103
-
104
- const FLOAT_TYPE sx = fma(FLOAT_TYPE(by10.x), q4_0, fma(FLOAT_TYPE(by10.y), q4_1, fma(FLOAT_TYPE(by10.z), q4_2, FLOAT_TYPE(by10.w) * q4_3)));
105
- const FLOAT_TYPE sy = fma(FLOAT_TYPE(by132.x), q4_4, fma(FLOAT_TYPE(by132.y), q4_5, fma(FLOAT_TYPE(by132.z), q4_6, FLOAT_TYPE(by132.w) * q4_7)));
106
- const FLOAT_TYPE sz = fma(FLOAT_TYPE(by20.x), q4_8, fma(FLOAT_TYPE(by20.y), q4_9, fma(FLOAT_TYPE(by20.z), q4_10, FLOAT_TYPE(by20.w) * q4_11)));
107
- const FLOAT_TYPE sw = fma(FLOAT_TYPE(by232.x), q4_12, fma(FLOAT_TYPE(by232.y), q4_13, fma(FLOAT_TYPE(by232.z), q4_14, FLOAT_TYPE(by232.w) * q4_15)));
108
- const FLOAT_TYPE smin =
109
- fma(FLOAT_TYPE(by10.x), sc2, fma(FLOAT_TYPE(by132.x), sc3, fma(FLOAT_TYPE(by20.x), sc6, fma(FLOAT_TYPE(by232.x), sc7,
110
- fma(FLOAT_TYPE(by10.y), sc2, fma(FLOAT_TYPE(by132.y), sc3, fma(FLOAT_TYPE(by20.y), sc6, fma(FLOAT_TYPE(by232.y), sc7,
111
- fma(FLOAT_TYPE(by10.z), sc2, fma(FLOAT_TYPE(by132.z), sc3, fma(FLOAT_TYPE(by20.z), sc6, fma(FLOAT_TYPE(by232.z), sc7,
112
- fma(FLOAT_TYPE(by10.w), sc2, fma(FLOAT_TYPE(by132.w), sc3, fma(FLOAT_TYPE(by20.w), sc6, FLOAT_TYPE(by232.w) * sc7)))))))))))))));
113
- temp[j][n] = fma(dall, fma(sx, sc0, fma(sy, sc1, fma(sz, sc4, sw * sc5))), fma(-dmin, smin, temp[j][n]));
114
- }
115
- }
116
- }
117
 
118
  reduce_result(temp, d_offset, first_row, num_rows, tid);
119
  }
 
6
 
7
  layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
8
 
9
+ FLOAT_TYPE temp[NUM_COLS][NUM_ROWS];
10
+
11
+ void calc_superblock(const uint a_offset, const uint b_offset, const uint v_im, const uint q_offset, const uint y_offset, const uint i, const uint num_blocks_per_row, const uint first_row, const uint num_rows) {
12
+ const uint y1_idx = i * QUANT_K + y_offset;
13
+ const uint y2_idx = y1_idx + 128;
14
+
15
+ [[unroll]] for (uint n = 0; n < num_rows; ++n) {
16
+ const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row;
17
+ vec2 d = vec2(data_a[ib0 + i].d);
18
+ const FLOAT_TYPE dall = FLOAT_TYPE(d.x);
19
+ const FLOAT_TYPE dmin = FLOAT_TYPE(d.y);
20
+
21
+ const uint32_t scale0_u32 = data_a_packed16[ib0 + i].scales[v_im ];
22
+ const uint32_t scale4_u32 = data_a_packed16[ib0 + i].scales[v_im + 2];
23
+ const uint32_t scale8_u32 = data_a_packed16[ib0 + i].scales[v_im + 4];
24
+
25
+ const uint32_t scale_0_4_l = (scale4_u32 << 16) | scale0_u32;
26
+ const uint32_t scale_0_4_h = (scale_0_4_l & 0xC0C0C0C0) >> 2;
27
+ const vec4 scale_0_4_l_f = vec4(unpack8(scale_0_4_l & 0x3F3F3F3F));
28
+ const vec4 scale8_f = vec4(unpack8((((scale8_u32 << 12) | scale8_u32) & 0x0F0F0F0F) | scale_0_4_h));
29
+
30
+ const FLOAT_TYPE sc0 = scale_0_4_l_f.x;
31
+ const FLOAT_TYPE sc1 = scale_0_4_l_f.y;
32
+ const FLOAT_TYPE sc2 = scale_0_4_l_f.z;
33
+ const FLOAT_TYPE sc3 = scale_0_4_l_f.w;
34
+ const FLOAT_TYPE sc4 = scale8_f.x;
35
+ const FLOAT_TYPE sc5 = scale8_f.y;
36
+ const FLOAT_TYPE sc6 = scale8_f.z;
37
+ const FLOAT_TYPE sc7 = scale8_f.w;
38
+
39
+ const uint32_t qs0_u32 = data_a_packed32[ib0 + i].qs[q_offset / 4];
40
+ const uint32_t qs64_u32 = data_a_packed32[ib0 + i].qs[q_offset / 4 + 16];
41
+
42
+ const uint32_t qs0_u32_lo4 = qs0_u32 & 0x0F0F0F0F;
43
+ const uint32_t qs0_u32_hi4 = (qs0_u32 >> 4) & 0x0F0F0F0F;
44
+ const uint32_t qs64_u32_lo4 = qs64_u32 & 0x0F0F0F0F;
45
+ const uint32_t qs64_u32_hi4 = (qs64_u32 >> 4) & 0x0F0F0F0F;
46
+
47
+ const vec4 qs0_lo4 = vec4(unpack8(qs0_u32_lo4));
48
+ const vec4 qs64_lo4 = vec4(unpack8(qs64_u32_lo4));
49
+ const vec4 qs0_hi4 = vec4(unpack8(qs0_u32_hi4));
50
+ const vec4 qs64_hi4 = vec4(unpack8(qs64_u32_hi4));
51
+
52
+ const FLOAT_TYPE q4_0 = qs0_lo4.x;
53
+ const FLOAT_TYPE q4_1 = qs0_lo4.y;
54
+ const FLOAT_TYPE q4_2 = qs0_lo4.z;
55
+ const FLOAT_TYPE q4_3 = qs0_lo4.w;
56
+ const FLOAT_TYPE q4_4 = qs0_hi4.x;
57
+ const FLOAT_TYPE q4_5 = qs0_hi4.y;
58
+ const FLOAT_TYPE q4_6 = qs0_hi4.z;
59
+ const FLOAT_TYPE q4_7 = qs0_hi4.w;
60
+ const FLOAT_TYPE q4_8 = qs64_lo4.x;
61
+ const FLOAT_TYPE q4_9 = qs64_lo4.y;
62
+ const FLOAT_TYPE q4_10 = qs64_lo4.z;
63
+ const FLOAT_TYPE q4_11 = qs64_lo4.w;
64
+ const FLOAT_TYPE q4_12 = qs64_hi4.x;
65
+ const FLOAT_TYPE q4_13 = qs64_hi4.y;
66
+ const FLOAT_TYPE q4_14 = qs64_hi4.z;
67
+ const FLOAT_TYPE q4_15 = qs64_hi4.w;
68
+
69
+ [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
70
+ vec4 by10 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y1_idx) / 4 ]);
71
+ vec4 by132 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y1_idx) / 4 + 8]);
72
+ vec4 by20 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y2_idx) / 4 ]);
73
+ vec4 by232 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y2_idx) / 4 + 8]);
74
+
75
+ const FLOAT_TYPE sx = fma(FLOAT_TYPE(by10.x), q4_0, fma(FLOAT_TYPE(by10.y), q4_1, fma(FLOAT_TYPE(by10.z), q4_2, FLOAT_TYPE(by10.w) * q4_3)));
76
+ const FLOAT_TYPE sy = fma(FLOAT_TYPE(by132.x), q4_4, fma(FLOAT_TYPE(by132.y), q4_5, fma(FLOAT_TYPE(by132.z), q4_6, FLOAT_TYPE(by132.w) * q4_7)));
77
+ const FLOAT_TYPE sz = fma(FLOAT_TYPE(by20.x), q4_8, fma(FLOAT_TYPE(by20.y), q4_9, fma(FLOAT_TYPE(by20.z), q4_10, FLOAT_TYPE(by20.w) * q4_11)));
78
+ const FLOAT_TYPE sw = fma(FLOAT_TYPE(by232.x), q4_12, fma(FLOAT_TYPE(by232.y), q4_13, fma(FLOAT_TYPE(by232.z), q4_14, FLOAT_TYPE(by232.w) * q4_15)));
79
+ const FLOAT_TYPE smin =
80
+ fma(FLOAT_TYPE(by10.x), sc2, fma(FLOAT_TYPE(by132.x), sc3, fma(FLOAT_TYPE(by20.x), sc6, fma(FLOAT_TYPE(by232.x), sc7,
81
+ fma(FLOAT_TYPE(by10.y), sc2, fma(FLOAT_TYPE(by132.y), sc3, fma(FLOAT_TYPE(by20.y), sc6, fma(FLOAT_TYPE(by232.y), sc7,
82
+ fma(FLOAT_TYPE(by10.z), sc2, fma(FLOAT_TYPE(by132.z), sc3, fma(FLOAT_TYPE(by20.z), sc6, fma(FLOAT_TYPE(by232.z), sc7,
83
+ fma(FLOAT_TYPE(by10.w), sc2, fma(FLOAT_TYPE(by132.w), sc3, fma(FLOAT_TYPE(by20.w), sc6, FLOAT_TYPE(by232.w) * sc7)))))))))))))));
84
+ temp[j][n] = fma(dall, fma(sx, sc0, fma(sy, sc1, fma(sz, sc4, sw * sc5))), fma(-dmin, smin, temp[j][n]));
85
+ }
86
+ }
87
+ }
88
+
89
  void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
90
  uint a_offset, b_offset, d_offset;
91
  get_offsets(a_offset, b_offset, d_offset);
 
95
  // 16 threads are used to process each block
96
  const uint it_size = gl_WorkGroupSize.x/16;
97
  const uint tid = gl_LocalInvocationID.x;
98
+ const uint itid = tid%16; // 0...15
99
+ const uint ix = tid/16;
100
 
101
+ const uint il = itid/4; // 0...3
102
+ const uint ir = itid - 4*il; // 0...3
 
 
103
  const uint n = 4;
104
 
105
  const uint v_im = il / 2; // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224
 
109
  const uint q_offset = 32*v_im + l0;
110
  const uint y_offset = 64*v_im + l0;
111
 
 
 
112
  [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
113
  [[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) {
114
  temp[j][i] = FLOAT_TYPE(0);
115
  }
116
  }
117
 
118
+ [[unroll]] for (uint i = ix; i < num_blocks_per_row; i += it_size)
119
+ calc_superblock(a_offset, b_offset, v_im, q_offset, y_offset, i, num_blocks_per_row, first_row, num_rows);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
120
 
121
  reduce_result(temp, d_offset, first_row, num_rows, tid);
122
  }
ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp CHANGED
@@ -6,6 +6,118 @@
6
 
7
  layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
10
  uint a_offset, b_offset, d_offset;
11
  get_offsets(a_offset, b_offset, d_offset);
@@ -15,11 +127,11 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
15
  // 16 threads are used to process each block
16
  const uint it_size = gl_WorkGroupSize.x/16;
17
  const uint tid = gl_LocalInvocationID.x;
18
- const uint itid = tid%16; // 0...16
19
- const uint ix = tid/16;
20
 
21
  const uint il = itid/4; // 0...3
22
- const uint ir = itid - 4*il; // 0...7 or 0...3
23
 
24
  const uint v_im = il / 2; // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224
25
  const uint v_in = il % 2;
@@ -28,121 +140,14 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
28
  const uint q_offset = 32*v_im + l0;
29
  const uint y_offset = 64*v_im + l0;
30
 
31
- FLOAT_TYPE temp[NUM_COLS][NUM_ROWS];
32
-
33
  [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
34
  [[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) {
35
  temp[j][i] = FLOAT_TYPE(0);
36
  }
37
  }
38
 
39
- [[unroll]] for (uint i = ix; i < num_blocks_per_row; i += it_size) {
40
- const uint y1_idx = i * QUANT_K + y_offset;
41
- const uint y2_idx = y1_idx + 128;
42
-
43
- [[unroll]] for (uint n = 0; n < num_rows; ++n) {
44
- const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row;
45
- vec2 d = vec2(data_a[ib0 + i].d);
46
- const FLOAT_TYPE dall = FLOAT_TYPE(d.x);
47
- const FLOAT_TYPE dmin = FLOAT_TYPE(d.y);
48
-
49
- uint32_t scale0_u32 = data_a_packed16[ib0 + i].scales[v_im ];
50
- uint32_t scale4_u32 = data_a_packed16[ib0 + i].scales[v_im + 2];
51
- uint32_t scale8_u32 = data_a_packed16[ib0 + i].scales[v_im + 4];
52
- uvec4 scale0 = uvec4(unpack8(scale0_u32));
53
- uvec4 scale4 = uvec4(unpack8(scale4_u32));
54
- uvec4 scale8 = uvec4(unpack8(scale8_u32));
55
-
56
- const uint32_t sc0 = ( scale0.x & 0x3f);
57
- const uint32_t sc1 = ( scale0.y & 0x3f);
58
- const uint32_t sc2 = ( scale4.x & 0x3f);
59
- const uint32_t sc3 = ( scale4.y & 0x3f);
60
- const uint32_t sc4 = (( scale8.x & 0x0f) | ((scale0.x & 0xc0) >> 2));
61
- const uint32_t sc5 = (( scale8.y & 0x0f) | ((scale0.y & 0xc0) >> 2));
62
- const uint32_t sc6 = (((scale8.x >> 4) & 0x0f) | ((scale4.x & 0xc0) >> 2));
63
- const uint32_t sc7 = (((scale8.y >> 4) & 0x0f) | ((scale4.y & 0xc0) >> 2));
64
-
65
- uint32_t qs0_16_u32 = uint32_t(data_a_packed16[ib0 + i].qs[q_offset / 2]) | (uint32_t(data_a_packed16[ib0 + i].qs[q_offset / 2 + 8]) << 16);
66
- uint32_t qs64_80_u32 = uint32_t(data_a_packed16[ib0 + i].qs[q_offset / 2 + 32]) | (uint32_t(data_a_packed16[ib0 + i].qs[q_offset / 2 + 40]) << 16);
67
-
68
- uint32_t qs0_16_u32_lo4 = qs0_16_u32 & 0x0F0F0F0F;
69
- uint32_t qs0_16_u32_hi4 = (qs0_16_u32 >> 4) & 0x0F0F0F0F;
70
- uint32_t qs64_80_u32_lo4 = qs64_80_u32 & 0x0F0F0F0F;
71
- uint32_t qs64_80_u32_hi4 = (qs64_80_u32 >> 4) & 0x0F0F0F0F;
72
-
73
- uint32_t qh = pack32(u16vec2(data_a_packed16[ib0 + i].qh[l0 / 2], data_a_packed16[ib0 + i].qh[l0 / 2 + 8]));
74
-
75
- uint32_t qs0_16_lo4_offset16 = ((qh >> (2*v_im)) & 0x01010101) << 4;
76
- uint32_t qs0_16_hi4_offset16 = ((qh >> (2*v_im)) & 0x02020202) << 3;
77
- uint32_t qs64_80_lo4_offset16 = ((qh >> (2*v_im)) & 0x10101010) << 0;
78
- uint32_t qs64_80_hi4_offset16 = ((qh >> (2*v_im)) & 0x20202020) >> 1;
79
-
80
- qs0_16_u32_lo4 += qs0_16_lo4_offset16;
81
- qs0_16_u32_hi4 += qs0_16_hi4_offset16;
82
- qs64_80_u32_lo4 += qs64_80_lo4_offset16;
83
- qs64_80_u32_hi4 += qs64_80_hi4_offset16;
84
-
85
- uvec4 qs0_16_lo4 = uvec4(unpack8(qs0_16_u32_lo4));
86
- uvec4 qs64_80_lo4 = uvec4(unpack8(qs64_80_u32_lo4));
87
- uvec4 qs0_16_hi4 = uvec4(unpack8(qs0_16_u32_hi4));
88
- uvec4 qs64_80_hi4 = uvec4(unpack8(qs64_80_u32_hi4));
89
-
90
- const uint32_t q4_0 = qs0_16_lo4.x;
91
- const uint32_t q4_1 = qs0_16_lo4.y;
92
- const uint32_t q4_2 = qs0_16_lo4.z;
93
- const uint32_t q4_3 = qs0_16_lo4.w;
94
- const uint32_t q4_4 = qs0_16_hi4.x;
95
- const uint32_t q4_5 = qs0_16_hi4.y;
96
- const uint32_t q4_6 = qs0_16_hi4.z;
97
- const uint32_t q4_7 = qs0_16_hi4.w;
98
- const uint32_t q4_8 = qs64_80_lo4.x;
99
- const uint32_t q4_9 = qs64_80_lo4.y;
100
- const uint32_t q4_10 = qs64_80_lo4.z;
101
- const uint32_t q4_11 = qs64_80_lo4.w;
102
- const uint32_t q4_12 = qs64_80_hi4.x;
103
- const uint32_t q4_13 = qs64_80_hi4.y;
104
- const uint32_t q4_14 = qs64_80_hi4.z;
105
- const uint32_t q4_15 = qs64_80_hi4.w;
106
-
107
- [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
108
- vec2 by10 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y1_idx) / 2 ]);
109
- vec2 by116 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y1_idx) / 2 + 8]);
110
- vec2 by132 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y1_idx) / 2 + 16]);
111
- vec2 by148 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y1_idx) / 2 + 24]);
112
- vec2 by20 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y2_idx) / 2 ]);
113
- vec2 by216 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y2_idx) / 2 + 8]);
114
- vec2 by232 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y2_idx) / 2 + 16]);
115
- vec2 by248 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y2_idx) / 2 + 24]);
116
-
117
- const FLOAT_TYPE sx =
118
- fma(FLOAT_TYPE(by10.x), q4_0,
119
- fma(FLOAT_TYPE(by10.y), q4_1,
120
- fma(FLOAT_TYPE(by116.x), q4_2,
121
- FLOAT_TYPE(by116.y) * q4_3)));
122
- const FLOAT_TYPE sy =
123
- fma(FLOAT_TYPE(by132.x), q4_4,
124
- fma(FLOAT_TYPE(by132.y), q4_5,
125
- fma(FLOAT_TYPE(by148.x), q4_6,
126
- FLOAT_TYPE(by148.y) * q4_7)));
127
- const FLOAT_TYPE sz =
128
- fma(FLOAT_TYPE(by20.x), q4_8,
129
- fma(FLOAT_TYPE(by20.y), q4_9,
130
- fma(FLOAT_TYPE(by216.x), q4_10,
131
- FLOAT_TYPE(by216.y) * q4_11)));
132
- const FLOAT_TYPE sw =
133
- fma(FLOAT_TYPE(by232.x), q4_12,
134
- fma(FLOAT_TYPE(by232.y), q4_13,
135
- fma(FLOAT_TYPE(by248.x), q4_14,
136
- FLOAT_TYPE(by248.y) * q4_15)));
137
- const FLOAT_TYPE smin =
138
- fma(FLOAT_TYPE(by10.x) + FLOAT_TYPE(by10.y) + FLOAT_TYPE(by116.x) + FLOAT_TYPE(by116.y), sc2,
139
- fma(FLOAT_TYPE(by132.x) + FLOAT_TYPE(by132.y) + FLOAT_TYPE(by148.x) + FLOAT_TYPE(by148.y), sc3,
140
- fma(FLOAT_TYPE(by20.x) + FLOAT_TYPE(by20.y) + FLOAT_TYPE(by216.x) + FLOAT_TYPE(by216.y), sc6,
141
- (FLOAT_TYPE(by232.x) + FLOAT_TYPE(by232.y) + FLOAT_TYPE(by248.x) + FLOAT_TYPE(by248.y)) * sc7)));
142
- temp[j][n] = fma(dall, fma(sx, sc0, fma(sy, sc1, fma(sz, sc4, sw * sc5))), fma(-dmin, smin, temp[j][n]));
143
- }
144
- }
145
- }
146
 
147
  reduce_result(temp, d_offset, first_row, num_rows, tid);
148
  }
 
6
 
7
  layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
8
 
9
+ FLOAT_TYPE temp[NUM_COLS][NUM_ROWS];
10
+
11
+ void calc_superblock(const uint a_offset, const uint b_offset, const uint v_im, const uint l0, const uint q_offset, const uint y_offset, const uint i, const uint num_blocks_per_row, const uint first_row, const uint num_rows) {
12
+ const uint y1_idx = i * QUANT_K + y_offset;
13
+ const uint y2_idx = y1_idx + 128;
14
+
15
+ [[unroll]] for (uint n = 0; n < num_rows; ++n) {
16
+ const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row;
17
+ vec2 d = vec2(data_a[ib0 + i].d);
18
+ const FLOAT_TYPE dall = FLOAT_TYPE(d.x);
19
+ const FLOAT_TYPE dmin = FLOAT_TYPE(d.y);
20
+
21
+ const uint32_t scale0_u32 = data_a_packed16[ib0 + i].scales[v_im ];
22
+ const uint32_t scale4_u32 = data_a_packed16[ib0 + i].scales[v_im + 2];
23
+ const uint32_t scale8_u32 = data_a_packed16[ib0 + i].scales[v_im + 4];
24
+
25
+ const uint32_t scale_0_4_l = (scale4_u32 << 16) | scale0_u32;
26
+ const uint32_t scale_0_4_h = (scale_0_4_l & 0xC0C0C0C0) >> 2;
27
+ const vec4 scale_0_4_l_f = vec4(unpack8(scale_0_4_l & 0x3F3F3F3F));
28
+ const vec4 scale8_f = vec4(unpack8((((scale8_u32 << 12) | scale8_u32) & 0x0F0F0F0F) | scale_0_4_h));
29
+
30
+ const FLOAT_TYPE sc0 = scale_0_4_l_f.x;
31
+ const FLOAT_TYPE sc1 = scale_0_4_l_f.y;
32
+ const FLOAT_TYPE sc2 = scale_0_4_l_f.z;
33
+ const FLOAT_TYPE sc3 = scale_0_4_l_f.w;
34
+ const FLOAT_TYPE sc4 = scale8_f.x;
35
+ const FLOAT_TYPE sc5 = scale8_f.y;
36
+ const FLOAT_TYPE sc6 = scale8_f.z;
37
+ const FLOAT_TYPE sc7 = scale8_f.w;
38
+
39
+ const uint32_t qs0_16_u32 = uint32_t(data_a_packed16[ib0 + i].qs[q_offset / 2]) | (uint32_t(data_a_packed16[ib0 + i].qs[q_offset / 2 + 8]) << 16);
40
+ const uint32_t qs64_80_u32 = uint32_t(data_a_packed16[ib0 + i].qs[q_offset / 2 + 32]) | (uint32_t(data_a_packed16[ib0 + i].qs[q_offset / 2 + 40]) << 16);
41
+
42
+ uint32_t qs0_16_u32_lo4 = qs0_16_u32 & 0x0F0F0F0F;
43
+ uint32_t qs0_16_u32_hi4 = (qs0_16_u32 >> 4) & 0x0F0F0F0F;
44
+ uint32_t qs64_80_u32_lo4 = qs64_80_u32 & 0x0F0F0F0F;
45
+ uint32_t qs64_80_u32_hi4 = (qs64_80_u32 >> 4) & 0x0F0F0F0F;
46
+
47
+ const uint32_t qh = pack32(u16vec2(data_a_packed16[ib0 + i].qh[l0 / 2], data_a_packed16[ib0 + i].qh[l0 / 2 + 8]));
48
+
49
+ const uint32_t qs0_16_lo4_offset16 = ((qh >> (2*v_im)) & 0x01010101) << 4;
50
+ const uint32_t qs0_16_hi4_offset16 = ((qh >> (2*v_im)) & 0x02020202) << 3;
51
+ const uint32_t qs64_80_lo4_offset16 = ((qh >> (2*v_im)) & 0x10101010);
52
+ const uint32_t qs64_80_hi4_offset16 = ((qh >> (2*v_im)) & 0x20202020) >> 1;
53
+
54
+ qs0_16_u32_lo4 += qs0_16_lo4_offset16;
55
+ qs0_16_u32_hi4 += qs0_16_hi4_offset16;
56
+ qs64_80_u32_lo4 += qs64_80_lo4_offset16;
57
+ qs64_80_u32_hi4 += qs64_80_hi4_offset16;
58
+
59
+ const vec4 qs0_16_lo4 = vec4(unpack8(qs0_16_u32_lo4));
60
+ const vec4 qs64_80_lo4 = vec4(unpack8(qs64_80_u32_lo4));
61
+ const vec4 qs0_16_hi4 = vec4(unpack8(qs0_16_u32_hi4));
62
+ const vec4 qs64_80_hi4 = vec4(unpack8(qs64_80_u32_hi4));
63
+
64
+ const FLOAT_TYPE q4_0 = qs0_16_lo4.x;
65
+ const FLOAT_TYPE q4_1 = qs0_16_lo4.y;
66
+ const FLOAT_TYPE q4_2 = qs0_16_lo4.z;
67
+ const FLOAT_TYPE q4_3 = qs0_16_lo4.w;
68
+ const FLOAT_TYPE q4_4 = qs0_16_hi4.x;
69
+ const FLOAT_TYPE q4_5 = qs0_16_hi4.y;
70
+ const FLOAT_TYPE q4_6 = qs0_16_hi4.z;
71
+ const FLOAT_TYPE q4_7 = qs0_16_hi4.w;
72
+ const FLOAT_TYPE q4_8 = qs64_80_lo4.x;
73
+ const FLOAT_TYPE q4_9 = qs64_80_lo4.y;
74
+ const FLOAT_TYPE q4_10 = qs64_80_lo4.z;
75
+ const FLOAT_TYPE q4_11 = qs64_80_lo4.w;
76
+ const FLOAT_TYPE q4_12 = qs64_80_hi4.x;
77
+ const FLOAT_TYPE q4_13 = qs64_80_hi4.y;
78
+ const FLOAT_TYPE q4_14 = qs64_80_hi4.z;
79
+ const FLOAT_TYPE q4_15 = qs64_80_hi4.w;
80
+
81
+ [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
82
+ vec2 by10 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y1_idx) / 2 ]);
83
+ vec2 by116 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y1_idx) / 2 + 8]);
84
+ vec2 by132 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y1_idx) / 2 + 16]);
85
+ vec2 by148 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y1_idx) / 2 + 24]);
86
+ vec2 by20 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y2_idx) / 2 ]);
87
+ vec2 by216 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y2_idx) / 2 + 8]);
88
+ vec2 by232 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y2_idx) / 2 + 16]);
89
+ vec2 by248 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y2_idx) / 2 + 24]);
90
+
91
+ const FLOAT_TYPE sx =
92
+ fma(FLOAT_TYPE(by10.x), q4_0,
93
+ fma(FLOAT_TYPE(by10.y), q4_1,
94
+ fma(FLOAT_TYPE(by116.x), q4_2,
95
+ FLOAT_TYPE(by116.y) * q4_3)));
96
+ const FLOAT_TYPE sy =
97
+ fma(FLOAT_TYPE(by132.x), q4_4,
98
+ fma(FLOAT_TYPE(by132.y), q4_5,
99
+ fma(FLOAT_TYPE(by148.x), q4_6,
100
+ FLOAT_TYPE(by148.y) * q4_7)));
101
+ const FLOAT_TYPE sz =
102
+ fma(FLOAT_TYPE(by20.x), q4_8,
103
+ fma(FLOAT_TYPE(by20.y), q4_9,
104
+ fma(FLOAT_TYPE(by216.x), q4_10,
105
+ FLOAT_TYPE(by216.y) * q4_11)));
106
+ const FLOAT_TYPE sw =
107
+ fma(FLOAT_TYPE(by232.x), q4_12,
108
+ fma(FLOAT_TYPE(by232.y), q4_13,
109
+ fma(FLOAT_TYPE(by248.x), q4_14,
110
+ FLOAT_TYPE(by248.y) * q4_15)));
111
+ const FLOAT_TYPE smin =
112
+ fma(FLOAT_TYPE(by10.x) + FLOAT_TYPE(by10.y) + FLOAT_TYPE(by116.x) + FLOAT_TYPE(by116.y), sc2,
113
+ fma(FLOAT_TYPE(by132.x) + FLOAT_TYPE(by132.y) + FLOAT_TYPE(by148.x) + FLOAT_TYPE(by148.y), sc3,
114
+ fma(FLOAT_TYPE(by20.x) + FLOAT_TYPE(by20.y) + FLOAT_TYPE(by216.x) + FLOAT_TYPE(by216.y), sc6,
115
+ (FLOAT_TYPE(by232.x) + FLOAT_TYPE(by232.y) + FLOAT_TYPE(by248.x) + FLOAT_TYPE(by248.y)) * sc7)));
116
+ temp[j][n] = fma(dall, fma(sx, sc0, fma(sy, sc1, fma(sz, sc4, sw * sc5))), fma(-dmin, smin, temp[j][n]));
117
+ }
118
+ }
119
+ }
120
+
121
  void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
122
  uint a_offset, b_offset, d_offset;
123
  get_offsets(a_offset, b_offset, d_offset);
 
127
  // 16 threads are used to process each block
128
  const uint it_size = gl_WorkGroupSize.x/16;
129
  const uint tid = gl_LocalInvocationID.x;
130
+ const uint itid = tid%16; // 0...15
131
+ const uint ix = tid/16;
132
 
133
  const uint il = itid/4; // 0...3
134
+ const uint ir = itid - 4*il; // 0...3
135
 
136
  const uint v_im = il / 2; // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224
137
  const uint v_in = il % 2;
 
140
  const uint q_offset = 32*v_im + l0;
141
  const uint y_offset = 64*v_im + l0;
142
 
 
 
143
  [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
144
  [[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) {
145
  temp[j][i] = FLOAT_TYPE(0);
146
  }
147
  }
148
 
149
+ [[unroll]] for (uint i = ix; i < num_blocks_per_row; i += it_size)
150
+ calc_superblock(a_offset, b_offset, v_im, l0, q_offset, y_offset, i, num_blocks_per_row, first_row, num_rows);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
151
 
152
  reduce_result(temp, d_offset, first_row, num_rows, tid);
153
  }
ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp CHANGED
@@ -6,7 +6,77 @@
6
 
7
  layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
8
 
9
- void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  uint a_offset, b_offset, d_offset;
11
  get_offsets(a_offset, b_offset, d_offset);
12
 
@@ -15,13 +85,11 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
15
  // 16 threads are used to process each block
16
  const uint it_size = gl_WorkGroupSize.x/16;
17
  const uint tid = gl_LocalInvocationID.x;
18
- const uint itid = tid%16; // 0...16
19
- const uint ix = tid/16;
20
 
21
- const uint step = 8;
22
-
23
- const uint v_im = itid/step; // 0 or 1. 0 computes 0..., 1 computes 128...
24
- const uint v_in = itid - step*v_im; // 0...15 or 0...7
25
 
26
  const uint l0 = 4 * v_in; // 0, 4, 8, ..., 28
27
  const uint is = v_in / 4;
@@ -31,68 +99,18 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
31
  const uint s_offset = 8*v_im + is;
32
  const uint y_offset = 128*v_im + l0;
33
 
34
- FLOAT_TYPE temp[NUM_COLS][NUM_ROWS];
35
-
36
  [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
37
  [[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) {
38
  temp[j][i] = FLOAT_TYPE(0);
39
  }
40
  }
41
 
42
- [[unroll]] for (uint i = ix; i < num_blocks_per_row; i += it_size) {
43
- const uint y_idx = i * QUANT_K + y_offset;
44
-
45
- [[unroll]] for (uint n = 0; n < num_rows; ++n) {
46
- const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row;
47
- const FLOAT_TYPE d = FLOAT_TYPE(data_a[ib0 + i].d);
48
-
49
- FLOAT_TYPE scales[4];
50
- scales[0] = FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 0]);
51
- scales[1] = FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 2]);
52
- scales[2] = FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 4]);
53
- scales[3] = FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 6]);
54
-
55
- uint32_t ql0_u32 = uint32_t(data_a_packed16[ib0 + i].ql[ql_offset / 2]) | (uint32_t(data_a_packed16[ib0 + i].ql[ql_offset / 2 + 1]) << 16);
56
- uint32_t ql32_u32 = uint32_t(data_a_packed16[ib0 + i].ql[ql_offset / 2 + 16]) | (uint32_t(data_a_packed16[ib0 + i].ql[ql_offset / 2 + 17]) << 16);
57
-
58
- uint32_t ql0_u32_lo4 = ql0_u32 & 0x0F0F0F0F;
59
- uint32_t ql0_u32_hi4 = (ql0_u32 >> 4) & 0x0F0F0F0F;
60
- uint32_t ql32_u32_lo4 = ql32_u32 & 0x0F0F0F0F;
61
- uint32_t ql32_u32_hi4 = (ql32_u32 >> 4) & 0x0F0F0F0F;
62
-
63
- uint32_t qh_u32 = uint32_t(data_a_packed16[ib0 + i].qh[qh_offset / 2]) | (uint32_t(data_a_packed16[ib0 + i].qh[qh_offset / 2 + 1]) << 16);
64
- uint32_t qh0_u32 = (qh_u32 & 0x03030303) << 4;
65
- uint32_t qh2_u32 = (qh_u32 & 0x0C0C0C0C) << 2;
66
- uint32_t qh4_u32 = (qh_u32 & 0x30303030) << 0;
67
- uint32_t qh6_u32 = (qh_u32 & 0xC0C0C0C0) >> 2;
68
-
69
- uint32_t q0_u32 = ql0_u32_lo4 | qh0_u32;
70
- uint32_t q1_u32 = ql32_u32_lo4 | qh2_u32;
71
- uint32_t q2_u32 = ql0_u32_hi4 | qh4_u32;
72
- uint32_t q3_u32 = ql32_u32_hi4 | qh6_u32;
73
-
74
- uvec4 q0 = uvec4(unpack8(q0_u32));
75
- uvec4 q1 = uvec4(unpack8(q1_u32));
76
- uvec4 q2 = uvec4(unpack8(q2_u32));
77
- uvec4 q3 = uvec4(unpack8(q3_u32));
78
-
79
- [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
80
- vec4 by0 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 ]);
81
- vec4 by32 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 8]);
82
- vec4 by64 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 16]);
83
- vec4 by96 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 24]);
84
-
85
- FLOAT_TYPE sum = FLOAT_TYPE(0.0);
86
- [[unroll]] for (int l = 0; l < 4; ++l) {
87
- sum = fma(FLOAT_TYPE(by0[l]) * scales[0], FLOAT_TYPE(int8_t(q0[l]) - 32),
88
- fma(FLOAT_TYPE(by32[l]) * scales[1], FLOAT_TYPE(int8_t(q1[l]) - 32),
89
- fma(FLOAT_TYPE(by64[l]) * scales[2], FLOAT_TYPE(int8_t(q2[l]) - 32),
90
- fma(FLOAT_TYPE(by96[l]) * scales[3], FLOAT_TYPE(int8_t(q3[l]) - 32), sum))));
91
- }
92
- temp[j][n] += sum * d;
93
- }
94
- }
95
- }
96
 
97
  reduce_result(temp, d_offset, first_row, num_rows, tid);
98
  }
 
6
 
7
  layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
8
 
9
+ shared FLOAT_TYPE sccache[BLOCK_SIZE/16][16];
10
+
11
+ FLOAT_TYPE temp[NUM_COLS][NUM_ROWS];
12
+
13
+ void calc_superblock(const uint a_offset, const uint b_offset, const uint itid, const uint ix, const uint ql_offset, const uint qh_offset, const uint s_offset, const uint y_offset, const uint i, const uint num_blocks_per_row, const uint first_row, const uint num_rows, const bool all_threads) {
14
+ const uint y_idx = i * QUANT_K + y_offset;
15
+
16
+ [[unroll]] for (uint n = 0; n < num_rows; ++n) {
17
+ const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row;
18
+
19
+ if (!all_threads) { // when we don't have enough blocks to use all threads
20
+ barrier();
21
+ if (i < num_blocks_per_row)
22
+ sccache[ix][itid] = FLOAT_TYPE(data_a[ib0 + i].scales[itid]);
23
+ barrier();
24
+
25
+ if (i >= num_blocks_per_row)
26
+ continue;
27
+ }
28
+
29
+ const uint32_t ql0_u32 = uint32_t(data_a_packed16[ib0 + i].ql[ql_offset / 2]) | (uint32_t(data_a_packed16[ib0 + i].ql[ql_offset / 2 + 1]) << 16);
30
+ const uint32_t ql32_u32 = uint32_t(data_a_packed16[ib0 + i].ql[ql_offset / 2 + 16]) | (uint32_t(data_a_packed16[ib0 + i].ql[ql_offset / 2 + 17]) << 16);
31
+
32
+ const uint32_t ql0_u32_lo4 = ql0_u32 & 0x0F0F0F0F;
33
+ const uint32_t ql0_u32_hi4 = (ql0_u32 >> 4) & 0x0F0F0F0F;
34
+ const uint32_t ql32_u32_lo4 = ql32_u32 & 0x0F0F0F0F;
35
+ const uint32_t ql32_u32_hi4 = (ql32_u32 >> 4) & 0x0F0F0F0F;
36
+
37
+ const uint32_t qh_u32 = uint32_t(data_a_packed16[ib0 + i].qh[qh_offset / 2]) | (uint32_t(data_a_packed16[ib0 + i].qh[qh_offset / 2 + 1]) << 16);
38
+ const uint32_t qh0_u32 = (qh_u32 & 0x03030303) << 4;
39
+ const uint32_t qh2_u32 = (qh_u32 & 0x0C0C0C0C) << 2;
40
+ const uint32_t qh4_u32 = (qh_u32 & 0x30303030);
41
+ const uint32_t qh6_u32 = (qh_u32 & 0xC0C0C0C0) >> 2;
42
+
43
+ const uint32_t q0_u32 = ql0_u32_lo4 | qh0_u32;
44
+ const uint32_t q1_u32 = ql32_u32_lo4 | qh2_u32;
45
+ const uint32_t q2_u32 = ql0_u32_hi4 | qh4_u32;
46
+ const uint32_t q3_u32 = ql32_u32_hi4 | qh6_u32;
47
+
48
+ const vec4 q0 = vec4(unpack8(q0_u32)) - 32;
49
+ const vec4 q1 = vec4(unpack8(q1_u32)) - 32;
50
+ const vec4 q2 = vec4(unpack8(q2_u32)) - 32;
51
+ const vec4 q3 = vec4(unpack8(q3_u32)) - 32;
52
+
53
+ if (all_threads) {
54
+ barrier();
55
+ sccache[ix][itid] = FLOAT_TYPE(data_a[ib0 + i].scales[itid]);
56
+ barrier();
57
+ }
58
+
59
+ const FLOAT_TYPE d = FLOAT_TYPE(data_a[ib0 + i].d);
60
+
61
+ [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
62
+ vec4 by0 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 ]);
63
+ vec4 by32 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 8]);
64
+ vec4 by64 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 16]);
65
+ vec4 by96 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 24]);
66
+
67
+ FLOAT_TYPE sum[4] = {0, 0, 0, 0};
68
+ [[unroll]] for (uint l = 0; l < 4; ++l) {
69
+ sum[0] = fma(FLOAT_TYPE(by0[l]), q0[l], sum[0]);
70
+ sum[1] = fma(FLOAT_TYPE(by32[l]), q1[l], sum[1]);
71
+ sum[2] = fma(FLOAT_TYPE(by64[l]), q2[l], sum[2]);
72
+ sum[3] = fma(FLOAT_TYPE(by96[l]), q3[l], sum[3]);
73
+ }
74
+ temp[j][n] = fma(fma(sum[0], sccache[ix][s_offset], fma(sum[1], sccache[ix][s_offset + 2], fma(sum[2], sccache[ix][s_offset + 4], sum[3] * sccache[ix][s_offset + 6]))), d, temp[j][n]);
75
+ }
76
+ }
77
+ }
78
+
79
+ void compute_outputs(const uint first_row, const uint num_rows) {
80
  uint a_offset, b_offset, d_offset;
81
  get_offsets(a_offset, b_offset, d_offset);
82
 
 
85
  // 16 threads are used to process each block
86
  const uint it_size = gl_WorkGroupSize.x/16;
87
  const uint tid = gl_LocalInvocationID.x;
88
+ const uint itid = tid%16; // 0...15
89
+ const uint ix = tid/16;
90
 
91
+ const uint v_im = itid/8; // 0 or 1. 0 computes 0..., 1 computes 128...
92
+ const uint v_in = itid - 8*v_im; // 0...7
 
 
93
 
94
  const uint l0 = 4 * v_in; // 0, 4, 8, ..., 28
95
  const uint is = v_in / 4;
 
99
  const uint s_offset = 8*v_im + is;
100
  const uint y_offset = 128*v_im + l0;
101
 
 
 
102
  [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
103
  [[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) {
104
  temp[j][i] = FLOAT_TYPE(0);
105
  }
106
  }
107
 
108
+ const uint nbr_par_th = num_blocks_per_row%it_size;
109
+ const uint nbr_all_th = num_blocks_per_row - nbr_par_th;
110
+ uint i0 = 0;
111
+ [[unroll]] for (; i0 < nbr_all_th; i0 += it_size)
112
+ calc_superblock(a_offset, b_offset, itid, ix, ql_offset, qh_offset, s_offset, y_offset, i0 + ix, num_blocks_per_row, first_row, num_rows, true);
113
+ calc_superblock(a_offset, b_offset, itid, ix, ql_offset, qh_offset, s_offset, y_offset, i0 + ix, num_blocks_per_row, first_row, num_rows, false);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
114
 
115
  reduce_result(temp, d_offset, first_row, num_rows, tid);
116
  }