jeffbolznv commited on
Commit
ee122d3
·
1 Parent(s): d49a569

vulkan: optimize coopmat2 q4_k/q5_k dequant functions. (llama/11206)

Browse files
ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.comp CHANGED
@@ -163,39 +163,47 @@ layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ4
163
  block_q4_K_packed16 block;
164
  };
165
 
 
 
 
 
166
  float16_t dequantFuncQ4_K(const in decodeBufQ4_K bl, const in uint blockCoords[2], const in uint coordInBlock[2])
167
  {
168
  decodeBufQ4_K_packed16 bl16 = decodeBufQ4_K_packed16(bl);
 
169
  const uint idx = coordInBlock[1];
170
 
171
  const uint b = (idx & 0x20) >> 5; // 0,1
172
  const uint is = (idx & 0xE0) >> 5; // 0..7
173
 
174
- const f16vec2 loadd = bl.block.d;
 
 
175
 
176
  uint32_t sc;
177
  uint32_t mbyte;
178
 
179
- uint32_t scidx0 = (is < 4) ? is : (is + 4);
180
- uint32_t scidx1 = (is < 4) ? is : (is - 4);
181
- uint32_t scidxmask1 = (is < 4) ? 0x30 : 0xC0;
182
- uint32_t scidxshift1 = (is < 4) ? 0 : 2;
183
- uint32_t mbidx0 = is + 4;
184
- uint32_t mbidx1 = (is < 4) ? is + 4 : is;
185
- uint32_t mbidxmask0 = (is < 4) ? 0xF : 0xF0;
186
- uint32_t mbidxshift0 = (is < 4) ? 0 : 4;
187
- uint32_t mbidxmask1 = (is < 4) ? 0x30 : 0xC0;
188
- uint32_t mbidxshift1 = (is < 4) ? 0 : 2;
189
 
190
- sc = uint8_t((bl.block.scales[scidx0] & 0xF) | ((bl.block.scales[scidx1] & scidxmask1) >> scidxshift1));
191
- mbyte = uint8_t(((bl.block.scales[mbidx0] & mbidxmask0) >> mbidxshift0) | ((bl.block.scales[mbidx1] & mbidxmask1) >> mbidxshift1));
 
 
 
 
 
 
 
 
 
192
 
193
  const float16_t d = loadd.x * float16_t(sc);
194
  const float16_t m = loadd.y * float16_t(mbyte);
195
 
196
  uint qs = uint32_t(bl16.block.qs[((idx & 0xC0) >> 2) + ((idx & 0x1E) >> 1)]);
197
- qs = (qs >> (b * 4)) & 0x0F0F;
198
- qs = unpack8(qs)[idx & 1];
199
 
200
  float16_t ret = d * float16_t(qs) - m;
201
 
@@ -210,47 +218,53 @@ layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ5
210
  block_q5_K_packed16 block;
211
  };
212
 
 
 
 
 
213
  float16_t dequantFuncQ5_K(const in decodeBufQ5_K bl, const in uint blockCoords[2], const in uint coordInBlock[2])
214
  {
215
  decodeBufQ5_K_packed16 bl16 = decodeBufQ5_K_packed16(bl);
 
216
  const uint idx = coordInBlock[1];
217
 
218
  const uint b = (idx & 0x20) >> 5; // 0,1
219
  const uint is = (idx & 0xE0) >> 5; // 0..7
220
 
221
- const uint32_t hm = 0x0101 << is;
222
 
223
- const f16vec2 loadd = bl.block.d;
224
 
225
  uint32_t sc;
226
  uint32_t mbyte;
227
 
228
- uint32_t scidx0 = (is < 4) ? is : (is + 4);
229
- uint32_t scidx1 = (is < 4) ? is : (is - 4);
230
- uint32_t scidxmask1 = (is < 4) ? 0x30 : 0xC0;
231
- uint32_t scidxshift1 = (is < 4) ? 0 : 2;
232
- uint32_t mbidx0 = is + 4;
233
- uint32_t mbidx1 = (is < 4) ? is + 4 : is;
234
- uint32_t mbidxmask0 = (is < 4) ? 0xF : 0xF0;
235
- uint32_t mbidxshift0 = (is < 4) ? 0 : 4;
236
- uint32_t mbidxmask1 = (is < 4) ? 0x30 : 0xC0;
237
- uint32_t mbidxshift1 = (is < 4) ? 0 : 2;
238
 
239
- sc = uint8_t((bl.block.scales[scidx0] & 0xF) | ((bl.block.scales[scidx1] & scidxmask1) >> scidxshift1));
240
- mbyte = uint8_t(((bl.block.scales[mbidx0] & mbidxmask0) >> mbidxshift0) | ((bl.block.scales[mbidx1] & mbidxmask1) >> mbidxshift1));
 
 
 
 
 
 
 
 
 
241
 
242
  const float16_t d = loadd.x * float16_t(sc);
243
  const float16_t m = loadd.y * float16_t(mbyte);
244
 
245
  uint qh = uint32_t(bl16.block.qh[(idx & 0x1E) >> 1]);
246
- qh = qh & hm;
247
- qh = unpack8(qh)[idx & 1];
248
 
249
  uint qs = uint32_t(bl16.block.qs[((idx & 0xC0) >> 2) + ((idx & 0x1E) >> 1)]);
250
  qs = (qs >> (b * 4)) & 0x0F0F;
251
- qs = unpack8(qs)[idx & 1];
252
 
253
- float16_t ret = d * (float16_t(qs) + (qh != 0 ? float16_t(16) : float16_t(0))) - m;
254
 
255
  return ret;
256
  }
 
163
  block_q4_K_packed16 block;
164
  };
165
 
166
+ layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ4_K_packed128 {
167
+ block_q4_K_packed128 block;
168
+ };
169
+
170
  float16_t dequantFuncQ4_K(const in decodeBufQ4_K bl, const in uint blockCoords[2], const in uint coordInBlock[2])
171
  {
172
  decodeBufQ4_K_packed16 bl16 = decodeBufQ4_K_packed16(bl);
173
+ decodeBufQ4_K_packed128 bl128 = decodeBufQ4_K_packed128(bl);
174
  const uint idx = coordInBlock[1];
175
 
176
  const uint b = (idx & 0x20) >> 5; // 0,1
177
  const uint is = (idx & 0xE0) >> 5; // 0..7
178
 
179
+ uvec4 v = bl128.block.q4k[0];
180
+
181
+ const f16vec2 loadd = unpackFloat2x16(v.x);
182
 
183
  uint32_t sc;
184
  uint32_t mbyte;
185
 
186
+ uint32_t scale0 = v.y;
187
+ uint32_t scale4 = v.z;
188
+ uint32_t scale8 = v.w;
 
 
 
 
 
 
 
189
 
190
+ uint32_t sc_lo = scale0;
191
+ uint32_t mb_lo = scale4;
192
+ uint32_t sc_hi = (scale8 & 0x0F0F0F0F) | ((scale0 & 0xC0C0C0C0) >> 2);
193
+ uint32_t mb_hi = ((scale8 & 0xF0F0F0F0) >> 4) | ((scale4 & 0xC0C0C0C0) >> 2);
194
+
195
+ sc = is < 4 ? sc_lo : sc_hi;
196
+ mbyte = is < 4 ? mb_lo : mb_hi;
197
+ sc = sc >> (8 * (is & 3));
198
+ mbyte = mbyte >> (8 * (is & 3));
199
+ sc &= 0x3F;
200
+ mbyte &= 0x3F;
201
 
202
  const float16_t d = loadd.x * float16_t(sc);
203
  const float16_t m = loadd.y * float16_t(mbyte);
204
 
205
  uint qs = uint32_t(bl16.block.qs[((idx & 0xC0) >> 2) + ((idx & 0x1E) >> 1)]);
206
+ qs = (qs >> (b * 4 + 8 * (idx & 1))) & 0xF;
 
207
 
208
  float16_t ret = d * float16_t(qs) - m;
209
 
 
218
  block_q5_K_packed16 block;
219
  };
220
 
221
+ layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ5_K_packed128 {
222
+ block_q5_K_packed128 block;
223
+ };
224
+
225
  float16_t dequantFuncQ5_K(const in decodeBufQ5_K bl, const in uint blockCoords[2], const in uint coordInBlock[2])
226
  {
227
  decodeBufQ5_K_packed16 bl16 = decodeBufQ5_K_packed16(bl);
228
+ decodeBufQ5_K_packed128 bl128 = decodeBufQ5_K_packed128(bl);
229
  const uint idx = coordInBlock[1];
230
 
231
  const uint b = (idx & 0x20) >> 5; // 0,1
232
  const uint is = (idx & 0xE0) >> 5; // 0..7
233
 
234
+ uvec4 v = bl128.block.q5k[0];
235
 
236
+ const f16vec2 loadd = unpackFloat2x16(v.x);
237
 
238
  uint32_t sc;
239
  uint32_t mbyte;
240
 
241
+ uint32_t scale0 = v.y;
242
+ uint32_t scale4 = v.z;
243
+ uint32_t scale8 = v.w;
 
 
 
 
 
 
 
244
 
245
+ uint32_t sc_lo = scale0;
246
+ uint32_t mb_lo = scale4;
247
+ uint32_t sc_hi = (scale8 & 0x0F0F0F0F) | ((scale0 & 0xC0C0C0C0) >> 2);
248
+ uint32_t mb_hi = ((scale8 & 0xF0F0F0F0) >> 4) | ((scale4 & 0xC0C0C0C0) >> 2);
249
+
250
+ sc = is < 4 ? sc_lo : sc_hi;
251
+ mbyte = is < 4 ? mb_lo : mb_hi;
252
+ sc = sc >> (8 * (is & 3));
253
+ mbyte = mbyte >> (8 * (is & 3));
254
+ sc &= 0x3F;
255
+ mbyte &= 0x3F;
256
 
257
  const float16_t d = loadd.x * float16_t(sc);
258
  const float16_t m = loadd.y * float16_t(mbyte);
259
 
260
  uint qh = uint32_t(bl16.block.qh[(idx & 0x1E) >> 1]);
261
+ qh = ((qh >> is) & 0x101) << 4;
 
262
 
263
  uint qs = uint32_t(bl16.block.qs[((idx & 0xC0) >> 2) + ((idx & 0x1E) >> 1)]);
264
  qs = (qs >> (b * 4)) & 0x0F0F;
265
+ qs = unpack8(qs | qh)[idx & 1];
266
 
267
+ float16_t ret = d * (float16_t(qs)) - m;
268
 
269
  return ret;
270
  }
ggml/src/ggml-vulkan/vulkan-shaders/types.comp CHANGED
@@ -227,6 +227,11 @@ struct block_q4_K_packed32
227
  uint32_t qs[QUANT_K_Q4_K/2/4];
228
  };
229
 
 
 
 
 
 
230
  #if defined(DATA_A_Q4_K)
231
  #define QUANT_K QUANT_K_Q4_K
232
  #define A_TYPE block_q4_K
@@ -252,6 +257,11 @@ struct block_q5_K_packed16
252
  uint16_t qs[QUANT_K_Q5_K/2/2];
253
  };
254
 
 
 
 
 
 
255
  #if defined(DATA_A_Q5_K)
256
  #define QUANT_K QUANT_K_Q5_K
257
  #define A_TYPE block_q5_K
 
227
  uint32_t qs[QUANT_K_Q4_K/2/4];
228
  };
229
 
230
+ struct block_q4_K_packed128
231
+ {
232
+ uvec4 q4k[9];
233
+ };
234
+
235
  #if defined(DATA_A_Q4_K)
236
  #define QUANT_K QUANT_K_Q4_K
237
  #define A_TYPE block_q4_K
 
257
  uint16_t qs[QUANT_K_Q5_K/2/2];
258
  };
259
 
260
+ struct block_q5_K_packed128
261
+ {
262
+ uvec4 q5k[11];
263
+ };
264
+
265
  #if defined(DATA_A_Q5_K)
266
  #define QUANT_K QUANT_K_Q5_K
267
  #define A_TYPE block_q5_K