Spaces:
Running
Running
Commit
·
ee122d3
1
Parent(s):
d49a569
vulkan: optimize coopmat2 q4_k/q5_k dequant functions. (llama/11206)
Browse files
ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.comp
CHANGED
|
@@ -163,39 +163,47 @@ layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ4
|
|
| 163 |
block_q4_K_packed16 block;
|
| 164 |
};
|
| 165 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 166 |
float16_t dequantFuncQ4_K(const in decodeBufQ4_K bl, const in uint blockCoords[2], const in uint coordInBlock[2])
|
| 167 |
{
|
| 168 |
decodeBufQ4_K_packed16 bl16 = decodeBufQ4_K_packed16(bl);
|
|
|
|
| 169 |
const uint idx = coordInBlock[1];
|
| 170 |
|
| 171 |
const uint b = (idx & 0x20) >> 5; // 0,1
|
| 172 |
const uint is = (idx & 0xE0) >> 5; // 0..7
|
| 173 |
|
| 174 |
-
|
|
|
|
|
|
|
| 175 |
|
| 176 |
uint32_t sc;
|
| 177 |
uint32_t mbyte;
|
| 178 |
|
| 179 |
-
uint32_t
|
| 180 |
-
uint32_t
|
| 181 |
-
uint32_t
|
| 182 |
-
uint32_t scidxshift1 = (is < 4) ? 0 : 2;
|
| 183 |
-
uint32_t mbidx0 = is + 4;
|
| 184 |
-
uint32_t mbidx1 = (is < 4) ? is + 4 : is;
|
| 185 |
-
uint32_t mbidxmask0 = (is < 4) ? 0xF : 0xF0;
|
| 186 |
-
uint32_t mbidxshift0 = (is < 4) ? 0 : 4;
|
| 187 |
-
uint32_t mbidxmask1 = (is < 4) ? 0x30 : 0xC0;
|
| 188 |
-
uint32_t mbidxshift1 = (is < 4) ? 0 : 2;
|
| 189 |
|
| 190 |
-
|
| 191 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 192 |
|
| 193 |
const float16_t d = loadd.x * float16_t(sc);
|
| 194 |
const float16_t m = loadd.y * float16_t(mbyte);
|
| 195 |
|
| 196 |
uint qs = uint32_t(bl16.block.qs[((idx & 0xC0) >> 2) + ((idx & 0x1E) >> 1)]);
|
| 197 |
-
qs = (qs >> (b * 4)) &
|
| 198 |
-
qs = unpack8(qs)[idx & 1];
|
| 199 |
|
| 200 |
float16_t ret = d * float16_t(qs) - m;
|
| 201 |
|
|
@@ -210,47 +218,53 @@ layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ5
|
|
| 210 |
block_q5_K_packed16 block;
|
| 211 |
};
|
| 212 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 213 |
float16_t dequantFuncQ5_K(const in decodeBufQ5_K bl, const in uint blockCoords[2], const in uint coordInBlock[2])
|
| 214 |
{
|
| 215 |
decodeBufQ5_K_packed16 bl16 = decodeBufQ5_K_packed16(bl);
|
|
|
|
| 216 |
const uint idx = coordInBlock[1];
|
| 217 |
|
| 218 |
const uint b = (idx & 0x20) >> 5; // 0,1
|
| 219 |
const uint is = (idx & 0xE0) >> 5; // 0..7
|
| 220 |
|
| 221 |
-
|
| 222 |
|
| 223 |
-
const f16vec2 loadd =
|
| 224 |
|
| 225 |
uint32_t sc;
|
| 226 |
uint32_t mbyte;
|
| 227 |
|
| 228 |
-
uint32_t
|
| 229 |
-
uint32_t
|
| 230 |
-
uint32_t
|
| 231 |
-
uint32_t scidxshift1 = (is < 4) ? 0 : 2;
|
| 232 |
-
uint32_t mbidx0 = is + 4;
|
| 233 |
-
uint32_t mbidx1 = (is < 4) ? is + 4 : is;
|
| 234 |
-
uint32_t mbidxmask0 = (is < 4) ? 0xF : 0xF0;
|
| 235 |
-
uint32_t mbidxshift0 = (is < 4) ? 0 : 4;
|
| 236 |
-
uint32_t mbidxmask1 = (is < 4) ? 0x30 : 0xC0;
|
| 237 |
-
uint32_t mbidxshift1 = (is < 4) ? 0 : 2;
|
| 238 |
|
| 239 |
-
|
| 240 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 241 |
|
| 242 |
const float16_t d = loadd.x * float16_t(sc);
|
| 243 |
const float16_t m = loadd.y * float16_t(mbyte);
|
| 244 |
|
| 245 |
uint qh = uint32_t(bl16.block.qh[(idx & 0x1E) >> 1]);
|
| 246 |
-
qh = qh &
|
| 247 |
-
qh = unpack8(qh)[idx & 1];
|
| 248 |
|
| 249 |
uint qs = uint32_t(bl16.block.qs[((idx & 0xC0) >> 2) + ((idx & 0x1E) >> 1)]);
|
| 250 |
qs = (qs >> (b * 4)) & 0x0F0F;
|
| 251 |
-
qs = unpack8(qs)[idx & 1];
|
| 252 |
|
| 253 |
-
float16_t ret = d * (float16_t(qs)
|
| 254 |
|
| 255 |
return ret;
|
| 256 |
}
|
|
|
|
| 163 |
block_q4_K_packed16 block;
|
| 164 |
};
|
| 165 |
|
| 166 |
+
layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ4_K_packed128 {
|
| 167 |
+
block_q4_K_packed128 block;
|
| 168 |
+
};
|
| 169 |
+
|
| 170 |
float16_t dequantFuncQ4_K(const in decodeBufQ4_K bl, const in uint blockCoords[2], const in uint coordInBlock[2])
|
| 171 |
{
|
| 172 |
decodeBufQ4_K_packed16 bl16 = decodeBufQ4_K_packed16(bl);
|
| 173 |
+
decodeBufQ4_K_packed128 bl128 = decodeBufQ4_K_packed128(bl);
|
| 174 |
const uint idx = coordInBlock[1];
|
| 175 |
|
| 176 |
const uint b = (idx & 0x20) >> 5; // 0,1
|
| 177 |
const uint is = (idx & 0xE0) >> 5; // 0..7
|
| 178 |
|
| 179 |
+
uvec4 v = bl128.block.q4k[0];
|
| 180 |
+
|
| 181 |
+
const f16vec2 loadd = unpackFloat2x16(v.x);
|
| 182 |
|
| 183 |
uint32_t sc;
|
| 184 |
uint32_t mbyte;
|
| 185 |
|
| 186 |
+
uint32_t scale0 = v.y;
|
| 187 |
+
uint32_t scale4 = v.z;
|
| 188 |
+
uint32_t scale8 = v.w;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 189 |
|
| 190 |
+
uint32_t sc_lo = scale0;
|
| 191 |
+
uint32_t mb_lo = scale4;
|
| 192 |
+
uint32_t sc_hi = (scale8 & 0x0F0F0F0F) | ((scale0 & 0xC0C0C0C0) >> 2);
|
| 193 |
+
uint32_t mb_hi = ((scale8 & 0xF0F0F0F0) >> 4) | ((scale4 & 0xC0C0C0C0) >> 2);
|
| 194 |
+
|
| 195 |
+
sc = is < 4 ? sc_lo : sc_hi;
|
| 196 |
+
mbyte = is < 4 ? mb_lo : mb_hi;
|
| 197 |
+
sc = sc >> (8 * (is & 3));
|
| 198 |
+
mbyte = mbyte >> (8 * (is & 3));
|
| 199 |
+
sc &= 0x3F;
|
| 200 |
+
mbyte &= 0x3F;
|
| 201 |
|
| 202 |
const float16_t d = loadd.x * float16_t(sc);
|
| 203 |
const float16_t m = loadd.y * float16_t(mbyte);
|
| 204 |
|
| 205 |
uint qs = uint32_t(bl16.block.qs[((idx & 0xC0) >> 2) + ((idx & 0x1E) >> 1)]);
|
| 206 |
+
qs = (qs >> (b * 4 + 8 * (idx & 1))) & 0xF;
|
|
|
|
| 207 |
|
| 208 |
float16_t ret = d * float16_t(qs) - m;
|
| 209 |
|
|
|
|
| 218 |
block_q5_K_packed16 block;
|
| 219 |
};
|
| 220 |
|
| 221 |
+
layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ5_K_packed128 {
|
| 222 |
+
block_q5_K_packed128 block;
|
| 223 |
+
};
|
| 224 |
+
|
| 225 |
float16_t dequantFuncQ5_K(const in decodeBufQ5_K bl, const in uint blockCoords[2], const in uint coordInBlock[2])
|
| 226 |
{
|
| 227 |
decodeBufQ5_K_packed16 bl16 = decodeBufQ5_K_packed16(bl);
|
| 228 |
+
decodeBufQ5_K_packed128 bl128 = decodeBufQ5_K_packed128(bl);
|
| 229 |
const uint idx = coordInBlock[1];
|
| 230 |
|
| 231 |
const uint b = (idx & 0x20) >> 5; // 0,1
|
| 232 |
const uint is = (idx & 0xE0) >> 5; // 0..7
|
| 233 |
|
| 234 |
+
uvec4 v = bl128.block.q5k[0];
|
| 235 |
|
| 236 |
+
const f16vec2 loadd = unpackFloat2x16(v.x);
|
| 237 |
|
| 238 |
uint32_t sc;
|
| 239 |
uint32_t mbyte;
|
| 240 |
|
| 241 |
+
uint32_t scale0 = v.y;
|
| 242 |
+
uint32_t scale4 = v.z;
|
| 243 |
+
uint32_t scale8 = v.w;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 244 |
|
| 245 |
+
uint32_t sc_lo = scale0;
|
| 246 |
+
uint32_t mb_lo = scale4;
|
| 247 |
+
uint32_t sc_hi = (scale8 & 0x0F0F0F0F) | ((scale0 & 0xC0C0C0C0) >> 2);
|
| 248 |
+
uint32_t mb_hi = ((scale8 & 0xF0F0F0F0) >> 4) | ((scale4 & 0xC0C0C0C0) >> 2);
|
| 249 |
+
|
| 250 |
+
sc = is < 4 ? sc_lo : sc_hi;
|
| 251 |
+
mbyte = is < 4 ? mb_lo : mb_hi;
|
| 252 |
+
sc = sc >> (8 * (is & 3));
|
| 253 |
+
mbyte = mbyte >> (8 * (is & 3));
|
| 254 |
+
sc &= 0x3F;
|
| 255 |
+
mbyte &= 0x3F;
|
| 256 |
|
| 257 |
const float16_t d = loadd.x * float16_t(sc);
|
| 258 |
const float16_t m = loadd.y * float16_t(mbyte);
|
| 259 |
|
| 260 |
uint qh = uint32_t(bl16.block.qh[(idx & 0x1E) >> 1]);
|
| 261 |
+
qh = ((qh >> is) & 0x101) << 4;
|
|
|
|
| 262 |
|
| 263 |
uint qs = uint32_t(bl16.block.qs[((idx & 0xC0) >> 2) + ((idx & 0x1E) >> 1)]);
|
| 264 |
qs = (qs >> (b * 4)) & 0x0F0F;
|
| 265 |
+
qs = unpack8(qs | qh)[idx & 1];
|
| 266 |
|
| 267 |
+
float16_t ret = d * (float16_t(qs)) - m;
|
| 268 |
|
| 269 |
return ret;
|
| 270 |
}
|
ggml/src/ggml-vulkan/vulkan-shaders/types.comp
CHANGED
|
@@ -227,6 +227,11 @@ struct block_q4_K_packed32
|
|
| 227 |
uint32_t qs[QUANT_K_Q4_K/2/4];
|
| 228 |
};
|
| 229 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 230 |
#if defined(DATA_A_Q4_K)
|
| 231 |
#define QUANT_K QUANT_K_Q4_K
|
| 232 |
#define A_TYPE block_q4_K
|
|
@@ -252,6 +257,11 @@ struct block_q5_K_packed16
|
|
| 252 |
uint16_t qs[QUANT_K_Q5_K/2/2];
|
| 253 |
};
|
| 254 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 255 |
#if defined(DATA_A_Q5_K)
|
| 256 |
#define QUANT_K QUANT_K_Q5_K
|
| 257 |
#define A_TYPE block_q5_K
|
|
|
|
| 227 |
uint32_t qs[QUANT_K_Q4_K/2/4];
|
| 228 |
};
|
| 229 |
|
| 230 |
+
struct block_q4_K_packed128
|
| 231 |
+
{
|
| 232 |
+
uvec4 q4k[9];
|
| 233 |
+
};
|
| 234 |
+
|
| 235 |
#if defined(DATA_A_Q4_K)
|
| 236 |
#define QUANT_K QUANT_K_Q4_K
|
| 237 |
#define A_TYPE block_q4_K
|
|
|
|
| 257 |
uint16_t qs[QUANT_K_Q5_K/2/2];
|
| 258 |
};
|
| 259 |
|
| 260 |
+
struct block_q5_K_packed128
|
| 261 |
+
{
|
| 262 |
+
uvec4 q5k[11];
|
| 263 |
+
};
|
| 264 |
+
|
| 265 |
#if defined(DATA_A_Q5_K)
|
| 266 |
#define QUANT_K QUANT_K_Q5_K
|
| 267 |
#define A_TYPE block_q5_K
|