Spaces:
Running
Running
metal : use F32 prec for K*Q in vec FA (llama/9595)
Browse files
ggml/src/ggml-metal.metal
CHANGED
|
@@ -2631,11 +2631,11 @@ kernel void kernel_flash_attn_ext_vec_f16(
|
|
| 2631 |
const short iv3 = iq3 / rv3;
|
| 2632 |
|
| 2633 |
// load the queries from shared memory into local memory
|
| 2634 |
-
|
| 2635 |
|
| 2636 |
for (short ii = 0; ii < D4; ii += NW) {
|
| 2637 |
short i = ii + tiisg;
|
| 2638 |
-
mq[i] = sq4[i];
|
| 2639 |
}
|
| 2640 |
|
| 2641 |
// pointer to the mask
|
|
@@ -2661,11 +2661,11 @@ kernel void kernel_flash_attn_ext_vec_f16(
|
|
| 2661 |
for (short ii = 0; ii < D4; ii += NW) {
|
| 2662 |
const short i = ii + tiisg;
|
| 2663 |
|
| 2664 |
-
|
| 2665 |
-
mk[0] = pk4[i + 0*(nb11/8)];
|
| 2666 |
-
mk[1] = pk4[i + 1*(nb11/8)];
|
| 2667 |
-
mk[2] = pk4[i + 2*(nb11/8)];
|
| 2668 |
-
mk[3] = pk4[i + 3*(nb11/8)];
|
| 2669 |
|
| 2670 |
mqk += (float4) (mq[i] * mk);
|
| 2671 |
}
|
|
|
|
| 2631 |
const short iv3 = iq3 / rv3;
|
| 2632 |
|
| 2633 |
// load the queries from shared memory into local memory
|
| 2634 |
+
float4 mq[D4];
|
| 2635 |
|
| 2636 |
for (short ii = 0; ii < D4; ii += NW) {
|
| 2637 |
short i = ii + tiisg;
|
| 2638 |
+
mq[i] = (float4) sq4[i];
|
| 2639 |
}
|
| 2640 |
|
| 2641 |
// pointer to the mask
|
|
|
|
| 2661 |
for (short ii = 0; ii < D4; ii += NW) {
|
| 2662 |
const short i = ii + tiisg;
|
| 2663 |
|
| 2664 |
+
float4x4 mk;
|
| 2665 |
+
mk[0] = (float4) pk4[i + 0*(nb11/8)];
|
| 2666 |
+
mk[1] = (float4) pk4[i + 1*(nb11/8)];
|
| 2667 |
+
mk[2] = (float4) pk4[i + 2*(nb11/8)];
|
| 2668 |
+
mk[3] = (float4) pk4[i + 3*(nb11/8)];
|
| 2669 |
|
| 2670 |
mqk += (float4) (mq[i] * mk);
|
| 2671 |
}
|