Spaces:
Running
Running
metal : add quantized FA support (llama/10149)
Browse files* metal : add quantized FA (vec) support
ggml-ci
* metal : add quantized FA (non-vec) support
* metal : fix support check
ggml-ci
* metal : clean-up
* metal : clean-up (cont)
* metal : fix shared memory calc + reduce smem + comments
* metal : float-correctness
* metal : minor [no ci]
- ggml/src/ggml-metal.m +265 -37
- ggml/src/ggml-metal.metal +303 -155
ggml/src/ggml-metal.m
CHANGED
|
@@ -255,9 +255,49 @@ enum ggml_metal_kernel_type {
|
|
| 255 |
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96,
|
| 256 |
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112,
|
| 257 |
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128,
|
| 258 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 259 |
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128,
|
| 260 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 261 |
GGML_METAL_KERNEL_TYPE_CPY_F32_F32,
|
| 262 |
GGML_METAL_KERNEL_TYPE_CPY_F32_F16,
|
| 263 |
GGML_METAL_KERNEL_TYPE_CPY_F16_F16,
|
|
@@ -710,9 +750,49 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
|
|
| 710 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96, flash_attn_ext_f16_h96, support_simdgroup_mm);
|
| 711 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112, flash_attn_ext_f16_h112, support_simdgroup_mm);
|
| 712 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128, flash_attn_ext_f16_h128, support_simdgroup_mm);
|
| 713 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 714 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128, flash_attn_ext_vec_f16_h128, support_simdgroup_reduction);
|
| 715 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 716 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F16, cpy_f32_f16, true);
|
| 717 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F32, cpy_f32_f32, true);
|
| 718 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F16, cpy_f16_f16, true);
|
|
@@ -869,13 +949,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
|
|
| 869 |
case GGML_OP_LEAKY_RELU:
|
| 870 |
return true;
|
| 871 |
case GGML_OP_FLASH_ATTN_EXT:
|
| 872 |
-
if (op->src[1]->type !=
|
| 873 |
-
return false;
|
| 874 |
-
}
|
| 875 |
-
if (op->src[2]->type != GGML_TYPE_F16) {
|
| 876 |
-
return false;
|
| 877 |
-
}
|
| 878 |
-
if (op->src[0]->ne[0] == 256) {
|
| 879 |
return false;
|
| 880 |
}
|
| 881 |
return support_simdgroup_mm; // TODO: over-restricted for vec-kernels
|
|
@@ -2822,6 +2896,7 @@ static void ggml_metal_encode_node(
|
|
| 2822 |
GGML_ASSERT(ne11 % 32 == 0);
|
| 2823 |
|
| 2824 |
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
|
|
|
| 2825 |
|
| 2826 |
GGML_ASSERT(ggml_are_same_shape (src1, src2));
|
| 2827 |
|
|
@@ -2869,26 +2944,154 @@ static void ggml_metal_encode_node(
|
|
| 2869 |
bool use_vec_kernel = false;
|
| 2870 |
|
| 2871 |
if (ne01 >= 4 || (ne00%128 != 0)) {
|
| 2872 |
-
switch (
|
| 2873 |
-
case
|
| 2874 |
-
|
| 2875 |
-
|
| 2876 |
-
|
| 2877 |
-
|
| 2878 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2879 |
default:
|
| 2880 |
-
|
| 2881 |
-
|
| 2882 |
-
|
| 2883 |
-
|
| 2884 |
-
|
| 2885 |
}
|
| 2886 |
} else {
|
| 2887 |
use_vec_kernel = true;
|
| 2888 |
|
| 2889 |
switch (ne00) {
|
| 2890 |
-
case 128:
|
| 2891 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2892 |
default:
|
| 2893 |
{
|
| 2894 |
GGML_LOG_ERROR("unsupported size: %lld\n", ne00);
|
|
@@ -2942,10 +3145,16 @@ static void ggml_metal_encode_node(
|
|
| 2942 |
GGML_ASSERT(nqptg % 8 == 0);
|
| 2943 |
GGML_ASSERT(ncpsg % 32 == 0);
|
| 2944 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2945 |
int64_t nsgmax = 2;
|
| 2946 |
|
| 2947 |
while (true) {
|
| 2948 |
-
const size_t smem =
|
| 2949 |
if (smem > device.maxThreadgroupMemoryLength) {
|
| 2950 |
break;
|
| 2951 |
}
|
|
@@ -2956,16 +3165,15 @@ static void ggml_metal_encode_node(
|
|
| 2956 |
// simdgroups per threadgroup (a.k.a. warps)
|
| 2957 |
const int64_t nsg = ne01 <= nqptg ? MAX(4, MIN(nsgmax, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32))) : 4;
|
| 2958 |
|
| 2959 |
-
const size_t smem =
|
| 2960 |
|
| 2961 |
-
//printf("smem: %zu, max: %zu\n", smem, device.maxThreadgroupMemoryLength);
|
| 2962 |
GGML_ASSERT(smem <= device.maxThreadgroupMemoryLength);
|
| 2963 |
-
|
| 2964 |
-
|
| 2965 |
-
|
| 2966 |
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)];
|
| 2967 |
} else {
|
| 2968 |
-
//
|
| 2969 |
const int64_t nqptg = 1; // queries per threadgroup !! sync with kernel template arguments !!
|
| 2970 |
const int64_t ncpsg = 32; // cache values per simdgroup !! sync with kernel template arguments !!
|
| 2971 |
|
|
@@ -2973,8 +3181,28 @@ static void ggml_metal_encode_node(
|
|
| 2973 |
GGML_ASSERT(nqptg % 1 == 0);
|
| 2974 |
GGML_ASSERT(ncpsg % 32 == 0);
|
| 2975 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2976 |
// simdgroups per threadgroup (a.k.a. warps)
|
| 2977 |
-
const int64_t nsgt = MAX(2, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32));
|
| 2978 |
|
| 2979 |
int64_t nsg = 1;
|
| 2980 |
while (nsg <= nsgt) {
|
|
@@ -2982,12 +3210,12 @@ static void ggml_metal_encode_node(
|
|
| 2982 |
}
|
| 2983 |
nsg /= 2;
|
| 2984 |
|
| 2985 |
-
const size_t smem = (
|
| 2986 |
|
| 2987 |
-
//printf("smem: %zu, max: %zu\n", smem, device.maxThreadgroupMemoryLength);
|
| 2988 |
GGML_ASSERT(smem <= device.maxThreadgroupMemoryLength);
|
| 2989 |
-
[encoder setThreadgroupMemoryLength:
|
| 2990 |
-
|
| 2991 |
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)];
|
| 2992 |
}
|
| 2993 |
} break;
|
|
|
|
| 255 |
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96,
|
| 256 |
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112,
|
| 257 |
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128,
|
| 258 |
+
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256,
|
| 259 |
+
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H64,
|
| 260 |
+
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H80,
|
| 261 |
+
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H96,
|
| 262 |
+
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H112,
|
| 263 |
+
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H128,
|
| 264 |
+
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H256,
|
| 265 |
+
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H64,
|
| 266 |
+
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H80,
|
| 267 |
+
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H96,
|
| 268 |
+
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H112,
|
| 269 |
+
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H128,
|
| 270 |
+
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H256,
|
| 271 |
+
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H64,
|
| 272 |
+
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H80,
|
| 273 |
+
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H96,
|
| 274 |
+
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H112,
|
| 275 |
+
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H128,
|
| 276 |
+
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H256,
|
| 277 |
+
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H64,
|
| 278 |
+
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H80,
|
| 279 |
+
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H96,
|
| 280 |
+
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H112,
|
| 281 |
+
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H128,
|
| 282 |
+
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H256,
|
| 283 |
+
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H64,
|
| 284 |
+
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H80,
|
| 285 |
+
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H96,
|
| 286 |
+
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H112,
|
| 287 |
+
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H128,
|
| 288 |
+
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H256,
|
| 289 |
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128,
|
| 290 |
+
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H128,
|
| 291 |
+
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H128,
|
| 292 |
+
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H128,
|
| 293 |
+
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H128,
|
| 294 |
+
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H128,
|
| 295 |
+
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256,
|
| 296 |
+
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H256,
|
| 297 |
+
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H256,
|
| 298 |
+
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H256,
|
| 299 |
+
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H256,
|
| 300 |
+
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H256,
|
| 301 |
GGML_METAL_KERNEL_TYPE_CPY_F32_F32,
|
| 302 |
GGML_METAL_KERNEL_TYPE_CPY_F32_F16,
|
| 303 |
GGML_METAL_KERNEL_TYPE_CPY_F16_F16,
|
|
|
|
| 750 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96, flash_attn_ext_f16_h96, support_simdgroup_mm);
|
| 751 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112, flash_attn_ext_f16_h112, support_simdgroup_mm);
|
| 752 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128, flash_attn_ext_f16_h128, support_simdgroup_mm);
|
| 753 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, flash_attn_ext_f16_h256, support_simdgroup_mm);
|
| 754 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H64, flash_attn_ext_q4_0_h64, support_simdgroup_mm);
|
| 755 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H80, flash_attn_ext_q4_0_h80, support_simdgroup_mm);
|
| 756 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H96, flash_attn_ext_q4_0_h96, support_simdgroup_mm);
|
| 757 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H112, flash_attn_ext_q4_0_h112, support_simdgroup_mm);
|
| 758 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H128, flash_attn_ext_q4_0_h128, support_simdgroup_mm);
|
| 759 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H256, flash_attn_ext_q4_0_h256, support_simdgroup_mm);
|
| 760 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H64, flash_attn_ext_q4_1_h64, support_simdgroup_mm);
|
| 761 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H80, flash_attn_ext_q4_1_h80, support_simdgroup_mm);
|
| 762 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H96, flash_attn_ext_q4_1_h96, support_simdgroup_mm);
|
| 763 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H112, flash_attn_ext_q4_1_h112, support_simdgroup_mm);
|
| 764 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H128, flash_attn_ext_q4_1_h128, support_simdgroup_mm);
|
| 765 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H256, flash_attn_ext_q4_1_h256, support_simdgroup_mm);
|
| 766 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H64, flash_attn_ext_q5_0_h64, support_simdgroup_mm);
|
| 767 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H80, flash_attn_ext_q5_0_h80, support_simdgroup_mm);
|
| 768 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H96, flash_attn_ext_q5_0_h96, support_simdgroup_mm);
|
| 769 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H112, flash_attn_ext_q5_0_h112, support_simdgroup_mm);
|
| 770 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H128, flash_attn_ext_q5_0_h128, support_simdgroup_mm);
|
| 771 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H256, flash_attn_ext_q5_0_h256, support_simdgroup_mm);
|
| 772 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H64, flash_attn_ext_q5_1_h64, support_simdgroup_mm);
|
| 773 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H80, flash_attn_ext_q5_1_h80, support_simdgroup_mm);
|
| 774 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H96, flash_attn_ext_q5_1_h96, support_simdgroup_mm);
|
| 775 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H112, flash_attn_ext_q5_1_h112, support_simdgroup_mm);
|
| 776 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H128, flash_attn_ext_q5_1_h128, support_simdgroup_mm);
|
| 777 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H256, flash_attn_ext_q5_1_h256, support_simdgroup_mm);
|
| 778 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H64, flash_attn_ext_q8_0_h64, support_simdgroup_mm);
|
| 779 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H80, flash_attn_ext_q8_0_h80, support_simdgroup_mm);
|
| 780 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H96, flash_attn_ext_q8_0_h96, support_simdgroup_mm);
|
| 781 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H112, flash_attn_ext_q8_0_h112, support_simdgroup_mm);
|
| 782 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H128, flash_attn_ext_q8_0_h128, support_simdgroup_mm);
|
| 783 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H256, flash_attn_ext_q8_0_h256, support_simdgroup_mm);
|
| 784 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128, flash_attn_ext_vec_f16_h128, support_simdgroup_reduction);
|
| 785 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H128, flash_attn_ext_vec_q4_0_h128, support_simdgroup_reduction);
|
| 786 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H128, flash_attn_ext_vec_q4_1_h128, support_simdgroup_reduction);
|
| 787 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H128, flash_attn_ext_vec_q5_0_h128, support_simdgroup_reduction);
|
| 788 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H128, flash_attn_ext_vec_q5_1_h128, support_simdgroup_reduction);
|
| 789 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H128, flash_attn_ext_vec_q8_0_h128, support_simdgroup_reduction);
|
| 790 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, flash_attn_ext_vec_f16_h256, support_simdgroup_reduction);
|
| 791 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H256, flash_attn_ext_vec_q4_0_h256, support_simdgroup_reduction);
|
| 792 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H256, flash_attn_ext_vec_q4_1_h256, support_simdgroup_reduction);
|
| 793 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H256, flash_attn_ext_vec_q5_0_h256, support_simdgroup_reduction);
|
| 794 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H256, flash_attn_ext_vec_q5_1_h256, support_simdgroup_reduction);
|
| 795 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H256, flash_attn_ext_vec_q8_0_h256, support_simdgroup_reduction);
|
| 796 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F16, cpy_f32_f16, true);
|
| 797 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F32, cpy_f32_f32, true);
|
| 798 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F16, cpy_f16_f16, true);
|
|
|
|
| 949 |
case GGML_OP_LEAKY_RELU:
|
| 950 |
return true;
|
| 951 |
case GGML_OP_FLASH_ATTN_EXT:
|
| 952 |
+
if (op->src[1]->type != op->src[2]->type) {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 953 |
return false;
|
| 954 |
}
|
| 955 |
return support_simdgroup_mm; // TODO: over-restricted for vec-kernels
|
|
|
|
| 2896 |
GGML_ASSERT(ne11 % 32 == 0);
|
| 2897 |
|
| 2898 |
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
| 2899 |
+
GGML_ASSERT(src1->type == src2->type);
|
| 2900 |
|
| 2901 |
GGML_ASSERT(ggml_are_same_shape (src1, src2));
|
| 2902 |
|
|
|
|
| 2944 |
bool use_vec_kernel = false;
|
| 2945 |
|
| 2946 |
if (ne01 >= 4 || (ne00%128 != 0)) {
|
| 2947 |
+
switch (src1->type) {
|
| 2948 |
+
case GGML_TYPE_F16:
|
| 2949 |
+
{
|
| 2950 |
+
switch (ne00) {
|
| 2951 |
+
case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64 ].pipeline; break;
|
| 2952 |
+
case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80 ].pipeline; break;
|
| 2953 |
+
case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96 ].pipeline; break;
|
| 2954 |
+
case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112].pipeline; break;
|
| 2955 |
+
case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128].pipeline; break;
|
| 2956 |
+
case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256].pipeline; break;
|
| 2957 |
+
default:
|
| 2958 |
+
{
|
| 2959 |
+
GGML_LOG_ERROR("unsupported size: %lld\n", ne00);
|
| 2960 |
+
GGML_LOG_ERROR("add template specialization for this size\n");
|
| 2961 |
+
GGML_ABORT("add template specialization for this size");
|
| 2962 |
+
}
|
| 2963 |
+
}
|
| 2964 |
+
} break;
|
| 2965 |
+
case GGML_TYPE_Q4_0:
|
| 2966 |
+
{
|
| 2967 |
+
switch (ne00) {
|
| 2968 |
+
case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H64 ].pipeline; break;
|
| 2969 |
+
case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H80 ].pipeline; break;
|
| 2970 |
+
case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H96 ].pipeline; break;
|
| 2971 |
+
case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H112].pipeline; break;
|
| 2972 |
+
case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H128].pipeline; break;
|
| 2973 |
+
case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H256].pipeline; break;
|
| 2974 |
+
default:
|
| 2975 |
+
{
|
| 2976 |
+
GGML_LOG_ERROR("unsupported size: %lld\n", ne00);
|
| 2977 |
+
GGML_LOG_ERROR("add template specialization for this size\n");
|
| 2978 |
+
GGML_ABORT("add template specialization for this size");
|
| 2979 |
+
}
|
| 2980 |
+
}
|
| 2981 |
+
} break;
|
| 2982 |
+
case GGML_TYPE_Q4_1:
|
| 2983 |
+
{
|
| 2984 |
+
switch (ne00) {
|
| 2985 |
+
case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H64 ].pipeline; break;
|
| 2986 |
+
case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H80 ].pipeline; break;
|
| 2987 |
+
case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H96 ].pipeline; break;
|
| 2988 |
+
case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H112].pipeline; break;
|
| 2989 |
+
case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H128].pipeline; break;
|
| 2990 |
+
case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H256].pipeline; break;
|
| 2991 |
+
default:
|
| 2992 |
+
{
|
| 2993 |
+
GGML_LOG_ERROR("unsupported size: %lld\n", ne00);
|
| 2994 |
+
GGML_LOG_ERROR("add template specialization for this size\n");
|
| 2995 |
+
GGML_ABORT("add template specialization for this size");
|
| 2996 |
+
}
|
| 2997 |
+
}
|
| 2998 |
+
} break;
|
| 2999 |
+
case GGML_TYPE_Q5_0:
|
| 3000 |
+
{
|
| 3001 |
+
switch (ne00) {
|
| 3002 |
+
case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H64 ].pipeline; break;
|
| 3003 |
+
case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H80 ].pipeline; break;
|
| 3004 |
+
case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H96 ].pipeline; break;
|
| 3005 |
+
case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H112].pipeline; break;
|
| 3006 |
+
case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H128].pipeline; break;
|
| 3007 |
+
case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H256].pipeline; break;
|
| 3008 |
+
default:
|
| 3009 |
+
{
|
| 3010 |
+
GGML_LOG_ERROR("unsupported size: %lld\n", ne00);
|
| 3011 |
+
GGML_LOG_ERROR("add template specialization for this size\n");
|
| 3012 |
+
GGML_ABORT("add template specialization for this size");
|
| 3013 |
+
}
|
| 3014 |
+
}
|
| 3015 |
+
} break;
|
| 3016 |
+
case GGML_TYPE_Q5_1:
|
| 3017 |
+
{
|
| 3018 |
+
switch (ne00) {
|
| 3019 |
+
case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H64 ].pipeline; break;
|
| 3020 |
+
case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H80 ].pipeline; break;
|
| 3021 |
+
case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H96 ].pipeline; break;
|
| 3022 |
+
case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H112].pipeline; break;
|
| 3023 |
+
case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H128].pipeline; break;
|
| 3024 |
+
case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H256].pipeline; break;
|
| 3025 |
+
default:
|
| 3026 |
+
{
|
| 3027 |
+
GGML_LOG_ERROR("unsupported size: %lld\n", ne00);
|
| 3028 |
+
GGML_LOG_ERROR("add template specialization for this size\n");
|
| 3029 |
+
GGML_ABORT("add template specialization for this size");
|
| 3030 |
+
}
|
| 3031 |
+
}
|
| 3032 |
+
} break;
|
| 3033 |
+
case GGML_TYPE_Q8_0:
|
| 3034 |
+
{
|
| 3035 |
+
switch (ne00) {
|
| 3036 |
+
case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H64 ].pipeline; break;
|
| 3037 |
+
case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H80 ].pipeline; break;
|
| 3038 |
+
case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H96 ].pipeline; break;
|
| 3039 |
+
case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H112].pipeline; break;
|
| 3040 |
+
case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H128].pipeline; break;
|
| 3041 |
+
case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H256].pipeline; break;
|
| 3042 |
+
default:
|
| 3043 |
+
{
|
| 3044 |
+
GGML_LOG_ERROR("unsupported size: %lld\n", ne00);
|
| 3045 |
+
GGML_LOG_ERROR("add template specialization for this size\n");
|
| 3046 |
+
GGML_ABORT("add template specialization for this size");
|
| 3047 |
+
}
|
| 3048 |
+
}
|
| 3049 |
+
} break;
|
| 3050 |
default:
|
| 3051 |
+
{
|
| 3052 |
+
GGML_LOG_ERROR("unsupported type: %d\n", src1->type);
|
| 3053 |
+
GGML_LOG_ERROR("add template specialization for this type\n");
|
| 3054 |
+
GGML_ABORT("add template specialization for this type");
|
| 3055 |
+
}
|
| 3056 |
}
|
| 3057 |
} else {
|
| 3058 |
use_vec_kernel = true;
|
| 3059 |
|
| 3060 |
switch (ne00) {
|
| 3061 |
+
case 128:
|
| 3062 |
+
{
|
| 3063 |
+
switch (src1->type) {
|
| 3064 |
+
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128].pipeline; break;
|
| 3065 |
+
case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H128].pipeline; break;
|
| 3066 |
+
case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H128].pipeline; break;
|
| 3067 |
+
case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H128].pipeline; break;
|
| 3068 |
+
case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H128].pipeline; break;
|
| 3069 |
+
case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H128].pipeline; break;
|
| 3070 |
+
default:
|
| 3071 |
+
{
|
| 3072 |
+
GGML_LOG_ERROR("unsupported type: %d\n", src1->type);
|
| 3073 |
+
GGML_LOG_ERROR("add template specialization for this type\n");
|
| 3074 |
+
GGML_ABORT("add template specialization for this type");
|
| 3075 |
+
}
|
| 3076 |
+
}
|
| 3077 |
+
} break;
|
| 3078 |
+
case 256:
|
| 3079 |
+
{
|
| 3080 |
+
switch (src1->type) {
|
| 3081 |
+
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256].pipeline; break;
|
| 3082 |
+
case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H256].pipeline; break;
|
| 3083 |
+
case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H256].pipeline; break;
|
| 3084 |
+
case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H256].pipeline; break;
|
| 3085 |
+
case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H256].pipeline; break;
|
| 3086 |
+
case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H256].pipeline; break;
|
| 3087 |
+
default:
|
| 3088 |
+
{
|
| 3089 |
+
GGML_LOG_ERROR("unsupported type: %d\n", src1->type);
|
| 3090 |
+
GGML_LOG_ERROR("add template specialization for this type\n");
|
| 3091 |
+
GGML_ABORT("add template specialization for this type");
|
| 3092 |
+
}
|
| 3093 |
+
}
|
| 3094 |
+
} break;
|
| 3095 |
default:
|
| 3096 |
{
|
| 3097 |
GGML_LOG_ERROR("unsupported size: %lld\n", ne00);
|
|
|
|
| 3145 |
GGML_ASSERT(nqptg % 8 == 0);
|
| 3146 |
GGML_ASSERT(ncpsg % 32 == 0);
|
| 3147 |
|
| 3148 |
+
// 16*32*(nsg)
|
| 3149 |
+
// the shared memory needed for the simdgroups to load the KV cache
|
| 3150 |
+
// each thread loads (dequantizes) 16 head elements, there are 32 threads in th SG
|
| 3151 |
+
//
|
| 3152 |
+
#define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(ne00 + 2*(ncpsg + nqptg)*(nsg)) + 16*32*(nsg))*(sizeof(float)/2), 16))
|
| 3153 |
+
|
| 3154 |
int64_t nsgmax = 2;
|
| 3155 |
|
| 3156 |
while (true) {
|
| 3157 |
+
const size_t smem = FATTN_SMEM(nsgmax);
|
| 3158 |
if (smem > device.maxThreadgroupMemoryLength) {
|
| 3159 |
break;
|
| 3160 |
}
|
|
|
|
| 3165 |
// simdgroups per threadgroup (a.k.a. warps)
|
| 3166 |
const int64_t nsg = ne01 <= nqptg ? MAX(4, MIN(nsgmax, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32))) : 4;
|
| 3167 |
|
| 3168 |
+
const size_t smem = FATTN_SMEM(nsg);
|
| 3169 |
|
| 3170 |
+
//printf("smem: %zu, max: %zu, nsg = %d\n", smem, device.maxThreadgroupMemoryLength, (int) nsg);
|
| 3171 |
GGML_ASSERT(smem <= device.maxThreadgroupMemoryLength);
|
| 3172 |
+
[encoder setThreadgroupMemoryLength:smem atIndex:0];
|
| 3173 |
+
#undef FATTN_SMEM
|
|
|
|
| 3174 |
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)];
|
| 3175 |
} else {
|
| 3176 |
+
// half4x4 kernel
|
| 3177 |
const int64_t nqptg = 1; // queries per threadgroup !! sync with kernel template arguments !!
|
| 3178 |
const int64_t ncpsg = 32; // cache values per simdgroup !! sync with kernel template arguments !!
|
| 3179 |
|
|
|
|
| 3181 |
GGML_ASSERT(nqptg % 1 == 0);
|
| 3182 |
GGML_ASSERT(ncpsg % 32 == 0);
|
| 3183 |
|
| 3184 |
+
// ne00 + 2*ncpsg*(nsg)
|
| 3185 |
+
// for each query, we load it as f16 in shared memory (ne00)
|
| 3186 |
+
// and store the attention scores (nqptg x ncpsg) as f32
|
| 3187 |
+
//
|
| 3188 |
+
// 2*ne00*(nsg)
|
| 3189 |
+
// each simdgroup has a full f32 head vector in shared mem to accumulate results
|
| 3190 |
+
//
|
| 3191 |
+
#define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(ne00 + 2*ncpsg*(nsg)) + 2*ne00*(nsg))*(sizeof(float)/2), 16))
|
| 3192 |
+
|
| 3193 |
+
int64_t nsgmax = 2;
|
| 3194 |
+
|
| 3195 |
+
while (true) {
|
| 3196 |
+
const size_t smem = FATTN_SMEM(nsgmax);
|
| 3197 |
+
if (smem > device.maxThreadgroupMemoryLength) {
|
| 3198 |
+
break;
|
| 3199 |
+
}
|
| 3200 |
+
nsgmax *= 2;
|
| 3201 |
+
}
|
| 3202 |
+
nsgmax /= 2;
|
| 3203 |
+
|
| 3204 |
// simdgroups per threadgroup (a.k.a. warps)
|
| 3205 |
+
const int64_t nsgt = MAX(2, MIN(nsgmax, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32)));
|
| 3206 |
|
| 3207 |
int64_t nsg = 1;
|
| 3208 |
while (nsg <= nsgt) {
|
|
|
|
| 3210 |
}
|
| 3211 |
nsg /= 2;
|
| 3212 |
|
| 3213 |
+
const size_t smem = FATTN_SMEM(nsg);
|
| 3214 |
|
| 3215 |
+
//printf("smem: %zu, max: %zu, nsg = %d\n", smem, device.maxThreadgroupMemoryLength, (int) nsg);
|
| 3216 |
GGML_ASSERT(smem <= device.maxThreadgroupMemoryLength);
|
| 3217 |
+
[encoder setThreadgroupMemoryLength:smem atIndex:0];
|
| 3218 |
+
#undef FATTN_SMEM
|
| 3219 |
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)];
|
| 3220 |
}
|
| 3221 |
} break;
|
ggml/src/ggml-metal.metal
CHANGED
|
@@ -2723,46 +2723,10 @@ kernel void kernel_leaky_relu_f32(
|
|
| 2723 |
dst[tpig] = src0[tpig] > 0.0f ? src0[tpig] : src0[tpig] * slope;
|
| 2724 |
}
|
| 2725 |
|
| 2726 |
-
typedef void (flash_attn_ext_f16_t)(
|
| 2727 |
-
device const char * q,
|
| 2728 |
-
device const char * k,
|
| 2729 |
-
device const char * v,
|
| 2730 |
-
device const char * mask,
|
| 2731 |
-
device float * dst,
|
| 2732 |
-
constant int64_t & ne01,
|
| 2733 |
-
constant int64_t & ne02,
|
| 2734 |
-
constant int64_t & ne03,
|
| 2735 |
-
constant uint64_t & nb01,
|
| 2736 |
-
constant uint64_t & nb02,
|
| 2737 |
-
constant uint64_t & nb03,
|
| 2738 |
-
constant int64_t & ne11,
|
| 2739 |
-
constant int64_t & ne12,
|
| 2740 |
-
constant int64_t & ne13,
|
| 2741 |
-
constant uint64_t & nb11,
|
| 2742 |
-
constant uint64_t & nb12,
|
| 2743 |
-
constant uint64_t & nb13,
|
| 2744 |
-
constant uint64_t & nb21,
|
| 2745 |
-
constant uint64_t & nb22,
|
| 2746 |
-
constant uint64_t & nb23,
|
| 2747 |
-
constant uint64_t & nb31,
|
| 2748 |
-
constant int64_t & ne1,
|
| 2749 |
-
constant int64_t & ne2,
|
| 2750 |
-
constant float & scale,
|
| 2751 |
-
constant float & max_bias,
|
| 2752 |
-
constant float & m0,
|
| 2753 |
-
constant float & m1,
|
| 2754 |
-
constant uint32_t & n_head_log2,
|
| 2755 |
-
constant float & logit_softcap,
|
| 2756 |
-
threadgroup half * shared,
|
| 2757 |
-
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 2758 |
-
uint3 tpitg[[thread_position_in_threadgroup]],
|
| 2759 |
-
uint3 ntg[[threads_per_threadgroup]],
|
| 2760 |
-
ushort tiisg[[thread_index_in_simdgroup]],
|
| 2761 |
-
ushort sgitg[[simdgroup_index_in_threadgroup]]);
|
| 2762 |
-
|
| 2763 |
// ref: https://arxiv.org/pdf/2307.08691.pdf
|
| 2764 |
-
|
| 2765 |
-
|
|
|
|
| 2766 |
device const char * q,
|
| 2767 |
device const char * k,
|
| 2768 |
device const char * v,
|
|
@@ -2800,15 +2764,15 @@ kernel void kernel_flash_attn_ext_f16(
|
|
| 2800 |
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
|
| 2801 |
const short nsg = ntg.y; // number of simdgroups
|
| 2802 |
|
| 2803 |
-
const
|
| 2804 |
-
const
|
| 2805 |
-
const
|
| 2806 |
|
| 2807 |
-
const short D4
|
| 2808 |
-
const short D8
|
| 2809 |
-
|
| 2810 |
-
const short NW
|
| 2811 |
-
const short SH
|
| 2812 |
|
| 2813 |
const short T = D + 2*nsg*SH; // shared memory size per query in (half)
|
| 2814 |
const short TF = T/2; // shared memory size per query in (float)
|
|
@@ -2818,6 +2782,9 @@ kernel void kernel_flash_attn_ext_f16(
|
|
| 2818 |
threadgroup half4 * sq4 = (threadgroup half4 *) (shared + 0*D); // same as above but in half4
|
| 2819 |
threadgroup float * ss = (threadgroup float *) (shared + 2*sgitg*SH + 1*D); // scratch buffer for attention and diagonal matrix
|
| 2820 |
|
|
|
|
|
|
|
|
|
|
| 2821 |
// store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper)
|
| 2822 |
simdgroup_half8x8 lo[D8];
|
| 2823 |
|
|
@@ -2849,25 +2816,28 @@ kernel void kernel_flash_attn_ext_f16(
|
|
| 2849 |
threadgroup_barrier(mem_flags::mem_threadgroup);
|
| 2850 |
|
| 2851 |
{
|
| 2852 |
-
float S[Q] = { [0 ... Q-1] = 0.
|
| 2853 |
float M[Q] = { [0 ... Q-1] = -FLT_MAX/2 };
|
| 2854 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2855 |
// assume K and V are same shape
|
| 2856 |
const short ne22 = ne12;
|
| 2857 |
const short ne23 = ne13;
|
| 2858 |
|
| 2859 |
-
// broadcast
|
| 2860 |
const short rk2 = ne02/ne12;
|
| 2861 |
const short rk3 = ne03/ne13;
|
| 2862 |
|
| 2863 |
-
const short rv2 = ne02/ne22;
|
| 2864 |
-
const short rv3 = ne03/ne23;
|
| 2865 |
-
|
| 2866 |
-
// k indices
|
| 2867 |
const short ik2 = iq2/rk2;
|
| 2868 |
const short ik3 = iq3/rk3;
|
| 2869 |
|
| 2870 |
-
// v
|
|
|
|
|
|
|
|
|
|
| 2871 |
const short iv2 = iq2/rv2;
|
| 2872 |
const short iv3 = iq3/rv3;
|
| 2873 |
|
|
@@ -2906,13 +2876,59 @@ kernel void kernel_flash_attn_ext_f16(
|
|
| 2906 |
for (short cc = 0; cc < C/8; ++cc) {
|
| 2907 |
simdgroup_float8x8 mqk = make_filled_simdgroup_matrix<float, 8>(0.h);
|
| 2908 |
|
| 2909 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2910 |
|
| 2911 |
-
|
| 2912 |
-
simdgroup_half8x8 mk;
|
| 2913 |
-
simdgroup_load(mk, pk + i*8, nb11/sizeof(half), 0, true); // transpose
|
| 2914 |
|
| 2915 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2916 |
}
|
| 2917 |
|
| 2918 |
simdgroup_store(mqk, ss + 8*cc, TF, 0, false);
|
|
@@ -2977,16 +2993,61 @@ kernel void kernel_flash_attn_ext_f16(
|
|
| 2977 |
// O = O + (Q*K^T)*V
|
| 2978 |
{
|
| 2979 |
for (short cc = 0; cc < C/8; ++cc) {
|
| 2980 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2981 |
|
| 2982 |
-
|
| 2983 |
-
|
| 2984 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2985 |
|
| 2986 |
-
|
| 2987 |
-
simdgroup_load(mv, ss + 8*cc, TF, 0, false);
|
| 2988 |
|
| 2989 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2990 |
}
|
| 2991 |
}
|
| 2992 |
}
|
|
@@ -3003,7 +3064,7 @@ kernel void kernel_flash_attn_ext_f16(
|
|
| 3003 |
|
| 3004 |
// reduce the warps sequentially
|
| 3005 |
for (short sg = 1; sg < nsg; ++sg) {
|
| 3006 |
-
float S = { 0.
|
| 3007 |
float M = { -FLT_MAX/2 };
|
| 3008 |
|
| 3009 |
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
@@ -3082,15 +3143,54 @@ kernel void kernel_flash_attn_ext_f16(
|
|
| 3082 |
}
|
| 3083 |
}
|
| 3084 |
|
| 3085 |
-
|
| 3086 |
-
|
| 3087 |
-
template [[host_name("
|
| 3088 |
-
template [[host_name("
|
| 3089 |
-
template [[host_name("
|
| 3090 |
-
|
| 3091 |
-
|
| 3092 |
-
template
|
| 3093 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3094 |
device const char * q,
|
| 3095 |
device const char * k,
|
| 3096 |
device const char * v,
|
|
@@ -3128,36 +3228,27 @@ kernel void kernel_flash_attn_ext_vec_f16(
|
|
| 3128 |
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
|
| 3129 |
const short nsg = ntg.y; // number of simdgroups
|
| 3130 |
|
| 3131 |
-
const
|
| 3132 |
-
const
|
| 3133 |
-
const
|
| 3134 |
|
| 3135 |
-
const short D4
|
| 3136 |
-
const short
|
| 3137 |
-
const short
|
|
|
|
|
|
|
| 3138 |
|
| 3139 |
const short T = D + 2*nsg*SH; // shared memory size per query in (half)
|
| 3140 |
|
| 3141 |
-
|
| 3142 |
-
|
| 3143 |
-
//
|
| 3144 |
-
|
| 3145 |
-
|
| 3146 |
-
|
| 3147 |
-
const float base = h < n_head_log2 ? m0 : m1;
|
| 3148 |
-
const int exp = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
|
| 3149 |
-
|
| 3150 |
-
slope = pow(base, exp);
|
| 3151 |
-
}
|
| 3152 |
-
|
| 3153 |
-
//threadgroup half * sq = (threadgroup half *) (shared + 0*D); // holds the query data
|
| 3154 |
-
threadgroup half4 * sq4 = (threadgroup half4 *) (shared + 0*D); // same as above but in half4
|
| 3155 |
-
threadgroup float * ss = (threadgroup float *) (shared + 2*sgitg*SH + 1*D); // scratch buffer for attention and diagonal matrix
|
| 3156 |
-
threadgroup float4 * ss4 = (threadgroup float4 *) (shared + 2*sgitg*SH + 1*D); // same as above but in half4
|
| 3157 |
-
threadgroup half4 * sr4 = (threadgroup half4 *) (shared + sgitg*D + 1*T); // scratch buffer for the results
|
| 3158 |
|
| 3159 |
// store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper)
|
| 3160 |
-
|
| 3161 |
|
| 3162 |
// load heads from Q to shared memory
|
| 3163 |
device const float4 * q4 = (device const float4 *) ((device const char *) q + (iq1*nb01 + iq2*nb02 + iq3*nb03));
|
|
@@ -3171,8 +3262,8 @@ kernel void kernel_flash_attn_ext_vec_f16(
|
|
| 3171 |
}
|
| 3172 |
|
| 3173 |
// zero out lo
|
| 3174 |
-
for (short i =
|
| 3175 |
-
lo[i
|
| 3176 |
}
|
| 3177 |
|
| 3178 |
// zero out shared memory SH
|
|
@@ -3183,38 +3274,52 @@ kernel void kernel_flash_attn_ext_vec_f16(
|
|
| 3183 |
threadgroup_barrier(mem_flags::mem_threadgroup);
|
| 3184 |
|
| 3185 |
{
|
| 3186 |
-
float S =
|
| 3187 |
-
float M =
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3188 |
|
| 3189 |
// assume K and V are same shape
|
| 3190 |
const short ne22 = ne12;
|
| 3191 |
const short ne23 = ne13;
|
| 3192 |
|
| 3193 |
-
// broadcast
|
| 3194 |
const short rk2 = ne02/ne12;
|
| 3195 |
const short rk3 = ne03/ne13;
|
| 3196 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3197 |
const short rv2 = ne02/ne22;
|
| 3198 |
const short rv3 = ne03/ne23;
|
| 3199 |
|
| 3200 |
-
|
| 3201 |
-
const short
|
| 3202 |
-
const short ik3 = iq3 / rk3;
|
| 3203 |
-
|
| 3204 |
-
// v indices
|
| 3205 |
-
const short iv2 = iq2 / rv2;
|
| 3206 |
-
const short iv3 = iq3 / rv3;
|
| 3207 |
|
| 3208 |
// load the queries from shared memory into local memory
|
| 3209 |
-
|
| 3210 |
|
| 3211 |
-
for (short ii = 0; ii <
|
| 3212 |
-
|
| 3213 |
-
mq[ii/NW] = (float4) sq4[i];
|
| 3214 |
}
|
| 3215 |
|
| 3216 |
// pointer to the mask
|
| 3217 |
-
device const
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3218 |
|
| 3219 |
// loop over the KV cache
|
| 3220 |
// each simdgroup handles blocks of Q rows and C columns
|
|
@@ -3226,47 +3331,54 @@ kernel void kernel_flash_attn_ext_vec_f16(
|
|
| 3226 |
|
| 3227 |
// Q*K^T
|
| 3228 |
{
|
| 3229 |
-
|
| 3230 |
for (short cc = 0; cc < C/4; ++cc) {
|
| 3231 |
-
|
| 3232 |
|
| 3233 |
-
device const
|
| 3234 |
|
| 3235 |
#pragma unroll
|
| 3236 |
-
for (short ii = 0; ii <
|
| 3237 |
-
const short i = ii +
|
| 3238 |
|
| 3239 |
float4x4 mk;
|
| 3240 |
-
|
| 3241 |
-
mk[1] = (float4) pk4[i + 1*(nb11/8)];
|
| 3242 |
-
mk[2] = (float4) pk4[i + 2*(nb11/8)];
|
| 3243 |
-
mk[3] = (float4) pk4[i + 3*(nb11/8)];
|
| 3244 |
|
| 3245 |
-
mqk +=
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3246 |
}
|
| 3247 |
|
| 3248 |
-
// reduce
|
| 3249 |
-
|
| 3250 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3251 |
mqk += simd_shuffle_down(mqk, 4);
|
| 3252 |
mqk += simd_shuffle_down(mqk, 2);
|
| 3253 |
mqk += simd_shuffle_down(mqk, 1);
|
| 3254 |
|
| 3255 |
// mqk = mqk*scale + mask*slope
|
| 3256 |
-
if (
|
| 3257 |
mqk *= scale;
|
| 3258 |
|
| 3259 |
if (logit_softcap != 0.0f) {
|
| 3260 |
mqk = logit_softcap*precise::tanh(mqk);
|
| 3261 |
}
|
| 3262 |
|
| 3263 |
-
mqk += (mask != q) ? ((
|
| 3264 |
|
| 3265 |
-
|
| 3266 |
}
|
| 3267 |
}
|
| 3268 |
}
|
| 3269 |
|
|
|
|
|
|
|
| 3270 |
// online softmax
|
| 3271 |
{
|
| 3272 |
const short p = tiisg;
|
|
@@ -3286,29 +3398,32 @@ kernel void kernel_flash_attn_ext_vec_f16(
|
|
| 3286 |
|
| 3287 |
// O = diag(ms)*O
|
| 3288 |
#pragma unroll
|
| 3289 |
-
for (short ii = 0; ii <
|
| 3290 |
-
lo[ii/
|
| 3291 |
}
|
| 3292 |
}
|
| 3293 |
|
|
|
|
|
|
|
| 3294 |
// O = O + (Q*K^T)*V
|
| 3295 |
{
|
| 3296 |
#pragma unroll
|
| 3297 |
for (short cc = 0; cc < C/4; ++cc) {
|
| 3298 |
-
device const
|
|
|
|
|
|
|
| 3299 |
|
| 3300 |
#pragma unroll
|
| 3301 |
-
for (short ii = 0; ii <
|
| 3302 |
-
const short i = ii +
|
|
|
|
|
|
|
|
|
|
| 3303 |
|
| 3304 |
-
lo[ii/
|
| 3305 |
-
lo[ii/NW] += pv4[i + 1*(nb21/8)] * ss[4*cc + 1];
|
| 3306 |
-
lo[ii/NW] += pv4[i + 2*(nb21/8)] * ss[4*cc + 2];
|
| 3307 |
-
lo[ii/NW] += pv4[i + 3*(nb21/8)] * ss[4*cc + 3];
|
| 3308 |
}
|
| 3309 |
}
|
| 3310 |
}
|
| 3311 |
-
|
| 3312 |
}
|
| 3313 |
|
| 3314 |
// these are needed for reducing the results from the simdgroups (reuse the ss buffer)
|
|
@@ -3318,10 +3433,32 @@ kernel void kernel_flash_attn_ext_vec_f16(
|
|
| 3318 |
}
|
| 3319 |
}
|
| 3320 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3321 |
// store results to shared memory
|
| 3322 |
-
for (short
|
| 3323 |
-
|
| 3324 |
-
sr4[i] = lo[ii/NW];
|
| 3325 |
}
|
| 3326 |
|
| 3327 |
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
@@ -3348,30 +3485,41 @@ kernel void kernel_flash_attn_ext_vec_f16(
|
|
| 3348 |
}
|
| 3349 |
|
| 3350 |
// O_0 = diag(ms0)*O_0 + diag(ms1)*O_1
|
| 3351 |
-
for (short
|
| 3352 |
-
|
| 3353 |
-
sr4[i] = sr4[i]*ms0 + sr4[i + r*D4]*ms1;
|
| 3354 |
}
|
| 3355 |
}
|
| 3356 |
|
| 3357 |
threadgroup_barrier(mem_flags::mem_threadgroup);
|
| 3358 |
}
|
| 3359 |
|
| 3360 |
-
device
|
| 3361 |
|
| 3362 |
// final rescale with 1/S and store to global memory
|
| 3363 |
if (sgitg == 0) {
|
| 3364 |
const float S = ss[0];
|
| 3365 |
|
| 3366 |
-
for (short
|
| 3367 |
-
|
| 3368 |
-
dst4[(iq3*ne2*ne1 + iq2 + (iq1)*ne1)*D4 + i] = (float4) sr4[i]/S;
|
| 3369 |
}
|
| 3370 |
}
|
| 3371 |
}
|
| 3372 |
|
| 3373 |
-
|
| 3374 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3375 |
|
| 3376 |
template<typename T0, typename T1>
|
| 3377 |
kernel void kernel_cpy(
|
|
|
|
| 2723 |
dst[tpig] = src0[tpig] > 0.0f ? src0[tpig] : src0[tpig] * slope;
|
| 2724 |
}
|
| 2725 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2726 |
// ref: https://arxiv.org/pdf/2307.08691.pdf
|
| 2727 |
+
// D - head size, Q - queries per threadgroup, KV - key/value processed per each simdgroup, C - cache items per threadgroup
|
| 2728 |
+
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &), short D, short Q = 8, short KV = 8, short C = 32>
|
| 2729 |
+
kernel void kernel_flash_attn_ext(
|
| 2730 |
device const char * q,
|
| 2731 |
device const char * k,
|
| 2732 |
device const char * v,
|
|
|
|
| 2764 |
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
|
| 2765 |
const short nsg = ntg.y; // number of simdgroups
|
| 2766 |
|
| 2767 |
+
const int iq3 = tgpig[2];
|
| 2768 |
+
const int iq2 = tgpig[1];
|
| 2769 |
+
const int iq1 = tgpig[0]*Q;
|
| 2770 |
|
| 2771 |
+
const short D4 = D/4;
|
| 2772 |
+
const short D8 = D/8;
|
| 2773 |
+
const short D16 = D/16;
|
| 2774 |
+
const short NW = N_SIMDWIDTH;
|
| 2775 |
+
const short SH = (C + Q); // shared memory per simdgroup in (half)
|
| 2776 |
|
| 2777 |
const short T = D + 2*nsg*SH; // shared memory size per query in (half)
|
| 2778 |
const short TF = T/2; // shared memory size per query in (float)
|
|
|
|
| 2782 |
threadgroup half4 * sq4 = (threadgroup half4 *) (shared + 0*D); // same as above but in half4
|
| 2783 |
threadgroup float * ss = (threadgroup float *) (shared + 2*sgitg*SH + 1*D); // scratch buffer for attention and diagonal matrix
|
| 2784 |
|
| 2785 |
+
threadgroup half * skv = (threadgroup half *) (shared + sgitg*(4*16*KV) + Q*T); // scratch buffer to load K and V in shared memory
|
| 2786 |
+
threadgroup half4x4 * skv4 = (threadgroup half4x4 *) (shared + sgitg*(4*16*KV) + Q*T); // same as above but in half4x4
|
| 2787 |
+
|
| 2788 |
// store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper)
|
| 2789 |
simdgroup_half8x8 lo[D8];
|
| 2790 |
|
|
|
|
| 2816 |
threadgroup_barrier(mem_flags::mem_threadgroup);
|
| 2817 |
|
| 2818 |
{
|
| 2819 |
+
float S[Q] = { [0 ... Q-1] = 0.0f };
|
| 2820 |
float M[Q] = { [0 ... Q-1] = -FLT_MAX/2 };
|
| 2821 |
|
| 2822 |
+
// thread indices inside the simdgroup
|
| 2823 |
+
const short tx = tiisg%4;
|
| 2824 |
+
const short ty = tiisg/4;
|
| 2825 |
+
|
| 2826 |
// assume K and V are same shape
|
| 2827 |
const short ne22 = ne12;
|
| 2828 |
const short ne23 = ne13;
|
| 2829 |
|
| 2830 |
+
// broadcast k
|
| 2831 |
const short rk2 = ne02/ne12;
|
| 2832 |
const short rk3 = ne03/ne13;
|
| 2833 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2834 |
const short ik2 = iq2/rk2;
|
| 2835 |
const short ik3 = iq3/rk3;
|
| 2836 |
|
| 2837 |
+
// broadcast v
|
| 2838 |
+
const short rv2 = ne02/ne22;
|
| 2839 |
+
const short rv3 = ne03/ne23;
|
| 2840 |
+
|
| 2841 |
const short iv2 = iq2/rv2;
|
| 2842 |
const short iv3 = iq3/rv3;
|
| 2843 |
|
|
|
|
| 2876 |
for (short cc = 0; cc < C/8; ++cc) {
|
| 2877 |
simdgroup_float8x8 mqk = make_filled_simdgroup_matrix<float, 8>(0.h);
|
| 2878 |
|
| 2879 |
+
// this is compile-time check, so it does not have runtime overhead
|
| 2880 |
+
if (is_same<block_q, half4x4>::value) {
|
| 2881 |
+
// we can read directly from global memory
|
| 2882 |
+
device const half * pk = (device const half *) ((device const char *) k + ((ic + 8*cc)*nb11 + ik2*nb12 + ik3*nb13));
|
| 2883 |
+
|
| 2884 |
+
for (short i = 0; i < D8; ++i) {
|
| 2885 |
+
simdgroup_half8x8 mk;
|
| 2886 |
+
simdgroup_load(mk, pk + i*8, nb11/sizeof(half), 0, true); // transpose
|
| 2887 |
+
|
| 2888 |
+
simdgroup_multiply_accumulate(mqk, mq[i], mk, mqk);
|
| 2889 |
+
}
|
| 2890 |
+
} else {
|
| 2891 |
+
for (short ii = 0; ii < D16; ii += 4) {
|
| 2892 |
+
device const block_q * pk4 = (device const block_q *) ((device const char *) k + ((ic + 8*cc + ty)*nb11 + ik2*nb12 + ik3*nb13));
|
| 2893 |
+
|
| 2894 |
+
if (D16%4 == 0) {
|
| 2895 |
+
// the head is evenly divisible by 4*16 = 64, so no need for bound checks
|
| 2896 |
+
half4x4 tmp;
|
| 2897 |
+
dequantize_func(pk4 + (ii + tx)/nl, (ii + tx)%nl, tmp);
|
| 2898 |
+
skv4[4*ty + tx] = tmp;
|
| 2899 |
|
| 2900 |
+
simdgroup_barrier(mem_flags::mem_threadgroup);
|
|
|
|
|
|
|
| 2901 |
|
| 2902 |
+
#pragma unroll
|
| 2903 |
+
for (short k = 0; k < 4; ++k) {
|
| 2904 |
+
simdgroup_half8x8 mk;
|
| 2905 |
+
|
| 2906 |
+
simdgroup_load(mk, skv + 16*k + 0*8, 4*16, 0, true); // transpose
|
| 2907 |
+
simdgroup_multiply_accumulate(mqk, mq[2*(ii + k) + 0], mk, mqk);
|
| 2908 |
+
|
| 2909 |
+
simdgroup_load(mk, skv + 16*k + 1*8, 4*16, 0, true); // transpose
|
| 2910 |
+
simdgroup_multiply_accumulate(mqk, mq[2*(ii + k) + 1], mk, mqk);
|
| 2911 |
+
}
|
| 2912 |
+
} else {
|
| 2913 |
+
if (ii + tx < D16) {
|
| 2914 |
+
half4x4 tmp;
|
| 2915 |
+
dequantize_func(pk4 + (ii + tx)/nl, (ii + tx)%nl, tmp);
|
| 2916 |
+
skv4[4*ty + tx] = tmp;
|
| 2917 |
+
}
|
| 2918 |
+
|
| 2919 |
+
simdgroup_barrier(mem_flags::mem_threadgroup);
|
| 2920 |
+
|
| 2921 |
+
for (short k = 0; k < 4 && ii + k < D16; ++k) {
|
| 2922 |
+
simdgroup_half8x8 mk;
|
| 2923 |
+
|
| 2924 |
+
simdgroup_load(mk, skv + 16*k + 0*8, 4*16, 0, true); // transpose
|
| 2925 |
+
simdgroup_multiply_accumulate(mqk, mq[2*(ii + k) + 0], mk, mqk);
|
| 2926 |
+
|
| 2927 |
+
simdgroup_load(mk, skv + 16*k + 1*8, 4*16, 0, true); // transpose
|
| 2928 |
+
simdgroup_multiply_accumulate(mqk, mq[2*(ii + k) + 1], mk, mqk);
|
| 2929 |
+
}
|
| 2930 |
+
}
|
| 2931 |
+
}
|
| 2932 |
}
|
| 2933 |
|
| 2934 |
simdgroup_store(mqk, ss + 8*cc, TF, 0, false);
|
|
|
|
| 2993 |
// O = O + (Q*K^T)*V
|
| 2994 |
{
|
| 2995 |
for (short cc = 0; cc < C/8; ++cc) {
|
| 2996 |
+
simdgroup_float8x8 ms;
|
| 2997 |
+
simdgroup_load(ms, ss + 8*cc, TF, 0, false);
|
| 2998 |
+
|
| 2999 |
+
if (is_same<block_q, half4x4>::value) {
|
| 3000 |
+
// we can read directly from global memory
|
| 3001 |
+
device const half * pv = (device const half *) ((device const char *) v + ((ic + 8*cc)*nb21 + iv2*nb22 + iv3*nb23));
|
| 3002 |
+
#pragma unroll
|
| 3003 |
+
for (short i = 0; i < D8; ++i) {
|
| 3004 |
+
simdgroup_half8x8 mv;
|
| 3005 |
+
simdgroup_load(mv, pv + i*8, nb21/sizeof(half), 0, false);
|
| 3006 |
|
| 3007 |
+
simdgroup_multiply_accumulate(lo[i], ms, mv, lo[i]);
|
| 3008 |
+
}
|
| 3009 |
+
} else {
|
| 3010 |
+
for (short ii = 0; ii < D16; ii += 4) {
|
| 3011 |
+
device const block_q * pv4 = (device const block_q *) ((device const char *) v + ((ic + 8*cc + ty)*nb21 + iv2*nb22 + iv3*nb23));
|
| 3012 |
+
|
| 3013 |
+
if (D16%4 == 0) {
|
| 3014 |
+
// no need for bound checks
|
| 3015 |
+
half4x4 tmp;
|
| 3016 |
+
dequantize_func(pv4 + (ii + tx)/nl, (ii + tx)%nl, tmp);
|
| 3017 |
+
skv4[4*ty + tx] = tmp;
|
| 3018 |
|
| 3019 |
+
simdgroup_barrier(mem_flags::mem_threadgroup);
|
|
|
|
| 3020 |
|
| 3021 |
+
#pragma unroll
|
| 3022 |
+
for (short k = 0; k < 4; ++k) {
|
| 3023 |
+
simdgroup_half8x8 mv;
|
| 3024 |
+
|
| 3025 |
+
simdgroup_load(mv, skv + 16*k + 0*8, 4*16, 0, false);
|
| 3026 |
+
simdgroup_multiply_accumulate(lo[2*(ii + k) + 0], ms, mv, lo[2*(ii + k) + 0]);
|
| 3027 |
+
|
| 3028 |
+
simdgroup_load(mv, skv + 16*k + 1*8, 4*16, 0, false);
|
| 3029 |
+
simdgroup_multiply_accumulate(lo[2*(ii + k) + 1], ms, mv, lo[2*(ii + k) + 1]);
|
| 3030 |
+
}
|
| 3031 |
+
} else {
|
| 3032 |
+
if (ii + tx < D16) {
|
| 3033 |
+
half4x4 tmp;
|
| 3034 |
+
dequantize_func(pv4 + (ii + tx)/nl, (ii + tx)%nl, tmp);
|
| 3035 |
+
skv4[4*ty + tx] = tmp;
|
| 3036 |
+
}
|
| 3037 |
+
|
| 3038 |
+
simdgroup_barrier(mem_flags::mem_threadgroup);
|
| 3039 |
+
|
| 3040 |
+
for (short k = 0; k < 4 && ii + k < D16; ++k) {
|
| 3041 |
+
simdgroup_half8x8 mv;
|
| 3042 |
+
|
| 3043 |
+
simdgroup_load(mv, skv + 16*k + 0*8, 4*16, 0, false);
|
| 3044 |
+
simdgroup_multiply_accumulate(lo[2*(ii + k) + 0], ms, mv, lo[2*(ii + k) + 0]);
|
| 3045 |
+
|
| 3046 |
+
simdgroup_load(mv, skv + 16*k + 1*8, 4*16, 0, false);
|
| 3047 |
+
simdgroup_multiply_accumulate(lo[2*(ii + k) + 1], ms, mv, lo[2*(ii + k) + 1]);
|
| 3048 |
+
}
|
| 3049 |
+
}
|
| 3050 |
+
}
|
| 3051 |
}
|
| 3052 |
}
|
| 3053 |
}
|
|
|
|
| 3064 |
|
| 3065 |
// reduce the warps sequentially
|
| 3066 |
for (short sg = 1; sg < nsg; ++sg) {
|
| 3067 |
+
float S = { 0.0f };
|
| 3068 |
float M = { -FLT_MAX/2 };
|
| 3069 |
|
| 3070 |
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
|
|
| 3143 |
}
|
| 3144 |
}
|
| 3145 |
|
| 3146 |
+
typedef decltype(kernel_flash_attn_ext<half4x4, 1, dequantize_f16, 64>) flash_attn_ext_t;
|
| 3147 |
+
|
| 3148 |
+
template [[host_name("kernel_flash_attn_ext_f16_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<half4x4, 1, dequantize_f16, 64>;
|
| 3149 |
+
template [[host_name("kernel_flash_attn_ext_f16_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<half4x4, 1, dequantize_f16, 80>;
|
| 3150 |
+
template [[host_name("kernel_flash_attn_ext_f16_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<half4x4, 1, dequantize_f16, 96>;
|
| 3151 |
+
template [[host_name("kernel_flash_attn_ext_f16_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<half4x4, 1, dequantize_f16, 112>;
|
| 3152 |
+
template [[host_name("kernel_flash_attn_ext_f16_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<half4x4, 1, dequantize_f16, 128>;
|
| 3153 |
+
template [[host_name("kernel_flash_attn_ext_f16_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<half4x4, 1, dequantize_f16, 256>;
|
| 3154 |
+
|
| 3155 |
+
template [[host_name("kernel_flash_attn_ext_q4_0_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q4_0, 2, dequantize_q4_0, 64>;
|
| 3156 |
+
template [[host_name("kernel_flash_attn_ext_q4_0_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q4_0, 2, dequantize_q4_0, 80>;
|
| 3157 |
+
template [[host_name("kernel_flash_attn_ext_q4_0_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q4_0, 2, dequantize_q4_0, 96>;
|
| 3158 |
+
template [[host_name("kernel_flash_attn_ext_q4_0_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q4_0, 2, dequantize_q4_0, 112>;
|
| 3159 |
+
template [[host_name("kernel_flash_attn_ext_q4_0_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q4_0, 2, dequantize_q4_0, 128>;
|
| 3160 |
+
template [[host_name("kernel_flash_attn_ext_q4_0_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q4_0, 2, dequantize_q4_0, 256>;
|
| 3161 |
+
|
| 3162 |
+
template [[host_name("kernel_flash_attn_ext_q4_1_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q4_1, 2, dequantize_q4_1, 64>;
|
| 3163 |
+
template [[host_name("kernel_flash_attn_ext_q4_1_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q4_1, 2, dequantize_q4_1, 80>;
|
| 3164 |
+
template [[host_name("kernel_flash_attn_ext_q4_1_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q4_1, 2, dequantize_q4_1, 96>;
|
| 3165 |
+
template [[host_name("kernel_flash_attn_ext_q4_1_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q4_1, 2, dequantize_q4_1, 112>;
|
| 3166 |
+
template [[host_name("kernel_flash_attn_ext_q4_1_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q4_1, 2, dequantize_q4_1, 128>;
|
| 3167 |
+
template [[host_name("kernel_flash_attn_ext_q4_1_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q4_1, 2, dequantize_q4_1, 256>;
|
| 3168 |
+
|
| 3169 |
+
template [[host_name("kernel_flash_attn_ext_q5_0_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q5_0, 2, dequantize_q5_0, 64>;
|
| 3170 |
+
template [[host_name("kernel_flash_attn_ext_q5_0_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q5_0, 2, dequantize_q5_0, 80>;
|
| 3171 |
+
template [[host_name("kernel_flash_attn_ext_q5_0_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q5_0, 2, dequantize_q5_0, 96>;
|
| 3172 |
+
template [[host_name("kernel_flash_attn_ext_q5_0_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q5_0, 2, dequantize_q5_0, 112>;
|
| 3173 |
+
template [[host_name("kernel_flash_attn_ext_q5_0_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q5_0, 2, dequantize_q5_0, 128>;
|
| 3174 |
+
template [[host_name("kernel_flash_attn_ext_q5_0_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q5_0, 2, dequantize_q5_0, 256>;
|
| 3175 |
+
|
| 3176 |
+
template [[host_name("kernel_flash_attn_ext_q5_1_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q5_1, 2, dequantize_q5_1, 64>;
|
| 3177 |
+
template [[host_name("kernel_flash_attn_ext_q5_1_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q5_1, 2, dequantize_q5_1, 80>;
|
| 3178 |
+
template [[host_name("kernel_flash_attn_ext_q5_1_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q5_1, 2, dequantize_q5_1, 96>;
|
| 3179 |
+
template [[host_name("kernel_flash_attn_ext_q5_1_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q5_1, 2, dequantize_q5_1, 112>;
|
| 3180 |
+
template [[host_name("kernel_flash_attn_ext_q5_1_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q5_1, 2, dequantize_q5_1, 128>;
|
| 3181 |
+
template [[host_name("kernel_flash_attn_ext_q5_1_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q5_1, 2, dequantize_q5_1, 256>;
|
| 3182 |
+
|
| 3183 |
+
template [[host_name("kernel_flash_attn_ext_q8_0_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q8_0, 2, dequantize_q8_0, 64>;
|
| 3184 |
+
template [[host_name("kernel_flash_attn_ext_q8_0_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q8_0, 2, dequantize_q8_0, 80>;
|
| 3185 |
+
template [[host_name("kernel_flash_attn_ext_q8_0_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q8_0, 2, dequantize_q8_0, 96>;
|
| 3186 |
+
template [[host_name("kernel_flash_attn_ext_q8_0_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q8_0, 2, dequantize_q8_0, 112>;
|
| 3187 |
+
template [[host_name("kernel_flash_attn_ext_q8_0_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q8_0, 2, dequantize_q8_0, 128>;
|
| 3188 |
+
template [[host_name("kernel_flash_attn_ext_q8_0_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q8_0, 2, dequantize_q8_0, 256>;
|
| 3189 |
+
|
| 3190 |
+
// NOTE: can use half instead of float precision for some extra perf
|
| 3191 |
+
// D - head size, Q - queries per threadgroup, C - cache items per threadgroup
|
| 3192 |
+
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread float4x4 &), short D, short Q = 1, short C = 32>
|
| 3193 |
+
kernel void kernel_flash_attn_ext_vec(
|
| 3194 |
device const char * q,
|
| 3195 |
device const char * k,
|
| 3196 |
device const char * v,
|
|
|
|
| 3228 |
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
|
| 3229 |
const short nsg = ntg.y; // number of simdgroups
|
| 3230 |
|
| 3231 |
+
const int iq3 = tgpig[2];
|
| 3232 |
+
const int iq2 = tgpig[1];
|
| 3233 |
+
const int iq1 = tgpig[0];
|
| 3234 |
|
| 3235 |
+
const short D4 = D/4;
|
| 3236 |
+
const short D16 = D/16;
|
| 3237 |
+
const short NW = N_SIMDWIDTH;
|
| 3238 |
+
const short NW4 = NW/4;
|
| 3239 |
+
const short SH = C; // shared memory per simdgroup in (half)
|
| 3240 |
|
| 3241 |
const short T = D + 2*nsg*SH; // shared memory size per query in (half)
|
| 3242 |
|
| 3243 |
+
//threadgroup half * sq = (threadgroup half *) (shared + 0*D); // holds the query data
|
| 3244 |
+
threadgroup half4 * sq4 = (threadgroup half4 *) (shared + 0*D); // same as above but in half4
|
| 3245 |
+
threadgroup half4x4 * sq44 = (threadgroup half4x4 *) (shared + 0*D); // same as above but in half4x4
|
| 3246 |
+
threadgroup float * ss = (threadgroup float *) (shared + 2*sgitg*SH + 1*D); // scratch buffer for attention
|
| 3247 |
+
threadgroup float4 * ss4 = (threadgroup float4 *) (shared + 2*sgitg*SH + 1*D); // same as above but in half4
|
| 3248 |
+
threadgroup float4x4 * sr44 = (threadgroup float4x4 *) (shared + 2*sgitg*D + Q*T); // scratch buffer for the results
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3249 |
|
| 3250 |
// store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper)
|
| 3251 |
+
float4x4 lo[D16/NW4];
|
| 3252 |
|
| 3253 |
// load heads from Q to shared memory
|
| 3254 |
device const float4 * q4 = (device const float4 *) ((device const char *) q + (iq1*nb01 + iq2*nb02 + iq3*nb03));
|
|
|
|
| 3262 |
}
|
| 3263 |
|
| 3264 |
// zero out lo
|
| 3265 |
+
for (short i = 0; i < D16/NW4; i += NW4) {
|
| 3266 |
+
lo[i] = float4x4(0.0f);
|
| 3267 |
}
|
| 3268 |
|
| 3269 |
// zero out shared memory SH
|
|
|
|
| 3274 |
threadgroup_barrier(mem_flags::mem_threadgroup);
|
| 3275 |
|
| 3276 |
{
|
| 3277 |
+
float S = 0.0f;
|
| 3278 |
+
float M = -FLT_MAX/2;
|
| 3279 |
+
|
| 3280 |
+
// thread indices inside the simdgroup
|
| 3281 |
+
const short tx = tiisg%8;
|
| 3282 |
+
const short ty = tiisg/8;
|
| 3283 |
|
| 3284 |
// assume K and V are same shape
|
| 3285 |
const short ne22 = ne12;
|
| 3286 |
const short ne23 = ne13;
|
| 3287 |
|
| 3288 |
+
// broadcast k
|
| 3289 |
const short rk2 = ne02/ne12;
|
| 3290 |
const short rk3 = ne03/ne13;
|
| 3291 |
|
| 3292 |
+
const short ik2 = iq2/rk2;
|
| 3293 |
+
const short ik3 = iq3/rk3;
|
| 3294 |
+
|
| 3295 |
+
// broadcast v
|
| 3296 |
const short rv2 = ne02/ne22;
|
| 3297 |
const short rv3 = ne03/ne23;
|
| 3298 |
|
| 3299 |
+
const short iv2 = iq2/rv2;
|
| 3300 |
+
const short iv3 = iq3/rv3;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3301 |
|
| 3302 |
// load the queries from shared memory into local memory
|
| 3303 |
+
float4x4 mq[D16/NW4];
|
| 3304 |
|
| 3305 |
+
for (short ii = 0; ii < D16; ii += NW4) {
|
| 3306 |
+
mq[ii/NW4] = (float4x4) sq44[ii + tx];
|
|
|
|
| 3307 |
}
|
| 3308 |
|
| 3309 |
// pointer to the mask
|
| 3310 |
+
device const half * mp = (device const half *) (mask + iq1*nb31);
|
| 3311 |
+
|
| 3312 |
+
float slope = 1.0f;
|
| 3313 |
+
|
| 3314 |
+
// ALiBi
|
| 3315 |
+
if (max_bias > 0.0f) {
|
| 3316 |
+
const uint32_t h = iq2;
|
| 3317 |
+
|
| 3318 |
+
const float base = h < n_head_log2 ? m0 : m1;
|
| 3319 |
+
const int exp = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
|
| 3320 |
+
|
| 3321 |
+
slope = pow(base, exp);
|
| 3322 |
+
}
|
| 3323 |
|
| 3324 |
// loop over the KV cache
|
| 3325 |
// each simdgroup handles blocks of Q rows and C columns
|
|
|
|
| 3331 |
|
| 3332 |
// Q*K^T
|
| 3333 |
{
|
| 3334 |
+
// each simdgroup processes 1 query and 4 keys
|
| 3335 |
for (short cc = 0; cc < C/4; ++cc) {
|
| 3336 |
+
float mqk = 0.0;
|
| 3337 |
|
| 3338 |
+
device const block_q * pk = (device const block_q *) ((device const char *) k + ((ic + 4*cc + ty)*nb11 + ik2*nb12 + ik3*nb13));
|
| 3339 |
|
| 3340 |
#pragma unroll
|
| 3341 |
+
for (short ii = 0; ii < D16; ii += NW4) {
|
| 3342 |
+
const short i = ii + tx;
|
| 3343 |
|
| 3344 |
float4x4 mk;
|
| 3345 |
+
dequantize_func(pk + i/nl, i%nl, mk);
|
|
|
|
|
|
|
|
|
|
| 3346 |
|
| 3347 |
+
mqk +=
|
| 3348 |
+
dot(mq[ii/NW4][0], mk[0]) +
|
| 3349 |
+
dot(mq[ii/NW4][1], mk[1]) +
|
| 3350 |
+
dot(mq[ii/NW4][2], mk[2]) +
|
| 3351 |
+
dot(mq[ii/NW4][3], mk[3]);
|
| 3352 |
}
|
| 3353 |
|
| 3354 |
+
// simdgroup reduce
|
| 3355 |
+
// [ 0 .. 7] -> [ 0]
|
| 3356 |
+
// [ 8 .. 15] -> [ 8]
|
| 3357 |
+
// [16 .. 23] -> [16]
|
| 3358 |
+
// [24 .. 31] -> [24]
|
| 3359 |
+
//mqk += simd_shuffle_down(mqk, 16);
|
| 3360 |
+
//mqk += simd_shuffle_down(mqk, 8);
|
| 3361 |
mqk += simd_shuffle_down(mqk, 4);
|
| 3362 |
mqk += simd_shuffle_down(mqk, 2);
|
| 3363 |
mqk += simd_shuffle_down(mqk, 1);
|
| 3364 |
|
| 3365 |
// mqk = mqk*scale + mask*slope
|
| 3366 |
+
if (tx == 0) {
|
| 3367 |
mqk *= scale;
|
| 3368 |
|
| 3369 |
if (logit_softcap != 0.0f) {
|
| 3370 |
mqk = logit_softcap*precise::tanh(mqk);
|
| 3371 |
}
|
| 3372 |
|
| 3373 |
+
mqk += (mask != q) ? ((float) mp[ic + 4*cc + ty])*slope : (float) 0.0f;
|
| 3374 |
|
| 3375 |
+
ss[4*cc + ty] = mqk;
|
| 3376 |
}
|
| 3377 |
}
|
| 3378 |
}
|
| 3379 |
|
| 3380 |
+
simdgroup_barrier(mem_flags::mem_threadgroup);
|
| 3381 |
+
|
| 3382 |
// online softmax
|
| 3383 |
{
|
| 3384 |
const short p = tiisg;
|
|
|
|
| 3398 |
|
| 3399 |
// O = diag(ms)*O
|
| 3400 |
#pragma unroll
|
| 3401 |
+
for (short ii = 0; ii < D16; ii += NW4) {
|
| 3402 |
+
lo[ii/NW4] *= ms;
|
| 3403 |
}
|
| 3404 |
}
|
| 3405 |
|
| 3406 |
+
simdgroup_barrier(mem_flags::mem_threadgroup);
|
| 3407 |
+
|
| 3408 |
// O = O + (Q*K^T)*V
|
| 3409 |
{
|
| 3410 |
#pragma unroll
|
| 3411 |
for (short cc = 0; cc < C/4; ++cc) {
|
| 3412 |
+
device const block_q * pv4 = (device const block_q *) ((device const char *) v + ((ic + 4*cc + ty)*nb21 + iv2*nb22 + iv3*nb23));
|
| 3413 |
+
|
| 3414 |
+
const float4x4 lss(ss[4*cc + ty]);
|
| 3415 |
|
| 3416 |
#pragma unroll
|
| 3417 |
+
for (short ii = 0; ii < D16; ii += NW4) {
|
| 3418 |
+
const short i = ii + tx;
|
| 3419 |
+
|
| 3420 |
+
float4x4 mv;
|
| 3421 |
+
dequantize_func(pv4 + i/nl, i%nl, mv);
|
| 3422 |
|
| 3423 |
+
lo[ii/NW4] += mv*lss;
|
|
|
|
|
|
|
|
|
|
| 3424 |
}
|
| 3425 |
}
|
| 3426 |
}
|
|
|
|
| 3427 |
}
|
| 3428 |
|
| 3429 |
// these are needed for reducing the results from the simdgroups (reuse the ss buffer)
|
|
|
|
| 3433 |
}
|
| 3434 |
}
|
| 3435 |
|
| 3436 |
+
// simdgroup reduce
|
| 3437 |
+
// [ 0, 8, 16, 24] -> [ 0]
|
| 3438 |
+
// [ 1, 9, 17, 25] -> [ 1]
|
| 3439 |
+
// [ 2, 10, 18, 26] -> [ 2]
|
| 3440 |
+
// [ 3, 11, 19, 27] -> [ 3]
|
| 3441 |
+
// [ 4, 12, 20, 28] -> [ 4]
|
| 3442 |
+
// [ 5, 13, 21, 29] -> [ 5]
|
| 3443 |
+
// [ 6, 14, 22, 30] -> [ 6]
|
| 3444 |
+
// [ 7, 15, 23, 31] -> [ 7]
|
| 3445 |
+
for (short ii = 0; ii < D16; ii += NW4) {
|
| 3446 |
+
lo[ii/NW4][0] += simd_shuffle_down(lo[ii/NW4][0], 16);
|
| 3447 |
+
lo[ii/NW4][0] += simd_shuffle_down(lo[ii/NW4][0], 8);
|
| 3448 |
+
|
| 3449 |
+
lo[ii/NW4][1] += simd_shuffle_down(lo[ii/NW4][1], 16);
|
| 3450 |
+
lo[ii/NW4][1] += simd_shuffle_down(lo[ii/NW4][1], 8);
|
| 3451 |
+
|
| 3452 |
+
lo[ii/NW4][2] += simd_shuffle_down(lo[ii/NW4][2], 16);
|
| 3453 |
+
lo[ii/NW4][2] += simd_shuffle_down(lo[ii/NW4][2], 8);
|
| 3454 |
+
|
| 3455 |
+
lo[ii/NW4][3] += simd_shuffle_down(lo[ii/NW4][3], 16);
|
| 3456 |
+
lo[ii/NW4][3] += simd_shuffle_down(lo[ii/NW4][3], 8);
|
| 3457 |
+
}
|
| 3458 |
+
|
| 3459 |
// store results to shared memory
|
| 3460 |
+
for (short i = tiisg; i < D16; i += NW4) {
|
| 3461 |
+
sr44[i] = lo[i/NW4];
|
|
|
|
| 3462 |
}
|
| 3463 |
|
| 3464 |
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
|
|
| 3485 |
}
|
| 3486 |
|
| 3487 |
// O_0 = diag(ms0)*O_0 + diag(ms1)*O_1
|
| 3488 |
+
for (short i = tiisg; i < D16; i += NW) {
|
| 3489 |
+
sr44[i] = sr44[i]*ms0 + sr44[i + r*D16]*ms1;
|
|
|
|
| 3490 |
}
|
| 3491 |
}
|
| 3492 |
|
| 3493 |
threadgroup_barrier(mem_flags::mem_threadgroup);
|
| 3494 |
}
|
| 3495 |
|
| 3496 |
+
device float4x4 * dst44 = (device float4x4 *) dst;
|
| 3497 |
|
| 3498 |
// final rescale with 1/S and store to global memory
|
| 3499 |
if (sgitg == 0) {
|
| 3500 |
const float S = ss[0];
|
| 3501 |
|
| 3502 |
+
for (short i = tiisg; i < D16; i += NW) {
|
| 3503 |
+
dst44[(iq3*ne2*ne1 + iq2 + (iq1)*ne1)*D16 + i] = sr44[i]/S;
|
|
|
|
| 3504 |
}
|
| 3505 |
}
|
| 3506 |
}
|
| 3507 |
|
| 3508 |
+
typedef decltype(kernel_flash_attn_ext_vec<half4x4, 1, dequantize_f16, 64>) flash_attn_ext_vec_t;
|
| 3509 |
+
|
| 3510 |
+
template [[host_name("kernel_flash_attn_ext_vec_f16_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<half4x4, 1, dequantize_f16, 128>;
|
| 3511 |
+
template [[host_name("kernel_flash_attn_ext_vec_q4_0_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<block_q4_0, 2, dequantize_q4_0, 128>;
|
| 3512 |
+
template [[host_name("kernel_flash_attn_ext_vec_q4_1_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<block_q4_1, 2, dequantize_q4_1, 128>;
|
| 3513 |
+
template [[host_name("kernel_flash_attn_ext_vec_q5_0_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<block_q5_0, 2, dequantize_q5_0, 128>;
|
| 3514 |
+
template [[host_name("kernel_flash_attn_ext_vec_q5_1_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<block_q5_1, 2, dequantize_q5_1, 128>;
|
| 3515 |
+
template [[host_name("kernel_flash_attn_ext_vec_q8_0_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<block_q8_0, 2, dequantize_q8_0, 128>;
|
| 3516 |
+
|
| 3517 |
+
template [[host_name("kernel_flash_attn_ext_vec_f16_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<half4x4, 1, dequantize_f16, 256>;
|
| 3518 |
+
template [[host_name("kernel_flash_attn_ext_vec_q4_0_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<block_q4_0, 2, dequantize_q4_0, 256>;
|
| 3519 |
+
template [[host_name("kernel_flash_attn_ext_vec_q4_1_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<block_q4_1, 2, dequantize_q4_1, 256>;
|
| 3520 |
+
template [[host_name("kernel_flash_attn_ext_vec_q5_0_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<block_q5_0, 2, dequantize_q5_0, 256>;
|
| 3521 |
+
template [[host_name("kernel_flash_attn_ext_vec_q5_1_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<block_q5_1, 2, dequantize_q5_1, 256>;
|
| 3522 |
+
template [[host_name("kernel_flash_attn_ext_vec_q8_0_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<block_q8_0, 2, dequantize_q8_0, 256>;
|
| 3523 |
|
| 3524 |
template<typename T0, typename T1>
|
| 3525 |
kernel void kernel_cpy(
|