Spaces:
Running
Running
llama : fix FA when KV cache is not used (i.e. embeddings) (llama/12825)
Browse files* ggml : FA supports F32 V
* graph : cast KV to F16 when the KV cache is not used
ggml-ci
* server : add test that exercises embeddings with FA enabled
ggml-ci
ggml/src/ggml-cpu/ops.cpp
CHANGED
|
@@ -6769,8 +6769,8 @@ static void ggml_compute_forward_flash_attn_ext_f16(
|
|
| 6769 |
ggml_vec_dot_t const kq_vec_dot = ggml_get_type_traits_cpu(k->type)->vec_dot;
|
| 6770 |
ggml_to_float_t const v_to_float = ggml_get_type_traits(v->type)->to_float;
|
| 6771 |
|
| 6772 |
-
GGML_ASSERT(q_to_vec_dot && "fattn: unsupported K-type");
|
| 6773 |
-
GGML_ASSERT(v_to_float
|
| 6774 |
|
| 6775 |
// loop over n_batch and n_head
|
| 6776 |
for (int ir = ir0; ir < ir1; ++ir) {
|
|
@@ -6866,10 +6866,14 @@ static void ggml_compute_forward_flash_attn_ext_f16(
|
|
| 6866 |
vs = expf(s - M);
|
| 6867 |
}
|
| 6868 |
|
| 6869 |
-
v_to_float(v_data, V32, DV);
|
| 6870 |
-
|
| 6871 |
// V += v*expf(s - M)
|
| 6872 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6873 |
}
|
| 6874 |
|
| 6875 |
S = S*ms + vs; // scale and increment sum with partial sum
|
|
|
|
| 6769 |
ggml_vec_dot_t const kq_vec_dot = ggml_get_type_traits_cpu(k->type)->vec_dot;
|
| 6770 |
ggml_to_float_t const v_to_float = ggml_get_type_traits(v->type)->to_float;
|
| 6771 |
|
| 6772 |
+
GGML_ASSERT(( q_to_vec_dot) && "fattn: unsupported K-type");
|
| 6773 |
+
GGML_ASSERT((v->type == GGML_TYPE_F32 || v_to_float ) && "fattn: unsupported V-type");
|
| 6774 |
|
| 6775 |
// loop over n_batch and n_head
|
| 6776 |
for (int ir = ir0; ir < ir1; ++ir) {
|
|
|
|
| 6866 |
vs = expf(s - M);
|
| 6867 |
}
|
| 6868 |
|
|
|
|
|
|
|
| 6869 |
// V += v*expf(s - M)
|
| 6870 |
+
if (v_to_float) {
|
| 6871 |
+
v_to_float(v_data, V32, DV);
|
| 6872 |
+
ggml_vec_mad_f32(DV, VKQ32, V32, vs);
|
| 6873 |
+
} else {
|
| 6874 |
+
// V is F32
|
| 6875 |
+
ggml_vec_mad_f32(DV, VKQ32, (const float *) v_data, vs);
|
| 6876 |
+
}
|
| 6877 |
}
|
| 6878 |
|
| 6879 |
S = S*ms + vs; // scale and increment sum with partial sum
|
ggml/src/ggml-metal/ggml-metal.m
CHANGED
|
@@ -1346,6 +1346,11 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
|
|
| 1346 |
case GGML_OP_ARANGE:
|
| 1347 |
return true;
|
| 1348 |
case GGML_OP_FLASH_ATTN_EXT:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1349 |
if (op->src[1]->type != op->src[2]->type) {
|
| 1350 |
return false;
|
| 1351 |
}
|
|
|
|
| 1346 |
case GGML_OP_ARANGE:
|
| 1347 |
return true;
|
| 1348 |
case GGML_OP_FLASH_ATTN_EXT:
|
| 1349 |
+
if (op->src[0]->ne[0] == 32) {
|
| 1350 |
+
// head size == 32 (e.g. bert-bge-small)
|
| 1351 |
+
// TODO: not sure if it is worth adding kernels for this size
|
| 1352 |
+
return false;
|
| 1353 |
+
}
|
| 1354 |
if (op->src[1]->type != op->src[2]->type) {
|
| 1355 |
return false;
|
| 1356 |
}
|