Spaces:
Sleeping
Sleeping
ggml : refactor rope norm/neox (llama/7634)
Browse files* ggml : unify rope norm/neox (CPU)
* ggml : fix compile warning
* ggml : remove GLM rope mode
ggml-ci
* metal : better rope implementation
ggml-ci
* cuda : better rope implementation
ggml-ci
* naming : n_orig_ctx -> n_ctx_orig
ggml-ci
* dev : add reminders to update backends
ggml-ci
* vulkan : fix ggml_rope_ext() usage
* cuda : fix array size + indents
ggml-ci
- ggml-cuda/rope.cu +107 -166
- ggml-kompute.cpp +8 -5
- ggml-metal.m +28 -24
- ggml-metal.metal +83 -59
- ggml-sycl.cpp +7 -67
- ggml-vulkan.cpp +8 -14
- ggml.c +97 -233
- ggml.h +9 -27
ggml-cuda/rope.cu
CHANGED
|
@@ -1,7 +1,7 @@
|
|
| 1 |
#include "rope.cuh"
|
| 2 |
|
| 3 |
struct rope_corr_dims {
|
| 4 |
-
float v[
|
| 5 |
};
|
| 6 |
|
| 7 |
static __device__ float rope_yarn_ramp(const float low, const float high, const int i0) {
|
|
@@ -13,8 +13,7 @@ static __device__ float rope_yarn_ramp(const float low, const float high, const
|
|
| 13 |
// MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng.
|
| 14 |
static __device__ void rope_yarn(
|
| 15 |
float theta_extrap, float freq_scale, rope_corr_dims corr_dims, int64_t i0, float ext_factor, float mscale,
|
| 16 |
-
float * cos_theta, float * sin_theta
|
| 17 |
-
) {
|
| 18 |
// Get n-d rotational scaling corrected for extrapolation
|
| 19 |
float theta_interp = freq_scale * theta_extrap;
|
| 20 |
float theta = theta_interp;
|
|
@@ -29,27 +28,38 @@ static __device__ void rope_yarn(
|
|
| 29 |
*sin_theta = sinf(theta) * mscale;
|
| 30 |
}
|
| 31 |
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
) {
|
| 38 |
-
const int col = 2*(blockDim.y*blockIdx.y + threadIdx.y);
|
| 39 |
|
| 40 |
-
if (
|
| 41 |
return;
|
| 42 |
}
|
| 43 |
|
| 44 |
const int row = blockDim.x*blockIdx.x + threadIdx.x;
|
| 45 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 46 |
const int i2 = row/p_delta_rows;
|
| 47 |
|
| 48 |
-
const
|
| 49 |
-
|
|
|
|
| 50 |
|
| 51 |
-
float cos_theta
|
| 52 |
-
|
|
|
|
|
|
|
| 53 |
|
| 54 |
const float x0 = x[i + 0];
|
| 55 |
const float x1 = x[i + 1];
|
|
@@ -58,23 +68,20 @@ static __global__ void rope(
|
|
| 58 |
dst[i + 1] = x0*sin_theta + x1*cos_theta;
|
| 59 |
}
|
| 60 |
|
| 61 |
-
template<typename T, bool
|
| 62 |
static __global__ void rope_neox(
|
| 63 |
-
const T * x, T * dst, int
|
| 64 |
-
float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, const float * freq_factors
|
| 65 |
-
)
|
| 66 |
-
const int col = 2*(blockDim.y*blockIdx.y + threadIdx.y);
|
| 67 |
|
| 68 |
-
if (
|
| 69 |
return;
|
| 70 |
}
|
| 71 |
|
| 72 |
const int row = blockDim.x*blockIdx.x + threadIdx.x;
|
| 73 |
-
const int ib = col / n_dims;
|
| 74 |
-
const int ic = col % n_dims;
|
| 75 |
|
| 76 |
-
if (
|
| 77 |
-
const int i = row*
|
| 78 |
|
| 79 |
dst[i + 0] = x[i + 0];
|
| 80 |
dst[i + 1] = x[i + 1];
|
|
@@ -82,16 +89,17 @@ static __global__ void rope_neox(
|
|
| 82 |
return;
|
| 83 |
}
|
| 84 |
|
| 85 |
-
const int i = row*
|
| 86 |
const int i2 = row/p_delta_rows;
|
| 87 |
|
| 88 |
-
const
|
| 89 |
-
const float freq_factor = has_freq_facs ? freq_factors[ic/2] : 1.0f;
|
| 90 |
|
| 91 |
-
const float
|
| 92 |
|
| 93 |
-
float cos_theta
|
| 94 |
-
|
|
|
|
|
|
|
| 95 |
|
| 96 |
const float x0 = x[i + 0];
|
| 97 |
const float x1 = x[i + n_dims/2];
|
|
@@ -100,144 +108,81 @@ static __global__ void rope_neox(
|
|
| 100 |
dst[i + n_dims/2] = x0*sin_theta + x1*cos_theta;
|
| 101 |
}
|
| 102 |
|
| 103 |
-
static __global__ void rope_glm_f32(
|
| 104 |
-
const float * x, float * dst, int ncols, const int32_t * pos, float freq_scale, int p_delta_rows, float freq_base,
|
| 105 |
-
int n_ctx
|
| 106 |
-
) {
|
| 107 |
-
const int col = blockDim.x*blockIdx.x + threadIdx.x;
|
| 108 |
-
const int half_n_dims = ncols/4;
|
| 109 |
-
|
| 110 |
-
if (col >= half_n_dims) {
|
| 111 |
-
return;
|
| 112 |
-
}
|
| 113 |
-
|
| 114 |
-
const int row = blockDim.y*blockIdx.y + threadIdx.y;
|
| 115 |
-
const int i = row*ncols + col;
|
| 116 |
-
const int i2 = row/p_delta_rows;
|
| 117 |
-
|
| 118 |
-
const float col_theta_scale = powf(freq_base, -2.0f*col/ncols);
|
| 119 |
-
// FIXME: this is likely wrong
|
| 120 |
-
const int p = pos != nullptr ? pos[i2] : 0;
|
| 121 |
-
|
| 122 |
-
const float theta = min(p, n_ctx - 2)*freq_scale*col_theta_scale;
|
| 123 |
-
const float sin_theta = sinf(theta);
|
| 124 |
-
const float cos_theta = cosf(theta);
|
| 125 |
-
|
| 126 |
-
const float x0 = x[i + 0];
|
| 127 |
-
const float x1 = x[i + half_n_dims];
|
| 128 |
-
|
| 129 |
-
dst[i + 0] = x0*cos_theta - x1*sin_theta;
|
| 130 |
-
dst[i + half_n_dims] = x0*sin_theta + x1*cos_theta;
|
| 131 |
-
|
| 132 |
-
const float block_theta = ((float)max(p - n_ctx - 2, 0))*col_theta_scale;
|
| 133 |
-
const float sin_block_theta = sinf(block_theta);
|
| 134 |
-
const float cos_block_theta = cosf(block_theta);
|
| 135 |
-
|
| 136 |
-
const float x2 = x[i + half_n_dims * 2];
|
| 137 |
-
const float x3 = x[i + half_n_dims * 3];
|
| 138 |
-
|
| 139 |
-
dst[i + half_n_dims * 2] = x2*cos_block_theta - x3*sin_block_theta;
|
| 140 |
-
dst[i + half_n_dims * 3] = x2*sin_block_theta + x3*cos_block_theta;
|
| 141 |
-
}
|
| 142 |
-
|
| 143 |
-
|
| 144 |
template<typename T>
|
| 145 |
-
static void
|
| 146 |
-
const T * x, T * dst, int
|
| 147 |
-
float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, cudaStream_t stream
|
| 148 |
-
)
|
| 149 |
-
GGML_ASSERT(ncols % 2 == 0);
|
| 150 |
const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
|
| 151 |
-
const int
|
| 152 |
-
const dim3 block_nums(
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 157 |
} else {
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
|
|
|
|
| 161 |
}
|
| 162 |
}
|
| 163 |
|
| 164 |
template<typename T>
|
| 165 |
static void rope_neox_cuda(
|
| 166 |
-
const T * x, T * dst, int
|
| 167 |
-
float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream
|
| 168 |
-
)
|
| 169 |
-
GGML_ASSERT(ncols % 2 == 0);
|
| 170 |
const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
|
| 171 |
-
const int
|
| 172 |
-
const dim3 block_nums(
|
| 173 |
|
| 174 |
const float theta_scale = powf(freq_base, -2.0f/n_dims);
|
| 175 |
|
| 176 |
-
if (
|
| 177 |
-
|
| 178 |
-
|
| 179 |
-
x, dst, ncols, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
|
| 180 |
-
theta_scale, freq_factors
|
| 181 |
-
);
|
| 182 |
-
} else {
|
| 183 |
-
rope_neox<T, false, true><<<block_nums, block_dims, 0, stream>>>(
|
| 184 |
-
x, dst, ncols, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
|
| 185 |
theta_scale, freq_factors
|
| 186 |
);
|
| 187 |
-
}
|
| 188 |
} else {
|
| 189 |
-
|
| 190 |
-
|
| 191 |
-
x, dst, ncols, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
|
| 192 |
-
theta_scale, freq_factors
|
| 193 |
-
);
|
| 194 |
-
} else {
|
| 195 |
-
rope_neox<T, true, true><<<block_nums, block_dims, 0, stream>>>(
|
| 196 |
-
x, dst, ncols, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
|
| 197 |
theta_scale, freq_factors
|
| 198 |
);
|
| 199 |
-
}
|
| 200 |
}
|
| 201 |
}
|
| 202 |
|
| 203 |
-
static void
|
| 204 |
-
const
|
| 205 |
-
float freq_base,
|
| 206 |
-
) {
|
| 207 |
-
GGML_ASSERT(ncols % 4 == 0);
|
| 208 |
-
const dim3 block_dims(CUDA_ROPE_BLOCK_SIZE/4, 1, 1);
|
| 209 |
-
const int num_blocks_x = (ncols + CUDA_ROPE_BLOCK_SIZE - 1) / CUDA_ROPE_BLOCK_SIZE;
|
| 210 |
-
const dim3 block_nums(num_blocks_x, nrows, 1);
|
| 211 |
-
rope_glm_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols, pos, freq_scale, p_delta_rows, freq_base, n_ctx);
|
| 212 |
-
}
|
| 213 |
-
|
| 214 |
-
static void rope_cuda_f16(
|
| 215 |
-
const half * x, half * dst, int ncols, int nrows, const int32_t * pos, float freq_scale, int p_delta_rows,
|
| 216 |
-
float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, cudaStream_t stream) {
|
| 217 |
|
| 218 |
-
|
| 219 |
}
|
| 220 |
|
| 221 |
-
static void
|
| 222 |
-
const float * x, float * dst, int
|
| 223 |
-
float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, cudaStream_t stream) {
|
| 224 |
|
| 225 |
-
|
| 226 |
}
|
| 227 |
|
| 228 |
static void rope_neox_cuda_f16(
|
| 229 |
-
const half * x, half * dst, int
|
| 230 |
float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream) {
|
| 231 |
|
| 232 |
-
rope_neox_cuda<half>(x, dst,
|
| 233 |
}
|
| 234 |
|
| 235 |
static void rope_neox_cuda_f32(
|
| 236 |
-
const float * x, float * dst, int
|
| 237 |
float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream
|
| 238 |
) {
|
| 239 |
|
| 240 |
-
rope_neox_cuda<float>(x, dst,
|
| 241 |
}
|
| 242 |
|
| 243 |
void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
|
@@ -258,16 +203,22 @@ void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
|
| 258 |
|
| 259 |
const int64_t ne00 = src0->ne[0];
|
| 260 |
const int64_t ne01 = src0->ne[1];
|
| 261 |
-
const int64_t
|
| 262 |
|
| 263 |
-
//const int n_past
|
| 264 |
-
const int n_dims
|
| 265 |
-
const int mode
|
| 266 |
-
const int n_ctx
|
| 267 |
-
const int
|
| 268 |
|
| 269 |
// RoPE alteration for extended context
|
| 270 |
-
float freq_base
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 271 |
memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float));
|
| 272 |
memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float));
|
| 273 |
memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float));
|
|
@@ -275,38 +226,28 @@ void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
|
| 275 |
memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
|
| 276 |
memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
|
| 277 |
|
| 278 |
-
const float * freq_factors = nullptr;
|
| 279 |
-
const int32_t * pos = nullptr;
|
| 280 |
-
|
| 281 |
const bool is_neox = mode & 2;
|
| 282 |
-
const bool is_glm = mode & 4;
|
| 283 |
|
| 284 |
-
pos = (const int32_t *) src1_d;
|
| 285 |
|
| 286 |
-
|
| 287 |
-
|
| 288 |
-
|
| 289 |
-
}
|
| 290 |
-
} else {
|
| 291 |
-
GGML_ASSERT(src2 == nullptr && "TODO: freq_factors not implemented for !is_neox");
|
| 292 |
}
|
| 293 |
|
| 294 |
rope_corr_dims corr_dims;
|
| 295 |
-
ggml_rope_yarn_corr_dims(n_dims,
|
| 296 |
|
| 297 |
// compute
|
| 298 |
-
if (
|
| 299 |
-
GGML_ASSERT(false);
|
| 300 |
-
rope_glm_f32_cuda(src0_d, dst_d, ne00, nrows, pos, freq_scale, ne01, freq_base, n_ctx, stream);
|
| 301 |
-
} else if (is_neox) {
|
| 302 |
if (src0->type == GGML_TYPE_F32) {
|
| 303 |
rope_neox_cuda_f32(
|
| 304 |
-
(const float *)src0_d, (float *)dst_d, ne00, n_dims,
|
| 305 |
attn_factor, corr_dims, freq_factors, stream
|
| 306 |
);
|
| 307 |
} else if (src0->type == GGML_TYPE_F16) {
|
| 308 |
rope_neox_cuda_f16(
|
| 309 |
-
(const half *)src0_d, (half *)dst_d, ne00, n_dims,
|
| 310 |
attn_factor, corr_dims, freq_factors, stream
|
| 311 |
);
|
| 312 |
} else {
|
|
@@ -314,14 +255,14 @@ void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
|
| 314 |
}
|
| 315 |
} else {
|
| 316 |
if (src0->type == GGML_TYPE_F32) {
|
| 317 |
-
|
| 318 |
-
(const float *)src0_d, (float *)dst_d, ne00,
|
| 319 |
-
attn_factor, corr_dims, stream
|
| 320 |
);
|
| 321 |
} else if (src0->type == GGML_TYPE_F16) {
|
| 322 |
-
|
| 323 |
-
(const half *)src0_d, (half *)dst_d, ne00,
|
| 324 |
-
attn_factor, corr_dims, stream
|
| 325 |
);
|
| 326 |
} else {
|
| 327 |
GGML_ASSERT(false);
|
|
|
|
| 1 |
#include "rope.cuh"
|
| 2 |
|
| 3 |
struct rope_corr_dims {
|
| 4 |
+
float v[2];
|
| 5 |
};
|
| 6 |
|
| 7 |
static __device__ float rope_yarn_ramp(const float low, const float high, const int i0) {
|
|
|
|
| 13 |
// MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng.
|
| 14 |
static __device__ void rope_yarn(
|
| 15 |
float theta_extrap, float freq_scale, rope_corr_dims corr_dims, int64_t i0, float ext_factor, float mscale,
|
| 16 |
+
float * cos_theta, float * sin_theta) {
|
|
|
|
| 17 |
// Get n-d rotational scaling corrected for extrapolation
|
| 18 |
float theta_interp = freq_scale * theta_extrap;
|
| 19 |
float theta = theta_interp;
|
|
|
|
| 28 |
*sin_theta = sinf(theta) * mscale;
|
| 29 |
}
|
| 30 |
|
| 31 |
+
template<typename T, bool has_ff>
|
| 32 |
+
static __global__ void rope_norm(
|
| 33 |
+
const T * x, T * dst, int ne0, int n_dims, const int32_t * pos, float freq_scale, int p_delta_rows,
|
| 34 |
+
float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, const float * freq_factors) {
|
| 35 |
+
const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y);
|
|
|
|
|
|
|
| 36 |
|
| 37 |
+
if (i0 >= ne0) {
|
| 38 |
return;
|
| 39 |
}
|
| 40 |
|
| 41 |
const int row = blockDim.x*blockIdx.x + threadIdx.x;
|
| 42 |
+
|
| 43 |
+
if (i0 >= n_dims) {
|
| 44 |
+
const int i = row*ne0 + i0;
|
| 45 |
+
|
| 46 |
+
dst[i + 0] = x[i + 0];
|
| 47 |
+
dst[i + 1] = x[i + 1];
|
| 48 |
+
|
| 49 |
+
return;
|
| 50 |
+
}
|
| 51 |
+
|
| 52 |
+
const int i = row*ne0 + i0;
|
| 53 |
const int i2 = row/p_delta_rows;
|
| 54 |
|
| 55 |
+
const float theta_base = pos[i2]*powf(theta_scale, i0/2.0f);
|
| 56 |
+
|
| 57 |
+
const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f;
|
| 58 |
|
| 59 |
+
float cos_theta;
|
| 60 |
+
float sin_theta;
|
| 61 |
+
|
| 62 |
+
rope_yarn(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
|
| 63 |
|
| 64 |
const float x0 = x[i + 0];
|
| 65 |
const float x1 = x[i + 1];
|
|
|
|
| 68 |
dst[i + 1] = x0*sin_theta + x1*cos_theta;
|
| 69 |
}
|
| 70 |
|
| 71 |
+
template<typename T, bool has_ff>
|
| 72 |
static __global__ void rope_neox(
|
| 73 |
+
const T * x, T * dst, int ne0, int n_dims, const int32_t * pos, float freq_scale, int p_delta_rows,
|
| 74 |
+
float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, const float * freq_factors) {
|
| 75 |
+
const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y);
|
|
|
|
| 76 |
|
| 77 |
+
if (i0 >= ne0) {
|
| 78 |
return;
|
| 79 |
}
|
| 80 |
|
| 81 |
const int row = blockDim.x*blockIdx.x + threadIdx.x;
|
|
|
|
|
|
|
| 82 |
|
| 83 |
+
if (i0 >= n_dims) {
|
| 84 |
+
const int i = row*ne0 + i0;
|
| 85 |
|
| 86 |
dst[i + 0] = x[i + 0];
|
| 87 |
dst[i + 1] = x[i + 1];
|
|
|
|
| 89 |
return;
|
| 90 |
}
|
| 91 |
|
| 92 |
+
const int i = row*ne0 + i0/2;
|
| 93 |
const int i2 = row/p_delta_rows;
|
| 94 |
|
| 95 |
+
const float theta_base = pos[i2]*powf(theta_scale, i0/2.0f);
|
|
|
|
| 96 |
|
| 97 |
+
const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f;
|
| 98 |
|
| 99 |
+
float cos_theta;
|
| 100 |
+
float sin_theta;
|
| 101 |
+
|
| 102 |
+
rope_yarn(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
|
| 103 |
|
| 104 |
const float x0 = x[i + 0];
|
| 105 |
const float x1 = x[i + n_dims/2];
|
|
|
|
| 108 |
dst[i + n_dims/2] = x0*sin_theta + x1*cos_theta;
|
| 109 |
}
|
| 110 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 111 |
template<typename T>
|
| 112 |
+
static void rope_norm_cuda(
|
| 113 |
+
const T * x, T * dst, int ne0, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
|
| 114 |
+
float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream) {
|
| 115 |
+
GGML_ASSERT(ne0 % 2 == 0);
|
|
|
|
| 116 |
const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
|
| 117 |
+
const int n_blocks_x = (ne0 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
|
| 118 |
+
const dim3 block_nums(nr, n_blocks_x, 1);
|
| 119 |
+
|
| 120 |
+
const float theta_scale = powf(freq_base, -2.0f/n_dims);
|
| 121 |
+
|
| 122 |
+
if (freq_factors == nullptr) {
|
| 123 |
+
rope_norm<T, false><<<block_nums, block_dims, 0, stream>>>(
|
| 124 |
+
x, dst, ne0, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
|
| 125 |
+
theta_scale, freq_factors
|
| 126 |
+
);
|
| 127 |
} else {
|
| 128 |
+
rope_norm<T, true><<<block_nums, block_dims, 0, stream>>>(
|
| 129 |
+
x, dst, ne0, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
|
| 130 |
+
theta_scale, freq_factors
|
| 131 |
+
);
|
| 132 |
}
|
| 133 |
}
|
| 134 |
|
| 135 |
template<typename T>
|
| 136 |
static void rope_neox_cuda(
|
| 137 |
+
const T * x, T * dst, int ne0, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
|
| 138 |
+
float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream) {
|
| 139 |
+
GGML_ASSERT(ne0 % 2 == 0);
|
|
|
|
| 140 |
const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
|
| 141 |
+
const int n_blocks_x = (ne0 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
|
| 142 |
+
const dim3 block_nums(nr, n_blocks_x, 1);
|
| 143 |
|
| 144 |
const float theta_scale = powf(freq_base, -2.0f/n_dims);
|
| 145 |
|
| 146 |
+
if (freq_factors == nullptr) {
|
| 147 |
+
rope_neox<T, false><<<block_nums, block_dims, 0, stream>>>(
|
| 148 |
+
x, dst, ne0, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 149 |
theta_scale, freq_factors
|
| 150 |
);
|
|
|
|
| 151 |
} else {
|
| 152 |
+
rope_neox<T, true><<<block_nums, block_dims, 0, stream>>>(
|
| 153 |
+
x, dst, ne0, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 154 |
theta_scale, freq_factors
|
| 155 |
);
|
|
|
|
| 156 |
}
|
| 157 |
}
|
| 158 |
|
| 159 |
+
static void rope_norm_cuda_f16(
|
| 160 |
+
const half * x, half * dst, int ne0, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
|
| 161 |
+
float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream) {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 162 |
|
| 163 |
+
rope_norm_cuda<half>(x, dst, ne0, n_dims, nr, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream);
|
| 164 |
}
|
| 165 |
|
| 166 |
+
static void rope_norm_cuda_f32(
|
| 167 |
+
const float * x, float * dst, int ne0, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
|
| 168 |
+
float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream) {
|
| 169 |
|
| 170 |
+
rope_norm_cuda<float>(x, dst, ne0, n_dims, nr, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream);
|
| 171 |
}
|
| 172 |
|
| 173 |
static void rope_neox_cuda_f16(
|
| 174 |
+
const half * x, half * dst, int ne0, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
|
| 175 |
float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream) {
|
| 176 |
|
| 177 |
+
rope_neox_cuda<half>(x, dst, ne0, n_dims, nr, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream);
|
| 178 |
}
|
| 179 |
|
| 180 |
static void rope_neox_cuda_f32(
|
| 181 |
+
const float * x, float * dst, int ne0, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
|
| 182 |
float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream
|
| 183 |
) {
|
| 184 |
|
| 185 |
+
rope_neox_cuda<float>(x, dst, ne0, n_dims, nr, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream);
|
| 186 |
}
|
| 187 |
|
| 188 |
void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
|
|
|
| 203 |
|
| 204 |
const int64_t ne00 = src0->ne[0];
|
| 205 |
const int64_t ne01 = src0->ne[1];
|
| 206 |
+
const int64_t nr = ggml_nrows(src0);
|
| 207 |
|
| 208 |
+
//const int n_past = ((int32_t *) dst->op_params)[0];
|
| 209 |
+
const int n_dims = ((int32_t *) dst->op_params)[1];
|
| 210 |
+
const int mode = ((int32_t *) dst->op_params)[2];
|
| 211 |
+
//const int n_ctx = ((int32_t *) dst->op_params)[3];
|
| 212 |
+
const int n_ctx_orig = ((int32_t *) dst->op_params)[4];
|
| 213 |
|
| 214 |
// RoPE alteration for extended context
|
| 215 |
+
float freq_base;
|
| 216 |
+
float freq_scale;
|
| 217 |
+
float ext_factor;
|
| 218 |
+
float attn_factor;
|
| 219 |
+
float beta_fast;
|
| 220 |
+
float beta_slow;
|
| 221 |
+
|
| 222 |
memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float));
|
| 223 |
memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float));
|
| 224 |
memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float));
|
|
|
|
| 226 |
memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
|
| 227 |
memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
|
| 228 |
|
|
|
|
|
|
|
|
|
|
| 229 |
const bool is_neox = mode & 2;
|
|
|
|
| 230 |
|
| 231 |
+
const int32_t * pos = (const int32_t *) src1_d;
|
| 232 |
|
| 233 |
+
const float * freq_factors = nullptr;
|
| 234 |
+
if (src2 != nullptr) {
|
| 235 |
+
freq_factors = (const float *) src2->data;
|
|
|
|
|
|
|
|
|
|
| 236 |
}
|
| 237 |
|
| 238 |
rope_corr_dims corr_dims;
|
| 239 |
+
ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims.v);
|
| 240 |
|
| 241 |
// compute
|
| 242 |
+
if (is_neox) {
|
|
|
|
|
|
|
|
|
|
| 243 |
if (src0->type == GGML_TYPE_F32) {
|
| 244 |
rope_neox_cuda_f32(
|
| 245 |
+
(const float *)src0_d, (float *)dst_d, ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
|
| 246 |
attn_factor, corr_dims, freq_factors, stream
|
| 247 |
);
|
| 248 |
} else if (src0->type == GGML_TYPE_F16) {
|
| 249 |
rope_neox_cuda_f16(
|
| 250 |
+
(const half *)src0_d, (half *)dst_d, ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
|
| 251 |
attn_factor, corr_dims, freq_factors, stream
|
| 252 |
);
|
| 253 |
} else {
|
|
|
|
| 255 |
}
|
| 256 |
} else {
|
| 257 |
if (src0->type == GGML_TYPE_F32) {
|
| 258 |
+
rope_norm_cuda_f32(
|
| 259 |
+
(const float *)src0_d, (float *)dst_d, ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
|
| 260 |
+
attn_factor, corr_dims, freq_factors, stream
|
| 261 |
);
|
| 262 |
} else if (src0->type == GGML_TYPE_F16) {
|
| 263 |
+
rope_norm_cuda_f16(
|
| 264 |
+
(const half *)src0_d, (half *)dst_d, ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
|
| 265 |
+
attn_factor, corr_dims, freq_factors, stream
|
| 266 |
);
|
| 267 |
} else {
|
| 268 |
GGML_ASSERT(false);
|
ggml-kompute.cpp
CHANGED
|
@@ -1192,7 +1192,7 @@ static void ggml_vk_rope(
|
|
| 1192 |
const std::shared_ptr<kp::Tensor>& inB,
|
| 1193 |
const std::shared_ptr<kp::Tensor>& out,
|
| 1194 |
uint32_t inAOff, uint32_t inBOff, uint32_t outOff,
|
| 1195 |
-
ggml_type src0t, int32_t n_dims, int32_t mode, int32_t
|
| 1196 |
float freq_base, float freq_scale, float ext_factor, float attn_factor, float beta_fast, float beta_slow,
|
| 1197 |
int32_t ne01, int32_t ne02, int32_t ne03,
|
| 1198 |
uint32_t nb00, uint32_t nb01, uint32_t nb02, uint32_t nb03,
|
|
@@ -1221,14 +1221,14 @@ static void ggml_vk_rope(
|
|
| 1221 |
|
| 1222 |
struct PushConstants {
|
| 1223 |
uint32_t inAOff, inBOff, outOff;
|
| 1224 |
-
int32_t n_dims, mode,
|
| 1225 |
float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
|
| 1226 |
uint32_t nb00, nb01, nb02, nb03;
|
| 1227 |
int32_t ne0;
|
| 1228 |
uint32_t nb0, nb1, nb2, nb3;
|
| 1229 |
} pushConsts {
|
| 1230 |
safe_divide(inAOff, type_size), safe_divide(inBOff, 4), safe_divide(outOff, type_size),
|
| 1231 |
-
n_dims, mode,
|
| 1232 |
freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow,
|
| 1233 |
nb00, nb01, nb02, nb03,
|
| 1234 |
ne0,
|
|
@@ -1692,13 +1692,16 @@ static void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml
|
|
| 1692 |
#pragma message(" https://github.com/ggerganov/llama.cpp/pull/7225")
|
| 1693 |
GGML_ASSERT(dst->src[2] == nullptr && "phi3 frequency factors not implemented yet");
|
| 1694 |
|
|
|
|
|
|
|
|
|
|
| 1695 |
GGML_ASSERT(ne10 == ne02);
|
| 1696 |
GGML_ASSERT(src0t == dstt);
|
| 1697 |
// const int n_past = ((int32_t *) dst->op_params)[0];
|
| 1698 |
const int n_dims = ((int32_t *) dst->op_params)[1];
|
| 1699 |
const int mode = ((int32_t *) dst->op_params)[2];
|
| 1700 |
// skip 3, n_ctx used in GLM RoPE, unimplemented in Vulkan
|
| 1701 |
-
const int
|
| 1702 |
|
| 1703 |
float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
|
| 1704 |
memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float));
|
|
@@ -1708,7 +1711,7 @@ static void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml
|
|
| 1708 |
memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
|
| 1709 |
memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
|
| 1710 |
ggml_vk_rope(
|
| 1711 |
-
seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, src0t, n_dims, mode,
|
| 1712 |
freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow,
|
| 1713 |
ne01, ne02, ne03, nb00, nb01, nb02, nb03, ne0, nb0, nb1, nb2, nb3
|
| 1714 |
);
|
|
|
|
| 1192 |
const std::shared_ptr<kp::Tensor>& inB,
|
| 1193 |
const std::shared_ptr<kp::Tensor>& out,
|
| 1194 |
uint32_t inAOff, uint32_t inBOff, uint32_t outOff,
|
| 1195 |
+
ggml_type src0t, int32_t n_dims, int32_t mode, int32_t n_ctx_orig,
|
| 1196 |
float freq_base, float freq_scale, float ext_factor, float attn_factor, float beta_fast, float beta_slow,
|
| 1197 |
int32_t ne01, int32_t ne02, int32_t ne03,
|
| 1198 |
uint32_t nb00, uint32_t nb01, uint32_t nb02, uint32_t nb03,
|
|
|
|
| 1221 |
|
| 1222 |
struct PushConstants {
|
| 1223 |
uint32_t inAOff, inBOff, outOff;
|
| 1224 |
+
int32_t n_dims, mode, n_ctx_orig;
|
| 1225 |
float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
|
| 1226 |
uint32_t nb00, nb01, nb02, nb03;
|
| 1227 |
int32_t ne0;
|
| 1228 |
uint32_t nb0, nb1, nb2, nb3;
|
| 1229 |
} pushConsts {
|
| 1230 |
safe_divide(inAOff, type_size), safe_divide(inBOff, 4), safe_divide(outOff, type_size),
|
| 1231 |
+
n_dims, mode, n_ctx_orig,
|
| 1232 |
freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow,
|
| 1233 |
nb00, nb01, nb02, nb03,
|
| 1234 |
ne0,
|
|
|
|
| 1692 |
#pragma message(" https://github.com/ggerganov/llama.cpp/pull/7225")
|
| 1693 |
GGML_ASSERT(dst->src[2] == nullptr && "phi3 frequency factors not implemented yet");
|
| 1694 |
|
| 1695 |
+
#pragma message("TODO: update rope NORM mode to match NEOX mode")
|
| 1696 |
+
#pragma message(" https://github.com/ggerganov/llama.cpp/pull/7634")
|
| 1697 |
+
|
| 1698 |
GGML_ASSERT(ne10 == ne02);
|
| 1699 |
GGML_ASSERT(src0t == dstt);
|
| 1700 |
// const int n_past = ((int32_t *) dst->op_params)[0];
|
| 1701 |
const int n_dims = ((int32_t *) dst->op_params)[1];
|
| 1702 |
const int mode = ((int32_t *) dst->op_params)[2];
|
| 1703 |
// skip 3, n_ctx used in GLM RoPE, unimplemented in Vulkan
|
| 1704 |
+
const int n_ctx_orig = ((int32_t *) dst->op_params)[4];
|
| 1705 |
|
| 1706 |
float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
|
| 1707 |
memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float));
|
|
|
|
| 1711 |
memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
|
| 1712 |
memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
|
| 1713 |
ggml_vk_rope(
|
| 1714 |
+
seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, src0t, n_dims, mode, n_ctx_orig,
|
| 1715 |
freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow,
|
| 1716 |
ne01, ne02, ne03, nb00, nb01, nb02, nb03, ne0, nb0, nb1, nb2, nb3
|
| 1717 |
);
|
ggml-metal.m
CHANGED
|
@@ -172,8 +172,10 @@ enum ggml_metal_kernel_type {
|
|
| 172 |
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F32,
|
| 173 |
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32,
|
| 174 |
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32,
|
| 175 |
-
|
| 176 |
-
|
|
|
|
|
|
|
| 177 |
GGML_METAL_KERNEL_TYPE_IM2COL_F16,
|
| 178 |
GGML_METAL_KERNEL_TYPE_IM2COL_F32,
|
| 179 |
GGML_METAL_KERNEL_TYPE_UPSCALE_F32,
|
|
@@ -626,8 +628,10 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
| 626 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F32, mul_mm_id_iq1_m_f32, ctx->support_simdgroup_mm);
|
| 627 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32, mul_mm_id_iq4_nl_f32, ctx->support_simdgroup_mm);
|
| 628 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32, mul_mm_id_iq4_xs_f32, ctx->support_simdgroup_mm);
|
| 629 |
-
GGML_METAL_ADD_KERNEL(
|
| 630 |
-
GGML_METAL_ADD_KERNEL(
|
|
|
|
|
|
|
| 631 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F16, im2col_f16, true);
|
| 632 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F32, im2col_f32, true);
|
| 633 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_UPSCALE_F32, upscale_f32, true);
|
|
@@ -2285,7 +2289,7 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
| 2285 |
const int n_dims = ((int32_t *) dst->op_params)[1];
|
| 2286 |
const int mode = ((int32_t *) dst->op_params)[2];
|
| 2287 |
// skip 3, n_ctx, used in GLM RoPE, unimplemented in metal
|
| 2288 |
-
const int
|
| 2289 |
|
| 2290 |
float freq_base;
|
| 2291 |
float freq_scale;
|
|
@@ -2302,22 +2306,23 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
| 2302 |
memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
|
| 2303 |
|
| 2304 |
const bool is_neox = mode & 2;
|
| 2305 |
-
const bool is_glm = mode & 4;
|
| 2306 |
|
| 2307 |
-
|
| 2308 |
|
| 2309 |
if (!is_neox) {
|
| 2310 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2311 |
}
|
| 2312 |
|
| 2313 |
-
id<MTLComputePipelineState> pipeline = nil;
|
| 2314 |
-
|
| 2315 |
-
switch (src0->type) {
|
| 2316 |
-
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_F32].pipeline; break;
|
| 2317 |
-
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_F16].pipeline; break;
|
| 2318 |
-
default: GGML_ASSERT(false);
|
| 2319 |
-
};
|
| 2320 |
-
|
| 2321 |
[encoder setComputePipelineState:pipeline];
|
| 2322 |
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
| 2323 |
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
|
@@ -2345,14 +2350,13 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
| 2345 |
[encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:19];
|
| 2346 |
[encoder setBytes:&n_past length:sizeof( int) atIndex:20];
|
| 2347 |
[encoder setBytes:&n_dims length:sizeof( int) atIndex:21];
|
| 2348 |
-
[encoder setBytes:&
|
| 2349 |
-
[encoder setBytes:&
|
| 2350 |
-
[encoder setBytes:&
|
| 2351 |
-
[encoder setBytes:&
|
| 2352 |
-
[encoder setBytes:&
|
| 2353 |
-
[encoder setBytes:&
|
| 2354 |
-
[encoder setBytes:&
|
| 2355 |
-
[encoder setBytes:&beta_slow length:sizeof( float) atIndex:29];
|
| 2356 |
|
| 2357 |
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
| 2358 |
} break;
|
|
|
|
| 172 |
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F32,
|
| 173 |
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32,
|
| 174 |
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32,
|
| 175 |
+
GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32,
|
| 176 |
+
GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16,
|
| 177 |
+
GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32,
|
| 178 |
+
GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16,
|
| 179 |
GGML_METAL_KERNEL_TYPE_IM2COL_F16,
|
| 180 |
GGML_METAL_KERNEL_TYPE_IM2COL_F32,
|
| 181 |
GGML_METAL_KERNEL_TYPE_UPSCALE_F32,
|
|
|
|
| 628 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F32, mul_mm_id_iq1_m_f32, ctx->support_simdgroup_mm);
|
| 629 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32, mul_mm_id_iq4_nl_f32, ctx->support_simdgroup_mm);
|
| 630 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32, mul_mm_id_iq4_xs_f32, ctx->support_simdgroup_mm);
|
| 631 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32, rope_norm_f32, true);
|
| 632 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16, rope_norm_f16, true);
|
| 633 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32, rope_neox_f32, true);
|
| 634 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16, rope_neox_f16, true);
|
| 635 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F16, im2col_f16, true);
|
| 636 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F32, im2col_f32, true);
|
| 637 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_UPSCALE_F32, upscale_f32, true);
|
|
|
|
| 2289 |
const int n_dims = ((int32_t *) dst->op_params)[1];
|
| 2290 |
const int mode = ((int32_t *) dst->op_params)[2];
|
| 2291 |
// skip 3, n_ctx, used in GLM RoPE, unimplemented in metal
|
| 2292 |
+
const int n_ctx_orig = ((int32_t *) dst->op_params)[4];
|
| 2293 |
|
| 2294 |
float freq_base;
|
| 2295 |
float freq_scale;
|
|
|
|
| 2306 |
memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
|
| 2307 |
|
| 2308 |
const bool is_neox = mode & 2;
|
|
|
|
| 2309 |
|
| 2310 |
+
id<MTLComputePipelineState> pipeline = nil;
|
| 2311 |
|
| 2312 |
if (!is_neox) {
|
| 2313 |
+
switch (src0->type) {
|
| 2314 |
+
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32].pipeline; break;
|
| 2315 |
+
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16].pipeline; break;
|
| 2316 |
+
default: GGML_ASSERT(false);
|
| 2317 |
+
};
|
| 2318 |
+
} else {
|
| 2319 |
+
switch (src0->type) {
|
| 2320 |
+
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32].pipeline; break;
|
| 2321 |
+
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16].pipeline; break;
|
| 2322 |
+
default: GGML_ASSERT(false);
|
| 2323 |
+
};
|
| 2324 |
}
|
| 2325 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2326 |
[encoder setComputePipelineState:pipeline];
|
| 2327 |
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
| 2328 |
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
|
|
|
| 2350 |
[encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:19];
|
| 2351 |
[encoder setBytes:&n_past length:sizeof( int) atIndex:20];
|
| 2352 |
[encoder setBytes:&n_dims length:sizeof( int) atIndex:21];
|
| 2353 |
+
[encoder setBytes:&n_ctx_orig length:sizeof( int) atIndex:22];
|
| 2354 |
+
[encoder setBytes:&freq_base length:sizeof( float) atIndex:23];
|
| 2355 |
+
[encoder setBytes:&freq_scale length:sizeof( float) atIndex:24];
|
| 2356 |
+
[encoder setBytes:&ext_factor length:sizeof( float) atIndex:25];
|
| 2357 |
+
[encoder setBytes:&attn_factor length:sizeof( float) atIndex:26];
|
| 2358 |
+
[encoder setBytes:&beta_fast length:sizeof( float) atIndex:27];
|
| 2359 |
+
[encoder setBytes:&beta_slow length:sizeof( float) atIndex:28];
|
|
|
|
| 2360 |
|
| 2361 |
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
| 2362 |
} break;
|
ggml-metal.metal
CHANGED
|
@@ -1654,8 +1654,7 @@ static float rope_yarn_ramp(const float low, const float high, const int i0) {
|
|
| 1654 |
// MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng.
|
| 1655 |
static void rope_yarn(
|
| 1656 |
float theta_extrap, float freq_scale, float corr_dims[2], int64_t i0, float ext_factor, float mscale,
|
| 1657 |
-
thread float * cos_theta, thread float * sin_theta
|
| 1658 |
-
) {
|
| 1659 |
// Get n-d rotational scaling corrected for extrapolation
|
| 1660 |
float theta_interp = freq_scale * theta_extrap;
|
| 1661 |
float theta = theta_interp;
|
|
@@ -1672,19 +1671,20 @@ static void rope_yarn(
|
|
| 1672 |
|
| 1673 |
// Apparently solving `n_rot = 2pi * x * base^((2 * max_pos_emb) / n_dims)` for x, we get
|
| 1674 |
// `corr_fac(n_rot) = n_dims * log(max_pos_emb / (n_rot * 2pi)) / (2 * log(base))`
|
| 1675 |
-
static float rope_yarn_corr_factor(int n_dims, int
|
| 1676 |
-
return n_dims * log(
|
| 1677 |
}
|
| 1678 |
|
| 1679 |
static void rope_yarn_corr_dims(
|
| 1680 |
-
int n_dims, int
|
| 1681 |
) {
|
| 1682 |
// start and end correction dims
|
| 1683 |
-
dims[0] = max(0.0f, floor(rope_yarn_corr_factor(n_dims,
|
| 1684 |
-
dims[1] = min(n_dims - 1.0f, ceil(rope_yarn_corr_factor(n_dims,
|
| 1685 |
}
|
| 1686 |
|
| 1687 |
-
|
|
|
|
| 1688 |
device const void * src0,
|
| 1689 |
device const int32_t * src1,
|
| 1690 |
device const float * src2,
|
|
@@ -1707,8 +1707,7 @@ typedef void (rope_t)(
|
|
| 1707 |
constant uint64_t & nb3,
|
| 1708 |
constant int & n_past,
|
| 1709 |
constant int & n_dims,
|
| 1710 |
-
constant int &
|
| 1711 |
-
constant int & n_orig_ctx,
|
| 1712 |
constant float & freq_base,
|
| 1713 |
constant float & freq_scale,
|
| 1714 |
constant float & ext_factor,
|
|
@@ -1717,10 +1716,52 @@ typedef void (rope_t)(
|
|
| 1717 |
constant float & beta_slow,
|
| 1718 |
uint tiitg[[thread_index_in_threadgroup]],
|
| 1719 |
uint3 tptg[[threads_per_threadgroup]],
|
| 1720 |
-
uint3 tgpig[[threadgroup_position_in_grid]])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1721 |
|
| 1722 |
template<typename T>
|
| 1723 |
-
kernel void
|
| 1724 |
device const void * src0,
|
| 1725 |
device const int32_t * src1,
|
| 1726 |
device const float * src2,
|
|
@@ -1743,8 +1784,7 @@ kernel void kernel_rope(
|
|
| 1743 |
constant uint64_t & nb3,
|
| 1744 |
constant int & n_past,
|
| 1745 |
constant int & n_dims,
|
| 1746 |
-
constant int &
|
| 1747 |
-
constant int & n_orig_ctx,
|
| 1748 |
constant float & freq_base,
|
| 1749 |
constant float & freq_scale,
|
| 1750 |
constant float & ext_factor,
|
|
@@ -1758,69 +1798,53 @@ kernel void kernel_rope(
|
|
| 1758 |
const int64_t i2 = tgpig[1];
|
| 1759 |
const int64_t i1 = tgpig[0];
|
| 1760 |
|
| 1761 |
-
const bool is_neox = mode & 2;
|
| 1762 |
-
|
| 1763 |
float corr_dims[2];
|
| 1764 |
-
rope_yarn_corr_dims(n_dims,
|
| 1765 |
|
| 1766 |
device const int32_t * pos = src1;
|
| 1767 |
|
| 1768 |
-
const
|
| 1769 |
-
|
| 1770 |
-
const float theta_base = (float)p;
|
| 1771 |
const float inv_ndims = -1.f/n_dims;
|
| 1772 |
|
| 1773 |
-
|
| 1774 |
-
|
| 1775 |
-
const float theta = theta_base * pow(freq_base, inv_ndims*i0);
|
| 1776 |
-
|
| 1777 |
-
float cos_theta, sin_theta;
|
| 1778 |
-
rope_yarn(theta, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
|
| 1779 |
-
|
| 1780 |
-
device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
|
| 1781 |
-
device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
| 1782 |
-
|
| 1783 |
-
const T x0 = src[0];
|
| 1784 |
-
const T x1 = src[1];
|
| 1785 |
|
| 1786 |
-
|
| 1787 |
-
|
| 1788 |
-
|
| 1789 |
-
} else {
|
| 1790 |
-
for (int64_t ic = 2*tiitg; ic < ne0; ic += 2*tptg.x) {
|
| 1791 |
-
if (ic < n_dims) {
|
| 1792 |
-
const int64_t i0 = ic/2;
|
| 1793 |
|
| 1794 |
-
|
| 1795 |
-
|
| 1796 |
-
const float theta = theta_base * pow(freq_base, inv_ndims*ic);
|
| 1797 |
|
| 1798 |
-
|
| 1799 |
-
rope_yarn(theta/freq_factor, freq_scale, corr_dims, ic, ext_factor, attn_factor, &cos_theta, &sin_theta);
|
| 1800 |
|
| 1801 |
-
|
| 1802 |
-
device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
| 1803 |
|
| 1804 |
-
|
| 1805 |
-
|
| 1806 |
|
| 1807 |
-
|
| 1808 |
-
|
| 1809 |
-
} else {
|
| 1810 |
-
const int64_t i0 = ic;
|
| 1811 |
|
| 1812 |
-
|
| 1813 |
-
|
|
|
|
|
|
|
|
|
|
| 1814 |
|
| 1815 |
-
|
| 1816 |
-
|
| 1817 |
-
}
|
| 1818 |
}
|
| 1819 |
}
|
| 1820 |
}
|
| 1821 |
|
| 1822 |
-
|
| 1823 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1824 |
|
| 1825 |
typedef void (im2col_t)(
|
| 1826 |
device const float * x,
|
|
|
|
| 1654 |
// MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng.
|
| 1655 |
static void rope_yarn(
|
| 1656 |
float theta_extrap, float freq_scale, float corr_dims[2], int64_t i0, float ext_factor, float mscale,
|
| 1657 |
+
thread float * cos_theta, thread float * sin_theta) {
|
|
|
|
| 1658 |
// Get n-d rotational scaling corrected for extrapolation
|
| 1659 |
float theta_interp = freq_scale * theta_extrap;
|
| 1660 |
float theta = theta_interp;
|
|
|
|
| 1671 |
|
| 1672 |
// Apparently solving `n_rot = 2pi * x * base^((2 * max_pos_emb) / n_dims)` for x, we get
|
| 1673 |
// `corr_fac(n_rot) = n_dims * log(max_pos_emb / (n_rot * 2pi)) / (2 * log(base))`
|
| 1674 |
+
static float rope_yarn_corr_factor(int n_dims, int n_ctx_orig, float n_rot, float base) {
|
| 1675 |
+
return n_dims * log(n_ctx_orig / (n_rot * 2 * M_PI_F)) / (2 * log(base));
|
| 1676 |
}
|
| 1677 |
|
| 1678 |
static void rope_yarn_corr_dims(
|
| 1679 |
+
int n_dims, int n_ctx_orig, float freq_base, float beta_fast, float beta_slow, float dims[2]
|
| 1680 |
) {
|
| 1681 |
// start and end correction dims
|
| 1682 |
+
dims[0] = max(0.0f, floor(rope_yarn_corr_factor(n_dims, n_ctx_orig, beta_fast, freq_base)));
|
| 1683 |
+
dims[1] = min(n_dims - 1.0f, ceil(rope_yarn_corr_factor(n_dims, n_ctx_orig, beta_slow, freq_base)));
|
| 1684 |
}
|
| 1685 |
|
| 1686 |
+
template<typename T>
|
| 1687 |
+
kernel void kernel_rope_norm(
|
| 1688 |
device const void * src0,
|
| 1689 |
device const int32_t * src1,
|
| 1690 |
device const float * src2,
|
|
|
|
| 1707 |
constant uint64_t & nb3,
|
| 1708 |
constant int & n_past,
|
| 1709 |
constant int & n_dims,
|
| 1710 |
+
constant int & n_ctx_orig,
|
|
|
|
| 1711 |
constant float & freq_base,
|
| 1712 |
constant float & freq_scale,
|
| 1713 |
constant float & ext_factor,
|
|
|
|
| 1716 |
constant float & beta_slow,
|
| 1717 |
uint tiitg[[thread_index_in_threadgroup]],
|
| 1718 |
uint3 tptg[[threads_per_threadgroup]],
|
| 1719 |
+
uint3 tgpig[[threadgroup_position_in_grid]]) {
|
| 1720 |
+
const int64_t i3 = tgpig[2];
|
| 1721 |
+
const int64_t i2 = tgpig[1];
|
| 1722 |
+
const int64_t i1 = tgpig[0];
|
| 1723 |
+
|
| 1724 |
+
float corr_dims[2];
|
| 1725 |
+
rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims);
|
| 1726 |
+
|
| 1727 |
+
device const int32_t * pos = src1;
|
| 1728 |
+
|
| 1729 |
+
const float theta_base = (float) pos[i2];
|
| 1730 |
+
const float inv_ndims = -1.f/n_dims;
|
| 1731 |
+
|
| 1732 |
+
float cos_theta;
|
| 1733 |
+
float sin_theta;
|
| 1734 |
+
|
| 1735 |
+
for (int64_t i0 = 2*tiitg; i0 < ne0; i0 += 2*tptg.x) {
|
| 1736 |
+
if (i0 < n_dims) {
|
| 1737 |
+
const int64_t ic = i0/2;
|
| 1738 |
+
|
| 1739 |
+
const float theta = theta_base * pow(freq_base, inv_ndims*i0);
|
| 1740 |
+
|
| 1741 |
+
const float freq_factor = src2 != src0 ? src2[ic] : 1.0f;
|
| 1742 |
+
|
| 1743 |
+
rope_yarn(theta/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
|
| 1744 |
+
|
| 1745 |
+
device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
|
| 1746 |
+
device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
| 1747 |
+
|
| 1748 |
+
const float x0 = src[0];
|
| 1749 |
+
const float x1 = src[1];
|
| 1750 |
+
|
| 1751 |
+
dst_data[0] = x0*cos_theta - x1*sin_theta;
|
| 1752 |
+
dst_data[1] = x0*sin_theta + x1*cos_theta;
|
| 1753 |
+
} else {
|
| 1754 |
+
device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
|
| 1755 |
+
device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
| 1756 |
+
|
| 1757 |
+
dst_data[0] = src[0];
|
| 1758 |
+
dst_data[1] = src[1];
|
| 1759 |
+
}
|
| 1760 |
+
}
|
| 1761 |
+
}
|
| 1762 |
|
| 1763 |
template<typename T>
|
| 1764 |
+
kernel void kernel_rope_neox(
|
| 1765 |
device const void * src0,
|
| 1766 |
device const int32_t * src1,
|
| 1767 |
device const float * src2,
|
|
|
|
| 1784 |
constant uint64_t & nb3,
|
| 1785 |
constant int & n_past,
|
| 1786 |
constant int & n_dims,
|
| 1787 |
+
constant int & n_ctx_orig,
|
|
|
|
| 1788 |
constant float & freq_base,
|
| 1789 |
constant float & freq_scale,
|
| 1790 |
constant float & ext_factor,
|
|
|
|
| 1798 |
const int64_t i2 = tgpig[1];
|
| 1799 |
const int64_t i1 = tgpig[0];
|
| 1800 |
|
|
|
|
|
|
|
| 1801 |
float corr_dims[2];
|
| 1802 |
+
rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims);
|
| 1803 |
|
| 1804 |
device const int32_t * pos = src1;
|
| 1805 |
|
| 1806 |
+
const float theta_base = (float) pos[i2];
|
|
|
|
|
|
|
| 1807 |
const float inv_ndims = -1.f/n_dims;
|
| 1808 |
|
| 1809 |
+
float cos_theta;
|
| 1810 |
+
float sin_theta;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1811 |
|
| 1812 |
+
for (int64_t i0 = 2*tiitg; i0 < ne0; i0 += 2*tptg.x) {
|
| 1813 |
+
if (i0 < n_dims) {
|
| 1814 |
+
const int64_t ic = i0/2;
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1815 |
|
| 1816 |
+
const float theta = theta_base * pow(freq_base, inv_ndims*i0);
|
|
|
|
|
|
|
| 1817 |
|
| 1818 |
+
const float freq_factor = src2 != src0 ? src2[ic] : 1.0f;
|
|
|
|
| 1819 |
|
| 1820 |
+
rope_yarn(theta/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
|
|
|
|
| 1821 |
|
| 1822 |
+
device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
|
| 1823 |
+
device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
|
| 1824 |
|
| 1825 |
+
const float x0 = src[0];
|
| 1826 |
+
const float x1 = src[n_dims/2];
|
|
|
|
|
|
|
| 1827 |
|
| 1828 |
+
dst_data[0] = x0*cos_theta - x1*sin_theta;
|
| 1829 |
+
dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta;
|
| 1830 |
+
} else {
|
| 1831 |
+
device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
|
| 1832 |
+
device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
| 1833 |
|
| 1834 |
+
dst_data[0] = src[0];
|
| 1835 |
+
dst_data[1] = src[1];
|
|
|
|
| 1836 |
}
|
| 1837 |
}
|
| 1838 |
}
|
| 1839 |
|
| 1840 |
+
typedef decltype(kernel_rope_norm<float>) kernel_rope_norm_t;
|
| 1841 |
+
typedef decltype(kernel_rope_neox<float>) kernel_rope_neox_t;
|
| 1842 |
+
|
| 1843 |
+
template [[host_name("kernel_rope_norm_f32")]] kernel kernel_rope_norm_t kernel_rope_norm<float>;
|
| 1844 |
+
template [[host_name("kernel_rope_norm_f16")]] kernel kernel_rope_norm_t kernel_rope_norm<half>;
|
| 1845 |
+
|
| 1846 |
+
template [[host_name("kernel_rope_neox_f32")]] kernel kernel_rope_neox_t kernel_rope_neox<float>;
|
| 1847 |
+
template [[host_name("kernel_rope_neox_f16")]] kernel kernel_rope_neox_t kernel_rope_neox<half>;
|
| 1848 |
|
| 1849 |
typedef void (im2col_t)(
|
| 1850 |
device const float * x,
|
ggml-sycl.cpp
CHANGED
|
@@ -8928,49 +8928,6 @@ static void rope_neox(
|
|
| 8928 |
dst[i + n_dims/2] = x0*sin_theta + x1*cos_theta;
|
| 8929 |
}
|
| 8930 |
|
| 8931 |
-
static void rope_glm_f32(
|
| 8932 |
-
const float * x, float * dst, int ncols, const int32_t * pos, float freq_scale, int p_delta_rows, float freq_base,
|
| 8933 |
-
int n_ctx
|
| 8934 |
-
, const sycl::nd_item<3> &item_ct1) {
|
| 8935 |
-
const int col = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
|
| 8936 |
-
item_ct1.get_local_id(2);
|
| 8937 |
-
const int half_n_dims = ncols/4;
|
| 8938 |
-
|
| 8939 |
-
if (col >= half_n_dims) {
|
| 8940 |
-
return;
|
| 8941 |
-
}
|
| 8942 |
-
|
| 8943 |
-
const int row = item_ct1.get_local_range(1) * item_ct1.get_group(1) +
|
| 8944 |
-
item_ct1.get_local_id(1);
|
| 8945 |
-
const int i = row*ncols + col;
|
| 8946 |
-
const int i2 = row/p_delta_rows;
|
| 8947 |
-
|
| 8948 |
-
const float col_theta_scale = dpct::pow(freq_base, -2.0f * col / ncols);
|
| 8949 |
-
// FIXME: this is likely wrong
|
| 8950 |
-
const int p = pos != nullptr ? pos[i2] : 0;
|
| 8951 |
-
|
| 8952 |
-
const float theta = sycl::min(p, n_ctx - 2) * freq_scale * col_theta_scale;
|
| 8953 |
-
const float sin_theta = sycl::sin((float)theta);
|
| 8954 |
-
const float cos_theta = sycl::cos((float)theta);
|
| 8955 |
-
|
| 8956 |
-
const float x0 = x[i + 0];
|
| 8957 |
-
const float x1 = x[i + half_n_dims];
|
| 8958 |
-
|
| 8959 |
-
dst[i + 0] = x0*cos_theta - x1*sin_theta;
|
| 8960 |
-
dst[i + half_n_dims] = x0*sin_theta + x1*cos_theta;
|
| 8961 |
-
|
| 8962 |
-
const float block_theta =
|
| 8963 |
-
((float)sycl::max(p - n_ctx - 2, 0)) * col_theta_scale;
|
| 8964 |
-
const float sin_block_theta = sycl::sin((float)block_theta);
|
| 8965 |
-
const float cos_block_theta = sycl::cos((float)block_theta);
|
| 8966 |
-
|
| 8967 |
-
const float x2 = x[i + half_n_dims * 2];
|
| 8968 |
-
const float x3 = x[i + half_n_dims * 3];
|
| 8969 |
-
|
| 8970 |
-
dst[i + half_n_dims * 2] = x2*cos_block_theta - x3*sin_block_theta;
|
| 8971 |
-
dst[i + half_n_dims * 3] = x2*sin_block_theta + x3*cos_block_theta;
|
| 8972 |
-
}
|
| 8973 |
-
|
| 8974 |
static void k_sum_rows_f32(const float * x, float * dst, const int ncols,
|
| 8975 |
const sycl::nd_item<3> &item_ct1) {
|
| 8976 |
const int row = item_ct1.get_group(1);
|
|
@@ -12520,22 +12477,6 @@ static void rope_neox_sycl(const T *x, T *dst, int ncols, int n_dims, int nrows,
|
|
| 12520 |
}
|
| 12521 |
}
|
| 12522 |
|
| 12523 |
-
static void rope_glm_f32_sycl(const float *x, float *dst, int ncols, int nrows,
|
| 12524 |
-
const int32_t *pos, float freq_scale,
|
| 12525 |
-
int p_delta_rows, float freq_base, int n_ctx,
|
| 12526 |
-
dpct::queue_ptr stream) {
|
| 12527 |
-
GGML_ASSERT(ncols % 4 == 0);
|
| 12528 |
-
const sycl::range<3> block_dims(1, 1, SYCL_ROPE_BLOCK_SIZE / 4);
|
| 12529 |
-
const int num_blocks_x = (ncols + SYCL_ROPE_BLOCK_SIZE - 1) / SYCL_ROPE_BLOCK_SIZE;
|
| 12530 |
-
const sycl::range<3> block_nums(1, nrows, num_blocks_x);
|
| 12531 |
-
stream->parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
| 12532 |
-
[=](sycl::nd_item<3> item_ct1) {
|
| 12533 |
-
rope_glm_f32(x, dst, ncols, pos, freq_scale,
|
| 12534 |
-
p_delta_rows, freq_base, n_ctx,
|
| 12535 |
-
item_ct1);
|
| 12536 |
-
});
|
| 12537 |
-
}
|
| 12538 |
-
|
| 12539 |
static void sum_rows_f32_sycl(const float *x, float *dst, const int ncols,
|
| 12540 |
const int nrows, dpct::queue_ptr stream) {
|
| 12541 |
const sycl::range<3> block_dims(1, 1, WARP_SIZE);
|
|
@@ -14066,8 +14007,8 @@ inline void ggml_sycl_op_rope(const ggml_tensor *src0, const ggml_tensor *src1,
|
|
| 14066 |
//const int n_past = ((int32_t *) dst->op_params)[0];
|
| 14067 |
const int n_dims = ((int32_t *) dst->op_params)[1];
|
| 14068 |
const int mode = ((int32_t *) dst->op_params)[2];
|
| 14069 |
-
const int n_ctx = ((int32_t *) dst->op_params)[3];
|
| 14070 |
-
const int
|
| 14071 |
|
| 14072 |
// RoPE alteration for extended context
|
| 14073 |
float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
|
|
@@ -14087,7 +14028,9 @@ inline void ggml_sycl_op_rope(const ggml_tensor *src0, const ggml_tensor *src1,
|
|
| 14087 |
}
|
| 14088 |
|
| 14089 |
const bool is_neox = mode & 2;
|
| 14090 |
-
|
|
|
|
|
|
|
| 14091 |
|
| 14092 |
if (is_neox) {
|
| 14093 |
pos = (const int32_t *) src1_dd;
|
|
@@ -14100,13 +14043,10 @@ inline void ggml_sycl_op_rope(const ggml_tensor *src0, const ggml_tensor *src1,
|
|
| 14100 |
}
|
| 14101 |
|
| 14102 |
rope_corr_dims corr_dims;
|
| 14103 |
-
ggml_rope_yarn_corr_dims(n_dims,
|
| 14104 |
|
| 14105 |
// compute
|
| 14106 |
-
if (
|
| 14107 |
-
GGML_ASSERT(false);
|
| 14108 |
-
rope_glm_f32_sycl(src0_dd, dst_dd, ne00, nrows, pos, freq_scale, ne01, freq_base, n_ctx, main_stream);
|
| 14109 |
-
} else if (is_neox) {
|
| 14110 |
if (src0->type == GGML_TYPE_F32) {
|
| 14111 |
rope_neox_sycl(
|
| 14112 |
(const float *)src0_dd, (float *)dst_dd, ne00, n_dims, nrows, pos, freq_scale, ne01, freq_base, ext_factor,
|
|
|
|
| 8928 |
dst[i + n_dims/2] = x0*sin_theta + x1*cos_theta;
|
| 8929 |
}
|
| 8930 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8931 |
static void k_sum_rows_f32(const float * x, float * dst, const int ncols,
|
| 8932 |
const sycl::nd_item<3> &item_ct1) {
|
| 8933 |
const int row = item_ct1.get_group(1);
|
|
|
|
| 12477 |
}
|
| 12478 |
}
|
| 12479 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12480 |
static void sum_rows_f32_sycl(const float *x, float *dst, const int ncols,
|
| 12481 |
const int nrows, dpct::queue_ptr stream) {
|
| 12482 |
const sycl::range<3> block_dims(1, 1, WARP_SIZE);
|
|
|
|
| 14007 |
//const int n_past = ((int32_t *) dst->op_params)[0];
|
| 14008 |
const int n_dims = ((int32_t *) dst->op_params)[1];
|
| 14009 |
const int mode = ((int32_t *) dst->op_params)[2];
|
| 14010 |
+
//const int n_ctx = ((int32_t *) dst->op_params)[3];
|
| 14011 |
+
const int n_ctx_orig = ((int32_t *) dst->op_params)[4];
|
| 14012 |
|
| 14013 |
// RoPE alteration for extended context
|
| 14014 |
float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
|
|
|
|
| 14028 |
}
|
| 14029 |
|
| 14030 |
const bool is_neox = mode & 2;
|
| 14031 |
+
|
| 14032 |
+
#pragma message("TODO: update rope NORM mode to match NEOX mode")
|
| 14033 |
+
#pragma message(" https://github.com/ggerganov/llama.cpp/pull/7634")
|
| 14034 |
|
| 14035 |
if (is_neox) {
|
| 14036 |
pos = (const int32_t *) src1_dd;
|
|
|
|
| 14043 |
}
|
| 14044 |
|
| 14045 |
rope_corr_dims corr_dims;
|
| 14046 |
+
ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims.v);
|
| 14047 |
|
| 14048 |
// compute
|
| 14049 |
+
if (is_neox) {
|
|
|
|
|
|
|
|
|
|
| 14050 |
if (src0->type == GGML_TYPE_F32) {
|
| 14051 |
rope_neox_sycl(
|
| 14052 |
(const float *)src0_dd, (float *)dst_dd, ne00, n_dims, nrows, pos, freq_scale, ne01, freq_base, ext_factor,
|
ggml-vulkan.cpp
CHANGED
|
@@ -3898,11 +3898,6 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
|
|
| 3898 |
{
|
| 3899 |
const int mode = ((const int32_t *) dst->op_params)[2];
|
| 3900 |
const bool is_neox = mode & 2;
|
| 3901 |
-
const bool is_glm = mode & 4;
|
| 3902 |
-
|
| 3903 |
-
if (is_glm) {
|
| 3904 |
-
return nullptr;
|
| 3905 |
-
}
|
| 3906 |
|
| 3907 |
if (is_neox) {
|
| 3908 |
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
|
|
@@ -4401,7 +4396,7 @@ static void ggml_vk_rope(ggml_backend_vk_context * ctx, vk_context * subctx, con
|
|
| 4401 |
const int n_dims = ((int32_t *) dst->op_params)[1];
|
| 4402 |
const int mode = ((int32_t *) dst->op_params)[2];
|
| 4403 |
// const int n_ctx = ((int32_t *) dst->op_params)[3];
|
| 4404 |
-
const int
|
| 4405 |
const float freq_base = ((float *) dst->op_params)[5];
|
| 4406 |
const float freq_scale = ((float *) dst->op_params)[6];
|
| 4407 |
const float ext_factor = ((float *) dst->op_params)[7];
|
|
@@ -4410,12 +4405,12 @@ static void ggml_vk_rope(ggml_backend_vk_context * ctx, vk_context * subctx, con
|
|
| 4410 |
const float beta_slow = ((float *) dst->op_params)[10];
|
| 4411 |
|
| 4412 |
const bool is_neox = mode & 2;
|
| 4413 |
-
const bool is_glm = mode & 4;
|
| 4414 |
|
| 4415 |
-
|
|
|
|
| 4416 |
|
| 4417 |
float corr_dims[2];
|
| 4418 |
-
ggml_rope_yarn_corr_dims(n_dims,
|
| 4419 |
|
| 4420 |
if (is_neox) {
|
| 4421 |
const float theta_scale = powf(freq_base, -2.0f/n_dims);
|
|
@@ -6485,9 +6480,8 @@ GGML_CALL static bool ggml_backend_vk_supports_op(ggml_backend_t backend, const
|
|
| 6485 |
case GGML_OP_ROPE:
|
| 6486 |
{
|
| 6487 |
const int mode = ((const int32_t *) op->op_params)[2];
|
| 6488 |
-
const bool is_glm = mode & 4;
|
| 6489 |
|
| 6490 |
-
return
|
| 6491 |
} break;
|
| 6492 |
case GGML_OP_NONE:
|
| 6493 |
case GGML_OP_RESHAPE:
|
|
@@ -6992,15 +6986,15 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_compute_
|
|
| 6992 |
} else if (tensor->op == GGML_OP_ROPE) {
|
| 6993 |
const int n_dims = ((int32_t *) tensor->op_params)[1];
|
| 6994 |
const int mode = ((int32_t *) tensor->op_params)[2];
|
| 6995 |
-
const int
|
| 6996 |
-
const int
|
| 6997 |
float freq_base = ((float *) tensor->op_params)[5];
|
| 6998 |
float freq_scale = ((float *) tensor->op_params)[6];
|
| 6999 |
float ext_factor = ((float *) tensor->op_params)[7];
|
| 7000 |
float attn_factor = ((float *) tensor->op_params)[8];
|
| 7001 |
float beta_fast = ((float *) tensor->op_params)[9];
|
| 7002 |
float beta_slow = ((float *) tensor->op_params)[10];
|
| 7003 |
-
tensor_clone = ggml_rope_ext(ggml_ctx, src0_clone, src1_clone, src2_clone, n_dims, mode,
|
| 7004 |
} else if (tensor->op == GGML_OP_UNARY) {
|
| 7005 |
switch (ggml_get_unary_op(tensor)) {
|
| 7006 |
case GGML_UNARY_OP_SILU:
|
|
|
|
| 3898 |
{
|
| 3899 |
const int mode = ((const int32_t *) dst->op_params)[2];
|
| 3900 |
const bool is_neox = mode & 2;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3901 |
|
| 3902 |
if (is_neox) {
|
| 3903 |
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
|
|
|
|
| 4396 |
const int n_dims = ((int32_t *) dst->op_params)[1];
|
| 4397 |
const int mode = ((int32_t *) dst->op_params)[2];
|
| 4398 |
// const int n_ctx = ((int32_t *) dst->op_params)[3];
|
| 4399 |
+
const int n_ctx_orig = ((int32_t *) dst->op_params)[4];
|
| 4400 |
const float freq_base = ((float *) dst->op_params)[5];
|
| 4401 |
const float freq_scale = ((float *) dst->op_params)[6];
|
| 4402 |
const float ext_factor = ((float *) dst->op_params)[7];
|
|
|
|
| 4405 |
const float beta_slow = ((float *) dst->op_params)[10];
|
| 4406 |
|
| 4407 |
const bool is_neox = mode & 2;
|
|
|
|
| 4408 |
|
| 4409 |
+
#pragma message("TODO: update rope NORM mode to match NEOX mode")
|
| 4410 |
+
#pragma message(" https://github.com/ggerganov/llama.cpp/pull/7634")
|
| 4411 |
|
| 4412 |
float corr_dims[2];
|
| 4413 |
+
ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims);
|
| 4414 |
|
| 4415 |
if (is_neox) {
|
| 4416 |
const float theta_scale = powf(freq_base, -2.0f/n_dims);
|
|
|
|
| 6480 |
case GGML_OP_ROPE:
|
| 6481 |
{
|
| 6482 |
const int mode = ((const int32_t *) op->op_params)[2];
|
|
|
|
| 6483 |
|
| 6484 |
+
return true;
|
| 6485 |
} break;
|
| 6486 |
case GGML_OP_NONE:
|
| 6487 |
case GGML_OP_RESHAPE:
|
|
|
|
| 6986 |
} else if (tensor->op == GGML_OP_ROPE) {
|
| 6987 |
const int n_dims = ((int32_t *) tensor->op_params)[1];
|
| 6988 |
const int mode = ((int32_t *) tensor->op_params)[2];
|
| 6989 |
+
//const int n_ctx_ggml = ((int32_t *) tensor->op_params)[3];
|
| 6990 |
+
const int n_ctx_orig_ggml = ((int32_t *) tensor->op_params)[4];
|
| 6991 |
float freq_base = ((float *) tensor->op_params)[5];
|
| 6992 |
float freq_scale = ((float *) tensor->op_params)[6];
|
| 6993 |
float ext_factor = ((float *) tensor->op_params)[7];
|
| 6994 |
float attn_factor = ((float *) tensor->op_params)[8];
|
| 6995 |
float beta_fast = ((float *) tensor->op_params)[9];
|
| 6996 |
float beta_slow = ((float *) tensor->op_params)[10];
|
| 6997 |
+
tensor_clone = ggml_rope_ext(ggml_ctx, src0_clone, src1_clone, src2_clone, n_dims, mode, n_ctx_orig_ggml, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow);
|
| 6998 |
} else if (tensor->op == GGML_OP_UNARY) {
|
| 6999 |
switch (ggml_get_unary_op(tensor)) {
|
| 7000 |
case GGML_UNARY_OP_SILU:
|
ggml.c
CHANGED
|
@@ -6250,16 +6250,13 @@ static struct ggml_tensor * ggml_rope_impl(
|
|
| 6250 |
struct ggml_tensor * c,
|
| 6251 |
int n_dims,
|
| 6252 |
int mode,
|
| 6253 |
-
int
|
| 6254 |
-
int n_orig_ctx,
|
| 6255 |
float freq_base,
|
| 6256 |
float freq_scale,
|
| 6257 |
float ext_factor,
|
| 6258 |
float attn_factor,
|
| 6259 |
float beta_fast,
|
| 6260 |
float beta_slow,
|
| 6261 |
-
float xpos_base,
|
| 6262 |
-
bool xpos_down,
|
| 6263 |
bool inplace) {
|
| 6264 |
GGML_ASSERT((mode & 1) == 0 && "mode & 1 == 1 is no longer supported");
|
| 6265 |
|
|
@@ -6280,15 +6277,13 @@ static struct ggml_tensor * ggml_rope_impl(
|
|
| 6280 |
|
| 6281 |
struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
|
| 6282 |
|
| 6283 |
-
int32_t params[
|
| 6284 |
memcpy(params + 5, &freq_base, sizeof(float));
|
| 6285 |
memcpy(params + 6, &freq_scale, sizeof(float));
|
| 6286 |
memcpy(params + 7, &ext_factor, sizeof(float));
|
| 6287 |
memcpy(params + 8, &attn_factor, sizeof(float));
|
| 6288 |
memcpy(params + 9, &beta_fast, sizeof(float));
|
| 6289 |
memcpy(params + 10, &beta_slow, sizeof(float));
|
| 6290 |
-
memcpy(params + 11, &xpos_base, sizeof(float));
|
| 6291 |
-
memcpy(params + 12, &xpos_down, sizeof(bool));
|
| 6292 |
ggml_set_op_params(result, params, sizeof(params));
|
| 6293 |
|
| 6294 |
result->op = GGML_OP_ROPE;
|
|
@@ -6305,10 +6300,9 @@ struct ggml_tensor * ggml_rope(
|
|
| 6305 |
struct ggml_tensor * a,
|
| 6306 |
struct ggml_tensor * b,
|
| 6307 |
int n_dims,
|
| 6308 |
-
int mode
|
| 6309 |
-
int n_ctx) {
|
| 6310 |
return ggml_rope_impl(
|
| 6311 |
-
ctx, a, b, NULL, n_dims, mode,
|
| 6312 |
);
|
| 6313 |
}
|
| 6314 |
|
|
@@ -6317,10 +6311,9 @@ struct ggml_tensor * ggml_rope_inplace(
|
|
| 6317 |
struct ggml_tensor * a,
|
| 6318 |
struct ggml_tensor * b,
|
| 6319 |
int n_dims,
|
| 6320 |
-
int mode
|
| 6321 |
-
int n_ctx) {
|
| 6322 |
return ggml_rope_impl(
|
| 6323 |
-
ctx, a, b, NULL, n_dims, mode,
|
| 6324 |
);
|
| 6325 |
}
|
| 6326 |
|
|
@@ -6331,8 +6324,7 @@ struct ggml_tensor * ggml_rope_ext(
|
|
| 6331 |
struct ggml_tensor * c,
|
| 6332 |
int n_dims,
|
| 6333 |
int mode,
|
| 6334 |
-
int
|
| 6335 |
-
int n_orig_ctx,
|
| 6336 |
float freq_base,
|
| 6337 |
float freq_scale,
|
| 6338 |
float ext_factor,
|
|
@@ -6340,8 +6332,8 @@ struct ggml_tensor * ggml_rope_ext(
|
|
| 6340 |
float beta_fast,
|
| 6341 |
float beta_slow) {
|
| 6342 |
return ggml_rope_impl(
|
| 6343 |
-
ctx, a, b, c, n_dims, mode,
|
| 6344 |
-
ext_factor, attn_factor, beta_fast, beta_slow,
|
| 6345 |
);
|
| 6346 |
}
|
| 6347 |
|
|
@@ -6352,8 +6344,7 @@ struct ggml_tensor * ggml_rope_ext_inplace(
|
|
| 6352 |
struct ggml_tensor * c,
|
| 6353 |
int n_dims,
|
| 6354 |
int mode,
|
| 6355 |
-
int
|
| 6356 |
-
int n_orig_ctx,
|
| 6357 |
float freq_base,
|
| 6358 |
float freq_scale,
|
| 6359 |
float ext_factor,
|
|
@@ -6361,8 +6352,8 @@ struct ggml_tensor * ggml_rope_ext_inplace(
|
|
| 6361 |
float beta_fast,
|
| 6362 |
float beta_slow) {
|
| 6363 |
return ggml_rope_impl(
|
| 6364 |
-
ctx, a, b, c, n_dims, mode,
|
| 6365 |
-
ext_factor, attn_factor, beta_fast, beta_slow,
|
| 6366 |
);
|
| 6367 |
}
|
| 6368 |
|
|
@@ -6372,8 +6363,7 @@ struct ggml_tensor * ggml_rope_custom(
|
|
| 6372 |
struct ggml_tensor * b,
|
| 6373 |
int n_dims,
|
| 6374 |
int mode,
|
| 6375 |
-
int
|
| 6376 |
-
int n_orig_ctx,
|
| 6377 |
float freq_base,
|
| 6378 |
float freq_scale,
|
| 6379 |
float ext_factor,
|
|
@@ -6381,8 +6371,8 @@ struct ggml_tensor * ggml_rope_custom(
|
|
| 6381 |
float beta_fast,
|
| 6382 |
float beta_slow) {
|
| 6383 |
return ggml_rope_impl(
|
| 6384 |
-
ctx, a, b, NULL, n_dims, mode,
|
| 6385 |
-
ext_factor, attn_factor, beta_fast, beta_slow,
|
| 6386 |
);
|
| 6387 |
}
|
| 6388 |
|
|
@@ -6392,8 +6382,7 @@ struct ggml_tensor * ggml_rope_custom_inplace(
|
|
| 6392 |
struct ggml_tensor * b,
|
| 6393 |
int n_dims,
|
| 6394 |
int mode,
|
| 6395 |
-
int
|
| 6396 |
-
int n_orig_ctx,
|
| 6397 |
float freq_base,
|
| 6398 |
float freq_scale,
|
| 6399 |
float ext_factor,
|
|
@@ -6401,21 +6390,11 @@ struct ggml_tensor * ggml_rope_custom_inplace(
|
|
| 6401 |
float beta_fast,
|
| 6402 |
float beta_slow) {
|
| 6403 |
return ggml_rope_impl(
|
| 6404 |
-
ctx, a, b, NULL, n_dims, mode,
|
| 6405 |
-
ext_factor, attn_factor, beta_fast, beta_slow,
|
| 6406 |
);
|
| 6407 |
}
|
| 6408 |
|
| 6409 |
-
struct ggml_tensor * ggml_rope_xpos_inplace(
|
| 6410 |
-
struct ggml_context * ctx,
|
| 6411 |
-
struct ggml_tensor * a,
|
| 6412 |
-
struct ggml_tensor * b,
|
| 6413 |
-
int n_dims,
|
| 6414 |
-
float base,
|
| 6415 |
-
bool down) {
|
| 6416 |
-
return ggml_rope_impl(ctx, a, b, NULL, n_dims, 0, 0, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, base, down, true);
|
| 6417 |
-
}
|
| 6418 |
-
|
| 6419 |
// ggml_rope_back
|
| 6420 |
|
| 6421 |
struct ggml_tensor * ggml_rope_back(
|
|
@@ -6425,16 +6404,13 @@ struct ggml_tensor * ggml_rope_back(
|
|
| 6425 |
struct ggml_tensor * c,
|
| 6426 |
int n_dims,
|
| 6427 |
int mode,
|
| 6428 |
-
int
|
| 6429 |
-
int n_orig_ctx,
|
| 6430 |
float freq_base,
|
| 6431 |
float freq_scale,
|
| 6432 |
float ext_factor,
|
| 6433 |
float attn_factor,
|
| 6434 |
float beta_fast,
|
| 6435 |
-
float beta_slow
|
| 6436 |
-
float xpos_base,
|
| 6437 |
-
bool xpos_down) {
|
| 6438 |
GGML_ASSERT(ggml_is_vector(b));
|
| 6439 |
GGML_ASSERT(b->type == GGML_TYPE_I32);
|
| 6440 |
GGML_ASSERT(a->ne[2] == b->ne[0]);
|
|
@@ -6450,15 +6426,13 @@ struct ggml_tensor * ggml_rope_back(
|
|
| 6450 |
|
| 6451 |
struct ggml_tensor * result = ggml_dup_tensor(ctx, a);
|
| 6452 |
|
| 6453 |
-
int32_t params[
|
| 6454 |
memcpy(params + 5, &freq_base, sizeof(float));
|
| 6455 |
memcpy(params + 6, &freq_scale, sizeof(float));
|
| 6456 |
memcpy(params + 7, &ext_factor, sizeof(float));
|
| 6457 |
memcpy(params + 8, &attn_factor, sizeof(float));
|
| 6458 |
memcpy(params + 9, &beta_fast, sizeof(float));
|
| 6459 |
memcpy(params + 10, &beta_slow, sizeof(float));
|
| 6460 |
-
memcpy(params + 11, &xpos_base, sizeof(float));
|
| 6461 |
-
memcpy(params + 12, &xpos_down, sizeof(bool));
|
| 6462 |
ggml_set_op_params(result, params, sizeof(params));
|
| 6463 |
|
| 6464 |
result->op = GGML_OP_ROPE_BACK;
|
|
@@ -14227,8 +14201,7 @@ static float rope_yarn_ramp(const float low, const float high, const int i0) {
|
|
| 14227 |
// MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng.
|
| 14228 |
static void rope_yarn(
|
| 14229 |
float theta_extrap, float freq_scale, float corr_dims[2], int64_t i0, float ext_factor, float mscale,
|
| 14230 |
-
float * cos_theta, float * sin_theta
|
| 14231 |
-
) {
|
| 14232 |
// Get n-d rotational scaling corrected for extrapolation
|
| 14233 |
float theta_interp = freq_scale * theta_extrap;
|
| 14234 |
float theta = theta_interp;
|
|
@@ -14245,18 +14218,19 @@ static void rope_yarn(
|
|
| 14245 |
|
| 14246 |
// Apparently solving `n_rot = 2pi * x * base^((2 * max_pos_emb) / n_dims)` for x, we get
|
| 14247 |
// `corr_dim(n_rot) = n_dims * log(max_pos_emb / (n_rot * 2pi)) / (2 * log(base))`
|
| 14248 |
-
static float ggml_rope_yarn_corr_dim(int n_dims, int
|
| 14249 |
-
return n_dims * logf(
|
| 14250 |
}
|
| 14251 |
|
| 14252 |
static void ggml_rope_cache_init(
|
| 14253 |
-
float theta_base, float freq_scale, float corr_dims[2], int64_t ne0, float ext_factor, float mscale,
|
| 14254 |
-
float * cache, float sin_sign, float theta_scale
|
| 14255 |
-
|
| 14256 |
float theta = theta_base;
|
| 14257 |
for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
|
|
|
|
| 14258 |
rope_yarn(
|
| 14259 |
-
theta, freq_scale, corr_dims, i0, ext_factor, mscale, &cache[i0 + 0], &cache[i0 + 1]
|
| 14260 |
);
|
| 14261 |
cache[i0 + 1] *= sin_sign;
|
| 14262 |
|
|
@@ -14265,11 +14239,11 @@ static void ggml_rope_cache_init(
|
|
| 14265 |
}
|
| 14266 |
|
| 14267 |
GGML_CALL void ggml_rope_yarn_corr_dims(
|
| 14268 |
-
int n_dims, int
|
| 14269 |
) {
|
| 14270 |
// start and end correction dims
|
| 14271 |
-
float start = floorf(ggml_rope_yarn_corr_dim(n_dims,
|
| 14272 |
-
float end = ceilf(ggml_rope_yarn_corr_dim(n_dims,
|
| 14273 |
dims[0] = MAX(0, start);
|
| 14274 |
dims[1] = MIN(n_dims - 1, end);
|
| 14275 |
}
|
|
@@ -14289,15 +14263,11 @@ static void ggml_compute_forward_rope_f32(
|
|
| 14289 |
|
| 14290 |
float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
|
| 14291 |
|
| 14292 |
-
// these two only relevant for xPos RoPE:
|
| 14293 |
-
float xpos_base;
|
| 14294 |
-
bool xpos_down;
|
| 14295 |
-
|
| 14296 |
//const int n_past = ((int32_t *) dst->op_params)[0];
|
| 14297 |
const int n_dims = ((int32_t *) dst->op_params)[1];
|
| 14298 |
const int mode = ((int32_t *) dst->op_params)[2];
|
| 14299 |
-
const int n_ctx = ((int32_t *) dst->op_params)[3];
|
| 14300 |
-
const int
|
| 14301 |
|
| 14302 |
memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float));
|
| 14303 |
memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float));
|
|
@@ -14305,8 +14275,6 @@ static void ggml_compute_forward_rope_f32(
|
|
| 14305 |
memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float));
|
| 14306 |
memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
|
| 14307 |
memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
|
| 14308 |
-
memcpy(&xpos_base, (int32_t *) dst->op_params + 11, sizeof(float));
|
| 14309 |
-
memcpy(&xpos_down, (int32_t *) dst->op_params + 12, sizeof(bool));
|
| 14310 |
|
| 14311 |
GGML_TENSOR_UNARY_OP_LOCALS
|
| 14312 |
|
|
@@ -14336,20 +14304,15 @@ static void ggml_compute_forward_rope_f32(
|
|
| 14336 |
const float theta_scale = powf(freq_base, -2.0f/n_dims);
|
| 14337 |
|
| 14338 |
float corr_dims[2];
|
| 14339 |
-
ggml_rope_yarn_corr_dims(n_dims,
|
| 14340 |
|
| 14341 |
const bool is_neox = mode & 2;
|
| 14342 |
-
const bool is_glm = mode & 4;
|
| 14343 |
|
| 14344 |
const float * freq_factors = NULL;
|
| 14345 |
-
if (
|
| 14346 |
-
|
| 14347 |
-
|
| 14348 |
-
|
| 14349 |
-
freq_factors = (const float *) src2->data;
|
| 14350 |
-
}
|
| 14351 |
-
} else {
|
| 14352 |
-
GGML_ASSERT(src2 == NULL && "TODO: freq_factors not implemented for !is_neox");
|
| 14353 |
}
|
| 14354 |
|
| 14355 |
// backward process uses inverse rotation by cos and sin.
|
|
@@ -14364,94 +14327,50 @@ static void ggml_compute_forward_rope_f32(
|
|
| 14364 |
const int64_t p = pos[i2];
|
| 14365 |
|
| 14366 |
float * cache = (float *) params->wdata + (ne0 + CACHE_LINE_SIZE_F32)*ith;
|
| 14367 |
-
|
| 14368 |
-
ggml_rope_cache_init(p, freq_scale, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
|
| 14369 |
-
}
|
| 14370 |
|
| 14371 |
for (int64_t i1 = 0; i1 < ne1; i1++) {
|
| 14372 |
if (ir++ < ir0) continue;
|
| 14373 |
if (ir > ir1) break;
|
| 14374 |
|
| 14375 |
-
|
| 14376 |
-
|
| 14377 |
-
if (is_glm) {
|
| 14378 |
-
theta_base = MIN(p, n_ctx - 2);
|
| 14379 |
-
float block_theta = MAX(p - (n_ctx - 2), 0);
|
| 14380 |
-
for (int64_t i0 = 0; i0 < ne0 / 4; i0++) {
|
| 14381 |
-
const float cos_theta = cosf(theta_base);
|
| 14382 |
-
const float sin_theta = sinf(theta_base) * sin_sign;
|
| 14383 |
-
const float cos_block_theta = cosf(block_theta);
|
| 14384 |
-
const float sin_block_theta = sinf(block_theta) * sin_sign;
|
| 14385 |
-
|
| 14386 |
-
theta_base *= theta_scale;
|
| 14387 |
-
block_theta *= theta_scale;
|
| 14388 |
-
|
| 14389 |
-
const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
|
| 14390 |
-
float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
| 14391 |
-
|
| 14392 |
-
const float x0 = src[0];
|
| 14393 |
-
const float x1 = src[n_dims/2];
|
| 14394 |
-
const float x2 = src[n_dims];
|
| 14395 |
-
const float x3 = src[n_dims/2*3];
|
| 14396 |
-
|
| 14397 |
-
dst_data[0] = x0*cos_theta - x1*sin_theta;
|
| 14398 |
-
dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta;
|
| 14399 |
-
dst_data[n_dims] = x2*cos_block_theta - x3*sin_block_theta;
|
| 14400 |
-
dst_data[n_dims/2*3] = x2*sin_block_theta + x3*cos_block_theta;
|
| 14401 |
-
}
|
| 14402 |
-
} else if (!is_neox) {
|
| 14403 |
-
for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
|
| 14404 |
const float cos_theta = cache[i0 + 0];
|
| 14405 |
const float sin_theta = cache[i0 + 1];
|
| 14406 |
|
| 14407 |
-
// zeta scaling for xPos only:
|
| 14408 |
-
float zeta = xpos_base != 0.0f ? powf((i0 + 0.4f * ne0) / (1.4f * ne0), p / xpos_base) : 1.0f;
|
| 14409 |
-
if (xpos_down) zeta = 1.0f / zeta;
|
| 14410 |
-
|
| 14411 |
const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
|
| 14412 |
float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
| 14413 |
|
| 14414 |
const float x0 = src[0];
|
| 14415 |
const float x1 = src[1];
|
| 14416 |
|
| 14417 |
-
dst_data[0] = x0*cos_theta
|
| 14418 |
-
dst_data[1] = x0*sin_theta
|
| 14419 |
}
|
| 14420 |
} else {
|
| 14421 |
-
|
| 14422 |
-
|
| 14423 |
-
if (ic < n_dims) {
|
| 14424 |
-
const int64_t i0 = ic/2;
|
| 14425 |
-
|
| 14426 |
-
const float freq_factor = freq_factors ? freq_factors[i0] : 1.0f;
|
| 14427 |
-
|
| 14428 |
-
float cos_theta, sin_theta;
|
| 14429 |
-
rope_yarn(
|
| 14430 |
-
theta_base/freq_factor, freq_scale, corr_dims, ic, ext_factor, attn_factor,
|
| 14431 |
-
&cos_theta, &sin_theta
|
| 14432 |
-
);
|
| 14433 |
|
| 14434 |
-
|
| 14435 |
-
|
| 14436 |
|
| 14437 |
-
|
| 14438 |
-
|
| 14439 |
|
| 14440 |
-
|
| 14441 |
-
|
| 14442 |
|
| 14443 |
-
|
| 14444 |
-
|
| 14445 |
-
|
| 14446 |
-
|
| 14447 |
|
| 14448 |
-
|
| 14449 |
-
|
|
|
|
| 14450 |
|
| 14451 |
-
|
| 14452 |
-
|
| 14453 |
-
}
|
| 14454 |
-
}
|
| 14455 |
}
|
| 14456 |
}
|
| 14457 |
}
|
|
@@ -14477,8 +14396,8 @@ static void ggml_compute_forward_rope_f16(
|
|
| 14477 |
//const int n_past = ((int32_t *) dst->op_params)[0];
|
| 14478 |
const int n_dims = ((int32_t *) dst->op_params)[1];
|
| 14479 |
const int mode = ((int32_t *) dst->op_params)[2];
|
| 14480 |
-
const int n_ctx = ((int32_t *) dst->op_params)[3];
|
| 14481 |
-
const int
|
| 14482 |
memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float));
|
| 14483 |
memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float));
|
| 14484 |
memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float));
|
|
@@ -14514,20 +14433,15 @@ static void ggml_compute_forward_rope_f16(
|
|
| 14514 |
const float theta_scale = powf(freq_base, -2.0f/n_dims);
|
| 14515 |
|
| 14516 |
float corr_dims[2];
|
| 14517 |
-
ggml_rope_yarn_corr_dims(n_dims,
|
| 14518 |
|
| 14519 |
const bool is_neox = mode & 2;
|
| 14520 |
-
const bool is_glm = mode & 4;
|
| 14521 |
|
| 14522 |
const float * freq_factors = NULL;
|
| 14523 |
-
if (
|
| 14524 |
-
|
| 14525 |
-
|
| 14526 |
-
|
| 14527 |
-
freq_factors = (const float *) src2->data;
|
| 14528 |
-
}
|
| 14529 |
-
} else {
|
| 14530 |
-
GGML_ASSERT(src2 == NULL && "TODO: freq_factors not implemented for !is_neox");
|
| 14531 |
}
|
| 14532 |
|
| 14533 |
// backward process uses inverse rotation by cos and sin.
|
|
@@ -14542,43 +14456,14 @@ static void ggml_compute_forward_rope_f16(
|
|
| 14542 |
const int64_t p = pos[i2];
|
| 14543 |
|
| 14544 |
float * cache = (float *) params->wdata + (ne0 + CACHE_LINE_SIZE_F32)*ith;
|
| 14545 |
-
|
| 14546 |
-
ggml_rope_cache_init(p, freq_scale, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
|
| 14547 |
-
}
|
| 14548 |
|
| 14549 |
for (int64_t i1 = 0; i1 < ne1; i1++) {
|
| 14550 |
if (ir++ < ir0) continue;
|
| 14551 |
if (ir > ir1) break;
|
| 14552 |
|
| 14553 |
-
|
| 14554 |
-
|
| 14555 |
-
if (is_glm) {
|
| 14556 |
-
theta_base = MIN(p, n_ctx - 2);
|
| 14557 |
-
float block_theta = MAX(p - (n_ctx - 2), 0);
|
| 14558 |
-
for (int64_t i0 = 0; i0 < ne0 / 4; i0++) {
|
| 14559 |
-
const float cos_theta = cosf(theta_base);
|
| 14560 |
-
const float sin_theta = sinf(theta_base) * sin_sign;
|
| 14561 |
-
const float cos_block_theta = cosf(block_theta);
|
| 14562 |
-
const float sin_block_theta = sinf(block_theta) * sin_sign;
|
| 14563 |
-
|
| 14564 |
-
theta_base *= theta_scale;
|
| 14565 |
-
block_theta *= theta_scale;
|
| 14566 |
-
|
| 14567 |
-
const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
|
| 14568 |
-
ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
| 14569 |
-
|
| 14570 |
-
const float x0 = GGML_FP16_TO_FP32(src[0]);
|
| 14571 |
-
const float x1 = GGML_FP16_TO_FP32(src[n_dims/2]);
|
| 14572 |
-
const float x2 = GGML_FP16_TO_FP32(src[n_dims]);
|
| 14573 |
-
const float x3 = GGML_FP16_TO_FP32(src[n_dims/2*3]);
|
| 14574 |
-
|
| 14575 |
-
dst_data[0] = GGML_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
|
| 14576 |
-
dst_data[n_dims/2] = GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
|
| 14577 |
-
dst_data[n_dims] = GGML_FP32_TO_FP16(x2*cos_block_theta - x3*sin_block_theta);
|
| 14578 |
-
dst_data[n_dims/2*3] = GGML_FP32_TO_FP16(x2*sin_block_theta + x3*cos_block_theta);
|
| 14579 |
-
}
|
| 14580 |
-
} else if (!is_neox) {
|
| 14581 |
-
for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
|
| 14582 |
const float cos_theta = cache[i0 + 0];
|
| 14583 |
const float sin_theta = cache[i0 + 1];
|
| 14584 |
|
|
@@ -14592,40 +14477,29 @@ static void ggml_compute_forward_rope_f16(
|
|
| 14592 |
dst_data[1] = GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
|
| 14593 |
}
|
| 14594 |
} else {
|
| 14595 |
-
|
| 14596 |
-
|
| 14597 |
-
if (ic < n_dims) {
|
| 14598 |
-
const int64_t i0 = ic/2;
|
| 14599 |
-
|
| 14600 |
-
const float freq_factor = freq_factors ? freq_factors[i0] : 1.0f;
|
| 14601 |
|
| 14602 |
-
|
| 14603 |
-
|
| 14604 |
-
theta_base/freq_factor, freq_scale, corr_dims, ic, ext_factor, attn_factor,
|
| 14605 |
-
&cos_theta, &sin_theta
|
| 14606 |
-
);
|
| 14607 |
-
|
| 14608 |
-
sin_theta *= sin_sign;
|
| 14609 |
-
theta_base *= theta_scale;
|
| 14610 |
|
| 14611 |
-
|
| 14612 |
-
|
| 14613 |
|
| 14614 |
-
|
| 14615 |
-
|
| 14616 |
|
| 14617 |
-
|
| 14618 |
-
|
| 14619 |
-
|
| 14620 |
-
|
| 14621 |
|
| 14622 |
-
|
| 14623 |
-
|
|
|
|
| 14624 |
|
| 14625 |
-
|
| 14626 |
-
|
| 14627 |
-
}
|
| 14628 |
-
}
|
| 14629 |
}
|
| 14630 |
}
|
| 14631 |
}
|
|
@@ -18327,9 +18201,9 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|
| 18327 |
//const int n_past = ((int32_t *) tensor->op_params)[0];
|
| 18328 |
const int n_dims = ((int32_t *) tensor->op_params)[1];
|
| 18329 |
const int mode = ((int32_t *) tensor->op_params)[2];
|
| 18330 |
-
const int n_ctx = ((int32_t *) tensor->op_params)[3];
|
| 18331 |
-
const int
|
| 18332 |
-
float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
|
| 18333 |
|
| 18334 |
memcpy(&freq_base, (int32_t *) tensor->op_params + 5, sizeof(float));
|
| 18335 |
memcpy(&freq_scale, (int32_t *) tensor->op_params + 6, sizeof(float));
|
|
@@ -18337,8 +18211,6 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|
| 18337 |
memcpy(&attn_factor, (int32_t *) tensor->op_params + 8, sizeof(float));
|
| 18338 |
memcpy(&beta_fast, (int32_t *) tensor->op_params + 9, sizeof(float));
|
| 18339 |
memcpy(&beta_slow, (int32_t *) tensor->op_params + 10, sizeof(float));
|
| 18340 |
-
memcpy(&xpos_base, (int32_t *) tensor->op_params + 11, sizeof(float));
|
| 18341 |
-
memcpy(&xpos_down, (int32_t *) tensor->op_params + 12, sizeof(bool));
|
| 18342 |
|
| 18343 |
src0->grad = ggml_add_or_set(ctx,
|
| 18344 |
src0->grad,
|
|
@@ -18348,16 +18220,13 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|
| 18348 |
src2,
|
| 18349 |
n_dims,
|
| 18350 |
mode,
|
| 18351 |
-
|
| 18352 |
-
n_orig_ctx,
|
| 18353 |
freq_base,
|
| 18354 |
freq_scale,
|
| 18355 |
ext_factor,
|
| 18356 |
attn_factor,
|
| 18357 |
beta_fast,
|
| 18358 |
-
beta_slow,
|
| 18359 |
-
xpos_base,
|
| 18360 |
-
xpos_down),
|
| 18361 |
zero_table);
|
| 18362 |
}
|
| 18363 |
} break;
|
|
@@ -18367,9 +18236,9 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|
| 18367 |
//const int n_past = ((int32_t *) tensor->op_params)[0];
|
| 18368 |
const int n_dims = ((int32_t *) tensor->op_params)[1];
|
| 18369 |
const int mode = ((int32_t *) tensor->op_params)[2];
|
| 18370 |
-
const int n_ctx = ((int32_t *) tensor->op_params)[3];
|
| 18371 |
-
const int
|
| 18372 |
-
float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
|
| 18373 |
|
| 18374 |
memcpy(&freq_base, (int32_t *) tensor->op_params + 5, sizeof(float));
|
| 18375 |
memcpy(&freq_scale, (int32_t *) tensor->op_params + 6, sizeof(float));
|
|
@@ -18377,8 +18246,6 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|
| 18377 |
memcpy(&attn_factor, (int32_t *) tensor->op_params + 8, sizeof(float));
|
| 18378 |
memcpy(&beta_fast, (int32_t *) tensor->op_params + 9, sizeof(float));
|
| 18379 |
memcpy(&beta_slow, (int32_t *) tensor->op_params + 10, sizeof(float));
|
| 18380 |
-
memcpy(&xpos_base, (int32_t *) tensor->op_params + 11, sizeof(float));
|
| 18381 |
-
memcpy(&xpos_down, (int32_t *) tensor->op_params + 12, sizeof(bool));
|
| 18382 |
|
| 18383 |
src0->grad = ggml_add_or_set(ctx,
|
| 18384 |
src0->grad,
|
|
@@ -18388,16 +18255,13 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|
| 18388 |
src2,
|
| 18389 |
n_dims,
|
| 18390 |
mode,
|
| 18391 |
-
|
| 18392 |
-
n_orig_ctx,
|
| 18393 |
freq_base,
|
| 18394 |
freq_scale,
|
| 18395 |
ext_factor,
|
| 18396 |
attn_factor,
|
| 18397 |
beta_fast,
|
| 18398 |
beta_slow,
|
| 18399 |
-
xpos_base,
|
| 18400 |
-
xpos_down,
|
| 18401 |
false),
|
| 18402 |
zero_table);
|
| 18403 |
}
|
|
|
|
| 6250 |
struct ggml_tensor * c,
|
| 6251 |
int n_dims,
|
| 6252 |
int mode,
|
| 6253 |
+
int n_ctx_orig,
|
|
|
|
| 6254 |
float freq_base,
|
| 6255 |
float freq_scale,
|
| 6256 |
float ext_factor,
|
| 6257 |
float attn_factor,
|
| 6258 |
float beta_fast,
|
| 6259 |
float beta_slow,
|
|
|
|
|
|
|
| 6260 |
bool inplace) {
|
| 6261 |
GGML_ASSERT((mode & 1) == 0 && "mode & 1 == 1 is no longer supported");
|
| 6262 |
|
|
|
|
| 6277 |
|
| 6278 |
struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
|
| 6279 |
|
| 6280 |
+
int32_t params[11] = { /*n_past*/ 0, n_dims, mode, /*n_ctx*/ 0, n_ctx_orig };
|
| 6281 |
memcpy(params + 5, &freq_base, sizeof(float));
|
| 6282 |
memcpy(params + 6, &freq_scale, sizeof(float));
|
| 6283 |
memcpy(params + 7, &ext_factor, sizeof(float));
|
| 6284 |
memcpy(params + 8, &attn_factor, sizeof(float));
|
| 6285 |
memcpy(params + 9, &beta_fast, sizeof(float));
|
| 6286 |
memcpy(params + 10, &beta_slow, sizeof(float));
|
|
|
|
|
|
|
| 6287 |
ggml_set_op_params(result, params, sizeof(params));
|
| 6288 |
|
| 6289 |
result->op = GGML_OP_ROPE;
|
|
|
|
| 6300 |
struct ggml_tensor * a,
|
| 6301 |
struct ggml_tensor * b,
|
| 6302 |
int n_dims,
|
| 6303 |
+
int mode) {
|
|
|
|
| 6304 |
return ggml_rope_impl(
|
| 6305 |
+
ctx, a, b, NULL, n_dims, mode, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, false
|
| 6306 |
);
|
| 6307 |
}
|
| 6308 |
|
|
|
|
| 6311 |
struct ggml_tensor * a,
|
| 6312 |
struct ggml_tensor * b,
|
| 6313 |
int n_dims,
|
| 6314 |
+
int mode) {
|
|
|
|
| 6315 |
return ggml_rope_impl(
|
| 6316 |
+
ctx, a, b, NULL, n_dims, mode, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, true
|
| 6317 |
);
|
| 6318 |
}
|
| 6319 |
|
|
|
|
| 6324 |
struct ggml_tensor * c,
|
| 6325 |
int n_dims,
|
| 6326 |
int mode,
|
| 6327 |
+
int n_ctx_orig,
|
|
|
|
| 6328 |
float freq_base,
|
| 6329 |
float freq_scale,
|
| 6330 |
float ext_factor,
|
|
|
|
| 6332 |
float beta_fast,
|
| 6333 |
float beta_slow) {
|
| 6334 |
return ggml_rope_impl(
|
| 6335 |
+
ctx, a, b, c, n_dims, mode, n_ctx_orig, freq_base, freq_scale,
|
| 6336 |
+
ext_factor, attn_factor, beta_fast, beta_slow, false
|
| 6337 |
);
|
| 6338 |
}
|
| 6339 |
|
|
|
|
| 6344 |
struct ggml_tensor * c,
|
| 6345 |
int n_dims,
|
| 6346 |
int mode,
|
| 6347 |
+
int n_ctx_orig,
|
|
|
|
| 6348 |
float freq_base,
|
| 6349 |
float freq_scale,
|
| 6350 |
float ext_factor,
|
|
|
|
| 6352 |
float beta_fast,
|
| 6353 |
float beta_slow) {
|
| 6354 |
return ggml_rope_impl(
|
| 6355 |
+
ctx, a, b, c, n_dims, mode, n_ctx_orig, freq_base, freq_scale,
|
| 6356 |
+
ext_factor, attn_factor, beta_fast, beta_slow, true
|
| 6357 |
);
|
| 6358 |
}
|
| 6359 |
|
|
|
|
| 6363 |
struct ggml_tensor * b,
|
| 6364 |
int n_dims,
|
| 6365 |
int mode,
|
| 6366 |
+
int n_ctx_orig,
|
|
|
|
| 6367 |
float freq_base,
|
| 6368 |
float freq_scale,
|
| 6369 |
float ext_factor,
|
|
|
|
| 6371 |
float beta_fast,
|
| 6372 |
float beta_slow) {
|
| 6373 |
return ggml_rope_impl(
|
| 6374 |
+
ctx, a, b, NULL, n_dims, mode, n_ctx_orig, freq_base, freq_scale,
|
| 6375 |
+
ext_factor, attn_factor, beta_fast, beta_slow, false
|
| 6376 |
);
|
| 6377 |
}
|
| 6378 |
|
|
|
|
| 6382 |
struct ggml_tensor * b,
|
| 6383 |
int n_dims,
|
| 6384 |
int mode,
|
| 6385 |
+
int n_ctx_orig,
|
|
|
|
| 6386 |
float freq_base,
|
| 6387 |
float freq_scale,
|
| 6388 |
float ext_factor,
|
|
|
|
| 6390 |
float beta_fast,
|
| 6391 |
float beta_slow) {
|
| 6392 |
return ggml_rope_impl(
|
| 6393 |
+
ctx, a, b, NULL, n_dims, mode, n_ctx_orig, freq_base, freq_scale,
|
| 6394 |
+
ext_factor, attn_factor, beta_fast, beta_slow, true
|
| 6395 |
);
|
| 6396 |
}
|
| 6397 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6398 |
// ggml_rope_back
|
| 6399 |
|
| 6400 |
struct ggml_tensor * ggml_rope_back(
|
|
|
|
| 6404 |
struct ggml_tensor * c,
|
| 6405 |
int n_dims,
|
| 6406 |
int mode,
|
| 6407 |
+
int n_ctx_orig,
|
|
|
|
| 6408 |
float freq_base,
|
| 6409 |
float freq_scale,
|
| 6410 |
float ext_factor,
|
| 6411 |
float attn_factor,
|
| 6412 |
float beta_fast,
|
| 6413 |
+
float beta_slow) {
|
|
|
|
|
|
|
| 6414 |
GGML_ASSERT(ggml_is_vector(b));
|
| 6415 |
GGML_ASSERT(b->type == GGML_TYPE_I32);
|
| 6416 |
GGML_ASSERT(a->ne[2] == b->ne[0]);
|
|
|
|
| 6426 |
|
| 6427 |
struct ggml_tensor * result = ggml_dup_tensor(ctx, a);
|
| 6428 |
|
| 6429 |
+
int32_t params[11] = { /*n_past*/ 0, n_dims, mode, /*n_ctx*/ 0, n_ctx_orig };
|
| 6430 |
memcpy(params + 5, &freq_base, sizeof(float));
|
| 6431 |
memcpy(params + 6, &freq_scale, sizeof(float));
|
| 6432 |
memcpy(params + 7, &ext_factor, sizeof(float));
|
| 6433 |
memcpy(params + 8, &attn_factor, sizeof(float));
|
| 6434 |
memcpy(params + 9, &beta_fast, sizeof(float));
|
| 6435 |
memcpy(params + 10, &beta_slow, sizeof(float));
|
|
|
|
|
|
|
| 6436 |
ggml_set_op_params(result, params, sizeof(params));
|
| 6437 |
|
| 6438 |
result->op = GGML_OP_ROPE_BACK;
|
|
|
|
| 14201 |
// MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng.
|
| 14202 |
static void rope_yarn(
|
| 14203 |
float theta_extrap, float freq_scale, float corr_dims[2], int64_t i0, float ext_factor, float mscale,
|
| 14204 |
+
float * cos_theta, float * sin_theta) {
|
|
|
|
| 14205 |
// Get n-d rotational scaling corrected for extrapolation
|
| 14206 |
float theta_interp = freq_scale * theta_extrap;
|
| 14207 |
float theta = theta_interp;
|
|
|
|
| 14218 |
|
| 14219 |
// Apparently solving `n_rot = 2pi * x * base^((2 * max_pos_emb) / n_dims)` for x, we get
|
| 14220 |
// `corr_dim(n_rot) = n_dims * log(max_pos_emb / (n_rot * 2pi)) / (2 * log(base))`
|
| 14221 |
+
static float ggml_rope_yarn_corr_dim(int n_dims, int n_ctx_orig, float n_rot, float base) {
|
| 14222 |
+
return n_dims * logf(n_ctx_orig / (n_rot * 2 * (float)M_PI)) / (2 * logf(base));
|
| 14223 |
}
|
| 14224 |
|
| 14225 |
static void ggml_rope_cache_init(
|
| 14226 |
+
float theta_base, float freq_scale, const float * freq_factors, float corr_dims[2], int64_t ne0, float ext_factor, float mscale,
|
| 14227 |
+
float * cache, float sin_sign, float theta_scale) {
|
| 14228 |
+
// ref: https://github.com/jquesnelle/yarn/blob/master/scaled_rope/LlamaYaRNScaledRotaryEmbedding.py
|
| 14229 |
float theta = theta_base;
|
| 14230 |
for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
|
| 14231 |
+
const float ff = freq_factors ? freq_factors[i0/2] : 1.0f;
|
| 14232 |
rope_yarn(
|
| 14233 |
+
theta/ff, freq_scale, corr_dims, i0, ext_factor, mscale, &cache[i0 + 0], &cache[i0 + 1]
|
| 14234 |
);
|
| 14235 |
cache[i0 + 1] *= sin_sign;
|
| 14236 |
|
|
|
|
| 14239 |
}
|
| 14240 |
|
| 14241 |
GGML_CALL void ggml_rope_yarn_corr_dims(
|
| 14242 |
+
int n_dims, int n_ctx_orig, float freq_base, float beta_fast, float beta_slow, float dims[2]
|
| 14243 |
) {
|
| 14244 |
// start and end correction dims
|
| 14245 |
+
float start = floorf(ggml_rope_yarn_corr_dim(n_dims, n_ctx_orig, beta_fast, freq_base));
|
| 14246 |
+
float end = ceilf(ggml_rope_yarn_corr_dim(n_dims, n_ctx_orig, beta_slow, freq_base));
|
| 14247 |
dims[0] = MAX(0, start);
|
| 14248 |
dims[1] = MIN(n_dims - 1, end);
|
| 14249 |
}
|
|
|
|
| 14263 |
|
| 14264 |
float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
|
| 14265 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14266 |
//const int n_past = ((int32_t *) dst->op_params)[0];
|
| 14267 |
const int n_dims = ((int32_t *) dst->op_params)[1];
|
| 14268 |
const int mode = ((int32_t *) dst->op_params)[2];
|
| 14269 |
+
//const int n_ctx = ((int32_t *) dst->op_params)[3];
|
| 14270 |
+
const int n_ctx_orig = ((int32_t *) dst->op_params)[4];
|
| 14271 |
|
| 14272 |
memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float));
|
| 14273 |
memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float));
|
|
|
|
| 14275 |
memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float));
|
| 14276 |
memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
|
| 14277 |
memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
|
|
|
|
|
|
|
| 14278 |
|
| 14279 |
GGML_TENSOR_UNARY_OP_LOCALS
|
| 14280 |
|
|
|
|
| 14304 |
const float theta_scale = powf(freq_base, -2.0f/n_dims);
|
| 14305 |
|
| 14306 |
float corr_dims[2];
|
| 14307 |
+
ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims);
|
| 14308 |
|
| 14309 |
const bool is_neox = mode & 2;
|
|
|
|
| 14310 |
|
| 14311 |
const float * freq_factors = NULL;
|
| 14312 |
+
if (src2 != NULL) {
|
| 14313 |
+
GGML_ASSERT(src2->type == GGML_TYPE_F32);
|
| 14314 |
+
GGML_ASSERT(src2->ne[0] >= n_dims / 2);
|
| 14315 |
+
freq_factors = (const float *) src2->data;
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14316 |
}
|
| 14317 |
|
| 14318 |
// backward process uses inverse rotation by cos and sin.
|
|
|
|
| 14327 |
const int64_t p = pos[i2];
|
| 14328 |
|
| 14329 |
float * cache = (float *) params->wdata + (ne0 + CACHE_LINE_SIZE_F32)*ith;
|
| 14330 |
+
ggml_rope_cache_init(p, freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
|
|
|
|
|
|
|
| 14331 |
|
| 14332 |
for (int64_t i1 = 0; i1 < ne1; i1++) {
|
| 14333 |
if (ir++ < ir0) continue;
|
| 14334 |
if (ir > ir1) break;
|
| 14335 |
|
| 14336 |
+
if (!is_neox) {
|
| 14337 |
+
for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14338 |
const float cos_theta = cache[i0 + 0];
|
| 14339 |
const float sin_theta = cache[i0 + 1];
|
| 14340 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14341 |
const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
|
| 14342 |
float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
| 14343 |
|
| 14344 |
const float x0 = src[0];
|
| 14345 |
const float x1 = src[1];
|
| 14346 |
|
| 14347 |
+
dst_data[0] = x0*cos_theta - x1*sin_theta;
|
| 14348 |
+
dst_data[1] = x0*sin_theta + x1*cos_theta;
|
| 14349 |
}
|
| 14350 |
} else {
|
| 14351 |
+
for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
|
| 14352 |
+
const int64_t ic = i0/2;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14353 |
|
| 14354 |
+
const float cos_theta = cache[i0 + 0];
|
| 14355 |
+
const float sin_theta = cache[i0 + 1];
|
| 14356 |
|
| 14357 |
+
const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
|
| 14358 |
+
float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
|
| 14359 |
|
| 14360 |
+
const float x0 = src[0];
|
| 14361 |
+
const float x1 = src[n_dims/2];
|
| 14362 |
|
| 14363 |
+
dst_data[0] = x0*cos_theta - x1*sin_theta;
|
| 14364 |
+
dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta;
|
| 14365 |
+
}
|
| 14366 |
+
}
|
| 14367 |
|
| 14368 |
+
for (int64_t i0 = n_dims; i0 < ne0; i0 += 2) {
|
| 14369 |
+
const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
|
| 14370 |
+
float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
| 14371 |
|
| 14372 |
+
dst_data[0] = src[0];
|
| 14373 |
+
dst_data[1] = src[1];
|
|
|
|
|
|
|
| 14374 |
}
|
| 14375 |
}
|
| 14376 |
}
|
|
|
|
| 14396 |
//const int n_past = ((int32_t *) dst->op_params)[0];
|
| 14397 |
const int n_dims = ((int32_t *) dst->op_params)[1];
|
| 14398 |
const int mode = ((int32_t *) dst->op_params)[2];
|
| 14399 |
+
//const int n_ctx = ((int32_t *) dst->op_params)[3];
|
| 14400 |
+
const int n_ctx_orig = ((int32_t *) dst->op_params)[4];
|
| 14401 |
memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float));
|
| 14402 |
memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float));
|
| 14403 |
memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float));
|
|
|
|
| 14433 |
const float theta_scale = powf(freq_base, -2.0f/n_dims);
|
| 14434 |
|
| 14435 |
float corr_dims[2];
|
| 14436 |
+
ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims);
|
| 14437 |
|
| 14438 |
const bool is_neox = mode & 2;
|
|
|
|
| 14439 |
|
| 14440 |
const float * freq_factors = NULL;
|
| 14441 |
+
if (src2 != NULL) {
|
| 14442 |
+
GGML_ASSERT(src2->type == GGML_TYPE_F32);
|
| 14443 |
+
GGML_ASSERT(src2->ne[0] >= n_dims / 2);
|
| 14444 |
+
freq_factors = (const float *) src2->data;
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14445 |
}
|
| 14446 |
|
| 14447 |
// backward process uses inverse rotation by cos and sin.
|
|
|
|
| 14456 |
const int64_t p = pos[i2];
|
| 14457 |
|
| 14458 |
float * cache = (float *) params->wdata + (ne0 + CACHE_LINE_SIZE_F32)*ith;
|
| 14459 |
+
ggml_rope_cache_init(p, freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
|
|
|
|
|
|
|
| 14460 |
|
| 14461 |
for (int64_t i1 = 0; i1 < ne1; i1++) {
|
| 14462 |
if (ir++ < ir0) continue;
|
| 14463 |
if (ir > ir1) break;
|
| 14464 |
|
| 14465 |
+
if (!is_neox) {
|
| 14466 |
+
for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14467 |
const float cos_theta = cache[i0 + 0];
|
| 14468 |
const float sin_theta = cache[i0 + 1];
|
| 14469 |
|
|
|
|
| 14477 |
dst_data[1] = GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
|
| 14478 |
}
|
| 14479 |
} else {
|
| 14480 |
+
for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
|
| 14481 |
+
const int64_t ic = i0/2;
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14482 |
|
| 14483 |
+
const float cos_theta = cache[i0 + 0];
|
| 14484 |
+
const float sin_theta = cache[i0 + 1];
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14485 |
|
| 14486 |
+
const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
|
| 14487 |
+
ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
|
| 14488 |
|
| 14489 |
+
const float x0 = GGML_FP16_TO_FP32(src[0]);
|
| 14490 |
+
const float x1 = GGML_FP16_TO_FP32(src[n_dims/2]);
|
| 14491 |
|
| 14492 |
+
dst_data[0] = GGML_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
|
| 14493 |
+
dst_data[n_dims/2] = GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
|
| 14494 |
+
}
|
| 14495 |
+
}
|
| 14496 |
|
| 14497 |
+
for (int64_t i0 = n_dims; i0 < ne0; i0 += 2) {
|
| 14498 |
+
const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
|
| 14499 |
+
ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
| 14500 |
|
| 14501 |
+
dst_data[0] = src[0];
|
| 14502 |
+
dst_data[1] = src[1];
|
|
|
|
|
|
|
| 14503 |
}
|
| 14504 |
}
|
| 14505 |
}
|
|
|
|
| 18201 |
//const int n_past = ((int32_t *) tensor->op_params)[0];
|
| 18202 |
const int n_dims = ((int32_t *) tensor->op_params)[1];
|
| 18203 |
const int mode = ((int32_t *) tensor->op_params)[2];
|
| 18204 |
+
//const int n_ctx = ((int32_t *) tensor->op_params)[3];
|
| 18205 |
+
const int n_ctx_orig = ((int32_t *) tensor->op_params)[4];
|
| 18206 |
+
float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
|
| 18207 |
|
| 18208 |
memcpy(&freq_base, (int32_t *) tensor->op_params + 5, sizeof(float));
|
| 18209 |
memcpy(&freq_scale, (int32_t *) tensor->op_params + 6, sizeof(float));
|
|
|
|
| 18211 |
memcpy(&attn_factor, (int32_t *) tensor->op_params + 8, sizeof(float));
|
| 18212 |
memcpy(&beta_fast, (int32_t *) tensor->op_params + 9, sizeof(float));
|
| 18213 |
memcpy(&beta_slow, (int32_t *) tensor->op_params + 10, sizeof(float));
|
|
|
|
|
|
|
| 18214 |
|
| 18215 |
src0->grad = ggml_add_or_set(ctx,
|
| 18216 |
src0->grad,
|
|
|
|
| 18220 |
src2,
|
| 18221 |
n_dims,
|
| 18222 |
mode,
|
| 18223 |
+
n_ctx_orig,
|
|
|
|
| 18224 |
freq_base,
|
| 18225 |
freq_scale,
|
| 18226 |
ext_factor,
|
| 18227 |
attn_factor,
|
| 18228 |
beta_fast,
|
| 18229 |
+
beta_slow),
|
|
|
|
|
|
|
| 18230 |
zero_table);
|
| 18231 |
}
|
| 18232 |
} break;
|
|
|
|
| 18236 |
//const int n_past = ((int32_t *) tensor->op_params)[0];
|
| 18237 |
const int n_dims = ((int32_t *) tensor->op_params)[1];
|
| 18238 |
const int mode = ((int32_t *) tensor->op_params)[2];
|
| 18239 |
+
//const int n_ctx = ((int32_t *) tensor->op_params)[3];
|
| 18240 |
+
const int n_ctx_orig = ((int32_t *) tensor->op_params)[4];
|
| 18241 |
+
float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
|
| 18242 |
|
| 18243 |
memcpy(&freq_base, (int32_t *) tensor->op_params + 5, sizeof(float));
|
| 18244 |
memcpy(&freq_scale, (int32_t *) tensor->op_params + 6, sizeof(float));
|
|
|
|
| 18246 |
memcpy(&attn_factor, (int32_t *) tensor->op_params + 8, sizeof(float));
|
| 18247 |
memcpy(&beta_fast, (int32_t *) tensor->op_params + 9, sizeof(float));
|
| 18248 |
memcpy(&beta_slow, (int32_t *) tensor->op_params + 10, sizeof(float));
|
|
|
|
|
|
|
| 18249 |
|
| 18250 |
src0->grad = ggml_add_or_set(ctx,
|
| 18251 |
src0->grad,
|
|
|
|
| 18255 |
src2,
|
| 18256 |
n_dims,
|
| 18257 |
mode,
|
| 18258 |
+
n_ctx_orig,
|
|
|
|
| 18259 |
freq_base,
|
| 18260 |
freq_scale,
|
| 18261 |
ext_factor,
|
| 18262 |
attn_factor,
|
| 18263 |
beta_fast,
|
| 18264 |
beta_slow,
|
|
|
|
|
|
|
| 18265 |
false),
|
| 18266 |
zero_table);
|
| 18267 |
}
|
ggml.h
CHANGED
|
@@ -1465,7 +1465,6 @@ extern "C" {
|
|
| 1465 |
// rotary position embedding
|
| 1466 |
// if mode & 1 == 1, skip n_past elements (NOT SUPPORTED)
|
| 1467 |
// if mode & 2 == 1, GPT-NeoX style
|
| 1468 |
-
// if mode & 4 == 1, ChatGLM style
|
| 1469 |
//
|
| 1470 |
// b is an int32 vector with size a->ne[2], it contains the positions
|
| 1471 |
// c is freq factors (e.g. phi3-128k), (optional)
|
|
@@ -1474,8 +1473,7 @@ extern "C" {
|
|
| 1474 |
struct ggml_tensor * a,
|
| 1475 |
struct ggml_tensor * b,
|
| 1476 |
int n_dims,
|
| 1477 |
-
int mode
|
| 1478 |
-
int n_ctx);
|
| 1479 |
|
| 1480 |
// in-place, returns view(a)
|
| 1481 |
GGML_API struct ggml_tensor * ggml_rope_inplace(
|
|
@@ -1483,8 +1481,7 @@ extern "C" {
|
|
| 1483 |
struct ggml_tensor * a,
|
| 1484 |
struct ggml_tensor * b,
|
| 1485 |
int n_dims,
|
| 1486 |
-
int mode
|
| 1487 |
-
int n_ctx);
|
| 1488 |
|
| 1489 |
// custom RoPE
|
| 1490 |
GGML_API struct ggml_tensor * ggml_rope_ext(
|
|
@@ -1494,8 +1491,7 @@ extern "C" {
|
|
| 1494 |
struct ggml_tensor * c,
|
| 1495 |
int n_dims,
|
| 1496 |
int mode,
|
| 1497 |
-
int
|
| 1498 |
-
int n_orig_ctx,
|
| 1499 |
float freq_base,
|
| 1500 |
float freq_scale,
|
| 1501 |
float ext_factor,
|
|
@@ -1511,8 +1507,7 @@ extern "C" {
|
|
| 1511 |
struct ggml_tensor * c,
|
| 1512 |
int n_dims,
|
| 1513 |
int mode,
|
| 1514 |
-
int
|
| 1515 |
-
int n_orig_ctx,
|
| 1516 |
float freq_base,
|
| 1517 |
float freq_scale,
|
| 1518 |
float ext_factor,
|
|
@@ -1526,8 +1521,7 @@ extern "C" {
|
|
| 1526 |
struct ggml_tensor * b,
|
| 1527 |
int n_dims,
|
| 1528 |
int mode,
|
| 1529 |
-
int
|
| 1530 |
-
int n_orig_ctx,
|
| 1531 |
float freq_base,
|
| 1532 |
float freq_scale,
|
| 1533 |
float ext_factor,
|
|
@@ -1542,8 +1536,7 @@ extern "C" {
|
|
| 1542 |
struct ggml_tensor * b,
|
| 1543 |
int n_dims,
|
| 1544 |
int mode,
|
| 1545 |
-
int
|
| 1546 |
-
int n_orig_ctx,
|
| 1547 |
float freq_base,
|
| 1548 |
float freq_scale,
|
| 1549 |
float ext_factor,
|
|
@@ -1552,17 +1545,9 @@ extern "C" {
|
|
| 1552 |
float beta_slow),
|
| 1553 |
"use ggml_rope_ext_inplace instead");
|
| 1554 |
|
| 1555 |
-
struct ggml_tensor * ggml_rope_xpos_inplace(
|
| 1556 |
-
struct ggml_context * ctx,
|
| 1557 |
-
struct ggml_tensor * a,
|
| 1558 |
-
struct ggml_tensor * b,
|
| 1559 |
-
int n_dims,
|
| 1560 |
-
float base,
|
| 1561 |
-
bool down);
|
| 1562 |
-
|
| 1563 |
// compute correction dims for YaRN RoPE scaling
|
| 1564 |
GGML_CALL void ggml_rope_yarn_corr_dims(
|
| 1565 |
-
int n_dims, int
|
| 1566 |
|
| 1567 |
// rotary position embedding backward, i.e compute dx from dy
|
| 1568 |
// a - dy
|
|
@@ -1573,16 +1558,13 @@ extern "C" {
|
|
| 1573 |
struct ggml_tensor * c,
|
| 1574 |
int n_dims,
|
| 1575 |
int mode,
|
| 1576 |
-
int
|
| 1577 |
-
int n_orig_ctx,
|
| 1578 |
float freq_base,
|
| 1579 |
float freq_scale,
|
| 1580 |
float ext_factor,
|
| 1581 |
float attn_factor,
|
| 1582 |
float beta_fast,
|
| 1583 |
-
float beta_slow
|
| 1584 |
-
float xpos_base,
|
| 1585 |
-
bool xpos_down);
|
| 1586 |
|
| 1587 |
// clamp
|
| 1588 |
// in-place, returns view(a)
|
|
|
|
| 1465 |
// rotary position embedding
|
| 1466 |
// if mode & 1 == 1, skip n_past elements (NOT SUPPORTED)
|
| 1467 |
// if mode & 2 == 1, GPT-NeoX style
|
|
|
|
| 1468 |
//
|
| 1469 |
// b is an int32 vector with size a->ne[2], it contains the positions
|
| 1470 |
// c is freq factors (e.g. phi3-128k), (optional)
|
|
|
|
| 1473 |
struct ggml_tensor * a,
|
| 1474 |
struct ggml_tensor * b,
|
| 1475 |
int n_dims,
|
| 1476 |
+
int mode);
|
|
|
|
| 1477 |
|
| 1478 |
// in-place, returns view(a)
|
| 1479 |
GGML_API struct ggml_tensor * ggml_rope_inplace(
|
|
|
|
| 1481 |
struct ggml_tensor * a,
|
| 1482 |
struct ggml_tensor * b,
|
| 1483 |
int n_dims,
|
| 1484 |
+
int mode);
|
|
|
|
| 1485 |
|
| 1486 |
// custom RoPE
|
| 1487 |
GGML_API struct ggml_tensor * ggml_rope_ext(
|
|
|
|
| 1491 |
struct ggml_tensor * c,
|
| 1492 |
int n_dims,
|
| 1493 |
int mode,
|
| 1494 |
+
int n_ctx_orig,
|
|
|
|
| 1495 |
float freq_base,
|
| 1496 |
float freq_scale,
|
| 1497 |
float ext_factor,
|
|
|
|
| 1507 |
struct ggml_tensor * c,
|
| 1508 |
int n_dims,
|
| 1509 |
int mode,
|
| 1510 |
+
int n_ctx_orig,
|
|
|
|
| 1511 |
float freq_base,
|
| 1512 |
float freq_scale,
|
| 1513 |
float ext_factor,
|
|
|
|
| 1521 |
struct ggml_tensor * b,
|
| 1522 |
int n_dims,
|
| 1523 |
int mode,
|
| 1524 |
+
int n_ctx_orig,
|
|
|
|
| 1525 |
float freq_base,
|
| 1526 |
float freq_scale,
|
| 1527 |
float ext_factor,
|
|
|
|
| 1536 |
struct ggml_tensor * b,
|
| 1537 |
int n_dims,
|
| 1538 |
int mode,
|
| 1539 |
+
int n_ctx_orig,
|
|
|
|
| 1540 |
float freq_base,
|
| 1541 |
float freq_scale,
|
| 1542 |
float ext_factor,
|
|
|
|
| 1545 |
float beta_slow),
|
| 1546 |
"use ggml_rope_ext_inplace instead");
|
| 1547 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1548 |
// compute correction dims for YaRN RoPE scaling
|
| 1549 |
GGML_CALL void ggml_rope_yarn_corr_dims(
|
| 1550 |
+
int n_dims, int n_ctx_orig, float freq_base, float beta_fast, float beta_slow, float dims[2]);
|
| 1551 |
|
| 1552 |
// rotary position embedding backward, i.e compute dx from dy
|
| 1553 |
// a - dy
|
|
|
|
| 1558 |
struct ggml_tensor * c,
|
| 1559 |
int n_dims,
|
| 1560 |
int mode,
|
| 1561 |
+
int n_ctx_orig,
|
|
|
|
| 1562 |
float freq_base,
|
| 1563 |
float freq_scale,
|
| 1564 |
float ext_factor,
|
| 1565 |
float attn_factor,
|
| 1566 |
float beta_fast,
|
| 1567 |
+
float beta_slow);
|
|
|
|
|
|
|
| 1568 |
|
| 1569 |
// clamp
|
| 1570 |
// in-place, returns view(a)
|