ggerganov commited on
Commit
90cc3cd
·
1 Parent(s): b6e7294

metal : separate scale and mask from QKT in FA kernel (llama/9189)

Browse files

* metal : separate scale and mask from QKT in FA kernel

* metal : ne01 check no longer necessary

* metal : keep data in local memory

Files changed (1) hide show
  1. ggml/src/ggml-metal.metal +13 -22
ggml/src/ggml-metal.metal CHANGED
@@ -2341,24 +2341,6 @@ kernel void kernel_flash_attn_ext_f16(
2341
  }
2342
 
2343
  simdgroup_store(mqk, ss + 8*cc, TF, 0, false);
2344
-
2345
- const short tx = tiisg%4;
2346
- const short ty = tiisg/4;
2347
-
2348
- // mqk = mqk*scale
2349
- ss[8*cc + ty*TF + 2*tx + 0] *= scale;
2350
- ss[8*cc + ty*TF + 2*tx + 1] *= scale;
2351
-
2352
- if (logit_softcap != 0.0f) {
2353
- ss[8*cc + ty*TF + 2*tx + 0] = logit_softcap*precise::tanh(ss[8*cc + ty*TF + 2*tx + 0]);
2354
- ss[8*cc + ty*TF + 2*tx + 1] = logit_softcap*precise::tanh(ss[8*cc + ty*TF + 2*tx + 1]);
2355
- }
2356
-
2357
- if (mask != q) {
2358
- // mqk = mqk + mask*slope
2359
- ss[8*cc + ty*TF + 2*tx + 0] += slope*mp[ic + 8*cc + ty*nb31/sizeof(half) + 2*tx + 0];
2360
- ss[8*cc + ty*TF + 2*tx + 1] += slope*mp[ic + 8*cc + ty*nb31/sizeof(half) + 2*tx + 1];
2361
- }
2362
  }
2363
  }
2364
 
@@ -2370,10 +2352,19 @@ kernel void kernel_flash_attn_ext_f16(
2370
  float ms[Q];
2371
 
2372
  for (short j = 0; j < Q; ++j) {
2373
- const short p = tiisg;
2374
-
2375
  const float m = M[j];
2376
- const float s = ss[j*TF + p];
 
 
 
 
 
 
 
 
 
 
 
2377
 
2378
  smax = simd_max(max(smax, s));
2379
  M[j] = simd_max(max(M[j], s));
@@ -2384,7 +2375,7 @@ kernel void kernel_flash_attn_ext_f16(
2384
  S[j] = S[j]*ms[j] + simd_sum(vs);
2385
 
2386
  // the P matrix from the paper (Q rows, C columns)
2387
- ss[j*TF + p] = vs;
2388
  }
2389
 
2390
  // create a QxQ diagonal matrix for rescaling the output
 
2341
  }
2342
 
2343
  simdgroup_store(mqk, ss + 8*cc, TF, 0, false);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2344
  }
2345
  }
2346
 
 
2352
  float ms[Q];
2353
 
2354
  for (short j = 0; j < Q; ++j) {
 
 
2355
  const float m = M[j];
2356
+
2357
+ // scale and apply the logitcap / mask
2358
+ float s = ss[j*TF + tiisg]*scale;
2359
+
2360
+ if (logit_softcap != 0.0f) {
2361
+ s = logit_softcap*precise::tanh(s);
2362
+ }
2363
+
2364
+ if (mask != q) {
2365
+ // mqk = mqk + mask*slope
2366
+ s += slope*mp[ic + j*nb31/sizeof(half) + tiisg];
2367
+ }
2368
 
2369
  smax = simd_max(max(smax, s));
2370
  M[j] = simd_max(max(M[j], s));
 
2375
  S[j] = S[j]*ms[j] + simd_sum(vs);
2376
 
2377
  // the P matrix from the paper (Q rows, C columns)
2378
+ ss[j*TF + tiisg] = vs;
2379
  }
2380
 
2381
  // create a QxQ diagonal matrix for rescaling the output