Spaces:
Running
Running
metal : separate scale and mask from QKT in FA kernel (llama/9189)
Browse files* metal : separate scale and mask from QKT in FA kernel
* metal : ne01 check no longer necessary
* metal : keep data in local memory
- ggml/src/ggml-metal.metal +13 -22
ggml/src/ggml-metal.metal
CHANGED
|
@@ -2341,24 +2341,6 @@ kernel void kernel_flash_attn_ext_f16(
|
|
| 2341 |
}
|
| 2342 |
|
| 2343 |
simdgroup_store(mqk, ss + 8*cc, TF, 0, false);
|
| 2344 |
-
|
| 2345 |
-
const short tx = tiisg%4;
|
| 2346 |
-
const short ty = tiisg/4;
|
| 2347 |
-
|
| 2348 |
-
// mqk = mqk*scale
|
| 2349 |
-
ss[8*cc + ty*TF + 2*tx + 0] *= scale;
|
| 2350 |
-
ss[8*cc + ty*TF + 2*tx + 1] *= scale;
|
| 2351 |
-
|
| 2352 |
-
if (logit_softcap != 0.0f) {
|
| 2353 |
-
ss[8*cc + ty*TF + 2*tx + 0] = logit_softcap*precise::tanh(ss[8*cc + ty*TF + 2*tx + 0]);
|
| 2354 |
-
ss[8*cc + ty*TF + 2*tx + 1] = logit_softcap*precise::tanh(ss[8*cc + ty*TF + 2*tx + 1]);
|
| 2355 |
-
}
|
| 2356 |
-
|
| 2357 |
-
if (mask != q) {
|
| 2358 |
-
// mqk = mqk + mask*slope
|
| 2359 |
-
ss[8*cc + ty*TF + 2*tx + 0] += slope*mp[ic + 8*cc + ty*nb31/sizeof(half) + 2*tx + 0];
|
| 2360 |
-
ss[8*cc + ty*TF + 2*tx + 1] += slope*mp[ic + 8*cc + ty*nb31/sizeof(half) + 2*tx + 1];
|
| 2361 |
-
}
|
| 2362 |
}
|
| 2363 |
}
|
| 2364 |
|
|
@@ -2370,10 +2352,19 @@ kernel void kernel_flash_attn_ext_f16(
|
|
| 2370 |
float ms[Q];
|
| 2371 |
|
| 2372 |
for (short j = 0; j < Q; ++j) {
|
| 2373 |
-
const short p = tiisg;
|
| 2374 |
-
|
| 2375 |
const float m = M[j];
|
| 2376 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2377 |
|
| 2378 |
smax = simd_max(max(smax, s));
|
| 2379 |
M[j] = simd_max(max(M[j], s));
|
|
@@ -2384,7 +2375,7 @@ kernel void kernel_flash_attn_ext_f16(
|
|
| 2384 |
S[j] = S[j]*ms[j] + simd_sum(vs);
|
| 2385 |
|
| 2386 |
// the P matrix from the paper (Q rows, C columns)
|
| 2387 |
-
ss[j*TF +
|
| 2388 |
}
|
| 2389 |
|
| 2390 |
// create a QxQ diagonal matrix for rescaling the output
|
|
|
|
| 2341 |
}
|
| 2342 |
|
| 2343 |
simdgroup_store(mqk, ss + 8*cc, TF, 0, false);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2344 |
}
|
| 2345 |
}
|
| 2346 |
|
|
|
|
| 2352 |
float ms[Q];
|
| 2353 |
|
| 2354 |
for (short j = 0; j < Q; ++j) {
|
|
|
|
|
|
|
| 2355 |
const float m = M[j];
|
| 2356 |
+
|
| 2357 |
+
// scale and apply the logitcap / mask
|
| 2358 |
+
float s = ss[j*TF + tiisg]*scale;
|
| 2359 |
+
|
| 2360 |
+
if (logit_softcap != 0.0f) {
|
| 2361 |
+
s = logit_softcap*precise::tanh(s);
|
| 2362 |
+
}
|
| 2363 |
+
|
| 2364 |
+
if (mask != q) {
|
| 2365 |
+
// mqk = mqk + mask*slope
|
| 2366 |
+
s += slope*mp[ic + j*nb31/sizeof(half) + tiisg];
|
| 2367 |
+
}
|
| 2368 |
|
| 2369 |
smax = simd_max(max(smax, s));
|
| 2370 |
M[j] = simd_max(max(M[j], s));
|
|
|
|
| 2375 |
S[j] = S[j]*ms[j] + simd_sum(vs);
|
| 2376 |
|
| 2377 |
// the P matrix from the paper (Q rows, C columns)
|
| 2378 |
+
ss[j*TF + tiisg] = vs;
|
| 2379 |
}
|
| 2380 |
|
| 2381 |
// create a QxQ diagonal matrix for rescaling the output
|