ggerganov commited on
Commit
8d153a7
·
1 Parent(s): f16510d

metal : handle F16 inf values, fix FA partial offload (llama/7434)

Browse files
Files changed (1) hide show
  1. ggml-metal.metal +12 -15
ggml-metal.metal CHANGED
@@ -2204,11 +2204,7 @@ kernel void kernel_flash_attn_ext_f16(
2204
  // pointer to the mask
2205
  device const half * mp = (device const half *) (mask + iq1*nb31);
2206
 
2207
- // prepare diagonal scale matrix
2208
- simdgroup_float8x8 mscale(scale);
2209
-
2210
- // prepare diagonal slope matrix
2211
- simdgroup_float8x8 mslope(1.0f);
2212
 
2213
  // ALiBi
2214
  if (max_bias > 0.0f) {
@@ -2217,7 +2213,7 @@ kernel void kernel_flash_attn_ext_f16(
2217
  const float base = h < n_head_log2 ? m0 : m1;
2218
  const int exph = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
2219
 
2220
- mslope = simdgroup_float8x8(pow(base, exph));
2221
  }
2222
 
2223
  // loop over the KV cache
@@ -2242,18 +2238,20 @@ kernel void kernel_flash_attn_ext_f16(
2242
  simdgroup_multiply_accumulate(mqk, mq[i], mk, mqk);
2243
  }
2244
 
 
 
 
 
 
2245
  if (mask != q) {
2246
  // mqk = mqk*scale + mask*slope
2247
- simdgroup_half8x8 mm;
2248
- simdgroup_load(mm, mp + ic + 8*cc, nb31/sizeof(half), 0, false);
2249
- simdgroup_multiply(mm, mslope, mm);
2250
- simdgroup_multiply_accumulate(mqk, mqk, mscale, mm);
2251
  } else {
2252
  // mqk = mqk*scale
2253
- simdgroup_multiply(mqk, mscale, mqk);
 
2254
  }
2255
-
2256
- simdgroup_store(mqk, ss + 8*cc, TF, 0, false);
2257
  }
2258
  }
2259
 
@@ -2816,8 +2814,7 @@ kernel void kernel_cpy_f32_f16(
2816
  for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
2817
  device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
2818
 
2819
- // TODO: is there a better way to handle -INFINITY?
2820
- dst_data[i00] = src[0] == -INFINITY ? -MAXHALF : src[0];
2821
  }
2822
  }
2823
 
 
2204
  // pointer to the mask
2205
  device const half * mp = (device const half *) (mask + iq1*nb31);
2206
 
2207
+ float slope = 1.0f;
 
 
 
 
2208
 
2209
  // ALiBi
2210
  if (max_bias > 0.0f) {
 
2213
  const float base = h < n_head_log2 ? m0 : m1;
2214
  const int exph = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
2215
 
2216
+ slope = pow(base, exph);
2217
  }
2218
 
2219
  // loop over the KV cache
 
2238
  simdgroup_multiply_accumulate(mqk, mq[i], mk, mqk);
2239
  }
2240
 
2241
+ simdgroup_store(mqk, ss + 8*cc, TF, 0, false);
2242
+
2243
+ const short tx = tiisg%4;
2244
+ const short ty = tiisg/4;
2245
+
2246
  if (mask != q) {
2247
  // mqk = mqk*scale + mask*slope
2248
+ ss[8*cc + ty*TF + 2*tx + 0] = scale*ss[8*cc + ty*TF + 2*tx + 0] + slope*mp[ic + 8*cc + ty*nb31/sizeof(half) + 2*tx + 0];
2249
+ ss[8*cc + ty*TF + 2*tx + 1] = scale*ss[8*cc + ty*TF + 2*tx + 1] + slope*mp[ic + 8*cc + ty*nb31/sizeof(half) + 2*tx + 1];
 
 
2250
  } else {
2251
  // mqk = mqk*scale
2252
+ ss[8*cc + ty*TF + 2*tx + 0] *= scale;
2253
+ ss[8*cc + ty*TF + 2*tx + 1] *= scale;
2254
  }
 
 
2255
  }
2256
  }
2257
 
 
2814
  for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
2815
  device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
2816
 
2817
+ dst_data[i00] = src[0];
 
2818
  }
2819
  }
2820