Akarshan Biswas commited on
Commit
e4b1812
·
1 Parent(s): a434936

SYCL: Add mrope kernel (llama/13755)

Browse files

* SYCL: Add mrope kernel

* feat: Optimize rope operations with vectorization

Uses `sycl::vec` to load and store two elements at a time,
significantly improving performance in `rope_norm`,
`rope_neox`, and `rope_multi`. This reduces the number of memory
accesses and leverages SIMD instructions for faster execution.

* Use ceil_div

ggml/src/ggml-sycl/ggml-sycl.cpp CHANGED
@@ -4257,14 +4257,6 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
4257
  case GGML_OP_SOFT_MAX:
4258
  return true;
4259
  case GGML_OP_ROPE:
4260
- {
4261
- const int mode = ((const int32_t *) op->op_params)[2];
4262
- // mode is not used as a bitmask in practice, the various rope type modes are independent implementations
4263
- if (mode == GGML_ROPE_TYPE_MROPE) {
4264
- return false;
4265
- }
4266
- return true;
4267
- }
4268
  case GGML_OP_IM2COL:
4269
  return true;
4270
  case GGML_OP_UPSCALE:
 
4257
  case GGML_OP_SOFT_MAX:
4258
  return true;
4259
  case GGML_OP_ROPE:
 
 
 
 
 
 
 
 
4260
  case GGML_OP_IM2COL:
4261
  return true;
4262
  case GGML_OP_UPSCALE:
ggml/src/ggml-sycl/rope.cpp CHANGED
@@ -49,10 +49,7 @@ static void rope_norm(const T * x, T * dst, const int ne0, const int ne1, const
49
 
50
  if (i0 >= n_dims) {
51
  const int i = row * ne0 + i0;
52
-
53
- dst[i + 0] = x[i + 0];
54
- dst[i + 1] = x[i + 1];
55
-
56
  return;
57
  }
58
 
@@ -93,10 +90,7 @@ static void rope_neox(const T * x, T * dst, const int ne0, const int ne1, const
93
 
94
  if (i0 >= n_dims) {
95
  const int i = row * ne0 + i0;
96
-
97
- dst[i + 0] = x[i + 0];
98
- dst[i + 1] = x[i + 1];
99
-
100
  return;
101
  }
102
 
@@ -122,6 +116,63 @@ static void rope_neox(const T * x, T * dst, const int ne0, const int ne1, const
122
  dst[i + n_dims / 2] = x0 * sin_theta + x1 * cos_theta;
123
  }
124
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
125
  template <typename T, bool has_ff>
126
  static void rope_vision(const T * x, T * dst, const int ne0, const int ne1, const int ne2, const size_t s1,
127
  const size_t s2, const int n_dims, const int32_t * pos, const float freq_scale,
@@ -171,7 +222,7 @@ static void rope_norm_sycl(const T * x, T * dst, const int ne0, const int ne1, c
171
  const float * freq_factors, queue_ptr stream) {
172
  GGML_ASSERT(ne0 % 2 == 0);
173
  const sycl::range<3> block_dims(1, SYCL_ROPE_BLOCK_SIZE, 1);
174
- const int num_blocks_x = (ne0 + 2 * SYCL_ROPE_BLOCK_SIZE - 1) / (2 * SYCL_ROPE_BLOCK_SIZE);
175
  const sycl::range<3> block_nums(1, num_blocks_x, nr);
176
 
177
  const float theta_scale = powf(freq_base, -2.0f / n_dims);
@@ -208,7 +259,7 @@ static void rope_neox_sycl(const T * x, T * dst, const int ne0, const int ne1, c
208
  const rope_corr_dims corr_dims, const float * freq_factors, queue_ptr stream) {
209
  GGML_ASSERT(ne0 % 2 == 0);
210
  const sycl::range<3> block_dims(1, SYCL_ROPE_BLOCK_SIZE, 1);
211
- const int num_blocks_x = (ne0 + 2 * SYCL_ROPE_BLOCK_SIZE - 1) / (2 * SYCL_ROPE_BLOCK_SIZE);
212
  const sycl::range<3> block_nums(1, num_blocks_x, nr);
213
 
214
  const float theta_scale = powf(freq_base, -2.0f / n_dims);
@@ -228,6 +279,40 @@ static void rope_neox_sycl(const T * x, T * dst, const int ne0, const int ne1, c
228
  }
229
  }
230
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
231
  // rope vision
232
  template <typename T>
233
  static void rope_vision_sycl(const T * x, T * dst, const int ne0, const int ne1, const int ne2, const size_t s1,
@@ -237,7 +322,7 @@ static void rope_vision_sycl(const T * x, T * dst, const int ne0, const int ne1,
237
  const mrope_sections sections, queue_ptr stream) {
238
  GGML_ASSERT(ne0 % 2 == 0);
239
  const sycl::range<3> block_dims(1, SYCL_ROPE_BLOCK_SIZE, 1);
240
- const int n_blocks_y = (ne0 + 2 * SYCL_ROPE_BLOCK_SIZE - 1) / (2 * SYCL_ROPE_BLOCK_SIZE);
241
  const sycl::range<3> grid_dims(1, n_blocks_y, nr);
242
  const sycl::nd_range<3> nd_range(grid_dims * block_dims, block_dims);
243
 
@@ -298,8 +383,17 @@ inline void ggml_sycl_op_rope(ggml_backend_sycl_context & ctx, ggml_tensor *dst)
298
  memcpy(&sections.v, (int32_t *) dst->op_params + 11, sizeof(int)*4);
299
 
300
  const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
 
301
  const bool is_vision = mode == GGML_ROPE_TYPE_VISION;
302
 
 
 
 
 
 
 
 
 
303
  const int32_t * pos = (const int32_t *) dst->src[1]->data;
304
 
305
  const float * freq_factors = nullptr;
@@ -326,6 +420,19 @@ inline void ggml_sycl_op_rope(ggml_backend_sycl_context & ctx, ggml_tensor *dst)
326
  } else {
327
  GGML_ABORT("fatal error");
328
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
329
  } else if (is_vision) {
330
  GGML_SYCL_DEBUG("%s: vision path\n", __func__);
331
  if (dst->src[0]->type == GGML_TYPE_F16) {
 
49
 
50
  if (i0 >= n_dims) {
51
  const int i = row * ne0 + i0;
52
+ *reinterpret_cast<sycl::vec<T, 2> *>(dst + i) = *reinterpret_cast<const sycl::vec<T, 2> *>(x + i);
 
 
 
53
  return;
54
  }
55
 
 
90
 
91
  if (i0 >= n_dims) {
92
  const int i = row * ne0 + i0;
93
+ *reinterpret_cast<sycl::vec<T, 2> *>(dst + i) = *reinterpret_cast<const sycl::vec<T, 2> *>(x + i);
 
 
 
94
  return;
95
  }
96
 
 
116
  dst[i + n_dims / 2] = x0 * sin_theta + x1 * cos_theta;
117
  }
118
 
119
+ template <typename T, bool has_ff>
120
+ static void rope_multi(const T * x, T * dst, const int ne0, const int ne1, const int ne2, const size_t s1,
121
+ const size_t s2, const int n_dims, const int32_t * pos, const float freq_scale,
122
+ const float ext_factor, const float attn_factor, const rope_corr_dims corr_dims,
123
+ const float theta_scale, const float * freq_factors, const mrope_sections sections,
124
+ const sycl::nd_item<3> & item_ct1) {
125
+ // get index pos
126
+ const int i0 = 2 * (item_ct1.get_group(1) * item_ct1.get_local_range(1) + item_ct1.get_local_id(1));
127
+ if (i0 >= ne0) {
128
+ return;
129
+ }
130
+ const int row_dst = (item_ct1.get_group(2) * item_ct1.get_local_range(2)) + item_ct1.get_local_id(2);
131
+
132
+ if (i0 >= n_dims) {
133
+ const int i = row_dst*ne0 + i0;
134
+ *reinterpret_cast<sycl::vec<T, 2> *>(dst + i) = *reinterpret_cast<const sycl::vec<T, 2> *>(x + i);
135
+ return;
136
+ }
137
+
138
+ const int row_x = row_dst % ne1;
139
+ const int channel_x = row_dst / ne1;
140
+ const int idst = (row_dst * ne0) + (i0 / 2);
141
+ const size_t ix = ((size_t) channel_x * s2) + ((size_t) row_x * s1) + (i0 / 2);
142
+
143
+ const int sect_dims = sections.v[0] + sections.v[1] + sections.v[2] + sections.v[3];
144
+ const int sec_w = sections.v[1] + sections.v[0];
145
+ const int sector = (i0 / 2) % sect_dims;
146
+
147
+
148
+ float theta_base = 0.0;
149
+ if (sector < sections.v[0]) {
150
+ theta_base = pos[channel_x]*sycl::pow(theta_scale, i0/2.0f);
151
+ }
152
+ else if (sector >= sections.v[0] && sector < sec_w) {
153
+ theta_base = pos[channel_x + ne2 * 1]*sycl::pow(theta_scale, i0/2.0f);
154
+ }
155
+ else if (sector >= sec_w && sector < sec_w + sections.v[2]) {
156
+ theta_base = pos[channel_x + ne2 * 2]*sycl::pow(theta_scale, i0/2.0f);
157
+ }
158
+ else if (sector >= sec_w + sections.v[2]) {
159
+ theta_base = pos[channel_x + ne2 * 3]*sycl::pow(theta_scale, i0/2.0f);
160
+ }
161
+
162
+ const float freq_factor = has_ff ? freq_factors[i0 / 2] : 1.0f;
163
+ float cos_theta;
164
+ float sin_theta;
165
+ rope_yarn(theta_base / freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
166
+ const float x0 = x[ix + 0];
167
+ const float x1 = x[ix + n_dims/2];
168
+
169
+ // store results in dst
170
+ dst[idst + 0] = x0 * cos_theta - x1 * sin_theta;
171
+ dst[idst + n_dims/2] = x0 * sin_theta + x1 * cos_theta;
172
+ }
173
+
174
+
175
+
176
  template <typename T, bool has_ff>
177
  static void rope_vision(const T * x, T * dst, const int ne0, const int ne1, const int ne2, const size_t s1,
178
  const size_t s2, const int n_dims, const int32_t * pos, const float freq_scale,
 
222
  const float * freq_factors, queue_ptr stream) {
223
  GGML_ASSERT(ne0 % 2 == 0);
224
  const sycl::range<3> block_dims(1, SYCL_ROPE_BLOCK_SIZE, 1);
225
+ const int num_blocks_x = ceil_div(ne0, (2 * SYCL_ROPE_BLOCK_SIZE));
226
  const sycl::range<3> block_nums(1, num_blocks_x, nr);
227
 
228
  const float theta_scale = powf(freq_base, -2.0f / n_dims);
 
259
  const rope_corr_dims corr_dims, const float * freq_factors, queue_ptr stream) {
260
  GGML_ASSERT(ne0 % 2 == 0);
261
  const sycl::range<3> block_dims(1, SYCL_ROPE_BLOCK_SIZE, 1);
262
+ const int num_blocks_x = ceil_div(ne0, (2 * SYCL_ROPE_BLOCK_SIZE));
263
  const sycl::range<3> block_nums(1, num_blocks_x, nr);
264
 
265
  const float theta_scale = powf(freq_base, -2.0f / n_dims);
 
279
  }
280
  }
281
 
282
+ template <typename T>
283
+ static void rope_multi_sycl(const T * x, T * dst, const int ne0, const int ne1, const int ne2, const size_t s1,
284
+ const size_t s2, const int n_dims, const int nr, const int32_t * pos,
285
+ const float freq_scale, const float freq_base, const float ext_factor,
286
+ const float attn_factor, const rope_corr_dims corr_dims, const float * freq_factors,
287
+ const mrope_sections sections, queue_ptr stream) {
288
+ GGML_ASSERT(ne0 % 2 == 0);
289
+ const sycl::range<3> block_dims(1, SYCL_ROPE_BLOCK_SIZE, 1);
290
+ const int n_blocks_y = ceil_div(ne0, (2 * SYCL_ROPE_BLOCK_SIZE));
291
+ const sycl::range<3> grid_dims(1, n_blocks_y, nr);
292
+ const sycl::nd_range<3> nd_range(grid_dims * block_dims, block_dims);
293
+
294
+ const float theta_scale = std::pow(freq_base, -2.0f / n_dims);
295
+ // Add FP16 capability check if T could be sycl::half
296
+ if constexpr (std::is_same_v<T, sycl::half>) {
297
+ dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 });
298
+ }
299
+ // launch kernel
300
+ if (freq_factors == nullptr) {
301
+ stream->parallel_for(nd_range, [=](sycl::nd_item<3> item_ct1) {
302
+ rope_multi<T, false>(x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor,
303
+ corr_dims, theta_scale, freq_factors, sections, item_ct1);
304
+ });
305
+ } else {
306
+ stream->parallel_for(nd_range, [=](sycl::nd_item<3> item_ct1) {
307
+ rope_multi<T, true>(x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor,
308
+ corr_dims, theta_scale, freq_factors, sections, item_ct1);
309
+ });
310
+ }
311
+ }
312
+
313
+
314
+
315
+
316
  // rope vision
317
  template <typename T>
318
  static void rope_vision_sycl(const T * x, T * dst, const int ne0, const int ne1, const int ne2, const size_t s1,
 
322
  const mrope_sections sections, queue_ptr stream) {
323
  GGML_ASSERT(ne0 % 2 == 0);
324
  const sycl::range<3> block_dims(1, SYCL_ROPE_BLOCK_SIZE, 1);
325
+ const int n_blocks_y = ceil_div(ne0, (2 * SYCL_ROPE_BLOCK_SIZE));
326
  const sycl::range<3> grid_dims(1, n_blocks_y, nr);
327
  const sycl::nd_range<3> nd_range(grid_dims * block_dims, block_dims);
328
 
 
383
  memcpy(&sections.v, (int32_t *) dst->op_params + 11, sizeof(int)*4);
384
 
385
  const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
386
+ const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE;
387
  const bool is_vision = mode == GGML_ROPE_TYPE_VISION;
388
 
389
+ if (is_mrope) {
390
+ GGML_ASSERT(sections.v[0] > 0 || sections.v[1] > 0 || sections.v[2] > 0);
391
+ }
392
+
393
+ if (is_vision) {
394
+ GGML_ASSERT(n_dims == ne00/2);
395
+ }
396
+
397
  const int32_t * pos = (const int32_t *) dst->src[1]->data;
398
 
399
  const float * freq_factors = nullptr;
 
420
  } else {
421
  GGML_ABORT("fatal error");
422
  }
423
+ } else if (is_mrope && !is_vision) {
424
+ GGML_SYCL_DEBUG("%s: mrope path\n", __func__);
425
+ if (dst->src[0]->type == GGML_TYPE_F16) {
426
+ rope_multi_sycl((const sycl::half *)dst->src[0]->data, (sycl::half *)dst->data, ne00, ne01, ne02, s01,
427
+ s02, n_dims, nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims,
428
+ freq_factors, sections, main_stream);
429
+ } else if (dst->src[0]->type == GGML_TYPE_F32) {
430
+ rope_multi_sycl((const float *) dst->src[0]->data, (float *) dst->data, ne00, ne01, ne02, s01, s02, n_dims,
431
+ nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections,
432
+ main_stream);
433
+ } else {
434
+ GGML_ABORT("Fatal error: Tensor type unsupported!");
435
+ }
436
  } else if (is_vision) {
437
  GGML_SYCL_DEBUG("%s: vision path\n", __func__);
438
  if (dst->src[0]->type == GGML_TYPE_F16) {