ggerganov commited on
Commit
4c17fa1
·
1 Parent(s): 49e3343

vulkan : sync (llama/0)

Browse files
ggml/src/ggml-vulkan/ggml-vulkan.cpp CHANGED
@@ -241,15 +241,19 @@ struct vk_device_struct {
241
  vk_pipeline pipeline_norm_f32;
242
  vk_pipeline pipeline_group_norm_f32;
243
  vk_pipeline pipeline_rms_norm_f32;
 
244
  vk_pipeline pipeline_gelu_f32;
245
  vk_pipeline pipeline_gelu_quick_f32;
246
  vk_pipeline pipeline_silu_f32;
 
247
  vk_pipeline pipeline_relu_f32;
248
  vk_pipeline pipeline_leaky_relu_f32;
249
  vk_pipeline pipeline_tanh_f32;
 
250
  vk_pipeline pipeline_diag_mask_inf_f32;
251
  vk_pipeline pipeline_soft_max_f32, pipeline_soft_max_f32_f16;
252
  vk_pipeline pipeline_soft_max_f32_wg512, pipeline_soft_max_f32_f16_wg512;
 
253
  vk_pipeline pipeline_rope_norm_f32, pipeline_rope_norm_f16;
254
  vk_pipeline pipeline_rope_neox_f32, pipeline_rope_neox_f16;
255
  vk_pipeline pipeline_rope_multi_f32, pipeline_rope_multi_f16;
@@ -504,6 +508,7 @@ struct vk_op_rope_push_constants {
504
  uint32_t s1;
505
  uint32_t s2;
506
  int32_t sections[4];
 
507
  };
508
 
509
  struct vk_op_soft_max_push_constants {
@@ -2122,6 +2127,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
2122
  ggml_vk_create_pipeline(device, device->pipeline_norm_f32, "norm_f32", norm_f32_len, norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
2123
  ggml_vk_create_pipeline(device, device->pipeline_group_norm_f32, "group_norm_f32", group_norm_f32_len, group_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
2124
  ggml_vk_create_pipeline(device, device->pipeline_rms_norm_f32, "rms_norm_f32", rms_norm_f32_len, rms_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
 
2125
 
2126
  ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_f32, "cpy_f32_f32", cpy_f32_f32_len, cpy_f32_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
2127
  ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_f16, "cpy_f32_f16", cpy_f32_f16_len, cpy_f32_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
@@ -2181,9 +2187,11 @@ static void ggml_vk_load_shaders(vk_device& device) {
2181
  ggml_vk_create_pipeline(device, device->pipeline_gelu_f32, "gelu_f32", gelu_f32_len, gelu_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
2182
  ggml_vk_create_pipeline(device, device->pipeline_gelu_quick_f32, "gelu_quick_f32", gelu_quick_f32_len, gelu_quick_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
2183
  ggml_vk_create_pipeline(device, device->pipeline_silu_f32, "silu_f32", silu_f32_len, silu_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
 
2184
  ggml_vk_create_pipeline(device, device->pipeline_relu_f32, "relu_f32", relu_f32_len, relu_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
2185
  ggml_vk_create_pipeline(device, device->pipeline_leaky_relu_f32, "leaky_relu_f32", leaky_relu_f32_len, leaky_relu_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
2186
  ggml_vk_create_pipeline(device, device->pipeline_tanh_f32, "tanh_f32", tanh_f32_len, tanh_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
 
2187
 
2188
  ggml_vk_create_pipeline(device, device->pipeline_diag_mask_inf_f32, "diag_mask_inf_f32", diag_mask_inf_f32_len, diag_mask_inf_f32_data, "main", 2, sizeof(vk_op_diag_mask_push_constants), {1, 512, 1}, {}, 1, true);
2189
 
@@ -2191,6 +2199,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
2191
  ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32_wg512, "soft_max_f32_wg512", soft_max_f32_len, soft_max_f32_data, "main", 3, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { 512 }, 1);
2192
  ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32_f16, "soft_max_f32_f16", soft_max_f32_f16_len, soft_max_f32_f16_data, "main", 3, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
2193
  ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32_f16_wg512, "soft_max_f32_f16_wg512", soft_max_f32_f16_len, soft_max_f32_f16_data, "main", 3, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { 512 }, 1);
 
2194
 
2195
  ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f32, "rope_norm_f32", rope_norm_f32_len, rope_norm_f32_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
2196
  ggml_vk_create_pipeline(device, device->pipeline_rope_neox_f32, "rope_neox_f32", rope_neox_f32_len, rope_neox_f32_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
@@ -5284,6 +5293,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
5284
  case GGML_OP_CONT:
5285
  case GGML_OP_DUP:
5286
  return ggml_vk_get_cpy_pipeline(ctx, src0, dst, dst->type);
 
 
 
 
 
5287
  case GGML_OP_NORM:
5288
  if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
5289
  return ctx->device->pipeline_norm_f32;
@@ -5299,6 +5313,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
5299
  return ctx->device->pipeline_rms_norm_f32;
5300
  }
5301
  return nullptr;
 
 
 
 
 
5302
  case GGML_OP_UNARY:
5303
  switch (ggml_get_unary_op(dst)) {
5304
  case GGML_UNARY_OP_SILU:
@@ -5326,6 +5345,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
5326
  return ctx->device->pipeline_tanh_f32;
5327
  }
5328
  break;
 
 
 
 
 
5329
  default:
5330
  break;
5331
  }
@@ -5345,7 +5369,13 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
5345
  return src0->ne[0] > 1024 ? ctx->device->pipeline_soft_max_f32_f16_wg512 : ctx->device->pipeline_soft_max_f32_f16;
5346
  }
5347
  return nullptr;
 
 
 
 
 
5348
  case GGML_OP_ROPE:
 
5349
  {
5350
  const int mode = ((const int32_t *) dst->op_params)[2];
5351
  const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
@@ -5673,7 +5703,9 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
5673
  switch (op) {
5674
  case GGML_OP_NORM:
5675
  case GGML_OP_RMS_NORM:
 
5676
  case GGML_OP_SOFT_MAX:
 
5677
  case GGML_OP_SUM_ROWS:
5678
  case GGML_OP_ARGMAX:
5679
  {
@@ -5697,6 +5729,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
5697
  } break;
5698
  case GGML_OP_DIAG_MASK_INF:
5699
  case GGML_OP_ROPE:
 
5700
  elements = { (uint32_t)ggml_nrows(src0), (uint32_t)ne00, 1 };
5701
  break;
5702
  case GGML_OP_GET_ROWS:
@@ -5792,7 +5825,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
5792
 
5793
  ggml_vk_sync_buffers(subctx);
5794
  ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, subbuf_y, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, sizeof(PC), &pc, elements);
5795
- } else if (op == GGML_OP_ROPE) {
5796
  // Empty src2 is possible in rope, but the shader needs a buffer
5797
  vk_subbuffer subbuf_z;
5798
  if (use_src2) {
@@ -6314,6 +6347,10 @@ static void ggml_vk_cpy(ggml_backend_vk_context * ctx, vk_context& subctx, const
6314
  }, dryrun);
6315
  }
6316
 
 
 
 
 
6317
  static void ggml_vk_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
6318
  float * op_params = (float *)dst->op_params;
6319
 
@@ -6336,6 +6373,11 @@ static void ggml_vk_rms_norm(ggml_backend_vk_context * ctx, vk_context& subctx,
6336
  ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_RMS_NORM, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], 0.0f }, dryrun);
6337
  }
6338
 
 
 
 
 
 
6339
  static void ggml_vk_unary(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
6340
  ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_UNARY, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f }, dryrun);
6341
  }
@@ -6371,7 +6413,12 @@ static void ggml_vk_soft_max(ggml_backend_vk_context * ctx, vk_context& subctx,
6371
  }, dryrun);
6372
  }
6373
 
6374
- static void ggml_vk_rope(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, bool dryrun = false) {
 
 
 
 
 
6375
  const int n_dims = ((int32_t *) dst->op_params)[1];
6376
  const int mode = ((int32_t *) dst->op_params)[2];
6377
  // const int n_ctx = ((int32_t *) dst->op_params)[3];
@@ -6399,7 +6446,7 @@ static void ggml_vk_rope(ggml_backend_vk_context * ctx, vk_context& subctx, cons
6399
  (uint32_t)src0->ne[0], (uint32_t)n_dims, freq_scale, (uint32_t)src0->ne[1],
6400
  freq_base, ext_factor, attn_factor, {corr_dims[0], corr_dims[1]}, theta_scale,
6401
  src2 != nullptr, (uint32_t)src0->ne[2], s1, s2,
6402
- sections[0], sections[1], sections[2], sections[3],
6403
  }, dryrun);
6404
  }
6405
 
@@ -7296,6 +7343,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
7296
  case GGML_UNARY_OP_GELU_QUICK:
7297
  case GGML_UNARY_OP_RELU:
7298
  case GGML_UNARY_OP_TANH:
 
7299
  break;
7300
  default:
7301
  return false;
@@ -7320,12 +7368,16 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
7320
  case GGML_OP_CPY:
7321
  case GGML_OP_CONT:
7322
  case GGML_OP_DUP:
 
7323
  case GGML_OP_NORM:
7324
  case GGML_OP_GROUP_NORM:
7325
  case GGML_OP_RMS_NORM:
 
7326
  case GGML_OP_DIAG_MASK_INF:
7327
  case GGML_OP_SOFT_MAX:
 
7328
  case GGML_OP_ROPE:
 
7329
  case GGML_OP_MUL_MAT:
7330
  case GGML_OP_MUL_MAT_ID:
7331
  case GGML_OP_ARGSORT:
@@ -7378,13 +7430,17 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
7378
  case GGML_OP_CPY:
7379
  case GGML_OP_CONT:
7380
  case GGML_OP_DUP:
 
7381
  case GGML_OP_NORM:
7382
  case GGML_OP_GROUP_NORM:
7383
  case GGML_OP_RMS_NORM:
 
7384
  case GGML_OP_UNARY:
7385
  case GGML_OP_DIAG_MASK_INF:
7386
  case GGML_OP_SOFT_MAX:
 
7387
  case GGML_OP_ROPE:
 
7388
  case GGML_OP_ARGSORT:
7389
  case GGML_OP_SUM:
7390
  case GGML_OP_SUM_ROWS:
@@ -7476,6 +7532,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
7476
  case GGML_OP_DUP:
7477
  ggml_vk_cpy(ctx, compute_ctx, src0, node, dryrun);
7478
 
 
 
 
 
7479
  break;
7480
  case GGML_OP_NORM:
7481
  ggml_vk_norm(ctx, compute_ctx, src0, node, dryrun);
@@ -7488,6 +7548,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
7488
  case GGML_OP_RMS_NORM:
7489
  ggml_vk_rms_norm(ctx, compute_ctx, src0, node, dryrun);
7490
 
 
 
 
 
7491
  break;
7492
  case GGML_OP_UNARY:
7493
  switch (ggml_get_unary_op(node)) {
@@ -7496,6 +7560,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
7496
  case GGML_UNARY_OP_GELU_QUICK:
7497
  case GGML_UNARY_OP_RELU:
7498
  case GGML_UNARY_OP_TANH:
 
7499
  ggml_vk_unary(ctx, compute_ctx, src0, node, dryrun);
7500
  break;
7501
  default:
@@ -7509,9 +7574,17 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
7509
  case GGML_OP_SOFT_MAX:
7510
  ggml_vk_soft_max(ctx, compute_ctx, src0, src1, node, dryrun);
7511
 
 
 
 
 
7512
  break;
7513
  case GGML_OP_ROPE:
7514
- ggml_vk_rope(ctx, compute_ctx, src0, src1, src2, node, dryrun);
 
 
 
 
7515
 
7516
  break;
7517
  case GGML_OP_ARGSORT:
@@ -7637,12 +7710,16 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor *
7637
  case GGML_OP_CPY:
7638
  case GGML_OP_CONT:
7639
  case GGML_OP_DUP:
 
7640
  case GGML_OP_NORM:
7641
  case GGML_OP_GROUP_NORM:
7642
  case GGML_OP_RMS_NORM:
 
7643
  case GGML_OP_DIAG_MASK_INF:
7644
  case GGML_OP_SOFT_MAX:
 
7645
  case GGML_OP_ROPE:
 
7646
  case GGML_OP_RESHAPE:
7647
  case GGML_OP_VIEW:
7648
  case GGML_OP_PERMUTE:
@@ -7671,6 +7748,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor *
7671
  case GGML_UNARY_OP_GELU_QUICK:
7672
  case GGML_UNARY_OP_RELU:
7673
  case GGML_UNARY_OP_TANH:
 
7674
  buf = tensor->buffer;
7675
  break;
7676
  default:
@@ -8373,6 +8451,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
8373
  case GGML_UNARY_OP_SILU:
8374
  case GGML_UNARY_OP_RELU:
8375
  case GGML_UNARY_OP_TANH:
 
8376
  return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
8377
  default:
8378
  return false;
@@ -8562,6 +8641,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
8562
  case GGML_OP_REPEAT_BACK:
8563
  return op->type == GGML_TYPE_F32 && op->src[0]->type == GGML_TYPE_F32;
8564
  case GGML_OP_ROPE:
 
8565
  case GGML_OP_NONE:
8566
  case GGML_OP_RESHAPE:
8567
  case GGML_OP_VIEW:
@@ -8576,6 +8656,8 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
8576
  case GGML_OP_SUB:
8577
  case GGML_OP_MUL:
8578
  case GGML_OP_DIV:
 
 
8579
  case GGML_OP_SQR:
8580
  case GGML_OP_SIN:
8581
  case GGML_OP_COS:
@@ -8588,6 +8670,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
8588
  case GGML_OP_PAD:
8589
  case GGML_OP_DIAG_MASK_INF:
8590
  case GGML_OP_SOFT_MAX:
 
8591
  case GGML_OP_ARGSORT:
8592
  case GGML_OP_SUM:
8593
  case GGML_OP_SUM_ROWS:
@@ -8979,15 +9062,22 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
8979
  tensor_clone = ggml_group_norm(ggml_ctx, src_clone[0], *(int *)tensor->op_params, ((float *)tensor->op_params)[1]);
8980
  } else if (tensor->op == GGML_OP_RMS_NORM) {
8981
  tensor_clone = ggml_rms_norm(ggml_ctx, src_clone[0], *(float *)tensor->op_params);
 
 
 
 
 
8982
  } else if (tensor->op == GGML_OP_SOFT_MAX) {
8983
  if (src1 != nullptr) {
8984
  tensor_clone = ggml_soft_max_ext(ggml_ctx, src_clone[0], src_clone[1], ((float *)tensor->op_params)[0], ((float *)tensor->op_params)[1]);
8985
  } else {
8986
  tensor_clone = ggml_soft_max(ggml_ctx, src_clone[0]);
8987
  }
 
 
8988
  } else if (tensor->op == GGML_OP_DIAG_MASK_INF) {
8989
  tensor_clone = ggml_diag_mask_inf(ggml_ctx, src_clone[0], *(int *)tensor->op_params);
8990
- } else if (tensor->op == GGML_OP_ROPE) {
8991
  const int n_dims = ((int32_t *) tensor->op_params)[1];
8992
  const int mode = ((int32_t *) tensor->op_params)[2];
8993
  //const int n_ctx_ggml = ((int32_t *) tensor->op_params)[3];
@@ -9000,9 +9090,17 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
9000
  const float beta_slow = ((float *) tensor->op_params)[10];
9001
  if (mode & GGML_ROPE_TYPE_MROPE) {
9002
  int32_t *sections = ((int32_t *) tensor->op_params) + 11;
9003
- tensor_clone = ggml_rope_multi(ggml_ctx, src_clone[0], src_clone[1], src_clone[2], n_dims, sections, mode, n_ctx_orig_ggml, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow);
 
 
 
 
9004
  } else {
9005
- tensor_clone = ggml_rope_ext(ggml_ctx, src_clone[0], src_clone[1], src_clone[2], n_dims, mode, n_ctx_orig_ggml, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow);
 
 
 
 
9006
  }
9007
  } else if (tensor->op == GGML_OP_UNARY) {
9008
  switch (ggml_get_unary_op(tensor)) {
@@ -9021,6 +9119,9 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
9021
  case GGML_UNARY_OP_TANH:
9022
  tensor_clone = ggml_tanh(ggml_ctx, src_clone[0]);
9023
  break;
 
 
 
9024
  default:
9025
  std::cerr << "Missing vk_check_results OP: " << ggml_op_name(tensor->op) << std::endl;
9026
  GGML_ABORT("fatal error");
 
241
  vk_pipeline pipeline_norm_f32;
242
  vk_pipeline pipeline_group_norm_f32;
243
  vk_pipeline pipeline_rms_norm_f32;
244
+ vk_pipeline pipeline_rms_norm_back_f32;
245
  vk_pipeline pipeline_gelu_f32;
246
  vk_pipeline pipeline_gelu_quick_f32;
247
  vk_pipeline pipeline_silu_f32;
248
+ vk_pipeline pipeline_silu_back_f32;
249
  vk_pipeline pipeline_relu_f32;
250
  vk_pipeline pipeline_leaky_relu_f32;
251
  vk_pipeline pipeline_tanh_f32;
252
+ vk_pipeline pipeline_sigmoid_f32;
253
  vk_pipeline pipeline_diag_mask_inf_f32;
254
  vk_pipeline pipeline_soft_max_f32, pipeline_soft_max_f32_f16;
255
  vk_pipeline pipeline_soft_max_f32_wg512, pipeline_soft_max_f32_f16_wg512;
256
+ vk_pipeline pipeline_soft_max_back_f32;
257
  vk_pipeline pipeline_rope_norm_f32, pipeline_rope_norm_f16;
258
  vk_pipeline pipeline_rope_neox_f32, pipeline_rope_neox_f16;
259
  vk_pipeline pipeline_rope_multi_f32, pipeline_rope_multi_f16;
 
508
  uint32_t s1;
509
  uint32_t s2;
510
  int32_t sections[4];
511
+ uint32_t is_back;
512
  };
513
 
514
  struct vk_op_soft_max_push_constants {
 
2127
  ggml_vk_create_pipeline(device, device->pipeline_norm_f32, "norm_f32", norm_f32_len, norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
2128
  ggml_vk_create_pipeline(device, device->pipeline_group_norm_f32, "group_norm_f32", group_norm_f32_len, group_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
2129
  ggml_vk_create_pipeline(device, device->pipeline_rms_norm_f32, "rms_norm_f32", rms_norm_f32_len, rms_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
2130
+ ggml_vk_create_pipeline(device, device->pipeline_rms_norm_back_f32, "rms_norm_back_f32", rms_norm_back_f32_len, rms_norm_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
2131
 
2132
  ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_f32, "cpy_f32_f32", cpy_f32_f32_len, cpy_f32_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
2133
  ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_f16, "cpy_f32_f16", cpy_f32_f16_len, cpy_f32_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
 
2187
  ggml_vk_create_pipeline(device, device->pipeline_gelu_f32, "gelu_f32", gelu_f32_len, gelu_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
2188
  ggml_vk_create_pipeline(device, device->pipeline_gelu_quick_f32, "gelu_quick_f32", gelu_quick_f32_len, gelu_quick_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
2189
  ggml_vk_create_pipeline(device, device->pipeline_silu_f32, "silu_f32", silu_f32_len, silu_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
2190
+ ggml_vk_create_pipeline(device, device->pipeline_silu_back_f32, "silu_back_f32", silu_back_f32_len, silu_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
2191
  ggml_vk_create_pipeline(device, device->pipeline_relu_f32, "relu_f32", relu_f32_len, relu_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
2192
  ggml_vk_create_pipeline(device, device->pipeline_leaky_relu_f32, "leaky_relu_f32", leaky_relu_f32_len, leaky_relu_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
2193
  ggml_vk_create_pipeline(device, device->pipeline_tanh_f32, "tanh_f32", tanh_f32_len, tanh_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
2194
+ ggml_vk_create_pipeline(device, device->pipeline_sigmoid_f32, "sigmoid_f32", sigmoid_f32_len, sigmoid_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
2195
 
2196
  ggml_vk_create_pipeline(device, device->pipeline_diag_mask_inf_f32, "diag_mask_inf_f32", diag_mask_inf_f32_len, diag_mask_inf_f32_data, "main", 2, sizeof(vk_op_diag_mask_push_constants), {1, 512, 1}, {}, 1, true);
2197
 
 
2199
  ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32_wg512, "soft_max_f32_wg512", soft_max_f32_len, soft_max_f32_data, "main", 3, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { 512 }, 1);
2200
  ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32_f16, "soft_max_f32_f16", soft_max_f32_f16_len, soft_max_f32_f16_data, "main", 3, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
2201
  ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32_f16_wg512, "soft_max_f32_f16_wg512", soft_max_f32_f16_len, soft_max_f32_f16_data, "main", 3, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { 512 }, 1);
2202
+ ggml_vk_create_pipeline(device, device->pipeline_soft_max_back_f32, "soft_max_back_f32", soft_max_back_f32_len, soft_max_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
2203
 
2204
  ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f32, "rope_norm_f32", rope_norm_f32_len, rope_norm_f32_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
2205
  ggml_vk_create_pipeline(device, device->pipeline_rope_neox_f32, "rope_neox_f32", rope_neox_f32_len, rope_neox_f32_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
 
5293
  case GGML_OP_CONT:
5294
  case GGML_OP_DUP:
5295
  return ggml_vk_get_cpy_pipeline(ctx, src0, dst, dst->type);
5296
+ case GGML_OP_SILU_BACK:
5297
+ if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
5298
+ return ctx->device->pipeline_silu_back_f32;
5299
+ }
5300
+ return nullptr;
5301
  case GGML_OP_NORM:
5302
  if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
5303
  return ctx->device->pipeline_norm_f32;
 
5313
  return ctx->device->pipeline_rms_norm_f32;
5314
  }
5315
  return nullptr;
5316
+ case GGML_OP_RMS_NORM_BACK:
5317
+ if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
5318
+ return ctx->device->pipeline_rms_norm_back_f32;
5319
+ }
5320
+ return nullptr;
5321
  case GGML_OP_UNARY:
5322
  switch (ggml_get_unary_op(dst)) {
5323
  case GGML_UNARY_OP_SILU:
 
5345
  return ctx->device->pipeline_tanh_f32;
5346
  }
5347
  break;
5348
+ case GGML_UNARY_OP_SIGMOID:
5349
+ if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
5350
+ return ctx->device->pipeline_sigmoid_f32;
5351
+ }
5352
+ break;
5353
  default:
5354
  break;
5355
  }
 
5369
  return src0->ne[0] > 1024 ? ctx->device->pipeline_soft_max_f32_f16_wg512 : ctx->device->pipeline_soft_max_f32_f16;
5370
  }
5371
  return nullptr;
5372
+ case GGML_OP_SOFT_MAX_BACK:
5373
+ if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
5374
+ return ctx->device->pipeline_soft_max_back_f32;
5375
+ }
5376
+ return nullptr;
5377
  case GGML_OP_ROPE:
5378
+ case GGML_OP_ROPE_BACK:
5379
  {
5380
  const int mode = ((const int32_t *) dst->op_params)[2];
5381
  const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
 
5703
  switch (op) {
5704
  case GGML_OP_NORM:
5705
  case GGML_OP_RMS_NORM:
5706
+ case GGML_OP_RMS_NORM_BACK:
5707
  case GGML_OP_SOFT_MAX:
5708
+ case GGML_OP_SOFT_MAX_BACK:
5709
  case GGML_OP_SUM_ROWS:
5710
  case GGML_OP_ARGMAX:
5711
  {
 
5729
  } break;
5730
  case GGML_OP_DIAG_MASK_INF:
5731
  case GGML_OP_ROPE:
5732
+ case GGML_OP_ROPE_BACK:
5733
  elements = { (uint32_t)ggml_nrows(src0), (uint32_t)ne00, 1 };
5734
  break;
5735
  case GGML_OP_GET_ROWS:
 
5825
 
5826
  ggml_vk_sync_buffers(subctx);
5827
  ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, subbuf_y, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, sizeof(PC), &pc, elements);
5828
+ } else if (op == GGML_OP_ROPE || op == GGML_OP_ROPE_BACK) {
5829
  // Empty src2 is possible in rope, but the shader needs a buffer
5830
  vk_subbuffer subbuf_z;
5831
  if (use_src2) {
 
6347
  }, dryrun);
6348
  }
6349
 
6350
+ static void ggml_vk_silu_back(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
6351
+ ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_SILU_BACK, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f }, dryrun);
6352
+ }
6353
+
6354
  static void ggml_vk_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
6355
  float * op_params = (float *)dst->op_params;
6356
 
 
6373
  ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_RMS_NORM, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], 0.0f }, dryrun);
6374
  }
6375
 
6376
+ static void ggml_vk_rms_norm_back(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
6377
+ float * op_params = (float *)dst->op_params;
6378
+ ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_RMS_NORM_BACK, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], 0.0f }, dryrun);
6379
+ }
6380
+
6381
  static void ggml_vk_unary(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
6382
  ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_UNARY, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f }, dryrun);
6383
  }
 
6413
  }, dryrun);
6414
  }
6415
 
6416
+ static void ggml_vk_soft_max_back(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
6417
+ float * op_params = (float *)dst->op_params;
6418
+ ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_SOFT_MAX_BACK, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], op_params[1] }, dryrun);
6419
+ }
6420
+
6421
+ static void ggml_vk_rope(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, bool backprop, bool dryrun = false) {
6422
  const int n_dims = ((int32_t *) dst->op_params)[1];
6423
  const int mode = ((int32_t *) dst->op_params)[2];
6424
  // const int n_ctx = ((int32_t *) dst->op_params)[3];
 
6446
  (uint32_t)src0->ne[0], (uint32_t)n_dims, freq_scale, (uint32_t)src0->ne[1],
6447
  freq_base, ext_factor, attn_factor, {corr_dims[0], corr_dims[1]}, theta_scale,
6448
  src2 != nullptr, (uint32_t)src0->ne[2], s1, s2,
6449
+ sections[0], sections[1], sections[2], sections[3], backprop
6450
  }, dryrun);
6451
  }
6452
 
 
7343
  case GGML_UNARY_OP_GELU_QUICK:
7344
  case GGML_UNARY_OP_RELU:
7345
  case GGML_UNARY_OP_TANH:
7346
+ case GGML_UNARY_OP_SIGMOID:
7347
  break;
7348
  default:
7349
  return false;
 
7368
  case GGML_OP_CPY:
7369
  case GGML_OP_CONT:
7370
  case GGML_OP_DUP:
7371
+ case GGML_OP_SILU_BACK:
7372
  case GGML_OP_NORM:
7373
  case GGML_OP_GROUP_NORM:
7374
  case GGML_OP_RMS_NORM:
7375
+ case GGML_OP_RMS_NORM_BACK:
7376
  case GGML_OP_DIAG_MASK_INF:
7377
  case GGML_OP_SOFT_MAX:
7378
+ case GGML_OP_SOFT_MAX_BACK:
7379
  case GGML_OP_ROPE:
7380
+ case GGML_OP_ROPE_BACK:
7381
  case GGML_OP_MUL_MAT:
7382
  case GGML_OP_MUL_MAT_ID:
7383
  case GGML_OP_ARGSORT:
 
7430
  case GGML_OP_CPY:
7431
  case GGML_OP_CONT:
7432
  case GGML_OP_DUP:
7433
+ case GGML_OP_SILU_BACK:
7434
  case GGML_OP_NORM:
7435
  case GGML_OP_GROUP_NORM:
7436
  case GGML_OP_RMS_NORM:
7437
+ case GGML_OP_RMS_NORM_BACK:
7438
  case GGML_OP_UNARY:
7439
  case GGML_OP_DIAG_MASK_INF:
7440
  case GGML_OP_SOFT_MAX:
7441
+ case GGML_OP_SOFT_MAX_BACK:
7442
  case GGML_OP_ROPE:
7443
+ case GGML_OP_ROPE_BACK:
7444
  case GGML_OP_ARGSORT:
7445
  case GGML_OP_SUM:
7446
  case GGML_OP_SUM_ROWS:
 
7532
  case GGML_OP_DUP:
7533
  ggml_vk_cpy(ctx, compute_ctx, src0, node, dryrun);
7534
 
7535
+ break;
7536
+ case GGML_OP_SILU_BACK:
7537
+ ggml_vk_silu_back(ctx, compute_ctx, src0, src1, node, dryrun);
7538
+
7539
  break;
7540
  case GGML_OP_NORM:
7541
  ggml_vk_norm(ctx, compute_ctx, src0, node, dryrun);
 
7548
  case GGML_OP_RMS_NORM:
7549
  ggml_vk_rms_norm(ctx, compute_ctx, src0, node, dryrun);
7550
 
7551
+ break;
7552
+ case GGML_OP_RMS_NORM_BACK:
7553
+ ggml_vk_rms_norm_back(ctx, compute_ctx, src0, src1, node, dryrun);
7554
+
7555
  break;
7556
  case GGML_OP_UNARY:
7557
  switch (ggml_get_unary_op(node)) {
 
7560
  case GGML_UNARY_OP_GELU_QUICK:
7561
  case GGML_UNARY_OP_RELU:
7562
  case GGML_UNARY_OP_TANH:
7563
+ case GGML_UNARY_OP_SIGMOID:
7564
  ggml_vk_unary(ctx, compute_ctx, src0, node, dryrun);
7565
  break;
7566
  default:
 
7574
  case GGML_OP_SOFT_MAX:
7575
  ggml_vk_soft_max(ctx, compute_ctx, src0, src1, node, dryrun);
7576
 
7577
+ break;
7578
+ case GGML_OP_SOFT_MAX_BACK:
7579
+ ggml_vk_soft_max_back(ctx, compute_ctx, src0, src1, node, dryrun);
7580
+
7581
  break;
7582
  case GGML_OP_ROPE:
7583
+ ggml_vk_rope(ctx, compute_ctx, src0, src1, src2, node, false, dryrun);
7584
+
7585
+ break;
7586
+ case GGML_OP_ROPE_BACK:
7587
+ ggml_vk_rope(ctx, compute_ctx, src0, src1, src2, node, true, dryrun);
7588
 
7589
  break;
7590
  case GGML_OP_ARGSORT:
 
7710
  case GGML_OP_CPY:
7711
  case GGML_OP_CONT:
7712
  case GGML_OP_DUP:
7713
+ case GGML_OP_SILU_BACK:
7714
  case GGML_OP_NORM:
7715
  case GGML_OP_GROUP_NORM:
7716
  case GGML_OP_RMS_NORM:
7717
+ case GGML_OP_RMS_NORM_BACK:
7718
  case GGML_OP_DIAG_MASK_INF:
7719
  case GGML_OP_SOFT_MAX:
7720
+ case GGML_OP_SOFT_MAX_BACK:
7721
  case GGML_OP_ROPE:
7722
+ case GGML_OP_ROPE_BACK:
7723
  case GGML_OP_RESHAPE:
7724
  case GGML_OP_VIEW:
7725
  case GGML_OP_PERMUTE:
 
7748
  case GGML_UNARY_OP_GELU_QUICK:
7749
  case GGML_UNARY_OP_RELU:
7750
  case GGML_UNARY_OP_TANH:
7751
+ case GGML_UNARY_OP_SIGMOID:
7752
  buf = tensor->buffer;
7753
  break;
7754
  default:
 
8451
  case GGML_UNARY_OP_SILU:
8452
  case GGML_UNARY_OP_RELU:
8453
  case GGML_UNARY_OP_TANH:
8454
+ case GGML_UNARY_OP_SIGMOID:
8455
  return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
8456
  default:
8457
  return false;
 
8641
  case GGML_OP_REPEAT_BACK:
8642
  return op->type == GGML_TYPE_F32 && op->src[0]->type == GGML_TYPE_F32;
8643
  case GGML_OP_ROPE:
8644
+ case GGML_OP_ROPE_BACK:
8645
  case GGML_OP_NONE:
8646
  case GGML_OP_RESHAPE:
8647
  case GGML_OP_VIEW:
 
8656
  case GGML_OP_SUB:
8657
  case GGML_OP_MUL:
8658
  case GGML_OP_DIV:
8659
+ case GGML_OP_SILU_BACK:
8660
+ case GGML_OP_RMS_NORM_BACK:
8661
  case GGML_OP_SQR:
8662
  case GGML_OP_SIN:
8663
  case GGML_OP_COS:
 
8670
  case GGML_OP_PAD:
8671
  case GGML_OP_DIAG_MASK_INF:
8672
  case GGML_OP_SOFT_MAX:
8673
+ case GGML_OP_SOFT_MAX_BACK:
8674
  case GGML_OP_ARGSORT:
8675
  case GGML_OP_SUM:
8676
  case GGML_OP_SUM_ROWS:
 
9062
  tensor_clone = ggml_group_norm(ggml_ctx, src_clone[0], *(int *)tensor->op_params, ((float *)tensor->op_params)[1]);
9063
  } else if (tensor->op == GGML_OP_RMS_NORM) {
9064
  tensor_clone = ggml_rms_norm(ggml_ctx, src_clone[0], *(float *)tensor->op_params);
9065
+ } else if (tensor->op == GGML_OP_RMS_NORM_BACK) {
9066
+ const float eps = ((float *) tensor->op_params)[0];
9067
+ tensor_clone = ggml_rms_norm_back(ggml_ctx, src_clone[0], src_clone[1], eps);
9068
+ } else if (tensor->op == GGML_OP_SILU_BACK) {
9069
+ tensor_clone = ggml_silu_back(ggml_ctx, src_clone[0], src_clone[1]);
9070
  } else if (tensor->op == GGML_OP_SOFT_MAX) {
9071
  if (src1 != nullptr) {
9072
  tensor_clone = ggml_soft_max_ext(ggml_ctx, src_clone[0], src_clone[1], ((float *)tensor->op_params)[0], ((float *)tensor->op_params)[1]);
9073
  } else {
9074
  tensor_clone = ggml_soft_max(ggml_ctx, src_clone[0]);
9075
  }
9076
+ } else if (tensor->op == GGML_OP_SOFT_MAX_BACK) {
9077
+ tensor_clone = ggml_soft_max_ext_back(ggml_ctx, src_clone[0], src_clone[1], ((float *)tensor->op_params)[0], ((float *)tensor->op_params)[1]);
9078
  } else if (tensor->op == GGML_OP_DIAG_MASK_INF) {
9079
  tensor_clone = ggml_diag_mask_inf(ggml_ctx, src_clone[0], *(int *)tensor->op_params);
9080
+ } else if (tensor->op == GGML_OP_ROPE || tensor->op == GGML_OP_ROPE_BACK) {
9081
  const int n_dims = ((int32_t *) tensor->op_params)[1];
9082
  const int mode = ((int32_t *) tensor->op_params)[2];
9083
  //const int n_ctx_ggml = ((int32_t *) tensor->op_params)[3];
 
9090
  const float beta_slow = ((float *) tensor->op_params)[10];
9091
  if (mode & GGML_ROPE_TYPE_MROPE) {
9092
  int32_t *sections = ((int32_t *) tensor->op_params) + 11;
9093
+ if (tensor->op == GGML_OP_ROPE) {
9094
+ tensor_clone = ggml_rope_multi(ggml_ctx, src_clone[0], src_clone[1], src_clone[2], n_dims, sections, mode, n_ctx_orig_ggml, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow);
9095
+ } else {
9096
+ tensor_clone = ggml_rope_multi_back(ggml_ctx, src_clone[0], src_clone[1], src_clone[2], n_dims, sections, mode, n_ctx_orig_ggml, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow);
9097
+ }
9098
  } else {
9099
+ if (tensor->op == GGML_OP_ROPE) {
9100
+ tensor_clone = ggml_rope_ext(ggml_ctx, src_clone[0], src_clone[1], src_clone[2], n_dims, mode, n_ctx_orig_ggml, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow);
9101
+ } else {
9102
+ tensor_clone = ggml_rope_ext_back(ggml_ctx, src_clone[0], src_clone[1], src_clone[2], n_dims, mode, n_ctx_orig_ggml, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow);
9103
+ }
9104
  }
9105
  } else if (tensor->op == GGML_OP_UNARY) {
9106
  switch (ggml_get_unary_op(tensor)) {
 
9119
  case GGML_UNARY_OP_TANH:
9120
  tensor_clone = ggml_tanh(ggml_ctx, src_clone[0]);
9121
  break;
9122
+ case GGML_UNARY_OP_SIGMOID:
9123
+ tensor_clone = ggml_sigmoid(ggml_ctx, src_clone[0]);
9124
+ break;
9125
  default:
9126
  std::cerr << "Missing vk_check_results OP: " << ggml_op_name(tensor->op) << std::endl;
9127
  GGML_ABORT("fatal error");
ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_back.comp ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #version 450
2
+
3
+ #include "generic_head.comp"
4
+ #include "types.comp"
5
+
6
+ #extension GL_EXT_control_flow_attributes : enable
7
+ #define BLOCK_SIZE 512
8
+
9
+ layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in;
10
+
11
+ layout (binding = 0) readonly buffer G {A_TYPE data_a[];};
12
+ layout (binding = 1) readonly buffer X {B_TYPE data_b[];};
13
+ layout (binding = 2) writeonly buffer D {D_TYPE data_d[];};
14
+
15
+ shared FLOAT_TYPE sum_xx[BLOCK_SIZE];
16
+ shared FLOAT_TYPE sum_xg[BLOCK_SIZE];
17
+
18
+ void main() {
19
+ const uint row = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x;
20
+ const uint tid = gl_LocalInvocationID.x;
21
+
22
+ // Compute derivative of x[i]/norm(x) = g[i]/norm(x) - x[i] dot(x,g)/KX / norm(x)^1.5
23
+
24
+ // partial sums for thread in warp
25
+ sum_xx[tid] = FLOAT_TYPE(0.0f);
26
+ sum_xg[tid] = FLOAT_TYPE(0.0f);
27
+
28
+ [[unroll]] for (uint col = tid; col < p.KX; col += BLOCK_SIZE) {
29
+ const FLOAT_TYPE gi = FLOAT_TYPE(data_a[row*p.KX + col]);
30
+ const FLOAT_TYPE xi = FLOAT_TYPE(data_b[row*p.KX + col]);
31
+ sum_xx[tid] += xi * xi;
32
+ sum_xg[tid] += xi * gi;
33
+ }
34
+
35
+ // sum up partial sums and write back result
36
+ barrier();
37
+ [[unroll]] for (int s = BLOCK_SIZE / 2; s > 0; s >>= 1) {
38
+ if (tid < s) {
39
+ sum_xx[tid] += sum_xx[tid + s];
40
+ sum_xg[tid] += sum_xg[tid + s];
41
+ }
42
+ barrier();
43
+ }
44
+
45
+ const FLOAT_TYPE eps = FLOAT_TYPE(p.param1);
46
+ const FLOAT_TYPE mean = sum_xx[0] / FLOAT_TYPE(p.KX);
47
+ const FLOAT_TYPE scale_g = inversesqrt(mean + eps);
48
+ const FLOAT_TYPE scale_x = -scale_g * sum_xg[0] / (sum_xx[0] + FLOAT_TYPE(p.KX) * eps);
49
+
50
+ [[unroll]] for (uint col = tid; col < p.KX; col += BLOCK_SIZE) {
51
+ data_d[row*p.KX + col] = D_TYPE(
52
+ scale_g * FLOAT_TYPE(data_a[row*p.KX + col]) +
53
+ scale_x * FLOAT_TYPE(data_b[row*p.KX + col]));
54
+ }
55
+ }
ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp CHANGED
@@ -29,6 +29,7 @@ layout (push_constant) uniform parameter {
29
  uint s1;
30
  uint s2;
31
  int sections[4];
 
32
  } p;
33
 
34
  float rope_yarn_ramp(const float low, const float high, const uint i0) {
@@ -48,6 +49,10 @@ void rope_yarn(const float theta_extrap, const uint i0, out float cos_theta, out
48
  // Get n-d magnitude scaling corrected for interpolation
49
  mscale *= 1.0f + 0.1f * log(1.0f / p.freq_scale);
50
  }
 
 
 
 
51
  cos_theta = cos(theta) * mscale;
52
  sin_theta = sin(theta) * mscale;
53
  }
 
29
  uint s1;
30
  uint s2;
31
  int sections[4];
32
+ uint is_back;
33
  } p;
34
 
35
  float rope_yarn_ramp(const float low, const float high, const uint i0) {
 
49
  // Get n-d magnitude scaling corrected for interpolation
50
  mscale *= 1.0f + 0.1f * log(1.0f / p.freq_scale);
51
  }
52
+ // Backprogagation uses inverted rotation
53
+ if (p.is_back != 0) {
54
+ theta = -theta;
55
+ }
56
  cos_theta = cos(theta) * mscale;
57
  sin_theta = sin(theta) * mscale;
58
  }
ggml/src/ggml-vulkan/vulkan-shaders/sigmoid.comp ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #version 450
2
+
3
+ #include "generic_head.comp"
4
+ #include "types.comp"
5
+
6
+ #extension GL_EXT_control_flow_attributes : enable
7
+
8
+ layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
9
+
10
+ layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
11
+ layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
12
+
13
+ void main() {
14
+ const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
15
+
16
+ if (i >= p.KX) {
17
+ return;
18
+ }
19
+ data_d[i] = D_TYPE(1. / (1 + exp(-1. *data_a[i])));
20
+ }
ggml/src/ggml-vulkan/vulkan-shaders/silu_back.comp ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #version 450
2
+
3
+ #include "generic_head.comp"
4
+ #include "types.comp"
5
+
6
+ #extension GL_EXT_control_flow_attributes : enable
7
+
8
+ layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
9
+
10
+ layout (binding = 0) readonly buffer G {A_TYPE data_g[];};
11
+ layout (binding = 1) readonly buffer X {B_TYPE data_x[];};
12
+ layout (binding = 2) writeonly buffer D {D_TYPE data_d[];};
13
+
14
+ void main() {
15
+ const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
16
+
17
+ if (i >= p.KX) {
18
+ return;
19
+ }
20
+
21
+ // Compute derivative of SiLU(x): 1/(1+exp(-x)) - x*exp(-x)/(1+exp(-x))^2
22
+
23
+ const float xi = float(data_x[i]);
24
+ const float s = 1.0f / (1.0f + exp(-xi));
25
+ data_d[i] = D_TYPE(data_g[i] * (s + xi * s * (1 - s)));
26
+ }
ggml/src/ggml-vulkan/vulkan-shaders/soft_max_back.comp ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #version 450
2
+
3
+ #extension GL_EXT_control_flow_attributes : enable
4
+
5
+ #include "generic_head.comp"
6
+ #include "types.comp"
7
+
8
+ layout(constant_id = 0) const uint BLOCK_SIZE = 32;
9
+ layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
10
+
11
+ // In this shader Y = softmax(X) and X is not provided as input.
12
+
13
+ layout (binding = 0) readonly buffer G {A_TYPE data_g[];};
14
+ layout (binding = 1) readonly buffer Y {B_TYPE data_y[];};
15
+ layout (binding = 2) buffer D {D_TYPE data_d[];};
16
+
17
+ shared FLOAT_TYPE sum_yg[BLOCK_SIZE];
18
+
19
+ void main() {
20
+ const uint row = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x;
21
+ const uint tid = gl_LocalInvocationID.x;
22
+
23
+ FLOAT_TYPE scale = p.param1;
24
+
25
+ // partial sums for thread in warp
26
+ sum_yg[tid] = FLOAT_TYPE(0.0f);
27
+
28
+ [[unroll]] for (uint col = tid; col < p.KX; col += BLOCK_SIZE) {
29
+ const FLOAT_TYPE gi = FLOAT_TYPE(data_g[row*p.KX + col]);
30
+ const FLOAT_TYPE yi = FLOAT_TYPE(data_y[row*p.KX + col]);
31
+ sum_yg[tid] += yi * gi;
32
+ }
33
+
34
+ // sum up partial sums and write back result
35
+ barrier();
36
+ [[unroll]] for (uint s = BLOCK_SIZE / 2; s > 0; s >>= 1) {
37
+ if (tid < s) {
38
+ sum_yg[tid] += sum_yg[tid + s];
39
+ }
40
+ barrier();
41
+ }
42
+
43
+ const FLOAT_TYPE dot_yg = sum_yg[0];
44
+
45
+ [[unroll]] for (uint col = tid; col < p.KX; col += BLOCK_SIZE) {
46
+ data_d[row*p.KX + col] = D_TYPE(scale
47
+ * (FLOAT_TYPE(data_g[row*p.KX + col]) - dot_yg)
48
+ * FLOAT_TYPE(data_y[row*p.KX + col]));
49
+ }
50
+ }
ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp CHANGED
@@ -433,6 +433,7 @@ void process_shaders() {
433
  string_to_spv("norm_f32", "norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
434
  string_to_spv("group_norm_f32", "group_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
435
  string_to_spv("rms_norm_f32", "rms_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
 
436
 
437
  string_to_spv("cpy_f32_f32", "copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
438
  string_to_spv("cpy_f32_f16", "copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}});
@@ -483,14 +484,17 @@ void process_shaders() {
483
  string_to_spv("gelu_f32", "gelu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
484
  string_to_spv("gelu_quick_f32", "gelu_quick.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
485
  string_to_spv("silu_f32", "silu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
 
486
  string_to_spv("relu_f32", "relu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
487
  string_to_spv("leaky_relu_f32", "leaky_relu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
488
  string_to_spv("tanh_f32", "tanh.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
 
489
 
490
  string_to_spv("diag_mask_inf_f32", "diag_mask_inf.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
491
 
492
  string_to_spv("soft_max_f32", "soft_max.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}));
493
  string_to_spv("soft_max_f32_f16", "soft_max.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}}));
 
494
 
495
  string_to_spv("rope_norm_f32", "rope_norm.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
496
  string_to_spv("rope_norm_f16", "rope_norm.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
 
433
  string_to_spv("norm_f32", "norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
434
  string_to_spv("group_norm_f32", "group_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
435
  string_to_spv("rms_norm_f32", "rms_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
436
+ string_to_spv("rms_norm_back_f32", "rms_norm_back.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}));
437
 
438
  string_to_spv("cpy_f32_f32", "copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
439
  string_to_spv("cpy_f32_f16", "copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}});
 
484
  string_to_spv("gelu_f32", "gelu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
485
  string_to_spv("gelu_quick_f32", "gelu_quick.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
486
  string_to_spv("silu_f32", "silu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
487
+ string_to_spv("silu_back_f32", "silu_back.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}});
488
  string_to_spv("relu_f32", "relu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
489
  string_to_spv("leaky_relu_f32", "leaky_relu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
490
  string_to_spv("tanh_f32", "tanh.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
491
+ string_to_spv("sigmoid_f32", "sigmoid.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
492
 
493
  string_to_spv("diag_mask_inf_f32", "diag_mask_inf.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
494
 
495
  string_to_spv("soft_max_f32", "soft_max.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}));
496
  string_to_spv("soft_max_f32_f16", "soft_max.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}}));
497
+ string_to_spv("soft_max_back_f32", "soft_max_back.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}));
498
 
499
  string_to_spv("rope_norm_f32", "rope_norm.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
500
  string_to_spv("rope_norm_f16", "rope_norm.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});