Spaces:
Running
Running
metal : handle F16 inf values, fix FA partial offload (llama/7434)
Browse files- ggml-metal.metal +12 -15
ggml-metal.metal
CHANGED
|
@@ -2204,11 +2204,7 @@ kernel void kernel_flash_attn_ext_f16(
|
|
| 2204 |
// pointer to the mask
|
| 2205 |
device const half * mp = (device const half *) (mask + iq1*nb31);
|
| 2206 |
|
| 2207 |
-
|
| 2208 |
-
simdgroup_float8x8 mscale(scale);
|
| 2209 |
-
|
| 2210 |
-
// prepare diagonal slope matrix
|
| 2211 |
-
simdgroup_float8x8 mslope(1.0f);
|
| 2212 |
|
| 2213 |
// ALiBi
|
| 2214 |
if (max_bias > 0.0f) {
|
|
@@ -2217,7 +2213,7 @@ kernel void kernel_flash_attn_ext_f16(
|
|
| 2217 |
const float base = h < n_head_log2 ? m0 : m1;
|
| 2218 |
const int exph = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
|
| 2219 |
|
| 2220 |
-
|
| 2221 |
}
|
| 2222 |
|
| 2223 |
// loop over the KV cache
|
|
@@ -2242,18 +2238,20 @@ kernel void kernel_flash_attn_ext_f16(
|
|
| 2242 |
simdgroup_multiply_accumulate(mqk, mq[i], mk, mqk);
|
| 2243 |
}
|
| 2244 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2245 |
if (mask != q) {
|
| 2246 |
// mqk = mqk*scale + mask*slope
|
| 2247 |
-
|
| 2248 |
-
|
| 2249 |
-
simdgroup_multiply(mm, mslope, mm);
|
| 2250 |
-
simdgroup_multiply_accumulate(mqk, mqk, mscale, mm);
|
| 2251 |
} else {
|
| 2252 |
// mqk = mqk*scale
|
| 2253 |
-
|
|
|
|
| 2254 |
}
|
| 2255 |
-
|
| 2256 |
-
simdgroup_store(mqk, ss + 8*cc, TF, 0, false);
|
| 2257 |
}
|
| 2258 |
}
|
| 2259 |
|
|
@@ -2816,8 +2814,7 @@ kernel void kernel_cpy_f32_f16(
|
|
| 2816 |
for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
|
| 2817 |
device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
|
| 2818 |
|
| 2819 |
-
|
| 2820 |
-
dst_data[i00] = src[0] == -INFINITY ? -MAXHALF : src[0];
|
| 2821 |
}
|
| 2822 |
}
|
| 2823 |
|
|
|
|
| 2204 |
// pointer to the mask
|
| 2205 |
device const half * mp = (device const half *) (mask + iq1*nb31);
|
| 2206 |
|
| 2207 |
+
float slope = 1.0f;
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2208 |
|
| 2209 |
// ALiBi
|
| 2210 |
if (max_bias > 0.0f) {
|
|
|
|
| 2213 |
const float base = h < n_head_log2 ? m0 : m1;
|
| 2214 |
const int exph = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
|
| 2215 |
|
| 2216 |
+
slope = pow(base, exph);
|
| 2217 |
}
|
| 2218 |
|
| 2219 |
// loop over the KV cache
|
|
|
|
| 2238 |
simdgroup_multiply_accumulate(mqk, mq[i], mk, mqk);
|
| 2239 |
}
|
| 2240 |
|
| 2241 |
+
simdgroup_store(mqk, ss + 8*cc, TF, 0, false);
|
| 2242 |
+
|
| 2243 |
+
const short tx = tiisg%4;
|
| 2244 |
+
const short ty = tiisg/4;
|
| 2245 |
+
|
| 2246 |
if (mask != q) {
|
| 2247 |
// mqk = mqk*scale + mask*slope
|
| 2248 |
+
ss[8*cc + ty*TF + 2*tx + 0] = scale*ss[8*cc + ty*TF + 2*tx + 0] + slope*mp[ic + 8*cc + ty*nb31/sizeof(half) + 2*tx + 0];
|
| 2249 |
+
ss[8*cc + ty*TF + 2*tx + 1] = scale*ss[8*cc + ty*TF + 2*tx + 1] + slope*mp[ic + 8*cc + ty*nb31/sizeof(half) + 2*tx + 1];
|
|
|
|
|
|
|
| 2250 |
} else {
|
| 2251 |
// mqk = mqk*scale
|
| 2252 |
+
ss[8*cc + ty*TF + 2*tx + 0] *= scale;
|
| 2253 |
+
ss[8*cc + ty*TF + 2*tx + 1] *= scale;
|
| 2254 |
}
|
|
|
|
|
|
|
| 2255 |
}
|
| 2256 |
}
|
| 2257 |
|
|
|
|
| 2814 |
for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
|
| 2815 |
device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
|
| 2816 |
|
| 2817 |
+
dst_data[i00] = src[0];
|
|
|
|
| 2818 |
}
|
| 2819 |
}
|
| 2820 |
|