Spaces:
Running
Running
Commit
·
e4d1f59
1
Parent(s):
42283e1
vulkan: support noncontiguous rms_norm (llama/13031)
Browse files
ggml/src/ggml-vulkan/ggml-vulkan.cpp
CHANGED
|
@@ -2397,7 +2397,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
| 2397 |
|
| 2398 |
ggml_vk_create_pipeline(device, device->pipeline_norm_f32, "norm_f32", norm_f32_len, norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
|
| 2399 |
ggml_vk_create_pipeline(device, device->pipeline_group_norm_f32, "group_norm_f32", group_norm_f32_len, group_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
|
| 2400 |
-
ggml_vk_create_pipeline(device, device->pipeline_rms_norm_f32, "rms_norm_f32", rms_norm_f32_len, rms_norm_f32_data, "main", 2, sizeof(
|
| 2401 |
ggml_vk_create_pipeline(device, device->pipeline_rms_norm_back_f32, "rms_norm_back_f32", rms_norm_back_f32_len, rms_norm_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
|
| 2402 |
ggml_vk_create_pipeline(device, device->pipeline_l2_norm_f32, "l2_norm_f32", l2_norm_f32_len, l2_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
|
| 2403 |
|
|
@@ -6006,6 +6006,7 @@ static bool ggml_vk_op_supports_incontiguous(ggml_op op) {
|
|
| 6006 |
case GGML_OP_REPEAT:
|
| 6007 |
case GGML_OP_REPEAT_BACK:
|
| 6008 |
case GGML_OP_ROPE:
|
|
|
|
| 6009 |
return true;
|
| 6010 |
default:
|
| 6011 |
return false;
|
|
@@ -6216,7 +6217,6 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
|
|
| 6216 |
|
| 6217 |
switch (op) {
|
| 6218 |
case GGML_OP_NORM:
|
| 6219 |
-
case GGML_OP_RMS_NORM:
|
| 6220 |
case GGML_OP_RMS_NORM_BACK:
|
| 6221 |
case GGML_OP_L2_NORM:
|
| 6222 |
case GGML_OP_SOFT_MAX:
|
|
@@ -6233,6 +6233,10 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
|
|
| 6233 |
elements = { nr, 1, 1 };
|
| 6234 |
}
|
| 6235 |
} break;
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6236 |
case GGML_OP_SUM:
|
| 6237 |
// We use GGML_OP_SUM_ROWS with 1 row.
|
| 6238 |
elements = { 1, 1, 1 };
|
|
@@ -6883,7 +6887,17 @@ static void ggml_vk_group_norm(ggml_backend_vk_context * ctx, vk_context& subctx
|
|
| 6883 |
|
| 6884 |
static void ggml_vk_rms_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
|
| 6885 |
float * op_params = (float *)dst->op_params;
|
| 6886 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6887 |
}
|
| 6888 |
|
| 6889 |
static void ggml_vk_rms_norm_back(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
|
|
@@ -9388,10 +9402,10 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
|
| 9388 |
case GGML_OP_VIEW:
|
| 9389 |
case GGML_OP_PERMUTE:
|
| 9390 |
case GGML_OP_TRANSPOSE:
|
|
|
|
| 9391 |
return true;
|
| 9392 |
case GGML_OP_NORM:
|
| 9393 |
case GGML_OP_GROUP_NORM:
|
| 9394 |
-
case GGML_OP_RMS_NORM:
|
| 9395 |
case GGML_OP_L2_NORM:
|
| 9396 |
return ggml_is_contiguous(op->src[0]);
|
| 9397 |
case GGML_OP_ADD:
|
|
|
|
| 2397 |
|
| 2398 |
ggml_vk_create_pipeline(device, device->pipeline_norm_f32, "norm_f32", norm_f32_len, norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
|
| 2399 |
ggml_vk_create_pipeline(device, device->pipeline_group_norm_f32, "group_norm_f32", group_norm_f32_len, group_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
|
| 2400 |
+
ggml_vk_create_pipeline(device, device->pipeline_rms_norm_f32, "rms_norm_f32", rms_norm_f32_len, rms_norm_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {1, 1, 1}, {}, 1);
|
| 2401 |
ggml_vk_create_pipeline(device, device->pipeline_rms_norm_back_f32, "rms_norm_back_f32", rms_norm_back_f32_len, rms_norm_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
|
| 2402 |
ggml_vk_create_pipeline(device, device->pipeline_l2_norm_f32, "l2_norm_f32", l2_norm_f32_len, l2_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
|
| 2403 |
|
|
|
|
| 6006 |
case GGML_OP_REPEAT:
|
| 6007 |
case GGML_OP_REPEAT_BACK:
|
| 6008 |
case GGML_OP_ROPE:
|
| 6009 |
+
case GGML_OP_RMS_NORM:
|
| 6010 |
return true;
|
| 6011 |
default:
|
| 6012 |
return false;
|
|
|
|
| 6217 |
|
| 6218 |
switch (op) {
|
| 6219 |
case GGML_OP_NORM:
|
|
|
|
| 6220 |
case GGML_OP_RMS_NORM_BACK:
|
| 6221 |
case GGML_OP_L2_NORM:
|
| 6222 |
case GGML_OP_SOFT_MAX:
|
|
|
|
| 6233 |
elements = { nr, 1, 1 };
|
| 6234 |
}
|
| 6235 |
} break;
|
| 6236 |
+
case GGML_OP_RMS_NORM:
|
| 6237 |
+
elements = { (uint32_t)ne01, (uint32_t)ne02, (uint32_t)ne03 };
|
| 6238 |
+
break;
|
| 6239 |
+
|
| 6240 |
case GGML_OP_SUM:
|
| 6241 |
// We use GGML_OP_SUM_ROWS with 1 row.
|
| 6242 |
elements = { 1, 1, 1 };
|
|
|
|
| 6887 |
|
| 6888 |
static void ggml_vk_rms_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
|
| 6889 |
float * op_params = (float *)dst->op_params;
|
| 6890 |
+
const uint32_t src0_type_size = ggml_type_size(src0->type);
|
| 6891 |
+
const uint32_t dst_type_size = ggml_type_size(dst->type);
|
| 6892 |
+
|
| 6893 |
+
ggml_vk_op_f32<vk_op_unary_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_RMS_NORM, {
|
| 6894 |
+
(uint32_t)ggml_nelements(src0),
|
| 6895 |
+
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
|
| 6896 |
+
(uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
|
| 6897 |
+
0,
|
| 6898 |
+
op_params[0], 0.0f,
|
| 6899 |
+
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
| 6900 |
+
}, dryrun);
|
| 6901 |
}
|
| 6902 |
|
| 6903 |
static void ggml_vk_rms_norm_back(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
|
|
|
|
| 9402 |
case GGML_OP_VIEW:
|
| 9403 |
case GGML_OP_PERMUTE:
|
| 9404 |
case GGML_OP_TRANSPOSE:
|
| 9405 |
+
case GGML_OP_RMS_NORM:
|
| 9406 |
return true;
|
| 9407 |
case GGML_OP_NORM:
|
| 9408 |
case GGML_OP_GROUP_NORM:
|
|
|
|
| 9409 |
case GGML_OP_L2_NORM:
|
| 9410 |
return ggml_is_contiguous(op->src[0]);
|
| 9411 |
case GGML_OP_ADD:
|
ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp
CHANGED
|
@@ -1,6 +1,6 @@
|
|
| 1 |
#version 450
|
| 2 |
|
| 3 |
-
#include "
|
| 4 |
#include "types.comp"
|
| 5 |
|
| 6 |
#extension GL_EXT_control_flow_attributes : enable
|
|
@@ -8,19 +8,29 @@
|
|
| 8 |
|
| 9 |
layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in;
|
| 10 |
|
| 11 |
-
layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
|
| 12 |
-
layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
|
| 13 |
-
|
| 14 |
shared FLOAT_TYPE sum[BLOCK_SIZE];
|
| 15 |
|
| 16 |
void main() {
|
| 17 |
-
const uint
|
| 18 |
-
const uint
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
|
| 20 |
sum[tid] = FLOAT_TYPE(0.0f); // partial sum for thread in warp
|
| 21 |
|
| 22 |
-
[[unroll]] for (uint col = tid; col <
|
| 23 |
-
const FLOAT_TYPE xi = FLOAT_TYPE(data_a[
|
| 24 |
sum[tid] += xi * xi;
|
| 25 |
}
|
| 26 |
|
|
@@ -33,10 +43,10 @@ void main() {
|
|
| 33 |
barrier();
|
| 34 |
}
|
| 35 |
|
| 36 |
-
const FLOAT_TYPE mean = sum[0] / FLOAT_TYPE(
|
| 37 |
const FLOAT_TYPE scale = inversesqrt(mean + FLOAT_TYPE(p.param1));
|
| 38 |
|
| 39 |
-
[[unroll]] for (uint col = tid; col <
|
| 40 |
-
data_d[
|
| 41 |
}
|
| 42 |
}
|
|
|
|
| 1 |
#version 450
|
| 2 |
|
| 3 |
+
#include "generic_unary_head.comp"
|
| 4 |
#include "types.comp"
|
| 5 |
|
| 6 |
#extension GL_EXT_control_flow_attributes : enable
|
|
|
|
| 8 |
|
| 9 |
layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in;
|
| 10 |
|
|
|
|
|
|
|
|
|
|
| 11 |
shared FLOAT_TYPE sum[BLOCK_SIZE];
|
| 12 |
|
| 13 |
void main() {
|
| 14 |
+
const uint ncols = p.ne00;
|
| 15 |
+
const uint nrows = gl_NumWorkGroups.x;
|
| 16 |
+
const uint nchannels = gl_NumWorkGroups.y;
|
| 17 |
+
|
| 18 |
+
const uint row = gl_WorkGroupID.x;
|
| 19 |
+
const uint channel = gl_WorkGroupID.y;
|
| 20 |
+
const uint samp = gl_WorkGroupID.z;
|
| 21 |
+
const uint tid = gl_LocalInvocationID.x;
|
| 22 |
+
|
| 23 |
+
const uint stride_row = p.nb01;
|
| 24 |
+
const uint stride_channel = p.nb02;
|
| 25 |
+
const uint stride_sample = p.nb03;
|
| 26 |
+
|
| 27 |
+
uint32_t a_offset = samp*stride_sample + channel*stride_channel + row*stride_row + get_aoffset();
|
| 28 |
+
uint32_t d_offset = ((samp*nchannels + channel)*nrows + row)*ncols + get_doffset();
|
| 29 |
|
| 30 |
sum[tid] = FLOAT_TYPE(0.0f); // partial sum for thread in warp
|
| 31 |
|
| 32 |
+
[[unroll]] for (uint col = tid; col < ncols; col += BLOCK_SIZE) {
|
| 33 |
+
const FLOAT_TYPE xi = FLOAT_TYPE(data_a[a_offset + col]);
|
| 34 |
sum[tid] += xi * xi;
|
| 35 |
}
|
| 36 |
|
|
|
|
| 43 |
barrier();
|
| 44 |
}
|
| 45 |
|
| 46 |
+
const FLOAT_TYPE mean = sum[0] / FLOAT_TYPE(ncols);
|
| 47 |
const FLOAT_TYPE scale = inversesqrt(mean + FLOAT_TYPE(p.param1));
|
| 48 |
|
| 49 |
+
[[unroll]] for (uint col = tid; col < ncols; col += BLOCK_SIZE) {
|
| 50 |
+
data_d[d_offset + col] = D_TYPE(scale * FLOAT_TYPE(data_a[a_offset + col]));
|
| 51 |
}
|
| 52 |
}
|