ggerganov HF Staff commited on
Commit
e7cb2dc
·
1 Parent(s): 87f1ea3

llama : fix FA when KV cache is not used (i.e. embeddings) (llama/12825)

Browse files

* ggml : FA supports F32 V

* graph : cast KV to F16 when the KV cache is not used

ggml-ci

* server : add test that exercises embeddings with FA enabled

ggml-ci

ggml/src/ggml-cpu/ops.cpp CHANGED
@@ -6769,8 +6769,8 @@ static void ggml_compute_forward_flash_attn_ext_f16(
6769
  ggml_vec_dot_t const kq_vec_dot = ggml_get_type_traits_cpu(k->type)->vec_dot;
6770
  ggml_to_float_t const v_to_float = ggml_get_type_traits(v->type)->to_float;
6771
 
6772
- GGML_ASSERT(q_to_vec_dot && "fattn: unsupported K-type");
6773
- GGML_ASSERT(v_to_float && "fattn: unsupported V-type");
6774
 
6775
  // loop over n_batch and n_head
6776
  for (int ir = ir0; ir < ir1; ++ir) {
@@ -6866,10 +6866,14 @@ static void ggml_compute_forward_flash_attn_ext_f16(
6866
  vs = expf(s - M);
6867
  }
6868
 
6869
- v_to_float(v_data, V32, DV);
6870
-
6871
  // V += v*expf(s - M)
6872
- ggml_vec_mad_f32(DV, VKQ32, V32, vs);
 
 
 
 
 
 
6873
  }
6874
 
6875
  S = S*ms + vs; // scale and increment sum with partial sum
 
6769
  ggml_vec_dot_t const kq_vec_dot = ggml_get_type_traits_cpu(k->type)->vec_dot;
6770
  ggml_to_float_t const v_to_float = ggml_get_type_traits(v->type)->to_float;
6771
 
6772
+ GGML_ASSERT(( q_to_vec_dot) && "fattn: unsupported K-type");
6773
+ GGML_ASSERT((v->type == GGML_TYPE_F32 || v_to_float ) && "fattn: unsupported V-type");
6774
 
6775
  // loop over n_batch and n_head
6776
  for (int ir = ir0; ir < ir1; ++ir) {
 
6866
  vs = expf(s - M);
6867
  }
6868
 
 
 
6869
  // V += v*expf(s - M)
6870
+ if (v_to_float) {
6871
+ v_to_float(v_data, V32, DV);
6872
+ ggml_vec_mad_f32(DV, VKQ32, V32, vs);
6873
+ } else {
6874
+ // V is F32
6875
+ ggml_vec_mad_f32(DV, VKQ32, (const float *) v_data, vs);
6876
+ }
6877
  }
6878
 
6879
  S = S*ms + vs; // scale and increment sum with partial sum
ggml/src/ggml-metal/ggml-metal.m CHANGED
@@ -1346,6 +1346,11 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
1346
  case GGML_OP_ARANGE:
1347
  return true;
1348
  case GGML_OP_FLASH_ATTN_EXT:
 
 
 
 
 
1349
  if (op->src[1]->type != op->src[2]->type) {
1350
  return false;
1351
  }
 
1346
  case GGML_OP_ARANGE:
1347
  return true;
1348
  case GGML_OP_FLASH_ATTN_EXT:
1349
+ if (op->src[0]->ne[0] == 32) {
1350
+ // head size == 32 (e.g. bert-bge-small)
1351
+ // TODO: not sure if it is worth adding kernels for this size
1352
+ return false;
1353
+ }
1354
  if (op->src[1]->type != op->src[2]->type) {
1355
  return false;
1356
  }