ggerganov commited on
Commit
ded0c68
·
1 Parent(s): 6124287

ggml : refactor rope norm/neox (llama/7634)

Browse files

* ggml : unify rope norm/neox (CPU)

* ggml : fix compile warning

* ggml : remove GLM rope mode

ggml-ci

* metal : better rope implementation

ggml-ci

* cuda : better rope implementation

ggml-ci

* naming : n_orig_ctx -> n_ctx_orig

ggml-ci

* dev : add reminders to update backends

ggml-ci

* vulkan : fix ggml_rope_ext() usage

* cuda : fix array size + indents

ggml-ci

Files changed (8) hide show
  1. ggml-cuda/rope.cu +107 -166
  2. ggml-kompute.cpp +8 -5
  3. ggml-metal.m +28 -24
  4. ggml-metal.metal +83 -59
  5. ggml-sycl.cpp +7 -67
  6. ggml-vulkan.cpp +8 -14
  7. ggml.c +97 -233
  8. ggml.h +9 -27
ggml-cuda/rope.cu CHANGED
@@ -1,7 +1,7 @@
1
  #include "rope.cuh"
2
 
3
  struct rope_corr_dims {
4
- float v[4];
5
  };
6
 
7
  static __device__ float rope_yarn_ramp(const float low, const float high, const int i0) {
@@ -13,8 +13,7 @@ static __device__ float rope_yarn_ramp(const float low, const float high, const
13
  // MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng.
14
  static __device__ void rope_yarn(
15
  float theta_extrap, float freq_scale, rope_corr_dims corr_dims, int64_t i0, float ext_factor, float mscale,
16
- float * cos_theta, float * sin_theta
17
- ) {
18
  // Get n-d rotational scaling corrected for extrapolation
19
  float theta_interp = freq_scale * theta_extrap;
20
  float theta = theta_interp;
@@ -29,27 +28,38 @@ static __device__ void rope_yarn(
29
  *sin_theta = sinf(theta) * mscale;
30
  }
31
 
32
- // rope == RoPE == rotary positional embedding
33
- template<typename T, bool has_pos>
34
- static __global__ void rope(
35
- const T * x, T * dst, int ncols, const int32_t * pos, float freq_scale, int p_delta_rows, float freq_base,
36
- float ext_factor, float attn_factor, rope_corr_dims corr_dims
37
- ) {
38
- const int col = 2*(blockDim.y*blockIdx.y + threadIdx.y);
39
 
40
- if (col >= ncols) {
41
  return;
42
  }
43
 
44
  const int row = blockDim.x*blockIdx.x + threadIdx.x;
45
- const int i = row*ncols + col;
 
 
 
 
 
 
 
 
 
 
46
  const int i2 = row/p_delta_rows;
47
 
48
- const int p = has_pos ? pos[i2] : 0;
49
- const float theta_base = p*powf(freq_base, -float(col)/ncols);
 
50
 
51
- float cos_theta, sin_theta;
52
- rope_yarn(theta_base, freq_scale, corr_dims, col, ext_factor, attn_factor, &cos_theta, &sin_theta);
 
 
53
 
54
  const float x0 = x[i + 0];
55
  const float x1 = x[i + 1];
@@ -58,23 +68,20 @@ static __global__ void rope(
58
  dst[i + 1] = x0*sin_theta + x1*cos_theta;
59
  }
60
 
61
- template<typename T, bool has_pos, bool has_freq_facs>
62
  static __global__ void rope_neox(
63
- const T * x, T * dst, int ncols, int n_dims, const int32_t * pos, float freq_scale, int p_delta_rows,
64
- float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, const float * freq_factors
65
- ) {
66
- const int col = 2*(blockDim.y*blockIdx.y + threadIdx.y);
67
 
68
- if (col >= ncols) {
69
  return;
70
  }
71
 
72
  const int row = blockDim.x*blockIdx.x + threadIdx.x;
73
- const int ib = col / n_dims;
74
- const int ic = col % n_dims;
75
 
76
- if (ib > 0) {
77
- const int i = row*ncols + ib*n_dims + ic;
78
 
79
  dst[i + 0] = x[i + 0];
80
  dst[i + 1] = x[i + 1];
@@ -82,16 +89,17 @@ static __global__ void rope_neox(
82
  return;
83
  }
84
 
85
- const int i = row*ncols + ib*n_dims + ic/2;
86
  const int i2 = row/p_delta_rows;
87
 
88
- const int p = has_pos ? pos[i2] : 0;
89
- const float freq_factor = has_freq_facs ? freq_factors[ic/2] : 1.0f;
90
 
91
- const float theta_base = p*powf(theta_scale, col/2.0f)/freq_factor;
92
 
93
- float cos_theta, sin_theta;
94
- rope_yarn(theta_base, freq_scale, corr_dims, ic, ext_factor, attn_factor, &cos_theta, &sin_theta);
 
 
95
 
96
  const float x0 = x[i + 0];
97
  const float x1 = x[i + n_dims/2];
@@ -100,144 +108,81 @@ static __global__ void rope_neox(
100
  dst[i + n_dims/2] = x0*sin_theta + x1*cos_theta;
101
  }
102
 
103
- static __global__ void rope_glm_f32(
104
- const float * x, float * dst, int ncols, const int32_t * pos, float freq_scale, int p_delta_rows, float freq_base,
105
- int n_ctx
106
- ) {
107
- const int col = blockDim.x*blockIdx.x + threadIdx.x;
108
- const int half_n_dims = ncols/4;
109
-
110
- if (col >= half_n_dims) {
111
- return;
112
- }
113
-
114
- const int row = blockDim.y*blockIdx.y + threadIdx.y;
115
- const int i = row*ncols + col;
116
- const int i2 = row/p_delta_rows;
117
-
118
- const float col_theta_scale = powf(freq_base, -2.0f*col/ncols);
119
- // FIXME: this is likely wrong
120
- const int p = pos != nullptr ? pos[i2] : 0;
121
-
122
- const float theta = min(p, n_ctx - 2)*freq_scale*col_theta_scale;
123
- const float sin_theta = sinf(theta);
124
- const float cos_theta = cosf(theta);
125
-
126
- const float x0 = x[i + 0];
127
- const float x1 = x[i + half_n_dims];
128
-
129
- dst[i + 0] = x0*cos_theta - x1*sin_theta;
130
- dst[i + half_n_dims] = x0*sin_theta + x1*cos_theta;
131
-
132
- const float block_theta = ((float)max(p - n_ctx - 2, 0))*col_theta_scale;
133
- const float sin_block_theta = sinf(block_theta);
134
- const float cos_block_theta = cosf(block_theta);
135
-
136
- const float x2 = x[i + half_n_dims * 2];
137
- const float x3 = x[i + half_n_dims * 3];
138
-
139
- dst[i + half_n_dims * 2] = x2*cos_block_theta - x3*sin_block_theta;
140
- dst[i + half_n_dims * 3] = x2*sin_block_theta + x3*cos_block_theta;
141
- }
142
-
143
-
144
  template<typename T>
145
- static void rope_cuda(
146
- const T * x, T * dst, int ncols, int nrows, const int32_t * pos, float freq_scale, int p_delta_rows,
147
- float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, cudaStream_t stream
148
- ) {
149
- GGML_ASSERT(ncols % 2 == 0);
150
  const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
151
- const int num_blocks_x = (ncols + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
152
- const dim3 block_nums(nrows, num_blocks_x, 1);
153
- if (pos == nullptr) {
154
- rope<T, false><<<block_nums, block_dims, 0, stream>>>(
155
- x, dst, ncols, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims
156
- );
 
 
 
 
157
  } else {
158
- rope<T, true><<<block_nums, block_dims, 0, stream>>>(
159
- x, dst, ncols, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims
160
- );
 
161
  }
162
  }
163
 
164
  template<typename T>
165
  static void rope_neox_cuda(
166
- const T * x, T * dst, int ncols, int n_dims, int nrows, const int32_t * pos, float freq_scale, int p_delta_rows,
167
- float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream
168
- ) {
169
- GGML_ASSERT(ncols % 2 == 0);
170
  const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
171
- const int num_blocks_x = (ncols + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
172
- const dim3 block_nums(nrows, num_blocks_x, 1);
173
 
174
  const float theta_scale = powf(freq_base, -2.0f/n_dims);
175
 
176
- if (pos == nullptr) {
177
- if (freq_factors == nullptr) {
178
- rope_neox<T, false, false><<<block_nums, block_dims, 0, stream>>>(
179
- x, dst, ncols, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
180
- theta_scale, freq_factors
181
- );
182
- } else {
183
- rope_neox<T, false, true><<<block_nums, block_dims, 0, stream>>>(
184
- x, dst, ncols, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
185
  theta_scale, freq_factors
186
  );
187
- }
188
  } else {
189
- if (freq_factors == nullptr) {
190
- rope_neox<T, true, false><<<block_nums, block_dims, 0, stream>>>(
191
- x, dst, ncols, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
192
- theta_scale, freq_factors
193
- );
194
- } else {
195
- rope_neox<T, true, true><<<block_nums, block_dims, 0, stream>>>(
196
- x, dst, ncols, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
197
  theta_scale, freq_factors
198
  );
199
- }
200
  }
201
  }
202
 
203
- static void rope_glm_f32_cuda(
204
- const float * x, float * dst, int ncols, int nrows, const int32_t * pos, float freq_scale, int p_delta_rows,
205
- float freq_base, int n_ctx, cudaStream_t stream
206
- ) {
207
- GGML_ASSERT(ncols % 4 == 0);
208
- const dim3 block_dims(CUDA_ROPE_BLOCK_SIZE/4, 1, 1);
209
- const int num_blocks_x = (ncols + CUDA_ROPE_BLOCK_SIZE - 1) / CUDA_ROPE_BLOCK_SIZE;
210
- const dim3 block_nums(num_blocks_x, nrows, 1);
211
- rope_glm_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols, pos, freq_scale, p_delta_rows, freq_base, n_ctx);
212
- }
213
-
214
- static void rope_cuda_f16(
215
- const half * x, half * dst, int ncols, int nrows, const int32_t * pos, float freq_scale, int p_delta_rows,
216
- float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, cudaStream_t stream) {
217
 
218
- rope_cuda<half>(x, dst, ncols, nrows, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, stream);
219
  }
220
 
221
- static void rope_cuda_f32(
222
- const float * x, float * dst, int ncols, int nrows, const int32_t * pos, float freq_scale, int p_delta_rows,
223
- float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, cudaStream_t stream) {
224
 
225
- rope_cuda<float>(x, dst, ncols, nrows, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, stream);
226
  }
227
 
228
  static void rope_neox_cuda_f16(
229
- const half * x, half * dst, int ncols, int n_dims, int nrows, const int32_t * pos, float freq_scale, int p_delta_rows,
230
  float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream) {
231
 
232
- rope_neox_cuda<half>(x, dst, ncols, n_dims, nrows, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream);
233
  }
234
 
235
  static void rope_neox_cuda_f32(
236
- const float * x, float * dst, int ncols, int n_dims, int nrows, const int32_t * pos, float freq_scale, int p_delta_rows,
237
  float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream
238
  ) {
239
 
240
- rope_neox_cuda<float>(x, dst, ncols, n_dims, nrows, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream);
241
  }
242
 
243
  void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
@@ -258,16 +203,22 @@ void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
258
 
259
  const int64_t ne00 = src0->ne[0];
260
  const int64_t ne01 = src0->ne[1];
261
- const int64_t nrows = ggml_nrows(src0);
262
 
263
- //const int n_past = ((int32_t *) dst->op_params)[0];
264
- const int n_dims = ((int32_t *) dst->op_params)[1];
265
- const int mode = ((int32_t *) dst->op_params)[2];
266
- const int n_ctx = ((int32_t *) dst->op_params)[3];
267
- const int n_orig_ctx = ((int32_t *) dst->op_params)[4];
268
 
269
  // RoPE alteration for extended context
270
- float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
 
 
 
 
 
 
271
  memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float));
272
  memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float));
273
  memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float));
@@ -275,38 +226,28 @@ void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
275
  memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
276
  memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
277
 
278
- const float * freq_factors = nullptr;
279
- const int32_t * pos = nullptr;
280
-
281
  const bool is_neox = mode & 2;
282
- const bool is_glm = mode & 4;
283
 
284
- pos = (const int32_t *) src1_d;
285
 
286
- if (is_neox) {
287
- if (src2 != nullptr) {
288
- freq_factors = (const float *) src2->data;
289
- }
290
- } else {
291
- GGML_ASSERT(src2 == nullptr && "TODO: freq_factors not implemented for !is_neox");
292
  }
293
 
294
  rope_corr_dims corr_dims;
295
- ggml_rope_yarn_corr_dims(n_dims, n_orig_ctx, freq_base, beta_fast, beta_slow, corr_dims.v);
296
 
297
  // compute
298
- if (is_glm) {
299
- GGML_ASSERT(false);
300
- rope_glm_f32_cuda(src0_d, dst_d, ne00, nrows, pos, freq_scale, ne01, freq_base, n_ctx, stream);
301
- } else if (is_neox) {
302
  if (src0->type == GGML_TYPE_F32) {
303
  rope_neox_cuda_f32(
304
- (const float *)src0_d, (float *)dst_d, ne00, n_dims, nrows, pos, freq_scale, ne01, freq_base, ext_factor,
305
  attn_factor, corr_dims, freq_factors, stream
306
  );
307
  } else if (src0->type == GGML_TYPE_F16) {
308
  rope_neox_cuda_f16(
309
- (const half *)src0_d, (half *)dst_d, ne00, n_dims, nrows, pos, freq_scale, ne01, freq_base, ext_factor,
310
  attn_factor, corr_dims, freq_factors, stream
311
  );
312
  } else {
@@ -314,14 +255,14 @@ void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
314
  }
315
  } else {
316
  if (src0->type == GGML_TYPE_F32) {
317
- rope_cuda_f32(
318
- (const float *)src0_d, (float *)dst_d, ne00, nrows, pos, freq_scale, ne01, freq_base, ext_factor,
319
- attn_factor, corr_dims, stream
320
  );
321
  } else if (src0->type == GGML_TYPE_F16) {
322
- rope_cuda_f16(
323
- (const half *)src0_d, (half *)dst_d, ne00, nrows, pos, freq_scale, ne01, freq_base, ext_factor,
324
- attn_factor, corr_dims, stream
325
  );
326
  } else {
327
  GGML_ASSERT(false);
 
1
  #include "rope.cuh"
2
 
3
  struct rope_corr_dims {
4
+ float v[2];
5
  };
6
 
7
  static __device__ float rope_yarn_ramp(const float low, const float high, const int i0) {
 
13
  // MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng.
14
  static __device__ void rope_yarn(
15
  float theta_extrap, float freq_scale, rope_corr_dims corr_dims, int64_t i0, float ext_factor, float mscale,
16
+ float * cos_theta, float * sin_theta) {
 
17
  // Get n-d rotational scaling corrected for extrapolation
18
  float theta_interp = freq_scale * theta_extrap;
19
  float theta = theta_interp;
 
28
  *sin_theta = sinf(theta) * mscale;
29
  }
30
 
31
+ template<typename T, bool has_ff>
32
+ static __global__ void rope_norm(
33
+ const T * x, T * dst, int ne0, int n_dims, const int32_t * pos, float freq_scale, int p_delta_rows,
34
+ float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, const float * freq_factors) {
35
+ const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y);
 
 
36
 
37
+ if (i0 >= ne0) {
38
  return;
39
  }
40
 
41
  const int row = blockDim.x*blockIdx.x + threadIdx.x;
42
+
43
+ if (i0 >= n_dims) {
44
+ const int i = row*ne0 + i0;
45
+
46
+ dst[i + 0] = x[i + 0];
47
+ dst[i + 1] = x[i + 1];
48
+
49
+ return;
50
+ }
51
+
52
+ const int i = row*ne0 + i0;
53
  const int i2 = row/p_delta_rows;
54
 
55
+ const float theta_base = pos[i2]*powf(theta_scale, i0/2.0f);
56
+
57
+ const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f;
58
 
59
+ float cos_theta;
60
+ float sin_theta;
61
+
62
+ rope_yarn(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
63
 
64
  const float x0 = x[i + 0];
65
  const float x1 = x[i + 1];
 
68
  dst[i + 1] = x0*sin_theta + x1*cos_theta;
69
  }
70
 
71
+ template<typename T, bool has_ff>
72
  static __global__ void rope_neox(
73
+ const T * x, T * dst, int ne0, int n_dims, const int32_t * pos, float freq_scale, int p_delta_rows,
74
+ float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, const float * freq_factors) {
75
+ const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y);
 
76
 
77
+ if (i0 >= ne0) {
78
  return;
79
  }
80
 
81
  const int row = blockDim.x*blockIdx.x + threadIdx.x;
 
 
82
 
83
+ if (i0 >= n_dims) {
84
+ const int i = row*ne0 + i0;
85
 
86
  dst[i + 0] = x[i + 0];
87
  dst[i + 1] = x[i + 1];
 
89
  return;
90
  }
91
 
92
+ const int i = row*ne0 + i0/2;
93
  const int i2 = row/p_delta_rows;
94
 
95
+ const float theta_base = pos[i2]*powf(theta_scale, i0/2.0f);
 
96
 
97
+ const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f;
98
 
99
+ float cos_theta;
100
+ float sin_theta;
101
+
102
+ rope_yarn(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
103
 
104
  const float x0 = x[i + 0];
105
  const float x1 = x[i + n_dims/2];
 
108
  dst[i + n_dims/2] = x0*sin_theta + x1*cos_theta;
109
  }
110
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
111
  template<typename T>
112
+ static void rope_norm_cuda(
113
+ const T * x, T * dst, int ne0, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
114
+ float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream) {
115
+ GGML_ASSERT(ne0 % 2 == 0);
 
116
  const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
117
+ const int n_blocks_x = (ne0 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
118
+ const dim3 block_nums(nr, n_blocks_x, 1);
119
+
120
+ const float theta_scale = powf(freq_base, -2.0f/n_dims);
121
+
122
+ if (freq_factors == nullptr) {
123
+ rope_norm<T, false><<<block_nums, block_dims, 0, stream>>>(
124
+ x, dst, ne0, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
125
+ theta_scale, freq_factors
126
+ );
127
  } else {
128
+ rope_norm<T, true><<<block_nums, block_dims, 0, stream>>>(
129
+ x, dst, ne0, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
130
+ theta_scale, freq_factors
131
+ );
132
  }
133
  }
134
 
135
  template<typename T>
136
  static void rope_neox_cuda(
137
+ const T * x, T * dst, int ne0, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
138
+ float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream) {
139
+ GGML_ASSERT(ne0 % 2 == 0);
 
140
  const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
141
+ const int n_blocks_x = (ne0 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
142
+ const dim3 block_nums(nr, n_blocks_x, 1);
143
 
144
  const float theta_scale = powf(freq_base, -2.0f/n_dims);
145
 
146
+ if (freq_factors == nullptr) {
147
+ rope_neox<T, false><<<block_nums, block_dims, 0, stream>>>(
148
+ x, dst, ne0, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
 
 
 
 
 
 
149
  theta_scale, freq_factors
150
  );
 
151
  } else {
152
+ rope_neox<T, true><<<block_nums, block_dims, 0, stream>>>(
153
+ x, dst, ne0, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
 
 
 
 
 
 
154
  theta_scale, freq_factors
155
  );
 
156
  }
157
  }
158
 
159
+ static void rope_norm_cuda_f16(
160
+ const half * x, half * dst, int ne0, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
161
+ float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream) {
 
 
 
 
 
 
 
 
 
 
 
162
 
163
+ rope_norm_cuda<half>(x, dst, ne0, n_dims, nr, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream);
164
  }
165
 
166
+ static void rope_norm_cuda_f32(
167
+ const float * x, float * dst, int ne0, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
168
+ float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream) {
169
 
170
+ rope_norm_cuda<float>(x, dst, ne0, n_dims, nr, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream);
171
  }
172
 
173
  static void rope_neox_cuda_f16(
174
+ const half * x, half * dst, int ne0, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
175
  float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream) {
176
 
177
+ rope_neox_cuda<half>(x, dst, ne0, n_dims, nr, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream);
178
  }
179
 
180
  static void rope_neox_cuda_f32(
181
+ const float * x, float * dst, int ne0, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
182
  float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream
183
  ) {
184
 
185
+ rope_neox_cuda<float>(x, dst, ne0, n_dims, nr, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream);
186
  }
187
 
188
  void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
 
203
 
204
  const int64_t ne00 = src0->ne[0];
205
  const int64_t ne01 = src0->ne[1];
206
+ const int64_t nr = ggml_nrows(src0);
207
 
208
+ //const int n_past = ((int32_t *) dst->op_params)[0];
209
+ const int n_dims = ((int32_t *) dst->op_params)[1];
210
+ const int mode = ((int32_t *) dst->op_params)[2];
211
+ //const int n_ctx = ((int32_t *) dst->op_params)[3];
212
+ const int n_ctx_orig = ((int32_t *) dst->op_params)[4];
213
 
214
  // RoPE alteration for extended context
215
+ float freq_base;
216
+ float freq_scale;
217
+ float ext_factor;
218
+ float attn_factor;
219
+ float beta_fast;
220
+ float beta_slow;
221
+
222
  memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float));
223
  memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float));
224
  memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float));
 
226
  memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
227
  memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
228
 
 
 
 
229
  const bool is_neox = mode & 2;
 
230
 
231
+ const int32_t * pos = (const int32_t *) src1_d;
232
 
233
+ const float * freq_factors = nullptr;
234
+ if (src2 != nullptr) {
235
+ freq_factors = (const float *) src2->data;
 
 
 
236
  }
237
 
238
  rope_corr_dims corr_dims;
239
+ ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims.v);
240
 
241
  // compute
242
+ if (is_neox) {
 
 
 
243
  if (src0->type == GGML_TYPE_F32) {
244
  rope_neox_cuda_f32(
245
+ (const float *)src0_d, (float *)dst_d, ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
246
  attn_factor, corr_dims, freq_factors, stream
247
  );
248
  } else if (src0->type == GGML_TYPE_F16) {
249
  rope_neox_cuda_f16(
250
+ (const half *)src0_d, (half *)dst_d, ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
251
  attn_factor, corr_dims, freq_factors, stream
252
  );
253
  } else {
 
255
  }
256
  } else {
257
  if (src0->type == GGML_TYPE_F32) {
258
+ rope_norm_cuda_f32(
259
+ (const float *)src0_d, (float *)dst_d, ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
260
+ attn_factor, corr_dims, freq_factors, stream
261
  );
262
  } else if (src0->type == GGML_TYPE_F16) {
263
+ rope_norm_cuda_f16(
264
+ (const half *)src0_d, (half *)dst_d, ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
265
+ attn_factor, corr_dims, freq_factors, stream
266
  );
267
  } else {
268
  GGML_ASSERT(false);
ggml-kompute.cpp CHANGED
@@ -1192,7 +1192,7 @@ static void ggml_vk_rope(
1192
  const std::shared_ptr<kp::Tensor>& inB,
1193
  const std::shared_ptr<kp::Tensor>& out,
1194
  uint32_t inAOff, uint32_t inBOff, uint32_t outOff,
1195
- ggml_type src0t, int32_t n_dims, int32_t mode, int32_t n_orig_ctx,
1196
  float freq_base, float freq_scale, float ext_factor, float attn_factor, float beta_fast, float beta_slow,
1197
  int32_t ne01, int32_t ne02, int32_t ne03,
1198
  uint32_t nb00, uint32_t nb01, uint32_t nb02, uint32_t nb03,
@@ -1221,14 +1221,14 @@ static void ggml_vk_rope(
1221
 
1222
  struct PushConstants {
1223
  uint32_t inAOff, inBOff, outOff;
1224
- int32_t n_dims, mode, n_orig_ctx;
1225
  float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
1226
  uint32_t nb00, nb01, nb02, nb03;
1227
  int32_t ne0;
1228
  uint32_t nb0, nb1, nb2, nb3;
1229
  } pushConsts {
1230
  safe_divide(inAOff, type_size), safe_divide(inBOff, 4), safe_divide(outOff, type_size),
1231
- n_dims, mode, n_orig_ctx,
1232
  freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow,
1233
  nb00, nb01, nb02, nb03,
1234
  ne0,
@@ -1692,13 +1692,16 @@ static void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml
1692
  #pragma message(" https://github.com/ggerganov/llama.cpp/pull/7225")
1693
  GGML_ASSERT(dst->src[2] == nullptr && "phi3 frequency factors not implemented yet");
1694
 
 
 
 
1695
  GGML_ASSERT(ne10 == ne02);
1696
  GGML_ASSERT(src0t == dstt);
1697
  // const int n_past = ((int32_t *) dst->op_params)[0];
1698
  const int n_dims = ((int32_t *) dst->op_params)[1];
1699
  const int mode = ((int32_t *) dst->op_params)[2];
1700
  // skip 3, n_ctx used in GLM RoPE, unimplemented in Vulkan
1701
- const int n_orig_ctx = ((int32_t *) dst->op_params)[4];
1702
 
1703
  float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
1704
  memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float));
@@ -1708,7 +1711,7 @@ static void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml
1708
  memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
1709
  memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
1710
  ggml_vk_rope(
1711
- seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, src0t, n_dims, mode, n_orig_ctx,
1712
  freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow,
1713
  ne01, ne02, ne03, nb00, nb01, nb02, nb03, ne0, nb0, nb1, nb2, nb3
1714
  );
 
1192
  const std::shared_ptr<kp::Tensor>& inB,
1193
  const std::shared_ptr<kp::Tensor>& out,
1194
  uint32_t inAOff, uint32_t inBOff, uint32_t outOff,
1195
+ ggml_type src0t, int32_t n_dims, int32_t mode, int32_t n_ctx_orig,
1196
  float freq_base, float freq_scale, float ext_factor, float attn_factor, float beta_fast, float beta_slow,
1197
  int32_t ne01, int32_t ne02, int32_t ne03,
1198
  uint32_t nb00, uint32_t nb01, uint32_t nb02, uint32_t nb03,
 
1221
 
1222
  struct PushConstants {
1223
  uint32_t inAOff, inBOff, outOff;
1224
+ int32_t n_dims, mode, n_ctx_orig;
1225
  float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
1226
  uint32_t nb00, nb01, nb02, nb03;
1227
  int32_t ne0;
1228
  uint32_t nb0, nb1, nb2, nb3;
1229
  } pushConsts {
1230
  safe_divide(inAOff, type_size), safe_divide(inBOff, 4), safe_divide(outOff, type_size),
1231
+ n_dims, mode, n_ctx_orig,
1232
  freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow,
1233
  nb00, nb01, nb02, nb03,
1234
  ne0,
 
1692
  #pragma message(" https://github.com/ggerganov/llama.cpp/pull/7225")
1693
  GGML_ASSERT(dst->src[2] == nullptr && "phi3 frequency factors not implemented yet");
1694
 
1695
+ #pragma message("TODO: update rope NORM mode to match NEOX mode")
1696
+ #pragma message(" https://github.com/ggerganov/llama.cpp/pull/7634")
1697
+
1698
  GGML_ASSERT(ne10 == ne02);
1699
  GGML_ASSERT(src0t == dstt);
1700
  // const int n_past = ((int32_t *) dst->op_params)[0];
1701
  const int n_dims = ((int32_t *) dst->op_params)[1];
1702
  const int mode = ((int32_t *) dst->op_params)[2];
1703
  // skip 3, n_ctx used in GLM RoPE, unimplemented in Vulkan
1704
+ const int n_ctx_orig = ((int32_t *) dst->op_params)[4];
1705
 
1706
  float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
1707
  memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float));
 
1711
  memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
1712
  memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
1713
  ggml_vk_rope(
1714
+ seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, src0t, n_dims, mode, n_ctx_orig,
1715
  freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow,
1716
  ne01, ne02, ne03, nb00, nb01, nb02, nb03, ne0, nb0, nb1, nb2, nb3
1717
  );
ggml-metal.m CHANGED
@@ -172,8 +172,10 @@ enum ggml_metal_kernel_type {
172
  GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F32,
173
  GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32,
174
  GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32,
175
- GGML_METAL_KERNEL_TYPE_ROPE_F32,
176
- GGML_METAL_KERNEL_TYPE_ROPE_F16,
 
 
177
  GGML_METAL_KERNEL_TYPE_IM2COL_F16,
178
  GGML_METAL_KERNEL_TYPE_IM2COL_F32,
179
  GGML_METAL_KERNEL_TYPE_UPSCALE_F32,
@@ -626,8 +628,10 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
626
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F32, mul_mm_id_iq1_m_f32, ctx->support_simdgroup_mm);
627
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32, mul_mm_id_iq4_nl_f32, ctx->support_simdgroup_mm);
628
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32, mul_mm_id_iq4_xs_f32, ctx->support_simdgroup_mm);
629
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F32, rope_f32, true);
630
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F16, rope_f16, true);
 
 
631
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F16, im2col_f16, true);
632
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F32, im2col_f32, true);
633
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_UPSCALE_F32, upscale_f32, true);
@@ -2285,7 +2289,7 @@ static enum ggml_status ggml_metal_graph_compute(
2285
  const int n_dims = ((int32_t *) dst->op_params)[1];
2286
  const int mode = ((int32_t *) dst->op_params)[2];
2287
  // skip 3, n_ctx, used in GLM RoPE, unimplemented in metal
2288
- const int n_orig_ctx = ((int32_t *) dst->op_params)[4];
2289
 
2290
  float freq_base;
2291
  float freq_scale;
@@ -2302,22 +2306,23 @@ static enum ggml_status ggml_metal_graph_compute(
2302
  memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
2303
 
2304
  const bool is_neox = mode & 2;
2305
- const bool is_glm = mode & 4;
2306
 
2307
- GGML_ASSERT(!is_glm && "GLM RoPE not implemented in Metal");
2308
 
2309
  if (!is_neox) {
2310
- GGML_ASSERT(id_src2 == nil && "TODO: freq_factors not implemented for !is_neox");
 
 
 
 
 
 
 
 
 
 
2311
  }
2312
 
2313
- id<MTLComputePipelineState> pipeline = nil;
2314
-
2315
- switch (src0->type) {
2316
- case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_F32].pipeline; break;
2317
- case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_F16].pipeline; break;
2318
- default: GGML_ASSERT(false);
2319
- };
2320
-
2321
  [encoder setComputePipelineState:pipeline];
2322
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2323
  [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
@@ -2345,14 +2350,13 @@ static enum ggml_status ggml_metal_graph_compute(
2345
  [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:19];
2346
  [encoder setBytes:&n_past length:sizeof( int) atIndex:20];
2347
  [encoder setBytes:&n_dims length:sizeof( int) atIndex:21];
2348
- [encoder setBytes:&mode length:sizeof( int) atIndex:22];
2349
- [encoder setBytes:&n_orig_ctx length:sizeof( int) atIndex:23];
2350
- [encoder setBytes:&freq_base length:sizeof( float) atIndex:24];
2351
- [encoder setBytes:&freq_scale length:sizeof( float) atIndex:25];
2352
- [encoder setBytes:&ext_factor length:sizeof( float) atIndex:26];
2353
- [encoder setBytes:&attn_factor length:sizeof( float) atIndex:27];
2354
- [encoder setBytes:&beta_fast length:sizeof( float) atIndex:28];
2355
- [encoder setBytes:&beta_slow length:sizeof( float) atIndex:29];
2356
 
2357
  [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
2358
  } break;
 
172
  GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F32,
173
  GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32,
174
  GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32,
175
+ GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32,
176
+ GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16,
177
+ GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32,
178
+ GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16,
179
  GGML_METAL_KERNEL_TYPE_IM2COL_F16,
180
  GGML_METAL_KERNEL_TYPE_IM2COL_F32,
181
  GGML_METAL_KERNEL_TYPE_UPSCALE_F32,
 
628
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F32, mul_mm_id_iq1_m_f32, ctx->support_simdgroup_mm);
629
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32, mul_mm_id_iq4_nl_f32, ctx->support_simdgroup_mm);
630
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32, mul_mm_id_iq4_xs_f32, ctx->support_simdgroup_mm);
631
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32, rope_norm_f32, true);
632
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16, rope_norm_f16, true);
633
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32, rope_neox_f32, true);
634
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16, rope_neox_f16, true);
635
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F16, im2col_f16, true);
636
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F32, im2col_f32, true);
637
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_UPSCALE_F32, upscale_f32, true);
 
2289
  const int n_dims = ((int32_t *) dst->op_params)[1];
2290
  const int mode = ((int32_t *) dst->op_params)[2];
2291
  // skip 3, n_ctx, used in GLM RoPE, unimplemented in metal
2292
+ const int n_ctx_orig = ((int32_t *) dst->op_params)[4];
2293
 
2294
  float freq_base;
2295
  float freq_scale;
 
2306
  memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
2307
 
2308
  const bool is_neox = mode & 2;
 
2309
 
2310
+ id<MTLComputePipelineState> pipeline = nil;
2311
 
2312
  if (!is_neox) {
2313
+ switch (src0->type) {
2314
+ case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32].pipeline; break;
2315
+ case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16].pipeline; break;
2316
+ default: GGML_ASSERT(false);
2317
+ };
2318
+ } else {
2319
+ switch (src0->type) {
2320
+ case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32].pipeline; break;
2321
+ case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16].pipeline; break;
2322
+ default: GGML_ASSERT(false);
2323
+ };
2324
  }
2325
 
 
 
 
 
 
 
 
 
2326
  [encoder setComputePipelineState:pipeline];
2327
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2328
  [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
 
2350
  [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:19];
2351
  [encoder setBytes:&n_past length:sizeof( int) atIndex:20];
2352
  [encoder setBytes:&n_dims length:sizeof( int) atIndex:21];
2353
+ [encoder setBytes:&n_ctx_orig length:sizeof( int) atIndex:22];
2354
+ [encoder setBytes:&freq_base length:sizeof( float) atIndex:23];
2355
+ [encoder setBytes:&freq_scale length:sizeof( float) atIndex:24];
2356
+ [encoder setBytes:&ext_factor length:sizeof( float) atIndex:25];
2357
+ [encoder setBytes:&attn_factor length:sizeof( float) atIndex:26];
2358
+ [encoder setBytes:&beta_fast length:sizeof( float) atIndex:27];
2359
+ [encoder setBytes:&beta_slow length:sizeof( float) atIndex:28];
 
2360
 
2361
  [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
2362
  } break;
ggml-metal.metal CHANGED
@@ -1654,8 +1654,7 @@ static float rope_yarn_ramp(const float low, const float high, const int i0) {
1654
  // MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng.
1655
  static void rope_yarn(
1656
  float theta_extrap, float freq_scale, float corr_dims[2], int64_t i0, float ext_factor, float mscale,
1657
- thread float * cos_theta, thread float * sin_theta
1658
- ) {
1659
  // Get n-d rotational scaling corrected for extrapolation
1660
  float theta_interp = freq_scale * theta_extrap;
1661
  float theta = theta_interp;
@@ -1672,19 +1671,20 @@ static void rope_yarn(
1672
 
1673
  // Apparently solving `n_rot = 2pi * x * base^((2 * max_pos_emb) / n_dims)` for x, we get
1674
  // `corr_fac(n_rot) = n_dims * log(max_pos_emb / (n_rot * 2pi)) / (2 * log(base))`
1675
- static float rope_yarn_corr_factor(int n_dims, int n_orig_ctx, float n_rot, float base) {
1676
- return n_dims * log(n_orig_ctx / (n_rot * 2 * M_PI_F)) / (2 * log(base));
1677
  }
1678
 
1679
  static void rope_yarn_corr_dims(
1680
- int n_dims, int n_orig_ctx, float freq_base, float beta_fast, float beta_slow, float dims[2]
1681
  ) {
1682
  // start and end correction dims
1683
- dims[0] = max(0.0f, floor(rope_yarn_corr_factor(n_dims, n_orig_ctx, beta_fast, freq_base)));
1684
- dims[1] = min(n_dims - 1.0f, ceil(rope_yarn_corr_factor(n_dims, n_orig_ctx, beta_slow, freq_base)));
1685
  }
1686
 
1687
- typedef void (rope_t)(
 
1688
  device const void * src0,
1689
  device const int32_t * src1,
1690
  device const float * src2,
@@ -1707,8 +1707,7 @@ typedef void (rope_t)(
1707
  constant uint64_t & nb3,
1708
  constant int & n_past,
1709
  constant int & n_dims,
1710
- constant int & mode,
1711
- constant int & n_orig_ctx,
1712
  constant float & freq_base,
1713
  constant float & freq_scale,
1714
  constant float & ext_factor,
@@ -1717,10 +1716,52 @@ typedef void (rope_t)(
1717
  constant float & beta_slow,
1718
  uint tiitg[[thread_index_in_threadgroup]],
1719
  uint3 tptg[[threads_per_threadgroup]],
1720
- uint3 tgpig[[threadgroup_position_in_grid]]);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1721
 
1722
  template<typename T>
1723
- kernel void kernel_rope(
1724
  device const void * src0,
1725
  device const int32_t * src1,
1726
  device const float * src2,
@@ -1743,8 +1784,7 @@ kernel void kernel_rope(
1743
  constant uint64_t & nb3,
1744
  constant int & n_past,
1745
  constant int & n_dims,
1746
- constant int & mode,
1747
- constant int & n_orig_ctx,
1748
  constant float & freq_base,
1749
  constant float & freq_scale,
1750
  constant float & ext_factor,
@@ -1758,69 +1798,53 @@ kernel void kernel_rope(
1758
  const int64_t i2 = tgpig[1];
1759
  const int64_t i1 = tgpig[0];
1760
 
1761
- const bool is_neox = mode & 2;
1762
-
1763
  float corr_dims[2];
1764
- rope_yarn_corr_dims(n_dims, n_orig_ctx, freq_base, beta_fast, beta_slow, corr_dims);
1765
 
1766
  device const int32_t * pos = src1;
1767
 
1768
- const int64_t p = pos[i2];
1769
-
1770
- const float theta_base = (float)p;
1771
  const float inv_ndims = -1.f/n_dims;
1772
 
1773
- if (!is_neox) {
1774
- for (int64_t i0 = 2*tiitg; i0 < ne0; i0 += 2*tptg.x) {
1775
- const float theta = theta_base * pow(freq_base, inv_ndims*i0);
1776
-
1777
- float cos_theta, sin_theta;
1778
- rope_yarn(theta, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
1779
-
1780
- device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
1781
- device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
1782
-
1783
- const T x0 = src[0];
1784
- const T x1 = src[1];
1785
 
1786
- dst_data[0] = x0*cos_theta - x1*sin_theta;
1787
- dst_data[1] = x0*sin_theta + x1*cos_theta;
1788
- }
1789
- } else {
1790
- for (int64_t ic = 2*tiitg; ic < ne0; ic += 2*tptg.x) {
1791
- if (ic < n_dims) {
1792
- const int64_t i0 = ic/2;
1793
 
1794
- const float freq_factor = src2 != src0 ? src2[i0] : 1.0f;
1795
-
1796
- const float theta = theta_base * pow(freq_base, inv_ndims*ic);
1797
 
1798
- float cos_theta, sin_theta;
1799
- rope_yarn(theta/freq_factor, freq_scale, corr_dims, ic, ext_factor, attn_factor, &cos_theta, &sin_theta);
1800
 
1801
- device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
1802
- device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
1803
 
1804
- const float x0 = src[0];
1805
- const float x1 = src[n_dims/2];
1806
 
1807
- dst_data[0] = x0*cos_theta - x1*sin_theta;
1808
- dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta;
1809
- } else {
1810
- const int64_t i0 = ic;
1811
 
1812
- device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
1813
- device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
 
 
 
1814
 
1815
- dst_data[0] = src[0];
1816
- dst_data[1] = src[1];
1817
- }
1818
  }
1819
  }
1820
  }
1821
 
1822
- template [[host_name("kernel_rope_f32")]] kernel rope_t kernel_rope<float>;
1823
- template [[host_name("kernel_rope_f16")]] kernel rope_t kernel_rope<half>;
 
 
 
 
 
 
1824
 
1825
  typedef void (im2col_t)(
1826
  device const float * x,
 
1654
  // MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng.
1655
  static void rope_yarn(
1656
  float theta_extrap, float freq_scale, float corr_dims[2], int64_t i0, float ext_factor, float mscale,
1657
+ thread float * cos_theta, thread float * sin_theta) {
 
1658
  // Get n-d rotational scaling corrected for extrapolation
1659
  float theta_interp = freq_scale * theta_extrap;
1660
  float theta = theta_interp;
 
1671
 
1672
  // Apparently solving `n_rot = 2pi * x * base^((2 * max_pos_emb) / n_dims)` for x, we get
1673
  // `corr_fac(n_rot) = n_dims * log(max_pos_emb / (n_rot * 2pi)) / (2 * log(base))`
1674
+ static float rope_yarn_corr_factor(int n_dims, int n_ctx_orig, float n_rot, float base) {
1675
+ return n_dims * log(n_ctx_orig / (n_rot * 2 * M_PI_F)) / (2 * log(base));
1676
  }
1677
 
1678
  static void rope_yarn_corr_dims(
1679
+ int n_dims, int n_ctx_orig, float freq_base, float beta_fast, float beta_slow, float dims[2]
1680
  ) {
1681
  // start and end correction dims
1682
+ dims[0] = max(0.0f, floor(rope_yarn_corr_factor(n_dims, n_ctx_orig, beta_fast, freq_base)));
1683
+ dims[1] = min(n_dims - 1.0f, ceil(rope_yarn_corr_factor(n_dims, n_ctx_orig, beta_slow, freq_base)));
1684
  }
1685
 
1686
+ template<typename T>
1687
+ kernel void kernel_rope_norm(
1688
  device const void * src0,
1689
  device const int32_t * src1,
1690
  device const float * src2,
 
1707
  constant uint64_t & nb3,
1708
  constant int & n_past,
1709
  constant int & n_dims,
1710
+ constant int & n_ctx_orig,
 
1711
  constant float & freq_base,
1712
  constant float & freq_scale,
1713
  constant float & ext_factor,
 
1716
  constant float & beta_slow,
1717
  uint tiitg[[thread_index_in_threadgroup]],
1718
  uint3 tptg[[threads_per_threadgroup]],
1719
+ uint3 tgpig[[threadgroup_position_in_grid]]) {
1720
+ const int64_t i3 = tgpig[2];
1721
+ const int64_t i2 = tgpig[1];
1722
+ const int64_t i1 = tgpig[0];
1723
+
1724
+ float corr_dims[2];
1725
+ rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims);
1726
+
1727
+ device const int32_t * pos = src1;
1728
+
1729
+ const float theta_base = (float) pos[i2];
1730
+ const float inv_ndims = -1.f/n_dims;
1731
+
1732
+ float cos_theta;
1733
+ float sin_theta;
1734
+
1735
+ for (int64_t i0 = 2*tiitg; i0 < ne0; i0 += 2*tptg.x) {
1736
+ if (i0 < n_dims) {
1737
+ const int64_t ic = i0/2;
1738
+
1739
+ const float theta = theta_base * pow(freq_base, inv_ndims*i0);
1740
+
1741
+ const float freq_factor = src2 != src0 ? src2[ic] : 1.0f;
1742
+
1743
+ rope_yarn(theta/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
1744
+
1745
+ device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
1746
+ device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
1747
+
1748
+ const float x0 = src[0];
1749
+ const float x1 = src[1];
1750
+
1751
+ dst_data[0] = x0*cos_theta - x1*sin_theta;
1752
+ dst_data[1] = x0*sin_theta + x1*cos_theta;
1753
+ } else {
1754
+ device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
1755
+ device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
1756
+
1757
+ dst_data[0] = src[0];
1758
+ dst_data[1] = src[1];
1759
+ }
1760
+ }
1761
+ }
1762
 
1763
  template<typename T>
1764
+ kernel void kernel_rope_neox(
1765
  device const void * src0,
1766
  device const int32_t * src1,
1767
  device const float * src2,
 
1784
  constant uint64_t & nb3,
1785
  constant int & n_past,
1786
  constant int & n_dims,
1787
+ constant int & n_ctx_orig,
 
1788
  constant float & freq_base,
1789
  constant float & freq_scale,
1790
  constant float & ext_factor,
 
1798
  const int64_t i2 = tgpig[1];
1799
  const int64_t i1 = tgpig[0];
1800
 
 
 
1801
  float corr_dims[2];
1802
+ rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims);
1803
 
1804
  device const int32_t * pos = src1;
1805
 
1806
+ const float theta_base = (float) pos[i2];
 
 
1807
  const float inv_ndims = -1.f/n_dims;
1808
 
1809
+ float cos_theta;
1810
+ float sin_theta;
 
 
 
 
 
 
 
 
 
 
1811
 
1812
+ for (int64_t i0 = 2*tiitg; i0 < ne0; i0 += 2*tptg.x) {
1813
+ if (i0 < n_dims) {
1814
+ const int64_t ic = i0/2;
 
 
 
 
1815
 
1816
+ const float theta = theta_base * pow(freq_base, inv_ndims*i0);
 
 
1817
 
1818
+ const float freq_factor = src2 != src0 ? src2[ic] : 1.0f;
 
1819
 
1820
+ rope_yarn(theta/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
 
1821
 
1822
+ device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
1823
+ device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
1824
 
1825
+ const float x0 = src[0];
1826
+ const float x1 = src[n_dims/2];
 
 
1827
 
1828
+ dst_data[0] = x0*cos_theta - x1*sin_theta;
1829
+ dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta;
1830
+ } else {
1831
+ device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
1832
+ device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
1833
 
1834
+ dst_data[0] = src[0];
1835
+ dst_data[1] = src[1];
 
1836
  }
1837
  }
1838
  }
1839
 
1840
+ typedef decltype(kernel_rope_norm<float>) kernel_rope_norm_t;
1841
+ typedef decltype(kernel_rope_neox<float>) kernel_rope_neox_t;
1842
+
1843
+ template [[host_name("kernel_rope_norm_f32")]] kernel kernel_rope_norm_t kernel_rope_norm<float>;
1844
+ template [[host_name("kernel_rope_norm_f16")]] kernel kernel_rope_norm_t kernel_rope_norm<half>;
1845
+
1846
+ template [[host_name("kernel_rope_neox_f32")]] kernel kernel_rope_neox_t kernel_rope_neox<float>;
1847
+ template [[host_name("kernel_rope_neox_f16")]] kernel kernel_rope_neox_t kernel_rope_neox<half>;
1848
 
1849
  typedef void (im2col_t)(
1850
  device const float * x,
ggml-sycl.cpp CHANGED
@@ -8928,49 +8928,6 @@ static void rope_neox(
8928
  dst[i + n_dims/2] = x0*sin_theta + x1*cos_theta;
8929
  }
8930
 
8931
- static void rope_glm_f32(
8932
- const float * x, float * dst, int ncols, const int32_t * pos, float freq_scale, int p_delta_rows, float freq_base,
8933
- int n_ctx
8934
- , const sycl::nd_item<3> &item_ct1) {
8935
- const int col = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
8936
- item_ct1.get_local_id(2);
8937
- const int half_n_dims = ncols/4;
8938
-
8939
- if (col >= half_n_dims) {
8940
- return;
8941
- }
8942
-
8943
- const int row = item_ct1.get_local_range(1) * item_ct1.get_group(1) +
8944
- item_ct1.get_local_id(1);
8945
- const int i = row*ncols + col;
8946
- const int i2 = row/p_delta_rows;
8947
-
8948
- const float col_theta_scale = dpct::pow(freq_base, -2.0f * col / ncols);
8949
- // FIXME: this is likely wrong
8950
- const int p = pos != nullptr ? pos[i2] : 0;
8951
-
8952
- const float theta = sycl::min(p, n_ctx - 2) * freq_scale * col_theta_scale;
8953
- const float sin_theta = sycl::sin((float)theta);
8954
- const float cos_theta = sycl::cos((float)theta);
8955
-
8956
- const float x0 = x[i + 0];
8957
- const float x1 = x[i + half_n_dims];
8958
-
8959
- dst[i + 0] = x0*cos_theta - x1*sin_theta;
8960
- dst[i + half_n_dims] = x0*sin_theta + x1*cos_theta;
8961
-
8962
- const float block_theta =
8963
- ((float)sycl::max(p - n_ctx - 2, 0)) * col_theta_scale;
8964
- const float sin_block_theta = sycl::sin((float)block_theta);
8965
- const float cos_block_theta = sycl::cos((float)block_theta);
8966
-
8967
- const float x2 = x[i + half_n_dims * 2];
8968
- const float x3 = x[i + half_n_dims * 3];
8969
-
8970
- dst[i + half_n_dims * 2] = x2*cos_block_theta - x3*sin_block_theta;
8971
- dst[i + half_n_dims * 3] = x2*sin_block_theta + x3*cos_block_theta;
8972
- }
8973
-
8974
  static void k_sum_rows_f32(const float * x, float * dst, const int ncols,
8975
  const sycl::nd_item<3> &item_ct1) {
8976
  const int row = item_ct1.get_group(1);
@@ -12520,22 +12477,6 @@ static void rope_neox_sycl(const T *x, T *dst, int ncols, int n_dims, int nrows,
12520
  }
12521
  }
12522
 
12523
- static void rope_glm_f32_sycl(const float *x, float *dst, int ncols, int nrows,
12524
- const int32_t *pos, float freq_scale,
12525
- int p_delta_rows, float freq_base, int n_ctx,
12526
- dpct::queue_ptr stream) {
12527
- GGML_ASSERT(ncols % 4 == 0);
12528
- const sycl::range<3> block_dims(1, 1, SYCL_ROPE_BLOCK_SIZE / 4);
12529
- const int num_blocks_x = (ncols + SYCL_ROPE_BLOCK_SIZE - 1) / SYCL_ROPE_BLOCK_SIZE;
12530
- const sycl::range<3> block_nums(1, nrows, num_blocks_x);
12531
- stream->parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims),
12532
- [=](sycl::nd_item<3> item_ct1) {
12533
- rope_glm_f32(x, dst, ncols, pos, freq_scale,
12534
- p_delta_rows, freq_base, n_ctx,
12535
- item_ct1);
12536
- });
12537
- }
12538
-
12539
  static void sum_rows_f32_sycl(const float *x, float *dst, const int ncols,
12540
  const int nrows, dpct::queue_ptr stream) {
12541
  const sycl::range<3> block_dims(1, 1, WARP_SIZE);
@@ -14066,8 +14007,8 @@ inline void ggml_sycl_op_rope(const ggml_tensor *src0, const ggml_tensor *src1,
14066
  //const int n_past = ((int32_t *) dst->op_params)[0];
14067
  const int n_dims = ((int32_t *) dst->op_params)[1];
14068
  const int mode = ((int32_t *) dst->op_params)[2];
14069
- const int n_ctx = ((int32_t *) dst->op_params)[3];
14070
- const int n_orig_ctx = ((int32_t *) dst->op_params)[4];
14071
 
14072
  // RoPE alteration for extended context
14073
  float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
@@ -14087,7 +14028,9 @@ inline void ggml_sycl_op_rope(const ggml_tensor *src0, const ggml_tensor *src1,
14087
  }
14088
 
14089
  const bool is_neox = mode & 2;
14090
- const bool is_glm = mode & 4;
 
 
14091
 
14092
  if (is_neox) {
14093
  pos = (const int32_t *) src1_dd;
@@ -14100,13 +14043,10 @@ inline void ggml_sycl_op_rope(const ggml_tensor *src0, const ggml_tensor *src1,
14100
  }
14101
 
14102
  rope_corr_dims corr_dims;
14103
- ggml_rope_yarn_corr_dims(n_dims, n_orig_ctx, freq_base, beta_fast, beta_slow, corr_dims.v);
14104
 
14105
  // compute
14106
- if (is_glm) {
14107
- GGML_ASSERT(false);
14108
- rope_glm_f32_sycl(src0_dd, dst_dd, ne00, nrows, pos, freq_scale, ne01, freq_base, n_ctx, main_stream);
14109
- } else if (is_neox) {
14110
  if (src0->type == GGML_TYPE_F32) {
14111
  rope_neox_sycl(
14112
  (const float *)src0_dd, (float *)dst_dd, ne00, n_dims, nrows, pos, freq_scale, ne01, freq_base, ext_factor,
 
8928
  dst[i + n_dims/2] = x0*sin_theta + x1*cos_theta;
8929
  }
8930
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8931
  static void k_sum_rows_f32(const float * x, float * dst, const int ncols,
8932
  const sycl::nd_item<3> &item_ct1) {
8933
  const int row = item_ct1.get_group(1);
 
12477
  }
12478
  }
12479
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12480
  static void sum_rows_f32_sycl(const float *x, float *dst, const int ncols,
12481
  const int nrows, dpct::queue_ptr stream) {
12482
  const sycl::range<3> block_dims(1, 1, WARP_SIZE);
 
14007
  //const int n_past = ((int32_t *) dst->op_params)[0];
14008
  const int n_dims = ((int32_t *) dst->op_params)[1];
14009
  const int mode = ((int32_t *) dst->op_params)[2];
14010
+ //const int n_ctx = ((int32_t *) dst->op_params)[3];
14011
+ const int n_ctx_orig = ((int32_t *) dst->op_params)[4];
14012
 
14013
  // RoPE alteration for extended context
14014
  float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
 
14028
  }
14029
 
14030
  const bool is_neox = mode & 2;
14031
+
14032
+ #pragma message("TODO: update rope NORM mode to match NEOX mode")
14033
+ #pragma message(" https://github.com/ggerganov/llama.cpp/pull/7634")
14034
 
14035
  if (is_neox) {
14036
  pos = (const int32_t *) src1_dd;
 
14043
  }
14044
 
14045
  rope_corr_dims corr_dims;
14046
+ ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims.v);
14047
 
14048
  // compute
14049
+ if (is_neox) {
 
 
 
14050
  if (src0->type == GGML_TYPE_F32) {
14051
  rope_neox_sycl(
14052
  (const float *)src0_dd, (float *)dst_dd, ne00, n_dims, nrows, pos, freq_scale, ne01, freq_base, ext_factor,
ggml-vulkan.cpp CHANGED
@@ -3898,11 +3898,6 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
3898
  {
3899
  const int mode = ((const int32_t *) dst->op_params)[2];
3900
  const bool is_neox = mode & 2;
3901
- const bool is_glm = mode & 4;
3902
-
3903
- if (is_glm) {
3904
- return nullptr;
3905
- }
3906
 
3907
  if (is_neox) {
3908
  if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
@@ -4401,7 +4396,7 @@ static void ggml_vk_rope(ggml_backend_vk_context * ctx, vk_context * subctx, con
4401
  const int n_dims = ((int32_t *) dst->op_params)[1];
4402
  const int mode = ((int32_t *) dst->op_params)[2];
4403
  // const int n_ctx = ((int32_t *) dst->op_params)[3];
4404
- const int n_orig_ctx = ((int32_t *) dst->op_params)[4];
4405
  const float freq_base = ((float *) dst->op_params)[5];
4406
  const float freq_scale = ((float *) dst->op_params)[6];
4407
  const float ext_factor = ((float *) dst->op_params)[7];
@@ -4410,12 +4405,12 @@ static void ggml_vk_rope(ggml_backend_vk_context * ctx, vk_context * subctx, con
4410
  const float beta_slow = ((float *) dst->op_params)[10];
4411
 
4412
  const bool is_neox = mode & 2;
4413
- const bool is_glm = mode & 4;
4414
 
4415
- GGML_ASSERT(!is_glm);
 
4416
 
4417
  float corr_dims[2];
4418
- ggml_rope_yarn_corr_dims(n_dims, n_orig_ctx, freq_base, beta_fast, beta_slow, corr_dims);
4419
 
4420
  if (is_neox) {
4421
  const float theta_scale = powf(freq_base, -2.0f/n_dims);
@@ -6485,9 +6480,8 @@ GGML_CALL static bool ggml_backend_vk_supports_op(ggml_backend_t backend, const
6485
  case GGML_OP_ROPE:
6486
  {
6487
  const int mode = ((const int32_t *) op->op_params)[2];
6488
- const bool is_glm = mode & 4;
6489
 
6490
- return !is_glm;
6491
  } break;
6492
  case GGML_OP_NONE:
6493
  case GGML_OP_RESHAPE:
@@ -6992,15 +6986,15 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_compute_
6992
  } else if (tensor->op == GGML_OP_ROPE) {
6993
  const int n_dims = ((int32_t *) tensor->op_params)[1];
6994
  const int mode = ((int32_t *) tensor->op_params)[2];
6995
- const int n_ggml_ctx = ((int32_t *) tensor->op_params)[3];
6996
- const int n_orig_ggml_ctx = ((int32_t *) tensor->op_params)[4];
6997
  float freq_base = ((float *) tensor->op_params)[5];
6998
  float freq_scale = ((float *) tensor->op_params)[6];
6999
  float ext_factor = ((float *) tensor->op_params)[7];
7000
  float attn_factor = ((float *) tensor->op_params)[8];
7001
  float beta_fast = ((float *) tensor->op_params)[9];
7002
  float beta_slow = ((float *) tensor->op_params)[10];
7003
- tensor_clone = ggml_rope_ext(ggml_ctx, src0_clone, src1_clone, src2_clone, n_dims, mode, n_ggml_ctx, n_orig_ggml_ctx, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow);
7004
  } else if (tensor->op == GGML_OP_UNARY) {
7005
  switch (ggml_get_unary_op(tensor)) {
7006
  case GGML_UNARY_OP_SILU:
 
3898
  {
3899
  const int mode = ((const int32_t *) dst->op_params)[2];
3900
  const bool is_neox = mode & 2;
 
 
 
 
 
3901
 
3902
  if (is_neox) {
3903
  if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
 
4396
  const int n_dims = ((int32_t *) dst->op_params)[1];
4397
  const int mode = ((int32_t *) dst->op_params)[2];
4398
  // const int n_ctx = ((int32_t *) dst->op_params)[3];
4399
+ const int n_ctx_orig = ((int32_t *) dst->op_params)[4];
4400
  const float freq_base = ((float *) dst->op_params)[5];
4401
  const float freq_scale = ((float *) dst->op_params)[6];
4402
  const float ext_factor = ((float *) dst->op_params)[7];
 
4405
  const float beta_slow = ((float *) dst->op_params)[10];
4406
 
4407
  const bool is_neox = mode & 2;
 
4408
 
4409
+ #pragma message("TODO: update rope NORM mode to match NEOX mode")
4410
+ #pragma message(" https://github.com/ggerganov/llama.cpp/pull/7634")
4411
 
4412
  float corr_dims[2];
4413
+ ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims);
4414
 
4415
  if (is_neox) {
4416
  const float theta_scale = powf(freq_base, -2.0f/n_dims);
 
6480
  case GGML_OP_ROPE:
6481
  {
6482
  const int mode = ((const int32_t *) op->op_params)[2];
 
6483
 
6484
+ return true;
6485
  } break;
6486
  case GGML_OP_NONE:
6487
  case GGML_OP_RESHAPE:
 
6986
  } else if (tensor->op == GGML_OP_ROPE) {
6987
  const int n_dims = ((int32_t *) tensor->op_params)[1];
6988
  const int mode = ((int32_t *) tensor->op_params)[2];
6989
+ //const int n_ctx_ggml = ((int32_t *) tensor->op_params)[3];
6990
+ const int n_ctx_orig_ggml = ((int32_t *) tensor->op_params)[4];
6991
  float freq_base = ((float *) tensor->op_params)[5];
6992
  float freq_scale = ((float *) tensor->op_params)[6];
6993
  float ext_factor = ((float *) tensor->op_params)[7];
6994
  float attn_factor = ((float *) tensor->op_params)[8];
6995
  float beta_fast = ((float *) tensor->op_params)[9];
6996
  float beta_slow = ((float *) tensor->op_params)[10];
6997
+ tensor_clone = ggml_rope_ext(ggml_ctx, src0_clone, src1_clone, src2_clone, n_dims, mode, n_ctx_orig_ggml, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow);
6998
  } else if (tensor->op == GGML_OP_UNARY) {
6999
  switch (ggml_get_unary_op(tensor)) {
7000
  case GGML_UNARY_OP_SILU:
ggml.c CHANGED
@@ -6250,16 +6250,13 @@ static struct ggml_tensor * ggml_rope_impl(
6250
  struct ggml_tensor * c,
6251
  int n_dims,
6252
  int mode,
6253
- int n_ctx,
6254
- int n_orig_ctx,
6255
  float freq_base,
6256
  float freq_scale,
6257
  float ext_factor,
6258
  float attn_factor,
6259
  float beta_fast,
6260
  float beta_slow,
6261
- float xpos_base,
6262
- bool xpos_down,
6263
  bool inplace) {
6264
  GGML_ASSERT((mode & 1) == 0 && "mode & 1 == 1 is no longer supported");
6265
 
@@ -6280,15 +6277,13 @@ static struct ggml_tensor * ggml_rope_impl(
6280
 
6281
  struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
6282
 
6283
- int32_t params[13] = { /*n_past*/ 0, n_dims, mode, n_ctx, n_orig_ctx };
6284
  memcpy(params + 5, &freq_base, sizeof(float));
6285
  memcpy(params + 6, &freq_scale, sizeof(float));
6286
  memcpy(params + 7, &ext_factor, sizeof(float));
6287
  memcpy(params + 8, &attn_factor, sizeof(float));
6288
  memcpy(params + 9, &beta_fast, sizeof(float));
6289
  memcpy(params + 10, &beta_slow, sizeof(float));
6290
- memcpy(params + 11, &xpos_base, sizeof(float));
6291
- memcpy(params + 12, &xpos_down, sizeof(bool));
6292
  ggml_set_op_params(result, params, sizeof(params));
6293
 
6294
  result->op = GGML_OP_ROPE;
@@ -6305,10 +6300,9 @@ struct ggml_tensor * ggml_rope(
6305
  struct ggml_tensor * a,
6306
  struct ggml_tensor * b,
6307
  int n_dims,
6308
- int mode,
6309
- int n_ctx) {
6310
  return ggml_rope_impl(
6311
- ctx, a, b, NULL, n_dims, mode, n_ctx, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, false, false
6312
  );
6313
  }
6314
 
@@ -6317,10 +6311,9 @@ struct ggml_tensor * ggml_rope_inplace(
6317
  struct ggml_tensor * a,
6318
  struct ggml_tensor * b,
6319
  int n_dims,
6320
- int mode,
6321
- int n_ctx) {
6322
  return ggml_rope_impl(
6323
- ctx, a, b, NULL, n_dims, mode, n_ctx, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, false, true
6324
  );
6325
  }
6326
 
@@ -6331,8 +6324,7 @@ struct ggml_tensor * ggml_rope_ext(
6331
  struct ggml_tensor * c,
6332
  int n_dims,
6333
  int mode,
6334
- int n_ctx,
6335
- int n_orig_ctx,
6336
  float freq_base,
6337
  float freq_scale,
6338
  float ext_factor,
@@ -6340,8 +6332,8 @@ struct ggml_tensor * ggml_rope_ext(
6340
  float beta_fast,
6341
  float beta_slow) {
6342
  return ggml_rope_impl(
6343
- ctx, a, b, c, n_dims, mode, n_ctx, n_orig_ctx, freq_base, freq_scale,
6344
- ext_factor, attn_factor, beta_fast, beta_slow, 0.0f, false, false
6345
  );
6346
  }
6347
 
@@ -6352,8 +6344,7 @@ struct ggml_tensor * ggml_rope_ext_inplace(
6352
  struct ggml_tensor * c,
6353
  int n_dims,
6354
  int mode,
6355
- int n_ctx,
6356
- int n_orig_ctx,
6357
  float freq_base,
6358
  float freq_scale,
6359
  float ext_factor,
@@ -6361,8 +6352,8 @@ struct ggml_tensor * ggml_rope_ext_inplace(
6361
  float beta_fast,
6362
  float beta_slow) {
6363
  return ggml_rope_impl(
6364
- ctx, a, b, c, n_dims, mode, n_ctx, n_orig_ctx, freq_base, freq_scale,
6365
- ext_factor, attn_factor, beta_fast, beta_slow, 0.0f, false, true
6366
  );
6367
  }
6368
 
@@ -6372,8 +6363,7 @@ struct ggml_tensor * ggml_rope_custom(
6372
  struct ggml_tensor * b,
6373
  int n_dims,
6374
  int mode,
6375
- int n_ctx,
6376
- int n_orig_ctx,
6377
  float freq_base,
6378
  float freq_scale,
6379
  float ext_factor,
@@ -6381,8 +6371,8 @@ struct ggml_tensor * ggml_rope_custom(
6381
  float beta_fast,
6382
  float beta_slow) {
6383
  return ggml_rope_impl(
6384
- ctx, a, b, NULL, n_dims, mode, n_ctx, n_orig_ctx, freq_base, freq_scale,
6385
- ext_factor, attn_factor, beta_fast, beta_slow, 0.0f, false, false
6386
  );
6387
  }
6388
 
@@ -6392,8 +6382,7 @@ struct ggml_tensor * ggml_rope_custom_inplace(
6392
  struct ggml_tensor * b,
6393
  int n_dims,
6394
  int mode,
6395
- int n_ctx,
6396
- int n_orig_ctx,
6397
  float freq_base,
6398
  float freq_scale,
6399
  float ext_factor,
@@ -6401,21 +6390,11 @@ struct ggml_tensor * ggml_rope_custom_inplace(
6401
  float beta_fast,
6402
  float beta_slow) {
6403
  return ggml_rope_impl(
6404
- ctx, a, b, NULL, n_dims, mode, n_ctx, n_orig_ctx, freq_base, freq_scale,
6405
- ext_factor, attn_factor, beta_fast, beta_slow, 0.0f, false, true
6406
  );
6407
  }
6408
 
6409
- struct ggml_tensor * ggml_rope_xpos_inplace(
6410
- struct ggml_context * ctx,
6411
- struct ggml_tensor * a,
6412
- struct ggml_tensor * b,
6413
- int n_dims,
6414
- float base,
6415
- bool down) {
6416
- return ggml_rope_impl(ctx, a, b, NULL, n_dims, 0, 0, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, base, down, true);
6417
- }
6418
-
6419
  // ggml_rope_back
6420
 
6421
  struct ggml_tensor * ggml_rope_back(
@@ -6425,16 +6404,13 @@ struct ggml_tensor * ggml_rope_back(
6425
  struct ggml_tensor * c,
6426
  int n_dims,
6427
  int mode,
6428
- int n_ctx,
6429
- int n_orig_ctx,
6430
  float freq_base,
6431
  float freq_scale,
6432
  float ext_factor,
6433
  float attn_factor,
6434
  float beta_fast,
6435
- float beta_slow,
6436
- float xpos_base,
6437
- bool xpos_down) {
6438
  GGML_ASSERT(ggml_is_vector(b));
6439
  GGML_ASSERT(b->type == GGML_TYPE_I32);
6440
  GGML_ASSERT(a->ne[2] == b->ne[0]);
@@ -6450,15 +6426,13 @@ struct ggml_tensor * ggml_rope_back(
6450
 
6451
  struct ggml_tensor * result = ggml_dup_tensor(ctx, a);
6452
 
6453
- int32_t params[13] = { /*n_past*/ 0, n_dims, mode, n_ctx, n_orig_ctx };
6454
  memcpy(params + 5, &freq_base, sizeof(float));
6455
  memcpy(params + 6, &freq_scale, sizeof(float));
6456
  memcpy(params + 7, &ext_factor, sizeof(float));
6457
  memcpy(params + 8, &attn_factor, sizeof(float));
6458
  memcpy(params + 9, &beta_fast, sizeof(float));
6459
  memcpy(params + 10, &beta_slow, sizeof(float));
6460
- memcpy(params + 11, &xpos_base, sizeof(float));
6461
- memcpy(params + 12, &xpos_down, sizeof(bool));
6462
  ggml_set_op_params(result, params, sizeof(params));
6463
 
6464
  result->op = GGML_OP_ROPE_BACK;
@@ -14227,8 +14201,7 @@ static float rope_yarn_ramp(const float low, const float high, const int i0) {
14227
  // MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng.
14228
  static void rope_yarn(
14229
  float theta_extrap, float freq_scale, float corr_dims[2], int64_t i0, float ext_factor, float mscale,
14230
- float * cos_theta, float * sin_theta
14231
- ) {
14232
  // Get n-d rotational scaling corrected for extrapolation
14233
  float theta_interp = freq_scale * theta_extrap;
14234
  float theta = theta_interp;
@@ -14245,18 +14218,19 @@ static void rope_yarn(
14245
 
14246
  // Apparently solving `n_rot = 2pi * x * base^((2 * max_pos_emb) / n_dims)` for x, we get
14247
  // `corr_dim(n_rot) = n_dims * log(max_pos_emb / (n_rot * 2pi)) / (2 * log(base))`
14248
- static float ggml_rope_yarn_corr_dim(int n_dims, int n_orig_ctx, float n_rot, float base) {
14249
- return n_dims * logf(n_orig_ctx / (n_rot * 2 * (float)M_PI)) / (2 * logf(base));
14250
  }
14251
 
14252
  static void ggml_rope_cache_init(
14253
- float theta_base, float freq_scale, float corr_dims[2], int64_t ne0, float ext_factor, float mscale,
14254
- float * cache, float sin_sign, float theta_scale
14255
- ) {
14256
  float theta = theta_base;
14257
  for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
 
14258
  rope_yarn(
14259
- theta, freq_scale, corr_dims, i0, ext_factor, mscale, &cache[i0 + 0], &cache[i0 + 1]
14260
  );
14261
  cache[i0 + 1] *= sin_sign;
14262
 
@@ -14265,11 +14239,11 @@ static void ggml_rope_cache_init(
14265
  }
14266
 
14267
  GGML_CALL void ggml_rope_yarn_corr_dims(
14268
- int n_dims, int n_orig_ctx, float freq_base, float beta_fast, float beta_slow, float dims[2]
14269
  ) {
14270
  // start and end correction dims
14271
- float start = floorf(ggml_rope_yarn_corr_dim(n_dims, n_orig_ctx, beta_fast, freq_base));
14272
- float end = ceilf(ggml_rope_yarn_corr_dim(n_dims, n_orig_ctx, beta_slow, freq_base));
14273
  dims[0] = MAX(0, start);
14274
  dims[1] = MIN(n_dims - 1, end);
14275
  }
@@ -14289,15 +14263,11 @@ static void ggml_compute_forward_rope_f32(
14289
 
14290
  float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
14291
 
14292
- // these two only relevant for xPos RoPE:
14293
- float xpos_base;
14294
- bool xpos_down;
14295
-
14296
  //const int n_past = ((int32_t *) dst->op_params)[0];
14297
  const int n_dims = ((int32_t *) dst->op_params)[1];
14298
  const int mode = ((int32_t *) dst->op_params)[2];
14299
- const int n_ctx = ((int32_t *) dst->op_params)[3];
14300
- const int n_orig_ctx = ((int32_t *) dst->op_params)[4];
14301
 
14302
  memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float));
14303
  memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float));
@@ -14305,8 +14275,6 @@ static void ggml_compute_forward_rope_f32(
14305
  memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float));
14306
  memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
14307
  memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
14308
- memcpy(&xpos_base, (int32_t *) dst->op_params + 11, sizeof(float));
14309
- memcpy(&xpos_down, (int32_t *) dst->op_params + 12, sizeof(bool));
14310
 
14311
  GGML_TENSOR_UNARY_OP_LOCALS
14312
 
@@ -14336,20 +14304,15 @@ static void ggml_compute_forward_rope_f32(
14336
  const float theta_scale = powf(freq_base, -2.0f/n_dims);
14337
 
14338
  float corr_dims[2];
14339
- ggml_rope_yarn_corr_dims(n_dims, n_orig_ctx, freq_base, beta_fast, beta_slow, corr_dims);
14340
 
14341
  const bool is_neox = mode & 2;
14342
- const bool is_glm = mode & 4;
14343
 
14344
  const float * freq_factors = NULL;
14345
- if (is_neox) {
14346
- if (src2 != NULL) {
14347
- GGML_ASSERT(src2->type == GGML_TYPE_F32);
14348
- GGML_ASSERT(src2->ne[0] >= n_dims / 2);
14349
- freq_factors = (const float *) src2->data;
14350
- }
14351
- } else {
14352
- GGML_ASSERT(src2 == NULL && "TODO: freq_factors not implemented for !is_neox");
14353
  }
14354
 
14355
  // backward process uses inverse rotation by cos and sin.
@@ -14364,94 +14327,50 @@ static void ggml_compute_forward_rope_f32(
14364
  const int64_t p = pos[i2];
14365
 
14366
  float * cache = (float *) params->wdata + (ne0 + CACHE_LINE_SIZE_F32)*ith;
14367
- if (!is_glm && !is_neox) { // TODO: cache sin/cos for glm, neox
14368
- ggml_rope_cache_init(p, freq_scale, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
14369
- }
14370
 
14371
  for (int64_t i1 = 0; i1 < ne1; i1++) {
14372
  if (ir++ < ir0) continue;
14373
  if (ir > ir1) break;
14374
 
14375
- float theta_base = (float)p;
14376
-
14377
- if (is_glm) {
14378
- theta_base = MIN(p, n_ctx - 2);
14379
- float block_theta = MAX(p - (n_ctx - 2), 0);
14380
- for (int64_t i0 = 0; i0 < ne0 / 4; i0++) {
14381
- const float cos_theta = cosf(theta_base);
14382
- const float sin_theta = sinf(theta_base) * sin_sign;
14383
- const float cos_block_theta = cosf(block_theta);
14384
- const float sin_block_theta = sinf(block_theta) * sin_sign;
14385
-
14386
- theta_base *= theta_scale;
14387
- block_theta *= theta_scale;
14388
-
14389
- const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
14390
- float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
14391
-
14392
- const float x0 = src[0];
14393
- const float x1 = src[n_dims/2];
14394
- const float x2 = src[n_dims];
14395
- const float x3 = src[n_dims/2*3];
14396
-
14397
- dst_data[0] = x0*cos_theta - x1*sin_theta;
14398
- dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta;
14399
- dst_data[n_dims] = x2*cos_block_theta - x3*sin_block_theta;
14400
- dst_data[n_dims/2*3] = x2*sin_block_theta + x3*cos_block_theta;
14401
- }
14402
- } else if (!is_neox) {
14403
- for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
14404
  const float cos_theta = cache[i0 + 0];
14405
  const float sin_theta = cache[i0 + 1];
14406
 
14407
- // zeta scaling for xPos only:
14408
- float zeta = xpos_base != 0.0f ? powf((i0 + 0.4f * ne0) / (1.4f * ne0), p / xpos_base) : 1.0f;
14409
- if (xpos_down) zeta = 1.0f / zeta;
14410
-
14411
  const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
14412
  float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
14413
 
14414
  const float x0 = src[0];
14415
  const float x1 = src[1];
14416
 
14417
- dst_data[0] = x0*cos_theta*zeta - x1*sin_theta*zeta;
14418
- dst_data[1] = x0*sin_theta*zeta + x1*cos_theta*zeta;
14419
  }
14420
  } else {
14421
- // ref: https://github.com/jquesnelle/yarn/blob/master/scaled_rope/LlamaYaRNScaledRotaryEmbedding.py
14422
- for (int64_t ic = 0; ic < ne0; ic += 2) {
14423
- if (ic < n_dims) {
14424
- const int64_t i0 = ic/2;
14425
-
14426
- const float freq_factor = freq_factors ? freq_factors[i0] : 1.0f;
14427
-
14428
- float cos_theta, sin_theta;
14429
- rope_yarn(
14430
- theta_base/freq_factor, freq_scale, corr_dims, ic, ext_factor, attn_factor,
14431
- &cos_theta, &sin_theta
14432
- );
14433
 
14434
- sin_theta *= sin_sign;
14435
- theta_base *= theta_scale;
14436
 
14437
- const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
14438
- float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
14439
 
14440
- const float x0 = src[0];
14441
- const float x1 = src[n_dims/2];
14442
 
14443
- dst_data[0] = x0*cos_theta - x1*sin_theta;
14444
- dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta;
14445
- } else {
14446
- const int64_t i0 = ic;
14447
 
14448
- const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
14449
- float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
 
14450
 
14451
- dst_data[0] = src[0];
14452
- dst_data[1] = src[1];
14453
- }
14454
- }
14455
  }
14456
  }
14457
  }
@@ -14477,8 +14396,8 @@ static void ggml_compute_forward_rope_f16(
14477
  //const int n_past = ((int32_t *) dst->op_params)[0];
14478
  const int n_dims = ((int32_t *) dst->op_params)[1];
14479
  const int mode = ((int32_t *) dst->op_params)[2];
14480
- const int n_ctx = ((int32_t *) dst->op_params)[3];
14481
- const int n_orig_ctx = ((int32_t *) dst->op_params)[4];
14482
  memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float));
14483
  memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float));
14484
  memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float));
@@ -14514,20 +14433,15 @@ static void ggml_compute_forward_rope_f16(
14514
  const float theta_scale = powf(freq_base, -2.0f/n_dims);
14515
 
14516
  float corr_dims[2];
14517
- ggml_rope_yarn_corr_dims(n_dims, n_orig_ctx, freq_base, beta_fast, beta_slow, corr_dims);
14518
 
14519
  const bool is_neox = mode & 2;
14520
- const bool is_glm = mode & 4;
14521
 
14522
  const float * freq_factors = NULL;
14523
- if (is_neox) {
14524
- if (src2 != NULL) {
14525
- GGML_ASSERT(src2->type == GGML_TYPE_F32);
14526
- GGML_ASSERT(src2->ne[0] >= n_dims / 2);
14527
- freq_factors = (const float *) src2->data;
14528
- }
14529
- } else {
14530
- GGML_ASSERT(src2 == NULL && "TODO: freq_factors not implemented for !is_neox");
14531
  }
14532
 
14533
  // backward process uses inverse rotation by cos and sin.
@@ -14542,43 +14456,14 @@ static void ggml_compute_forward_rope_f16(
14542
  const int64_t p = pos[i2];
14543
 
14544
  float * cache = (float *) params->wdata + (ne0 + CACHE_LINE_SIZE_F32)*ith;
14545
- if (!is_glm && !is_neox) { // TODO: cache sin/cos for glm, neox
14546
- ggml_rope_cache_init(p, freq_scale, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
14547
- }
14548
 
14549
  for (int64_t i1 = 0; i1 < ne1; i1++) {
14550
  if (ir++ < ir0) continue;
14551
  if (ir > ir1) break;
14552
 
14553
- float theta_base = (float)p;
14554
-
14555
- if (is_glm) {
14556
- theta_base = MIN(p, n_ctx - 2);
14557
- float block_theta = MAX(p - (n_ctx - 2), 0);
14558
- for (int64_t i0 = 0; i0 < ne0 / 4; i0++) {
14559
- const float cos_theta = cosf(theta_base);
14560
- const float sin_theta = sinf(theta_base) * sin_sign;
14561
- const float cos_block_theta = cosf(block_theta);
14562
- const float sin_block_theta = sinf(block_theta) * sin_sign;
14563
-
14564
- theta_base *= theta_scale;
14565
- block_theta *= theta_scale;
14566
-
14567
- const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
14568
- ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
14569
-
14570
- const float x0 = GGML_FP16_TO_FP32(src[0]);
14571
- const float x1 = GGML_FP16_TO_FP32(src[n_dims/2]);
14572
- const float x2 = GGML_FP16_TO_FP32(src[n_dims]);
14573
- const float x3 = GGML_FP16_TO_FP32(src[n_dims/2*3]);
14574
-
14575
- dst_data[0] = GGML_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
14576
- dst_data[n_dims/2] = GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
14577
- dst_data[n_dims] = GGML_FP32_TO_FP16(x2*cos_block_theta - x3*sin_block_theta);
14578
- dst_data[n_dims/2*3] = GGML_FP32_TO_FP16(x2*sin_block_theta + x3*cos_block_theta);
14579
- }
14580
- } else if (!is_neox) {
14581
- for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
14582
  const float cos_theta = cache[i0 + 0];
14583
  const float sin_theta = cache[i0 + 1];
14584
 
@@ -14592,40 +14477,29 @@ static void ggml_compute_forward_rope_f16(
14592
  dst_data[1] = GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
14593
  }
14594
  } else {
14595
- // ref: https://github.com/jquesnelle/yarn/blob/master/scaled_rope/LlamaYaRNScaledRotaryEmbedding.py
14596
- for (int64_t ic = 0; ic < ne0; ic += 2) {
14597
- if (ic < n_dims) {
14598
- const int64_t i0 = ic/2;
14599
-
14600
- const float freq_factor = freq_factors ? freq_factors[i0] : 1.0f;
14601
 
14602
- float cos_theta, sin_theta;
14603
- rope_yarn(
14604
- theta_base/freq_factor, freq_scale, corr_dims, ic, ext_factor, attn_factor,
14605
- &cos_theta, &sin_theta
14606
- );
14607
-
14608
- sin_theta *= sin_sign;
14609
- theta_base *= theta_scale;
14610
 
14611
- const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
14612
- ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
14613
 
14614
- const float x0 = GGML_FP16_TO_FP32(src[0]);
14615
- const float x1 = GGML_FP16_TO_FP32(src[n_dims/2]);
14616
 
14617
- dst_data[0] = GGML_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
14618
- dst_data[n_dims/2] = GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
14619
- } else {
14620
- const int64_t i0 = ic;
14621
 
14622
- const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
14623
- ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
 
14624
 
14625
- dst_data[0] = src[0];
14626
- dst_data[1] = src[1];
14627
- }
14628
- }
14629
  }
14630
  }
14631
  }
@@ -18327,9 +18201,9 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
18327
  //const int n_past = ((int32_t *) tensor->op_params)[0];
18328
  const int n_dims = ((int32_t *) tensor->op_params)[1];
18329
  const int mode = ((int32_t *) tensor->op_params)[2];
18330
- const int n_ctx = ((int32_t *) tensor->op_params)[3];
18331
- const int n_orig_ctx = ((int32_t *) tensor->op_params)[4];
18332
- float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow, xpos_base, xpos_down;
18333
 
18334
  memcpy(&freq_base, (int32_t *) tensor->op_params + 5, sizeof(float));
18335
  memcpy(&freq_scale, (int32_t *) tensor->op_params + 6, sizeof(float));
@@ -18337,8 +18211,6 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
18337
  memcpy(&attn_factor, (int32_t *) tensor->op_params + 8, sizeof(float));
18338
  memcpy(&beta_fast, (int32_t *) tensor->op_params + 9, sizeof(float));
18339
  memcpy(&beta_slow, (int32_t *) tensor->op_params + 10, sizeof(float));
18340
- memcpy(&xpos_base, (int32_t *) tensor->op_params + 11, sizeof(float));
18341
- memcpy(&xpos_down, (int32_t *) tensor->op_params + 12, sizeof(bool));
18342
 
18343
  src0->grad = ggml_add_or_set(ctx,
18344
  src0->grad,
@@ -18348,16 +18220,13 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
18348
  src2,
18349
  n_dims,
18350
  mode,
18351
- n_ctx,
18352
- n_orig_ctx,
18353
  freq_base,
18354
  freq_scale,
18355
  ext_factor,
18356
  attn_factor,
18357
  beta_fast,
18358
- beta_slow,
18359
- xpos_base,
18360
- xpos_down),
18361
  zero_table);
18362
  }
18363
  } break;
@@ -18367,9 +18236,9 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
18367
  //const int n_past = ((int32_t *) tensor->op_params)[0];
18368
  const int n_dims = ((int32_t *) tensor->op_params)[1];
18369
  const int mode = ((int32_t *) tensor->op_params)[2];
18370
- const int n_ctx = ((int32_t *) tensor->op_params)[3];
18371
- const int n_orig_ctx = ((int32_t *) tensor->op_params)[4];
18372
- float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow, xpos_base, xpos_down;
18373
 
18374
  memcpy(&freq_base, (int32_t *) tensor->op_params + 5, sizeof(float));
18375
  memcpy(&freq_scale, (int32_t *) tensor->op_params + 6, sizeof(float));
@@ -18377,8 +18246,6 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
18377
  memcpy(&attn_factor, (int32_t *) tensor->op_params + 8, sizeof(float));
18378
  memcpy(&beta_fast, (int32_t *) tensor->op_params + 9, sizeof(float));
18379
  memcpy(&beta_slow, (int32_t *) tensor->op_params + 10, sizeof(float));
18380
- memcpy(&xpos_base, (int32_t *) tensor->op_params + 11, sizeof(float));
18381
- memcpy(&xpos_down, (int32_t *) tensor->op_params + 12, sizeof(bool));
18382
 
18383
  src0->grad = ggml_add_or_set(ctx,
18384
  src0->grad,
@@ -18388,16 +18255,13 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
18388
  src2,
18389
  n_dims,
18390
  mode,
18391
- n_ctx,
18392
- n_orig_ctx,
18393
  freq_base,
18394
  freq_scale,
18395
  ext_factor,
18396
  attn_factor,
18397
  beta_fast,
18398
  beta_slow,
18399
- xpos_base,
18400
- xpos_down,
18401
  false),
18402
  zero_table);
18403
  }
 
6250
  struct ggml_tensor * c,
6251
  int n_dims,
6252
  int mode,
6253
+ int n_ctx_orig,
 
6254
  float freq_base,
6255
  float freq_scale,
6256
  float ext_factor,
6257
  float attn_factor,
6258
  float beta_fast,
6259
  float beta_slow,
 
 
6260
  bool inplace) {
6261
  GGML_ASSERT((mode & 1) == 0 && "mode & 1 == 1 is no longer supported");
6262
 
 
6277
 
6278
  struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
6279
 
6280
+ int32_t params[11] = { /*n_past*/ 0, n_dims, mode, /*n_ctx*/ 0, n_ctx_orig };
6281
  memcpy(params + 5, &freq_base, sizeof(float));
6282
  memcpy(params + 6, &freq_scale, sizeof(float));
6283
  memcpy(params + 7, &ext_factor, sizeof(float));
6284
  memcpy(params + 8, &attn_factor, sizeof(float));
6285
  memcpy(params + 9, &beta_fast, sizeof(float));
6286
  memcpy(params + 10, &beta_slow, sizeof(float));
 
 
6287
  ggml_set_op_params(result, params, sizeof(params));
6288
 
6289
  result->op = GGML_OP_ROPE;
 
6300
  struct ggml_tensor * a,
6301
  struct ggml_tensor * b,
6302
  int n_dims,
6303
+ int mode) {
 
6304
  return ggml_rope_impl(
6305
+ ctx, a, b, NULL, n_dims, mode, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, false
6306
  );
6307
  }
6308
 
 
6311
  struct ggml_tensor * a,
6312
  struct ggml_tensor * b,
6313
  int n_dims,
6314
+ int mode) {
 
6315
  return ggml_rope_impl(
6316
+ ctx, a, b, NULL, n_dims, mode, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, true
6317
  );
6318
  }
6319
 
 
6324
  struct ggml_tensor * c,
6325
  int n_dims,
6326
  int mode,
6327
+ int n_ctx_orig,
 
6328
  float freq_base,
6329
  float freq_scale,
6330
  float ext_factor,
 
6332
  float beta_fast,
6333
  float beta_slow) {
6334
  return ggml_rope_impl(
6335
+ ctx, a, b, c, n_dims, mode, n_ctx_orig, freq_base, freq_scale,
6336
+ ext_factor, attn_factor, beta_fast, beta_slow, false
6337
  );
6338
  }
6339
 
 
6344
  struct ggml_tensor * c,
6345
  int n_dims,
6346
  int mode,
6347
+ int n_ctx_orig,
 
6348
  float freq_base,
6349
  float freq_scale,
6350
  float ext_factor,
 
6352
  float beta_fast,
6353
  float beta_slow) {
6354
  return ggml_rope_impl(
6355
+ ctx, a, b, c, n_dims, mode, n_ctx_orig, freq_base, freq_scale,
6356
+ ext_factor, attn_factor, beta_fast, beta_slow, true
6357
  );
6358
  }
6359
 
 
6363
  struct ggml_tensor * b,
6364
  int n_dims,
6365
  int mode,
6366
+ int n_ctx_orig,
 
6367
  float freq_base,
6368
  float freq_scale,
6369
  float ext_factor,
 
6371
  float beta_fast,
6372
  float beta_slow) {
6373
  return ggml_rope_impl(
6374
+ ctx, a, b, NULL, n_dims, mode, n_ctx_orig, freq_base, freq_scale,
6375
+ ext_factor, attn_factor, beta_fast, beta_slow, false
6376
  );
6377
  }
6378
 
 
6382
  struct ggml_tensor * b,
6383
  int n_dims,
6384
  int mode,
6385
+ int n_ctx_orig,
 
6386
  float freq_base,
6387
  float freq_scale,
6388
  float ext_factor,
 
6390
  float beta_fast,
6391
  float beta_slow) {
6392
  return ggml_rope_impl(
6393
+ ctx, a, b, NULL, n_dims, mode, n_ctx_orig, freq_base, freq_scale,
6394
+ ext_factor, attn_factor, beta_fast, beta_slow, true
6395
  );
6396
  }
6397
 
 
 
 
 
 
 
 
 
 
 
6398
  // ggml_rope_back
6399
 
6400
  struct ggml_tensor * ggml_rope_back(
 
6404
  struct ggml_tensor * c,
6405
  int n_dims,
6406
  int mode,
6407
+ int n_ctx_orig,
 
6408
  float freq_base,
6409
  float freq_scale,
6410
  float ext_factor,
6411
  float attn_factor,
6412
  float beta_fast,
6413
+ float beta_slow) {
 
 
6414
  GGML_ASSERT(ggml_is_vector(b));
6415
  GGML_ASSERT(b->type == GGML_TYPE_I32);
6416
  GGML_ASSERT(a->ne[2] == b->ne[0]);
 
6426
 
6427
  struct ggml_tensor * result = ggml_dup_tensor(ctx, a);
6428
 
6429
+ int32_t params[11] = { /*n_past*/ 0, n_dims, mode, /*n_ctx*/ 0, n_ctx_orig };
6430
  memcpy(params + 5, &freq_base, sizeof(float));
6431
  memcpy(params + 6, &freq_scale, sizeof(float));
6432
  memcpy(params + 7, &ext_factor, sizeof(float));
6433
  memcpy(params + 8, &attn_factor, sizeof(float));
6434
  memcpy(params + 9, &beta_fast, sizeof(float));
6435
  memcpy(params + 10, &beta_slow, sizeof(float));
 
 
6436
  ggml_set_op_params(result, params, sizeof(params));
6437
 
6438
  result->op = GGML_OP_ROPE_BACK;
 
14201
  // MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng.
14202
  static void rope_yarn(
14203
  float theta_extrap, float freq_scale, float corr_dims[2], int64_t i0, float ext_factor, float mscale,
14204
+ float * cos_theta, float * sin_theta) {
 
14205
  // Get n-d rotational scaling corrected for extrapolation
14206
  float theta_interp = freq_scale * theta_extrap;
14207
  float theta = theta_interp;
 
14218
 
14219
  // Apparently solving `n_rot = 2pi * x * base^((2 * max_pos_emb) / n_dims)` for x, we get
14220
  // `corr_dim(n_rot) = n_dims * log(max_pos_emb / (n_rot * 2pi)) / (2 * log(base))`
14221
+ static float ggml_rope_yarn_corr_dim(int n_dims, int n_ctx_orig, float n_rot, float base) {
14222
+ return n_dims * logf(n_ctx_orig / (n_rot * 2 * (float)M_PI)) / (2 * logf(base));
14223
  }
14224
 
14225
  static void ggml_rope_cache_init(
14226
+ float theta_base, float freq_scale, const float * freq_factors, float corr_dims[2], int64_t ne0, float ext_factor, float mscale,
14227
+ float * cache, float sin_sign, float theta_scale) {
14228
+ // ref: https://github.com/jquesnelle/yarn/blob/master/scaled_rope/LlamaYaRNScaledRotaryEmbedding.py
14229
  float theta = theta_base;
14230
  for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
14231
+ const float ff = freq_factors ? freq_factors[i0/2] : 1.0f;
14232
  rope_yarn(
14233
+ theta/ff, freq_scale, corr_dims, i0, ext_factor, mscale, &cache[i0 + 0], &cache[i0 + 1]
14234
  );
14235
  cache[i0 + 1] *= sin_sign;
14236
 
 
14239
  }
14240
 
14241
  GGML_CALL void ggml_rope_yarn_corr_dims(
14242
+ int n_dims, int n_ctx_orig, float freq_base, float beta_fast, float beta_slow, float dims[2]
14243
  ) {
14244
  // start and end correction dims
14245
+ float start = floorf(ggml_rope_yarn_corr_dim(n_dims, n_ctx_orig, beta_fast, freq_base));
14246
+ float end = ceilf(ggml_rope_yarn_corr_dim(n_dims, n_ctx_orig, beta_slow, freq_base));
14247
  dims[0] = MAX(0, start);
14248
  dims[1] = MIN(n_dims - 1, end);
14249
  }
 
14263
 
14264
  float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
14265
 
 
 
 
 
14266
  //const int n_past = ((int32_t *) dst->op_params)[0];
14267
  const int n_dims = ((int32_t *) dst->op_params)[1];
14268
  const int mode = ((int32_t *) dst->op_params)[2];
14269
+ //const int n_ctx = ((int32_t *) dst->op_params)[3];
14270
+ const int n_ctx_orig = ((int32_t *) dst->op_params)[4];
14271
 
14272
  memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float));
14273
  memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float));
 
14275
  memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float));
14276
  memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
14277
  memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
 
 
14278
 
14279
  GGML_TENSOR_UNARY_OP_LOCALS
14280
 
 
14304
  const float theta_scale = powf(freq_base, -2.0f/n_dims);
14305
 
14306
  float corr_dims[2];
14307
+ ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims);
14308
 
14309
  const bool is_neox = mode & 2;
 
14310
 
14311
  const float * freq_factors = NULL;
14312
+ if (src2 != NULL) {
14313
+ GGML_ASSERT(src2->type == GGML_TYPE_F32);
14314
+ GGML_ASSERT(src2->ne[0] >= n_dims / 2);
14315
+ freq_factors = (const float *) src2->data;
 
 
 
 
14316
  }
14317
 
14318
  // backward process uses inverse rotation by cos and sin.
 
14327
  const int64_t p = pos[i2];
14328
 
14329
  float * cache = (float *) params->wdata + (ne0 + CACHE_LINE_SIZE_F32)*ith;
14330
+ ggml_rope_cache_init(p, freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
 
 
14331
 
14332
  for (int64_t i1 = 0; i1 < ne1; i1++) {
14333
  if (ir++ < ir0) continue;
14334
  if (ir > ir1) break;
14335
 
14336
+ if (!is_neox) {
14337
+ for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14338
  const float cos_theta = cache[i0 + 0];
14339
  const float sin_theta = cache[i0 + 1];
14340
 
 
 
 
 
14341
  const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
14342
  float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
14343
 
14344
  const float x0 = src[0];
14345
  const float x1 = src[1];
14346
 
14347
+ dst_data[0] = x0*cos_theta - x1*sin_theta;
14348
+ dst_data[1] = x0*sin_theta + x1*cos_theta;
14349
  }
14350
  } else {
14351
+ for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
14352
+ const int64_t ic = i0/2;
 
 
 
 
 
 
 
 
 
 
14353
 
14354
+ const float cos_theta = cache[i0 + 0];
14355
+ const float sin_theta = cache[i0 + 1];
14356
 
14357
+ const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
14358
+ float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
14359
 
14360
+ const float x0 = src[0];
14361
+ const float x1 = src[n_dims/2];
14362
 
14363
+ dst_data[0] = x0*cos_theta - x1*sin_theta;
14364
+ dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta;
14365
+ }
14366
+ }
14367
 
14368
+ for (int64_t i0 = n_dims; i0 < ne0; i0 += 2) {
14369
+ const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
14370
+ float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
14371
 
14372
+ dst_data[0] = src[0];
14373
+ dst_data[1] = src[1];
 
 
14374
  }
14375
  }
14376
  }
 
14396
  //const int n_past = ((int32_t *) dst->op_params)[0];
14397
  const int n_dims = ((int32_t *) dst->op_params)[1];
14398
  const int mode = ((int32_t *) dst->op_params)[2];
14399
+ //const int n_ctx = ((int32_t *) dst->op_params)[3];
14400
+ const int n_ctx_orig = ((int32_t *) dst->op_params)[4];
14401
  memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float));
14402
  memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float));
14403
  memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float));
 
14433
  const float theta_scale = powf(freq_base, -2.0f/n_dims);
14434
 
14435
  float corr_dims[2];
14436
+ ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims);
14437
 
14438
  const bool is_neox = mode & 2;
 
14439
 
14440
  const float * freq_factors = NULL;
14441
+ if (src2 != NULL) {
14442
+ GGML_ASSERT(src2->type == GGML_TYPE_F32);
14443
+ GGML_ASSERT(src2->ne[0] >= n_dims / 2);
14444
+ freq_factors = (const float *) src2->data;
 
 
 
 
14445
  }
14446
 
14447
  // backward process uses inverse rotation by cos and sin.
 
14456
  const int64_t p = pos[i2];
14457
 
14458
  float * cache = (float *) params->wdata + (ne0 + CACHE_LINE_SIZE_F32)*ith;
14459
+ ggml_rope_cache_init(p, freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
 
 
14460
 
14461
  for (int64_t i1 = 0; i1 < ne1; i1++) {
14462
  if (ir++ < ir0) continue;
14463
  if (ir > ir1) break;
14464
 
14465
+ if (!is_neox) {
14466
+ for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14467
  const float cos_theta = cache[i0 + 0];
14468
  const float sin_theta = cache[i0 + 1];
14469
 
 
14477
  dst_data[1] = GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
14478
  }
14479
  } else {
14480
+ for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
14481
+ const int64_t ic = i0/2;
 
 
 
 
14482
 
14483
+ const float cos_theta = cache[i0 + 0];
14484
+ const float sin_theta = cache[i0 + 1];
 
 
 
 
 
 
14485
 
14486
+ const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
14487
+ ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
14488
 
14489
+ const float x0 = GGML_FP16_TO_FP32(src[0]);
14490
+ const float x1 = GGML_FP16_TO_FP32(src[n_dims/2]);
14491
 
14492
+ dst_data[0] = GGML_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
14493
+ dst_data[n_dims/2] = GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
14494
+ }
14495
+ }
14496
 
14497
+ for (int64_t i0 = n_dims; i0 < ne0; i0 += 2) {
14498
+ const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
14499
+ ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
14500
 
14501
+ dst_data[0] = src[0];
14502
+ dst_data[1] = src[1];
 
 
14503
  }
14504
  }
14505
  }
 
18201
  //const int n_past = ((int32_t *) tensor->op_params)[0];
18202
  const int n_dims = ((int32_t *) tensor->op_params)[1];
18203
  const int mode = ((int32_t *) tensor->op_params)[2];
18204
+ //const int n_ctx = ((int32_t *) tensor->op_params)[3];
18205
+ const int n_ctx_orig = ((int32_t *) tensor->op_params)[4];
18206
+ float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
18207
 
18208
  memcpy(&freq_base, (int32_t *) tensor->op_params + 5, sizeof(float));
18209
  memcpy(&freq_scale, (int32_t *) tensor->op_params + 6, sizeof(float));
 
18211
  memcpy(&attn_factor, (int32_t *) tensor->op_params + 8, sizeof(float));
18212
  memcpy(&beta_fast, (int32_t *) tensor->op_params + 9, sizeof(float));
18213
  memcpy(&beta_slow, (int32_t *) tensor->op_params + 10, sizeof(float));
 
 
18214
 
18215
  src0->grad = ggml_add_or_set(ctx,
18216
  src0->grad,
 
18220
  src2,
18221
  n_dims,
18222
  mode,
18223
+ n_ctx_orig,
 
18224
  freq_base,
18225
  freq_scale,
18226
  ext_factor,
18227
  attn_factor,
18228
  beta_fast,
18229
+ beta_slow),
 
 
18230
  zero_table);
18231
  }
18232
  } break;
 
18236
  //const int n_past = ((int32_t *) tensor->op_params)[0];
18237
  const int n_dims = ((int32_t *) tensor->op_params)[1];
18238
  const int mode = ((int32_t *) tensor->op_params)[2];
18239
+ //const int n_ctx = ((int32_t *) tensor->op_params)[3];
18240
+ const int n_ctx_orig = ((int32_t *) tensor->op_params)[4];
18241
+ float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
18242
 
18243
  memcpy(&freq_base, (int32_t *) tensor->op_params + 5, sizeof(float));
18244
  memcpy(&freq_scale, (int32_t *) tensor->op_params + 6, sizeof(float));
 
18246
  memcpy(&attn_factor, (int32_t *) tensor->op_params + 8, sizeof(float));
18247
  memcpy(&beta_fast, (int32_t *) tensor->op_params + 9, sizeof(float));
18248
  memcpy(&beta_slow, (int32_t *) tensor->op_params + 10, sizeof(float));
 
 
18249
 
18250
  src0->grad = ggml_add_or_set(ctx,
18251
  src0->grad,
 
18255
  src2,
18256
  n_dims,
18257
  mode,
18258
+ n_ctx_orig,
 
18259
  freq_base,
18260
  freq_scale,
18261
  ext_factor,
18262
  attn_factor,
18263
  beta_fast,
18264
  beta_slow,
 
 
18265
  false),
18266
  zero_table);
18267
  }
ggml.h CHANGED
@@ -1465,7 +1465,6 @@ extern "C" {
1465
  // rotary position embedding
1466
  // if mode & 1 == 1, skip n_past elements (NOT SUPPORTED)
1467
  // if mode & 2 == 1, GPT-NeoX style
1468
- // if mode & 4 == 1, ChatGLM style
1469
  //
1470
  // b is an int32 vector with size a->ne[2], it contains the positions
1471
  // c is freq factors (e.g. phi3-128k), (optional)
@@ -1474,8 +1473,7 @@ extern "C" {
1474
  struct ggml_tensor * a,
1475
  struct ggml_tensor * b,
1476
  int n_dims,
1477
- int mode,
1478
- int n_ctx);
1479
 
1480
  // in-place, returns view(a)
1481
  GGML_API struct ggml_tensor * ggml_rope_inplace(
@@ -1483,8 +1481,7 @@ extern "C" {
1483
  struct ggml_tensor * a,
1484
  struct ggml_tensor * b,
1485
  int n_dims,
1486
- int mode,
1487
- int n_ctx);
1488
 
1489
  // custom RoPE
1490
  GGML_API struct ggml_tensor * ggml_rope_ext(
@@ -1494,8 +1491,7 @@ extern "C" {
1494
  struct ggml_tensor * c,
1495
  int n_dims,
1496
  int mode,
1497
- int n_ctx,
1498
- int n_orig_ctx,
1499
  float freq_base,
1500
  float freq_scale,
1501
  float ext_factor,
@@ -1511,8 +1507,7 @@ extern "C" {
1511
  struct ggml_tensor * c,
1512
  int n_dims,
1513
  int mode,
1514
- int n_ctx,
1515
- int n_orig_ctx,
1516
  float freq_base,
1517
  float freq_scale,
1518
  float ext_factor,
@@ -1526,8 +1521,7 @@ extern "C" {
1526
  struct ggml_tensor * b,
1527
  int n_dims,
1528
  int mode,
1529
- int n_ctx,
1530
- int n_orig_ctx,
1531
  float freq_base,
1532
  float freq_scale,
1533
  float ext_factor,
@@ -1542,8 +1536,7 @@ extern "C" {
1542
  struct ggml_tensor * b,
1543
  int n_dims,
1544
  int mode,
1545
- int n_ctx,
1546
- int n_orig_ctx,
1547
  float freq_base,
1548
  float freq_scale,
1549
  float ext_factor,
@@ -1552,17 +1545,9 @@ extern "C" {
1552
  float beta_slow),
1553
  "use ggml_rope_ext_inplace instead");
1554
 
1555
- struct ggml_tensor * ggml_rope_xpos_inplace(
1556
- struct ggml_context * ctx,
1557
- struct ggml_tensor * a,
1558
- struct ggml_tensor * b,
1559
- int n_dims,
1560
- float base,
1561
- bool down);
1562
-
1563
  // compute correction dims for YaRN RoPE scaling
1564
  GGML_CALL void ggml_rope_yarn_corr_dims(
1565
- int n_dims, int n_orig_ctx, float freq_base, float beta_fast, float beta_slow, float dims[2]);
1566
 
1567
  // rotary position embedding backward, i.e compute dx from dy
1568
  // a - dy
@@ -1573,16 +1558,13 @@ extern "C" {
1573
  struct ggml_tensor * c,
1574
  int n_dims,
1575
  int mode,
1576
- int n_ctx,
1577
- int n_orig_ctx,
1578
  float freq_base,
1579
  float freq_scale,
1580
  float ext_factor,
1581
  float attn_factor,
1582
  float beta_fast,
1583
- float beta_slow,
1584
- float xpos_base,
1585
- bool xpos_down);
1586
 
1587
  // clamp
1588
  // in-place, returns view(a)
 
1465
  // rotary position embedding
1466
  // if mode & 1 == 1, skip n_past elements (NOT SUPPORTED)
1467
  // if mode & 2 == 1, GPT-NeoX style
 
1468
  //
1469
  // b is an int32 vector with size a->ne[2], it contains the positions
1470
  // c is freq factors (e.g. phi3-128k), (optional)
 
1473
  struct ggml_tensor * a,
1474
  struct ggml_tensor * b,
1475
  int n_dims,
1476
+ int mode);
 
1477
 
1478
  // in-place, returns view(a)
1479
  GGML_API struct ggml_tensor * ggml_rope_inplace(
 
1481
  struct ggml_tensor * a,
1482
  struct ggml_tensor * b,
1483
  int n_dims,
1484
+ int mode);
 
1485
 
1486
  // custom RoPE
1487
  GGML_API struct ggml_tensor * ggml_rope_ext(
 
1491
  struct ggml_tensor * c,
1492
  int n_dims,
1493
  int mode,
1494
+ int n_ctx_orig,
 
1495
  float freq_base,
1496
  float freq_scale,
1497
  float ext_factor,
 
1507
  struct ggml_tensor * c,
1508
  int n_dims,
1509
  int mode,
1510
+ int n_ctx_orig,
 
1511
  float freq_base,
1512
  float freq_scale,
1513
  float ext_factor,
 
1521
  struct ggml_tensor * b,
1522
  int n_dims,
1523
  int mode,
1524
+ int n_ctx_orig,
 
1525
  float freq_base,
1526
  float freq_scale,
1527
  float ext_factor,
 
1536
  struct ggml_tensor * b,
1537
  int n_dims,
1538
  int mode,
1539
+ int n_ctx_orig,
 
1540
  float freq_base,
1541
  float freq_scale,
1542
  float ext_factor,
 
1545
  float beta_slow),
1546
  "use ggml_rope_ext_inplace instead");
1547
 
 
 
 
 
 
 
 
 
1548
  // compute correction dims for YaRN RoPE scaling
1549
  GGML_CALL void ggml_rope_yarn_corr_dims(
1550
+ int n_dims, int n_ctx_orig, float freq_base, float beta_fast, float beta_slow, float dims[2]);
1551
 
1552
  // rotary position embedding backward, i.e compute dx from dy
1553
  // a - dy
 
1558
  struct ggml_tensor * c,
1559
  int n_dims,
1560
  int mode,
1561
+ int n_ctx_orig,
 
1562
  float freq_base,
1563
  float freq_scale,
1564
  float ext_factor,
1565
  float attn_factor,
1566
  float beta_fast,
1567
+ float beta_slow);
 
 
1568
 
1569
  // clamp
1570
  // in-place, returns view(a)