jeffbolznv commited on
Commit
6868981
·
1 Parent(s): 8707beb

vulkan: Optimize mul_mat_vec p021 and nc shaders (llama/12505)

Browse files

* tests: add mul_mat perf/functional tests for p021/nc vulkan shaders

* vulkan: Optimize mul_mat_vec p021 and nc shaders.

These shaders are used in attention calculations, and when the KV cache grows
large they start to dominate the run time. For the nc shader (which is called
with large 'k' dimension), use unrolling and vector loads. For the p021 shader
(which is called with large 'm' and small 'k' dimensions), take advantage of
grouped query attention to reuse loads from the A matrix for the whole group,
and reduce the number of workgroups (too much overhead from tiny dispatches).

Using subgroupAdd in the p021 shader also helps, use that conditionally.

ggml/src/ggml-vulkan/ggml-vulkan.cpp CHANGED
@@ -149,6 +149,7 @@ class vk_perf_logger;
149
  static void ggml_vk_destroy_buffer(vk_buffer& buf);
150
 
151
  static constexpr uint32_t mul_mat_vec_max_cols = 8;
 
152
 
153
  enum vk_device_architecture {
154
  OTHER,
@@ -231,6 +232,7 @@ struct vk_device_struct {
231
  bool uma;
232
  bool prefer_host_memory;
233
  bool float_controls_rte_fp16;
 
234
 
235
  bool subgroup_size_control;
236
  uint32_t subgroup_min_size;
@@ -277,7 +279,7 @@ struct vk_device_struct {
277
  vk_pipeline pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_COUNT][mul_mat_vec_max_cols];
278
  vk_pipeline pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_COUNT];
279
 
280
- vk_pipeline pipeline_mul_mat_vec_p021_f16_f32;
281
  vk_pipeline pipeline_mul_mat_vec_nc_f16_f32;
282
  vk_pipeline pipeline_get_rows[GGML_TYPE_COUNT];
283
  vk_pipeline pipeline_get_rows_f32[GGML_TYPE_COUNT];
@@ -2265,7 +2267,13 @@ static void ggml_vk_load_shaders(vk_device& device) {
2265
 
2266
  ggml_vk_create_pipeline(device, device->pipeline_matmul_split_k_reduce, "split_k_reduce", split_k_reduce_len, split_k_reduce_data, "main", 2, 2 * sizeof(uint32_t), {256 * 4, 1, 1}, {}, 1);
2267
 
2268
- ggml_vk_create_pipeline(device, device->pipeline_mul_mat_vec_p021_f16_f32, "mul_mat_vec_p021_f16_f32", mul_mat_vec_p021_f16_f32_len, mul_mat_vec_p021_f16_f32_data, "main", 3, 6 * sizeof(uint32_t), {1, 1, 1}, {}, 1);
 
 
 
 
 
 
2269
  ggml_vk_create_pipeline(device, device->pipeline_mul_mat_vec_nc_f16_f32, "mul_mat_vec_nc_f16_f32", mul_mat_vec_nc_f16_f32_len, mul_mat_vec_nc_f16_f32_data, "main", 3, 7 * sizeof(uint32_t), {1, 1, 1}, {}, 1);
2270
 
2271
  ggml_vk_create_pipeline(device, device->pipeline_norm_f32, "norm_f32", norm_f32_len, norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
@@ -2479,13 +2487,15 @@ static vk_device ggml_vk_get_device(size_t idx) {
2479
  vk::PhysicalDeviceDriverProperties driver_props;
2480
  vk::PhysicalDeviceShaderSMBuiltinsPropertiesNV sm_props;
2481
  vk::PhysicalDeviceShaderCoreProperties2AMD amd_shader_core_properties2_props;
 
2482
  vk::PhysicalDeviceVulkan12Properties vk12_props;
2483
  vk::PhysicalDeviceSubgroupSizeControlPropertiesEXT subgroup_size_control_props;
2484
 
2485
  props2.pNext = &props3;
2486
  props3.pNext = &subgroup_props;
2487
  subgroup_props.pNext = &driver_props;
2488
- driver_props.pNext = &vk12_props;
 
2489
 
2490
  VkBaseOutStructure * last_struct = (VkBaseOutStructure *)&vk12_props;
2491
 
@@ -2549,6 +2559,9 @@ static vk_device ggml_vk_get_device(size_t idx) {
2549
  }
2550
  device->float_controls_rte_fp16 = vk12_props.shaderRoundingModeRTEFloat16;
2551
 
 
 
 
2552
  const bool force_disable_f16 = getenv("GGML_VK_DISABLE_F16") != nullptr;
2553
 
2554
  device->fp16 = !force_disable_f16 && fp16_storage && fp16_compute;
@@ -4635,9 +4648,15 @@ static void ggml_vk_mul_mat_vec_p021_f16_f32(ggml_backend_vk_context * ctx, vk_c
4635
  const uint64_t qy_sz = ggml_type_size(src1->type) * y_ne / ggml_blck_size(src1->type);
4636
  const uint64_t d_sz = sizeof(float) * d_ne;
4637
 
 
 
 
 
 
 
4638
  if (dryrun) {
4639
  // Request descriptor sets
4640
- ggml_pipeline_request_descriptor_sets(ctx->device, ctx->device->pipeline_mul_mat_vec_p021_f16_f32, 1);
4641
  return;
4642
  }
4643
 
@@ -4661,8 +4680,15 @@ static void ggml_vk_mul_mat_vec_p021_f16_f32(ggml_backend_vk_context * ctx, vk_c
4661
 
4662
  // compute
4663
  const std::array<uint32_t, 6> pc = { (uint32_t)ne00, (uint32_t)ne01, (uint32_t)ne02, (uint32_t)ne12, (uint32_t)(qy_shader_offset / ggml_type_size(src1->type)), (uint32_t)(d_shader_offset / ggml_type_size(dst->type)) };
 
 
 
 
 
 
 
4664
  ggml_vk_sync_buffers(subctx);
4665
- ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_mul_mat_vec_p021_f16_f32, { vk_subbuffer{ d_Qx, qx_buf_offset, qx_sz }, vk_subbuffer{ d_Qy, qy_buffer_offset, qy_sz + qy_shader_offset }, vk_subbuffer{ d_D, d_buffer_offset, d_sz + d_shader_offset } }, 6 * sizeof(uint32_t), &pc, { 1, (uint32_t)ne01, (uint32_t)ne12 });
4666
  }
4667
 
4668
  static void ggml_vk_mul_mat_vec_nc_f16_f32(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
 
149
  static void ggml_vk_destroy_buffer(vk_buffer& buf);
150
 
151
  static constexpr uint32_t mul_mat_vec_max_cols = 8;
152
+ static constexpr uint32_t p021_max_gqa_ratio = 8;
153
 
154
  enum vk_device_architecture {
155
  OTHER,
 
232
  bool uma;
233
  bool prefer_host_memory;
234
  bool float_controls_rte_fp16;
235
+ bool subgroup_add;
236
 
237
  bool subgroup_size_control;
238
  uint32_t subgroup_min_size;
 
279
  vk_pipeline pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_COUNT][mul_mat_vec_max_cols];
280
  vk_pipeline pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_COUNT];
281
 
282
+ vk_pipeline pipeline_mul_mat_vec_p021_f16_f32[p021_max_gqa_ratio];
283
  vk_pipeline pipeline_mul_mat_vec_nc_f16_f32;
284
  vk_pipeline pipeline_get_rows[GGML_TYPE_COUNT];
285
  vk_pipeline pipeline_get_rows_f32[GGML_TYPE_COUNT];
 
2267
 
2268
  ggml_vk_create_pipeline(device, device->pipeline_matmul_split_k_reduce, "split_k_reduce", split_k_reduce_len, split_k_reduce_data, "main", 2, 2 * sizeof(uint32_t), {256 * 4, 1, 1}, {}, 1);
2269
 
2270
+ for (uint32_t i = 0; i < p021_max_gqa_ratio; ++i) {
2271
+ if (device->subgroup_add && device->subgroup_require_full_support) {
2272
+ ggml_vk_create_pipeline(device, device->pipeline_mul_mat_vec_p021_f16_f32[i], "mul_mat_vec_p021_f16_f32"+std::to_string(i+1), mul_mat_vec_p021_f16_f32_subgroup_add_len, mul_mat_vec_p021_f16_f32_subgroup_add_data, "main", 3, 6 * sizeof(uint32_t), {1, 1, 1}, {device->subgroup_size, i + 1}, 1, true, true);
2273
+ } else {
2274
+ ggml_vk_create_pipeline(device, device->pipeline_mul_mat_vec_p021_f16_f32[i], "mul_mat_vec_p021_f16_f32"+std::to_string(i+1), mul_mat_vec_p021_f16_f32_len, mul_mat_vec_p021_f16_f32_data, "main", 3, 6 * sizeof(uint32_t), {1, 1, 1}, {device->subgroup_size, i + 1}, 1, true);
2275
+ }
2276
+ }
2277
  ggml_vk_create_pipeline(device, device->pipeline_mul_mat_vec_nc_f16_f32, "mul_mat_vec_nc_f16_f32", mul_mat_vec_nc_f16_f32_len, mul_mat_vec_nc_f16_f32_data, "main", 3, 7 * sizeof(uint32_t), {1, 1, 1}, {}, 1);
2278
 
2279
  ggml_vk_create_pipeline(device, device->pipeline_norm_f32, "norm_f32", norm_f32_len, norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
 
2487
  vk::PhysicalDeviceDriverProperties driver_props;
2488
  vk::PhysicalDeviceShaderSMBuiltinsPropertiesNV sm_props;
2489
  vk::PhysicalDeviceShaderCoreProperties2AMD amd_shader_core_properties2_props;
2490
+ vk::PhysicalDeviceVulkan11Properties vk11_props;
2491
  vk::PhysicalDeviceVulkan12Properties vk12_props;
2492
  vk::PhysicalDeviceSubgroupSizeControlPropertiesEXT subgroup_size_control_props;
2493
 
2494
  props2.pNext = &props3;
2495
  props3.pNext = &subgroup_props;
2496
  subgroup_props.pNext = &driver_props;
2497
+ driver_props.pNext = &vk11_props;
2498
+ vk11_props.pNext = &vk12_props;
2499
 
2500
  VkBaseOutStructure * last_struct = (VkBaseOutStructure *)&vk12_props;
2501
 
 
2559
  }
2560
  device->float_controls_rte_fp16 = vk12_props.shaderRoundingModeRTEFloat16;
2561
 
2562
+ device->subgroup_add = (vk11_props.subgroupSupportedStages & vk::ShaderStageFlagBits::eCompute) &&
2563
+ (vk11_props.subgroupSupportedOperations & vk::SubgroupFeatureFlagBits::eArithmetic);
2564
+
2565
  const bool force_disable_f16 = getenv("GGML_VK_DISABLE_F16") != nullptr;
2566
 
2567
  device->fp16 = !force_disable_f16 && fp16_storage && fp16_compute;
 
4648
  const uint64_t qy_sz = ggml_type_size(src1->type) * y_ne / ggml_blck_size(src1->type);
4649
  const uint64_t d_sz = sizeof(float) * d_ne;
4650
 
4651
+ // With grouped query attention there are > 1 Q matrices per K, V matrix.
4652
+ uint32_t gqa_ratio = (uint32_t)ne12 / (uint32_t)ne02;
4653
+ if (gqa_ratio > 8 || gqa_ratio == 0 || ne12 != ne02 * gqa_ratio) {
4654
+ gqa_ratio = 1;
4655
+ }
4656
+
4657
  if (dryrun) {
4658
  // Request descriptor sets
4659
+ ggml_pipeline_request_descriptor_sets(ctx->device, ctx->device->pipeline_mul_mat_vec_p021_f16_f32[gqa_ratio - 1], 1);
4660
  return;
4661
  }
4662
 
 
4680
 
4681
  // compute
4682
  const std::array<uint32_t, 6> pc = { (uint32_t)ne00, (uint32_t)ne01, (uint32_t)ne02, (uint32_t)ne12, (uint32_t)(qy_shader_offset / ggml_type_size(src1->type)), (uint32_t)(d_shader_offset / ggml_type_size(dst->type)) };
4683
+
4684
+ uint32_t workgroups_z = (uint32_t)ne12;
4685
+ // When gqa_ratio > 1, each invocation does multiple rows and we can launch fewer workgroups
4686
+ if (gqa_ratio > 1) {
4687
+ workgroups_z /= gqa_ratio;
4688
+ }
4689
+
4690
  ggml_vk_sync_buffers(subctx);
4691
+ ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_mul_mat_vec_p021_f16_f32[gqa_ratio - 1], { vk_subbuffer{ d_Qx, qx_buf_offset, qx_sz }, vk_subbuffer{ d_Qy, qy_buffer_offset, qy_sz + qy_shader_offset }, vk_subbuffer{ d_D, d_buffer_offset, d_sz + d_shader_offset } }, 6 * sizeof(uint32_t), &pc, { 1, (uint32_t)ne01, workgroups_z });
4692
  }
4693
 
4694
  static void ggml_vk_mul_mat_vec_nc_f16_f32(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_nc.comp CHANGED
@@ -12,6 +12,9 @@ layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
12
  layout (binding = 1) readonly buffer B {B_TYPE data_b[];};
13
  layout (binding = 2) writeonly buffer D {D_TYPE dst[];};
14
 
 
 
 
15
  layout (push_constant) uniform parameter
16
  {
17
  uint ncols_x;
@@ -37,25 +40,66 @@ void main() {
37
 
38
  const uint idst = channel*nrows_dst + row_dst;
39
 
40
- tmp[tid] = 0.0f;
41
 
42
- for (uint col_x0 = 0; col_x0 < p.ncols_x; col_x0 += BLOCK_SIZE) {
43
- const uint col_x = col_x0 + tid;
44
 
45
- if (col_x >= p.ncols_x) {
46
- break;
47
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
 
49
- const uint row_y = col_x;
50
 
51
- const uint ix = channel_x*p.channel_stride_x + row_x*p.row_stride_x + col_x;
52
- const uint iy = channel*nrows_y + row_y;
53
 
54
- const FLOAT_TYPE xi = FLOAT_TYPE(data_a[ix]);
 
55
 
56
- tmp[tid] = fma(xi, FLOAT_TYPE(data_b[iy]), tmp[tid]);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
  }
58
 
 
 
59
  // sum up partial sums and write back result
60
  barrier();
61
  [[unroll]] for (int s = BLOCK_SIZE / 2; s > 0; s >>= 1) {
 
12
  layout (binding = 1) readonly buffer B {B_TYPE data_b[];};
13
  layout (binding = 2) writeonly buffer D {D_TYPE dst[];};
14
 
15
+ layout (binding = 0) readonly buffer AV4 {A_TYPE_VEC4 data_a_v4[];};
16
+ layout (binding = 1) readonly buffer BV4 {B_TYPE_VEC4 data_b_v4[];};
17
+
18
  layout (push_constant) uniform parameter
19
  {
20
  uint ncols_x;
 
40
 
41
  const uint idst = channel*nrows_dst + row_dst;
42
 
43
+ FLOAT_TYPE temp = 0.0f;
44
 
45
+ // Detect alignment for vector loads
46
+ bool is_aligned = (p.ncols_x % 4) == 0 && (p.row_stride_x % 4) == 0 && (p.channel_stride_x % 4) == 0;
47
 
48
+ for (uint col_x0 = 0; col_x0 < p.ncols_x;) {
49
+
50
+ // Unroll 2x and do vec4 loads if aligned
51
+ const uint unroll_count = 2;
52
+ if (col_x0 + unroll_count * 4 * BLOCK_SIZE <= p.ncols_x && is_aligned) {
53
+ [[unroll]] for (uint i = 0; i < unroll_count; ++i) {
54
+ const uint col_x = col_x0 + 4*tid;
55
+
56
+ const uint row_y = col_x;
57
+
58
+ const uint ix = channel_x*p.channel_stride_x + row_x*p.row_stride_x + col_x;
59
+ const uint iy = channel*nrows_y + row_y;
60
+
61
+ const vec4 av4 = vec4(data_a_v4[ix / 4]);
62
+ const vec4 bv4 = vec4(data_b_v4[iy / 4]);
63
+
64
+ temp += dot(av4, bv4);
65
+
66
+ col_x0 += 4*BLOCK_SIZE;
67
+ }
68
+ // do vec4 loads if aligned
69
+ } else if (col_x0 + 4*BLOCK_SIZE <= p.ncols_x && is_aligned) {
70
+ const uint col_x = col_x0 + 4*tid;
71
 
72
+ const uint row_y = col_x;
73
 
74
+ const uint ix = channel_x*p.channel_stride_x + row_x*p.row_stride_x + col_x;
75
+ const uint iy = channel*nrows_y + row_y;
76
 
77
+ const vec4 av4 = vec4(data_a_v4[ix / 4]);
78
+ const vec4 bv4 = vec4(data_b_v4[iy / 4]);
79
 
80
+ temp += dot(av4, bv4);
81
+
82
+ col_x0 += 4*BLOCK_SIZE;
83
+ } else {
84
+ const uint col_x = col_x0 + tid;
85
+ if (col_x >= p.ncols_x) {
86
+ break;
87
+ }
88
+
89
+ const uint row_y = col_x;
90
+
91
+ const uint ix = channel_x*p.channel_stride_x + row_x*p.row_stride_x + col_x;
92
+ const uint iy = channel*nrows_y + row_y;
93
+
94
+ const FLOAT_TYPE xi = FLOAT_TYPE(data_a[ix]);
95
+
96
+ temp = fma(xi, FLOAT_TYPE(data_b[iy]), temp);
97
+ col_x0 += BLOCK_SIZE;
98
+ }
99
  }
100
 
101
+ tmp[tid] = temp;
102
+
103
  // sum up partial sums and write back result
104
  barrier();
105
  [[unroll]] for (int s = BLOCK_SIZE / 2; s > 0; s >>= 1) {
ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_p021.comp CHANGED
@@ -2,16 +2,25 @@
2
 
3
  #extension GL_EXT_control_flow_attributes : enable
4
  #extension GL_EXT_shader_16bit_storage : require
 
 
 
5
 
6
- #define BLOCK_SIZE 32
7
  #define FLOAT_TYPE float
8
 
9
- layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in;
10
 
11
  layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
12
  layout (binding = 1) readonly buffer B {B_TYPE data_b[];};
13
  layout (binding = 2) writeonly buffer D {D_TYPE dst[];};
14
 
 
 
 
 
 
 
 
15
  layout (push_constant) uniform parameter
16
  {
17
  uint ncols_x;
@@ -22,52 +31,124 @@ layout (push_constant) uniform parameter
22
  uint d_offset;
23
  } p;
24
 
25
- shared FLOAT_TYPE tmp[BLOCK_SIZE];
 
 
26
 
27
  void main() {
28
  const uint tid = gl_LocalInvocationID.x;
29
  const uint row_x = gl_GlobalInvocationID.y;
30
- const uint channel = gl_GlobalInvocationID.z;
31
- const uint channel_x = channel / (p.nchannels_y / p.nchannels_x);
 
 
 
 
 
 
 
 
 
 
 
 
32
 
33
  const uint nrows_y = p.ncols_x;
34
  const uint nrows_dst = p.nrows_x;
35
  const uint row_dst = row_x;
36
 
37
- tmp[tid] = FLOAT_TYPE(0.0f);
 
 
 
 
 
 
38
 
39
  for (uint col_x0 = 0; col_x0 < p.ncols_x; col_x0 += BLOCK_SIZE) {
40
- const uint col_x = col_x0 + tid;
41
 
42
- if (col_x >= p.ncols_x) {
43
- break;
44
- }
45
 
46
- // x is transposed and permuted
47
- const uint ix = row_x*p.nchannels_x*p.ncols_x + channel_x*p.ncols_x + col_x;
48
- const FLOAT_TYPE xi = FLOAT_TYPE(data_a[ix]);
49
 
50
- const uint row_y = col_x;
 
 
51
 
52
- // y is not transposed but permuted
53
- const uint iy = channel*nrows_y + row_y;
 
54
 
55
- tmp[tid] = fma(xi, FLOAT_TYPE(data_b[iy]), tmp[tid]);
56
- }
 
 
 
 
 
 
 
 
 
57
 
58
- // dst is not transposed and not permuted
59
- const uint idst = channel*nrows_dst + row_dst;
 
60
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
  // sum up partial sums and write back result
62
  barrier();
63
  [[unroll]] for (int s = BLOCK_SIZE / 2; s > 0; s >>= 1) {
64
  if (tid < s) {
65
- tmp[tid] += tmp[tid + s];
 
 
 
66
  }
67
  barrier();
68
  }
 
 
 
 
69
 
70
  if (tid == 0) {
71
- dst[idst] = tmp[0];
 
 
 
 
72
  }
73
  }
 
2
 
3
  #extension GL_EXT_control_flow_attributes : enable
4
  #extension GL_EXT_shader_16bit_storage : require
5
+ #if USE_SUBGROUP_ADD
6
+ #extension GL_KHR_shader_subgroup_arithmetic : enable
7
+ #endif
8
 
 
9
  #define FLOAT_TYPE float
10
 
11
+ layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
12
 
13
  layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
14
  layout (binding = 1) readonly buffer B {B_TYPE data_b[];};
15
  layout (binding = 2) writeonly buffer D {D_TYPE dst[];};
16
 
17
+ layout (binding = 0) readonly buffer AV4 {A_TYPE_VEC4 data_a_v4[];};
18
+ layout (binding = 1) readonly buffer BV4 {B_TYPE_VEC4 data_b_v4[];};
19
+
20
+ layout(constant_id = 0) const int BLOCK_SIZE = 32;
21
+ // gqa_ratio is in the range [1,8]
22
+ layout(constant_id = 1) const uint gqa_ratio = 1;
23
+
24
  layout (push_constant) uniform parameter
25
  {
26
  uint ncols_x;
 
31
  uint d_offset;
32
  } p;
33
 
34
+ #if !USE_SUBGROUP_ADD
35
+ shared FLOAT_TYPE tmp[8][BLOCK_SIZE];
36
+ #endif
37
 
38
  void main() {
39
  const uint tid = gl_LocalInvocationID.x;
40
  const uint row_x = gl_GlobalInvocationID.y;
41
+
42
+ uint channel, channel_x;
43
+
44
+ // When gqa_ratio > 1, each invocation does multiple rows.
45
+ // The row in the A matrix is starting from channel / gqa_ratio and the
46
+ // rows in the B matrix are [channel, channel+gqa_ratio).
47
+ // When gpa_ratio is 1, each invocation does one row.
48
+ if (gqa_ratio > 1) {
49
+ channel_x = gl_GlobalInvocationID.z;
50
+ channel = channel_x * gqa_ratio;
51
+ } else {
52
+ channel = gl_GlobalInvocationID.z;
53
+ channel_x = channel / (p.nchannels_y / p.nchannels_x);;
54
+ }
55
 
56
  const uint nrows_y = p.ncols_x;
57
  const uint nrows_dst = p.nrows_x;
58
  const uint row_dst = row_x;
59
 
60
+ FLOAT_TYPE temp[8];
61
+ [[unroll]] for (uint i = 0; i < 8; ++i) {
62
+ temp[i] = FLOAT_TYPE(0.0f);
63
+ }
64
+
65
+ // Detect alignment for vector loads
66
+ bool is_aligned = (p.ncols_x % 4) == 0 && (p.nchannels_x % 4) == 0 && (nrows_y % 4) == 0;
67
 
68
  for (uint col_x0 = 0; col_x0 < p.ncols_x; col_x0 += BLOCK_SIZE) {
 
69
 
70
+ // Use vec4 loads if aligned
71
+ if (col_x0 + 4*BLOCK_SIZE <= p.ncols_x && is_aligned) {
 
72
 
73
+ uint col_x = col_x0 + 4*tid;
74
+ const uint row_y = col_x;
 
75
 
76
+ // x is transposed and permuted
77
+ const uint ix = row_x*p.nchannels_x*p.ncols_x + channel_x*p.ncols_x + col_x;
78
+ const vec4 av4 = vec4(data_a_v4[ix / 4]);
79
 
80
+ [[unroll]] for (uint c = 0; c < gqa_ratio; ++c) {
81
+ // y is not transposed but permuted
82
+ const uint iy = (channel + c)*nrows_y + row_y;
83
 
84
+ vec4 bv4 = data_b_v4[iy / 4];
85
+ temp[c] += dot(av4, bv4);
86
+ }
87
+
88
+ col_x0 += 3*BLOCK_SIZE;
89
+ } else {
90
+ const uint col_x = col_x0 + tid;
91
+
92
+ if (col_x >= p.ncols_x) {
93
+ break;
94
+ }
95
 
96
+ // x is transposed and permuted
97
+ const uint ix = row_x*p.nchannels_x*p.ncols_x + channel_x*p.ncols_x + col_x;
98
+ const FLOAT_TYPE xi = FLOAT_TYPE(data_a[ix]);
99
 
100
+ const uint row_y = col_x;
101
+
102
+ [[unroll]] for (uint c = 0; c < gqa_ratio; ++c) {
103
+ // y is not transposed but permuted
104
+ const uint iy = (channel + c)*nrows_y + row_y;
105
+
106
+ temp[c] = fma(xi, FLOAT_TYPE(data_b[iy]), temp[c]);
107
+ }
108
+ }
109
+ }
110
+
111
+ #if USE_SUBGROUP_ADD
112
+ // reduce vec4 at a time
113
+ vec4 t = vec4(temp[0], temp[1], temp[2], temp[3]);
114
+ t = subgroupAdd(t);
115
+ temp[0] = t[0];
116
+ temp[1] = t[1];
117
+ temp[2] = t[2];
118
+ temp[3] = t[3];
119
+ if (gqa_ratio > 4) {
120
+ t = vec4(temp[4], temp[5], temp[6], temp[7]);
121
+ t = subgroupAdd(t);
122
+ temp[4] = t[0];
123
+ temp[5] = t[1];
124
+ temp[6] = t[2];
125
+ temp[7] = t[3];
126
+ }
127
+ #else
128
+ [[unroll]] for (uint c = 0; c < gqa_ratio; ++c) {
129
+ tmp[c][tid] = temp[c];
130
+ }
131
  // sum up partial sums and write back result
132
  barrier();
133
  [[unroll]] for (int s = BLOCK_SIZE / 2; s > 0; s >>= 1) {
134
  if (tid < s) {
135
+ [[unroll]] for (uint c = 0; c < gqa_ratio; ++c) {
136
+ temp[c] += tmp[c][tid + s];
137
+ tmp[c][tid] = temp[c];
138
+ }
139
  }
140
  barrier();
141
  }
142
+ [[unroll]] for (uint c = 0; c < gqa_ratio; ++c) {
143
+ temp[c] = tmp[c][tid];
144
+ }
145
+ #endif
146
 
147
  if (tid == 0) {
148
+ [[unroll]] for (uint c = 0; c < gqa_ratio; ++c) {
149
+ // dst is not transposed and not permuted
150
+ const uint idst = (channel + c)*nrows_dst + row_dst;
151
+ dst[idst] = temp[c];
152
+ }
153
  }
154
  }
ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp CHANGED
@@ -426,8 +426,9 @@ void process_shaders() {
426
  }
427
  }
428
 
429
- string_to_spv("mul_mat_vec_p021_f16_f32", "mul_mat_vec_p021.comp", {{"A_TYPE", "float16_t"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}});
430
- string_to_spv("mul_mat_vec_nc_f16_f32", "mul_mat_vec_nc.comp", {{"A_TYPE", "float16_t"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}});
 
431
 
432
  // Norms
433
  string_to_spv("norm_f32", "norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
 
426
  }
427
  }
428
 
429
+ string_to_spv("mul_mat_vec_p021_f16_f32_subgroup_add", "mul_mat_vec_p021.comp", {{"A_TYPE", "float16_t"}, {"A_TYPE_VEC4", "f16vec4"}, {"B_TYPE", "float"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}});
430
+ string_to_spv("mul_mat_vec_p021_f16_f32", "mul_mat_vec_p021.comp", {{"A_TYPE", "float16_t"}, {"A_TYPE_VEC4", "f16vec4"}, {"B_TYPE", "float"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}});
431
+ string_to_spv("mul_mat_vec_nc_f16_f32", "mul_mat_vec_nc.comp", {{"A_TYPE", "float16_t"}, {"A_TYPE_VEC4", "f16vec4"}, {"B_TYPE", "float"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}});
432
 
433
  // Norms
434
  string_to_spv("norm_f32", "norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));