jeffbolznv commited on
Commit
dc0e685
·
1 Parent(s): abf6f22

vulkan: Optimize some mat-vec mul quant shaders (llama/10296)

Browse files

Compute two result elements per workgroup (for Q{4,5}_{0,1}). This reuses
the B loads across the rows and also reuses some addressing calculations.
This required manually partially unrolling the loop, since the compiler
is less willing to unroll outer loops.

Add bounds-checking on the last iteration of the loop. I think this was at
least partly broken before.

Optimize the Q4_K shader to vectorize most loads and reduce the number of
bit twiddling instructions.

ggml/src/ggml-vulkan/ggml-vulkan.cpp CHANGED
@@ -1365,47 +1365,48 @@ static void ggml_vk_load_shaders(vk_device& device) {
1365
  }
1366
 
1367
  // mul mat vec
1368
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_F32 ], "mul_mat_vec_f32_f32_f32", mul_mat_vec_f32_f32_f32_len, mul_mat_vec_f32_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
1369
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_F16 ], "mul_mat_vec_f16_f32_f32", mul_mat_vec_f16_f32_f32_len, mul_mat_vec_f16_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
1370
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q4_0], "mul_mat_vec_q4_0_f32_f32", mul_mat_vec_q4_0_f32_f32_len, mul_mat_vec_q4_0_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
1371
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q4_1], "mul_mat_vec_q4_1_f32_f32", mul_mat_vec_q4_1_f32_f32_len, mul_mat_vec_q4_1_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
1372
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q5_0], "mul_mat_vec_q5_0_f32_f32", mul_mat_vec_q5_0_f32_f32_len, mul_mat_vec_q5_0_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
1373
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q5_1], "mul_mat_vec_q5_1_f32_f32", mul_mat_vec_q5_1_f32_f32_len, mul_mat_vec_q5_1_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
1374
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q8_0], "mul_mat_vec_q8_0_f32_f32", mul_mat_vec_q8_0_f32_f32_len, mul_mat_vec_q8_0_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
1375
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q2_K], "mul_mat_vec_q2_k_f32_f32", mul_mat_vec_q2_k_f32_f32_len, mul_mat_vec_q2_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
1376
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q3_K], "mul_mat_vec_q3_k_f32_f32", mul_mat_vec_q3_k_f32_f32_len, mul_mat_vec_q3_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
1377
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q4_K], "mul_mat_vec_q4_k_f32_f32", mul_mat_vec_q4_k_f32_f32_len, mul_mat_vec_q4_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
1378
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q5_K], "mul_mat_vec_q5_k_f32_f32", mul_mat_vec_q5_k_f32_f32_len, mul_mat_vec_q5_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
1379
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q6_K], "mul_mat_vec_q6_k_f32_f32", mul_mat_vec_q6_k_f32_f32_len, mul_mat_vec_q6_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
1380
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ4_NL], "mul_mat_vec_iq4_nl_f32_f32", mul_mat_vec_iq4_nl_f32_f32_len, mul_mat_vec_iq4_nl_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
1381
-
1382
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_F32 ], "mul_mat_vec_f32_f16_f32", mul_mat_vec_f32_f16_f32_len, mul_mat_vec_f32_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
1383
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_F16 ], "mul_mat_vec_f16_f16_f32", mul_mat_vec_f16_f16_f32_len, mul_mat_vec_f16_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
1384
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q4_0], "mul_mat_vec_q4_0_f16_f32", mul_mat_vec_q4_0_f16_f32_len, mul_mat_vec_q4_0_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
1385
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q4_1], "mul_mat_vec_q4_1_f16_f32", mul_mat_vec_q4_1_f16_f32_len, mul_mat_vec_q4_1_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
1386
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q5_0], "mul_mat_vec_q5_0_f16_f32", mul_mat_vec_q5_0_f16_f32_len, mul_mat_vec_q5_0_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
1387
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q5_1], "mul_mat_vec_q5_1_f16_f32", mul_mat_vec_q5_1_f16_f32_len, mul_mat_vec_q5_1_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
1388
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q8_0], "mul_mat_vec_q8_0_f16_f32", mul_mat_vec_q8_0_f16_f32_len, mul_mat_vec_q8_0_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
1389
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q2_K], "mul_mat_vec_q2_k_f16_f32", mul_mat_vec_q2_k_f16_f32_len, mul_mat_vec_q2_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
1390
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q3_K], "mul_mat_vec_q3_k_f16_f32", mul_mat_vec_q3_k_f16_f32_len, mul_mat_vec_q3_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
1391
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q4_K], "mul_mat_vec_q4_k_f16_f32", mul_mat_vec_q4_k_f16_f32_len, mul_mat_vec_q4_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
1392
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q5_K], "mul_mat_vec_q5_k_f16_f32", mul_mat_vec_q5_k_f16_f32_len, mul_mat_vec_q5_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
1393
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q6_K], "mul_mat_vec_q6_k_f16_f32", mul_mat_vec_q6_k_f16_f32_len, mul_mat_vec_q6_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
1394
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ4_NL], "mul_mat_vec_iq4_nl_f16_f32", mul_mat_vec_iq4_nl_f16_f32_len, mul_mat_vec_iq4_nl_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
1395
-
1396
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_F32 ], "mul_mat_vec_id_f32_f32", mul_mat_vec_id_f32_f32_len, mul_mat_vec_id_f32_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
1397
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_F16 ], "mul_mat_vec_id_f16_f32", mul_mat_vec_id_f16_f32_len, mul_mat_vec_id_f16_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
1398
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q4_0], "mul_mat_vec_id_q4_0_f32", mul_mat_vec_id_q4_0_f32_len, mul_mat_vec_id_q4_0_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
1399
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q4_1], "mul_mat_vec_id_q4_1_f32", mul_mat_vec_id_q4_1_f32_len, mul_mat_vec_id_q4_1_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
1400
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q5_0], "mul_mat_vec_id_q5_0_f32", mul_mat_vec_id_q5_0_f32_len, mul_mat_vec_id_q5_0_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
1401
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q5_1], "mul_mat_vec_id_q5_1_f32", mul_mat_vec_id_q5_1_f32_len, mul_mat_vec_id_q5_1_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
1402
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q8_0], "mul_mat_vec_id_q8_0_f32", mul_mat_vec_id_q8_0_f32_len, mul_mat_vec_id_q8_0_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
1403
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q2_K], "mul_mat_vec_id_q2_k_f32", mul_mat_vec_id_q2_k_f32_len, mul_mat_vec_id_q2_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
1404
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q3_K], "mul_mat_vec_id_q3_k_f32", mul_mat_vec_id_q3_k_f32_len, mul_mat_vec_id_q3_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
1405
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q4_K], "mul_mat_vec_id_q4_k_f32", mul_mat_vec_id_q4_k_f32_len, mul_mat_vec_id_q4_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
1406
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q5_K], "mul_mat_vec_id_q5_k_f32", mul_mat_vec_id_q5_k_f32_len, mul_mat_vec_id_q5_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
1407
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q6_K], "mul_mat_vec_id_q6_k_f32", mul_mat_vec_id_q6_k_f32_len, mul_mat_vec_id_q6_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
1408
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ4_NL], "mul_mat_vec_id_iq4_nl_f32", mul_mat_vec_id_iq4_nl_f32_len, mul_mat_vec_id_iq4_nl_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
 
1409
 
1410
  // dequant shaders
1411
  ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_F32 ], "f32_to_f16", dequant_f32_len, dequant_f32_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1);
 
1365
  }
1366
 
1367
  // mul mat vec
1368
+ // computing two rows per workgroup is a benefit for Q4_0 -> Q5_1, but not for Q8_0.
1369
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_F32 ], "mul_mat_vec_f32_f32_f32", mul_mat_vec_f32_f32_f32_len, mul_mat_vec_f32_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1);
1370
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_F16 ], "mul_mat_vec_f16_f32_f32", mul_mat_vec_f16_f32_f32_len, mul_mat_vec_f16_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1);
1371
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q4_0], "mul_mat_vec_q4_0_f32_f32", mul_mat_vec_q4_0_f32_f32_len, mul_mat_vec_q4_0_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1);
1372
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q4_1], "mul_mat_vec_q4_1_f32_f32", mul_mat_vec_q4_1_f32_f32_len, mul_mat_vec_q4_1_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1);
1373
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q5_0], "mul_mat_vec_q5_0_f32_f32", mul_mat_vec_q5_0_f32_f32_len, mul_mat_vec_q5_0_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1);
1374
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q5_1], "mul_mat_vec_q5_1_f32_f32", mul_mat_vec_q5_1_f32_f32_len, mul_mat_vec_q5_1_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1);
1375
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q8_0], "mul_mat_vec_q8_0_f32_f32", mul_mat_vec_q8_0_f32_f32_len, mul_mat_vec_q8_0_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {device->subgroup_size, 1}, 1);
1376
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q2_K], "mul_mat_vec_q2_k_f32_f32", mul_mat_vec_q2_k_f32_f32_len, mul_mat_vec_q2_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {device->subgroup_size}, 1);
1377
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q3_K], "mul_mat_vec_q3_k_f32_f32", mul_mat_vec_q3_k_f32_f32_len, mul_mat_vec_q3_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {device->subgroup_size}, 1);
1378
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q4_K], "mul_mat_vec_q4_k_f32_f32", mul_mat_vec_q4_k_f32_f32_len, mul_mat_vec_q4_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {device->subgroup_size}, 1);
1379
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q5_K], "mul_mat_vec_q5_k_f32_f32", mul_mat_vec_q5_k_f32_f32_len, mul_mat_vec_q5_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {device->subgroup_size}, 1);
1380
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q6_K], "mul_mat_vec_q6_k_f32_f32", mul_mat_vec_q6_k_f32_f32_len, mul_mat_vec_q6_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {device->subgroup_size}, 1);
1381
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ4_NL], "mul_mat_vec_iq4_nl_f32_f32", mul_mat_vec_iq4_nl_f32_f32_len, mul_mat_vec_iq4_nl_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1);
1382
+
1383
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_F32 ], "mul_mat_vec_f32_f16_f32", mul_mat_vec_f32_f16_f32_len, mul_mat_vec_f32_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1);
1384
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_F16 ], "mul_mat_vec_f16_f16_f32", mul_mat_vec_f16_f16_f32_len, mul_mat_vec_f16_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1);
1385
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q4_0], "mul_mat_vec_q4_0_f16_f32", mul_mat_vec_q4_0_f16_f32_len, mul_mat_vec_q4_0_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1);
1386
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q4_1], "mul_mat_vec_q4_1_f16_f32", mul_mat_vec_q4_1_f16_f32_len, mul_mat_vec_q4_1_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1);
1387
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q5_0], "mul_mat_vec_q5_0_f16_f32", mul_mat_vec_q5_0_f16_f32_len, mul_mat_vec_q5_0_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1);
1388
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q5_1], "mul_mat_vec_q5_1_f16_f32", mul_mat_vec_q5_1_f16_f32_len, mul_mat_vec_q5_1_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1);
1389
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q8_0], "mul_mat_vec_q8_0_f16_f32", mul_mat_vec_q8_0_f16_f32_len, mul_mat_vec_q8_0_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {device->subgroup_size, 1}, 1);
1390
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q2_K], "mul_mat_vec_q2_k_f16_f32", mul_mat_vec_q2_k_f16_f32_len, mul_mat_vec_q2_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {device->subgroup_size}, 1);
1391
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q3_K], "mul_mat_vec_q3_k_f16_f32", mul_mat_vec_q3_k_f16_f32_len, mul_mat_vec_q3_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {device->subgroup_size}, 1);
1392
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q4_K], "mul_mat_vec_q4_k_f16_f32", mul_mat_vec_q4_k_f16_f32_len, mul_mat_vec_q4_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {device->subgroup_size}, 1);
1393
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q5_K], "mul_mat_vec_q5_k_f16_f32", mul_mat_vec_q5_k_f16_f32_len, mul_mat_vec_q5_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {device->subgroup_size}, 1);
1394
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q6_K], "mul_mat_vec_q6_k_f16_f32", mul_mat_vec_q6_k_f16_f32_len, mul_mat_vec_q6_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {device->subgroup_size}, 1);
1395
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ4_NL], "mul_mat_vec_iq4_nl_f16_f32", mul_mat_vec_iq4_nl_f16_f32_len, mul_mat_vec_iq4_nl_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size}, 1);
1396
+
1397
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_F32 ], "mul_mat_vec_id_f32_f32", mul_mat_vec_id_f32_f32_len, mul_mat_vec_id_f32_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1);
1398
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_F16 ], "mul_mat_vec_id_f16_f32", mul_mat_vec_id_f16_f32_len, mul_mat_vec_id_f16_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1);
1399
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q4_0], "mul_mat_vec_id_q4_0_f32", mul_mat_vec_id_q4_0_f32_len, mul_mat_vec_id_q4_0_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1);
1400
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q4_1], "mul_mat_vec_id_q4_1_f32", mul_mat_vec_id_q4_1_f32_len, mul_mat_vec_id_q4_1_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1);
1401
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q5_0], "mul_mat_vec_id_q5_0_f32", mul_mat_vec_id_q5_0_f32_len, mul_mat_vec_id_q5_0_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1);
1402
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q5_1], "mul_mat_vec_id_q5_1_f32", mul_mat_vec_id_q5_1_f32_len, mul_mat_vec_id_q5_1_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1);
1403
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q8_0], "mul_mat_vec_id_q8_0_f32", mul_mat_vec_id_q8_0_f32_len, mul_mat_vec_id_q8_0_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {1, 1, 1}, {device->subgroup_size, 1}, 1);
1404
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q2_K], "mul_mat_vec_id_q2_k_f32", mul_mat_vec_id_q2_k_f32_len, mul_mat_vec_id_q2_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {1, 1, 1}, {device->subgroup_size}, 1);
1405
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q3_K], "mul_mat_vec_id_q3_k_f32", mul_mat_vec_id_q3_k_f32_len, mul_mat_vec_id_q3_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {1, 1, 1}, {device->subgroup_size}, 1);
1406
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q4_K], "mul_mat_vec_id_q4_k_f32", mul_mat_vec_id_q4_k_f32_len, mul_mat_vec_id_q4_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {1, 1, 1}, {device->subgroup_size}, 1);
1407
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q5_K], "mul_mat_vec_id_q5_k_f32", mul_mat_vec_id_q5_k_f32_len, mul_mat_vec_id_q5_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {1, 1, 1}, {device->subgroup_size}, 1);
1408
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q6_K], "mul_mat_vec_id_q6_k_f32", mul_mat_vec_id_q6_k_f32_len, mul_mat_vec_id_q6_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {1, 1, 1}, {device->subgroup_size}, 1);
1409
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ4_NL], "mul_mat_vec_id_iq4_nl_f32", mul_mat_vec_id_iq4_nl_f32_len, mul_mat_vec_id_iq4_nl_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1);
1410
 
1411
  // dequant shaders
1412
  ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_F32 ], "f32_to_f16", dequant_f32_len, dequant_f32_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1);
ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp CHANGED
@@ -3,54 +3,107 @@
3
  #ifdef FLOAT16
4
  #extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
5
  #endif
 
 
 
6
 
7
  #include "mul_mat_vec_base.comp"
8
 
9
  layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
10
 
11
  layout (constant_id = 0) const uint BLOCK_SIZE = 32;
 
12
 
13
- shared FLOAT_TYPE tmp[BLOCK_SIZE];
14
 
15
- void main() {
16
- const uint row = gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z;
17
- const uint tid = gl_LocalInvocationID.x;
18
 
19
- // There are not enough cols to use all threads
20
- if (tid >= p.ncols) {
21
- return;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  }
 
 
 
 
 
 
 
 
 
 
 
 
23
 
24
- const uint block_size = min(p.ncols, BLOCK_SIZE);
 
25
 
26
- uint a_offset, b_offset, d_offset;
27
  get_offsets(a_offset, b_offset, d_offset);
 
28
 
29
- const uint y_offset = QUANT_R == 1 ? 1 : QUANT_K/2;
30
 
31
- tmp[tid] = FLOAT_TYPE(0.0f);
32
 
33
- [[unroll]] for (uint i = 0; i < p.ncols/block_size; i += 2) {
34
- const uint col = i*block_size + 2*tid;
35
- const uint ib = (row*p.ncols + col)/QUANT_K; // block index
36
- const uint iqs = (col%QUANT_K)/QUANT_R; // quant index
37
- const uint iybs = col - col%QUANT_K; // y block start index
38
 
39
- vec2 v = dequantize(ib, iqs, a_offset / QUANT_K);
 
40
 
41
- // matrix multiplication
42
- tmp[tid] = fma(FLOAT_TYPE(v.x), FLOAT_TYPE(data_b[b_offset + iybs + iqs]), fma(FLOAT_TYPE(v.y), FLOAT_TYPE(data_b[b_offset + iybs + iqs + y_offset]), tmp[tid]));
 
 
 
 
 
 
 
 
 
43
  }
44
 
45
  // sum up partial sums and write back result
 
 
 
46
  barrier();
47
- [[unroll]] for (uint s = block_size/2; s > 0; s >>= 1) {
48
  if (tid < s) {
49
- tmp[tid] += tmp[tid + s];
 
 
50
  }
51
  barrier();
52
  }
53
  if (tid == 0) {
54
- data_d[d_offset + row] = D_TYPE(tmp[0]);
 
 
 
 
 
 
 
 
 
 
 
 
 
55
  }
56
  }
 
3
  #ifdef FLOAT16
4
  #extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
5
  #endif
6
+ #extension GL_EXT_shader_explicit_arithmetic_types_int32 : require
7
+
8
+ #extension GL_EXT_null_initializer : enable
9
 
10
  #include "mul_mat_vec_base.comp"
11
 
12
  layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
13
 
14
  layout (constant_id = 0) const uint BLOCK_SIZE = 32;
15
+ layout (constant_id = 1) const uint NUM_ROWS = 1;
16
 
17
+ uint a_offset, b_offset, d_offset, y_offset;
18
 
19
+ shared FLOAT_TYPE tmpsh[NUM_ROWS][BLOCK_SIZE];
 
 
20
 
21
+ void iter(inout FLOAT_TYPE temp[NUM_ROWS], const uint first_row, const uint num_rows, const uint tid, const uint i, bool lastiter)
22
+ {
23
+ const uint col = i*BLOCK_SIZE + 2*tid;
24
+ const uint iqs = (col%QUANT_K)/QUANT_R; // quant index
25
+ const uint iybs = col - col%QUANT_K; // y block start index
26
+
27
+ // Check if the second of the pair of elements is OOB, and don't fetch B or
28
+ // accumulate it. We still fetch a pair of elements for A, which is fine for
29
+ // quantized formats since they'll be within the same block. We should
30
+ // probably skip fetching the second element for F16/F32, but as of now we
31
+ // still do.
32
+ const bool OOB = lastiter && (iybs + iqs + y_offset >= p.ncols);
33
+
34
+ FLOAT_TYPE b0 = 0, b1 = 0;
35
+ b0 = FLOAT_TYPE(data_b[b_offset + iybs + iqs]);
36
+ if (!OOB) {
37
+ b1 = FLOAT_TYPE(data_b[b_offset + iybs + iqs + y_offset]);
38
  }
39
+ [[unroll]] for (uint n = 0; n < num_rows; ++n) {
40
+ const uint ib = ((first_row + n)*p.ncols + col)/QUANT_K; // block index
41
+
42
+ const vec2 v = dequantize(ib, iqs, a_offset);
43
+
44
+ // matrix multiplication
45
+ temp[n] = fma(FLOAT_TYPE(v.x), b0, temp[n]);
46
+ if (!OOB) {
47
+ temp[n] = fma(FLOAT_TYPE(v.y), b1, temp[n]);
48
+ }
49
+ }
50
+ }
51
 
52
+ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
53
+ const uint tid = gl_LocalInvocationID.x;
54
 
 
55
  get_offsets(a_offset, b_offset, d_offset);
56
+ a_offset /= QUANT_K;
57
 
58
+ y_offset = QUANT_R == 1 ? 1 : QUANT_K/2;
59
 
60
+ FLOAT_TYPE temp[NUM_ROWS] = {};
61
 
62
+ const int unroll_count = 8;
 
 
 
 
63
 
64
+ const uint num_iters = (p.ncols >= 2*tid) ? ((p.ncols - 2*tid + BLOCK_SIZE - 1) / BLOCK_SIZE) : 0;
65
+ const uint unrolled_iters = num_iters & ~(2*unroll_count - 1);
66
 
67
+ uint i = 0;
68
+ while (i < unrolled_iters) {
69
+ // Manually partially unroll the loop
70
+ [[unroll]] for (uint k = 0; k < unroll_count; ++k) {
71
+ iter(temp, first_row, num_rows, tid, i, false);
72
+ i += 2;
73
+ }
74
+ }
75
+ while (i < num_iters) {
76
+ iter(temp, first_row, num_rows, tid, i, true);
77
+ i += 2;
78
  }
79
 
80
  // sum up partial sums and write back result
81
+ [[unroll]] for (uint n = 0; n < num_rows; ++n) {
82
+ tmpsh[n][tid] = temp[n];
83
+ }
84
  barrier();
85
+ [[unroll]] for (uint s = BLOCK_SIZE/2; s > 0; s >>= 1) {
86
  if (tid < s) {
87
+ [[unroll]] for (uint n = 0; n < num_rows; ++n) {
88
+ tmpsh[n][tid] += tmpsh[n][tid + s];
89
+ }
90
  }
91
  barrier();
92
  }
93
  if (tid == 0) {
94
+ [[unroll]] for (uint n = 0; n < num_rows; ++n) {
95
+ data_d[d_offset + first_row + n] = D_TYPE(tmpsh[n][0]);
96
+ }
97
+ }
98
+ }
99
+
100
+ void main() {
101
+ const uint first_row = NUM_ROWS * (gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z);
102
+
103
+ // do NUM_ROWS at a time, unless there aren't enough remaining rows
104
+ if (first_row + NUM_ROWS <= p.stride_d) {
105
+ compute_outputs(first_row, NUM_ROWS);
106
+ } else {
107
+ compute_outputs(first_row, p.stride_d - first_row);
108
  }
109
  }
ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp CHANGED
@@ -1,11 +1,34 @@
1
  #version 450
2
 
 
 
3
  #include "mul_mat_vec_base.comp"
4
 
5
  layout(local_size_x = 32, local_size_y = 1, local_size_z = 1) in;
6
 
7
  shared FLOAT_TYPE tmp[32];
8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  void main() {
10
  const uint row = gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z;
11
 
@@ -31,79 +54,81 @@ void main() {
31
  const uint q_offset = 32*v_im + l0;
32
  const uint y_offset = 64*v_im + l0;
33
 
34
- tmp[16 * ix + tid] = FLOAT_TYPE(0.0); // partial sum for thread in warp
35
 
36
  [[unroll]] for (uint i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) {
37
  const uint y1_idx = i * QUANT_K + y_offset;
38
  const uint y2_idx = y1_idx + 128;
39
 
40
- const FLOAT_TYPE dall = FLOAT_TYPE(data_a[ib0 + i].d.x);
41
- const FLOAT_TYPE dmin = FLOAT_TYPE(data_a[ib0 + i].d.y);
42
-
43
- const uint8_t sc0 = uint8_t( data_a[ib0 + i].scales[v_im * 2 ] & 0x3f);
44
- const uint8_t sc1 = uint8_t( data_a[ib0 + i].scales[v_im * 2 + 1] & 0x3f);
45
- const uint8_t sc2 = uint8_t( data_a[ib0 + i].scales[v_im * 2 + 4] & 0x3f);
46
- const uint8_t sc3 = uint8_t( data_a[ib0 + i].scales[v_im * 2 + 5] & 0x3f);
47
- const uint8_t sc4 = uint8_t(( data_a[ib0 + i].scales[v_im * 2 + 8] & 0x0f) | ((data_a[ib0 + i].scales[v_im * 2 ] & 0xc0) >> 2));
48
- const uint8_t sc5 = uint8_t(( data_a[ib0 + i].scales[v_im * 2 + 9] & 0x0f) | ((data_a[ib0 + i].scales[v_im * 2 + 1] & 0xc0) >> 2));
49
- const uint8_t sc6 = uint8_t(((data_a[ib0 + i].scales[v_im * 2 + 8] >> 4) & 0x0f) | ((data_a[ib0 + i].scales[v_im * 2 + 4] & 0xc0) >> 2));
50
- const uint8_t sc7 = uint8_t(((data_a[ib0 + i].scales[v_im * 2 + 9] >> 4) & 0x0f) | ((data_a[ib0 + i].scales[v_im * 2 + 5] & 0xc0) >> 2));
51
-
52
- #if K_QUANTS_PER_ITERATION == 2
53
- const uint8_t q4_0 = uint8_t(data_a[ib0 + i].qs[q_offset ] & 0xf);
54
- const uint8_t q4_1 = uint8_t(data_a[ib0 + i].qs[q_offset + 1] & 0xf);
55
- const uint8_t q4_2 = uint8_t(data_a[ib0 + i].qs[q_offset + 2] & 0xf);
56
- const uint8_t q4_3 = uint8_t(data_a[ib0 + i].qs[q_offset + 3] & 0xf);
57
- const uint8_t q4_4 = uint8_t(data_a[ib0 + i].qs[q_offset ] >> 4);
58
- const uint8_t q4_5 = uint8_t(data_a[ib0 + i].qs[q_offset + 1] >> 4);
59
- const uint8_t q4_6 = uint8_t(data_a[ib0 + i].qs[q_offset + 2] >> 4);
60
- const uint8_t q4_7 = uint8_t(data_a[ib0 + i].qs[q_offset + 3] >> 4);
61
- const uint8_t q4_8 = uint8_t(data_a[ib0 + i].qs[q_offset + 64] & 0xf);
62
- const uint8_t q4_9 = uint8_t(data_a[ib0 + i].qs[q_offset + 65] & 0xf);
63
- const uint8_t q4_10 = uint8_t(data_a[ib0 + i].qs[q_offset + 66] & 0xf);
64
- const uint8_t q4_11 = uint8_t(data_a[ib0 + i].qs[q_offset + 67] & 0xf);
65
- const uint8_t q4_12 = uint8_t(data_a[ib0 + i].qs[q_offset + 64] >> 4);
66
- const uint8_t q4_13 = uint8_t(data_a[ib0 + i].qs[q_offset + 65] >> 4);
67
- const uint8_t q4_14 = uint8_t(data_a[ib0 + i].qs[q_offset + 66] >> 4);
68
- const uint8_t q4_15 = uint8_t(data_a[ib0 + i].qs[q_offset + 67] >> 4);
69
-
70
- const FLOAT_TYPE sx = fma(FLOAT_TYPE(data_b[b_offset + y1_idx]), q4_0, fma(FLOAT_TYPE(data_b[b_offset + y1_idx + 1]), q4_1, fma(FLOAT_TYPE(data_b[b_offset + y1_idx + 2]), q4_2, FLOAT_TYPE(data_b[b_offset + y1_idx + 3]) * q4_3)));
71
- const FLOAT_TYPE sy = fma(FLOAT_TYPE(data_b[b_offset + y1_idx + 32]), q4_4, fma(FLOAT_TYPE(data_b[b_offset + y1_idx + 33]), q4_5, fma(FLOAT_TYPE(data_b[b_offset + y1_idx + 34]), q4_6, FLOAT_TYPE(data_b[b_offset + y1_idx + 35]) * q4_7)));
72
- const FLOAT_TYPE sz = fma(FLOAT_TYPE(data_b[b_offset + y2_idx]), q4_8, fma(FLOAT_TYPE(data_b[b_offset + y2_idx + 1]), q4_9, fma(FLOAT_TYPE(data_b[b_offset + y2_idx + 2]), q4_10, FLOAT_TYPE(data_b[b_offset + y2_idx + 3]) * q4_11)));
73
- const FLOAT_TYPE sw = fma(FLOAT_TYPE(data_b[b_offset + y2_idx + 32]), q4_12, fma(FLOAT_TYPE(data_b[b_offset + y2_idx + 33]), q4_13, fma(FLOAT_TYPE(data_b[b_offset + y2_idx + 34]), q4_14, FLOAT_TYPE(data_b[b_offset + y2_idx + 35]) * q4_15)));
74
- const FLOAT_TYPE smin =
75
- fma(FLOAT_TYPE(data_b[b_offset + y1_idx ]), sc2, fma(FLOAT_TYPE(data_b[b_offset + y1_idx + 32]), sc3, fma(FLOAT_TYPE(data_b[b_offset + y2_idx ]), sc6, fma(FLOAT_TYPE(data_b[b_offset + y2_idx + 32]), sc7,
76
- fma(FLOAT_TYPE(data_b[b_offset + y1_idx + 1]), sc2, fma(FLOAT_TYPE(data_b[b_offset + y1_idx + 33]), sc3, fma(FLOAT_TYPE(data_b[b_offset + y2_idx + 1]), sc6, fma(FLOAT_TYPE(data_b[b_offset + y2_idx + 33]), sc7,
77
- fma(FLOAT_TYPE(data_b[b_offset + y1_idx + 2]), sc2, fma(FLOAT_TYPE(data_b[b_offset + y1_idx + 34]), sc3, fma(FLOAT_TYPE(data_b[b_offset + y2_idx + 2]), sc6, fma(FLOAT_TYPE(data_b[b_offset + y2_idx + 34]), sc7,
78
- fma(FLOAT_TYPE(data_b[b_offset + y1_idx + 3]), sc2, fma(FLOAT_TYPE(data_b[b_offset + y1_idx + 35]), sc3, fma(FLOAT_TYPE(data_b[b_offset + y2_idx + 3]), sc6, FLOAT_TYPE(data_b[b_offset + y2_idx + 35]) * sc7)))))))))))))));
79
- const uint tmp_idx = 16 * ix + tid;
80
- tmp[tmp_idx] = fma(dall, fma(sx, sc0, fma(sy, sc1, fma(sz, sc4, sw * sc5))), fma(-dmin, smin, tmp[tmp_idx]));
81
- #else
82
- const uint8_t q4_0 = uint8_t(data_a[ib0 + i].qs[q_offset ] & 0xf);
83
- const uint8_t q4_1 = uint8_t(data_a[ib0 + i].qs[q_offset + 1] & 0xf);
84
- const uint8_t q4_2 = uint8_t(data_a[ib0 + i].qs[q_offset ] >> 4);
85
- const uint8_t q4_3 = uint8_t(data_a[ib0 + i].qs[q_offset + 1] >> 4);
86
- const uint8_t q4_4 = uint8_t(data_a[ib0 + i].qs[q_offset + 64] & 0xf);
87
- const uint8_t q4_5 = uint8_t(data_a[ib0 + i].qs[q_offset + 65] & 0xf);
88
- const uint8_t q4_6 = uint8_t(data_a[ib0 + i].qs[q_offset + 64] >> 4);
89
- const uint8_t q4_7 = uint8_t(data_a[ib0 + i].qs[q_offset + 65] >> 4);
90
-
91
- const FLOAT_TYPE sx = fma(FLOAT_TYPE(data_b[b_offset + y1_idx ]), q4_0, FLOAT_TYPE(data_b[b_offset + y1_idx + 1]) * q4_1);
92
- const FLOAT_TYPE sy = fma(FLOAT_TYPE(data_b[b_offset + y1_idx + 32]), q4_2, FLOAT_TYPE(data_b[b_offset + y1_idx + 33]) * q4_3);
93
- const FLOAT_TYPE sz = fma(FLOAT_TYPE(data_b[b_offset + y2_idx ]), q4_4, FLOAT_TYPE(data_b[b_offset + y2_idx + 1]) * q4_5);
94
- const FLOAT_TYPE sw = fma(FLOAT_TYPE(data_b[b_offset + y2_idx + 32]), q4_6, FLOAT_TYPE(data_b[b_offset + y2_idx + 33]) * q4_7);
 
 
 
 
95
  const FLOAT_TYPE smin =
96
- fma(FLOAT_TYPE(data_b[b_offset + y1_idx ]), sc2, fma(FLOAT_TYPE(data_b[b_offset + y1_idx + 32]), sc3, fma(FLOAT_TYPE(data_b[b_offset + y2_idx ]), sc6, fma(FLOAT_TYPE(data_b[b_offset + y2_idx + 32]), sc7,
97
- + fma(FLOAT_TYPE(data_b[b_offset + y1_idx + 1]), sc2, fma(FLOAT_TYPE(data_b[b_offset + y1_idx + 33]), sc3, fma(FLOAT_TYPE(data_b[b_offset + y2_idx + 1]), sc6, FLOAT_TYPE(data_b[b_offset + y2_idx + 33]) * sc7)))))));
98
-
99
- tmp[16 * ix + tid] += FLOAT_TYPE(dall * (sx * FLOAT_TYPE(data_a[ib0 + i].scales[v_im] & 0x3f) + sy * FLOAT_TYPE(data_a[ib0 + i].scales[v_im + 1] & 0x3f) +
100
- sz * FLOAT_TYPE((data_a[ib0 + i].scales[v_im + 4] & 0x0f) | ((data_a[ib0 + i].scales[v_im] & 0xc0) >> 2)) + sw * FLOAT_TYPE((data_a[ib0 + i].scales[v_im + 5] & 0x0f) | ((data_a[ib0 + i].scales[v_im + 1] & 0xc0) >> 2))) - dmin * smin);
101
- const uint tmp_idx = 16 * ix + tid;
102
- tmp[tmp_idx] = fma(dall, (fma(sx, FLOAT_TYPE(data_a[ib0 + i].scales[v_im] & 0x3f), fma(sy, FLOAT_TYPE(data_a[ib0 + i].scales[v_im + 1] & 0x3f),
103
- fma(sz, FLOAT_TYPE((data_a[ib0 + i].scales[v_im + 4] & 0x0f) | ((data_a[ib0 + i].scales[v_im] & 0xc0) >> 2)), fma(sw, FLOAT_TYPE((data_a[ib0 + i].scales[v_im + 5] & 0x0f) | ((data_a[ib0 + i].scales[v_im + 1] & 0xc0) >> 2))))))), fma(-dmin, smin, tmp[tmp_idx]));
104
- #endif
105
  }
106
 
 
 
107
  // sum up partial sums and write back result
108
  barrier();
109
  [[unroll]] for (uint s = 16; s > 0; s >>= 1) {
 
1
  #version 450
2
 
3
+ #extension GL_EXT_shader_explicit_arithmetic_types : require
4
+
5
  #include "mul_mat_vec_base.comp"
6
 
7
  layout(local_size_x = 32, local_size_y = 1, local_size_z = 1) in;
8
 
9
  shared FLOAT_TYPE tmp[32];
10
 
11
+ // Declare aliased versions of A and B bindings that can use 16b/32b loads for
12
+ // the quantized values, and vec4 loads for B.
13
+ struct block_q4_K_u32
14
+ {
15
+ f16vec2 d;
16
+ uint32_t scales[3*QUANT_K/64/4];
17
+ uint32_t qs[QUANT_K/2/4];
18
+ };
19
+
20
+ struct block_q4_K_u16
21
+ {
22
+ f16vec2 d;
23
+ uint16_t scales[3*QUANT_K/64/2];
24
+ uint16_t qs[QUANT_K/2/2];
25
+ };
26
+
27
+ layout (binding = 0) readonly buffer A_u32 {block_q4_K_u32 data_a_u32[];};
28
+ layout (binding = 0) readonly buffer A_u16 {block_q4_K_u16 data_a_u16[];};
29
+ layout (binding = 1) readonly buffer BV4 {B_TYPE_VEC4 data_b_v4[];};
30
+
31
+ // This shader assumes K_QUANTS_PER_ITERATION == 2 for alignment of loads
32
  void main() {
33
  const uint row = gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z;
34
 
 
54
  const uint q_offset = 32*v_im + l0;
55
  const uint y_offset = 64*v_im + l0;
56
 
57
+ FLOAT_TYPE temp = FLOAT_TYPE(0.0); // partial sum for thread in warp
58
 
59
  [[unroll]] for (uint i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) {
60
  const uint y1_idx = i * QUANT_K + y_offset;
61
  const uint y2_idx = y1_idx + 128;
62
 
63
+ f16vec2 d = data_a[ib0 + i].d;
64
+ const FLOAT_TYPE dall = FLOAT_TYPE(d.x);
65
+ const FLOAT_TYPE dmin = FLOAT_TYPE(d.y);
66
+
67
+ uint32_t scale0_u32 = data_a_u16[ib0 + i].scales[v_im ];
68
+ uint32_t scale4_u32 = data_a_u16[ib0 + i].scales[v_im + 2];
69
+ uint32_t scale8_u32 = data_a_u16[ib0 + i].scales[v_im + 4];
70
+ uvec4 scale0 = uvec4(unpack8(scale0_u32));
71
+ uvec4 scale4 = uvec4(unpack8(scale4_u32));
72
+ uvec4 scale8 = uvec4(unpack8(scale8_u32));
73
+
74
+ const uint32_t sc0 = ( scale0.x & 0x3f);
75
+ const uint32_t sc1 = ( scale0.y & 0x3f);
76
+ const uint32_t sc2 = ( scale4.x & 0x3f);
77
+ const uint32_t sc3 = ( scale4.y & 0x3f);
78
+ const uint32_t sc4 = (( scale8.x & 0x0f) | ((scale0.x & 0xc0) >> 2));
79
+ const uint32_t sc5 = (( scale8.y & 0x0f) | ((scale0.y & 0xc0) >> 2));
80
+ const uint32_t sc6 = (((scale8.x >> 4) & 0x0f) | ((scale4.x & 0xc0) >> 2));
81
+ const uint32_t sc7 = (((scale8.y >> 4) & 0x0f) | ((scale4.y & 0xc0) >> 2));
82
+
83
+ uint32_t qs0_u32 = data_a_u32[ib0 + i].qs[q_offset / 4];
84
+ uint32_t qs64_u32 = data_a_u32[ib0 + i].qs[q_offset / 4 + 16];
85
+
86
+ uint32_t qs0_u32_lo4 = qs0_u32 & 0x0F0F0F0F;
87
+ uint32_t qs0_u32_hi4 = (qs0_u32 >> 4) & 0x0F0F0F0F;
88
+ uint32_t qs64_u32_lo4 = qs64_u32 & 0x0F0F0F0F;
89
+ uint32_t qs64_u32_hi4 = (qs64_u32 >> 4) & 0x0F0F0F0F;
90
+
91
+ uvec4 qs0_lo4 = uvec4(unpack8(qs0_u32_lo4));
92
+ uvec4 qs64_lo4 = uvec4(unpack8(qs64_u32_lo4));
93
+ uvec4 qs0_hi4 = uvec4(unpack8(qs0_u32_hi4));
94
+ uvec4 qs64_hi4 = uvec4(unpack8(qs64_u32_hi4));
95
+
96
+ const uint32_t q4_0 = qs0_lo4.x;
97
+ const uint32_t q4_1 = qs0_lo4.y;
98
+ const uint32_t q4_2 = qs0_lo4.z;
99
+ const uint32_t q4_3 = qs0_lo4.w;
100
+ const uint32_t q4_4 = qs0_hi4.x;
101
+ const uint32_t q4_5 = qs0_hi4.y;
102
+ const uint32_t q4_6 = qs0_hi4.z;
103
+ const uint32_t q4_7 = qs0_hi4.w;
104
+ const uint32_t q4_8 = qs64_lo4.x;
105
+ const uint32_t q4_9 = qs64_lo4.y;
106
+ const uint32_t q4_10 = qs64_lo4.z;
107
+ const uint32_t q4_11 = qs64_lo4.w;
108
+ const uint32_t q4_12 = qs64_hi4.x;
109
+ const uint32_t q4_13 = qs64_hi4.y;
110
+ const uint32_t q4_14 = qs64_hi4.z;
111
+ const uint32_t q4_15 = qs64_hi4.w;
112
+
113
+ B_TYPE_VEC4 by10 = data_b_v4[(b_offset + y1_idx) / 4];
114
+ B_TYPE_VEC4 by132 = data_b_v4[(b_offset + y1_idx) / 4 + 8];
115
+ B_TYPE_VEC4 by20 = data_b_v4[(b_offset + y2_idx) / 4];
116
+ B_TYPE_VEC4 by232 = data_b_v4[(b_offset + y2_idx) / 4 + 8];
117
+
118
+ const FLOAT_TYPE sx = fma(FLOAT_TYPE(by10.x), q4_0, fma(FLOAT_TYPE(by10.y), q4_1, fma(FLOAT_TYPE(by10.z), q4_2, FLOAT_TYPE(by10.w) * q4_3)));
119
+ const FLOAT_TYPE sy = fma(FLOAT_TYPE(by132.x), q4_4, fma(FLOAT_TYPE(by132.y), q4_5, fma(FLOAT_TYPE(by132.z), q4_6, FLOAT_TYPE(by132.w) * q4_7)));
120
+ const FLOAT_TYPE sz = fma(FLOAT_TYPE(by20.x), q4_8, fma(FLOAT_TYPE(by20.y), q4_9, fma(FLOAT_TYPE(by20.z), q4_10, FLOAT_TYPE(by20.w) * q4_11)));
121
+ const FLOAT_TYPE sw = fma(FLOAT_TYPE(by232.x), q4_12, fma(FLOAT_TYPE(by232.y), q4_13, fma(FLOAT_TYPE(by232.z), q4_14, FLOAT_TYPE(by232.w) * q4_15)));
122
  const FLOAT_TYPE smin =
123
+ fma(FLOAT_TYPE(by10.x), sc2, fma(FLOAT_TYPE(by132.x), sc3, fma(FLOAT_TYPE(by20.x), sc6, fma(FLOAT_TYPE(by232.x), sc7,
124
+ fma(FLOAT_TYPE(by10.y), sc2, fma(FLOAT_TYPE(by132.y), sc3, fma(FLOAT_TYPE(by20.y), sc6, fma(FLOAT_TYPE(by232.y), sc7,
125
+ fma(FLOAT_TYPE(by10.z), sc2, fma(FLOAT_TYPE(by132.z), sc3, fma(FLOAT_TYPE(by20.z), sc6, fma(FLOAT_TYPE(by232.z), sc7,
126
+ fma(FLOAT_TYPE(by10.w), sc2, fma(FLOAT_TYPE(by132.w), sc3, fma(FLOAT_TYPE(by20.w), sc6, FLOAT_TYPE(by232.w) * sc7)))))))))))))));
127
+ temp = fma(dall, fma(sx, sc0, fma(sy, sc1, fma(sz, sc4, sw * sc5))), fma(-dmin, smin, temp));
 
 
 
 
128
  }
129
 
130
+ tmp[gl_LocalInvocationID.x] = temp;
131
+
132
  // sum up partial sums and write back result
133
  barrier();
134
  [[unroll]] for (uint s = 16; s > 0; s >>= 1) {
ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp CHANGED
@@ -317,10 +317,10 @@ void process_shaders() {
317
  std::string data_a_key = "DATA_A_" + to_uppercase(tname);
318
  std::string shader = (string_ends_with(tname, "_k")) ? "mul_mat_vec_" + tname + ".comp" : "mul_mat_vec.comp";
319
 
320
- string_to_spv("mul_mat_vec_" + tname + "_f32_f32", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}));
321
- string_to_spv("mul_mat_vec_" + tname + "_f16_f32", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}}));
322
 
323
- string_to_spv("mul_mat_vec_id_" + tname + "_f32", shader, merge_maps(base_dict, {{"MUL_MAT_ID", "1"}, {data_a_key, "1"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}));
324
 
325
  // Dequant shaders
326
  if (tname != "f16") {
 
317
  std::string data_a_key = "DATA_A_" + to_uppercase(tname);
318
  std::string shader = (string_ends_with(tname, "_k")) ? "mul_mat_vec_" + tname + ".comp" : "mul_mat_vec.comp";
319
 
320
+ string_to_spv("mul_mat_vec_" + tname + "_f32_f32", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}}));
321
+ string_to_spv("mul_mat_vec_" + tname + "_f16_f32", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float16_t"}, {"B_TYPE_VEC4", "f16vec4"}, {"D_TYPE", "float"}}));
322
 
323
+ string_to_spv("mul_mat_vec_id_" + tname + "_f32", shader, merge_maps(base_dict, {{"MUL_MAT_ID", "1"}, {data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}}));
324
 
325
  // Dequant shaders
326
  if (tname != "f16") {