ggerganov commited on
Commit
6e57274
·
unverified ·
1 Parent(s): fdd70c9

ggml : improve vec_dot_f16 unrolling in flash_attn_f16

Browse files
Files changed (2) hide show
  1. examples/command/command.cpp +1 -1
  2. ggml.c +53 -44
examples/command/command.cpp CHANGED
@@ -781,7 +781,7 @@ int always_prompt_transcription(struct whisper_context * ctx, audio_async & audi
781
  std::string prompt;
782
  std::string command;
783
 
784
- for (int i = 0; i < words.size(); ++i) {
785
  if (i < k_prompt_length) {
786
  prompt += words[i] + " ";
787
  } else {
 
781
  std::string prompt;
782
  std::string command;
783
 
784
+ for (int i = 0; i < (int) words.size(); ++i) {
785
  if (i < k_prompt_length) {
786
  prompt += words[i] + " ";
787
  } else {
ggml.c CHANGED
@@ -84,7 +84,7 @@ typedef void* thread_ret_t;
84
  #define GGML_GELU_FP16
85
 
86
  #define GGML_SOFT_MAX_UNROLL 4
87
- #define GGML_VEC_DOT_UNROLL 4
88
 
89
  #ifdef GGML_USE_ACCELERATE
90
  // uncomment to use vDSP for soft max computation
@@ -923,9 +923,9 @@ inline static void ggml_vec_dot_f16(const int n, float * restrict s, ggml_fp16_t
923
  inline static void ggml_vec_dot_f16_unroll(const int n, const int xs, float * restrict s, void * restrict xv, ggml_fp16_t * restrict y) {
924
  ggml_float sumf[GGML_VEC_DOT_UNROLL] = { 0.0 };
925
 
926
- const ggml_fp16_t * restrict x[GGML_VEC_DOT_UNROLL] = { xv };
927
 
928
- for (int i = 1; i < GGML_VEC_DOT_UNROLL; ++i) {
929
  x[i] = (ggml_fp16_t *) ((char *) xv + i*xs);
930
  }
931
 
@@ -6158,40 +6158,37 @@ static void ggml_compute_forward_flash_attn_f16(
6158
  S[i] = -INFINITY;
6159
  }
6160
 
6161
- // looks like unrolling here does not help
6162
- #if 1
6163
- for (int ic = 0; ic < nek1; ++ic) {
6164
- // k indices
6165
- const int ik3 = iq3;
6166
- const int ik2 = iq2;
6167
- const int ik1 = ic;
6168
-
6169
- // S indices
6170
- const int i1 = ik1;
6171
-
6172
- ggml_vec_dot_f16(neq0,
6173
- S + i1,
6174
- (ggml_fp16_t *) ((char *) k->data + (ik1*nbk1 + ik2*nbk2 + ik3*nbk3)),
6175
- (ggml_fp16_t *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3)));
6176
- }
6177
- #else
6178
- GGML_ASSERT(nek1 % GGML_VEC_DOT_UNROLL == 0);
6179
-
6180
- for (int ic = 0; ic < nek1; ic += GGML_VEC_DOT_UNROLL) {
6181
- // k indices
6182
- const int ik3 = iq3;
6183
- const int ik2 = iq2;
6184
- const int ik1 = ic;
6185
 
6186
- // S indices
6187
- const int i1 = ik1;
6188
 
6189
- ggml_vec_dot_f16_unroll(neq0, nbk1,
6190
- S + i1,
6191
- ((char *) k->data + (ik1*nbk1 + ik2*nbk2 + ik3*nbk3)),
6192
- (ggml_fp16_t *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3)));
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6193
  }
6194
- #endif
6195
 
6196
  // scale
6197
  ggml_vec_scale_f32(nek1, S, scale);
@@ -6261,18 +6258,30 @@ static void ggml_compute_forward_flash_attn_f16(
6261
  S16[i] = GGML_FP32_TO_FP16(S[i]);
6262
  }
6263
 
6264
- GGML_ASSERT(nev1 % GGML_VEC_DOT_UNROLL == 0);
 
 
 
 
 
6265
 
6266
- for (int ic = 0; ic < nev1; ic += GGML_VEC_DOT_UNROLL) {
6267
- // dst indices
6268
- const int i1 = iq1;
6269
- const int i2 = iq2;
6270
- const int i3 = iq3;
 
 
 
 
 
 
6271
 
6272
- ggml_vec_dot_f16_unroll(nek1, nbv1,
6273
- (float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)),
6274
- ((char *) v->data + ( ic*nbv1 + i2*nbv2 + i3*nbv3)),
6275
- S16);
 
6276
  }
6277
  }
6278
  }
 
84
  #define GGML_GELU_FP16
85
 
86
  #define GGML_SOFT_MAX_UNROLL 4
87
+ #define GGML_VEC_DOT_UNROLL 2
88
 
89
  #ifdef GGML_USE_ACCELERATE
90
  // uncomment to use vDSP for soft max computation
 
923
  inline static void ggml_vec_dot_f16_unroll(const int n, const int xs, float * restrict s, void * restrict xv, ggml_fp16_t * restrict y) {
924
  ggml_float sumf[GGML_VEC_DOT_UNROLL] = { 0.0 };
925
 
926
+ ggml_fp16_t * restrict x[GGML_VEC_DOT_UNROLL];
927
 
928
+ for (int i = 0; i < GGML_VEC_DOT_UNROLL; ++i) {
929
  x[i] = (ggml_fp16_t *) ((char *) xv + i*xs);
930
  }
931
 
 
6158
  S[i] = -INFINITY;
6159
  }
6160
 
6161
+ if (GGML_VEC_DOT_UNROLL > 2 || nek1 % GGML_VEC_DOT_UNROLL != 0) {
6162
+ for (int ic = 0; ic < nek1; ++ic) {
6163
+ // k indices
6164
+ const int ik3 = iq3;
6165
+ const int ik2 = iq2;
6166
+ const int ik1 = ic;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6167
 
6168
+ // S indices
6169
+ const int i1 = ik1;
6170
 
6171
+ ggml_vec_dot_f16(neq0,
6172
+ S + i1,
6173
+ (ggml_fp16_t *) ((char *) k->data + (ik1*nbk1 + ik2*nbk2 + ik3*nbk3)),
6174
+ (ggml_fp16_t *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3)));
6175
+ }
6176
+ } else {
6177
+ for (int ic = 0; ic < nek1; ic += GGML_VEC_DOT_UNROLL) {
6178
+ // k indices
6179
+ const int ik3 = iq3;
6180
+ const int ik2 = iq2;
6181
+ const int ik1 = ic;
6182
+
6183
+ // S indices
6184
+ const int i1 = ik1;
6185
+
6186
+ ggml_vec_dot_f16_unroll(neq0, nbk1,
6187
+ S + i1,
6188
+ ((char *) k->data + (ik1*nbk1 + ik2*nbk2 + ik3*nbk3)),
6189
+ (ggml_fp16_t *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3)));
6190
+ }
6191
  }
 
6192
 
6193
  // scale
6194
  ggml_vec_scale_f32(nek1, S, scale);
 
6258
  S16[i] = GGML_FP32_TO_FP16(S[i]);
6259
  }
6260
 
6261
+ if (GGML_VEC_DOT_UNROLL == 1 || (nev1 % GGML_VEC_DOT_UNROLL != 0)) {
6262
+ for (int ic = 0; ic < nev1; ++ic) {
6263
+ // dst indices
6264
+ const int i1 = iq1;
6265
+ const int i2 = iq2;
6266
+ const int i3 = iq3;
6267
 
6268
+ ggml_vec_dot_f16(nek1,
6269
+ (float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)),
6270
+ (ggml_fp16_t *) ((char *) v->data + ( ic*nbv1 + i2*nbv2 + i3*nbv3)),
6271
+ S16);
6272
+ }
6273
+ } else {
6274
+ for (int ic = 0; ic < nev1; ic += GGML_VEC_DOT_UNROLL) {
6275
+ // dst indices
6276
+ const int i1 = iq1;
6277
+ const int i2 = iq2;
6278
+ const int i3 = iq3;
6279
 
6280
+ ggml_vec_dot_f16_unroll(nek1, nbv1,
6281
+ (float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)),
6282
+ ((char *) v->data + ( ic*nbv1 + i2*nbv2 + i3*nbv3)),
6283
+ S16);
6284
+ }
6285
  }
6286
  }
6287
  }