ggerganov commited on
Commit
99c4239
·
1 Parent(s): 5aceb3d

metal : use F32 prec for K*Q in vec FA (llama/9595)

Browse files
Files changed (1) hide show
  1. ggml/src/ggml-metal.metal +7 -7
ggml/src/ggml-metal.metal CHANGED
@@ -2631,11 +2631,11 @@ kernel void kernel_flash_attn_ext_vec_f16(
2631
  const short iv3 = iq3 / rv3;
2632
 
2633
  // load the queries from shared memory into local memory
2634
- half4 mq[D4];
2635
 
2636
  for (short ii = 0; ii < D4; ii += NW) {
2637
  short i = ii + tiisg;
2638
- mq[i] = sq4[i];
2639
  }
2640
 
2641
  // pointer to the mask
@@ -2661,11 +2661,11 @@ kernel void kernel_flash_attn_ext_vec_f16(
2661
  for (short ii = 0; ii < D4; ii += NW) {
2662
  const short i = ii + tiisg;
2663
 
2664
- half4x4 mk;
2665
- mk[0] = pk4[i + 0*(nb11/8)];
2666
- mk[1] = pk4[i + 1*(nb11/8)];
2667
- mk[2] = pk4[i + 2*(nb11/8)];
2668
- mk[3] = pk4[i + 3*(nb11/8)];
2669
 
2670
  mqk += (float4) (mq[i] * mk);
2671
  }
 
2631
  const short iv3 = iq3 / rv3;
2632
 
2633
  // load the queries from shared memory into local memory
2634
+ float4 mq[D4];
2635
 
2636
  for (short ii = 0; ii < D4; ii += NW) {
2637
  short i = ii + tiisg;
2638
+ mq[i] = (float4) sq4[i];
2639
  }
2640
 
2641
  // pointer to the mask
 
2661
  for (short ii = 0; ii < D4; ii += NW) {
2662
  const short i = ii + tiisg;
2663
 
2664
+ float4x4 mk;
2665
+ mk[0] = (float4) pk4[i + 0*(nb11/8)];
2666
+ mk[1] = (float4) pk4[i + 1*(nb11/8)];
2667
+ mk[2] = (float4) pk4[i + 2*(nb11/8)];
2668
+ mk[3] = (float4) pk4[i + 3*(nb11/8)];
2669
 
2670
  mqk += (float4) (mq[i] * mk);
2671
  }