Spaces:
Sleeping
Sleeping
cuda : fix rope + add tests (llama/7452)
Browse files* cuda : fix rope pos data
ggml-ci
* ggml : drop mode & 1 == 1 support for ggml_rope
ggml-ci
* ggml : support freq_factors for f16 rope (CPU)
ggml-ci
* tests : add rope tests using frequency factors
ggml-ci
- ggml-cuda/rope.cu +2 -2
- ggml.c +18 -2
- ggml.h +1 -1
ggml-cuda/rope.cu
CHANGED
|
@@ -283,9 +283,9 @@ void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
|
| 283 |
const bool is_neox = mode & 2;
|
| 284 |
const bool is_glm = mode & 4;
|
| 285 |
|
| 286 |
-
|
| 287 |
-
pos = (const int32_t *) src1_d;
|
| 288 |
|
|
|
|
| 289 |
if (src2 != nullptr) {
|
| 290 |
freq_factors = (const float *) src2->data;
|
| 291 |
}
|
|
|
|
| 283 |
const bool is_neox = mode & 2;
|
| 284 |
const bool is_glm = mode & 4;
|
| 285 |
|
| 286 |
+
pos = (const int32_t *) src1_d;
|
|
|
|
| 287 |
|
| 288 |
+
if (is_neox) {
|
| 289 |
if (src2 != nullptr) {
|
| 290 |
freq_factors = (const float *) src2->data;
|
| 291 |
}
|
ggml.c
CHANGED
|
@@ -6245,6 +6245,8 @@ static struct ggml_tensor * ggml_rope_impl(
|
|
| 6245 |
float xpos_base,
|
| 6246 |
bool xpos_down,
|
| 6247 |
bool inplace) {
|
|
|
|
|
|
|
| 6248 |
GGML_ASSERT(ggml_is_vector(b));
|
| 6249 |
GGML_ASSERT(b->type == GGML_TYPE_I32);
|
| 6250 |
GGML_ASSERT(a->ne[2] == b->ne[0]);
|
|
@@ -14413,7 +14415,7 @@ static void ggml_compute_forward_rope_f32(
|
|
| 14413 |
freq_factors = (const float *) src2->data;
|
| 14414 |
}
|
| 14415 |
} else {
|
| 14416 |
-
GGML_ASSERT(src2 == NULL && "TODO: freq_factors not implemented for
|
| 14417 |
}
|
| 14418 |
|
| 14419 |
// backward process uses inverse rotation by cos and sin.
|
|
@@ -14529,6 +14531,7 @@ static void ggml_compute_forward_rope_f32(
|
|
| 14529 |
}
|
| 14530 |
}
|
| 14531 |
|
|
|
|
| 14532 |
static void ggml_compute_forward_rope_f16(
|
| 14533 |
const struct ggml_compute_params * params,
|
| 14534 |
struct ggml_tensor * dst,
|
|
@@ -14536,6 +14539,7 @@ static void ggml_compute_forward_rope_f16(
|
|
| 14536 |
|
| 14537 |
const struct ggml_tensor * src0 = dst->src[0];
|
| 14538 |
const struct ggml_tensor * src1 = dst->src[1];
|
|
|
|
| 14539 |
|
| 14540 |
if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
|
| 14541 |
return;
|
|
@@ -14588,6 +14592,17 @@ static void ggml_compute_forward_rope_f16(
|
|
| 14588 |
const bool is_neox = mode & 2;
|
| 14589 |
const bool is_glm = mode & 4;
|
| 14590 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14591 |
// backward process uses inverse rotation by cos and sin.
|
| 14592 |
// cos and sin build a rotation matrix, where the inverse is the transpose.
|
| 14593 |
// this essentially just switches the sign of sin.
|
|
@@ -14660,10 +14675,11 @@ static void ggml_compute_forward_rope_f16(
|
|
| 14660 |
|
| 14661 |
// simplified from `(ib * n_dims + ic) * inv_ndims`
|
| 14662 |
float cur_rot = inv_ndims * ic - ib;
|
|
|
|
| 14663 |
|
| 14664 |
float cos_theta, sin_theta;
|
| 14665 |
rope_yarn(
|
| 14666 |
-
theta_base, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor,
|
| 14667 |
&cos_theta, &sin_theta
|
| 14668 |
);
|
| 14669 |
sin_theta *= sin_sign;
|
|
|
|
| 6245 |
float xpos_base,
|
| 6246 |
bool xpos_down,
|
| 6247 |
bool inplace) {
|
| 6248 |
+
GGML_ASSERT((mode & 1) == 0 && "mode & 1 == 1 is no longer supported");
|
| 6249 |
+
|
| 6250 |
GGML_ASSERT(ggml_is_vector(b));
|
| 6251 |
GGML_ASSERT(b->type == GGML_TYPE_I32);
|
| 6252 |
GGML_ASSERT(a->ne[2] == b->ne[0]);
|
|
|
|
| 14415 |
freq_factors = (const float *) src2->data;
|
| 14416 |
}
|
| 14417 |
} else {
|
| 14418 |
+
GGML_ASSERT(src2 == NULL && "TODO: freq_factors not implemented for !is_neox");
|
| 14419 |
}
|
| 14420 |
|
| 14421 |
// backward process uses inverse rotation by cos and sin.
|
|
|
|
| 14531 |
}
|
| 14532 |
}
|
| 14533 |
|
| 14534 |
+
// TODO: deduplicate f16/f32 code
|
| 14535 |
static void ggml_compute_forward_rope_f16(
|
| 14536 |
const struct ggml_compute_params * params,
|
| 14537 |
struct ggml_tensor * dst,
|
|
|
|
| 14539 |
|
| 14540 |
const struct ggml_tensor * src0 = dst->src[0];
|
| 14541 |
const struct ggml_tensor * src1 = dst->src[1];
|
| 14542 |
+
const struct ggml_tensor * src2 = dst->src[2];
|
| 14543 |
|
| 14544 |
if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
|
| 14545 |
return;
|
|
|
|
| 14592 |
const bool is_neox = mode & 2;
|
| 14593 |
const bool is_glm = mode & 4;
|
| 14594 |
|
| 14595 |
+
const float * freq_factors = NULL;
|
| 14596 |
+
if (is_neox) {
|
| 14597 |
+
if (src2 != NULL) {
|
| 14598 |
+
GGML_ASSERT(src2->type == GGML_TYPE_F32);
|
| 14599 |
+
GGML_ASSERT(src2->ne[0] >= n_dims / 2);
|
| 14600 |
+
freq_factors = (const float *) src2->data;
|
| 14601 |
+
}
|
| 14602 |
+
} else {
|
| 14603 |
+
GGML_ASSERT(src2 == NULL && "TODO: freq_factors not implemented for !is_neox");
|
| 14604 |
+
}
|
| 14605 |
+
|
| 14606 |
// backward process uses inverse rotation by cos and sin.
|
| 14607 |
// cos and sin build a rotation matrix, where the inverse is the transpose.
|
| 14608 |
// this essentially just switches the sign of sin.
|
|
|
|
| 14675 |
|
| 14676 |
// simplified from `(ib * n_dims + ic) * inv_ndims`
|
| 14677 |
float cur_rot = inv_ndims * ic - ib;
|
| 14678 |
+
float freq_factor = freq_factors ? freq_factors[ic/2] : 1.0f;
|
| 14679 |
|
| 14680 |
float cos_theta, sin_theta;
|
| 14681 |
rope_yarn(
|
| 14682 |
+
theta_base/freq_factor, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor,
|
| 14683 |
&cos_theta, &sin_theta
|
| 14684 |
);
|
| 14685 |
sin_theta *= sin_sign;
|
ggml.h
CHANGED
|
@@ -1460,7 +1460,7 @@ extern "C" {
|
|
| 1460 |
struct ggml_tensor * b);
|
| 1461 |
|
| 1462 |
// rotary position embedding
|
| 1463 |
-
// if mode & 1 == 1, skip n_past elements (
|
| 1464 |
// if mode & 2 == 1, GPT-NeoX style
|
| 1465 |
// if mode & 4 == 1, ChatGLM style
|
| 1466 |
//
|
|
|
|
| 1460 |
struct ggml_tensor * b);
|
| 1461 |
|
| 1462 |
// rotary position embedding
|
| 1463 |
+
// if mode & 1 == 1, skip n_past elements (NOT SUPPORTED)
|
| 1464 |
// if mode & 2 == 1, GPT-NeoX style
|
| 1465 |
// if mode & 4 == 1, ChatGLM style
|
| 1466 |
//
|