Vineel Abhinav ggerganov commited on
Commit
bfc960a
·
1 Parent(s): 7941e9b

ggml: aarch64: Implement SVE F32 kernels for Mamba Sequential Scan Algorithm (llama/13882)

Browse files

* F32-Mamba-Seq_Scan-SVE

* Fix formatting

* ggml : missing space

---------

Co-authored-by: Georgi Gerganov <[email protected]>

Files changed (2) hide show
  1. ggml/src/ggml-cpu/ops.cpp +74 -30
  2. ggml/src/ggml-cpu/vec.h +36 -0
ggml/src/ggml-cpu/ops.cpp CHANGED
@@ -7633,39 +7633,83 @@ static void ggml_compute_forward_ssm_scan_f32(
7633
  const int ir1 = MIN(ir0 + dr, nr);
7634
  const int ir = ir1 - ir0;
7635
 
7636
- for (int i3 = 0; i3 < n_s; ++i3) {
7637
- for (int i2 = 0; i2 < n_t; ++i2) {
7638
- const float * s0 = (const float *) ((const char *) src0->data + ir0*(src0->nb[1]) + i3*(src0->nb[2])); // {d_state, d_inner, n_s}
7639
- const float * x = (const float *) ((const char *) src1->data + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s}
7640
- const float * dt = (const float *) ((const char *) src2->data + ir0*(src2->nb[0]) + i2*(src2->nb[1]) + i3*(src2->nb[2])); // {d_inner, n_t, n_s}
7641
- const float * A = (const float *) ((const char *) src3->data + ir0*(src3->nb[1])); // {d_state, d_inner}
7642
- const float * B = (const float *) ((const char *) src4->data + i2*(src4->nb[1]) + i3*(src4->nb[2])); // {d_state, n_t, n_s}
7643
- const float * C = (const float *) ((const char *) src5->data + i2*(src5->nb[1]) + i3*(src5->nb[2])); // {d_state, n_t, n_s}
7644
- float * y = ( float *) (( char *) dst->data + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s}
7645
- float * s = ( float *) (( char *) dst->data + ir0*(src0->nb[1]) + i3*(src0->nb[2]) + src1->nb[3]); // {d_state, d_inner, n_s}
7646
-
7647
- // use the output as the source for the next token-wise iterations
7648
- if (i2 > 0) { s0 = s; }
7649
-
7650
- // d_inner
7651
- for (int i1 = 0; i1 < ir; ++i1) {
7652
- // ref: https://github.com/state-spaces/mamba/blob/34076d664838588a3c97727b263478ab9f621a07/mamba_ssm/ops/triton/selective_state_update.py#L78
7653
- float dt_soft_plus = dt[i1] <= 20.0f ? log1pf(expf(dt[i1])) : dt[i1];
7654
- float x_dt = x[i1] * dt_soft_plus;
7655
- float sumf = 0.0f;
7656
- // d_state
7657
- for (int i0 = 0; i0 < nc; ++i0) {
7658
- int i = i0 + i1*nc;
7659
- // state = prev_state * dA + dB * x
7660
- float state = (s0[i] * expf(dt_soft_plus * A[i])) + (B[i0] * x_dt);
7661
- // y = rowwise_dotprod(state, C)
7662
- sumf += state * C[i0];
7663
- s[i] = state;
 
 
 
 
 
 
 
 
 
 
 
7664
  }
7665
- y[i1] = sumf;
7666
  }
7667
  }
7668
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7669
  }
7670
 
7671
  void ggml_compute_forward_ssm_scan(
 
7633
  const int ir1 = MIN(ir0 + dr, nr);
7634
  const int ir = ir1 - ir0;
7635
 
7636
+ #ifdef __ARM_FEATURE_SVE
7637
+ for (int i3 = 0; i3 < n_s; ++i3) {
7638
+ for (int i2 = 0; i2 < n_t; ++i2) {
7639
+ const float * s0 = (const float *) ((const char *) src0->data + ir0*(src0->nb[1]) + i3*(src0->nb[2])); // {d_state, d_inner, n_s}
7640
+ const float * x = (const float *) ((const char *) src1->data + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s}
7641
+ const float * dt = (const float *) ((const char *) src2->data + ir0*(src2->nb[0]) + i2*(src2->nb[1]) + i3*(src2->nb[2])); // {d_inner, n_t, n_s}
7642
+ const float * A = (const float *) ((const char *) src3->data + ir0*(src3->nb[1])); // {d_state, d_inner}
7643
+ const float * B = (const float *) ((const char *) src4->data + i2*(src4->nb[1]) + i3*(src4->nb[2])); // {d_state, n_t, n_s}
7644
+ const float * C = (const float *) ((const char *) src5->data + i2*(src5->nb[1]) + i3*(src5->nb[2])); // {d_state, n_t, n_s}
7645
+ float * y = ( float *) (( char *) dst->data + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s}
7646
+ float * s = ( float *) (( char *) dst->data + ir0*(src0->nb[1]) + i3*(src0->nb[2]) + src1->nb[3]); // {d_state, d_inner, n_s}
7647
+
7648
+ // use the output as the source for the next token-wise iterations
7649
+ if (i2 > 0) { s0 = s; }
7650
+
7651
+ // d_inner
7652
+ for (int i1 = 0; i1 < ir; ++i1) {
7653
+ float dt_soft_plus = dt[i1] <= 20.0f ? log1pf(expf(dt[i1])) : dt[i1];
7654
+ float x_dt = x[i1] * dt_soft_plus;
7655
+ svfloat32_t vx_dt = GGML_F32_VEC_SET1(x_dt);
7656
+ svfloat32_t vdt_soft_plus = GGML_F32_VEC_SET1(dt_soft_plus);
7657
+ svfloat32_t r1_vector = GGML_F32_VEC_ZERO;
7658
+
7659
+ for (int64_t k = 0; k < nc; k += svcntw()) {
7660
+ svfloat32_t vA = GGML_F32_VEC_LOAD(&A[i1*nc + k]);
7661
+ svfloat32_t vB = GGML_F32_VEC_LOAD(&B[k]);
7662
+ svfloat32_t vC = GGML_F32_VEC_LOAD(&C[k]);
7663
+ svfloat32_t vs0 = GGML_F32_VEC_LOAD(&s0[i1*nc + k]);
7664
+
7665
+ svfloat32_t t1 = GGML_F32_VEC_MUL(vdt_soft_plus, vA);
7666
+ t1 = exp_ps_sve(svptrue_b32(), t1);
7667
+ svfloat32_t t2 = GGML_F32_VEC_MUL(vx_dt, vB);
7668
+
7669
+ vs0 = GGML_F32_VEC_FMA(vs0, t1, t2);
7670
+ r1_vector = GGML_F32_VEC_ADD(GGML_F32_VEC_MUL(vs0, vC), r1_vector);
7671
+
7672
+ GGML_F32_VEC_STORE(&s[i1*nc + k], vs0);
7673
+ }
7674
+ y[i1] = GGML_F32xt_REDUCE_ONE(r1_vector);
7675
  }
 
7676
  }
7677
  }
7678
+ #else
7679
+ for (int i3 = 0; i3 < n_s; ++i3) {
7680
+ for (int i2 = 0; i2 < n_t; ++i2) {
7681
+ const float * s0 = (const float *) ((const char *) src0->data + ir0*(src0->nb[1]) + i3*(src0->nb[2])); // {d_state, d_inner, n_s}
7682
+ const float * x = (const float *) ((const char *) src1->data + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s}
7683
+ const float * dt = (const float *) ((const char *) src2->data + ir0*(src2->nb[0]) + i2*(src2->nb[1]) + i3*(src2->nb[2])); // {d_inner, n_t, n_s}
7684
+ const float * A = (const float *) ((const char *) src3->data + ir0*(src3->nb[1])); // {d_state, d_inner}
7685
+ const float * B = (const float *) ((const char *) src4->data + i2*(src4->nb[1]) + i3*(src4->nb[2])); // {d_state, n_t, n_s}
7686
+ const float * C = (const float *) ((const char *) src5->data + i2*(src5->nb[1]) + i3*(src5->nb[2])); // {d_state, n_t, n_s}
7687
+ float * y = ( float *) (( char *) dst->data + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s}
7688
+ float * s = ( float *) (( char *) dst->data + ir0*(src0->nb[1]) + i3*(src0->nb[2]) + src1->nb[3]); // {d_state, d_inner, n_s}
7689
+
7690
+ // use the output as the source for the next token-wise iterations
7691
+ if (i2 > 0) { s0 = s; }
7692
+
7693
+ // d_inner
7694
+ for (int i1 = 0; i1 < ir; ++i1) {
7695
+ // ref: https://github.com/state-spaces/mamba/blob/34076d664838588a3c97727b263478ab9f621a07/mamba_ssm/ops/triton/selective_state_update.py#L78
7696
+ float dt_soft_plus = dt[i1] <= 20.0f ? log1pf(expf(dt[i1])) : dt[i1];
7697
+ float x_dt = x[i1] * dt_soft_plus;
7698
+ float sumf = 0.0f;
7699
+ // d_state
7700
+ for (int i0 = 0; i0 < nc; ++i0) {
7701
+ int i = i0 + i1*nc;
7702
+ // state = prev_state * dA + dB * x
7703
+ float state = (s0[i] * expf(dt_soft_plus * A[i])) + (B[i0] * x_dt);
7704
+ // y = rowwise_dotprod(state, C)
7705
+ sumf += state * C[i0];
7706
+ s[i] = state;
7707
+ }
7708
+ y[i1] = sumf;
7709
+ }
7710
+ }
7711
+ }
7712
+ #endif
7713
  }
7714
 
7715
  void ggml_compute_forward_ssm_scan(
ggml/src/ggml-cpu/vec.h CHANGED
@@ -647,6 +647,42 @@ inline static ggml_fp16_t ggml_silu_f16(ggml_fp16_t x) {
647
  #error "ref: https://github.com/ggml-org/llama.cpp/pull/7154#issuecomment-2143844461"
648
  #endif
649
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
650
  #if defined(__ARM_NEON) && defined(__aarch64__)
651
 
652
  // adapted from arm limited optimized routine
 
647
  #error "ref: https://github.com/ggml-org/llama.cpp/pull/7154#issuecomment-2143844461"
648
  #endif
649
 
650
+ /* Below function was borrowed from the GitHub repository:
651
+ https://github.com/openvinotoolkit/openvino/blob/master/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/common.hpp */
652
+ #if defined(__ARM_FEATURE_SVE) && defined(__aarch64__)
653
+ inline static svfloat32_t exp_ps_sve(svbool_t pg, svfloat32_t src) {
654
+ // Constants
655
+ const svfloat32_t log2_e = svdup_n_f32(1.4426950409f);
656
+ const svfloat32_t ln2 = svdup_n_f32(0.6931473921f);
657
+ const svfloat32_t half_ln2_sq = svdup_n_f32(0.2413862043f);
658
+ const svuint32_t not_mask17 = svdup_n_u32(~((1u << 17) - 1));
659
+ const svfloat32_t one = svdup_n_f32(1.0f);
660
+ const svfloat32_t inactive1 = svdup_n_f32(0.0f);
661
+ const svint32_t inactive2 = svdup_n_s32(0);
662
+
663
+ // Algorithm starts here
664
+ svfloat32_t t0 = svmul_f32_m(pg, src, log2_e); // y = x * log2(e)
665
+ svfloat32_t t1 = svrintm_f32_m(inactive1, pg, t0); // rount to int (float)
666
+ svint32_t t2 = svcvt_s32_f32_m(inactive2, pg, t1); // n
667
+
668
+ t1 = svsub_f32_m(pg, t0, t1); // a = y - floor(y)
669
+ t1 = svadd_f32_m(pg, t1, one); // b = a + 1
670
+
671
+ svuint32_t t3 = svlsr_n_u32_m(pg, svreinterpret_u32_f32(t1), 17); // v = b >> 17 (u32)
672
+ svfloat32_t t4 = svexpa_f32(t3); // c = fexpa(v)
673
+ t4 = svscale_f32_m(pg, t4, t2); // fexpa(v) * 2^(n)
674
+
675
+ // and_(t2.d, t1.d, not_mask17.d)
676
+ svfloat32_t t5 = svreinterpret_f32_u32(svand_u32_m(pg, svreinterpret_u32_f32(t1), not_mask17));
677
+ t5 = svsub_f32_m(pg, t1, t5); // z
678
+ t0 = svmla_f32_m(pg, ln2, t5, half_ln2_sq); // ln2 + half_ln2_sq * z
679
+ t0 = svmla_f32_m(pg, one, t5, t0); // 1 + (ln2 * z) + (half_ln2_sq * z * z)
680
+ t0 = svmul_f32_m(pg, t0, t4); // Final result
681
+
682
+ return t0;
683
+ }
684
+ #endif
685
+
686
  #if defined(__ARM_NEON) && defined(__aarch64__)
687
 
688
  // adapted from arm limited optimized routine