jeffbolznv commited on
Commit
50a2978
·
1 Parent(s): e0bf47c

vulkan: further optimize mul_mat_vec using larger loads (llama/10387)

Browse files

* vulkan: Use pipeline_robustness to disable robustness in mul_mat_vec.

Add some early returns for nonexistent rows in mul_mat_vec shaders. These
can only be hit when dispatching a 2D grid of workgroups. Fix the logic
for the 2D grid of workgroups to round up.

Enable the pipeline robustness extension if it's available, and use it to
disable robustness for these pipelines. The instructions to do the bounds
checking contend for the same ALU resources as the bit twiddling dequant
instructions.

* vulkan: Add GLSL structure aliases for quant types to allow larger loads

In Vulkan it's not possible to cast pointer types, so instead you have to
declare an aliased binding for the memory with a different type. This
commit adds aliases for the quant formats using 16b ints, and in a few
places where the struct size is a multiple of 4 also using 32b ints.
Currently only q4_k's aliases are used, but others will be used in
subsequent commits.

* vulkan: use larger loads in q5_k and q6_k shaders.

Similar to the optimization I did in q4_k recently, this vectorizes some loads
and reduces the number of bit twiddling instructions.

* vulkan: use larger K step per iteration in mul_mat_vec.

Add vec4 dequantization functions, and use them to do K=8 per iteration in
mul_mat_vec. This uses 16b loads for the quant values and 128b loads for B
which helps reduce the load on the memory system.

The K_PER_ITER==2 logic is still there, just for F16/F32, and really only
because they support unaligned sizes.

Tweak the num_iters/unrolling logic to be simpler and catch a couple missed
unrolling opportunities.

ggml/src/ggml-vulkan/ggml-vulkan.cpp CHANGED
@@ -158,6 +158,7 @@ struct vk_device_struct {
158
  std::string name;
159
  uint64_t max_memory_allocation_size;
160
  bool fp16;
 
161
  vk::Device device;
162
  uint32_t vendor_id;
163
  vk_queue compute_queue;
@@ -654,7 +655,7 @@ static uint32_t compile_count = 0;
654
  static std::mutex compile_count_mutex;
655
  static std::condition_variable compile_count_cond;
656
 
657
- static void ggml_vk_create_pipeline_func(vk_device& device, vk_pipeline& pipeline, const std::string name, size_t spv_size, const void* spv_data, const std::string entrypoint, uint32_t parameter_count, uint32_t push_constant_size, std::array<uint32_t, 3> wg_denoms, std::vector<uint32_t> specialization_constants, uint32_t align) {
658
  VK_LOG_DEBUG("ggml_vk_create_pipeline(" << device->name << ", " << name << ", " << entrypoint << ", " << parameter_count << ", " << push_constant_size << ", (" << wg_denoms[0] << "," << wg_denoms[1] << "," << wg_denoms[2] << "), specialization_constants, " << align << ")");
659
  GGML_ASSERT(parameter_count > 0);
660
  GGML_ASSERT(wg_denoms[0] > 0 && wg_denoms[1] > 0 && wg_denoms[2] > 0); // NOLINT
@@ -724,6 +725,15 @@ static void ggml_vk_create_pipeline_func(vk_device& device, vk_pipeline& pipelin
724
  vk::PipelineCreateFlags(),
725
  pipeline_shader_create_info,
726
  pipeline->layout);
 
 
 
 
 
 
 
 
 
727
  pipeline->pipeline = device->device.createComputePipeline(VK_NULL_HANDLE, compute_pipeline_create_info).value;
728
 
729
  {
@@ -1261,7 +1271,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
1261
  device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL] = std::make_shared<vk_matmul_pipeline_struct>();
1262
 
1263
  std::vector<std::future<void>> compiles;
1264
- auto const &ggml_vk_create_pipeline = [&](vk_device& device, vk_pipeline& pipeline, const std::string &name, size_t spv_size, const void* spv_data, const std::string &entrypoint, uint32_t parameter_count, uint32_t push_constant_size, std::array<uint32_t, 3> wg_denoms, const std::vector<uint32_t>& specialization_constants, uint32_t align) {
1265
  {
1266
  // wait until fewer than N compiles are in progress
1267
  uint32_t N = std::max(1u, std::thread::hardware_concurrency());
@@ -1271,7 +1281,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
1271
  }
1272
  compile_count++;
1273
  }
1274
- compiles.push_back(std::async(ggml_vk_create_pipeline_func, std::ref(device), std::ref(pipeline), name, spv_size, spv_data, entrypoint, parameter_count, push_constant_size, wg_denoms, specialization_constants, align));
1275
  };
1276
 
1277
  if (device->fp16) {
@@ -1370,45 +1380,45 @@ static void ggml_vk_load_shaders(vk_device& device) {
1370
  // computing two rows per workgroup is a benefit for Q4_0 -> Q5_1, but not for Q8_0.
1371
  ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_F32 ], "mul_mat_vec_f32_f32_f32", mul_mat_vec_f32_f32_f32_len, mul_mat_vec_f32_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1);
1372
  ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_F16 ], "mul_mat_vec_f16_f32_f32", mul_mat_vec_f16_f32_f32_len, mul_mat_vec_f16_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1);
1373
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q4_0], "mul_mat_vec_q4_0_f32_f32", mul_mat_vec_q4_0_f32_f32_len, mul_mat_vec_q4_0_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1);
1374
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q4_1], "mul_mat_vec_q4_1_f32_f32", mul_mat_vec_q4_1_f32_f32_len, mul_mat_vec_q4_1_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1);
1375
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q5_0], "mul_mat_vec_q5_0_f32_f32", mul_mat_vec_q5_0_f32_f32_len, mul_mat_vec_q5_0_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1);
1376
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q5_1], "mul_mat_vec_q5_1_f32_f32", mul_mat_vec_q5_1_f32_f32_len, mul_mat_vec_q5_1_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1);
1377
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q8_0], "mul_mat_vec_q8_0_f32_f32", mul_mat_vec_q8_0_f32_f32_len, mul_mat_vec_q8_0_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {device->subgroup_size, 1}, 1);
1378
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q2_K], "mul_mat_vec_q2_k_f32_f32", mul_mat_vec_q2_k_f32_f32_len, mul_mat_vec_q2_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {device->subgroup_size}, 1);
1379
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q3_K], "mul_mat_vec_q3_k_f32_f32", mul_mat_vec_q3_k_f32_f32_len, mul_mat_vec_q3_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {device->subgroup_size}, 1);
1380
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q4_K], "mul_mat_vec_q4_k_f32_f32", mul_mat_vec_q4_k_f32_f32_len, mul_mat_vec_q4_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {device->subgroup_size}, 1);
1381
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q5_K], "mul_mat_vec_q5_k_f32_f32", mul_mat_vec_q5_k_f32_f32_len, mul_mat_vec_q5_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {device->subgroup_size}, 1);
1382
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q6_K], "mul_mat_vec_q6_k_f32_f32", mul_mat_vec_q6_k_f32_f32_len, mul_mat_vec_q6_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {device->subgroup_size}, 1);
1383
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ4_NL], "mul_mat_vec_iq4_nl_f32_f32", mul_mat_vec_iq4_nl_f32_f32_len, mul_mat_vec_iq4_nl_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1);
1384
 
1385
  ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_F32 ], "mul_mat_vec_f32_f16_f32", mul_mat_vec_f32_f16_f32_len, mul_mat_vec_f32_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1);
1386
  ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_F16 ], "mul_mat_vec_f16_f16_f32", mul_mat_vec_f16_f16_f32_len, mul_mat_vec_f16_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1);
1387
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q4_0], "mul_mat_vec_q4_0_f16_f32", mul_mat_vec_q4_0_f16_f32_len, mul_mat_vec_q4_0_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1);
1388
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q4_1], "mul_mat_vec_q4_1_f16_f32", mul_mat_vec_q4_1_f16_f32_len, mul_mat_vec_q4_1_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1);
1389
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q5_0], "mul_mat_vec_q5_0_f16_f32", mul_mat_vec_q5_0_f16_f32_len, mul_mat_vec_q5_0_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1);
1390
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q5_1], "mul_mat_vec_q5_1_f16_f32", mul_mat_vec_q5_1_f16_f32_len, mul_mat_vec_q5_1_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1);
1391
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q8_0], "mul_mat_vec_q8_0_f16_f32", mul_mat_vec_q8_0_f16_f32_len, mul_mat_vec_q8_0_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {device->subgroup_size, 1}, 1);
1392
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q2_K], "mul_mat_vec_q2_k_f16_f32", mul_mat_vec_q2_k_f16_f32_len, mul_mat_vec_q2_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {device->subgroup_size}, 1);
1393
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q3_K], "mul_mat_vec_q3_k_f16_f32", mul_mat_vec_q3_k_f16_f32_len, mul_mat_vec_q3_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {device->subgroup_size}, 1);
1394
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q4_K], "mul_mat_vec_q4_k_f16_f32", mul_mat_vec_q4_k_f16_f32_len, mul_mat_vec_q4_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {device->subgroup_size}, 1);
1395
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q5_K], "mul_mat_vec_q5_k_f16_f32", mul_mat_vec_q5_k_f16_f32_len, mul_mat_vec_q5_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {device->subgroup_size}, 1);
1396
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q6_K], "mul_mat_vec_q6_k_f16_f32", mul_mat_vec_q6_k_f16_f32_len, mul_mat_vec_q6_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {device->subgroup_size}, 1);
1397
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ4_NL], "mul_mat_vec_iq4_nl_f16_f32", mul_mat_vec_iq4_nl_f16_f32_len, mul_mat_vec_iq4_nl_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size}, 1);
1398
 
1399
  ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_F32 ], "mul_mat_vec_id_f32_f32", mul_mat_vec_id_f32_f32_len, mul_mat_vec_id_f32_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1);
1400
  ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_F16 ], "mul_mat_vec_id_f16_f32", mul_mat_vec_id_f16_f32_len, mul_mat_vec_id_f16_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1);
1401
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q4_0], "mul_mat_vec_id_q4_0_f32", mul_mat_vec_id_q4_0_f32_len, mul_mat_vec_id_q4_0_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1);
1402
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q4_1], "mul_mat_vec_id_q4_1_f32", mul_mat_vec_id_q4_1_f32_len, mul_mat_vec_id_q4_1_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1);
1403
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q5_0], "mul_mat_vec_id_q5_0_f32", mul_mat_vec_id_q5_0_f32_len, mul_mat_vec_id_q5_0_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1);
1404
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q5_1], "mul_mat_vec_id_q5_1_f32", mul_mat_vec_id_q5_1_f32_len, mul_mat_vec_id_q5_1_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1);
1405
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q8_0], "mul_mat_vec_id_q8_0_f32", mul_mat_vec_id_q8_0_f32_len, mul_mat_vec_id_q8_0_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {1, 1, 1}, {device->subgroup_size, 1}, 1);
1406
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q2_K], "mul_mat_vec_id_q2_k_f32", mul_mat_vec_id_q2_k_f32_len, mul_mat_vec_id_q2_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {1, 1, 1}, {device->subgroup_size}, 1);
1407
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q3_K], "mul_mat_vec_id_q3_k_f32", mul_mat_vec_id_q3_k_f32_len, mul_mat_vec_id_q3_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {1, 1, 1}, {device->subgroup_size}, 1);
1408
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q4_K], "mul_mat_vec_id_q4_k_f32", mul_mat_vec_id_q4_k_f32_len, mul_mat_vec_id_q4_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {1, 1, 1}, {device->subgroup_size}, 1);
1409
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q5_K], "mul_mat_vec_id_q5_k_f32", mul_mat_vec_id_q5_k_f32_len, mul_mat_vec_id_q5_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {1, 1, 1}, {device->subgroup_size}, 1);
1410
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q6_K], "mul_mat_vec_id_q6_k_f32", mul_mat_vec_id_q6_k_f32_len, mul_mat_vec_id_q6_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {1, 1, 1}, {device->subgroup_size}, 1);
1411
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ4_NL], "mul_mat_vec_id_iq4_nl_f32", mul_mat_vec_id_iq4_nl_f32_len, mul_mat_vec_id_iq4_nl_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1);
1412
 
1413
  // dequant shaders
1414
  ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_F32 ], "f32_to_f16", dequant_f32_len, dequant_f32_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1);
@@ -1591,12 +1601,15 @@ static vk_device ggml_vk_get_device(size_t idx) {
1591
 
1592
  bool fp16_storage = false;
1593
  bool fp16_compute = false;
 
1594
 
1595
  for (const auto& properties : ext_props) {
1596
  if (strcmp("VK_KHR_16bit_storage", properties.extensionName) == 0) {
1597
  fp16_storage = true;
1598
  } else if (strcmp("VK_KHR_shader_float16_int8", properties.extensionName) == 0) {
1599
  fp16_compute = true;
 
 
1600
  }
1601
  }
1602
 
@@ -1642,10 +1655,22 @@ static vk_device ggml_vk_get_device(size_t idx) {
1642
  vk12_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_VULKAN_1_2_FEATURES;
1643
  vk11_features.pNext = &vk12_features;
1644
 
 
 
 
 
 
 
 
 
 
 
1645
  vkGetPhysicalDeviceFeatures2(device->physical_device, &device_features2);
1646
 
1647
  device->fp16 = device->fp16 && vk12_features.shaderFloat16;
1648
 
 
 
1649
  if (!vk11_features.storageBuffer16BitAccess) {
1650
  std::cerr << "ggml_vulkan: device " << GGML_VK_NAME << idx << " does not support 16-bit storage." << std::endl;
1651
  throw std::runtime_error("Unsupported device");
@@ -3190,7 +3215,7 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context&
3190
 
3191
  if (ne01 > max_groups_x) {
3192
  groups_z = 64;
3193
- groups_x /= groups_z;
3194
  }
3195
 
3196
  // compute
@@ -3767,7 +3792,7 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte
3767
 
3768
  if (ne01 > max_groups_x) {
3769
  groups_z = 64;
3770
- groups_x /= groups_z;
3771
  }
3772
 
3773
  // compute
 
158
  std::string name;
159
  uint64_t max_memory_allocation_size;
160
  bool fp16;
161
+ bool pipeline_robustness;
162
  vk::Device device;
163
  uint32_t vendor_id;
164
  vk_queue compute_queue;
 
655
  static std::mutex compile_count_mutex;
656
  static std::condition_variable compile_count_cond;
657
 
658
+ static void ggml_vk_create_pipeline_func(vk_device& device, vk_pipeline& pipeline, const std::string name, size_t spv_size, const void* spv_data, const std::string entrypoint, uint32_t parameter_count, uint32_t push_constant_size, std::array<uint32_t, 3> wg_denoms, std::vector<uint32_t> specialization_constants, uint32_t align, bool disable_robustness) {
659
  VK_LOG_DEBUG("ggml_vk_create_pipeline(" << device->name << ", " << name << ", " << entrypoint << ", " << parameter_count << ", " << push_constant_size << ", (" << wg_denoms[0] << "," << wg_denoms[1] << "," << wg_denoms[2] << "), specialization_constants, " << align << ")");
660
  GGML_ASSERT(parameter_count > 0);
661
  GGML_ASSERT(wg_denoms[0] > 0 && wg_denoms[1] > 0 && wg_denoms[2] > 0); // NOLINT
 
725
  vk::PipelineCreateFlags(),
726
  pipeline_shader_create_info,
727
  pipeline->layout);
728
+
729
+ vk::PipelineRobustnessCreateInfoEXT rci;
730
+
731
+ if (device->pipeline_robustness && disable_robustness) {
732
+ rci.storageBuffers = vk::PipelineRobustnessBufferBehaviorEXT::eDisabled;
733
+ rci.uniformBuffers = vk::PipelineRobustnessBufferBehaviorEXT::eDisabled;
734
+ compute_pipeline_create_info.setPNext(&rci);
735
+ }
736
+
737
  pipeline->pipeline = device->device.createComputePipeline(VK_NULL_HANDLE, compute_pipeline_create_info).value;
738
 
739
  {
 
1271
  device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL] = std::make_shared<vk_matmul_pipeline_struct>();
1272
 
1273
  std::vector<std::future<void>> compiles;
1274
+ auto const &ggml_vk_create_pipeline = [&](vk_device& device, vk_pipeline& pipeline, const std::string &name, size_t spv_size, const void* spv_data, const std::string &entrypoint, uint32_t parameter_count, uint32_t push_constant_size, std::array<uint32_t, 3> wg_denoms, const std::vector<uint32_t>& specialization_constants, uint32_t align, bool disable_robustness = false) {
1275
  {
1276
  // wait until fewer than N compiles are in progress
1277
  uint32_t N = std::max(1u, std::thread::hardware_concurrency());
 
1281
  }
1282
  compile_count++;
1283
  }
1284
+ compiles.push_back(std::async(ggml_vk_create_pipeline_func, std::ref(device), std::ref(pipeline), name, spv_size, spv_data, entrypoint, parameter_count, push_constant_size, wg_denoms, specialization_constants, align, disable_robustness));
1285
  };
1286
 
1287
  if (device->fp16) {
 
1380
  // computing two rows per workgroup is a benefit for Q4_0 -> Q5_1, but not for Q8_0.
1381
  ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_F32 ], "mul_mat_vec_f32_f32_f32", mul_mat_vec_f32_f32_f32_len, mul_mat_vec_f32_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1);
1382
  ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_F16 ], "mul_mat_vec_f16_f32_f32", mul_mat_vec_f16_f32_f32_len, mul_mat_vec_f16_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1);
1383
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q4_0], "mul_mat_vec_q4_0_f32_f32", mul_mat_vec_q4_0_f32_f32_len, mul_mat_vec_q4_0_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1, true);
1384
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q4_1], "mul_mat_vec_q4_1_f32_f32", mul_mat_vec_q4_1_f32_f32_len, mul_mat_vec_q4_1_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1, true);
1385
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q5_0], "mul_mat_vec_q5_0_f32_f32", mul_mat_vec_q5_0_f32_f32_len, mul_mat_vec_q5_0_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1, true);
1386
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q5_1], "mul_mat_vec_q5_1_f32_f32", mul_mat_vec_q5_1_f32_f32_len, mul_mat_vec_q5_1_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1, true);
1387
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q8_0], "mul_mat_vec_q8_0_f32_f32", mul_mat_vec_q8_0_f32_f32_len, mul_mat_vec_q8_0_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {device->subgroup_size, 1}, 1, true);
1388
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q2_K], "mul_mat_vec_q2_k_f32_f32", mul_mat_vec_q2_k_f32_f32_len, mul_mat_vec_q2_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {device->subgroup_size}, 1, true);
1389
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q3_K], "mul_mat_vec_q3_k_f32_f32", mul_mat_vec_q3_k_f32_f32_len, mul_mat_vec_q3_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {device->subgroup_size}, 1, true);
1390
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q4_K], "mul_mat_vec_q4_k_f32_f32", mul_mat_vec_q4_k_f32_f32_len, mul_mat_vec_q4_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {device->subgroup_size}, 1, true);
1391
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q5_K], "mul_mat_vec_q5_k_f32_f32", mul_mat_vec_q5_k_f32_f32_len, mul_mat_vec_q5_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {device->subgroup_size}, 1, true);
1392
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q6_K], "mul_mat_vec_q6_k_f32_f32", mul_mat_vec_q6_k_f32_f32_len, mul_mat_vec_q6_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {device->subgroup_size}, 1, true);
1393
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ4_NL], "mul_mat_vec_iq4_nl_f32_f32", mul_mat_vec_iq4_nl_f32_f32_len, mul_mat_vec_iq4_nl_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1, true);
1394
 
1395
  ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_F32 ], "mul_mat_vec_f32_f16_f32", mul_mat_vec_f32_f16_f32_len, mul_mat_vec_f32_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1);
1396
  ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_F16 ], "mul_mat_vec_f16_f16_f32", mul_mat_vec_f16_f16_f32_len, mul_mat_vec_f16_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1);
1397
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q4_0], "mul_mat_vec_q4_0_f16_f32", mul_mat_vec_q4_0_f16_f32_len, mul_mat_vec_q4_0_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1, true);
1398
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q4_1], "mul_mat_vec_q4_1_f16_f32", mul_mat_vec_q4_1_f16_f32_len, mul_mat_vec_q4_1_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1, true);
1399
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q5_0], "mul_mat_vec_q5_0_f16_f32", mul_mat_vec_q5_0_f16_f32_len, mul_mat_vec_q5_0_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1, true);
1400
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q5_1], "mul_mat_vec_q5_1_f16_f32", mul_mat_vec_q5_1_f16_f32_len, mul_mat_vec_q5_1_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1, true);
1401
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q8_0], "mul_mat_vec_q8_0_f16_f32", mul_mat_vec_q8_0_f16_f32_len, mul_mat_vec_q8_0_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {device->subgroup_size, 1}, 1, true);
1402
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q2_K], "mul_mat_vec_q2_k_f16_f32", mul_mat_vec_q2_k_f16_f32_len, mul_mat_vec_q2_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {device->subgroup_size}, 1, true);
1403
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q3_K], "mul_mat_vec_q3_k_f16_f32", mul_mat_vec_q3_k_f16_f32_len, mul_mat_vec_q3_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {device->subgroup_size}, 1, true);
1404
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q4_K], "mul_mat_vec_q4_k_f16_f32", mul_mat_vec_q4_k_f16_f32_len, mul_mat_vec_q4_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {device->subgroup_size}, 1, true);
1405
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q5_K], "mul_mat_vec_q5_k_f16_f32", mul_mat_vec_q5_k_f16_f32_len, mul_mat_vec_q5_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {device->subgroup_size}, 1, true);
1406
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q6_K], "mul_mat_vec_q6_k_f16_f32", mul_mat_vec_q6_k_f16_f32_len, mul_mat_vec_q6_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {device->subgroup_size}, 1, true);
1407
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ4_NL], "mul_mat_vec_iq4_nl_f16_f32", mul_mat_vec_iq4_nl_f16_f32_len, mul_mat_vec_iq4_nl_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size}, 1, true);
1408
 
1409
  ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_F32 ], "mul_mat_vec_id_f32_f32", mul_mat_vec_id_f32_f32_len, mul_mat_vec_id_f32_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1);
1410
  ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_F16 ], "mul_mat_vec_id_f16_f32", mul_mat_vec_id_f16_f32_len, mul_mat_vec_id_f16_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1);
1411
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q4_0], "mul_mat_vec_id_q4_0_f32", mul_mat_vec_id_q4_0_f32_len, mul_mat_vec_id_q4_0_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1, true);
1412
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q4_1], "mul_mat_vec_id_q4_1_f32", mul_mat_vec_id_q4_1_f32_len, mul_mat_vec_id_q4_1_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1, true);
1413
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q5_0], "mul_mat_vec_id_q5_0_f32", mul_mat_vec_id_q5_0_f32_len, mul_mat_vec_id_q5_0_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1, true);
1414
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q5_1], "mul_mat_vec_id_q5_1_f32", mul_mat_vec_id_q5_1_f32_len, mul_mat_vec_id_q5_1_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1, true);
1415
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q8_0], "mul_mat_vec_id_q8_0_f32", mul_mat_vec_id_q8_0_f32_len, mul_mat_vec_id_q8_0_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {1, 1, 1}, {device->subgroup_size, 1}, 1, true);
1416
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q2_K], "mul_mat_vec_id_q2_k_f32", mul_mat_vec_id_q2_k_f32_len, mul_mat_vec_id_q2_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {1, 1, 1}, {device->subgroup_size}, 1, true);
1417
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q3_K], "mul_mat_vec_id_q3_k_f32", mul_mat_vec_id_q3_k_f32_len, mul_mat_vec_id_q3_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {1, 1, 1}, {device->subgroup_size}, 1, true);
1418
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q4_K], "mul_mat_vec_id_q4_k_f32", mul_mat_vec_id_q4_k_f32_len, mul_mat_vec_id_q4_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {1, 1, 1}, {device->subgroup_size}, 1, true);
1419
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q5_K], "mul_mat_vec_id_q5_k_f32", mul_mat_vec_id_q5_k_f32_len, mul_mat_vec_id_q5_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {1, 1, 1}, {device->subgroup_size}, 1, true);
1420
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q6_K], "mul_mat_vec_id_q6_k_f32", mul_mat_vec_id_q6_k_f32_len, mul_mat_vec_id_q6_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {1, 1, 1}, {device->subgroup_size}, 1, true);
1421
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ4_NL], "mul_mat_vec_id_iq4_nl_f32", mul_mat_vec_id_iq4_nl_f32_len, mul_mat_vec_id_iq4_nl_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1, true);
1422
 
1423
  // dequant shaders
1424
  ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_F32 ], "f32_to_f16", dequant_f32_len, dequant_f32_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1);
 
1601
 
1602
  bool fp16_storage = false;
1603
  bool fp16_compute = false;
1604
+ bool pipeline_robustness = false;
1605
 
1606
  for (const auto& properties : ext_props) {
1607
  if (strcmp("VK_KHR_16bit_storage", properties.extensionName) == 0) {
1608
  fp16_storage = true;
1609
  } else if (strcmp("VK_KHR_shader_float16_int8", properties.extensionName) == 0) {
1610
  fp16_compute = true;
1611
+ } else if (strcmp("VK_EXT_pipeline_robustness", properties.extensionName) == 0) {
1612
+ pipeline_robustness = true;
1613
  }
1614
  }
1615
 
 
1655
  vk12_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_VULKAN_1_2_FEATURES;
1656
  vk11_features.pNext = &vk12_features;
1657
 
1658
+ VkPhysicalDevicePipelineRobustnessFeaturesEXT pl_robustness_features;
1659
+ pl_robustness_features.pNext = nullptr;
1660
+ pl_robustness_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_PIPELINE_ROBUSTNESS_FEATURES_EXT;
1661
+ pl_robustness_features.pipelineRobustness = VK_FALSE;
1662
+
1663
+ if (pipeline_robustness) {
1664
+ vk12_features.pNext = &pl_robustness_features;
1665
+ device_extensions.push_back("VK_EXT_pipeline_robustness");
1666
+ }
1667
+
1668
  vkGetPhysicalDeviceFeatures2(device->physical_device, &device_features2);
1669
 
1670
  device->fp16 = device->fp16 && vk12_features.shaderFloat16;
1671
 
1672
+ device->pipeline_robustness = pl_robustness_features.pipelineRobustness;
1673
+
1674
  if (!vk11_features.storageBuffer16BitAccess) {
1675
  std::cerr << "ggml_vulkan: device " << GGML_VK_NAME << idx << " does not support 16-bit storage." << std::endl;
1676
  throw std::runtime_error("Unsupported device");
 
3215
 
3216
  if (ne01 > max_groups_x) {
3217
  groups_z = 64;
3218
+ groups_x = CEIL_DIV(groups_x, groups_z);
3219
  }
3220
 
3221
  // compute
 
3792
 
3793
  if (ne01 > max_groups_x) {
3794
  groups_z = 64;
3795
+ groups_x = CEIL_DIV(groups_x, groups_z);
3796
  }
3797
 
3798
  // compute
ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.comp CHANGED
@@ -2,6 +2,15 @@
2
  #extension GL_EXT_shader_explicit_arithmetic_types_int8 : require
3
  #endif
4
 
 
 
 
 
 
 
 
 
 
5
  #if defined(DATA_A_F32)
6
  vec2 dequantize(uint ib, uint iqs, uint a_offset) {
7
  return vec2(data_a[a_offset + ib], data_a[a_offset + ib + 1]);
@@ -20,6 +29,11 @@ vec2 dequantize(uint ib, uint iqs, uint a_offset) {
20
  const uint vui = uint(data_a[a_offset + ib].qs[iqs]);
21
  return (vec2(vui & 0xF, vui >> 4) - 8.0f) * d;
22
  }
 
 
 
 
 
23
  #endif
24
 
25
  #if defined(DATA_A_Q4_1)
@@ -29,6 +43,12 @@ vec2 dequantize(uint ib, uint iqs, uint a_offset) {
29
  const uint vui = uint(data_a[a_offset + ib].qs[iqs]);
30
  return vec2(vui & 0xF, vui >> 4) * d + m;
31
  }
 
 
 
 
 
 
32
  #endif
33
 
34
  #if defined(DATA_A_Q5_0)
@@ -39,6 +59,14 @@ vec2 dequantize(uint ib, uint iqs, uint a_offset) {
39
  const uint vui = uint(data_a[a_offset + ib].qs[iqs]);
40
  return (vec2((vui & 0xF) | qh.x, (vui >> 4) | qh.y) - 16.0f) * d;
41
  }
 
 
 
 
 
 
 
 
42
  #endif
43
 
44
  #if defined(DATA_A_Q5_1)
@@ -50,6 +78,15 @@ vec2 dequantize(uint ib, uint iqs, uint a_offset) {
50
  const uint vui = uint(data_a[a_offset + ib].qs[iqs]);
51
  return vec2((vui & 0xF) | qh.x, (vui >> 4) | qh.y) * d + m;
52
  }
 
 
 
 
 
 
 
 
 
53
  #endif
54
 
55
  #if defined(DATA_A_Q8_0)
@@ -57,6 +94,12 @@ vec2 dequantize(uint ib, uint iqs, uint a_offset) {
57
  const float d = float(data_a[a_offset + ib].d);
58
  return vec2(int(data_a[a_offset + ib].qs[iqs]), int(data_a[a_offset + ib].qs[iqs + 1])) * d;
59
  }
 
 
 
 
 
 
60
  #endif
61
 
62
  #if defined(DATA_A_IQ4_NL)
@@ -65,4 +108,9 @@ vec2 dequantize(uint ib, uint iqs, uint a_offset) {
65
  const uint vui = uint(data_a[a_offset + ib].qs[iqs]);
66
  return vec2(kvalues_iq4nl[vui & 0xF], kvalues_iq4nl[vui >> 4]) * d;
67
  }
 
 
 
 
 
68
  #endif
 
2
  #extension GL_EXT_shader_explicit_arithmetic_types_int8 : require
3
  #endif
4
 
5
+ #include "types.comp"
6
+
7
+ #if defined(A_TYPE_PACKED16)
8
+ layout (binding = 0) readonly buffer A_PACKED16 {A_TYPE_PACKED16 data_a_packed16[];};
9
+ #endif
10
+ #if defined(A_TYPE_PACKED32)
11
+ layout (binding = 0) readonly buffer A_PACKED32 {A_TYPE_PACKED32 data_a_packed32[];};
12
+ #endif
13
+
14
  #if defined(DATA_A_F32)
15
  vec2 dequantize(uint ib, uint iqs, uint a_offset) {
16
  return vec2(data_a[a_offset + ib], data_a[a_offset + ib + 1]);
 
29
  const uint vui = uint(data_a[a_offset + ib].qs[iqs]);
30
  return (vec2(vui & 0xF, vui >> 4) - 8.0f) * d;
31
  }
32
+ vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
33
+ const float d = float(data_a_packed16[a_offset + ib].d);
34
+ const uint vui = uint(data_a_packed16[a_offset + ib].qs[iqs/2]);
35
+ return (vec4(vui & 0xF, (vui >> 4) & 0xF, (vui >> 8) & 0xF, (vui >> 12) & 0xF) - 8.0f) * d;
36
+ }
37
  #endif
38
 
39
  #if defined(DATA_A_Q4_1)
 
43
  const uint vui = uint(data_a[a_offset + ib].qs[iqs]);
44
  return vec2(vui & 0xF, vui >> 4) * d + m;
45
  }
46
+ vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
47
+ const float d = float(data_a_packed16[a_offset + ib].d);
48
+ const float m = float(data_a_packed16[a_offset + ib].m);
49
+ const uint vui = uint(data_a_packed16[a_offset + ib].qs[iqs/2]);
50
+ return vec4(vui & 0xF, (vui >> 4) & 0xF, (vui >> 8) & 0xF, (vui >> 12) & 0xF) * d + m;
51
+ }
52
  #endif
53
 
54
  #if defined(DATA_A_Q5_0)
 
59
  const uint vui = uint(data_a[a_offset + ib].qs[iqs]);
60
  return (vec2((vui & 0xF) | qh.x, (vui >> 4) | qh.y) - 16.0f) * d;
61
  }
62
+ vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
63
+ const float d = float(data_a_packed16[a_offset + ib].d);
64
+ const uint uint_qh = uint(data_a_packed16[a_offset + ib].qh[1]) << 16 | data_a_packed16[a_offset + ib].qh[0];
65
+ const ivec2 qh0 = ivec2(((uint_qh >> iqs) << 4) & 0x10, (uint_qh >> (iqs + 12)) & 0x10);
66
+ const ivec2 qh1 = ivec2(((uint_qh >> (iqs + 1)) << 4) & 0x10, (uint_qh >> (iqs + 13)) & 0x10);
67
+ const uint vui = uint(data_a_packed16[a_offset + ib].qs[iqs/2]);
68
+ return (vec4(((vui >> 0) & 0xF) | qh0.x, ((vui >> 4) & 0xF) | qh0.y, ((vui >> 8) & 0xF) | qh1.x, ((vui >> 12) & 0xF) | qh1.y) - 16.0f) * d;
69
+ }
70
  #endif
71
 
72
  #if defined(DATA_A_Q5_1)
 
78
  const uint vui = uint(data_a[a_offset + ib].qs[iqs]);
79
  return vec2((vui & 0xF) | qh.x, (vui >> 4) | qh.y) * d + m;
80
  }
81
+ vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
82
+ const float d = float(data_a_packed16[a_offset + ib].d);
83
+ const float m = float(data_a_packed16[a_offset + ib].m);
84
+ const uint uint_qh = data_a_packed16[a_offset + ib].qh;
85
+ const ivec2 qh0 = ivec2(((uint_qh >> iqs) << 4) & 0x10, (uint_qh >> (iqs + 12)) & 0x10);
86
+ const ivec2 qh1 = ivec2(((uint_qh >> (iqs + 1)) << 4) & 0x10, (uint_qh >> (iqs + 13)) & 0x10);
87
+ const uint vui = uint(data_a_packed16[a_offset + ib].qs[iqs/2]);
88
+ return vec4(((vui >> 0) & 0xF) | qh0.x, ((vui >> 4) & 0xF) | qh0.y, ((vui >> 8) & 0xF) | qh1.x, ((vui >> 12) & 0xF) | qh1.y) * d + m;
89
+ }
90
  #endif
91
 
92
  #if defined(DATA_A_Q8_0)
 
94
  const float d = float(data_a[a_offset + ib].d);
95
  return vec2(int(data_a[a_offset + ib].qs[iqs]), int(data_a[a_offset + ib].qs[iqs + 1])) * d;
96
  }
97
+ vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
98
+ const float d = float(data_a_packed16[a_offset + ib].d);
99
+ uint32_t v0 = data_a_packed16[a_offset + ib].qs[iqs/2];
100
+ uint32_t v1 = data_a_packed16[a_offset + ib].qs[iqs/2 + 1];
101
+ return vec4(int8_t(v0 & 0xFF), int8_t((v0 >> 8) & 0xFF), int8_t(v1 & 0xFF), int8_t((v1 >> 8) & 0xFF)) * d;
102
+ }
103
  #endif
104
 
105
  #if defined(DATA_A_IQ4_NL)
 
108
  const uint vui = uint(data_a[a_offset + ib].qs[iqs]);
109
  return vec2(kvalues_iq4nl[vui & 0xF], kvalues_iq4nl[vui >> 4]) * d;
110
  }
111
+ vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
112
+ const float d = float(data_a_packed16[a_offset + ib].d);
113
+ const uint vui = uint(data_a_packed16[a_offset + ib].qs[iqs/2]);
114
+ return vec4(kvalues_iq4nl[vui & 0xF], kvalues_iq4nl[(vui >> 4) & 0xF], kvalues_iq4nl[(vui >> 8) & 0xF], kvalues_iq4nl[(vui >> 12) & 0xF]) * d;
115
+ }
116
  #endif
ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp CHANGED
@@ -3,7 +3,7 @@
3
  #ifdef FLOAT16
4
  #extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
5
  #endif
6
- #extension GL_EXT_shader_explicit_arithmetic_types_int32 : require
7
 
8
  #include "mul_mat_vec_base.comp"
9
 
@@ -12,16 +12,48 @@ layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
12
  layout (constant_id = 0) const uint BLOCK_SIZE = 32;
13
  layout (constant_id = 1) const uint NUM_ROWS = 1;
14
 
 
 
 
 
 
 
 
15
  uint a_offset, b_offset, d_offset, y_offset;
16
 
17
  shared FLOAT_TYPE tmpsh[NUM_ROWS][BLOCK_SIZE];
18
 
19
  void iter(inout FLOAT_TYPE temp[NUM_ROWS], const uint first_row, const uint num_rows, const uint tid, const uint i, bool lastiter)
20
  {
21
- const uint col = i*BLOCK_SIZE + 2*tid;
22
  const uint iqs = (col%QUANT_K)/QUANT_R; // quant index
23
  const uint iybs = col - col%QUANT_K; // y block start index
24
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  // Check if the second of the pair of elements is OOB, and don't fetch B or
26
  // accumulate it. We still fetch a pair of elements for A, which is fine for
27
  // quantized formats since they'll be within the same block. We should
@@ -34,9 +66,24 @@ void iter(inout FLOAT_TYPE temp[NUM_ROWS], const uint first_row, const uint num_
34
  if (!OOB) {
35
  b1 = FLOAT_TYPE(data_b[b_offset + iybs + iqs + y_offset]);
36
  }
 
37
  [[unroll]] for (uint n = 0; n < num_rows; ++n) {
38
  const uint ib = ((first_row + n)*p.ncols + col)/QUANT_K; // block index
39
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
  const vec2 v = dequantize(ib, iqs, a_offset);
41
 
42
  // matrix multiplication
@@ -44,6 +91,7 @@ void iter(inout FLOAT_TYPE temp[NUM_ROWS], const uint first_row, const uint num_
44
  if (!OOB) {
45
  temp[n] = fma(FLOAT_TYPE(v.y), b1, temp[n]);
46
  }
 
47
  }
48
  }
49
 
@@ -61,22 +109,33 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
61
  temp[i] = FLOAT_TYPE(0);
62
  }
63
 
64
- const int unroll_count = 8;
65
-
66
- const uint num_iters = (p.ncols >= 2*tid) ? ((p.ncols - 2*tid + BLOCK_SIZE - 1) / BLOCK_SIZE) : 0;
67
- const uint unrolled_iters = num_iters & ~(2*unroll_count - 1);
 
 
68
 
69
  uint i = 0;
70
  while (i < unrolled_iters) {
71
  // Manually partially unroll the loop
72
  [[unroll]] for (uint k = 0; k < unroll_count; ++k) {
73
- iter(temp, first_row, num_rows, tid, i, false);
74
- i += 2;
 
 
 
 
 
 
 
 
 
75
  }
76
  }
77
  while (i < num_iters) {
78
- iter(temp, first_row, num_rows, tid, i, true);
79
- i += 2;
80
  }
81
 
82
  // sum up partial sums and write back result
@@ -106,6 +165,9 @@ void main() {
106
  if (first_row + NUM_ROWS <= p.stride_d) {
107
  compute_outputs(first_row, NUM_ROWS);
108
  } else {
 
 
 
109
  compute_outputs(first_row, p.stride_d - first_row);
110
  }
111
  }
 
3
  #ifdef FLOAT16
4
  #extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
5
  #endif
6
+ #extension GL_EXT_shader_explicit_arithmetic_types : require
7
 
8
  #include "mul_mat_vec_base.comp"
9
 
 
12
  layout (constant_id = 0) const uint BLOCK_SIZE = 32;
13
  layout (constant_id = 1) const uint NUM_ROWS = 1;
14
 
15
+ #if !defined(DATA_A_F32) && !defined(DATA_A_F16)
16
+ #define K_PER_ITER 8
17
+ #else
18
+ #define K_PER_ITER 2
19
+ #endif
20
+
21
+
22
  uint a_offset, b_offset, d_offset, y_offset;
23
 
24
  shared FLOAT_TYPE tmpsh[NUM_ROWS][BLOCK_SIZE];
25
 
26
  void iter(inout FLOAT_TYPE temp[NUM_ROWS], const uint first_row, const uint num_rows, const uint tid, const uint i, bool lastiter)
27
  {
28
+ const uint col = i*BLOCK_SIZE + K_PER_ITER*tid;
29
  const uint iqs = (col%QUANT_K)/QUANT_R; // quant index
30
  const uint iybs = col - col%QUANT_K; // y block start index
31
 
32
+ #if K_PER_ITER == 8
33
+ #if QUANT_R == 2
34
+ B_TYPE_VEC4 bv02 = data_b_v4[(b_offset + iybs + iqs) / 4];
35
+ B_TYPE_VEC4 bv13 = data_b_v4[(b_offset + iybs + iqs + y_offset) / 4];
36
+ FLOAT_TYPE b0 = FLOAT_TYPE(bv02.x);
37
+ FLOAT_TYPE b1 = FLOAT_TYPE(bv13.x);
38
+ FLOAT_TYPE b2 = FLOAT_TYPE(bv02.y);
39
+ FLOAT_TYPE b3 = FLOAT_TYPE(bv13.y);
40
+ FLOAT_TYPE b4 = FLOAT_TYPE(bv02.z);
41
+ FLOAT_TYPE b5 = FLOAT_TYPE(bv13.z);
42
+ FLOAT_TYPE b6 = FLOAT_TYPE(bv02.w);
43
+ FLOAT_TYPE b7 = FLOAT_TYPE(bv13.w);
44
+ #else
45
+ B_TYPE_VEC4 bv0 = data_b_v4[(b_offset + iybs + iqs) / 4];
46
+ B_TYPE_VEC4 bv1 = data_b_v4[(b_offset + iybs + iqs) / 4 + 1];
47
+ FLOAT_TYPE b0 = FLOAT_TYPE(bv0.x);
48
+ FLOAT_TYPE b1 = FLOAT_TYPE(bv0.y);
49
+ FLOAT_TYPE b2 = FLOAT_TYPE(bv0.z);
50
+ FLOAT_TYPE b3 = FLOAT_TYPE(bv0.w);
51
+ FLOAT_TYPE b4 = FLOAT_TYPE(bv1.x);
52
+ FLOAT_TYPE b5 = FLOAT_TYPE(bv1.y);
53
+ FLOAT_TYPE b6 = FLOAT_TYPE(bv1.z);
54
+ FLOAT_TYPE b7 = FLOAT_TYPE(bv1.w);
55
+ #endif
56
+ #else
57
  // Check if the second of the pair of elements is OOB, and don't fetch B or
58
  // accumulate it. We still fetch a pair of elements for A, which is fine for
59
  // quantized formats since they'll be within the same block. We should
 
66
  if (!OOB) {
67
  b1 = FLOAT_TYPE(data_b[b_offset + iybs + iqs + y_offset]);
68
  }
69
+ #endif
70
  [[unroll]] for (uint n = 0; n < num_rows; ++n) {
71
  const uint ib = ((first_row + n)*p.ncols + col)/QUANT_K; // block index
72
 
73
+ #if K_PER_ITER == 8
74
+ const vec4 v = dequantize4(ib, iqs, a_offset);
75
+ const vec4 v2 = dequantize4(ib, iqs+(4/QUANT_R), a_offset);
76
+
77
+ // matrix multiplication
78
+ temp[n] = fma(FLOAT_TYPE(v.x), b0, temp[n]);
79
+ temp[n] = fma(FLOAT_TYPE(v.y), b1, temp[n]);
80
+ temp[n] = fma(FLOAT_TYPE(v.z), b2, temp[n]);
81
+ temp[n] = fma(FLOAT_TYPE(v.w), b3, temp[n]);
82
+ temp[n] = fma(FLOAT_TYPE(v2.x), b4, temp[n]);
83
+ temp[n] = fma(FLOAT_TYPE(v2.y), b5, temp[n]);
84
+ temp[n] = fma(FLOAT_TYPE(v2.z), b6, temp[n]);
85
+ temp[n] = fma(FLOAT_TYPE(v2.w), b7, temp[n]);
86
+ #else
87
  const vec2 v = dequantize(ib, iqs, a_offset);
88
 
89
  // matrix multiplication
 
91
  if (!OOB) {
92
  temp[n] = fma(FLOAT_TYPE(v.y), b1, temp[n]);
93
  }
94
+ #endif
95
  }
96
  }
97
 
 
109
  temp[i] = FLOAT_TYPE(0);
110
  }
111
 
112
+ uint num_iters = p.ncols / (K_PER_ITER * BLOCK_SIZE);
113
+ if (num_iters * K_PER_ITER * BLOCK_SIZE + K_PER_ITER*tid < p.ncols) {
114
+ num_iters++;
115
+ }
116
+ int unroll_count = 4;
117
+ uint unrolled_iters = num_iters & ~(unroll_count - 1);
118
 
119
  uint i = 0;
120
  while (i < unrolled_iters) {
121
  // Manually partially unroll the loop
122
  [[unroll]] for (uint k = 0; k < unroll_count; ++k) {
123
+ iter(temp, first_row, num_rows, tid, i*K_PER_ITER, false);
124
+ i++;
125
+ }
126
+ }
127
+ unroll_count = 2;
128
+ unrolled_iters = num_iters & ~(unroll_count - 1);
129
+ while (i < unrolled_iters) {
130
+ // Manually partially unroll the loop
131
+ [[unroll]] for (uint k = 0; k < unroll_count; ++k) {
132
+ iter(temp, first_row, num_rows, tid, i*K_PER_ITER, false);
133
+ i++;
134
  }
135
  }
136
  while (i < num_iters) {
137
+ iter(temp, first_row, num_rows, tid, i*K_PER_ITER, true);
138
+ i++;
139
  }
140
 
141
  // sum up partial sums and write back result
 
165
  if (first_row + NUM_ROWS <= p.stride_d) {
166
  compute_outputs(first_row, NUM_ROWS);
167
  } else {
168
+ if (first_row >= p.stride_d) {
169
+ return;
170
+ }
171
  compute_outputs(first_row, p.stride_d - first_row);
172
  }
173
  }
ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.comp CHANGED
@@ -12,6 +12,9 @@
12
 
13
  layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
14
  layout (binding = 1) readonly buffer B {B_TYPE data_b[];};
 
 
 
15
  layout (binding = 2) writeonly buffer D {D_TYPE data_d[];};
16
  #ifdef MUL_MAT_ID
17
  layout (binding = 3) readonly buffer IDS {int data_ids[];};
 
12
 
13
  layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
14
  layout (binding = 1) readonly buffer B {B_TYPE data_b[];};
15
+ layout (binding = 1) readonly buffer BV2 {B_TYPE_VEC2 data_b_v2[];};
16
+ layout (binding = 1) readonly buffer BV4 {B_TYPE_VEC4 data_b_v4[];};
17
+
18
  layout (binding = 2) writeonly buffer D {D_TYPE data_d[];};
19
  #ifdef MUL_MAT_ID
20
  layout (binding = 3) readonly buffer IDS {int data_ids[];};
ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp CHANGED
@@ -9,6 +9,10 @@ shared FLOAT_TYPE tmp[32];
9
  void main() {
10
  const uint row = gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z;
11
 
 
 
 
 
12
  uint a_offset, b_offset, d_offset;
13
  get_offsets(a_offset, b_offset, d_offset);
14
 
 
9
  void main() {
10
  const uint row = gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z;
11
 
12
+ if (row >= p.stride_d) {
13
+ return;
14
+ }
15
+
16
  uint a_offset, b_offset, d_offset;
17
  get_offsets(a_offset, b_offset, d_offset);
18
 
ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp CHANGED
@@ -9,6 +9,10 @@ shared FLOAT_TYPE tmp[32];
9
  void main() {
10
  const uint row = gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z;
11
 
 
 
 
 
12
  uint a_offset, b_offset, d_offset;
13
  get_offsets(a_offset, b_offset, d_offset);
14
 
 
9
  void main() {
10
  const uint row = gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z;
11
 
12
+ if (row >= p.stride_d) {
13
+ return;
14
+ }
15
+
16
  uint a_offset, b_offset, d_offset;
17
  get_offsets(a_offset, b_offset, d_offset);
18
 
ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp CHANGED
@@ -8,30 +8,14 @@ layout(local_size_x = 32, local_size_y = 1, local_size_z = 1) in;
8
 
9
  shared FLOAT_TYPE tmp[32];
10
 
11
- // Declare aliased versions of A and B bindings that can use 16b/32b loads for
12
- // the quantized values, and vec4 loads for B.
13
- struct block_q4_K_u32
14
- {
15
- f16vec2 d;
16
- uint32_t scales[3*QUANT_K/64/4];
17
- uint32_t qs[QUANT_K/2/4];
18
- };
19
-
20
- struct block_q4_K_u16
21
- {
22
- f16vec2 d;
23
- uint16_t scales[3*QUANT_K/64/2];
24
- uint16_t qs[QUANT_K/2/2];
25
- };
26
-
27
- layout (binding = 0) readonly buffer A_u32 {block_q4_K_u32 data_a_u32[];};
28
- layout (binding = 0) readonly buffer A_u16 {block_q4_K_u16 data_a_u16[];};
29
- layout (binding = 1) readonly buffer BV4 {B_TYPE_VEC4 data_b_v4[];};
30
-
31
  // This shader assumes K_QUANTS_PER_ITERATION == 2 for alignment of loads
32
  void main() {
33
  const uint row = gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z;
34
 
 
 
 
 
35
  uint a_offset, b_offset, d_offset;
36
  get_offsets(a_offset, b_offset, d_offset);
37
 
@@ -64,9 +48,9 @@ void main() {
64
  const FLOAT_TYPE dall = FLOAT_TYPE(d.x);
65
  const FLOAT_TYPE dmin = FLOAT_TYPE(d.y);
66
 
67
- uint32_t scale0_u32 = data_a_u16[ib0 + i].scales[v_im ];
68
- uint32_t scale4_u32 = data_a_u16[ib0 + i].scales[v_im + 2];
69
- uint32_t scale8_u32 = data_a_u16[ib0 + i].scales[v_im + 4];
70
  uvec4 scale0 = uvec4(unpack8(scale0_u32));
71
  uvec4 scale4 = uvec4(unpack8(scale4_u32));
72
  uvec4 scale8 = uvec4(unpack8(scale8_u32));
@@ -80,8 +64,8 @@ void main() {
80
  const uint32_t sc6 = (((scale8.x >> 4) & 0x0f) | ((scale4.x & 0xc0) >> 2));
81
  const uint32_t sc7 = (((scale8.y >> 4) & 0x0f) | ((scale4.y & 0xc0) >> 2));
82
 
83
- uint32_t qs0_u32 = data_a_u32[ib0 + i].qs[q_offset / 4];
84
- uint32_t qs64_u32 = data_a_u32[ib0 + i].qs[q_offset / 4 + 16];
85
 
86
  uint32_t qs0_u32_lo4 = qs0_u32 & 0x0F0F0F0F;
87
  uint32_t qs0_u32_hi4 = (qs0_u32 >> 4) & 0x0F0F0F0F;
 
8
 
9
  shared FLOAT_TYPE tmp[32];
10
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  // This shader assumes K_QUANTS_PER_ITERATION == 2 for alignment of loads
12
  void main() {
13
  const uint row = gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z;
14
 
15
+ if (row >= p.stride_d) {
16
+ return;
17
+ }
18
+
19
  uint a_offset, b_offset, d_offset;
20
  get_offsets(a_offset, b_offset, d_offset);
21
 
 
48
  const FLOAT_TYPE dall = FLOAT_TYPE(d.x);
49
  const FLOAT_TYPE dmin = FLOAT_TYPE(d.y);
50
 
51
+ uint32_t scale0_u32 = data_a_packed16[ib0 + i].scales[v_im ];
52
+ uint32_t scale4_u32 = data_a_packed16[ib0 + i].scales[v_im + 2];
53
+ uint32_t scale8_u32 = data_a_packed16[ib0 + i].scales[v_im + 4];
54
  uvec4 scale0 = uvec4(unpack8(scale0_u32));
55
  uvec4 scale4 = uvec4(unpack8(scale4_u32));
56
  uvec4 scale8 = uvec4(unpack8(scale8_u32));
 
64
  const uint32_t sc6 = (((scale8.x >> 4) & 0x0f) | ((scale4.x & 0xc0) >> 2));
65
  const uint32_t sc7 = (((scale8.y >> 4) & 0x0f) | ((scale4.y & 0xc0) >> 2));
66
 
67
+ uint32_t qs0_u32 = data_a_packed32[ib0 + i].qs[q_offset / 4];
68
+ uint32_t qs64_u32 = data_a_packed32[ib0 + i].qs[q_offset / 4 + 16];
69
 
70
  uint32_t qs0_u32_lo4 = qs0_u32 & 0x0F0F0F0F;
71
  uint32_t qs0_u32_hi4 = (qs0_u32 >> 4) & 0x0F0F0F0F;
ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp CHANGED
@@ -1,5 +1,7 @@
1
  #version 450
2
 
 
 
3
  #include "mul_mat_vec_base.comp"
4
 
5
  layout(local_size_x = 32, local_size_y = 1, local_size_z = 1) in;
@@ -9,6 +11,10 @@ shared FLOAT_TYPE tmp[32];
9
  void main() {
10
  const uint row = gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z;
11
 
 
 
 
 
12
  uint a_offset, b_offset, d_offset;
13
  get_offsets(a_offset, b_offset, d_offset);
14
 
@@ -31,70 +37,106 @@ void main() {
31
  const uint8_t hm1 = uint8_t(1 << (2*v_im));
32
  const uint8_t hm2 = uint8_t(hm1 << 4);
33
 
34
- tmp[16 * ix + tid] = FLOAT_TYPE(0.0); // partial sum for thread in warp
35
 
36
  [[unroll]] for (uint i = ix; i < num_blocks_per_row; i += 2) {
37
  const uint y1_idx = i * QUANT_K + y_offset;
38
  const uint y2_idx = y1_idx + 128;
39
 
40
- const FLOAT_TYPE dall = FLOAT_TYPE(data_a[ib0 + i].d.x);
41
- const FLOAT_TYPE dmin = FLOAT_TYPE(data_a[ib0 + i].d.y);
42
-
43
- const uint8_t sc0 = uint8_t( data_a[ib0 + i].scales[v_im * 2 ] & 0x3f);
44
- const uint8_t sc1 = uint8_t( data_a[ib0 + i].scales[v_im * 2 + 1] & 0x3f);
45
- const uint8_t sc2 = uint8_t( data_a[ib0 + i].scales[v_im * 2 + 4] & 0x3f);
46
- const uint8_t sc3 = uint8_t( data_a[ib0 + i].scales[v_im * 2 + 5] & 0x3f);
47
- const uint8_t sc4 = uint8_t(( data_a[ib0 + i].scales[v_im * 2 + 8] & 0x0f) | ((data_a[ib0 + i].scales[v_im * 2 ] & 0xc0) >> 2));
48
- const uint8_t sc5 = uint8_t(( data_a[ib0 + i].scales[v_im * 2 + 9] & 0x0f) | ((data_a[ib0 + i].scales[v_im * 2 + 1] & 0xc0) >> 2));
49
- const uint8_t sc6 = uint8_t(((data_a[ib0 + i].scales[v_im * 2 + 8] >> 4) & 0x0f) | ((data_a[ib0 + i].scales[v_im * 2 + 4] & 0xc0) >> 2));
50
- const uint8_t sc7 = uint8_t(((data_a[ib0 + i].scales[v_im * 2 + 9] >> 4) & 0x0f) | ((data_a[ib0 + i].scales[v_im * 2 + 5] & 0xc0) >> 2));
51
-
52
- const uint8_t q4_0 = uint8_t(data_a[ib0 + i].qs[q_offset ] & 0xf);
53
- const uint8_t q4_1 = uint8_t(data_a[ib0 + i].qs[q_offset + 1] & 0xf);
54
- const uint8_t q4_2 = uint8_t(data_a[ib0 + i].qs[q_offset + 16] & 0xf);
55
- const uint8_t q4_3 = uint8_t(data_a[ib0 + i].qs[q_offset + 17] & 0xf);
56
- const uint8_t q4_4 = uint8_t(data_a[ib0 + i].qs[q_offset ] >> 4);
57
- const uint8_t q4_5 = uint8_t(data_a[ib0 + i].qs[q_offset + 1] >> 4);
58
- const uint8_t q4_6 = uint8_t(data_a[ib0 + i].qs[q_offset + 16] >> 4);
59
- const uint8_t q4_7 = uint8_t(data_a[ib0 + i].qs[q_offset + 17] >> 4);
60
- const uint8_t q4_8 = uint8_t(data_a[ib0 + i].qs[q_offset + 64] & 0xf);
61
- const uint8_t q4_9 = uint8_t(data_a[ib0 + i].qs[q_offset + 65] & 0xf);
62
- const uint8_t q4_10 = uint8_t(data_a[ib0 + i].qs[q_offset + 80] & 0xf);
63
- const uint8_t q4_11 = uint8_t(data_a[ib0 + i].qs[q_offset + 81] & 0xf);
64
- const uint8_t q4_12 = uint8_t(data_a[ib0 + i].qs[q_offset + 64] >> 4);
65
- const uint8_t q4_13 = uint8_t(data_a[ib0 + i].qs[q_offset + 65] >> 4);
66
- const uint8_t q4_14 = uint8_t(data_a[ib0 + i].qs[q_offset + 80] >> 4);
67
- const uint8_t q4_15 = uint8_t(data_a[ib0 + i].qs[q_offset + 81] >> 4);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
 
69
  const FLOAT_TYPE sx =
70
- fma(FLOAT_TYPE(data_b[b_offset + y1_idx ]), (q4_0 + (((data_a[ib0 + i].qh[l0 ] & hm1) != 0) ? 16 : 0)),
71
- fma(FLOAT_TYPE(data_b[b_offset + y1_idx + 1]), (q4_1 + (((data_a[ib0 + i].qh[l0 + 1] & hm1) != 0) ? 16 : 0)),
72
- fma(FLOAT_TYPE(data_b[b_offset + y1_idx + 16]), (q4_2 + (((data_a[ib0 + i].qh[l0 + 16] & hm1) != 0) ? 16 : 0)),
73
- FLOAT_TYPE(data_b[b_offset + y1_idx + 17]) * (q4_3 + (((data_a[ib0 + i].qh[l0 + 17] & hm1) != 0) ? 16 : 0)))));
74
  const FLOAT_TYPE sy =
75
- fma(FLOAT_TYPE(data_b[b_offset + y1_idx + 32]), (q4_4 + (((data_a[ib0 + i].qh[l0 ] & (hm1 << 1)) != 0) ? 16 : 0)),
76
- fma(FLOAT_TYPE(data_b[b_offset + y1_idx + 33]), (q4_5 + (((data_a[ib0 + i].qh[l0 + 1] & (hm1 << 1)) != 0) ? 16 : 0)),
77
- fma(FLOAT_TYPE(data_b[b_offset + y1_idx + 48]), (q4_6 + (((data_a[ib0 + i].qh[l0 + 16] & (hm1 << 1)) != 0) ? 16 : 0)),
78
- FLOAT_TYPE(data_b[b_offset + y1_idx + 49]) * (q4_7 + (((data_a[ib0 + i].qh[l0 + 17] & (hm1 << 1)) != 0) ? 16 : 0)))));
79
  const FLOAT_TYPE sz =
80
- fma(FLOAT_TYPE(data_b[b_offset + y2_idx ]), (q4_8 + (((data_a[ib0 + i].qh[l0 ] & hm2) != 0) ? 16 : 0)),
81
- fma(FLOAT_TYPE(data_b[b_offset + y2_idx + 1]), (q4_9 + (((data_a[ib0 + i].qh[l0 + 1] & hm2) != 0) ? 16 : 0)),
82
- fma(FLOAT_TYPE(data_b[b_offset + y2_idx + 16]), (q4_10 + (((data_a[ib0 + i].qh[l0 + 16] & hm2) != 0) ? 16 : 0)),
83
- FLOAT_TYPE(data_b[b_offset + y2_idx + 17]) * (q4_11 + (((data_a[ib0 + i].qh[l0 + 17] & hm2) != 0) ? 16 : 0)))));
84
  const FLOAT_TYPE sw =
85
- fma(FLOAT_TYPE(data_b[b_offset + y2_idx + 32]), (q4_12 + (((data_a[ib0 + i].qh[l0 ] & (hm2 << 1)) != 0) ? 16 : 0)),
86
- fma(FLOAT_TYPE(data_b[b_offset + y2_idx + 33]), (q4_13 + (((data_a[ib0 + i].qh[l0 + 1] & (hm2 << 1)) != 0) ? 16 : 0)),
87
- fma(FLOAT_TYPE(data_b[b_offset + y2_idx + 48]), (q4_14 + (((data_a[ib0 + i].qh[l0 + 16] & (hm2 << 1)) != 0) ? 16 : 0)),
88
- FLOAT_TYPE(data_b[b_offset + y2_idx + 49]) * (q4_15 + (((data_a[ib0 + i].qh[l0 + 17] & (hm2 << 1)) != 0) ? 16 : 0)))));
89
  const FLOAT_TYPE smin =
90
- fma(FLOAT_TYPE(data_b[b_offset + y1_idx ]) + FLOAT_TYPE(data_b[b_offset + y1_idx + 1 ]) + FLOAT_TYPE(data_b[b_offset + y1_idx + 16]) + FLOAT_TYPE(data_b[b_offset + y1_idx + 17]), sc2,
91
- fma(FLOAT_TYPE(data_b[b_offset + y1_idx + 32]) + FLOAT_TYPE(data_b[b_offset + y1_idx + 33]) + FLOAT_TYPE(data_b[b_offset + y1_idx + 48]) + FLOAT_TYPE(data_b[b_offset + y1_idx + 49]), sc3,
92
- fma(FLOAT_TYPE(data_b[b_offset + y2_idx ]) + FLOAT_TYPE(data_b[b_offset + y2_idx + 1 ]) + FLOAT_TYPE(data_b[b_offset + y2_idx + 16]) + FLOAT_TYPE(data_b[b_offset + y2_idx + 17]), sc6,
93
- (FLOAT_TYPE(data_b[b_offset + y2_idx + 32]) + FLOAT_TYPE(data_b[b_offset + y2_idx + 33]) + FLOAT_TYPE(data_b[b_offset + y2_idx + 48]) + FLOAT_TYPE(data_b[b_offset + y2_idx + 49])) * sc7)));
94
- const uint tmp_idx = 16 * ix + tid;
95
- tmp[tmp_idx] = fma(dall, fma(sx, sc0, fma(sy, sc1, fma(sz, sc4, sw * sc5))), fma(-dmin, smin, tmp[tmp_idx]));
96
  }
97
 
 
 
98
  // sum up partial sums and write back result
99
  barrier();
100
  [[unroll]] for (uint s = 16; s > 0; s >>= 1) {
 
1
  #version 450
2
 
3
+ #extension GL_EXT_shader_explicit_arithmetic_types : require
4
+
5
  #include "mul_mat_vec_base.comp"
6
 
7
  layout(local_size_x = 32, local_size_y = 1, local_size_z = 1) in;
 
11
  void main() {
12
  const uint row = gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z;
13
 
14
+ if (row >= p.stride_d) {
15
+ return;
16
+ }
17
+
18
  uint a_offset, b_offset, d_offset;
19
  get_offsets(a_offset, b_offset, d_offset);
20
 
 
37
  const uint8_t hm1 = uint8_t(1 << (2*v_im));
38
  const uint8_t hm2 = uint8_t(hm1 << 4);
39
 
40
+ FLOAT_TYPE temp = FLOAT_TYPE(0.0); // partial sum for thread in warp
41
 
42
  [[unroll]] for (uint i = ix; i < num_blocks_per_row; i += 2) {
43
  const uint y1_idx = i * QUANT_K + y_offset;
44
  const uint y2_idx = y1_idx + 128;
45
 
46
+ f16vec2 d = data_a[ib0 + i].d;
47
+ const FLOAT_TYPE dall = FLOAT_TYPE(d.x);
48
+ const FLOAT_TYPE dmin = FLOAT_TYPE(d.y);
49
+
50
+ uint32_t scale0_u32 = data_a_packed16[ib0 + i].scales[v_im ];
51
+ uint32_t scale4_u32 = data_a_packed16[ib0 + i].scales[v_im + 2];
52
+ uint32_t scale8_u32 = data_a_packed16[ib0 + i].scales[v_im + 4];
53
+ uvec4 scale0 = uvec4(unpack8(scale0_u32));
54
+ uvec4 scale4 = uvec4(unpack8(scale4_u32));
55
+ uvec4 scale8 = uvec4(unpack8(scale8_u32));
56
+
57
+ const uint32_t sc0 = ( scale0.x & 0x3f);
58
+ const uint32_t sc1 = ( scale0.y & 0x3f);
59
+ const uint32_t sc2 = ( scale4.x & 0x3f);
60
+ const uint32_t sc3 = ( scale4.y & 0x3f);
61
+ const uint32_t sc4 = (( scale8.x & 0x0f) | ((scale0.x & 0xc0) >> 2));
62
+ const uint32_t sc5 = (( scale8.y & 0x0f) | ((scale0.y & 0xc0) >> 2));
63
+ const uint32_t sc6 = (((scale8.x >> 4) & 0x0f) | ((scale4.x & 0xc0) >> 2));
64
+ const uint32_t sc7 = (((scale8.y >> 4) & 0x0f) | ((scale4.y & 0xc0) >> 2));
65
+
66
+ uint32_t qs0_16_u32 = uint32_t(data_a_packed16[ib0 + i].qs[q_offset / 2]) | (uint32_t(data_a_packed16[ib0 + i].qs[q_offset / 2 + 8]) << 16);
67
+ uint32_t qs64_80_u32 = uint32_t(data_a_packed16[ib0 + i].qs[q_offset / 2 + 32]) | (uint32_t(data_a_packed16[ib0 + i].qs[q_offset / 2 + 40]) << 16);
68
+
69
+ uint32_t qs0_16_u32_lo4 = qs0_16_u32 & 0x0F0F0F0F;
70
+ uint32_t qs0_16_u32_hi4 = (qs0_16_u32 >> 4) & 0x0F0F0F0F;
71
+ uint32_t qs64_80_u32_lo4 = qs64_80_u32 & 0x0F0F0F0F;
72
+ uint32_t qs64_80_u32_hi4 = (qs64_80_u32 >> 4) & 0x0F0F0F0F;
73
+
74
+ uvec4 qs0_16_lo4 = uvec4(unpack8(qs0_16_u32_lo4));
75
+ uvec4 qs64_80_lo4 = uvec4(unpack8(qs64_80_u32_lo4));
76
+ uvec4 qs0_16_hi4 = uvec4(unpack8(qs0_16_u32_hi4));
77
+ uvec4 qs64_80_hi4 = uvec4(unpack8(qs64_80_u32_hi4));
78
+
79
+ const uint32_t q4_0 = qs0_16_lo4.x;
80
+ const uint32_t q4_1 = qs0_16_lo4.y;
81
+ const uint32_t q4_2 = qs0_16_lo4.z;
82
+ const uint32_t q4_3 = qs0_16_lo4.w;
83
+ const uint32_t q4_4 = qs0_16_hi4.x;
84
+ const uint32_t q4_5 = qs0_16_hi4.y;
85
+ const uint32_t q4_6 = qs0_16_hi4.z;
86
+ const uint32_t q4_7 = qs0_16_hi4.w;
87
+ const uint32_t q4_8 = qs64_80_lo4.x;
88
+ const uint32_t q4_9 = qs64_80_lo4.y;
89
+ const uint32_t q4_10 = qs64_80_lo4.z;
90
+ const uint32_t q4_11 = qs64_80_lo4.w;
91
+ const uint32_t q4_12 = qs64_80_hi4.x;
92
+ const uint32_t q4_13 = qs64_80_hi4.y;
93
+ const uint32_t q4_14 = qs64_80_hi4.z;
94
+ const uint32_t q4_15 = qs64_80_hi4.w;
95
+
96
+ B_TYPE_VEC2 by10 = data_b_v2[(b_offset + y1_idx) / 2];
97
+ B_TYPE_VEC2 by116 = data_b_v2[(b_offset + y1_idx) / 2 + 8];
98
+ B_TYPE_VEC2 by132 = data_b_v2[(b_offset + y1_idx) / 2 + 16];
99
+ B_TYPE_VEC2 by148 = data_b_v2[(b_offset + y1_idx) / 2 + 24];
100
+ B_TYPE_VEC2 by20 = data_b_v2[(b_offset + y2_idx) / 2];
101
+ B_TYPE_VEC2 by216 = data_b_v2[(b_offset + y2_idx) / 2 + 8];
102
+ B_TYPE_VEC2 by232 = data_b_v2[(b_offset + y2_idx) / 2 + 16];
103
+ B_TYPE_VEC2 by248 = data_b_v2[(b_offset + y2_idx) / 2 + 24];
104
+
105
+ uint32_t qh0 = data_a_packed16[ib0 + i].qh[l0 / 2];
106
+ uint32_t qh1 = qh0 >> 8;
107
+ uint32_t qh16 = data_a_packed16[ib0 + i].qh[l0 / 2 + 8];
108
+ uint32_t qh17 = qh16 >> 8;
109
 
110
  const FLOAT_TYPE sx =
111
+ fma(FLOAT_TYPE(by10.x), (q4_0 + (((qh0 & hm1) != 0) ? 16 : 0)),
112
+ fma(FLOAT_TYPE(by10.y), (q4_1 + (((qh1 & hm1) != 0) ? 16 : 0)),
113
+ fma(FLOAT_TYPE(by116.x), (q4_2 + (((qh16 & hm1) != 0) ? 16 : 0)),
114
+ FLOAT_TYPE(by116.y) * (q4_3 + (((qh17 & hm1) != 0) ? 16 : 0)))));
115
  const FLOAT_TYPE sy =
116
+ fma(FLOAT_TYPE(by132.x), (q4_4 + (((qh0 & (hm1 << 1)) != 0) ? 16 : 0)),
117
+ fma(FLOAT_TYPE(by132.y), (q4_5 + (((qh1 & (hm1 << 1)) != 0) ? 16 : 0)),
118
+ fma(FLOAT_TYPE(by148.x), (q4_6 + (((qh16 & (hm1 << 1)) != 0) ? 16 : 0)),
119
+ FLOAT_TYPE(by148.y) * (q4_7 + (((qh17 & (hm1 << 1)) != 0) ? 16 : 0)))));
120
  const FLOAT_TYPE sz =
121
+ fma(FLOAT_TYPE(by20.x), (q4_8 + (((qh0 & hm2) != 0) ? 16 : 0)),
122
+ fma(FLOAT_TYPE(by20.y), (q4_9 + (((qh1 & hm2) != 0) ? 16 : 0)),
123
+ fma(FLOAT_TYPE(by216.x), (q4_10 + (((qh16 & hm2) != 0) ? 16 : 0)),
124
+ FLOAT_TYPE(by216.y) * (q4_11 + (((qh17 & hm2) != 0) ? 16 : 0)))));
125
  const FLOAT_TYPE sw =
126
+ fma(FLOAT_TYPE(by232.x), (q4_12 + (((qh0 & (hm2 << 1)) != 0) ? 16 : 0)),
127
+ fma(FLOAT_TYPE(by232.y), (q4_13 + (((qh1 & (hm2 << 1)) != 0) ? 16 : 0)),
128
+ fma(FLOAT_TYPE(by248.x), (q4_14 + (((qh16 & (hm2 << 1)) != 0) ? 16 : 0)),
129
+ FLOAT_TYPE(by248.y) * (q4_15 + (((qh17 & (hm2 << 1)) != 0) ? 16 : 0)))));
130
  const FLOAT_TYPE smin =
131
+ fma(FLOAT_TYPE(by10.x) + FLOAT_TYPE(by10.y) + FLOAT_TYPE(by116.x) + FLOAT_TYPE(by116.y), sc2,
132
+ fma(FLOAT_TYPE(by132.x) + FLOAT_TYPE(by132.y) + FLOAT_TYPE(by148.x) + FLOAT_TYPE(by148.y), sc3,
133
+ fma(FLOAT_TYPE(by20.x) + FLOAT_TYPE(by20.y) + FLOAT_TYPE(by216.x) + FLOAT_TYPE(by216.y), sc6,
134
+ (FLOAT_TYPE(by232.x) + FLOAT_TYPE(by232.y) + FLOAT_TYPE(by248.x) + FLOAT_TYPE(by248.y)) * sc7)));
135
+ temp = fma(dall, fma(sx, sc0, fma(sy, sc1, fma(sz, sc4, sw * sc5))), fma(-dmin, smin, temp));
 
136
  }
137
 
138
+ tmp[gl_LocalInvocationID.x] = temp;
139
+
140
  // sum up partial sums and write back result
141
  barrier();
142
  [[unroll]] for (uint s = 16; s > 0; s >>= 1) {
ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp CHANGED
@@ -1,5 +1,7 @@
1
  #version 450
2
 
 
 
3
  #include "mul_mat_vec_base.comp"
4
 
5
  layout(local_size_x = 32, local_size_y = 1, local_size_z = 1) in;
@@ -9,6 +11,10 @@ shared FLOAT_TYPE tmp[32];
9
  void main() {
10
  const uint row = gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z;
11
 
 
 
 
 
12
  uint a_offset, b_offset, d_offset;
13
  get_offsets(a_offset, b_offset, d_offset);
14
 
@@ -36,41 +42,66 @@ void main() {
36
  const uint s_offset = 8*v_im + is;
37
  const uint y_offset = 128*v_im + l0;
38
 
39
- tmp[16 * ix + tid] = FLOAT_TYPE(0.0); // partial sum for thread in warp
40
 
41
  [[unroll]] for (uint i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) {
42
  const uint y_idx = i * QUANT_K + y_offset;
43
 
44
  const FLOAT_TYPE d = FLOAT_TYPE(data_a[ib0 + i].d);
45
 
46
- #if K_QUANTS_PER_ITERATION == 1
47
- const uint tmp_idx = 16 * ix + tid;
48
- tmp[tmp_idx] = fma(FLOAT_TYPE(data_b[b_offset + y_idx + 0]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 0]) * d, FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + 0] & 0xF) | ((data_a[ib0 + i].qh[qh_offset + 0] & 0x03) << 4)) - 32),
49
- fma(FLOAT_TYPE(data_b[b_offset + y_idx + 16]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 1]) * d, FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + 16] & 0xF) | ((data_a[ib0 + i].qh[qh_offset + 16] & 0x03) << 4)) - 32),
50
- fma(FLOAT_TYPE(data_b[b_offset + y_idx + 32]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 2]) * d, FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + 32] & 0xF) | ((data_a[ib0 + i].qh[qh_offset + 0] & 0x0c) << 2)) - 32),
51
- fma(FLOAT_TYPE(data_b[b_offset + y_idx + 48]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 3]) * d, FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + 48] & 0xF) | ((data_a[ib0 + i].qh[qh_offset + 16] & 0x0c) << 2)) - 32),
52
- fma(FLOAT_TYPE(data_b[b_offset + y_idx + 64]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 4]) * d, FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + 0] >> 4) | ((data_a[ib0 + i].qh[qh_offset + 0] & 0x30) >> 0)) - 32),
53
- fma(FLOAT_TYPE(data_b[b_offset + y_idx + 80]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 5]) * d, FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + 16] >> 4) | ((data_a[ib0 + i].qh[qh_offset + 16] & 0x30) >> 0)) - 32),
54
- fma(FLOAT_TYPE(data_b[b_offset + y_idx + 96]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 6]) * d, FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + 32] >> 4) | ((data_a[ib0 + i].qh[qh_offset + 0] & 0xc0) >> 2)) - 32),
55
- fma(FLOAT_TYPE(data_b[b_offset + y_idx +112]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 7]) * d, FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + 48] >> 4) | ((data_a[ib0 + i].qh[qh_offset + 16] & 0xc0) >> 2)) - 32), tmp[tmp_idx]))))))));
56
- #else
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
  FLOAT_TYPE sum = FLOAT_TYPE(0.0);
58
  [[unroll]] for (int l = 0; l < 4; ++l) {
59
- sum = fma(FLOAT_TYPE(data_b[b_offset + y_idx + l+ 0]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 0]) * d, FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + l+ 0] & 0xF) | (((data_a[ib0 + i].qh[qh_offset + l] >> 0) & 3) << 4)) - 32),
60
- fma(FLOAT_TYPE(data_b[b_offset + y_idx + l+32]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 2]) * d, FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + l+32] & 0xF) | (((data_a[ib0 + i].qh[qh_offset + l] >> 2) & 3) << 4)) - 32),
61
- fma(FLOAT_TYPE(data_b[b_offset + y_idx + l+64]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 4]) * d, FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + l+ 0] >> 4) | (((data_a[ib0 + i].qh[qh_offset + l] >> 4) & 3) << 4)) - 32),
62
- fma(FLOAT_TYPE(data_b[b_offset + y_idx + l+96]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 6]) * d, FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + l+32] >> 4) | (((data_a[ib0 + i].qh[qh_offset + l] >> 6) & 3) << 4)) - 32), sum))));
63
  }
64
- tmp[16 * ix + tid] += sum;
65
- #endif
66
  }
67
 
 
 
68
  // sum up partial sums and write back result
69
  barrier();
70
  [[unroll]] for (uint s = 16; s > 0; s >>= 1) {
71
  if (tid < s) {
72
  tmp[tid] += tmp[tid + s];
73
- }
74
  barrier();
75
  }
76
  if (tid == 0) {
 
1
  #version 450
2
 
3
+ #extension GL_EXT_shader_explicit_arithmetic_types : require
4
+
5
  #include "mul_mat_vec_base.comp"
6
 
7
  layout(local_size_x = 32, local_size_y = 1, local_size_z = 1) in;
 
11
  void main() {
12
  const uint row = gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z;
13
 
14
+ if (row >= p.stride_d) {
15
+ return;
16
+ }
17
+
18
  uint a_offset, b_offset, d_offset;
19
  get_offsets(a_offset, b_offset, d_offset);
20
 
 
42
  const uint s_offset = 8*v_im + is;
43
  const uint y_offset = 128*v_im + l0;
44
 
45
+ FLOAT_TYPE temp = FLOAT_TYPE(0.0); // partial sum for thread in warp
46
 
47
  [[unroll]] for (uint i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) {
48
  const uint y_idx = i * QUANT_K + y_offset;
49
 
50
  const FLOAT_TYPE d = FLOAT_TYPE(data_a[ib0 + i].d);
51
 
52
+ FLOAT_TYPE scales[4];
53
+ scales[0] = FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 0]);
54
+ scales[1] = FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 2]);
55
+ scales[2] = FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 4]);
56
+ scales[3] = FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 6]);
57
+
58
+ uint32_t ql0_u32 = uint32_t(data_a_packed16[ib0 + i].ql[ql_offset / 2]) | (uint32_t(data_a_packed16[ib0 + i].ql[ql_offset / 2 + 1]) << 16);
59
+ uint32_t ql32_u32 = uint32_t(data_a_packed16[ib0 + i].ql[ql_offset / 2 + 16]) | (uint32_t(data_a_packed16[ib0 + i].ql[ql_offset / 2 + 17]) << 16);
60
+
61
+ uint32_t ql0_u32_lo4 = ql0_u32 & 0x0F0F0F0F;
62
+ uint32_t ql0_u32_hi4 = (ql0_u32 >> 4) & 0x0F0F0F0F;
63
+ uint32_t ql32_u32_lo4 = ql32_u32 & 0x0F0F0F0F;
64
+ uint32_t ql32_u32_hi4 = (ql32_u32 >> 4) & 0x0F0F0F0F;
65
+
66
+ uint32_t qh_u32 = uint32_t(data_a_packed16[ib0 + i].qh[qh_offset / 2]) | (uint32_t(data_a_packed16[ib0 + i].qh[qh_offset / 2 + 1]) << 16);
67
+ uint32_t qh0_u32 = (qh_u32 & 0x03030303) << 4;
68
+ uint32_t qh2_u32 = (qh_u32 & 0x0C0C0C0C) << 2;
69
+ uint32_t qh4_u32 = (qh_u32 & 0x30303030) << 0;
70
+ uint32_t qh6_u32 = (qh_u32 & 0xC0C0C0C0) >> 2;
71
+
72
+ uint32_t q0_u32 = ql0_u32_lo4 | qh0_u32;
73
+ uint32_t q1_u32 = ql32_u32_lo4 | qh2_u32;
74
+ uint32_t q2_u32 = ql0_u32_hi4 | qh4_u32;
75
+ uint32_t q3_u32 = ql32_u32_hi4 | qh6_u32;
76
+
77
+ uvec4 q0 = uvec4(unpack8(q0_u32));
78
+ uvec4 q1 = uvec4(unpack8(q1_u32));
79
+ uvec4 q2 = uvec4(unpack8(q2_u32));
80
+ uvec4 q3 = uvec4(unpack8(q3_u32));
81
+
82
+ B_TYPE_VEC4 by0 = data_b_v4[(b_offset + y_idx) / 4];
83
+ B_TYPE_VEC4 by32 = data_b_v4[(b_offset + y_idx) / 4 + 8];
84
+ B_TYPE_VEC4 by64 = data_b_v4[(b_offset + y_idx) / 4 + 16];
85
+ B_TYPE_VEC4 by96 = data_b_v4[(b_offset + y_idx) / 4 + 24];
86
+
87
  FLOAT_TYPE sum = FLOAT_TYPE(0.0);
88
  [[unroll]] for (int l = 0; l < 4; ++l) {
89
+ sum = fma(FLOAT_TYPE(by0[l]) * scales[0], FLOAT_TYPE(int8_t(q0[l]) - 32),
90
+ fma(FLOAT_TYPE(by32[l]) * scales[1], FLOAT_TYPE(int8_t(q1[l]) - 32),
91
+ fma(FLOAT_TYPE(by64[l]) * scales[2], FLOAT_TYPE(int8_t(q2[l]) - 32),
92
+ fma(FLOAT_TYPE(by96[l]) * scales[3], FLOAT_TYPE(int8_t(q3[l]) - 32), sum))));
93
  }
94
+ temp += sum * d;
 
95
  }
96
 
97
+ tmp[gl_LocalInvocationID.x] = temp;
98
+
99
  // sum up partial sums and write back result
100
  barrier();
101
  [[unroll]] for (uint s = 16; s > 0; s >>= 1) {
102
  if (tid < s) {
103
  tmp[tid] += tmp[tid + s];
104
+ }
105
  barrier();
106
  }
107
  if (tid == 0) {
ggml/src/ggml-vulkan/vulkan-shaders/types.comp CHANGED
@@ -1,6 +1,8 @@
1
- #if !defined(DATA_A_F32) && !defined(DATA_A_F16)
2
- #extension GL_EXT_shader_explicit_arithmetic_types_int8 : require
3
- #endif
 
 
4
 
5
  #if defined(DATA_A_F32)
6
  #define QUANT_K 1
@@ -38,8 +40,14 @@ struct block_q4_0
38
  float16_t d;
39
  uint8_t qs[16];
40
  };
 
 
 
 
 
41
 
42
  #define A_TYPE block_q4_0
 
43
  #endif
44
 
45
  #if defined(DATA_A_Q4_1)
@@ -54,7 +62,15 @@ struct block_q4_1
54
  uint8_t qs[16];
55
  };
56
 
 
 
 
 
 
 
 
57
  #define A_TYPE block_q4_1
 
58
  #endif
59
 
60
  #if defined(DATA_A_Q5_0)
@@ -70,7 +86,15 @@ struct block_q5_0
70
  uint8_t qs[16];
71
  };
72
 
 
 
 
 
 
 
 
73
  #define A_TYPE block_q5_0
 
74
  #endif
75
 
76
  #if defined(DATA_A_Q5_1)
@@ -87,7 +111,16 @@ struct block_q5_1
87
  uint8_t qs[16];
88
  };
89
 
 
 
 
 
 
 
 
 
90
  #define A_TYPE block_q5_1
 
91
  #endif
92
 
93
  #if defined(DATA_A_Q8_0)
@@ -100,8 +133,14 @@ struct block_q8_0
100
  float16_t d;
101
  int8_t qs[32];
102
  };
 
 
 
 
 
103
 
104
  #define A_TYPE block_q8_0
 
105
  #endif
106
 
107
  // K-quants
@@ -116,7 +155,23 @@ struct block_q2_K
116
  f16vec2 d;
117
  };
118
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
119
  #define A_TYPE block_q2_K
 
 
120
  #endif
121
 
122
  #if defined(DATA_A_Q3_K)
@@ -131,7 +186,16 @@ struct block_q3_K
131
  float16_t d;
132
  };
133
 
 
 
 
 
 
 
 
 
134
  #define A_TYPE block_q3_K
 
135
  #endif
136
 
137
  #if defined(DATA_A_Q4_K)
@@ -145,7 +209,23 @@ struct block_q4_K
145
  uint8_t qs[QUANT_K/2];
146
  };
147
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
148
  #define A_TYPE block_q4_K
 
 
149
  #endif
150
 
151
  #if defined(DATA_A_Q5_K)
@@ -160,7 +240,16 @@ struct block_q5_K
160
  uint8_t qs[QUANT_K/2];
161
  };
162
 
 
 
 
 
 
 
 
 
163
  #define A_TYPE block_q5_K
 
164
  #endif
165
 
166
  #if defined(DATA_A_Q6_K)
@@ -175,7 +264,16 @@ struct block_q6_K
175
  float16_t d;
176
  };
177
 
 
 
 
 
 
 
 
 
178
  #define A_TYPE block_q6_K
 
179
  #endif
180
 
181
  // IQuants
@@ -191,10 +289,19 @@ struct block_iq4_nl
191
  uint8_t qs[QUANT_K/2];
192
  };
193
 
 
 
 
 
 
 
194
  #define A_TYPE block_iq4_nl
 
195
 
196
  const int8_t kvalues_iq4nl[16] = {
197
  int8_t(-127), int8_t(-104), int8_t(-83), int8_t(-65), int8_t(-49), int8_t(-35), int8_t(-22), int8_t(-10),
198
  int8_t(1), int8_t(13), int8_t(25), int8_t(38), int8_t(53), int8_t(69), int8_t(89), int8_t(113)
199
  };
200
  #endif
 
 
 
1
+
2
+ #if !defined(GGML_TYPES_COMP)
3
+ #define GGML_TYPES_COMP
4
+
5
+ #extension GL_EXT_shader_explicit_arithmetic_types : require
6
 
7
  #if defined(DATA_A_F32)
8
  #define QUANT_K 1
 
40
  float16_t d;
41
  uint8_t qs[16];
42
  };
43
+ struct block_q4_0_packed16
44
+ {
45
+ float16_t d;
46
+ uint16_t qs[16/2];
47
+ };
48
 
49
  #define A_TYPE block_q4_0
50
+ #define A_TYPE_PACKED16 block_q4_0_packed16
51
  #endif
52
 
53
  #if defined(DATA_A_Q4_1)
 
62
  uint8_t qs[16];
63
  };
64
 
65
+ struct block_q4_1_packed16
66
+ {
67
+ float16_t d;
68
+ float16_t m;
69
+ uint16_t qs[16/2];
70
+ };
71
+
72
  #define A_TYPE block_q4_1
73
+ #define A_TYPE_PACKED16 block_q4_1_packed16
74
  #endif
75
 
76
  #if defined(DATA_A_Q5_0)
 
86
  uint8_t qs[16];
87
  };
88
 
89
+ struct block_q5_0_packed16
90
+ {
91
+ float16_t d;
92
+ uint16_t qh[2];
93
+ uint16_t qs[16/2];
94
+ };
95
+
96
  #define A_TYPE block_q5_0
97
+ #define A_TYPE_PACKED16 block_q5_0_packed16
98
  #endif
99
 
100
  #if defined(DATA_A_Q5_1)
 
111
  uint8_t qs[16];
112
  };
113
 
114
+ struct block_q5_1_packed16
115
+ {
116
+ float16_t d;
117
+ float16_t m;
118
+ uint qh;
119
+ uint16_t qs[16/2];
120
+ };
121
+
122
  #define A_TYPE block_q5_1
123
+ #define A_TYPE_PACKED16 block_q5_1_packed16
124
  #endif
125
 
126
  #if defined(DATA_A_Q8_0)
 
133
  float16_t d;
134
  int8_t qs[32];
135
  };
136
+ struct block_q8_0_packed16
137
+ {
138
+ float16_t d;
139
+ uint16_t qs[32/2];
140
+ };
141
 
142
  #define A_TYPE block_q8_0
143
+ #define A_TYPE_PACKED16 block_q8_0_packed16
144
  #endif
145
 
146
  // K-quants
 
155
  f16vec2 d;
156
  };
157
 
158
+ struct block_q2_K_packed16
159
+ {
160
+ uint16_t scales[QUANT_K/16/2];
161
+ uint16_t qs[QUANT_K/4/2];
162
+ f16vec2 d;
163
+ };
164
+
165
+ struct block_q2_K_packed32
166
+ {
167
+ uint32_t scales[QUANT_K/16/4];
168
+ uint32_t qs[QUANT_K/4/4];
169
+ f16vec2 d;
170
+ };
171
+
172
  #define A_TYPE block_q2_K
173
+ #define A_TYPE_PACKED16 block_q2_K_packed16
174
+ #define A_TYPE_PACKED32 block_q2_K_packed32
175
  #endif
176
 
177
  #if defined(DATA_A_Q3_K)
 
186
  float16_t d;
187
  };
188
 
189
+ struct block_q3_K_packed16
190
+ {
191
+ uint16_t hmask[QUANT_K/8/2];
192
+ uint16_t qs[QUANT_K/4/2];
193
+ uint16_t scales[12/2];
194
+ float16_t d;
195
+ };
196
+
197
  #define A_TYPE block_q3_K
198
+ #define A_TYPE_PACKED16 block_q3_K_packed16
199
  #endif
200
 
201
  #if defined(DATA_A_Q4_K)
 
209
  uint8_t qs[QUANT_K/2];
210
  };
211
 
212
+ struct block_q4_K_packed16
213
+ {
214
+ f16vec2 d;
215
+ uint16_t scales[3*QUANT_K/64/2];
216
+ uint16_t qs[QUANT_K/2/2];
217
+ };
218
+
219
+ struct block_q4_K_packed32
220
+ {
221
+ f16vec2 d;
222
+ uint32_t scales[3*QUANT_K/64/4];
223
+ uint32_t qs[QUANT_K/2/4];
224
+ };
225
+
226
  #define A_TYPE block_q4_K
227
+ #define A_TYPE_PACKED16 block_q4_K_packed16
228
+ #define A_TYPE_PACKED32 block_q4_K_packed32
229
  #endif
230
 
231
  #if defined(DATA_A_Q5_K)
 
240
  uint8_t qs[QUANT_K/2];
241
  };
242
 
243
+ struct block_q5_K_packed16
244
+ {
245
+ f16vec2 d;
246
+ uint16_t scales[12/2];
247
+ uint16_t qh[QUANT_K/8/2];
248
+ uint16_t qs[QUANT_K/2/2];
249
+ };
250
+
251
  #define A_TYPE block_q5_K
252
+ #define A_TYPE_PACKED16 block_q5_K_packed16
253
  #endif
254
 
255
  #if defined(DATA_A_Q6_K)
 
264
  float16_t d;
265
  };
266
 
267
+ struct block_q6_K_packed16
268
+ {
269
+ uint16_t ql[QUANT_K/2/2];
270
+ uint16_t qh[QUANT_K/4/2];
271
+ int8_t scales[QUANT_K/16];
272
+ float16_t d;
273
+ };
274
+
275
  #define A_TYPE block_q6_K
276
+ #define A_TYPE_PACKED16 block_q6_K_packed16
277
  #endif
278
 
279
  // IQuants
 
289
  uint8_t qs[QUANT_K/2];
290
  };
291
 
292
+ struct block_iq4_nl_packed16
293
+ {
294
+ float16_t d;
295
+ uint16_t qs[QUANT_K/2/2];
296
+ };
297
+
298
  #define A_TYPE block_iq4_nl
299
+ #define A_TYPE_PACKED16 block_iq4_nl_packed16
300
 
301
  const int8_t kvalues_iq4nl[16] = {
302
  int8_t(-127), int8_t(-104), int8_t(-83), int8_t(-65), int8_t(-49), int8_t(-35), int8_t(-22), int8_t(-10),
303
  int8_t(1), int8_t(13), int8_t(25), int8_t(38), int8_t(53), int8_t(69), int8_t(89), int8_t(113)
304
  };
305
  #endif
306
+
307
+ #endif // !defined(GGML_TYPES_COMP)
ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp CHANGED
@@ -317,10 +317,10 @@ void process_shaders() {
317
  std::string data_a_key = "DATA_A_" + to_uppercase(tname);
318
  std::string shader = (string_ends_with(tname, "_k")) ? "mul_mat_vec_" + tname + ".comp" : "mul_mat_vec.comp";
319
 
320
- string_to_spv("mul_mat_vec_" + tname + "_f32_f32", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}}));
321
- string_to_spv("mul_mat_vec_" + tname + "_f16_f32", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float16_t"}, {"B_TYPE_VEC4", "f16vec4"}, {"D_TYPE", "float"}}));
322
 
323
- string_to_spv("mul_mat_vec_id_" + tname + "_f32", shader, merge_maps(base_dict, {{"MUL_MAT_ID", "1"}, {data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}}));
324
 
325
  // Dequant shaders
326
  if (tname != "f16") {
 
317
  std::string data_a_key = "DATA_A_" + to_uppercase(tname);
318
  std::string shader = (string_ends_with(tname, "_k")) ? "mul_mat_vec_" + tname + ".comp" : "mul_mat_vec.comp";
319
 
320
+ string_to_spv("mul_mat_vec_" + tname + "_f32_f32", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPE_VEC2", "vec2"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}}));
321
+ string_to_spv("mul_mat_vec_" + tname + "_f16_f32", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float16_t"}, {"B_TYPE_VEC2", "f16vec2"}, {"B_TYPE_VEC4", "f16vec4"}, {"D_TYPE", "float"}}));
322
 
323
+ string_to_spv("mul_mat_vec_id_" + tname + "_f32", shader, merge_maps(base_dict, {{"MUL_MAT_ID", "1"}, {data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPE_VEC2", "vec2"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}}));
324
 
325
  // Dequant shaders
326
  if (tname != "f16") {