jeffbolznv commited on
Commit
e4d1f59
·
1 Parent(s): 42283e1

vulkan: support noncontiguous rms_norm (llama/13031)

Browse files
ggml/src/ggml-vulkan/ggml-vulkan.cpp CHANGED
@@ -2397,7 +2397,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
2397
 
2398
  ggml_vk_create_pipeline(device, device->pipeline_norm_f32, "norm_f32", norm_f32_len, norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
2399
  ggml_vk_create_pipeline(device, device->pipeline_group_norm_f32, "group_norm_f32", group_norm_f32_len, group_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
2400
- ggml_vk_create_pipeline(device, device->pipeline_rms_norm_f32, "rms_norm_f32", rms_norm_f32_len, rms_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
2401
  ggml_vk_create_pipeline(device, device->pipeline_rms_norm_back_f32, "rms_norm_back_f32", rms_norm_back_f32_len, rms_norm_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
2402
  ggml_vk_create_pipeline(device, device->pipeline_l2_norm_f32, "l2_norm_f32", l2_norm_f32_len, l2_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
2403
 
@@ -6006,6 +6006,7 @@ static bool ggml_vk_op_supports_incontiguous(ggml_op op) {
6006
  case GGML_OP_REPEAT:
6007
  case GGML_OP_REPEAT_BACK:
6008
  case GGML_OP_ROPE:
 
6009
  return true;
6010
  default:
6011
  return false;
@@ -6216,7 +6217,6 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
6216
 
6217
  switch (op) {
6218
  case GGML_OP_NORM:
6219
- case GGML_OP_RMS_NORM:
6220
  case GGML_OP_RMS_NORM_BACK:
6221
  case GGML_OP_L2_NORM:
6222
  case GGML_OP_SOFT_MAX:
@@ -6233,6 +6233,10 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
6233
  elements = { nr, 1, 1 };
6234
  }
6235
  } break;
 
 
 
 
6236
  case GGML_OP_SUM:
6237
  // We use GGML_OP_SUM_ROWS with 1 row.
6238
  elements = { 1, 1, 1 };
@@ -6883,7 +6887,17 @@ static void ggml_vk_group_norm(ggml_backend_vk_context * ctx, vk_context& subctx
6883
 
6884
  static void ggml_vk_rms_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
6885
  float * op_params = (float *)dst->op_params;
6886
- ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_RMS_NORM, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], 0.0f }, dryrun);
 
 
 
 
 
 
 
 
 
 
6887
  }
6888
 
6889
  static void ggml_vk_rms_norm_back(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
@@ -9388,10 +9402,10 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
9388
  case GGML_OP_VIEW:
9389
  case GGML_OP_PERMUTE:
9390
  case GGML_OP_TRANSPOSE:
 
9391
  return true;
9392
  case GGML_OP_NORM:
9393
  case GGML_OP_GROUP_NORM:
9394
- case GGML_OP_RMS_NORM:
9395
  case GGML_OP_L2_NORM:
9396
  return ggml_is_contiguous(op->src[0]);
9397
  case GGML_OP_ADD:
 
2397
 
2398
  ggml_vk_create_pipeline(device, device->pipeline_norm_f32, "norm_f32", norm_f32_len, norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
2399
  ggml_vk_create_pipeline(device, device->pipeline_group_norm_f32, "group_norm_f32", group_norm_f32_len, group_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
2400
+ ggml_vk_create_pipeline(device, device->pipeline_rms_norm_f32, "rms_norm_f32", rms_norm_f32_len, rms_norm_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {1, 1, 1}, {}, 1);
2401
  ggml_vk_create_pipeline(device, device->pipeline_rms_norm_back_f32, "rms_norm_back_f32", rms_norm_back_f32_len, rms_norm_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
2402
  ggml_vk_create_pipeline(device, device->pipeline_l2_norm_f32, "l2_norm_f32", l2_norm_f32_len, l2_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
2403
 
 
6006
  case GGML_OP_REPEAT:
6007
  case GGML_OP_REPEAT_BACK:
6008
  case GGML_OP_ROPE:
6009
+ case GGML_OP_RMS_NORM:
6010
  return true;
6011
  default:
6012
  return false;
 
6217
 
6218
  switch (op) {
6219
  case GGML_OP_NORM:
 
6220
  case GGML_OP_RMS_NORM_BACK:
6221
  case GGML_OP_L2_NORM:
6222
  case GGML_OP_SOFT_MAX:
 
6233
  elements = { nr, 1, 1 };
6234
  }
6235
  } break;
6236
+ case GGML_OP_RMS_NORM:
6237
+ elements = { (uint32_t)ne01, (uint32_t)ne02, (uint32_t)ne03 };
6238
+ break;
6239
+
6240
  case GGML_OP_SUM:
6241
  // We use GGML_OP_SUM_ROWS with 1 row.
6242
  elements = { 1, 1, 1 };
 
6887
 
6888
  static void ggml_vk_rms_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
6889
  float * op_params = (float *)dst->op_params;
6890
+ const uint32_t src0_type_size = ggml_type_size(src0->type);
6891
+ const uint32_t dst_type_size = ggml_type_size(dst->type);
6892
+
6893
+ ggml_vk_op_f32<vk_op_unary_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_RMS_NORM, {
6894
+ (uint32_t)ggml_nelements(src0),
6895
+ (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
6896
+ (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
6897
+ 0,
6898
+ op_params[0], 0.0f,
6899
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
6900
+ }, dryrun);
6901
  }
6902
 
6903
  static void ggml_vk_rms_norm_back(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
 
9402
  case GGML_OP_VIEW:
9403
  case GGML_OP_PERMUTE:
9404
  case GGML_OP_TRANSPOSE:
9405
+ case GGML_OP_RMS_NORM:
9406
  return true;
9407
  case GGML_OP_NORM:
9408
  case GGML_OP_GROUP_NORM:
 
9409
  case GGML_OP_L2_NORM:
9410
  return ggml_is_contiguous(op->src[0]);
9411
  case GGML_OP_ADD:
ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp CHANGED
@@ -1,6 +1,6 @@
1
  #version 450
2
 
3
- #include "generic_head.comp"
4
  #include "types.comp"
5
 
6
  #extension GL_EXT_control_flow_attributes : enable
@@ -8,19 +8,29 @@
8
 
9
  layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in;
10
 
11
- layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
12
- layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
13
-
14
  shared FLOAT_TYPE sum[BLOCK_SIZE];
15
 
16
  void main() {
17
- const uint row = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x;
18
- const uint tid = gl_LocalInvocationID.x;
 
 
 
 
 
 
 
 
 
 
 
 
 
19
 
20
  sum[tid] = FLOAT_TYPE(0.0f); // partial sum for thread in warp
21
 
22
- [[unroll]] for (uint col = tid; col < p.KX; col += BLOCK_SIZE) {
23
- const FLOAT_TYPE xi = FLOAT_TYPE(data_a[row*p.KX + col]);
24
  sum[tid] += xi * xi;
25
  }
26
 
@@ -33,10 +43,10 @@ void main() {
33
  barrier();
34
  }
35
 
36
- const FLOAT_TYPE mean = sum[0] / FLOAT_TYPE(p.KX);
37
  const FLOAT_TYPE scale = inversesqrt(mean + FLOAT_TYPE(p.param1));
38
 
39
- [[unroll]] for (uint col = tid; col < p.KX; col += BLOCK_SIZE) {
40
- data_d[row*p.KX + col] = D_TYPE(scale * FLOAT_TYPE(data_a[row*p.KX + col]));
41
  }
42
  }
 
1
  #version 450
2
 
3
+ #include "generic_unary_head.comp"
4
  #include "types.comp"
5
 
6
  #extension GL_EXT_control_flow_attributes : enable
 
8
 
9
  layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in;
10
 
 
 
 
11
  shared FLOAT_TYPE sum[BLOCK_SIZE];
12
 
13
  void main() {
14
+ const uint ncols = p.ne00;
15
+ const uint nrows = gl_NumWorkGroups.x;
16
+ const uint nchannels = gl_NumWorkGroups.y;
17
+
18
+ const uint row = gl_WorkGroupID.x;
19
+ const uint channel = gl_WorkGroupID.y;
20
+ const uint samp = gl_WorkGroupID.z;
21
+ const uint tid = gl_LocalInvocationID.x;
22
+
23
+ const uint stride_row = p.nb01;
24
+ const uint stride_channel = p.nb02;
25
+ const uint stride_sample = p.nb03;
26
+
27
+ uint32_t a_offset = samp*stride_sample + channel*stride_channel + row*stride_row + get_aoffset();
28
+ uint32_t d_offset = ((samp*nchannels + channel)*nrows + row)*ncols + get_doffset();
29
 
30
  sum[tid] = FLOAT_TYPE(0.0f); // partial sum for thread in warp
31
 
32
+ [[unroll]] for (uint col = tid; col < ncols; col += BLOCK_SIZE) {
33
+ const FLOAT_TYPE xi = FLOAT_TYPE(data_a[a_offset + col]);
34
  sum[tid] += xi * xi;
35
  }
36
 
 
43
  barrier();
44
  }
45
 
46
+ const FLOAT_TYPE mean = sum[0] / FLOAT_TYPE(ncols);
47
  const FLOAT_TYPE scale = inversesqrt(mean + FLOAT_TYPE(p.param1));
48
 
49
+ [[unroll]] for (uint col = tid; col < ncols; col += BLOCK_SIZE) {
50
+ data_d[d_offset + col] = D_TYPE(scale * FLOAT_TYPE(data_a[a_offset + col]));
51
  }
52
  }