Spaces:
Running
Running
ggml : improve vec_dot_f16 unrolling in flash_attn_f16
Browse files- examples/command/command.cpp +1 -1
- ggml.c +53 -44
examples/command/command.cpp
CHANGED
|
@@ -781,7 +781,7 @@ int always_prompt_transcription(struct whisper_context * ctx, audio_async & audi
|
|
| 781 |
std::string prompt;
|
| 782 |
std::string command;
|
| 783 |
|
| 784 |
-
for (int i = 0; i < words.size(); ++i) {
|
| 785 |
if (i < k_prompt_length) {
|
| 786 |
prompt += words[i] + " ";
|
| 787 |
} else {
|
|
|
|
| 781 |
std::string prompt;
|
| 782 |
std::string command;
|
| 783 |
|
| 784 |
+
for (int i = 0; i < (int) words.size(); ++i) {
|
| 785 |
if (i < k_prompt_length) {
|
| 786 |
prompt += words[i] + " ";
|
| 787 |
} else {
|
ggml.c
CHANGED
|
@@ -84,7 +84,7 @@ typedef void* thread_ret_t;
|
|
| 84 |
#define GGML_GELU_FP16
|
| 85 |
|
| 86 |
#define GGML_SOFT_MAX_UNROLL 4
|
| 87 |
-
#define GGML_VEC_DOT_UNROLL
|
| 88 |
|
| 89 |
#ifdef GGML_USE_ACCELERATE
|
| 90 |
// uncomment to use vDSP for soft max computation
|
|
@@ -923,9 +923,9 @@ inline static void ggml_vec_dot_f16(const int n, float * restrict s, ggml_fp16_t
|
|
| 923 |
inline static void ggml_vec_dot_f16_unroll(const int n, const int xs, float * restrict s, void * restrict xv, ggml_fp16_t * restrict y) {
|
| 924 |
ggml_float sumf[GGML_VEC_DOT_UNROLL] = { 0.0 };
|
| 925 |
|
| 926 |
-
|
| 927 |
|
| 928 |
-
for (int i =
|
| 929 |
x[i] = (ggml_fp16_t *) ((char *) xv + i*xs);
|
| 930 |
}
|
| 931 |
|
|
@@ -6158,40 +6158,37 @@ static void ggml_compute_forward_flash_attn_f16(
|
|
| 6158 |
S[i] = -INFINITY;
|
| 6159 |
}
|
| 6160 |
|
| 6161 |
-
|
| 6162 |
-
|
| 6163 |
-
|
| 6164 |
-
|
| 6165 |
-
|
| 6166 |
-
|
| 6167 |
-
const int ik1 = ic;
|
| 6168 |
-
|
| 6169 |
-
// S indices
|
| 6170 |
-
const int i1 = ik1;
|
| 6171 |
-
|
| 6172 |
-
ggml_vec_dot_f16(neq0,
|
| 6173 |
-
S + i1,
|
| 6174 |
-
(ggml_fp16_t *) ((char *) k->data + (ik1*nbk1 + ik2*nbk2 + ik3*nbk3)),
|
| 6175 |
-
(ggml_fp16_t *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3)));
|
| 6176 |
-
}
|
| 6177 |
-
#else
|
| 6178 |
-
GGML_ASSERT(nek1 % GGML_VEC_DOT_UNROLL == 0);
|
| 6179 |
-
|
| 6180 |
-
for (int ic = 0; ic < nek1; ic += GGML_VEC_DOT_UNROLL) {
|
| 6181 |
-
// k indices
|
| 6182 |
-
const int ik3 = iq3;
|
| 6183 |
-
const int ik2 = iq2;
|
| 6184 |
-
const int ik1 = ic;
|
| 6185 |
|
| 6186 |
-
|
| 6187 |
-
|
| 6188 |
|
| 6189 |
-
|
| 6190 |
-
|
| 6191 |
-
|
| 6192 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6193 |
}
|
| 6194 |
-
#endif
|
| 6195 |
|
| 6196 |
// scale
|
| 6197 |
ggml_vec_scale_f32(nek1, S, scale);
|
|
@@ -6261,18 +6258,30 @@ static void ggml_compute_forward_flash_attn_f16(
|
|
| 6261 |
S16[i] = GGML_FP32_TO_FP16(S[i]);
|
| 6262 |
}
|
| 6263 |
|
| 6264 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6265 |
|
| 6266 |
-
|
| 6267 |
-
|
| 6268 |
-
|
| 6269 |
-
|
| 6270 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6271 |
|
| 6272 |
-
|
| 6273 |
-
|
| 6274 |
-
|
| 6275 |
-
|
|
|
|
| 6276 |
}
|
| 6277 |
}
|
| 6278 |
}
|
|
|
|
| 84 |
#define GGML_GELU_FP16
|
| 85 |
|
| 86 |
#define GGML_SOFT_MAX_UNROLL 4
|
| 87 |
+
#define GGML_VEC_DOT_UNROLL 2
|
| 88 |
|
| 89 |
#ifdef GGML_USE_ACCELERATE
|
| 90 |
// uncomment to use vDSP for soft max computation
|
|
|
|
| 923 |
inline static void ggml_vec_dot_f16_unroll(const int n, const int xs, float * restrict s, void * restrict xv, ggml_fp16_t * restrict y) {
|
| 924 |
ggml_float sumf[GGML_VEC_DOT_UNROLL] = { 0.0 };
|
| 925 |
|
| 926 |
+
ggml_fp16_t * restrict x[GGML_VEC_DOT_UNROLL];
|
| 927 |
|
| 928 |
+
for (int i = 0; i < GGML_VEC_DOT_UNROLL; ++i) {
|
| 929 |
x[i] = (ggml_fp16_t *) ((char *) xv + i*xs);
|
| 930 |
}
|
| 931 |
|
|
|
|
| 6158 |
S[i] = -INFINITY;
|
| 6159 |
}
|
| 6160 |
|
| 6161 |
+
if (GGML_VEC_DOT_UNROLL > 2 || nek1 % GGML_VEC_DOT_UNROLL != 0) {
|
| 6162 |
+
for (int ic = 0; ic < nek1; ++ic) {
|
| 6163 |
+
// k indices
|
| 6164 |
+
const int ik3 = iq3;
|
| 6165 |
+
const int ik2 = iq2;
|
| 6166 |
+
const int ik1 = ic;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6167 |
|
| 6168 |
+
// S indices
|
| 6169 |
+
const int i1 = ik1;
|
| 6170 |
|
| 6171 |
+
ggml_vec_dot_f16(neq0,
|
| 6172 |
+
S + i1,
|
| 6173 |
+
(ggml_fp16_t *) ((char *) k->data + (ik1*nbk1 + ik2*nbk2 + ik3*nbk3)),
|
| 6174 |
+
(ggml_fp16_t *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3)));
|
| 6175 |
+
}
|
| 6176 |
+
} else {
|
| 6177 |
+
for (int ic = 0; ic < nek1; ic += GGML_VEC_DOT_UNROLL) {
|
| 6178 |
+
// k indices
|
| 6179 |
+
const int ik3 = iq3;
|
| 6180 |
+
const int ik2 = iq2;
|
| 6181 |
+
const int ik1 = ic;
|
| 6182 |
+
|
| 6183 |
+
// S indices
|
| 6184 |
+
const int i1 = ik1;
|
| 6185 |
+
|
| 6186 |
+
ggml_vec_dot_f16_unroll(neq0, nbk1,
|
| 6187 |
+
S + i1,
|
| 6188 |
+
((char *) k->data + (ik1*nbk1 + ik2*nbk2 + ik3*nbk3)),
|
| 6189 |
+
(ggml_fp16_t *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3)));
|
| 6190 |
+
}
|
| 6191 |
}
|
|
|
|
| 6192 |
|
| 6193 |
// scale
|
| 6194 |
ggml_vec_scale_f32(nek1, S, scale);
|
|
|
|
| 6258 |
S16[i] = GGML_FP32_TO_FP16(S[i]);
|
| 6259 |
}
|
| 6260 |
|
| 6261 |
+
if (GGML_VEC_DOT_UNROLL == 1 || (nev1 % GGML_VEC_DOT_UNROLL != 0)) {
|
| 6262 |
+
for (int ic = 0; ic < nev1; ++ic) {
|
| 6263 |
+
// dst indices
|
| 6264 |
+
const int i1 = iq1;
|
| 6265 |
+
const int i2 = iq2;
|
| 6266 |
+
const int i3 = iq3;
|
| 6267 |
|
| 6268 |
+
ggml_vec_dot_f16(nek1,
|
| 6269 |
+
(float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)),
|
| 6270 |
+
(ggml_fp16_t *) ((char *) v->data + ( ic*nbv1 + i2*nbv2 + i3*nbv3)),
|
| 6271 |
+
S16);
|
| 6272 |
+
}
|
| 6273 |
+
} else {
|
| 6274 |
+
for (int ic = 0; ic < nev1; ic += GGML_VEC_DOT_UNROLL) {
|
| 6275 |
+
// dst indices
|
| 6276 |
+
const int i1 = iq1;
|
| 6277 |
+
const int i2 = iq2;
|
| 6278 |
+
const int i3 = iq3;
|
| 6279 |
|
| 6280 |
+
ggml_vec_dot_f16_unroll(nek1, nbv1,
|
| 6281 |
+
(float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)),
|
| 6282 |
+
((char *) v->data + ( ic*nbv1 + i2*nbv2 + i3*nbv3)),
|
| 6283 |
+
S16);
|
| 6284 |
+
}
|
| 6285 |
}
|
| 6286 |
}
|
| 6287 |
}
|