Spaces:
Running
Running
vulkan : sync (llama/0)
Browse files- ggml/src/ggml-vulkan/ggml-vulkan.cpp +108 -7
- ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_back.comp +55 -0
- ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp +5 -0
- ggml/src/ggml-vulkan/vulkan-shaders/sigmoid.comp +20 -0
- ggml/src/ggml-vulkan/vulkan-shaders/silu_back.comp +26 -0
- ggml/src/ggml-vulkan/vulkan-shaders/soft_max_back.comp +50 -0
- ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +4 -0
ggml/src/ggml-vulkan/ggml-vulkan.cpp
CHANGED
|
@@ -241,15 +241,19 @@ struct vk_device_struct {
|
|
| 241 |
vk_pipeline pipeline_norm_f32;
|
| 242 |
vk_pipeline pipeline_group_norm_f32;
|
| 243 |
vk_pipeline pipeline_rms_norm_f32;
|
|
|
|
| 244 |
vk_pipeline pipeline_gelu_f32;
|
| 245 |
vk_pipeline pipeline_gelu_quick_f32;
|
| 246 |
vk_pipeline pipeline_silu_f32;
|
|
|
|
| 247 |
vk_pipeline pipeline_relu_f32;
|
| 248 |
vk_pipeline pipeline_leaky_relu_f32;
|
| 249 |
vk_pipeline pipeline_tanh_f32;
|
|
|
|
| 250 |
vk_pipeline pipeline_diag_mask_inf_f32;
|
| 251 |
vk_pipeline pipeline_soft_max_f32, pipeline_soft_max_f32_f16;
|
| 252 |
vk_pipeline pipeline_soft_max_f32_wg512, pipeline_soft_max_f32_f16_wg512;
|
|
|
|
| 253 |
vk_pipeline pipeline_rope_norm_f32, pipeline_rope_norm_f16;
|
| 254 |
vk_pipeline pipeline_rope_neox_f32, pipeline_rope_neox_f16;
|
| 255 |
vk_pipeline pipeline_rope_multi_f32, pipeline_rope_multi_f16;
|
|
@@ -504,6 +508,7 @@ struct vk_op_rope_push_constants {
|
|
| 504 |
uint32_t s1;
|
| 505 |
uint32_t s2;
|
| 506 |
int32_t sections[4];
|
|
|
|
| 507 |
};
|
| 508 |
|
| 509 |
struct vk_op_soft_max_push_constants {
|
|
@@ -2122,6 +2127,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
| 2122 |
ggml_vk_create_pipeline(device, device->pipeline_norm_f32, "norm_f32", norm_f32_len, norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
|
| 2123 |
ggml_vk_create_pipeline(device, device->pipeline_group_norm_f32, "group_norm_f32", group_norm_f32_len, group_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
|
| 2124 |
ggml_vk_create_pipeline(device, device->pipeline_rms_norm_f32, "rms_norm_f32", rms_norm_f32_len, rms_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
|
|
|
|
| 2125 |
|
| 2126 |
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_f32, "cpy_f32_f32", cpy_f32_f32_len, cpy_f32_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
| 2127 |
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_f16, "cpy_f32_f16", cpy_f32_f16_len, cpy_f32_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
|
@@ -2181,9 +2187,11 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
| 2181 |
ggml_vk_create_pipeline(device, device->pipeline_gelu_f32, "gelu_f32", gelu_f32_len, gelu_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
|
| 2182 |
ggml_vk_create_pipeline(device, device->pipeline_gelu_quick_f32, "gelu_quick_f32", gelu_quick_f32_len, gelu_quick_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
|
| 2183 |
ggml_vk_create_pipeline(device, device->pipeline_silu_f32, "silu_f32", silu_f32_len, silu_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
|
|
|
|
| 2184 |
ggml_vk_create_pipeline(device, device->pipeline_relu_f32, "relu_f32", relu_f32_len, relu_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
|
| 2185 |
ggml_vk_create_pipeline(device, device->pipeline_leaky_relu_f32, "leaky_relu_f32", leaky_relu_f32_len, leaky_relu_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
|
| 2186 |
ggml_vk_create_pipeline(device, device->pipeline_tanh_f32, "tanh_f32", tanh_f32_len, tanh_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
|
|
|
|
| 2187 |
|
| 2188 |
ggml_vk_create_pipeline(device, device->pipeline_diag_mask_inf_f32, "diag_mask_inf_f32", diag_mask_inf_f32_len, diag_mask_inf_f32_data, "main", 2, sizeof(vk_op_diag_mask_push_constants), {1, 512, 1}, {}, 1, true);
|
| 2189 |
|
|
@@ -2191,6 +2199,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
| 2191 |
ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32_wg512, "soft_max_f32_wg512", soft_max_f32_len, soft_max_f32_data, "main", 3, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { 512 }, 1);
|
| 2192 |
ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32_f16, "soft_max_f32_f16", soft_max_f32_f16_len, soft_max_f32_f16_data, "main", 3, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
|
| 2193 |
ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32_f16_wg512, "soft_max_f32_f16_wg512", soft_max_f32_f16_len, soft_max_f32_f16_data, "main", 3, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { 512 }, 1);
|
|
|
|
| 2194 |
|
| 2195 |
ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f32, "rope_norm_f32", rope_norm_f32_len, rope_norm_f32_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
|
| 2196 |
ggml_vk_create_pipeline(device, device->pipeline_rope_neox_f32, "rope_neox_f32", rope_neox_f32_len, rope_neox_f32_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
|
|
@@ -5284,6 +5293,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
|
|
| 5284 |
case GGML_OP_CONT:
|
| 5285 |
case GGML_OP_DUP:
|
| 5286 |
return ggml_vk_get_cpy_pipeline(ctx, src0, dst, dst->type);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5287 |
case GGML_OP_NORM:
|
| 5288 |
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
|
| 5289 |
return ctx->device->pipeline_norm_f32;
|
|
@@ -5299,6 +5313,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
|
|
| 5299 |
return ctx->device->pipeline_rms_norm_f32;
|
| 5300 |
}
|
| 5301 |
return nullptr;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5302 |
case GGML_OP_UNARY:
|
| 5303 |
switch (ggml_get_unary_op(dst)) {
|
| 5304 |
case GGML_UNARY_OP_SILU:
|
|
@@ -5326,6 +5345,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
|
|
| 5326 |
return ctx->device->pipeline_tanh_f32;
|
| 5327 |
}
|
| 5328 |
break;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5329 |
default:
|
| 5330 |
break;
|
| 5331 |
}
|
|
@@ -5345,7 +5369,13 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
|
|
| 5345 |
return src0->ne[0] > 1024 ? ctx->device->pipeline_soft_max_f32_f16_wg512 : ctx->device->pipeline_soft_max_f32_f16;
|
| 5346 |
}
|
| 5347 |
return nullptr;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5348 |
case GGML_OP_ROPE:
|
|
|
|
| 5349 |
{
|
| 5350 |
const int mode = ((const int32_t *) dst->op_params)[2];
|
| 5351 |
const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
|
|
@@ -5673,7 +5703,9 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
|
|
| 5673 |
switch (op) {
|
| 5674 |
case GGML_OP_NORM:
|
| 5675 |
case GGML_OP_RMS_NORM:
|
|
|
|
| 5676 |
case GGML_OP_SOFT_MAX:
|
|
|
|
| 5677 |
case GGML_OP_SUM_ROWS:
|
| 5678 |
case GGML_OP_ARGMAX:
|
| 5679 |
{
|
|
@@ -5697,6 +5729,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
|
|
| 5697 |
} break;
|
| 5698 |
case GGML_OP_DIAG_MASK_INF:
|
| 5699 |
case GGML_OP_ROPE:
|
|
|
|
| 5700 |
elements = { (uint32_t)ggml_nrows(src0), (uint32_t)ne00, 1 };
|
| 5701 |
break;
|
| 5702 |
case GGML_OP_GET_ROWS:
|
|
@@ -5792,7 +5825,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
|
|
| 5792 |
|
| 5793 |
ggml_vk_sync_buffers(subctx);
|
| 5794 |
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, subbuf_y, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, sizeof(PC), &pc, elements);
|
| 5795 |
-
} else if (op == GGML_OP_ROPE) {
|
| 5796 |
// Empty src2 is possible in rope, but the shader needs a buffer
|
| 5797 |
vk_subbuffer subbuf_z;
|
| 5798 |
if (use_src2) {
|
|
@@ -6314,6 +6347,10 @@ static void ggml_vk_cpy(ggml_backend_vk_context * ctx, vk_context& subctx, const
|
|
| 6314 |
}, dryrun);
|
| 6315 |
}
|
| 6316 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6317 |
static void ggml_vk_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
|
| 6318 |
float * op_params = (float *)dst->op_params;
|
| 6319 |
|
|
@@ -6336,6 +6373,11 @@ static void ggml_vk_rms_norm(ggml_backend_vk_context * ctx, vk_context& subctx,
|
|
| 6336 |
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_RMS_NORM, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], 0.0f }, dryrun);
|
| 6337 |
}
|
| 6338 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6339 |
static void ggml_vk_unary(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
|
| 6340 |
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_UNARY, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f }, dryrun);
|
| 6341 |
}
|
|
@@ -6371,7 +6413,12 @@ static void ggml_vk_soft_max(ggml_backend_vk_context * ctx, vk_context& subctx,
|
|
| 6371 |
}, dryrun);
|
| 6372 |
}
|
| 6373 |
|
| 6374 |
-
static void
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6375 |
const int n_dims = ((int32_t *) dst->op_params)[1];
|
| 6376 |
const int mode = ((int32_t *) dst->op_params)[2];
|
| 6377 |
// const int n_ctx = ((int32_t *) dst->op_params)[3];
|
|
@@ -6399,7 +6446,7 @@ static void ggml_vk_rope(ggml_backend_vk_context * ctx, vk_context& subctx, cons
|
|
| 6399 |
(uint32_t)src0->ne[0], (uint32_t)n_dims, freq_scale, (uint32_t)src0->ne[1],
|
| 6400 |
freq_base, ext_factor, attn_factor, {corr_dims[0], corr_dims[1]}, theta_scale,
|
| 6401 |
src2 != nullptr, (uint32_t)src0->ne[2], s1, s2,
|
| 6402 |
-
sections[0], sections[1], sections[2], sections[3],
|
| 6403 |
}, dryrun);
|
| 6404 |
}
|
| 6405 |
|
|
@@ -7296,6 +7343,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
|
|
| 7296 |
case GGML_UNARY_OP_GELU_QUICK:
|
| 7297 |
case GGML_UNARY_OP_RELU:
|
| 7298 |
case GGML_UNARY_OP_TANH:
|
|
|
|
| 7299 |
break;
|
| 7300 |
default:
|
| 7301 |
return false;
|
|
@@ -7320,12 +7368,16 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
|
|
| 7320 |
case GGML_OP_CPY:
|
| 7321 |
case GGML_OP_CONT:
|
| 7322 |
case GGML_OP_DUP:
|
|
|
|
| 7323 |
case GGML_OP_NORM:
|
| 7324 |
case GGML_OP_GROUP_NORM:
|
| 7325 |
case GGML_OP_RMS_NORM:
|
|
|
|
| 7326 |
case GGML_OP_DIAG_MASK_INF:
|
| 7327 |
case GGML_OP_SOFT_MAX:
|
|
|
|
| 7328 |
case GGML_OP_ROPE:
|
|
|
|
| 7329 |
case GGML_OP_MUL_MAT:
|
| 7330 |
case GGML_OP_MUL_MAT_ID:
|
| 7331 |
case GGML_OP_ARGSORT:
|
|
@@ -7378,13 +7430,17 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
|
|
| 7378 |
case GGML_OP_CPY:
|
| 7379 |
case GGML_OP_CONT:
|
| 7380 |
case GGML_OP_DUP:
|
|
|
|
| 7381 |
case GGML_OP_NORM:
|
| 7382 |
case GGML_OP_GROUP_NORM:
|
| 7383 |
case GGML_OP_RMS_NORM:
|
|
|
|
| 7384 |
case GGML_OP_UNARY:
|
| 7385 |
case GGML_OP_DIAG_MASK_INF:
|
| 7386 |
case GGML_OP_SOFT_MAX:
|
|
|
|
| 7387 |
case GGML_OP_ROPE:
|
|
|
|
| 7388 |
case GGML_OP_ARGSORT:
|
| 7389 |
case GGML_OP_SUM:
|
| 7390 |
case GGML_OP_SUM_ROWS:
|
|
@@ -7476,6 +7532,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
|
|
| 7476 |
case GGML_OP_DUP:
|
| 7477 |
ggml_vk_cpy(ctx, compute_ctx, src0, node, dryrun);
|
| 7478 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7479 |
break;
|
| 7480 |
case GGML_OP_NORM:
|
| 7481 |
ggml_vk_norm(ctx, compute_ctx, src0, node, dryrun);
|
|
@@ -7488,6 +7548,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
|
|
| 7488 |
case GGML_OP_RMS_NORM:
|
| 7489 |
ggml_vk_rms_norm(ctx, compute_ctx, src0, node, dryrun);
|
| 7490 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7491 |
break;
|
| 7492 |
case GGML_OP_UNARY:
|
| 7493 |
switch (ggml_get_unary_op(node)) {
|
|
@@ -7496,6 +7560,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
|
|
| 7496 |
case GGML_UNARY_OP_GELU_QUICK:
|
| 7497 |
case GGML_UNARY_OP_RELU:
|
| 7498 |
case GGML_UNARY_OP_TANH:
|
|
|
|
| 7499 |
ggml_vk_unary(ctx, compute_ctx, src0, node, dryrun);
|
| 7500 |
break;
|
| 7501 |
default:
|
|
@@ -7509,9 +7574,17 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
|
|
| 7509 |
case GGML_OP_SOFT_MAX:
|
| 7510 |
ggml_vk_soft_max(ctx, compute_ctx, src0, src1, node, dryrun);
|
| 7511 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7512 |
break;
|
| 7513 |
case GGML_OP_ROPE:
|
| 7514 |
-
ggml_vk_rope(ctx, compute_ctx, src0, src1, src2, node, dryrun);
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7515 |
|
| 7516 |
break;
|
| 7517 |
case GGML_OP_ARGSORT:
|
|
@@ -7637,12 +7710,16 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor *
|
|
| 7637 |
case GGML_OP_CPY:
|
| 7638 |
case GGML_OP_CONT:
|
| 7639 |
case GGML_OP_DUP:
|
|
|
|
| 7640 |
case GGML_OP_NORM:
|
| 7641 |
case GGML_OP_GROUP_NORM:
|
| 7642 |
case GGML_OP_RMS_NORM:
|
|
|
|
| 7643 |
case GGML_OP_DIAG_MASK_INF:
|
| 7644 |
case GGML_OP_SOFT_MAX:
|
|
|
|
| 7645 |
case GGML_OP_ROPE:
|
|
|
|
| 7646 |
case GGML_OP_RESHAPE:
|
| 7647 |
case GGML_OP_VIEW:
|
| 7648 |
case GGML_OP_PERMUTE:
|
|
@@ -7671,6 +7748,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor *
|
|
| 7671 |
case GGML_UNARY_OP_GELU_QUICK:
|
| 7672 |
case GGML_UNARY_OP_RELU:
|
| 7673 |
case GGML_UNARY_OP_TANH:
|
|
|
|
| 7674 |
buf = tensor->buffer;
|
| 7675 |
break;
|
| 7676 |
default:
|
|
@@ -8373,6 +8451,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
|
| 8373 |
case GGML_UNARY_OP_SILU:
|
| 8374 |
case GGML_UNARY_OP_RELU:
|
| 8375 |
case GGML_UNARY_OP_TANH:
|
|
|
|
| 8376 |
return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
|
| 8377 |
default:
|
| 8378 |
return false;
|
|
@@ -8562,6 +8641,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
|
| 8562 |
case GGML_OP_REPEAT_BACK:
|
| 8563 |
return op->type == GGML_TYPE_F32 && op->src[0]->type == GGML_TYPE_F32;
|
| 8564 |
case GGML_OP_ROPE:
|
|
|
|
| 8565 |
case GGML_OP_NONE:
|
| 8566 |
case GGML_OP_RESHAPE:
|
| 8567 |
case GGML_OP_VIEW:
|
|
@@ -8576,6 +8656,8 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
|
| 8576 |
case GGML_OP_SUB:
|
| 8577 |
case GGML_OP_MUL:
|
| 8578 |
case GGML_OP_DIV:
|
|
|
|
|
|
|
| 8579 |
case GGML_OP_SQR:
|
| 8580 |
case GGML_OP_SIN:
|
| 8581 |
case GGML_OP_COS:
|
|
@@ -8588,6 +8670,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
|
| 8588 |
case GGML_OP_PAD:
|
| 8589 |
case GGML_OP_DIAG_MASK_INF:
|
| 8590 |
case GGML_OP_SOFT_MAX:
|
|
|
|
| 8591 |
case GGML_OP_ARGSORT:
|
| 8592 |
case GGML_OP_SUM:
|
| 8593 |
case GGML_OP_SUM_ROWS:
|
|
@@ -8979,15 +9062,22 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
|
|
| 8979 |
tensor_clone = ggml_group_norm(ggml_ctx, src_clone[0], *(int *)tensor->op_params, ((float *)tensor->op_params)[1]);
|
| 8980 |
} else if (tensor->op == GGML_OP_RMS_NORM) {
|
| 8981 |
tensor_clone = ggml_rms_norm(ggml_ctx, src_clone[0], *(float *)tensor->op_params);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8982 |
} else if (tensor->op == GGML_OP_SOFT_MAX) {
|
| 8983 |
if (src1 != nullptr) {
|
| 8984 |
tensor_clone = ggml_soft_max_ext(ggml_ctx, src_clone[0], src_clone[1], ((float *)tensor->op_params)[0], ((float *)tensor->op_params)[1]);
|
| 8985 |
} else {
|
| 8986 |
tensor_clone = ggml_soft_max(ggml_ctx, src_clone[0]);
|
| 8987 |
}
|
|
|
|
|
|
|
| 8988 |
} else if (tensor->op == GGML_OP_DIAG_MASK_INF) {
|
| 8989 |
tensor_clone = ggml_diag_mask_inf(ggml_ctx, src_clone[0], *(int *)tensor->op_params);
|
| 8990 |
-
} else if (tensor->op == GGML_OP_ROPE) {
|
| 8991 |
const int n_dims = ((int32_t *) tensor->op_params)[1];
|
| 8992 |
const int mode = ((int32_t *) tensor->op_params)[2];
|
| 8993 |
//const int n_ctx_ggml = ((int32_t *) tensor->op_params)[3];
|
|
@@ -9000,9 +9090,17 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
|
|
| 9000 |
const float beta_slow = ((float *) tensor->op_params)[10];
|
| 9001 |
if (mode & GGML_ROPE_TYPE_MROPE) {
|
| 9002 |
int32_t *sections = ((int32_t *) tensor->op_params) + 11;
|
| 9003 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9004 |
} else {
|
| 9005 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9006 |
}
|
| 9007 |
} else if (tensor->op == GGML_OP_UNARY) {
|
| 9008 |
switch (ggml_get_unary_op(tensor)) {
|
|
@@ -9021,6 +9119,9 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
|
|
| 9021 |
case GGML_UNARY_OP_TANH:
|
| 9022 |
tensor_clone = ggml_tanh(ggml_ctx, src_clone[0]);
|
| 9023 |
break;
|
|
|
|
|
|
|
|
|
|
| 9024 |
default:
|
| 9025 |
std::cerr << "Missing vk_check_results OP: " << ggml_op_name(tensor->op) << std::endl;
|
| 9026 |
GGML_ABORT("fatal error");
|
|
|
|
| 241 |
vk_pipeline pipeline_norm_f32;
|
| 242 |
vk_pipeline pipeline_group_norm_f32;
|
| 243 |
vk_pipeline pipeline_rms_norm_f32;
|
| 244 |
+
vk_pipeline pipeline_rms_norm_back_f32;
|
| 245 |
vk_pipeline pipeline_gelu_f32;
|
| 246 |
vk_pipeline pipeline_gelu_quick_f32;
|
| 247 |
vk_pipeline pipeline_silu_f32;
|
| 248 |
+
vk_pipeline pipeline_silu_back_f32;
|
| 249 |
vk_pipeline pipeline_relu_f32;
|
| 250 |
vk_pipeline pipeline_leaky_relu_f32;
|
| 251 |
vk_pipeline pipeline_tanh_f32;
|
| 252 |
+
vk_pipeline pipeline_sigmoid_f32;
|
| 253 |
vk_pipeline pipeline_diag_mask_inf_f32;
|
| 254 |
vk_pipeline pipeline_soft_max_f32, pipeline_soft_max_f32_f16;
|
| 255 |
vk_pipeline pipeline_soft_max_f32_wg512, pipeline_soft_max_f32_f16_wg512;
|
| 256 |
+
vk_pipeline pipeline_soft_max_back_f32;
|
| 257 |
vk_pipeline pipeline_rope_norm_f32, pipeline_rope_norm_f16;
|
| 258 |
vk_pipeline pipeline_rope_neox_f32, pipeline_rope_neox_f16;
|
| 259 |
vk_pipeline pipeline_rope_multi_f32, pipeline_rope_multi_f16;
|
|
|
|
| 508 |
uint32_t s1;
|
| 509 |
uint32_t s2;
|
| 510 |
int32_t sections[4];
|
| 511 |
+
uint32_t is_back;
|
| 512 |
};
|
| 513 |
|
| 514 |
struct vk_op_soft_max_push_constants {
|
|
|
|
| 2127 |
ggml_vk_create_pipeline(device, device->pipeline_norm_f32, "norm_f32", norm_f32_len, norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
|
| 2128 |
ggml_vk_create_pipeline(device, device->pipeline_group_norm_f32, "group_norm_f32", group_norm_f32_len, group_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
|
| 2129 |
ggml_vk_create_pipeline(device, device->pipeline_rms_norm_f32, "rms_norm_f32", rms_norm_f32_len, rms_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
|
| 2130 |
+
ggml_vk_create_pipeline(device, device->pipeline_rms_norm_back_f32, "rms_norm_back_f32", rms_norm_back_f32_len, rms_norm_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
|
| 2131 |
|
| 2132 |
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_f32, "cpy_f32_f32", cpy_f32_f32_len, cpy_f32_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
| 2133 |
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_f16, "cpy_f32_f16", cpy_f32_f16_len, cpy_f32_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
|
|
|
| 2187 |
ggml_vk_create_pipeline(device, device->pipeline_gelu_f32, "gelu_f32", gelu_f32_len, gelu_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
|
| 2188 |
ggml_vk_create_pipeline(device, device->pipeline_gelu_quick_f32, "gelu_quick_f32", gelu_quick_f32_len, gelu_quick_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
|
| 2189 |
ggml_vk_create_pipeline(device, device->pipeline_silu_f32, "silu_f32", silu_f32_len, silu_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
|
| 2190 |
+
ggml_vk_create_pipeline(device, device->pipeline_silu_back_f32, "silu_back_f32", silu_back_f32_len, silu_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
|
| 2191 |
ggml_vk_create_pipeline(device, device->pipeline_relu_f32, "relu_f32", relu_f32_len, relu_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
|
| 2192 |
ggml_vk_create_pipeline(device, device->pipeline_leaky_relu_f32, "leaky_relu_f32", leaky_relu_f32_len, leaky_relu_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
|
| 2193 |
ggml_vk_create_pipeline(device, device->pipeline_tanh_f32, "tanh_f32", tanh_f32_len, tanh_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
|
| 2194 |
+
ggml_vk_create_pipeline(device, device->pipeline_sigmoid_f32, "sigmoid_f32", sigmoid_f32_len, sigmoid_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
|
| 2195 |
|
| 2196 |
ggml_vk_create_pipeline(device, device->pipeline_diag_mask_inf_f32, "diag_mask_inf_f32", diag_mask_inf_f32_len, diag_mask_inf_f32_data, "main", 2, sizeof(vk_op_diag_mask_push_constants), {1, 512, 1}, {}, 1, true);
|
| 2197 |
|
|
|
|
| 2199 |
ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32_wg512, "soft_max_f32_wg512", soft_max_f32_len, soft_max_f32_data, "main", 3, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { 512 }, 1);
|
| 2200 |
ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32_f16, "soft_max_f32_f16", soft_max_f32_f16_len, soft_max_f32_f16_data, "main", 3, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
|
| 2201 |
ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32_f16_wg512, "soft_max_f32_f16_wg512", soft_max_f32_f16_len, soft_max_f32_f16_data, "main", 3, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { 512 }, 1);
|
| 2202 |
+
ggml_vk_create_pipeline(device, device->pipeline_soft_max_back_f32, "soft_max_back_f32", soft_max_back_f32_len, soft_max_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
|
| 2203 |
|
| 2204 |
ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f32, "rope_norm_f32", rope_norm_f32_len, rope_norm_f32_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
|
| 2205 |
ggml_vk_create_pipeline(device, device->pipeline_rope_neox_f32, "rope_neox_f32", rope_neox_f32_len, rope_neox_f32_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
|
|
|
|
| 5293 |
case GGML_OP_CONT:
|
| 5294 |
case GGML_OP_DUP:
|
| 5295 |
return ggml_vk_get_cpy_pipeline(ctx, src0, dst, dst->type);
|
| 5296 |
+
case GGML_OP_SILU_BACK:
|
| 5297 |
+
if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
|
| 5298 |
+
return ctx->device->pipeline_silu_back_f32;
|
| 5299 |
+
}
|
| 5300 |
+
return nullptr;
|
| 5301 |
case GGML_OP_NORM:
|
| 5302 |
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
|
| 5303 |
return ctx->device->pipeline_norm_f32;
|
|
|
|
| 5313 |
return ctx->device->pipeline_rms_norm_f32;
|
| 5314 |
}
|
| 5315 |
return nullptr;
|
| 5316 |
+
case GGML_OP_RMS_NORM_BACK:
|
| 5317 |
+
if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
|
| 5318 |
+
return ctx->device->pipeline_rms_norm_back_f32;
|
| 5319 |
+
}
|
| 5320 |
+
return nullptr;
|
| 5321 |
case GGML_OP_UNARY:
|
| 5322 |
switch (ggml_get_unary_op(dst)) {
|
| 5323 |
case GGML_UNARY_OP_SILU:
|
|
|
|
| 5345 |
return ctx->device->pipeline_tanh_f32;
|
| 5346 |
}
|
| 5347 |
break;
|
| 5348 |
+
case GGML_UNARY_OP_SIGMOID:
|
| 5349 |
+
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
|
| 5350 |
+
return ctx->device->pipeline_sigmoid_f32;
|
| 5351 |
+
}
|
| 5352 |
+
break;
|
| 5353 |
default:
|
| 5354 |
break;
|
| 5355 |
}
|
|
|
|
| 5369 |
return src0->ne[0] > 1024 ? ctx->device->pipeline_soft_max_f32_f16_wg512 : ctx->device->pipeline_soft_max_f32_f16;
|
| 5370 |
}
|
| 5371 |
return nullptr;
|
| 5372 |
+
case GGML_OP_SOFT_MAX_BACK:
|
| 5373 |
+
if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
|
| 5374 |
+
return ctx->device->pipeline_soft_max_back_f32;
|
| 5375 |
+
}
|
| 5376 |
+
return nullptr;
|
| 5377 |
case GGML_OP_ROPE:
|
| 5378 |
+
case GGML_OP_ROPE_BACK:
|
| 5379 |
{
|
| 5380 |
const int mode = ((const int32_t *) dst->op_params)[2];
|
| 5381 |
const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
|
|
|
|
| 5703 |
switch (op) {
|
| 5704 |
case GGML_OP_NORM:
|
| 5705 |
case GGML_OP_RMS_NORM:
|
| 5706 |
+
case GGML_OP_RMS_NORM_BACK:
|
| 5707 |
case GGML_OP_SOFT_MAX:
|
| 5708 |
+
case GGML_OP_SOFT_MAX_BACK:
|
| 5709 |
case GGML_OP_SUM_ROWS:
|
| 5710 |
case GGML_OP_ARGMAX:
|
| 5711 |
{
|
|
|
|
| 5729 |
} break;
|
| 5730 |
case GGML_OP_DIAG_MASK_INF:
|
| 5731 |
case GGML_OP_ROPE:
|
| 5732 |
+
case GGML_OP_ROPE_BACK:
|
| 5733 |
elements = { (uint32_t)ggml_nrows(src0), (uint32_t)ne00, 1 };
|
| 5734 |
break;
|
| 5735 |
case GGML_OP_GET_ROWS:
|
|
|
|
| 5825 |
|
| 5826 |
ggml_vk_sync_buffers(subctx);
|
| 5827 |
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, subbuf_y, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, sizeof(PC), &pc, elements);
|
| 5828 |
+
} else if (op == GGML_OP_ROPE || op == GGML_OP_ROPE_BACK) {
|
| 5829 |
// Empty src2 is possible in rope, but the shader needs a buffer
|
| 5830 |
vk_subbuffer subbuf_z;
|
| 5831 |
if (use_src2) {
|
|
|
|
| 6347 |
}, dryrun);
|
| 6348 |
}
|
| 6349 |
|
| 6350 |
+
static void ggml_vk_silu_back(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
|
| 6351 |
+
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_SILU_BACK, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f }, dryrun);
|
| 6352 |
+
}
|
| 6353 |
+
|
| 6354 |
static void ggml_vk_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
|
| 6355 |
float * op_params = (float *)dst->op_params;
|
| 6356 |
|
|
|
|
| 6373 |
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_RMS_NORM, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], 0.0f }, dryrun);
|
| 6374 |
}
|
| 6375 |
|
| 6376 |
+
static void ggml_vk_rms_norm_back(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
|
| 6377 |
+
float * op_params = (float *)dst->op_params;
|
| 6378 |
+
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_RMS_NORM_BACK, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], 0.0f }, dryrun);
|
| 6379 |
+
}
|
| 6380 |
+
|
| 6381 |
static void ggml_vk_unary(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
|
| 6382 |
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_UNARY, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f }, dryrun);
|
| 6383 |
}
|
|
|
|
| 6413 |
}, dryrun);
|
| 6414 |
}
|
| 6415 |
|
| 6416 |
+
static void ggml_vk_soft_max_back(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
|
| 6417 |
+
float * op_params = (float *)dst->op_params;
|
| 6418 |
+
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_SOFT_MAX_BACK, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], op_params[1] }, dryrun);
|
| 6419 |
+
}
|
| 6420 |
+
|
| 6421 |
+
static void ggml_vk_rope(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, bool backprop, bool dryrun = false) {
|
| 6422 |
const int n_dims = ((int32_t *) dst->op_params)[1];
|
| 6423 |
const int mode = ((int32_t *) dst->op_params)[2];
|
| 6424 |
// const int n_ctx = ((int32_t *) dst->op_params)[3];
|
|
|
|
| 6446 |
(uint32_t)src0->ne[0], (uint32_t)n_dims, freq_scale, (uint32_t)src0->ne[1],
|
| 6447 |
freq_base, ext_factor, attn_factor, {corr_dims[0], corr_dims[1]}, theta_scale,
|
| 6448 |
src2 != nullptr, (uint32_t)src0->ne[2], s1, s2,
|
| 6449 |
+
sections[0], sections[1], sections[2], sections[3], backprop
|
| 6450 |
}, dryrun);
|
| 6451 |
}
|
| 6452 |
|
|
|
|
| 7343 |
case GGML_UNARY_OP_GELU_QUICK:
|
| 7344 |
case GGML_UNARY_OP_RELU:
|
| 7345 |
case GGML_UNARY_OP_TANH:
|
| 7346 |
+
case GGML_UNARY_OP_SIGMOID:
|
| 7347 |
break;
|
| 7348 |
default:
|
| 7349 |
return false;
|
|
|
|
| 7368 |
case GGML_OP_CPY:
|
| 7369 |
case GGML_OP_CONT:
|
| 7370 |
case GGML_OP_DUP:
|
| 7371 |
+
case GGML_OP_SILU_BACK:
|
| 7372 |
case GGML_OP_NORM:
|
| 7373 |
case GGML_OP_GROUP_NORM:
|
| 7374 |
case GGML_OP_RMS_NORM:
|
| 7375 |
+
case GGML_OP_RMS_NORM_BACK:
|
| 7376 |
case GGML_OP_DIAG_MASK_INF:
|
| 7377 |
case GGML_OP_SOFT_MAX:
|
| 7378 |
+
case GGML_OP_SOFT_MAX_BACK:
|
| 7379 |
case GGML_OP_ROPE:
|
| 7380 |
+
case GGML_OP_ROPE_BACK:
|
| 7381 |
case GGML_OP_MUL_MAT:
|
| 7382 |
case GGML_OP_MUL_MAT_ID:
|
| 7383 |
case GGML_OP_ARGSORT:
|
|
|
|
| 7430 |
case GGML_OP_CPY:
|
| 7431 |
case GGML_OP_CONT:
|
| 7432 |
case GGML_OP_DUP:
|
| 7433 |
+
case GGML_OP_SILU_BACK:
|
| 7434 |
case GGML_OP_NORM:
|
| 7435 |
case GGML_OP_GROUP_NORM:
|
| 7436 |
case GGML_OP_RMS_NORM:
|
| 7437 |
+
case GGML_OP_RMS_NORM_BACK:
|
| 7438 |
case GGML_OP_UNARY:
|
| 7439 |
case GGML_OP_DIAG_MASK_INF:
|
| 7440 |
case GGML_OP_SOFT_MAX:
|
| 7441 |
+
case GGML_OP_SOFT_MAX_BACK:
|
| 7442 |
case GGML_OP_ROPE:
|
| 7443 |
+
case GGML_OP_ROPE_BACK:
|
| 7444 |
case GGML_OP_ARGSORT:
|
| 7445 |
case GGML_OP_SUM:
|
| 7446 |
case GGML_OP_SUM_ROWS:
|
|
|
|
| 7532 |
case GGML_OP_DUP:
|
| 7533 |
ggml_vk_cpy(ctx, compute_ctx, src0, node, dryrun);
|
| 7534 |
|
| 7535 |
+
break;
|
| 7536 |
+
case GGML_OP_SILU_BACK:
|
| 7537 |
+
ggml_vk_silu_back(ctx, compute_ctx, src0, src1, node, dryrun);
|
| 7538 |
+
|
| 7539 |
break;
|
| 7540 |
case GGML_OP_NORM:
|
| 7541 |
ggml_vk_norm(ctx, compute_ctx, src0, node, dryrun);
|
|
|
|
| 7548 |
case GGML_OP_RMS_NORM:
|
| 7549 |
ggml_vk_rms_norm(ctx, compute_ctx, src0, node, dryrun);
|
| 7550 |
|
| 7551 |
+
break;
|
| 7552 |
+
case GGML_OP_RMS_NORM_BACK:
|
| 7553 |
+
ggml_vk_rms_norm_back(ctx, compute_ctx, src0, src1, node, dryrun);
|
| 7554 |
+
|
| 7555 |
break;
|
| 7556 |
case GGML_OP_UNARY:
|
| 7557 |
switch (ggml_get_unary_op(node)) {
|
|
|
|
| 7560 |
case GGML_UNARY_OP_GELU_QUICK:
|
| 7561 |
case GGML_UNARY_OP_RELU:
|
| 7562 |
case GGML_UNARY_OP_TANH:
|
| 7563 |
+
case GGML_UNARY_OP_SIGMOID:
|
| 7564 |
ggml_vk_unary(ctx, compute_ctx, src0, node, dryrun);
|
| 7565 |
break;
|
| 7566 |
default:
|
|
|
|
| 7574 |
case GGML_OP_SOFT_MAX:
|
| 7575 |
ggml_vk_soft_max(ctx, compute_ctx, src0, src1, node, dryrun);
|
| 7576 |
|
| 7577 |
+
break;
|
| 7578 |
+
case GGML_OP_SOFT_MAX_BACK:
|
| 7579 |
+
ggml_vk_soft_max_back(ctx, compute_ctx, src0, src1, node, dryrun);
|
| 7580 |
+
|
| 7581 |
break;
|
| 7582 |
case GGML_OP_ROPE:
|
| 7583 |
+
ggml_vk_rope(ctx, compute_ctx, src0, src1, src2, node, false, dryrun);
|
| 7584 |
+
|
| 7585 |
+
break;
|
| 7586 |
+
case GGML_OP_ROPE_BACK:
|
| 7587 |
+
ggml_vk_rope(ctx, compute_ctx, src0, src1, src2, node, true, dryrun);
|
| 7588 |
|
| 7589 |
break;
|
| 7590 |
case GGML_OP_ARGSORT:
|
|
|
|
| 7710 |
case GGML_OP_CPY:
|
| 7711 |
case GGML_OP_CONT:
|
| 7712 |
case GGML_OP_DUP:
|
| 7713 |
+
case GGML_OP_SILU_BACK:
|
| 7714 |
case GGML_OP_NORM:
|
| 7715 |
case GGML_OP_GROUP_NORM:
|
| 7716 |
case GGML_OP_RMS_NORM:
|
| 7717 |
+
case GGML_OP_RMS_NORM_BACK:
|
| 7718 |
case GGML_OP_DIAG_MASK_INF:
|
| 7719 |
case GGML_OP_SOFT_MAX:
|
| 7720 |
+
case GGML_OP_SOFT_MAX_BACK:
|
| 7721 |
case GGML_OP_ROPE:
|
| 7722 |
+
case GGML_OP_ROPE_BACK:
|
| 7723 |
case GGML_OP_RESHAPE:
|
| 7724 |
case GGML_OP_VIEW:
|
| 7725 |
case GGML_OP_PERMUTE:
|
|
|
|
| 7748 |
case GGML_UNARY_OP_GELU_QUICK:
|
| 7749 |
case GGML_UNARY_OP_RELU:
|
| 7750 |
case GGML_UNARY_OP_TANH:
|
| 7751 |
+
case GGML_UNARY_OP_SIGMOID:
|
| 7752 |
buf = tensor->buffer;
|
| 7753 |
break;
|
| 7754 |
default:
|
|
|
|
| 8451 |
case GGML_UNARY_OP_SILU:
|
| 8452 |
case GGML_UNARY_OP_RELU:
|
| 8453 |
case GGML_UNARY_OP_TANH:
|
| 8454 |
+
case GGML_UNARY_OP_SIGMOID:
|
| 8455 |
return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
|
| 8456 |
default:
|
| 8457 |
return false;
|
|
|
|
| 8641 |
case GGML_OP_REPEAT_BACK:
|
| 8642 |
return op->type == GGML_TYPE_F32 && op->src[0]->type == GGML_TYPE_F32;
|
| 8643 |
case GGML_OP_ROPE:
|
| 8644 |
+
case GGML_OP_ROPE_BACK:
|
| 8645 |
case GGML_OP_NONE:
|
| 8646 |
case GGML_OP_RESHAPE:
|
| 8647 |
case GGML_OP_VIEW:
|
|
|
|
| 8656 |
case GGML_OP_SUB:
|
| 8657 |
case GGML_OP_MUL:
|
| 8658 |
case GGML_OP_DIV:
|
| 8659 |
+
case GGML_OP_SILU_BACK:
|
| 8660 |
+
case GGML_OP_RMS_NORM_BACK:
|
| 8661 |
case GGML_OP_SQR:
|
| 8662 |
case GGML_OP_SIN:
|
| 8663 |
case GGML_OP_COS:
|
|
|
|
| 8670 |
case GGML_OP_PAD:
|
| 8671 |
case GGML_OP_DIAG_MASK_INF:
|
| 8672 |
case GGML_OP_SOFT_MAX:
|
| 8673 |
+
case GGML_OP_SOFT_MAX_BACK:
|
| 8674 |
case GGML_OP_ARGSORT:
|
| 8675 |
case GGML_OP_SUM:
|
| 8676 |
case GGML_OP_SUM_ROWS:
|
|
|
|
| 9062 |
tensor_clone = ggml_group_norm(ggml_ctx, src_clone[0], *(int *)tensor->op_params, ((float *)tensor->op_params)[1]);
|
| 9063 |
} else if (tensor->op == GGML_OP_RMS_NORM) {
|
| 9064 |
tensor_clone = ggml_rms_norm(ggml_ctx, src_clone[0], *(float *)tensor->op_params);
|
| 9065 |
+
} else if (tensor->op == GGML_OP_RMS_NORM_BACK) {
|
| 9066 |
+
const float eps = ((float *) tensor->op_params)[0];
|
| 9067 |
+
tensor_clone = ggml_rms_norm_back(ggml_ctx, src_clone[0], src_clone[1], eps);
|
| 9068 |
+
} else if (tensor->op == GGML_OP_SILU_BACK) {
|
| 9069 |
+
tensor_clone = ggml_silu_back(ggml_ctx, src_clone[0], src_clone[1]);
|
| 9070 |
} else if (tensor->op == GGML_OP_SOFT_MAX) {
|
| 9071 |
if (src1 != nullptr) {
|
| 9072 |
tensor_clone = ggml_soft_max_ext(ggml_ctx, src_clone[0], src_clone[1], ((float *)tensor->op_params)[0], ((float *)tensor->op_params)[1]);
|
| 9073 |
} else {
|
| 9074 |
tensor_clone = ggml_soft_max(ggml_ctx, src_clone[0]);
|
| 9075 |
}
|
| 9076 |
+
} else if (tensor->op == GGML_OP_SOFT_MAX_BACK) {
|
| 9077 |
+
tensor_clone = ggml_soft_max_ext_back(ggml_ctx, src_clone[0], src_clone[1], ((float *)tensor->op_params)[0], ((float *)tensor->op_params)[1]);
|
| 9078 |
} else if (tensor->op == GGML_OP_DIAG_MASK_INF) {
|
| 9079 |
tensor_clone = ggml_diag_mask_inf(ggml_ctx, src_clone[0], *(int *)tensor->op_params);
|
| 9080 |
+
} else if (tensor->op == GGML_OP_ROPE || tensor->op == GGML_OP_ROPE_BACK) {
|
| 9081 |
const int n_dims = ((int32_t *) tensor->op_params)[1];
|
| 9082 |
const int mode = ((int32_t *) tensor->op_params)[2];
|
| 9083 |
//const int n_ctx_ggml = ((int32_t *) tensor->op_params)[3];
|
|
|
|
| 9090 |
const float beta_slow = ((float *) tensor->op_params)[10];
|
| 9091 |
if (mode & GGML_ROPE_TYPE_MROPE) {
|
| 9092 |
int32_t *sections = ((int32_t *) tensor->op_params) + 11;
|
| 9093 |
+
if (tensor->op == GGML_OP_ROPE) {
|
| 9094 |
+
tensor_clone = ggml_rope_multi(ggml_ctx, src_clone[0], src_clone[1], src_clone[2], n_dims, sections, mode, n_ctx_orig_ggml, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow);
|
| 9095 |
+
} else {
|
| 9096 |
+
tensor_clone = ggml_rope_multi_back(ggml_ctx, src_clone[0], src_clone[1], src_clone[2], n_dims, sections, mode, n_ctx_orig_ggml, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow);
|
| 9097 |
+
}
|
| 9098 |
} else {
|
| 9099 |
+
if (tensor->op == GGML_OP_ROPE) {
|
| 9100 |
+
tensor_clone = ggml_rope_ext(ggml_ctx, src_clone[0], src_clone[1], src_clone[2], n_dims, mode, n_ctx_orig_ggml, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow);
|
| 9101 |
+
} else {
|
| 9102 |
+
tensor_clone = ggml_rope_ext_back(ggml_ctx, src_clone[0], src_clone[1], src_clone[2], n_dims, mode, n_ctx_orig_ggml, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow);
|
| 9103 |
+
}
|
| 9104 |
}
|
| 9105 |
} else if (tensor->op == GGML_OP_UNARY) {
|
| 9106 |
switch (ggml_get_unary_op(tensor)) {
|
|
|
|
| 9119 |
case GGML_UNARY_OP_TANH:
|
| 9120 |
tensor_clone = ggml_tanh(ggml_ctx, src_clone[0]);
|
| 9121 |
break;
|
| 9122 |
+
case GGML_UNARY_OP_SIGMOID:
|
| 9123 |
+
tensor_clone = ggml_sigmoid(ggml_ctx, src_clone[0]);
|
| 9124 |
+
break;
|
| 9125 |
default:
|
| 9126 |
std::cerr << "Missing vk_check_results OP: " << ggml_op_name(tensor->op) << std::endl;
|
| 9127 |
GGML_ABORT("fatal error");
|
ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_back.comp
ADDED
|
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#version 450
|
| 2 |
+
|
| 3 |
+
#include "generic_head.comp"
|
| 4 |
+
#include "types.comp"
|
| 5 |
+
|
| 6 |
+
#extension GL_EXT_control_flow_attributes : enable
|
| 7 |
+
#define BLOCK_SIZE 512
|
| 8 |
+
|
| 9 |
+
layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in;
|
| 10 |
+
|
| 11 |
+
layout (binding = 0) readonly buffer G {A_TYPE data_a[];};
|
| 12 |
+
layout (binding = 1) readonly buffer X {B_TYPE data_b[];};
|
| 13 |
+
layout (binding = 2) writeonly buffer D {D_TYPE data_d[];};
|
| 14 |
+
|
| 15 |
+
shared FLOAT_TYPE sum_xx[BLOCK_SIZE];
|
| 16 |
+
shared FLOAT_TYPE sum_xg[BLOCK_SIZE];
|
| 17 |
+
|
| 18 |
+
void main() {
|
| 19 |
+
const uint row = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x;
|
| 20 |
+
const uint tid = gl_LocalInvocationID.x;
|
| 21 |
+
|
| 22 |
+
// Compute derivative of x[i]/norm(x) = g[i]/norm(x) - x[i] dot(x,g)/KX / norm(x)^1.5
|
| 23 |
+
|
| 24 |
+
// partial sums for thread in warp
|
| 25 |
+
sum_xx[tid] = FLOAT_TYPE(0.0f);
|
| 26 |
+
sum_xg[tid] = FLOAT_TYPE(0.0f);
|
| 27 |
+
|
| 28 |
+
[[unroll]] for (uint col = tid; col < p.KX; col += BLOCK_SIZE) {
|
| 29 |
+
const FLOAT_TYPE gi = FLOAT_TYPE(data_a[row*p.KX + col]);
|
| 30 |
+
const FLOAT_TYPE xi = FLOAT_TYPE(data_b[row*p.KX + col]);
|
| 31 |
+
sum_xx[tid] += xi * xi;
|
| 32 |
+
sum_xg[tid] += xi * gi;
|
| 33 |
+
}
|
| 34 |
+
|
| 35 |
+
// sum up partial sums and write back result
|
| 36 |
+
barrier();
|
| 37 |
+
[[unroll]] for (int s = BLOCK_SIZE / 2; s > 0; s >>= 1) {
|
| 38 |
+
if (tid < s) {
|
| 39 |
+
sum_xx[tid] += sum_xx[tid + s];
|
| 40 |
+
sum_xg[tid] += sum_xg[tid + s];
|
| 41 |
+
}
|
| 42 |
+
barrier();
|
| 43 |
+
}
|
| 44 |
+
|
| 45 |
+
const FLOAT_TYPE eps = FLOAT_TYPE(p.param1);
|
| 46 |
+
const FLOAT_TYPE mean = sum_xx[0] / FLOAT_TYPE(p.KX);
|
| 47 |
+
const FLOAT_TYPE scale_g = inversesqrt(mean + eps);
|
| 48 |
+
const FLOAT_TYPE scale_x = -scale_g * sum_xg[0] / (sum_xx[0] + FLOAT_TYPE(p.KX) * eps);
|
| 49 |
+
|
| 50 |
+
[[unroll]] for (uint col = tid; col < p.KX; col += BLOCK_SIZE) {
|
| 51 |
+
data_d[row*p.KX + col] = D_TYPE(
|
| 52 |
+
scale_g * FLOAT_TYPE(data_a[row*p.KX + col]) +
|
| 53 |
+
scale_x * FLOAT_TYPE(data_b[row*p.KX + col]));
|
| 54 |
+
}
|
| 55 |
+
}
|
ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp
CHANGED
|
@@ -29,6 +29,7 @@ layout (push_constant) uniform parameter {
|
|
| 29 |
uint s1;
|
| 30 |
uint s2;
|
| 31 |
int sections[4];
|
|
|
|
| 32 |
} p;
|
| 33 |
|
| 34 |
float rope_yarn_ramp(const float low, const float high, const uint i0) {
|
|
@@ -48,6 +49,10 @@ void rope_yarn(const float theta_extrap, const uint i0, out float cos_theta, out
|
|
| 48 |
// Get n-d magnitude scaling corrected for interpolation
|
| 49 |
mscale *= 1.0f + 0.1f * log(1.0f / p.freq_scale);
|
| 50 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
| 51 |
cos_theta = cos(theta) * mscale;
|
| 52 |
sin_theta = sin(theta) * mscale;
|
| 53 |
}
|
|
|
|
| 29 |
uint s1;
|
| 30 |
uint s2;
|
| 31 |
int sections[4];
|
| 32 |
+
uint is_back;
|
| 33 |
} p;
|
| 34 |
|
| 35 |
float rope_yarn_ramp(const float low, const float high, const uint i0) {
|
|
|
|
| 49 |
// Get n-d magnitude scaling corrected for interpolation
|
| 50 |
mscale *= 1.0f + 0.1f * log(1.0f / p.freq_scale);
|
| 51 |
}
|
| 52 |
+
// Backprogagation uses inverted rotation
|
| 53 |
+
if (p.is_back != 0) {
|
| 54 |
+
theta = -theta;
|
| 55 |
+
}
|
| 56 |
cos_theta = cos(theta) * mscale;
|
| 57 |
sin_theta = sin(theta) * mscale;
|
| 58 |
}
|
ggml/src/ggml-vulkan/vulkan-shaders/sigmoid.comp
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#version 450
|
| 2 |
+
|
| 3 |
+
#include "generic_head.comp"
|
| 4 |
+
#include "types.comp"
|
| 5 |
+
|
| 6 |
+
#extension GL_EXT_control_flow_attributes : enable
|
| 7 |
+
|
| 8 |
+
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
|
| 9 |
+
|
| 10 |
+
layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
|
| 11 |
+
layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
|
| 12 |
+
|
| 13 |
+
void main() {
|
| 14 |
+
const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
|
| 15 |
+
|
| 16 |
+
if (i >= p.KX) {
|
| 17 |
+
return;
|
| 18 |
+
}
|
| 19 |
+
data_d[i] = D_TYPE(1. / (1 + exp(-1. *data_a[i])));
|
| 20 |
+
}
|
ggml/src/ggml-vulkan/vulkan-shaders/silu_back.comp
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#version 450
|
| 2 |
+
|
| 3 |
+
#include "generic_head.comp"
|
| 4 |
+
#include "types.comp"
|
| 5 |
+
|
| 6 |
+
#extension GL_EXT_control_flow_attributes : enable
|
| 7 |
+
|
| 8 |
+
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
|
| 9 |
+
|
| 10 |
+
layout (binding = 0) readonly buffer G {A_TYPE data_g[];};
|
| 11 |
+
layout (binding = 1) readonly buffer X {B_TYPE data_x[];};
|
| 12 |
+
layout (binding = 2) writeonly buffer D {D_TYPE data_d[];};
|
| 13 |
+
|
| 14 |
+
void main() {
|
| 15 |
+
const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
|
| 16 |
+
|
| 17 |
+
if (i >= p.KX) {
|
| 18 |
+
return;
|
| 19 |
+
}
|
| 20 |
+
|
| 21 |
+
// Compute derivative of SiLU(x): 1/(1+exp(-x)) - x*exp(-x)/(1+exp(-x))^2
|
| 22 |
+
|
| 23 |
+
const float xi = float(data_x[i]);
|
| 24 |
+
const float s = 1.0f / (1.0f + exp(-xi));
|
| 25 |
+
data_d[i] = D_TYPE(data_g[i] * (s + xi * s * (1 - s)));
|
| 26 |
+
}
|
ggml/src/ggml-vulkan/vulkan-shaders/soft_max_back.comp
ADDED
|
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#version 450
|
| 2 |
+
|
| 3 |
+
#extension GL_EXT_control_flow_attributes : enable
|
| 4 |
+
|
| 5 |
+
#include "generic_head.comp"
|
| 6 |
+
#include "types.comp"
|
| 7 |
+
|
| 8 |
+
layout(constant_id = 0) const uint BLOCK_SIZE = 32;
|
| 9 |
+
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
|
| 10 |
+
|
| 11 |
+
// In this shader Y = softmax(X) and X is not provided as input.
|
| 12 |
+
|
| 13 |
+
layout (binding = 0) readonly buffer G {A_TYPE data_g[];};
|
| 14 |
+
layout (binding = 1) readonly buffer Y {B_TYPE data_y[];};
|
| 15 |
+
layout (binding = 2) buffer D {D_TYPE data_d[];};
|
| 16 |
+
|
| 17 |
+
shared FLOAT_TYPE sum_yg[BLOCK_SIZE];
|
| 18 |
+
|
| 19 |
+
void main() {
|
| 20 |
+
const uint row = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x;
|
| 21 |
+
const uint tid = gl_LocalInvocationID.x;
|
| 22 |
+
|
| 23 |
+
FLOAT_TYPE scale = p.param1;
|
| 24 |
+
|
| 25 |
+
// partial sums for thread in warp
|
| 26 |
+
sum_yg[tid] = FLOAT_TYPE(0.0f);
|
| 27 |
+
|
| 28 |
+
[[unroll]] for (uint col = tid; col < p.KX; col += BLOCK_SIZE) {
|
| 29 |
+
const FLOAT_TYPE gi = FLOAT_TYPE(data_g[row*p.KX + col]);
|
| 30 |
+
const FLOAT_TYPE yi = FLOAT_TYPE(data_y[row*p.KX + col]);
|
| 31 |
+
sum_yg[tid] += yi * gi;
|
| 32 |
+
}
|
| 33 |
+
|
| 34 |
+
// sum up partial sums and write back result
|
| 35 |
+
barrier();
|
| 36 |
+
[[unroll]] for (uint s = BLOCK_SIZE / 2; s > 0; s >>= 1) {
|
| 37 |
+
if (tid < s) {
|
| 38 |
+
sum_yg[tid] += sum_yg[tid + s];
|
| 39 |
+
}
|
| 40 |
+
barrier();
|
| 41 |
+
}
|
| 42 |
+
|
| 43 |
+
const FLOAT_TYPE dot_yg = sum_yg[0];
|
| 44 |
+
|
| 45 |
+
[[unroll]] for (uint col = tid; col < p.KX; col += BLOCK_SIZE) {
|
| 46 |
+
data_d[row*p.KX + col] = D_TYPE(scale
|
| 47 |
+
* (FLOAT_TYPE(data_g[row*p.KX + col]) - dot_yg)
|
| 48 |
+
* FLOAT_TYPE(data_y[row*p.KX + col]));
|
| 49 |
+
}
|
| 50 |
+
}
|
ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp
CHANGED
|
@@ -433,6 +433,7 @@ void process_shaders() {
|
|
| 433 |
string_to_spv("norm_f32", "norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
|
| 434 |
string_to_spv("group_norm_f32", "group_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
|
| 435 |
string_to_spv("rms_norm_f32", "rms_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
|
|
|
|
| 436 |
|
| 437 |
string_to_spv("cpy_f32_f32", "copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
| 438 |
string_to_spv("cpy_f32_f16", "copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}});
|
|
@@ -483,14 +484,17 @@ void process_shaders() {
|
|
| 483 |
string_to_spv("gelu_f32", "gelu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
| 484 |
string_to_spv("gelu_quick_f32", "gelu_quick.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
| 485 |
string_to_spv("silu_f32", "silu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
|
|
|
| 486 |
string_to_spv("relu_f32", "relu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
| 487 |
string_to_spv("leaky_relu_f32", "leaky_relu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
| 488 |
string_to_spv("tanh_f32", "tanh.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
|
|
|
| 489 |
|
| 490 |
string_to_spv("diag_mask_inf_f32", "diag_mask_inf.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
| 491 |
|
| 492 |
string_to_spv("soft_max_f32", "soft_max.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}));
|
| 493 |
string_to_spv("soft_max_f32_f16", "soft_max.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}}));
|
|
|
|
| 494 |
|
| 495 |
string_to_spv("rope_norm_f32", "rope_norm.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
| 496 |
string_to_spv("rope_norm_f16", "rope_norm.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
|
|
|
|
| 433 |
string_to_spv("norm_f32", "norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
|
| 434 |
string_to_spv("group_norm_f32", "group_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
|
| 435 |
string_to_spv("rms_norm_f32", "rms_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
|
| 436 |
+
string_to_spv("rms_norm_back_f32", "rms_norm_back.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}));
|
| 437 |
|
| 438 |
string_to_spv("cpy_f32_f32", "copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
| 439 |
string_to_spv("cpy_f32_f16", "copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}});
|
|
|
|
| 484 |
string_to_spv("gelu_f32", "gelu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
| 485 |
string_to_spv("gelu_quick_f32", "gelu_quick.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
| 486 |
string_to_spv("silu_f32", "silu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
| 487 |
+
string_to_spv("silu_back_f32", "silu_back.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}});
|
| 488 |
string_to_spv("relu_f32", "relu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
| 489 |
string_to_spv("leaky_relu_f32", "leaky_relu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
| 490 |
string_to_spv("tanh_f32", "tanh.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
| 491 |
+
string_to_spv("sigmoid_f32", "sigmoid.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
| 492 |
|
| 493 |
string_to_spv("diag_mask_inf_f32", "diag_mask_inf.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
| 494 |
|
| 495 |
string_to_spv("soft_max_f32", "soft_max.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}));
|
| 496 |
string_to_spv("soft_max_f32_f16", "soft_max.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}}));
|
| 497 |
+
string_to_spv("soft_max_back_f32", "soft_max_back.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}));
|
| 498 |
|
| 499 |
string_to_spv("rope_norm_f32", "rope_norm.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
| 500 |
string_to_spv("rope_norm_f16", "rope_norm.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
|