zhentaoyu commited on
Commit
06acee2
·
1 Parent(s): ef3d018

Update SYCL-Rope op and Refactor (llama/8157)

Browse files

* align with rope.cu and move sycl-op to a single file

Files changed (1) hide show
  1. ggml/src/ggml-sycl.cpp +2 -303
ggml/src/ggml-sycl.cpp CHANGED
@@ -978,114 +978,6 @@ static void cpy_f32_q(const char * cx, char * cdst, const int ne,
978
  cpy_blck(cx + x_offset, cdst + dst_offset);
979
  }
980
 
981
- static float rope_yarn_ramp(const float low, const float high, const int i0) {
982
- const float y = (i0 / 2 - low) / sycl::max(0.001f, high - low);
983
- return 1.0f - sycl::min(1.0f, sycl::max(0.0f, y));
984
- }
985
-
986
- struct rope_corr_dims {
987
- float v[4];
988
- };
989
-
990
- // YaRN algorithm based on LlamaYaRNScaledRotaryEmbedding.py from https://github.com/jquesnelle/yarn
991
- // MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng.
992
- static void rope_yarn(
993
- float theta_extrap, float freq_scale, rope_corr_dims corr_dims, int64_t i0, float ext_factor, float mscale,
994
- float * cos_theta, float * sin_theta
995
- ) {
996
- // Get n-d rotational scaling corrected for extrapolation
997
- float theta_interp = freq_scale * theta_extrap;
998
- float theta = theta_interp;
999
- if (ext_factor != 0.0f) {
1000
- float ramp_mix = rope_yarn_ramp(corr_dims.v[0], corr_dims.v[1], i0) * ext_factor;
1001
- theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix;
1002
-
1003
- // Get n-d magnitude scaling corrected for interpolation
1004
- mscale *= 1.0f + 0.1f * sycl::log(1.0f / freq_scale);
1005
- }
1006
- *cos_theta = sycl::cos(theta) * mscale;
1007
- *sin_theta = sycl::sin(theta) * mscale;
1008
- }
1009
-
1010
- // rope == RoPE == rotary positional embedding
1011
- template<typename T, bool has_pos>
1012
- static void rope(
1013
- const T * x, T * dst, int ncols, const int32_t * pos, float freq_scale, int p_delta_rows, float freq_base,
1014
- float ext_factor, float attn_factor, rope_corr_dims corr_dims
1015
- ,
1016
- const sycl::nd_item<3> &item_ct1) {
1017
- const int col = 2 * (item_ct1.get_local_range(1) * item_ct1.get_group(1) +
1018
- item_ct1.get_local_id(1));
1019
-
1020
- if (col >= ncols) {
1021
- return;
1022
- }
1023
-
1024
- const int row = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
1025
- item_ct1.get_local_id(2);
1026
- const int i = row*ncols + col;
1027
- const int i2 = row/p_delta_rows;
1028
-
1029
- const int p = has_pos ? pos[i2] : 0;
1030
- const float theta_base = p * dpct::pow(freq_base, -float(col) / ncols);
1031
-
1032
- float cos_theta, sin_theta;
1033
- rope_yarn(theta_base, freq_scale, corr_dims, col, ext_factor, attn_factor, &cos_theta, &sin_theta);
1034
-
1035
- const float x0 = x[i + 0];
1036
- const float x1 = x[i + 1];
1037
-
1038
- dst[i + 0] = x0*cos_theta - x1*sin_theta;
1039
- dst[i + 1] = x0*sin_theta + x1*cos_theta;
1040
- }
1041
-
1042
- template<typename T, bool has_pos, bool has_freq_facs>
1043
- static void rope_neox(
1044
- const T * x, T * dst, int ncols, int n_dims, const int32_t * pos, float freq_scale, int p_delta_rows,
1045
- float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, float inv_ndims,
1046
- const float * freq_factors, const sycl::nd_item<3> &item_ct1) {
1047
- const int col = 2 * (item_ct1.get_local_range(1) * item_ct1.get_group(1) +
1048
- item_ct1.get_local_id(1));
1049
-
1050
- if (col >= ncols) {
1051
- return;
1052
- }
1053
-
1054
- const int row = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
1055
- item_ct1.get_local_id(2);
1056
- const int ib = col / n_dims;
1057
- const int ic = col % n_dims;
1058
-
1059
- if (ib > 0) {
1060
- const int i = row*ncols + ib*n_dims + ic;
1061
-
1062
- dst[i + 0] = x[i + 0];
1063
- dst[i + 1] = x[i + 1];
1064
-
1065
- return;
1066
- }
1067
-
1068
- const int i = row*ncols + ib*n_dims + ic/2;
1069
- const int i2 = row/p_delta_rows;
1070
-
1071
- float cur_rot = inv_ndims * ic - ib;
1072
-
1073
- const int p = has_pos ? pos[i2] : 0;
1074
- const float freq_factor = has_freq_facs ? freq_factors[ic/2] : 1.0f;
1075
-
1076
- const float theta_base =
1077
- p * freq_scale * dpct::pow(theta_scale, col / 2.0f)/freq_factor;
1078
-
1079
- float cos_theta, sin_theta;
1080
- rope_yarn(theta_base, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor, &cos_theta, &sin_theta);
1081
-
1082
- const float x0 = x[i + 0];
1083
- const float x1 = x[i + n_dims/2];
1084
-
1085
- dst[i + 0] = x0*cos_theta - x1*sin_theta;
1086
- dst[i + n_dims/2] = x0*sin_theta + x1*cos_theta;
1087
- }
1088
-
1089
  static void k_sum_rows_f32(const float * x, float * dst, const int ncols,
1090
  const sycl::nd_item<3> &item_ct1) {
1091
  const int row = item_ct1.get_group(1);
@@ -2241,110 +2133,6 @@ static void clamp_f32_sycl(const float *x, float *dst, const float min,
2241
  });
2242
  }
2243
 
2244
- template <typename T>
2245
- static void rope_sycl(const T *x, T *dst, int ncols, int nrows,
2246
- const int32_t *pos, float freq_scale, int p_delta_rows,
2247
- float freq_base, float ext_factor, float attn_factor,
2248
- rope_corr_dims corr_dims, queue_ptr stream) {
2249
- GGML_ASSERT(ncols % 2 == 0);
2250
- const sycl::range<3> block_dims(1, SYCL_ROPE_BLOCK_SIZE, 1);
2251
- const int num_blocks_x = (ncols + 2*SYCL_ROPE_BLOCK_SIZE - 1) / (2*SYCL_ROPE_BLOCK_SIZE);
2252
- const sycl::range<3> block_nums(1, num_blocks_x, nrows);
2253
- if (pos == nullptr) {
2254
- /*
2255
- DPCT1049:40: The work-group size passed to the SYCL kernel may exceed
2256
- the limit. To get the device limit, query
2257
- info::device::max_work_group_size. Adjust the work-group size if needed.
2258
- */
2259
- dpct::has_capability_or_fail(stream->get_device(),
2260
- {sycl::aspect::fp16});
2261
-
2262
- stream->parallel_for(
2263
- sycl::nd_range<3>(block_nums * block_dims, block_dims),
2264
- [=](sycl::nd_item<3> item_ct1) {
2265
- rope<T, false>(x, dst, ncols, pos, freq_scale, p_delta_rows,
2266
- freq_base, ext_factor, attn_factor, corr_dims,
2267
- item_ct1);
2268
- });
2269
- } else {
2270
- /*
2271
- DPCT1049:41: The work-group size passed to the SYCL kernel may exceed
2272
- the limit. To get the device limit, query
2273
- info::device::max_work_group_size. Adjust the work-group size if needed.
2274
- */
2275
- dpct::has_capability_or_fail(stream->get_device(),
2276
- {sycl::aspect::fp16});
2277
-
2278
- stream->parallel_for(
2279
- sycl::nd_range<3>(block_nums * block_dims, block_dims),
2280
- [=](sycl::nd_item<3> item_ct1) {
2281
- rope<T, true>(x, dst, ncols, pos, freq_scale, p_delta_rows,
2282
- freq_base, ext_factor, attn_factor, corr_dims,
2283
- item_ct1);
2284
- });
2285
- }
2286
- }
2287
-
2288
- template <typename T>
2289
- static void rope_neox_sycl(const T *x, T *dst, int ncols, int n_dims, int nrows,
2290
- const int32_t *pos, float freq_scale,
2291
- int p_delta_rows, float freq_base, float ext_factor,
2292
- float attn_factor, rope_corr_dims corr_dims,
2293
- const float * freq_factors, queue_ptr stream) {
2294
- GGML_ASSERT(ncols % 2 == 0);
2295
- const sycl::range<3> block_dims(1, SYCL_ROPE_BLOCK_SIZE, 1);
2296
- const int num_blocks_x = (ncols + 2*SYCL_ROPE_BLOCK_SIZE - 1) / (2*SYCL_ROPE_BLOCK_SIZE);
2297
- const sycl::range<3> block_nums(1, num_blocks_x, nrows);
2298
-
2299
- const float theta_scale = powf(freq_base, -2.0f/n_dims);
2300
- const float inv_ndims = -1.0f / n_dims;
2301
-
2302
- if (pos == nullptr) {
2303
- dpct::has_capability_or_fail(stream->get_device(),
2304
- {sycl::aspect::fp16});
2305
- if (freq_factors == nullptr) {
2306
- stream->parallel_for(
2307
- sycl::nd_range<3>(block_nums * block_dims, block_dims),
2308
- [=](sycl::nd_item<3> item_ct1) {
2309
- rope_neox<T, false, false>(x, dst, ncols, n_dims, pos, freq_scale,
2310
- p_delta_rows, ext_factor, attn_factor,
2311
- corr_dims, theta_scale, inv_ndims, freq_factors,
2312
- item_ct1);
2313
- });
2314
- } else {
2315
- stream->parallel_for(
2316
- sycl::nd_range<3>(block_nums * block_dims, block_dims),
2317
- [=](sycl::nd_item<3> item_ct1) {
2318
- rope_neox<T, false, true>(x, dst, ncols, n_dims, pos, freq_scale,
2319
- p_delta_rows, ext_factor, attn_factor,
2320
- corr_dims, theta_scale, inv_ndims, freq_factors,
2321
- item_ct1);
2322
- });
2323
- }
2324
- } else {
2325
- dpct::has_capability_or_fail(stream->get_device(),
2326
- {sycl::aspect::fp16});
2327
-
2328
- if (freq_factors == nullptr) {
2329
- stream->parallel_for(
2330
- sycl::nd_range<3>(block_nums * block_dims, block_dims),
2331
- [=](sycl::nd_item<3> item_ct1) {
2332
- rope_neox<T, true, false>(x, dst, ncols, n_dims, pos, freq_scale,
2333
- p_delta_rows, ext_factor, attn_factor,
2334
- corr_dims, theta_scale, inv_ndims, freq_factors, item_ct1);
2335
- });
2336
- } else {
2337
- stream->parallel_for(
2338
- sycl::nd_range<3>(block_nums * block_dims, block_dims),
2339
- [=](sycl::nd_item<3> item_ct1) {
2340
- rope_neox<T, true, true>(x, dst, ncols, n_dims, pos, freq_scale,
2341
- p_delta_rows, ext_factor, attn_factor,
2342
- corr_dims, theta_scale, inv_ndims, freq_factors, item_ct1);
2343
- });
2344
- }
2345
- }
2346
- }
2347
-
2348
  static void sum_rows_f32_sycl(const float *x, float *dst, const int ncols,
2349
  const int nrows, queue_ptr stream) {
2350
  const sycl::range<3> block_dims(1, 1, WARP_SIZE);
@@ -3461,97 +3249,6 @@ catch (sycl::exception const &exc) {
3461
  std::exit(1);
3462
  }
3463
 
3464
- inline void ggml_sycl_op_rope(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1,
3465
- ggml_tensor *dst, const float *src0_dd,
3466
- const float *src1_dd, float *dst_dd,
3467
- const queue_ptr &main_stream) {
3468
- const ggml_tensor * src2 = dst->src[2];
3469
-
3470
- GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
3471
- GGML_ASSERT( dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
3472
- GGML_ASSERT(src0->type == dst->type);
3473
-
3474
- const int64_t ne00 = src0->ne[0];
3475
- const int64_t ne01 = src0->ne[1];
3476
- const int64_t ne2 = dst->ne[2];
3477
- const int64_t nrows = ggml_nrows(src0);
3478
-
3479
- //const int n_past = ((int32_t *) dst->op_params)[0];
3480
- const int n_dims = ((int32_t *) dst->op_params)[1];
3481
- const int mode = ((int32_t *) dst->op_params)[2];
3482
- //const int n_ctx = ((int32_t *) dst->op_params)[3];
3483
- const int n_ctx_orig = ((int32_t *) dst->op_params)[4];
3484
-
3485
- // RoPE alteration for extended context
3486
- float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
3487
- memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float));
3488
- memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float));
3489
- memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float));
3490
- memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float));
3491
- memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
3492
- memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
3493
-
3494
- const float * freq_factors = nullptr;
3495
- const int32_t * pos = nullptr;
3496
- if ((mode & 1) == 0) {
3497
- GGML_ASSERT(src1->type == GGML_TYPE_I32);
3498
- GGML_ASSERT(src1->ne[0] == ne2);
3499
- pos = (const int32_t *) src1_dd;
3500
- }
3501
-
3502
- const bool is_neox = mode & 2;
3503
-
3504
- #pragma message("TODO: update rope NORM mode to match NEOX mode")
3505
- #pragma message(" https://github.com/ggerganov/llama.cpp/pull/7634")
3506
-
3507
- if (is_neox) {
3508
- pos = (const int32_t *) src1_dd;
3509
-
3510
- if (src2 != nullptr) {
3511
- freq_factors = (const float *) src2->data;
3512
- }
3513
- } else {
3514
- GGML_ASSERT(src2 == nullptr && "TODO: freq_factors not implemented for !is_neox");
3515
- }
3516
-
3517
- rope_corr_dims corr_dims;
3518
- ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims.v);
3519
-
3520
- // compute
3521
- if (is_neox) {
3522
- if (src0->type == GGML_TYPE_F32) {
3523
- rope_neox_sycl(
3524
- (const float *)src0_dd, (float *)dst_dd, ne00, n_dims, nrows, pos, freq_scale, ne01, freq_base, ext_factor,
3525
- attn_factor, corr_dims, freq_factors, main_stream
3526
- );
3527
- } else if (src0->type == GGML_TYPE_F16) {
3528
- rope_neox_sycl((const sycl::half *)src0_dd, (sycl::half *)dst_dd,
3529
- ne00, n_dims, nrows, pos, freq_scale, ne01,
3530
- freq_base, ext_factor, attn_factor, corr_dims,
3531
- freq_factors, main_stream);
3532
- } else {
3533
- GGML_ASSERT(false);
3534
- }
3535
- } else {
3536
- if (src0->type == GGML_TYPE_F32) {
3537
- rope_sycl(
3538
- (const float *)src0_dd, (float *)dst_dd, ne00, nrows, pos, freq_scale, ne01, freq_base, ext_factor,
3539
- attn_factor, corr_dims, main_stream
3540
- );
3541
- } else if (src0->type == GGML_TYPE_F16) {
3542
- rope_sycl((const sycl::half *)src0_dd, (sycl::half *)dst_dd, ne00,
3543
- nrows, pos, freq_scale, ne01, freq_base, ext_factor,
3544
- attn_factor, corr_dims, main_stream);
3545
- } else {
3546
- GGML_ASSERT(false);
3547
- }
3548
- }
3549
-
3550
- (void) src1;
3551
- (void) dst;
3552
- (void) src1_dd;
3553
- }
3554
-
3555
  static void ggml_sycl_op_pool2d(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
3556
  const ggml_tensor *src1, ggml_tensor *dst,
3557
  const float *src0_dd, const float *src1_dd,
@@ -6241,7 +5938,9 @@ GGML_CALL static bool ggml_backend_sycl_supports_op(ggml_backend_t backend, cons
6241
  case GGML_OP_CONT:
6242
  case GGML_OP_DIAG_MASK_INF:
6243
  case GGML_OP_SOFT_MAX:
 
6244
  case GGML_OP_ROPE:
 
6245
  case GGML_OP_IM2COL:
6246
  case GGML_OP_POOL_2D:
6247
  case GGML_OP_SUM_ROWS:
 
978
  cpy_blck(cx + x_offset, cdst + dst_offset);
979
  }
980
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
981
  static void k_sum_rows_f32(const float * x, float * dst, const int ncols,
982
  const sycl::nd_item<3> &item_ct1) {
983
  const int row = item_ct1.get_group(1);
 
2133
  });
2134
  }
2135
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2136
  static void sum_rows_f32_sycl(const float *x, float *dst, const int ncols,
2137
  const int nrows, queue_ptr stream) {
2138
  const sycl::range<3> block_dims(1, 1, WARP_SIZE);
 
3249
  std::exit(1);
3250
  }
3251
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3252
  static void ggml_sycl_op_pool2d(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
3253
  const ggml_tensor *src1, ggml_tensor *dst,
3254
  const float *src0_dd, const float *src1_dd,
 
5938
  case GGML_OP_CONT:
5939
  case GGML_OP_DIAG_MASK_INF:
5940
  case GGML_OP_SOFT_MAX:
5941
+ return true;
5942
  case GGML_OP_ROPE:
5943
+ return ggml_is_contiguous(op->src[0]);
5944
  case GGML_OP_IM2COL:
5945
  case GGML_OP_POOL_2D:
5946
  case GGML_OP_SUM_ROWS: