liuwei-git ggerganov commited on
Commit
ef68527
·
1 Parent(s): 8d153a7

llama : add phi3 128K model support (llama/7225)

Browse files

* add phi3 128k support in convert-hf-to-gguf

* add phi3 128k support in cuda

* address build warnings on llama.cpp

* adjust index value in cuda long rope freq factors

* add long rope support in ggml cpu backend

* make freq factors only depend on ctx size

* remove unused rope scaling type 'su' frin gguf converter

* fix flint warnings on convert-hf-to-gguf.py

* set to the short freq factor when context size is small than trained context size

* add one line of comments

* metal : support rope freq_factors

* ggml : update ggml_rope_ext API to support freq. factors

* backends : add dev messages to support rope freq. factors

* minor : style

* tests : update to use new rope API

* backends : fix pragma semicolons

* minor : cleanup

* llama : move rope factors from KV header to tensors

* llama : remove tmp assert

* cuda : fix compile warning

* convert : read/write n_head_kv

* llama : fix uninitialized tensors

---------

Co-authored-by: Georgi Gerganov <[email protected]>

Files changed (8) hide show
  1. ggml-cuda/rope.cu +48 -24
  2. ggml-kompute.cpp +4 -0
  3. ggml-metal.m +68 -53
  4. ggml-metal.metal +5 -1
  5. ggml-sycl.cpp +3 -0
  6. ggml-vulkan.cpp +4 -0
  7. ggml.c +68 -12
  8. ggml.h +36 -9
ggml-cuda/rope.cu CHANGED
@@ -58,10 +58,10 @@ static __global__ void rope(
58
  dst[i + 1] = x0*sin_theta + x1*cos_theta;
59
  }
60
 
61
- template<typename T, bool has_pos>
62
  static __global__ void rope_neox(
63
  const T * x, T * dst, int ncols, int n_dims, const int32_t * pos, float freq_scale, int p_delta_rows,
64
- float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, float inv_ndims
65
  ) {
66
  const int col = 2*(blockDim.y*blockIdx.y + threadIdx.y);
67
 
@@ -88,7 +88,9 @@ static __global__ void rope_neox(
88
  float cur_rot = inv_ndims * ic - ib;
89
 
90
  const int p = has_pos ? pos[i2] : 0;
91
- const float theta_base = p*freq_scale*powf(theta_scale, col/2.0f);
 
 
92
 
93
  float cos_theta, sin_theta;
94
  rope_yarn(theta_base, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor, &cos_theta, &sin_theta);
@@ -164,7 +166,7 @@ static void rope_cuda(
164
  template<typename T>
165
  static void rope_neox_cuda(
166
  const T * x, T * dst, int ncols, int n_dims, int nrows, const int32_t * pos, float freq_scale, int p_delta_rows,
167
- float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, cudaStream_t stream
168
  ) {
169
  GGML_ASSERT(ncols % 2 == 0);
170
  const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
@@ -175,15 +177,29 @@ static void rope_neox_cuda(
175
  const float inv_ndims = -1.0f / n_dims;
176
 
177
  if (pos == nullptr) {
178
- rope_neox<T, false><<<block_nums, block_dims, 0, stream>>>(
179
- x, dst, ncols, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
180
- theta_scale, inv_ndims
181
- );
 
 
 
 
 
 
 
182
  } else {
183
- rope_neox<T, true><<<block_nums, block_dims, 0, stream>>>(
184
- x, dst, ncols, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
185
- theta_scale, inv_ndims
186
- );
 
 
 
 
 
 
 
187
  }
188
  }
189
 
@@ -214,24 +230,27 @@ static void rope_cuda_f32(
214
 
215
  static void rope_neox_cuda_f16(
216
  const half * x, half * dst, int ncols, int n_dims, int nrows, const int32_t * pos, float freq_scale, int p_delta_rows,
217
- float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, cudaStream_t stream) {
218
 
219
- rope_neox_cuda<half>(x, dst, ncols, n_dims, nrows, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, stream);
220
  }
221
 
222
  static void rope_neox_cuda_f32(
223
  const float * x, float * dst, int ncols, int n_dims, int nrows, const int32_t * pos, float freq_scale, int p_delta_rows,
224
- float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, cudaStream_t stream
225
  ) {
226
 
227
- rope_neox_cuda<float>(x, dst, ncols, n_dims, nrows, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, stream);
228
  }
229
 
230
  void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
231
  const ggml_tensor * src0 = dst->src[0];
232
  const ggml_tensor * src1 = dst->src[1];
 
 
233
  const float * src0_d = (const float *)src0->data;
234
  const float * src1_d = (const float *)src1->data;
 
235
  float * dst_d = (float *)dst->data;
236
  cudaStream_t stream = ctx.stream();
237
 
@@ -241,7 +260,6 @@ void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
241
 
242
  const int64_t ne00 = src0->ne[0];
243
  const int64_t ne01 = src0->ne[1];
244
- const int64_t ne2 = dst->ne[2];
245
  const int64_t nrows = ggml_nrows(src0);
246
 
247
  //const int n_past = ((int32_t *) dst->op_params)[0];
@@ -259,16 +277,22 @@ void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
259
  memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
260
  memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
261
 
 
262
  const int32_t * pos = nullptr;
263
- if ((mode & 1) == 0) {
264
- GGML_ASSERT(src1->type == GGML_TYPE_I32);
265
- GGML_ASSERT(src1->ne[0] == ne2);
266
- pos = (const int32_t *) src1_d;
267
- }
268
 
269
  const bool is_neox = mode & 2;
270
  const bool is_glm = mode & 4;
271
 
 
 
 
 
 
 
 
 
 
 
272
  rope_corr_dims corr_dims;
273
  ggml_rope_yarn_corr_dims(n_dims, n_orig_ctx, freq_base, beta_fast, beta_slow, corr_dims.v);
274
 
@@ -280,12 +304,12 @@ void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
280
  if (src0->type == GGML_TYPE_F32) {
281
  rope_neox_cuda_f32(
282
  (const float *)src0_d, (float *)dst_d, ne00, n_dims, nrows, pos, freq_scale, ne01, freq_base, ext_factor,
283
- attn_factor, corr_dims, stream
284
  );
285
  } else if (src0->type == GGML_TYPE_F16) {
286
  rope_neox_cuda_f16(
287
  (const half *)src0_d, (half *)dst_d, ne00, n_dims, nrows, pos, freq_scale, ne01, freq_base, ext_factor,
288
- attn_factor, corr_dims, stream
289
  );
290
  } else {
291
  GGML_ASSERT(false);
 
58
  dst[i + 1] = x0*sin_theta + x1*cos_theta;
59
  }
60
 
61
+ template<typename T, bool has_pos, bool has_freq_facs>
62
  static __global__ void rope_neox(
63
  const T * x, T * dst, int ncols, int n_dims, const int32_t * pos, float freq_scale, int p_delta_rows,
64
+ float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, float inv_ndims, const float * freq_factors
65
  ) {
66
  const int col = 2*(blockDim.y*blockIdx.y + threadIdx.y);
67
 
 
88
  float cur_rot = inv_ndims * ic - ib;
89
 
90
  const int p = has_pos ? pos[i2] : 0;
91
+ const float freq_factor = has_freq_facs ? freq_factors[ic/2] : 1.0f;
92
+
93
+ const float theta_base = p*freq_scale*powf(theta_scale, col/2.0f)/freq_factor;
94
 
95
  float cos_theta, sin_theta;
96
  rope_yarn(theta_base, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor, &cos_theta, &sin_theta);
 
166
  template<typename T>
167
  static void rope_neox_cuda(
168
  const T * x, T * dst, int ncols, int n_dims, int nrows, const int32_t * pos, float freq_scale, int p_delta_rows,
169
+ float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream
170
  ) {
171
  GGML_ASSERT(ncols % 2 == 0);
172
  const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
 
177
  const float inv_ndims = -1.0f / n_dims;
178
 
179
  if (pos == nullptr) {
180
+ if (freq_factors == nullptr) {
181
+ rope_neox<T, false, false><<<block_nums, block_dims, 0, stream>>>(
182
+ x, dst, ncols, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
183
+ theta_scale, inv_ndims, freq_factors
184
+ );
185
+ } else {
186
+ rope_neox<T, false, true><<<block_nums, block_dims, 0, stream>>>(
187
+ x, dst, ncols, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
188
+ theta_scale, inv_ndims, freq_factors
189
+ );
190
+ }
191
  } else {
192
+ if (freq_factors == nullptr) {
193
+ rope_neox<T, true, false><<<block_nums, block_dims, 0, stream>>>(
194
+ x, dst, ncols, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
195
+ theta_scale, inv_ndims, freq_factors
196
+ );
197
+ } else {
198
+ rope_neox<T, true, true><<<block_nums, block_dims, 0, stream>>>(
199
+ x, dst, ncols, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
200
+ theta_scale, inv_ndims, freq_factors
201
+ );
202
+ }
203
  }
204
  }
205
 
 
230
 
231
  static void rope_neox_cuda_f16(
232
  const half * x, half * dst, int ncols, int n_dims, int nrows, const int32_t * pos, float freq_scale, int p_delta_rows,
233
+ float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream) {
234
 
235
+ rope_neox_cuda<half>(x, dst, ncols, n_dims, nrows, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream);
236
  }
237
 
238
  static void rope_neox_cuda_f32(
239
  const float * x, float * dst, int ncols, int n_dims, int nrows, const int32_t * pos, float freq_scale, int p_delta_rows,
240
+ float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream
241
  ) {
242
 
243
+ rope_neox_cuda<float>(x, dst, ncols, n_dims, nrows, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream);
244
  }
245
 
246
  void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
247
  const ggml_tensor * src0 = dst->src[0];
248
  const ggml_tensor * src1 = dst->src[1];
249
+ const ggml_tensor * src2 = dst->src[2];
250
+
251
  const float * src0_d = (const float *)src0->data;
252
  const float * src1_d = (const float *)src1->data;
253
+
254
  float * dst_d = (float *)dst->data;
255
  cudaStream_t stream = ctx.stream();
256
 
 
260
 
261
  const int64_t ne00 = src0->ne[0];
262
  const int64_t ne01 = src0->ne[1];
 
263
  const int64_t nrows = ggml_nrows(src0);
264
 
265
  //const int n_past = ((int32_t *) dst->op_params)[0];
 
277
  memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
278
  memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
279
 
280
+ const float * freq_factors = nullptr;
281
  const int32_t * pos = nullptr;
 
 
 
 
 
282
 
283
  const bool is_neox = mode & 2;
284
  const bool is_glm = mode & 4;
285
 
286
+ if (is_neox) {
287
+ pos = (const int32_t *) src1_d;
288
+
289
+ if (src2 != nullptr) {
290
+ freq_factors = (const float *) src2->data;
291
+ }
292
+ } else {
293
+ GGML_ASSERT(src2 == nullptr && "TODO: freq_factors not implemented for !is_neox");
294
+ }
295
+
296
  rope_corr_dims corr_dims;
297
  ggml_rope_yarn_corr_dims(n_dims, n_orig_ctx, freq_base, beta_fast, beta_slow, corr_dims.v);
298
 
 
304
  if (src0->type == GGML_TYPE_F32) {
305
  rope_neox_cuda_f32(
306
  (const float *)src0_d, (float *)dst_d, ne00, n_dims, nrows, pos, freq_scale, ne01, freq_base, ext_factor,
307
+ attn_factor, corr_dims, freq_factors, stream
308
  );
309
  } else if (src0->type == GGML_TYPE_F16) {
310
  rope_neox_cuda_f16(
311
  (const half *)src0_d, (half *)dst_d, ne00, n_dims, nrows, pos, freq_scale, ne01, freq_base, ext_factor,
312
+ attn_factor, corr_dims, freq_factors, stream
313
  );
314
  } else {
315
  GGML_ASSERT(false);
ggml-kompute.cpp CHANGED
@@ -1677,6 +1677,10 @@ static void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml
1677
  } break;
1678
  case GGML_OP_ROPE:
1679
  {
 
 
 
 
1680
  GGML_ASSERT(ne10 == ne02);
1681
  GGML_ASSERT(src0t == dstt);
1682
  // const int n_past = ((int32_t *) dst->op_params)[0];
 
1677
  } break;
1678
  case GGML_OP_ROPE:
1679
  {
1680
+ #pragma message("TODO: implement phi3 frequency factors support")
1681
+ #pragma message(" https://github.com/ggerganov/llama.cpp/pull/7225")
1682
+ GGML_ASSERT(dst->src[2] == nullptr && "phi3 frequency factors not implemented yet");
1683
+
1684
  GGML_ASSERT(ne10 == ne02);
1685
  GGML_ASSERT(src0t == dstt);
1686
  // const int n_past = ((int32_t *) dst->op_params)[0];
ggml-metal.m CHANGED
@@ -927,22 +927,32 @@ static enum ggml_status ggml_metal_graph_compute(
927
  const int64_t ne10 = src1 ? src1->ne[0] : 0;
928
  const int64_t ne11 = src1 ? src1->ne[1] : 0;
929
  const int64_t ne12 = src1 ? src1->ne[2] : 0;
930
- const int64_t ne13 = src1 ? src1->ne[3] : 0; UNUSED(ne13);
931
 
932
  const uint64_t nb10 = src1 ? src1->nb[0] : 0;
933
  const uint64_t nb11 = src1 ? src1->nb[1] : 0;
934
  const uint64_t nb12 = src1 ? src1->nb[2] : 0;
935
- const uint64_t nb13 = src1 ? src1->nb[3] : 0; UNUSED(nb13);
936
 
937
- const int64_t ne0 = dst ? dst->ne[0] : 0;
938
- const int64_t ne1 = dst ? dst->ne[1] : 0;
939
- const int64_t ne2 = dst ? dst->ne[2] : 0;
940
- const int64_t ne3 = dst ? dst->ne[3] : 0;
941
 
942
- const uint64_t nb0 = dst ? dst->nb[0] : 0;
943
- const uint64_t nb1 = dst ? dst->nb[1] : 0;
944
- const uint64_t nb2 = dst ? dst->nb[2] : 0;
945
- const uint64_t nb3 = dst ? dst->nb[3] : 0;
 
 
 
 
 
 
 
 
 
 
946
 
947
  const enum ggml_type src0t = src0 ? src0->type : GGML_TYPE_COUNT;
948
  const enum ggml_type src1t = src1 ? src1->type : GGML_TYPE_COUNT;
@@ -1785,16 +1795,6 @@ static enum ggml_status ggml_metal_graph_compute(
1785
  const int n_as = src0->ne[2];
1786
 
1787
  // src2 = ids
1788
- const int64_t ne20 = src2->ne[0];
1789
- const int64_t ne21 = src2->ne[1];
1790
- const int64_t ne22 = src2->ne[2]; GGML_UNUSED(ne22);
1791
- const int64_t ne23 = src2->ne[3]; GGML_UNUSED(ne23);
1792
-
1793
- const uint64_t nb20 = src2->nb[0]; GGML_UNUSED(nb20);
1794
- const uint64_t nb21 = src2->nb[1];
1795
- const uint64_t nb22 = src2->nb[2]; GGML_UNUSED(nb22);
1796
- const uint64_t nb23 = src2->nb[3]; GGML_UNUSED(nb23);
1797
-
1798
  const enum ggml_type src2t = src2->type; GGML_UNUSED(src2t);
1799
 
1800
  GGML_ASSERT(src2t == GGML_TYPE_I32);
@@ -2244,7 +2244,13 @@ static enum ggml_status ggml_metal_graph_compute(
2244
  // skip 3, n_ctx, used in GLM RoPE, unimplemented in metal
2245
  const int n_orig_ctx = ((int32_t *) dst->op_params)[4];
2246
 
2247
- float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
 
 
 
 
 
 
2248
  memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float));
2249
  memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float));
2250
  memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float));
@@ -2252,6 +2258,15 @@ static enum ggml_status ggml_metal_graph_compute(
2252
  memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
2253
  memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
2254
 
 
 
 
 
 
 
 
 
 
2255
  id<MTLComputePipelineState> pipeline = nil;
2256
 
2257
  switch (src0->type) {
@@ -2263,33 +2278,38 @@ static enum ggml_status ggml_metal_graph_compute(
2263
  [encoder setComputePipelineState:pipeline];
2264
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2265
  [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
2266
- [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
2267
- [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:3];
2268
- [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:4];
2269
- [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:5];
2270
- [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:6];
2271
- [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:7];
2272
- [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:8];
2273
- [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:9];
2274
- [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:10];
2275
- [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:11];
2276
- [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:12];
2277
- [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:13];
2278
- [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:14];
2279
- [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:15];
2280
- [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:16];
2281
- [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:17];
2282
- [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:18];
2283
- [encoder setBytes:&n_past length:sizeof( int) atIndex:19];
2284
- [encoder setBytes:&n_dims length:sizeof( int) atIndex:20];
2285
- [encoder setBytes:&mode length:sizeof( int) atIndex:21];
2286
- [encoder setBytes:&n_orig_ctx length:sizeof( int) atIndex:22];
2287
- [encoder setBytes:&freq_base length:sizeof( float) atIndex:23];
2288
- [encoder setBytes:&freq_scale length:sizeof( float) atIndex:24];
2289
- [encoder setBytes:&ext_factor length:sizeof( float) atIndex:25];
2290
- [encoder setBytes:&attn_factor length:sizeof( float) atIndex:26];
2291
- [encoder setBytes:&beta_fast length:sizeof( float) atIndex:27];
2292
- [encoder setBytes:&beta_slow length:sizeof( float) atIndex:28];
 
 
 
 
 
2293
 
2294
  [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
2295
  } break;
@@ -2535,11 +2555,6 @@ static enum ggml_status ggml_metal_graph_compute(
2535
  GGML_ASSERT(!src3 || src3->ne[1] >= GGML_PAD(src0->ne[1], 8) &&
2536
  "the Flash-Attention Metal kernel requires the mask to be padded to 8 and at least n_queries big");
2537
 
2538
- const uint64_t nb20 = src2 ? src2->nb[0] : 0; GGML_UNUSED(nb20);
2539
- const uint64_t nb21 = src2 ? src2->nb[1] : 0;
2540
- const uint64_t nb22 = src2 ? src2->nb[2] : 0;
2541
- const uint64_t nb23 = src2 ? src2->nb[3] : 0;
2542
-
2543
  const int64_t ne30 = src3 ? src3->ne[0] : 0; GGML_UNUSED(ne30);
2544
  //const int64_t ne31 = src3 ? src3->ne[1] : 0;
2545
  const int64_t ne32 = src3 ? src3->ne[2] : 0; GGML_UNUSED(ne32);
 
927
  const int64_t ne10 = src1 ? src1->ne[0] : 0;
928
  const int64_t ne11 = src1 ? src1->ne[1] : 0;
929
  const int64_t ne12 = src1 ? src1->ne[2] : 0;
930
+ const int64_t ne13 = src1 ? src1->ne[3] : 0;
931
 
932
  const uint64_t nb10 = src1 ? src1->nb[0] : 0;
933
  const uint64_t nb11 = src1 ? src1->nb[1] : 0;
934
  const uint64_t nb12 = src1 ? src1->nb[2] : 0;
935
+ const uint64_t nb13 = src1 ? src1->nb[3] : 0;
936
 
937
+ const int64_t ne20 = src2 ? src2->ne[0] : 0;
938
+ const int64_t ne21 = src2 ? src2->ne[1] : 0;
939
+ const int64_t ne22 = src2 ? src2->ne[2] : 0; GGML_UNUSED(ne22);
940
+ const int64_t ne23 = src2 ? src2->ne[3] : 0; GGML_UNUSED(ne23);
941
 
942
+ const uint64_t nb20 = src2 ? src2->nb[0] : 0; GGML_UNUSED(nb20);
943
+ const uint64_t nb21 = src2 ? src2->nb[1] : 0;
944
+ const uint64_t nb22 = src2 ? src2->nb[2] : 0;
945
+ const uint64_t nb23 = src2 ? src2->nb[3] : 0;
946
+
947
+ const int64_t ne0 = dst ? dst->ne[0] : 0;
948
+ const int64_t ne1 = dst ? dst->ne[1] : 0;
949
+ const int64_t ne2 = dst ? dst->ne[2] : 0;
950
+ const int64_t ne3 = dst ? dst->ne[3] : 0;
951
+
952
+ const uint64_t nb0 = dst ? dst->nb[0] : 0;
953
+ const uint64_t nb1 = dst ? dst->nb[1] : 0;
954
+ const uint64_t nb2 = dst ? dst->nb[2] : 0;
955
+ const uint64_t nb3 = dst ? dst->nb[3] : 0;
956
 
957
  const enum ggml_type src0t = src0 ? src0->type : GGML_TYPE_COUNT;
958
  const enum ggml_type src1t = src1 ? src1->type : GGML_TYPE_COUNT;
 
1795
  const int n_as = src0->ne[2];
1796
 
1797
  // src2 = ids
 
 
 
 
 
 
 
 
 
 
1798
  const enum ggml_type src2t = src2->type; GGML_UNUSED(src2t);
1799
 
1800
  GGML_ASSERT(src2t == GGML_TYPE_I32);
 
2244
  // skip 3, n_ctx, used in GLM RoPE, unimplemented in metal
2245
  const int n_orig_ctx = ((int32_t *) dst->op_params)[4];
2246
 
2247
+ float freq_base;
2248
+ float freq_scale;
2249
+ float ext_factor;
2250
+ float attn_factor;
2251
+ float beta_fast;
2252
+ float beta_slow;
2253
+
2254
  memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float));
2255
  memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float));
2256
  memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float));
 
2258
  memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
2259
  memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
2260
 
2261
+ const bool is_neox = mode & 2;
2262
+ const bool is_glm = mode & 4;
2263
+
2264
+ GGML_ASSERT(!is_glm && "GLM RoPE not implemented in Metal");
2265
+
2266
+ if (!is_neox) {
2267
+ GGML_ASSERT(id_src2 == nil && "TODO: freq_factors not implemented for !is_neox");
2268
+ }
2269
+
2270
  id<MTLComputePipelineState> pipeline = nil;
2271
 
2272
  switch (src0->type) {
 
2278
  [encoder setComputePipelineState:pipeline];
2279
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2280
  [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
2281
+ if (id_src2 != nil) {
2282
+ [encoder setBuffer:id_src2 offset:offs_src2 atIndex:2];
2283
+ } else {
2284
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:2];
2285
+ }
2286
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:3];
2287
+ [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:4];
2288
+ [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:5];
2289
+ [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:6];
2290
+ [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:7];
2291
+ [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:8];
2292
+ [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:9];
2293
+ [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:10];
2294
+ [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:11];
2295
+ [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:12];
2296
+ [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:13];
2297
+ [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:14];
2298
+ [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:15];
2299
+ [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:16];
2300
+ [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:17];
2301
+ [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:18];
2302
+ [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:19];
2303
+ [encoder setBytes:&n_past length:sizeof( int) atIndex:20];
2304
+ [encoder setBytes:&n_dims length:sizeof( int) atIndex:21];
2305
+ [encoder setBytes:&mode length:sizeof( int) atIndex:22];
2306
+ [encoder setBytes:&n_orig_ctx length:sizeof( int) atIndex:23];
2307
+ [encoder setBytes:&freq_base length:sizeof( float) atIndex:24];
2308
+ [encoder setBytes:&freq_scale length:sizeof( float) atIndex:25];
2309
+ [encoder setBytes:&ext_factor length:sizeof( float) atIndex:26];
2310
+ [encoder setBytes:&attn_factor length:sizeof( float) atIndex:27];
2311
+ [encoder setBytes:&beta_fast length:sizeof( float) atIndex:28];
2312
+ [encoder setBytes:&beta_slow length:sizeof( float) atIndex:29];
2313
 
2314
  [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
2315
  } break;
 
2555
  GGML_ASSERT(!src3 || src3->ne[1] >= GGML_PAD(src0->ne[1], 8) &&
2556
  "the Flash-Attention Metal kernel requires the mask to be padded to 8 and at least n_queries big");
2557
 
 
 
 
 
 
2558
  const int64_t ne30 = src3 ? src3->ne[0] : 0; GGML_UNUSED(ne30);
2559
  //const int64_t ne31 = src3 ? src3->ne[1] : 0;
2560
  const int64_t ne32 = src3 ? src3->ne[2] : 0; GGML_UNUSED(ne32);
ggml-metal.metal CHANGED
@@ -1640,6 +1640,7 @@ static void rope_yarn_corr_dims(
1640
  typedef void (rope_t)(
1641
  device const void * src0,
1642
  device const int32_t * src1,
 
1643
  device float * dst,
1644
  constant int64_t & ne00,
1645
  constant int64_t & ne01,
@@ -1675,6 +1676,7 @@ template<typename T>
1675
  kernel void kernel_rope(
1676
  device const void * src0,
1677
  device const int32_t * src1,
 
1678
  device float * dst,
1679
  constant int64_t & ne00,
1680
  constant int64_t & ne01,
@@ -1744,8 +1746,10 @@ kernel void kernel_rope(
1744
 
1745
  // simplified from `(ib * n_dims + ic) * inv_ndims`
1746
  const float cur_rot = inv_ndims*ic - ib;
 
 
 
1747
 
1748
- const float theta = theta_0 * pow(freq_base, cur_rot);
1749
  float cos_theta, sin_theta;
1750
  rope_yarn(theta, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor, &cos_theta, &sin_theta);
1751
 
 
1640
  typedef void (rope_t)(
1641
  device const void * src0,
1642
  device const int32_t * src1,
1643
+ device const float * src2,
1644
  device float * dst,
1645
  constant int64_t & ne00,
1646
  constant int64_t & ne01,
 
1676
  kernel void kernel_rope(
1677
  device const void * src0,
1678
  device const int32_t * src1,
1679
+ device const float * src2,
1680
  device float * dst,
1681
  constant int64_t & ne00,
1682
  constant int64_t & ne01,
 
1746
 
1747
  // simplified from `(ib * n_dims + ic) * inv_ndims`
1748
  const float cur_rot = inv_ndims*ic - ib;
1749
+ const float freq_factor = src2 != src0 ? src2[ic/2] : 1.0f;
1750
+
1751
+ const float theta = theta_0 * pow(freq_base, cur_rot) / freq_factor;
1752
 
 
1753
  float cos_theta, sin_theta;
1754
  rope_yarn(theta, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor, &cos_theta, &sin_theta);
1755
 
ggml-sycl.cpp CHANGED
@@ -14454,6 +14454,9 @@ inline void ggml_sycl_op_rope(const ggml_tensor *src0, const ggml_tensor *src1,
14454
  ggml_tensor *dst, const float *src0_dd,
14455
  const float *src1_dd, float *dst_dd,
14456
  const dpct::queue_ptr &main_stream) {
 
 
 
14457
 
14458
  GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
14459
  GGML_ASSERT( dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
 
14454
  ggml_tensor *dst, const float *src0_dd,
14455
  const float *src1_dd, float *dst_dd,
14456
  const dpct::queue_ptr &main_stream) {
14457
+ #pragma message("TODO: implement phi3 frequency factors support")
14458
+ #pragma message(" https://github.com/ggerganov/llama.cpp/pull/7225")
14459
+ GGML_ASSERT(dst->src[2] == nullptr && "phi3 frequency factors not implemented yet");
14460
 
14461
  GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
14462
  GGML_ASSERT( dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
ggml-vulkan.cpp CHANGED
@@ -4238,6 +4238,10 @@ static void ggml_vk_soft_max(ggml_backend_vk_context * ctx, vk_context * subctx,
4238
  }
4239
 
4240
  static void ggml_vk_rope(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
 
 
 
 
4241
  const int n_dims = ((int32_t *) dst->op_params)[1];
4242
  const int mode = ((int32_t *) dst->op_params)[2];
4243
  // const int n_ctx = ((int32_t *) dst->op_params)[3];
 
4238
  }
4239
 
4240
  static void ggml_vk_rope(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
4241
+ #pragma message("TODO: implement phi3 frequency factors support")
4242
+ #pragma message(" https://github.com/ggerganov/llama.cpp/pull/7225")
4243
+ GGML_ASSERT(dst->src[2] == nullptr && "phi3 frequency factors not implemented yet");
4244
+
4245
  const int n_dims = ((int32_t *) dst->op_params)[1];
4246
  const int mode = ((int32_t *) dst->op_params)[2];
4247
  // const int n_ctx = ((int32_t *) dst->op_params)[3];
ggml.c CHANGED
@@ -6231,6 +6231,7 @@ static struct ggml_tensor * ggml_rope_impl(
6231
  struct ggml_context * ctx,
6232
  struct ggml_tensor * a,
6233
  struct ggml_tensor * b,
 
6234
  int n_dims,
6235
  int mode,
6236
  int n_ctx,
@@ -6248,6 +6249,11 @@ static struct ggml_tensor * ggml_rope_impl(
6248
  GGML_ASSERT(b->type == GGML_TYPE_I32);
6249
  GGML_ASSERT(a->ne[2] == b->ne[0]);
6250
 
 
 
 
 
 
6251
  bool is_node = false;
6252
 
6253
  if (a->grad) {
@@ -6271,6 +6277,7 @@ static struct ggml_tensor * ggml_rope_impl(
6271
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
6272
  result->src[0] = a;
6273
  result->src[1] = b;
 
6274
 
6275
  return result;
6276
  }
@@ -6283,7 +6290,7 @@ struct ggml_tensor * ggml_rope(
6283
  int mode,
6284
  int n_ctx) {
6285
  return ggml_rope_impl(
6286
- ctx, a, b, n_dims, mode, n_ctx, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, false, false
6287
  );
6288
  }
6289
 
@@ -6295,14 +6302,15 @@ struct ggml_tensor * ggml_rope_inplace(
6295
  int mode,
6296
  int n_ctx) {
6297
  return ggml_rope_impl(
6298
- ctx, a, b, n_dims, mode, n_ctx, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, false, true
6299
  );
6300
  }
6301
 
6302
- struct ggml_tensor * ggml_rope_custom(
6303
  struct ggml_context * ctx,
6304
  struct ggml_tensor * a,
6305
  struct ggml_tensor * b,
 
6306
  int n_dims,
6307
  int mode,
6308
  int n_ctx,
@@ -6314,15 +6322,16 @@ struct ggml_tensor * ggml_rope_custom(
6314
  float beta_fast,
6315
  float beta_slow) {
6316
  return ggml_rope_impl(
6317
- ctx, a, b, n_dims, mode, n_ctx, n_orig_ctx, freq_base, freq_scale,
6318
  ext_factor, attn_factor, beta_fast, beta_slow, 0.0f, false, false
6319
  );
6320
  }
6321
 
6322
- struct ggml_tensor * ggml_rope_custom_inplace(
6323
  struct ggml_context * ctx,
6324
  struct ggml_tensor * a,
6325
  struct ggml_tensor * b,
 
6326
  int n_dims,
6327
  int mode,
6328
  int n_ctx,
@@ -6334,19 +6343,49 @@ struct ggml_tensor * ggml_rope_custom_inplace(
6334
  float beta_fast,
6335
  float beta_slow) {
6336
  return ggml_rope_impl(
6337
- ctx, a, b, n_dims, mode, n_ctx, n_orig_ctx, freq_base, freq_scale,
6338
  ext_factor, attn_factor, beta_fast, beta_slow, 0.0f, false, true
6339
  );
6340
  }
6341
 
6342
- struct ggml_tensor * ggml_rope_xpos_inplace(
6343
  struct ggml_context * ctx,
6344
  struct ggml_tensor * a,
6345
  struct ggml_tensor * b,
6346
  int n_dims,
6347
- float base,
6348
- bool down) {
6349
- return ggml_rope_impl(ctx, a, b, n_dims, 0, 0, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, base, down, true);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6350
  }
6351
 
6352
  // ggml_rope_back
@@ -6355,6 +6394,7 @@ struct ggml_tensor * ggml_rope_back(
6355
  struct ggml_context * ctx,
6356
  struct ggml_tensor * a,
6357
  struct ggml_tensor * b,
 
6358
  int n_dims,
6359
  int mode,
6360
  int n_ctx,
@@ -6370,6 +6410,7 @@ struct ggml_tensor * ggml_rope_back(
6370
  GGML_ASSERT(ggml_is_vector(b));
6371
  GGML_ASSERT(b->type == GGML_TYPE_I32);
6372
  GGML_ASSERT(a->ne[2] == b->ne[0]);
 
6373
 
6374
  GGML_ASSERT((mode & 4) == 0 && "ggml_rope_back() for ChatGLM not implemented yet");
6375
 
@@ -14304,6 +14345,7 @@ static void ggml_compute_forward_rope_f32(
14304
 
14305
  const struct ggml_tensor * src0 = dst->src[0];
14306
  const struct ggml_tensor * src1 = dst->src[1];
 
14307
 
14308
  if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
14309
  return;
@@ -14363,6 +14405,17 @@ static void ggml_compute_forward_rope_f32(
14363
  const bool is_neox = mode & 2;
14364
  const bool is_glm = mode & 4;
14365
 
 
 
 
 
 
 
 
 
 
 
 
14366
  // backward process uses inverse rotation by cos and sin.
14367
  // cos and sin build a rotation matrix, where the inverse is the transpose.
14368
  // this essentially just switches the sign of sin.
@@ -14439,10 +14492,11 @@ static void ggml_compute_forward_rope_f32(
14439
 
14440
  // simplified from `(ib * n_dims + ic) * inv_ndims`
14441
  float cur_rot = inv_ndims * ic - ib;
 
14442
 
14443
  float cos_theta, sin_theta;
14444
  rope_yarn(
14445
- theta_base, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor,
14446
  &cos_theta, &sin_theta
14447
  );
14448
  sin_theta *= sin_sign;
@@ -18387,6 +18441,7 @@ static struct ggml_tensor * ggml_sub_or_set(struct ggml_context * ctx, struct gg
18387
  static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor * tensor, struct ggml_hash_set zero_table) {
18388
  struct ggml_tensor * src0 = tensor->src[0];
18389
  struct ggml_tensor * src1 = tensor->src[1];
 
18390
 
18391
  switch (tensor->op) {
18392
  case GGML_OP_DUP:
@@ -18918,6 +18973,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
18918
  ggml_rope_back(ctx,
18919
  tensor->grad,
18920
  src1,
 
18921
  n_dims,
18922
  mode,
18923
  n_ctx,
@@ -18957,6 +19013,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
18957
  ggml_rope_impl(ctx,
18958
  tensor->grad,
18959
  src1,
 
18960
  n_dims,
18961
  mode,
18962
  n_ctx,
@@ -19038,7 +19095,6 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
19038
  masked);
19039
  }
19040
 
19041
- struct ggml_tensor * src2 = tensor->src[2];
19042
  const int64_t elem_q = ggml_nelements(src0);
19043
  const int64_t elem_k = ggml_nelements(src1);
19044
  const int64_t elem_v = ggml_nelements(src2);
 
6231
  struct ggml_context * ctx,
6232
  struct ggml_tensor * a,
6233
  struct ggml_tensor * b,
6234
+ struct ggml_tensor * c,
6235
  int n_dims,
6236
  int mode,
6237
  int n_ctx,
 
6249
  GGML_ASSERT(b->type == GGML_TYPE_I32);
6250
  GGML_ASSERT(a->ne[2] == b->ne[0]);
6251
 
6252
+ if (c) {
6253
+ GGML_ASSERT(c->type == GGML_TYPE_F32);
6254
+ GGML_ASSERT(c->ne[0] >= n_dims / 2);
6255
+ }
6256
+
6257
  bool is_node = false;
6258
 
6259
  if (a->grad) {
 
6277
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
6278
  result->src[0] = a;
6279
  result->src[1] = b;
6280
+ result->src[2] = c;
6281
 
6282
  return result;
6283
  }
 
6290
  int mode,
6291
  int n_ctx) {
6292
  return ggml_rope_impl(
6293
+ ctx, a, b, NULL, n_dims, mode, n_ctx, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, false, false
6294
  );
6295
  }
6296
 
 
6302
  int mode,
6303
  int n_ctx) {
6304
  return ggml_rope_impl(
6305
+ ctx, a, b, NULL, n_dims, mode, n_ctx, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, false, true
6306
  );
6307
  }
6308
 
6309
+ struct ggml_tensor * ggml_rope_ext(
6310
  struct ggml_context * ctx,
6311
  struct ggml_tensor * a,
6312
  struct ggml_tensor * b,
6313
+ struct ggml_tensor * c,
6314
  int n_dims,
6315
  int mode,
6316
  int n_ctx,
 
6322
  float beta_fast,
6323
  float beta_slow) {
6324
  return ggml_rope_impl(
6325
+ ctx, a, b, c, n_dims, mode, n_ctx, n_orig_ctx, freq_base, freq_scale,
6326
  ext_factor, attn_factor, beta_fast, beta_slow, 0.0f, false, false
6327
  );
6328
  }
6329
 
6330
+ struct ggml_tensor * ggml_rope_ext_inplace(
6331
  struct ggml_context * ctx,
6332
  struct ggml_tensor * a,
6333
  struct ggml_tensor * b,
6334
+ struct ggml_tensor * c,
6335
  int n_dims,
6336
  int mode,
6337
  int n_ctx,
 
6343
  float beta_fast,
6344
  float beta_slow) {
6345
  return ggml_rope_impl(
6346
+ ctx, a, b, c, n_dims, mode, n_ctx, n_orig_ctx, freq_base, freq_scale,
6347
  ext_factor, attn_factor, beta_fast, beta_slow, 0.0f, false, true
6348
  );
6349
  }
6350
 
6351
+ struct ggml_tensor * ggml_rope_custom(
6352
  struct ggml_context * ctx,
6353
  struct ggml_tensor * a,
6354
  struct ggml_tensor * b,
6355
  int n_dims,
6356
+ int mode,
6357
+ int n_ctx,
6358
+ int n_orig_ctx,
6359
+ float freq_base,
6360
+ float freq_scale,
6361
+ float ext_factor,
6362
+ float attn_factor,
6363
+ float beta_fast,
6364
+ float beta_slow) {
6365
+ return ggml_rope_impl(
6366
+ ctx, a, b, NULL, n_dims, mode, n_ctx, n_orig_ctx, freq_base, freq_scale,
6367
+ ext_factor, attn_factor, beta_fast, beta_slow, 0.0f, false, false
6368
+ );
6369
+ }
6370
+
6371
+ struct ggml_tensor * ggml_rope_custom_inplace(
6372
+ struct ggml_context * ctx,
6373
+ struct ggml_tensor * a,
6374
+ struct ggml_tensor * b,
6375
+ int n_dims,
6376
+ int mode,
6377
+ int n_ctx,
6378
+ int n_orig_ctx,
6379
+ float freq_base,
6380
+ float freq_scale,
6381
+ float ext_factor,
6382
+ float attn_factor,
6383
+ float beta_fast,
6384
+ float beta_slow) {
6385
+ return ggml_rope_impl(
6386
+ ctx, a, b, NULL, n_dims, mode, n_ctx, n_orig_ctx, freq_base, freq_scale,
6387
+ ext_factor, attn_factor, beta_fast, beta_slow, 0.0f, false, true
6388
+ );
6389
  }
6390
 
6391
  // ggml_rope_back
 
6394
  struct ggml_context * ctx,
6395
  struct ggml_tensor * a,
6396
  struct ggml_tensor * b,
6397
+ struct ggml_tensor * c,
6398
  int n_dims,
6399
  int mode,
6400
  int n_ctx,
 
6410
  GGML_ASSERT(ggml_is_vector(b));
6411
  GGML_ASSERT(b->type == GGML_TYPE_I32);
6412
  GGML_ASSERT(a->ne[2] == b->ne[0]);
6413
+ GGML_ASSERT(c == NULL && "freq factors not implemented yet");
6414
 
6415
  GGML_ASSERT((mode & 4) == 0 && "ggml_rope_back() for ChatGLM not implemented yet");
6416
 
 
14345
 
14346
  const struct ggml_tensor * src0 = dst->src[0];
14347
  const struct ggml_tensor * src1 = dst->src[1];
14348
+ const struct ggml_tensor * src2 = dst->src[2];
14349
 
14350
  if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
14351
  return;
 
14405
  const bool is_neox = mode & 2;
14406
  const bool is_glm = mode & 4;
14407
 
14408
+ const float * freq_factors = NULL;
14409
+ if (is_neox) {
14410
+ if (src2 != NULL) {
14411
+ GGML_ASSERT(src2->type == GGML_TYPE_F32);
14412
+ GGML_ASSERT(src2->ne[0] >= n_dims / 2);
14413
+ freq_factors = (const float *) src2->data;
14414
+ }
14415
+ } else {
14416
+ GGML_ASSERT(src2 == NULL && "TODO: freq_factors not implemented for mode 1");
14417
+ }
14418
+
14419
  // backward process uses inverse rotation by cos and sin.
14420
  // cos and sin build a rotation matrix, where the inverse is the transpose.
14421
  // this essentially just switches the sign of sin.
 
14492
 
14493
  // simplified from `(ib * n_dims + ic) * inv_ndims`
14494
  float cur_rot = inv_ndims * ic - ib;
14495
+ float freq_factor = freq_factors ? freq_factors[ic/2] : 1.0f;
14496
 
14497
  float cos_theta, sin_theta;
14498
  rope_yarn(
14499
+ theta_base/freq_factor, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor,
14500
  &cos_theta, &sin_theta
14501
  );
14502
  sin_theta *= sin_sign;
 
18441
  static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor * tensor, struct ggml_hash_set zero_table) {
18442
  struct ggml_tensor * src0 = tensor->src[0];
18443
  struct ggml_tensor * src1 = tensor->src[1];
18444
+ struct ggml_tensor * src2 = tensor->src[2];
18445
 
18446
  switch (tensor->op) {
18447
  case GGML_OP_DUP:
 
18973
  ggml_rope_back(ctx,
18974
  tensor->grad,
18975
  src1,
18976
+ src2,
18977
  n_dims,
18978
  mode,
18979
  n_ctx,
 
19013
  ggml_rope_impl(ctx,
19014
  tensor->grad,
19015
  src1,
19016
+ src2,
19017
  n_dims,
19018
  mode,
19019
  n_ctx,
 
19095
  masked);
19096
  }
19097
 
 
19098
  const int64_t elem_q = ggml_nelements(src0);
19099
  const int64_t elem_k = ggml_nelements(src1);
19100
  const int64_t elem_v = ggml_nelements(src2);
ggml.h CHANGED
@@ -1465,6 +1465,7 @@ extern "C" {
1465
  // if mode & 4 == 1, ChatGLM style
1466
  //
1467
  // b is an int32 vector with size a->ne[2], it contains the positions
 
1468
  GGML_API struct ggml_tensor * ggml_rope(
1469
  struct ggml_context * ctx,
1470
  struct ggml_tensor * a,
@@ -1483,10 +1484,11 @@ extern "C" {
1483
  int n_ctx);
1484
 
1485
  // custom RoPE
1486
- GGML_API struct ggml_tensor * ggml_rope_custom(
1487
  struct ggml_context * ctx,
1488
  struct ggml_tensor * a,
1489
  struct ggml_tensor * b,
 
1490
  int n_dims,
1491
  int mode,
1492
  int n_ctx,
@@ -1499,10 +1501,11 @@ extern "C" {
1499
  float beta_slow);
1500
 
1501
  // in-place, returns view(a)
1502
- GGML_API struct ggml_tensor * ggml_rope_custom_inplace(
1503
  struct ggml_context * ctx,
1504
  struct ggml_tensor * a,
1505
  struct ggml_tensor * b,
 
1506
  int n_dims,
1507
  int mode,
1508
  int n_ctx,
@@ -1514,18 +1517,41 @@ extern "C" {
1514
  float beta_fast,
1515
  float beta_slow);
1516
 
1517
- // compute correction dims for YaRN RoPE scaling
1518
- GGML_CALL void ggml_rope_yarn_corr_dims(
1519
- int n_dims, int n_orig_ctx, float freq_base, float beta_fast, float beta_slow, float dims[2]);
 
 
 
 
 
 
 
 
 
 
 
 
1520
 
1521
- // xPos RoPE, in-place, returns view(a)
1522
- GGML_API struct ggml_tensor * ggml_rope_xpos_inplace(
1523
  struct ggml_context * ctx,
1524
  struct ggml_tensor * a,
1525
  struct ggml_tensor * b,
1526
  int n_dims,
1527
- float base,
1528
- bool down);
 
 
 
 
 
 
 
 
 
 
 
 
1529
 
1530
  // rotary position embedding backward, i.e compute dx from dy
1531
  // a - dy
@@ -1533,6 +1559,7 @@ extern "C" {
1533
  struct ggml_context * ctx,
1534
  struct ggml_tensor * a,
1535
  struct ggml_tensor * b,
 
1536
  int n_dims,
1537
  int mode,
1538
  int n_ctx,
 
1465
  // if mode & 4 == 1, ChatGLM style
1466
  //
1467
  // b is an int32 vector with size a->ne[2], it contains the positions
1468
+ // c is freq factors (e.g. phi3-128k), (optional)
1469
  GGML_API struct ggml_tensor * ggml_rope(
1470
  struct ggml_context * ctx,
1471
  struct ggml_tensor * a,
 
1484
  int n_ctx);
1485
 
1486
  // custom RoPE
1487
+ GGML_API struct ggml_tensor * ggml_rope_ext(
1488
  struct ggml_context * ctx,
1489
  struct ggml_tensor * a,
1490
  struct ggml_tensor * b,
1491
+ struct ggml_tensor * c,
1492
  int n_dims,
1493
  int mode,
1494
  int n_ctx,
 
1501
  float beta_slow);
1502
 
1503
  // in-place, returns view(a)
1504
+ GGML_API struct ggml_tensor * ggml_rope_ext_inplace(
1505
  struct ggml_context * ctx,
1506
  struct ggml_tensor * a,
1507
  struct ggml_tensor * b,
1508
+ struct ggml_tensor * c,
1509
  int n_dims,
1510
  int mode,
1511
  int n_ctx,
 
1517
  float beta_fast,
1518
  float beta_slow);
1519
 
1520
+ GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_rope_custom(
1521
+ struct ggml_context * ctx,
1522
+ struct ggml_tensor * a,
1523
+ struct ggml_tensor * b,
1524
+ int n_dims,
1525
+ int mode,
1526
+ int n_ctx,
1527
+ int n_orig_ctx,
1528
+ float freq_base,
1529
+ float freq_scale,
1530
+ float ext_factor,
1531
+ float attn_factor,
1532
+ float beta_fast,
1533
+ float beta_slow),
1534
+ "use ggml_rope_ext instead");
1535
 
1536
+ GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_rope_custom_inplace(
 
1537
  struct ggml_context * ctx,
1538
  struct ggml_tensor * a,
1539
  struct ggml_tensor * b,
1540
  int n_dims,
1541
+ int mode,
1542
+ int n_ctx,
1543
+ int n_orig_ctx,
1544
+ float freq_base,
1545
+ float freq_scale,
1546
+ float ext_factor,
1547
+ float attn_factor,
1548
+ float beta_fast,
1549
+ float beta_slow),
1550
+ "use ggml_rope_ext_inplace instead");
1551
+
1552
+ // compute correction dims for YaRN RoPE scaling
1553
+ GGML_CALL void ggml_rope_yarn_corr_dims(
1554
+ int n_dims, int n_orig_ctx, float freq_base, float beta_fast, float beta_slow, float dims[2]);
1555
 
1556
  // rotary position embedding backward, i.e compute dx from dy
1557
  // a - dy
 
1559
  struct ggml_context * ctx,
1560
  struct ggml_tensor * a,
1561
  struct ggml_tensor * b,
1562
+ struct ggml_tensor * c,
1563
  int n_dims,
1564
  int mode,
1565
  int n_ctx,