Spaces:
Running
Running
ggml: aarch64: Implement SVE F32 kernels for Mamba Sequential Scan Algorithm (llama/13882)
Browse files* F32-Mamba-Seq_Scan-SVE
* Fix formatting
* ggml : missing space
---------
Co-authored-by: Georgi Gerganov <[email protected]>
- ggml/src/ggml-cpu/ops.cpp +74 -30
- ggml/src/ggml-cpu/vec.h +36 -0
ggml/src/ggml-cpu/ops.cpp
CHANGED
|
@@ -7633,39 +7633,83 @@ static void ggml_compute_forward_ssm_scan_f32(
|
|
| 7633 |
const int ir1 = MIN(ir0 + dr, nr);
|
| 7634 |
const int ir = ir1 - ir0;
|
| 7635 |
|
| 7636 |
-
|
| 7637 |
-
for (int
|
| 7638 |
-
|
| 7639 |
-
|
| 7640 |
-
|
| 7641 |
-
|
| 7642 |
-
|
| 7643 |
-
|
| 7644 |
-
float *
|
| 7645 |
-
|
| 7646 |
-
|
| 7647 |
-
|
| 7648 |
-
|
| 7649 |
-
|
| 7650 |
-
|
| 7651 |
-
|
| 7652 |
-
|
| 7653 |
-
|
| 7654 |
-
|
| 7655 |
-
|
| 7656 |
-
|
| 7657 |
-
|
| 7658 |
-
|
| 7659 |
-
|
| 7660 |
-
|
| 7661 |
-
|
| 7662 |
-
|
| 7663 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7664 |
}
|
| 7665 |
-
y[i1] = sumf;
|
| 7666 |
}
|
| 7667 |
}
|
| 7668 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7669 |
}
|
| 7670 |
|
| 7671 |
void ggml_compute_forward_ssm_scan(
|
|
|
|
| 7633 |
const int ir1 = MIN(ir0 + dr, nr);
|
| 7634 |
const int ir = ir1 - ir0;
|
| 7635 |
|
| 7636 |
+
#ifdef __ARM_FEATURE_SVE
|
| 7637 |
+
for (int i3 = 0; i3 < n_s; ++i3) {
|
| 7638 |
+
for (int i2 = 0; i2 < n_t; ++i2) {
|
| 7639 |
+
const float * s0 = (const float *) ((const char *) src0->data + ir0*(src0->nb[1]) + i3*(src0->nb[2])); // {d_state, d_inner, n_s}
|
| 7640 |
+
const float * x = (const float *) ((const char *) src1->data + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s}
|
| 7641 |
+
const float * dt = (const float *) ((const char *) src2->data + ir0*(src2->nb[0]) + i2*(src2->nb[1]) + i3*(src2->nb[2])); // {d_inner, n_t, n_s}
|
| 7642 |
+
const float * A = (const float *) ((const char *) src3->data + ir0*(src3->nb[1])); // {d_state, d_inner}
|
| 7643 |
+
const float * B = (const float *) ((const char *) src4->data + i2*(src4->nb[1]) + i3*(src4->nb[2])); // {d_state, n_t, n_s}
|
| 7644 |
+
const float * C = (const float *) ((const char *) src5->data + i2*(src5->nb[1]) + i3*(src5->nb[2])); // {d_state, n_t, n_s}
|
| 7645 |
+
float * y = ( float *) (( char *) dst->data + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s}
|
| 7646 |
+
float * s = ( float *) (( char *) dst->data + ir0*(src0->nb[1]) + i3*(src0->nb[2]) + src1->nb[3]); // {d_state, d_inner, n_s}
|
| 7647 |
+
|
| 7648 |
+
// use the output as the source for the next token-wise iterations
|
| 7649 |
+
if (i2 > 0) { s0 = s; }
|
| 7650 |
+
|
| 7651 |
+
// d_inner
|
| 7652 |
+
for (int i1 = 0; i1 < ir; ++i1) {
|
| 7653 |
+
float dt_soft_plus = dt[i1] <= 20.0f ? log1pf(expf(dt[i1])) : dt[i1];
|
| 7654 |
+
float x_dt = x[i1] * dt_soft_plus;
|
| 7655 |
+
svfloat32_t vx_dt = GGML_F32_VEC_SET1(x_dt);
|
| 7656 |
+
svfloat32_t vdt_soft_plus = GGML_F32_VEC_SET1(dt_soft_plus);
|
| 7657 |
+
svfloat32_t r1_vector = GGML_F32_VEC_ZERO;
|
| 7658 |
+
|
| 7659 |
+
for (int64_t k = 0; k < nc; k += svcntw()) {
|
| 7660 |
+
svfloat32_t vA = GGML_F32_VEC_LOAD(&A[i1*nc + k]);
|
| 7661 |
+
svfloat32_t vB = GGML_F32_VEC_LOAD(&B[k]);
|
| 7662 |
+
svfloat32_t vC = GGML_F32_VEC_LOAD(&C[k]);
|
| 7663 |
+
svfloat32_t vs0 = GGML_F32_VEC_LOAD(&s0[i1*nc + k]);
|
| 7664 |
+
|
| 7665 |
+
svfloat32_t t1 = GGML_F32_VEC_MUL(vdt_soft_plus, vA);
|
| 7666 |
+
t1 = exp_ps_sve(svptrue_b32(), t1);
|
| 7667 |
+
svfloat32_t t2 = GGML_F32_VEC_MUL(vx_dt, vB);
|
| 7668 |
+
|
| 7669 |
+
vs0 = GGML_F32_VEC_FMA(vs0, t1, t2);
|
| 7670 |
+
r1_vector = GGML_F32_VEC_ADD(GGML_F32_VEC_MUL(vs0, vC), r1_vector);
|
| 7671 |
+
|
| 7672 |
+
GGML_F32_VEC_STORE(&s[i1*nc + k], vs0);
|
| 7673 |
+
}
|
| 7674 |
+
y[i1] = GGML_F32xt_REDUCE_ONE(r1_vector);
|
| 7675 |
}
|
|
|
|
| 7676 |
}
|
| 7677 |
}
|
| 7678 |
+
#else
|
| 7679 |
+
for (int i3 = 0; i3 < n_s; ++i3) {
|
| 7680 |
+
for (int i2 = 0; i2 < n_t; ++i2) {
|
| 7681 |
+
const float * s0 = (const float *) ((const char *) src0->data + ir0*(src0->nb[1]) + i3*(src0->nb[2])); // {d_state, d_inner, n_s}
|
| 7682 |
+
const float * x = (const float *) ((const char *) src1->data + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s}
|
| 7683 |
+
const float * dt = (const float *) ((const char *) src2->data + ir0*(src2->nb[0]) + i2*(src2->nb[1]) + i3*(src2->nb[2])); // {d_inner, n_t, n_s}
|
| 7684 |
+
const float * A = (const float *) ((const char *) src3->data + ir0*(src3->nb[1])); // {d_state, d_inner}
|
| 7685 |
+
const float * B = (const float *) ((const char *) src4->data + i2*(src4->nb[1]) + i3*(src4->nb[2])); // {d_state, n_t, n_s}
|
| 7686 |
+
const float * C = (const float *) ((const char *) src5->data + i2*(src5->nb[1]) + i3*(src5->nb[2])); // {d_state, n_t, n_s}
|
| 7687 |
+
float * y = ( float *) (( char *) dst->data + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s}
|
| 7688 |
+
float * s = ( float *) (( char *) dst->data + ir0*(src0->nb[1]) + i3*(src0->nb[2]) + src1->nb[3]); // {d_state, d_inner, n_s}
|
| 7689 |
+
|
| 7690 |
+
// use the output as the source for the next token-wise iterations
|
| 7691 |
+
if (i2 > 0) { s0 = s; }
|
| 7692 |
+
|
| 7693 |
+
// d_inner
|
| 7694 |
+
for (int i1 = 0; i1 < ir; ++i1) {
|
| 7695 |
+
// ref: https://github.com/state-spaces/mamba/blob/34076d664838588a3c97727b263478ab9f621a07/mamba_ssm/ops/triton/selective_state_update.py#L78
|
| 7696 |
+
float dt_soft_plus = dt[i1] <= 20.0f ? log1pf(expf(dt[i1])) : dt[i1];
|
| 7697 |
+
float x_dt = x[i1] * dt_soft_plus;
|
| 7698 |
+
float sumf = 0.0f;
|
| 7699 |
+
// d_state
|
| 7700 |
+
for (int i0 = 0; i0 < nc; ++i0) {
|
| 7701 |
+
int i = i0 + i1*nc;
|
| 7702 |
+
// state = prev_state * dA + dB * x
|
| 7703 |
+
float state = (s0[i] * expf(dt_soft_plus * A[i])) + (B[i0] * x_dt);
|
| 7704 |
+
// y = rowwise_dotprod(state, C)
|
| 7705 |
+
sumf += state * C[i0];
|
| 7706 |
+
s[i] = state;
|
| 7707 |
+
}
|
| 7708 |
+
y[i1] = sumf;
|
| 7709 |
+
}
|
| 7710 |
+
}
|
| 7711 |
+
}
|
| 7712 |
+
#endif
|
| 7713 |
}
|
| 7714 |
|
| 7715 |
void ggml_compute_forward_ssm_scan(
|
ggml/src/ggml-cpu/vec.h
CHANGED
|
@@ -647,6 +647,42 @@ inline static ggml_fp16_t ggml_silu_f16(ggml_fp16_t x) {
|
|
| 647 |
#error "ref: https://github.com/ggml-org/llama.cpp/pull/7154#issuecomment-2143844461"
|
| 648 |
#endif
|
| 649 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 650 |
#if defined(__ARM_NEON) && defined(__aarch64__)
|
| 651 |
|
| 652 |
// adapted from arm limited optimized routine
|
|
|
|
| 647 |
#error "ref: https://github.com/ggml-org/llama.cpp/pull/7154#issuecomment-2143844461"
|
| 648 |
#endif
|
| 649 |
|
| 650 |
+
/* Below function was borrowed from the GitHub repository:
|
| 651 |
+
https://github.com/openvinotoolkit/openvino/blob/master/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/common.hpp */
|
| 652 |
+
#if defined(__ARM_FEATURE_SVE) && defined(__aarch64__)
|
| 653 |
+
inline static svfloat32_t exp_ps_sve(svbool_t pg, svfloat32_t src) {
|
| 654 |
+
// Constants
|
| 655 |
+
const svfloat32_t log2_e = svdup_n_f32(1.4426950409f);
|
| 656 |
+
const svfloat32_t ln2 = svdup_n_f32(0.6931473921f);
|
| 657 |
+
const svfloat32_t half_ln2_sq = svdup_n_f32(0.2413862043f);
|
| 658 |
+
const svuint32_t not_mask17 = svdup_n_u32(~((1u << 17) - 1));
|
| 659 |
+
const svfloat32_t one = svdup_n_f32(1.0f);
|
| 660 |
+
const svfloat32_t inactive1 = svdup_n_f32(0.0f);
|
| 661 |
+
const svint32_t inactive2 = svdup_n_s32(0);
|
| 662 |
+
|
| 663 |
+
// Algorithm starts here
|
| 664 |
+
svfloat32_t t0 = svmul_f32_m(pg, src, log2_e); // y = x * log2(e)
|
| 665 |
+
svfloat32_t t1 = svrintm_f32_m(inactive1, pg, t0); // rount to int (float)
|
| 666 |
+
svint32_t t2 = svcvt_s32_f32_m(inactive2, pg, t1); // n
|
| 667 |
+
|
| 668 |
+
t1 = svsub_f32_m(pg, t0, t1); // a = y - floor(y)
|
| 669 |
+
t1 = svadd_f32_m(pg, t1, one); // b = a + 1
|
| 670 |
+
|
| 671 |
+
svuint32_t t3 = svlsr_n_u32_m(pg, svreinterpret_u32_f32(t1), 17); // v = b >> 17 (u32)
|
| 672 |
+
svfloat32_t t4 = svexpa_f32(t3); // c = fexpa(v)
|
| 673 |
+
t4 = svscale_f32_m(pg, t4, t2); // fexpa(v) * 2^(n)
|
| 674 |
+
|
| 675 |
+
// and_(t2.d, t1.d, not_mask17.d)
|
| 676 |
+
svfloat32_t t5 = svreinterpret_f32_u32(svand_u32_m(pg, svreinterpret_u32_f32(t1), not_mask17));
|
| 677 |
+
t5 = svsub_f32_m(pg, t1, t5); // z
|
| 678 |
+
t0 = svmla_f32_m(pg, ln2, t5, half_ln2_sq); // ln2 + half_ln2_sq * z
|
| 679 |
+
t0 = svmla_f32_m(pg, one, t5, t0); // 1 + (ln2 * z) + (half_ln2_sq * z * z)
|
| 680 |
+
t0 = svmul_f32_m(pg, t0, t4); // Final result
|
| 681 |
+
|
| 682 |
+
return t0;
|
| 683 |
+
}
|
| 684 |
+
#endif
|
| 685 |
+
|
| 686 |
#if defined(__ARM_NEON) && defined(__aarch64__)
|
| 687 |
|
| 688 |
// adapted from arm limited optimized routine
|