Spaces:
Sleeping
Sleeping
metal : add BF16 support (llama/8439)
Browse files* ggml : add initial BF16 support
ggml-ci
* metal : add mul_mat_id BF16 support
ggml-ci
* metal : check for bfloat support on the Metal device
ggml-ci
* metal : better var names [no ci]
* metal : do not build bfloat kernels when not supported
ggml-ci
* metal : try to fix BF16 support check
ggml-ci
* metal : this should correctly check bfloat support
- ggml/src/ggml-metal.m +259 -179
- ggml/src/ggml-metal.metal +54 -4
ggml/src/ggml-metal.m
CHANGED
|
@@ -36,16 +36,18 @@ static struct ggml_backend_metal_device_context {
|
|
| 36 |
id<MTLDevice> mtl_device;
|
| 37 |
int mtl_device_ref_count;
|
| 38 |
|
| 39 |
-
bool
|
| 40 |
-
bool
|
|
|
|
| 41 |
|
| 42 |
char name[128];
|
| 43 |
} g_ggml_ctx_dev_main = {
|
| 44 |
-
/*.mtl_device
|
| 45 |
-
/*.mtl_device_ref_count
|
| 46 |
-
/*.
|
| 47 |
-
/*.
|
| 48 |
-
/*.
|
|
|
|
| 49 |
};
|
| 50 |
|
| 51 |
// acquire
|
|
@@ -55,10 +57,13 @@ static id<MTLDevice> ggml_backend_metal_device_acq(struct ggml_backend_metal_dev
|
|
| 55 |
if (ctx->mtl_device == nil) {
|
| 56 |
ctx->mtl_device = MTLCreateSystemDefaultDevice();
|
| 57 |
|
| 58 |
-
ctx->
|
| 59 |
-
ctx->
|
| 60 |
|
| 61 |
-
ctx->
|
|
|
|
|
|
|
|
|
|
| 62 |
|
| 63 |
strncpy(ctx->name, [[ctx->mtl_device name] UTF8String], sizeof(ctx->name) - 1);
|
| 64 |
}
|
|
@@ -120,6 +125,7 @@ enum ggml_metal_kernel_type {
|
|
| 120 |
GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF_8,
|
| 121 |
GGML_METAL_KERNEL_TYPE_GET_ROWS_F32,
|
| 122 |
GGML_METAL_KERNEL_TYPE_GET_ROWS_F16,
|
|
|
|
| 123 |
GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_0,
|
| 124 |
GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_1,
|
| 125 |
GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_0,
|
|
@@ -146,10 +152,14 @@ enum ggml_metal_kernel_type {
|
|
| 146 |
GGML_METAL_KERNEL_TYPE_SSM_CONV_F32,
|
| 147 |
GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32,
|
| 148 |
GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32,
|
| 149 |
-
GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16,
|
| 150 |
GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32,
|
| 151 |
GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW,
|
| 152 |
GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 153 |
GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_0_F32,
|
| 154 |
GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_1_F32,
|
| 155 |
GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32,
|
|
@@ -170,10 +180,11 @@ enum ggml_metal_kernel_type {
|
|
| 170 |
GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32,
|
| 171 |
GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_XS_F32,
|
| 172 |
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32,
|
| 173 |
-
//GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F16,
|
| 174 |
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32,
|
| 175 |
//GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_1ROW,
|
| 176 |
//GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_L4,
|
|
|
|
|
|
|
| 177 |
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_0_F32,
|
| 178 |
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_1_F32,
|
| 179 |
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_0_F32,
|
|
@@ -195,6 +206,7 @@ enum ggml_metal_kernel_type {
|
|
| 195 |
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32,
|
| 196 |
GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32,
|
| 197 |
GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32,
|
|
|
|
| 198 |
GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32,
|
| 199 |
GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F32,
|
| 200 |
GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32,
|
|
@@ -216,6 +228,7 @@ enum ggml_metal_kernel_type {
|
|
| 216 |
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32,
|
| 217 |
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32,
|
| 218 |
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32,
|
|
|
|
| 219 |
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32,
|
| 220 |
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F32,
|
| 221 |
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F32,
|
|
@@ -300,8 +313,11 @@ enum ggml_metal_kernel_type {
|
|
| 300 |
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H256,
|
| 301 |
GGML_METAL_KERNEL_TYPE_CPY_F32_F32,
|
| 302 |
GGML_METAL_KERNEL_TYPE_CPY_F32_F16,
|
|
|
|
| 303 |
GGML_METAL_KERNEL_TYPE_CPY_F16_F16,
|
| 304 |
GGML_METAL_KERNEL_TYPE_CPY_F16_F32,
|
|
|
|
|
|
|
| 305 |
GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0,
|
| 306 |
GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0,
|
| 307 |
GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1,
|
|
@@ -480,7 +496,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
|
|
| 480 |
// dictionary of preprocessor macros
|
| 481 |
NSMutableDictionary * prep = [NSMutableDictionary dictionary];
|
| 482 |
|
| 483 |
-
MTLCompileOptions* options = [MTLCompileOptions new];
|
| 484 |
options.preprocessorMacros = prep;
|
| 485 |
|
| 486 |
//[options setFastMathEnabled:false];
|
|
@@ -530,9 +546,10 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
|
|
| 530 |
}
|
| 531 |
}
|
| 532 |
|
| 533 |
-
GGML_LOG_INFO("%s: simdgroup reduction
|
| 534 |
-
GGML_LOG_INFO("%s: simdgroup matrix mul.
|
| 535 |
-
GGML_LOG_INFO("%s:
|
|
|
|
| 536 |
|
| 537 |
ctx->capture_next_compute = false;
|
| 538 |
ctx->capture_started = false;
|
|
@@ -578,8 +595,9 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
|
|
| 578 |
GGML_LOG_WARN("%s: skipping %-40s (not supported)\n", __func__, "kernel_"#name); \
|
| 579 |
}
|
| 580 |
|
| 581 |
-
const bool
|
| 582 |
-
const bool
|
|
|
|
| 583 |
|
| 584 |
// simd_sum and simd_max requires MTLGPUFamilyApple7
|
| 585 |
|
|
@@ -607,14 +625,15 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
|
|
| 607 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_QUICK_4, gelu_quick_4, true);
|
| 608 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU, silu, true);
|
| 609 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU_4, silu_4, true);
|
| 610 |
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16, soft_max_f16,
|
| 611 |
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16_4, soft_max_f16_4,
|
| 612 |
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32, soft_max_f32,
|
| 613 |
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32_4, soft_max_f32_4,
|
| 614 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF, diag_mask_inf, true);
|
| 615 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF_8, diag_mask_inf_8, true);
|
| 616 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_F32, get_rows_f32, true);
|
| 617 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_F16, get_rows_f16, true);
|
|
|
|
| 618 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_0, get_rows_q4_0, true);
|
| 619 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_1, get_rows_q4_1, true);
|
| 620 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_0, get_rows_q5_0, true);
|
|
@@ -635,101 +654,108 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
|
|
| 635 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL, get_rows_iq4_nl, true);
|
| 636 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS, get_rows_iq4_xs, true);
|
| 637 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_I32, get_rows_i32, true);
|
| 638 |
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RMS_NORM, rms_norm,
|
| 639 |
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GROUP_NORM, group_norm,
|
| 640 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NORM, norm, true);
|
| 641 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_CONV_F32, ssm_conv_f32, true);
|
| 642 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32, ssm_scan_f32, true);
|
| 643 |
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32, mul_mv_f32_f32,
|
| 644 |
-
GGML_METAL_ADD_KERNEL(
|
| 645 |
-
GGML_METAL_ADD_KERNEL(
|
| 646 |
-
GGML_METAL_ADD_KERNEL(
|
| 647 |
-
GGML_METAL_ADD_KERNEL(
|
| 648 |
-
GGML_METAL_ADD_KERNEL(
|
| 649 |
-
GGML_METAL_ADD_KERNEL(
|
| 650 |
-
GGML_METAL_ADD_KERNEL(
|
| 651 |
-
GGML_METAL_ADD_KERNEL(
|
| 652 |
-
GGML_METAL_ADD_KERNEL(
|
| 653 |
-
GGML_METAL_ADD_KERNEL(
|
| 654 |
-
GGML_METAL_ADD_KERNEL(
|
| 655 |
-
GGML_METAL_ADD_KERNEL(
|
| 656 |
-
GGML_METAL_ADD_KERNEL(
|
| 657 |
-
GGML_METAL_ADD_KERNEL(
|
| 658 |
-
GGML_METAL_ADD_KERNEL(
|
| 659 |
-
GGML_METAL_ADD_KERNEL(
|
| 660 |
-
GGML_METAL_ADD_KERNEL(
|
| 661 |
-
GGML_METAL_ADD_KERNEL(
|
| 662 |
-
GGML_METAL_ADD_KERNEL(
|
| 663 |
-
GGML_METAL_ADD_KERNEL(
|
| 664 |
-
GGML_METAL_ADD_KERNEL(
|
| 665 |
-
GGML_METAL_ADD_KERNEL(
|
| 666 |
-
GGML_METAL_ADD_KERNEL(
|
| 667 |
-
GGML_METAL_ADD_KERNEL(
|
| 668 |
-
|
| 669 |
-
GGML_METAL_ADD_KERNEL(
|
| 670 |
-
|
| 671 |
-
|
| 672 |
-
GGML_METAL_ADD_KERNEL(
|
| 673 |
-
|
| 674 |
-
|
| 675 |
-
|
| 676 |
-
GGML_METAL_ADD_KERNEL(
|
| 677 |
-
GGML_METAL_ADD_KERNEL(
|
| 678 |
-
GGML_METAL_ADD_KERNEL(
|
| 679 |
-
GGML_METAL_ADD_KERNEL(
|
| 680 |
-
GGML_METAL_ADD_KERNEL(
|
| 681 |
-
GGML_METAL_ADD_KERNEL(
|
| 682 |
-
GGML_METAL_ADD_KERNEL(
|
| 683 |
-
GGML_METAL_ADD_KERNEL(
|
| 684 |
-
GGML_METAL_ADD_KERNEL(
|
| 685 |
-
GGML_METAL_ADD_KERNEL(
|
| 686 |
-
GGML_METAL_ADD_KERNEL(
|
| 687 |
-
GGML_METAL_ADD_KERNEL(
|
| 688 |
-
GGML_METAL_ADD_KERNEL(
|
| 689 |
-
GGML_METAL_ADD_KERNEL(
|
| 690 |
-
GGML_METAL_ADD_KERNEL(
|
| 691 |
-
GGML_METAL_ADD_KERNEL(
|
| 692 |
-
GGML_METAL_ADD_KERNEL(
|
| 693 |
-
GGML_METAL_ADD_KERNEL(
|
| 694 |
-
GGML_METAL_ADD_KERNEL(
|
| 695 |
-
GGML_METAL_ADD_KERNEL(
|
| 696 |
-
GGML_METAL_ADD_KERNEL(
|
| 697 |
-
GGML_METAL_ADD_KERNEL(
|
| 698 |
-
GGML_METAL_ADD_KERNEL(
|
| 699 |
-
GGML_METAL_ADD_KERNEL(
|
| 700 |
-
GGML_METAL_ADD_KERNEL(
|
| 701 |
-
GGML_METAL_ADD_KERNEL(
|
| 702 |
-
GGML_METAL_ADD_KERNEL(
|
| 703 |
-
GGML_METAL_ADD_KERNEL(
|
| 704 |
-
GGML_METAL_ADD_KERNEL(
|
| 705 |
-
GGML_METAL_ADD_KERNEL(
|
| 706 |
-
GGML_METAL_ADD_KERNEL(
|
| 707 |
-
GGML_METAL_ADD_KERNEL(
|
| 708 |
-
GGML_METAL_ADD_KERNEL(
|
| 709 |
-
GGML_METAL_ADD_KERNEL(
|
| 710 |
-
GGML_METAL_ADD_KERNEL(
|
| 711 |
-
GGML_METAL_ADD_KERNEL(
|
| 712 |
-
GGML_METAL_ADD_KERNEL(
|
| 713 |
-
GGML_METAL_ADD_KERNEL(
|
| 714 |
-
GGML_METAL_ADD_KERNEL(
|
| 715 |
-
GGML_METAL_ADD_KERNEL(
|
| 716 |
-
GGML_METAL_ADD_KERNEL(
|
| 717 |
-
GGML_METAL_ADD_KERNEL(
|
| 718 |
-
GGML_METAL_ADD_KERNEL(
|
| 719 |
-
GGML_METAL_ADD_KERNEL(
|
| 720 |
-
GGML_METAL_ADD_KERNEL(
|
| 721 |
-
GGML_METAL_ADD_KERNEL(
|
| 722 |
-
GGML_METAL_ADD_KERNEL(
|
| 723 |
-
GGML_METAL_ADD_KERNEL(
|
| 724 |
-
GGML_METAL_ADD_KERNEL(
|
| 725 |
-
GGML_METAL_ADD_KERNEL(
|
| 726 |
-
GGML_METAL_ADD_KERNEL(
|
| 727 |
-
GGML_METAL_ADD_KERNEL(
|
| 728 |
-
GGML_METAL_ADD_KERNEL(
|
| 729 |
-
GGML_METAL_ADD_KERNEL(
|
| 730 |
-
GGML_METAL_ADD_KERNEL(
|
| 731 |
-
GGML_METAL_ADD_KERNEL(
|
| 732 |
-
GGML_METAL_ADD_KERNEL(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 733 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32, rope_norm_f32, true);
|
| 734 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16, rope_norm_f16, true);
|
| 735 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32, rope_neox_f32, true);
|
|
@@ -745,58 +771,61 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
|
|
| 745 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC, argsort_f32_i32_asc, true);
|
| 746 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC, argsort_f32_i32_desc, true);
|
| 747 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32, leaky_relu_f32, true);
|
| 748 |
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64, flash_attn_ext_f16_h64,
|
| 749 |
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80, flash_attn_ext_f16_h80,
|
| 750 |
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96, flash_attn_ext_f16_h96,
|
| 751 |
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112, flash_attn_ext_f16_h112,
|
| 752 |
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128, flash_attn_ext_f16_h128,
|
| 753 |
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, flash_attn_ext_f16_h256,
|
| 754 |
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H64, flash_attn_ext_q4_0_h64,
|
| 755 |
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H80, flash_attn_ext_q4_0_h80,
|
| 756 |
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H96, flash_attn_ext_q4_0_h96,
|
| 757 |
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H112, flash_attn_ext_q4_0_h112,
|
| 758 |
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H128, flash_attn_ext_q4_0_h128,
|
| 759 |
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H256, flash_attn_ext_q4_0_h256,
|
| 760 |
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H64, flash_attn_ext_q4_1_h64,
|
| 761 |
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H80, flash_attn_ext_q4_1_h80,
|
| 762 |
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H96, flash_attn_ext_q4_1_h96,
|
| 763 |
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H112, flash_attn_ext_q4_1_h112,
|
| 764 |
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H128, flash_attn_ext_q4_1_h128,
|
| 765 |
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H256, flash_attn_ext_q4_1_h256,
|
| 766 |
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H64, flash_attn_ext_q5_0_h64,
|
| 767 |
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H80, flash_attn_ext_q5_0_h80,
|
| 768 |
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H96, flash_attn_ext_q5_0_h96,
|
| 769 |
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H112, flash_attn_ext_q5_0_h112,
|
| 770 |
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H128, flash_attn_ext_q5_0_h128,
|
| 771 |
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H256, flash_attn_ext_q5_0_h256,
|
| 772 |
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H64, flash_attn_ext_q5_1_h64,
|
| 773 |
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H80, flash_attn_ext_q5_1_h80,
|
| 774 |
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H96, flash_attn_ext_q5_1_h96,
|
| 775 |
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H112, flash_attn_ext_q5_1_h112,
|
| 776 |
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H128, flash_attn_ext_q5_1_h128,
|
| 777 |
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H256, flash_attn_ext_q5_1_h256,
|
| 778 |
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H64, flash_attn_ext_q8_0_h64,
|
| 779 |
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H80, flash_attn_ext_q8_0_h80,
|
| 780 |
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H96, flash_attn_ext_q8_0_h96,
|
| 781 |
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H112, flash_attn_ext_q8_0_h112,
|
| 782 |
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H128, flash_attn_ext_q8_0_h128,
|
| 783 |
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H256, flash_attn_ext_q8_0_h256,
|
| 784 |
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128, flash_attn_ext_vec_f16_h128,
|
| 785 |
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H128, flash_attn_ext_vec_q4_0_h128,
|
| 786 |
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H128, flash_attn_ext_vec_q4_1_h128,
|
| 787 |
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H128, flash_attn_ext_vec_q5_0_h128,
|
| 788 |
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H128, flash_attn_ext_vec_q5_1_h128,
|
| 789 |
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H128, flash_attn_ext_vec_q8_0_h128,
|
| 790 |
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, flash_attn_ext_vec_f16_h256,
|
| 791 |
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H256, flash_attn_ext_vec_q4_0_h256,
|
| 792 |
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H256, flash_attn_ext_vec_q4_1_h256,
|
| 793 |
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H256, flash_attn_ext_vec_q5_0_h256,
|
| 794 |
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H256, flash_attn_ext_vec_q5_1_h256,
|
| 795 |
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H256, flash_attn_ext_vec_q8_0_h256,
|
| 796 |
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F16, cpy_f32_f16, true);
|
| 797 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F32, cpy_f32_f32, true);
|
| 798 |
-
GGML_METAL_ADD_KERNEL(
|
|
|
|
| 799 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F32, cpy_f16_f32, true);
|
|
|
|
|
|
|
|
|
|
| 800 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0, cpy_f32_q8_0, true);
|
| 801 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0, cpy_f32_q4_0, true);
|
| 802 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1, cpy_f32_q4_1, true);
|
|
@@ -886,15 +915,18 @@ static id<MTLBuffer> ggml_metal_get_buffer(struct ggml_tensor * t, size_t * offs
|
|
| 886 |
}
|
| 887 |
|
| 888 |
static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_context * ctx_dev, const struct ggml_tensor * op) {
|
| 889 |
-
|
| 890 |
-
|
| 891 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 892 |
}
|
| 893 |
}
|
| 894 |
|
| 895 |
-
const bool support_simdgroup_mm = ctx_dev->support_simdgroup_mm;
|
| 896 |
-
const bool support_simdgroup_reduction = ctx_dev->support_simdgroup_reduction;
|
| 897 |
-
|
| 898 |
switch (op->op) {
|
| 899 |
case GGML_OP_UNARY:
|
| 900 |
switch (ggml_get_unary_op(op)) {
|
|
@@ -932,7 +964,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
|
|
| 932 |
case GGML_OP_SOFT_MAX:
|
| 933 |
case GGML_OP_RMS_NORM:
|
| 934 |
case GGML_OP_GROUP_NORM:
|
| 935 |
-
return
|
| 936 |
case GGML_OP_NORM:
|
| 937 |
case GGML_OP_ROPE:
|
| 938 |
return true;
|
|
@@ -952,13 +984,13 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
|
|
| 952 |
if (op->src[1]->type != op->src[2]->type) {
|
| 953 |
return false;
|
| 954 |
}
|
| 955 |
-
return
|
| 956 |
case GGML_OP_SSM_CONV:
|
| 957 |
case GGML_OP_SSM_SCAN:
|
| 958 |
return true;
|
| 959 |
case GGML_OP_MUL_MAT:
|
| 960 |
case GGML_OP_MUL_MAT_ID:
|
| 961 |
-
return
|
| 962 |
(op->src[0]->type != GGML_TYPE_F32 || op->src[1]->type == GGML_TYPE_F32);
|
| 963 |
case GGML_OP_CPY:
|
| 964 |
case GGML_OP_DUP:
|
|
@@ -969,6 +1001,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
|
|
| 969 |
switch (op->type) {
|
| 970 |
case GGML_TYPE_F32:
|
| 971 |
case GGML_TYPE_F16:
|
|
|
|
| 972 |
case GGML_TYPE_Q8_0:
|
| 973 |
case GGML_TYPE_Q4_0:
|
| 974 |
case GGML_TYPE_Q4_1:
|
|
@@ -981,10 +1014,18 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
|
|
| 981 |
}
|
| 982 |
case GGML_TYPE_F16:
|
| 983 |
switch (op->type) {
|
| 984 |
-
|
| 985 |
-
|
| 986 |
return true;
|
| 987 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 988 |
return false;
|
| 989 |
}
|
| 990 |
default:
|
|
@@ -1855,6 +1896,7 @@ static void ggml_metal_encode_node(
|
|
| 1855 |
switch (src0->type) {
|
| 1856 |
case GGML_TYPE_F32: GGML_ASSERT(nb01 % 16 == 0); break;
|
| 1857 |
case GGML_TYPE_F16: GGML_ASSERT(nb01 % 8 == 0); break;
|
|
|
|
| 1858 |
default: break;
|
| 1859 |
}
|
| 1860 |
|
|
@@ -1863,6 +1905,7 @@ static void ggml_metal_encode_node(
|
|
| 1863 |
switch (src0->type) {
|
| 1864 |
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32 ].pipeline; break;
|
| 1865 |
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32 ].pipeline; break;
|
|
|
|
| 1866 |
case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32 ].pipeline; break;
|
| 1867 |
case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F32 ].pipeline; break;
|
| 1868 |
case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32 ].pipeline; break;
|
|
@@ -1940,6 +1983,25 @@ static void ggml_metal_encode_node(
|
|
| 1940 |
nrows = 4;
|
| 1941 |
}
|
| 1942 |
} break;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1943 |
case GGML_TYPE_Q4_0:
|
| 1944 |
{
|
| 1945 |
nth0 = 8;
|
|
@@ -2158,12 +2220,12 @@ static void ggml_metal_encode_node(
|
|
| 2158 |
if ([device supportsFamily:MTLGPUFamilyApple7] &&
|
| 2159 |
ne00 % 32 == 0 && ne00 >= 64 &&
|
| 2160 |
dst_rows > dst_rows_min) {
|
| 2161 |
-
|
| 2162 |
// some Metal matrix data types require aligned pointers
|
| 2163 |
// ref: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf (Table 2.5)
|
| 2164 |
switch (src0->type) {
|
| 2165 |
-
case GGML_TYPE_F32:
|
| 2166 |
-
case GGML_TYPE_F16:
|
|
|
|
| 2167 |
default: break;
|
| 2168 |
}
|
| 2169 |
|
|
@@ -2172,6 +2234,7 @@ static void ggml_metal_encode_node(
|
|
| 2172 |
switch (src0->type) {
|
| 2173 |
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32 ].pipeline; break;
|
| 2174 |
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32 ].pipeline; break;
|
|
|
|
| 2175 |
case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32 ].pipeline; break;
|
| 2176 |
case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F32 ].pipeline; break;
|
| 2177 |
case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F32 ].pipeline; break;
|
|
@@ -2241,6 +2304,13 @@ static void ggml_metal_encode_node(
|
|
| 2241 |
nth1 = 1;
|
| 2242 |
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32].pipeline;
|
| 2243 |
} break;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2244 |
case GGML_TYPE_Q4_0:
|
| 2245 |
{
|
| 2246 |
nth0 = 8;
|
|
@@ -2438,6 +2508,7 @@ static void ggml_metal_encode_node(
|
|
| 2438 |
switch (src0->type) {
|
| 2439 |
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_F32 ].pipeline; break;
|
| 2440 |
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_F16 ].pipeline; break;
|
|
|
|
| 2441 |
case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_0 ].pipeline; break;
|
| 2442 |
case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_1 ].pipeline; break;
|
| 2443 |
case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_0 ].pipeline; break;
|
|
@@ -3237,6 +3308,7 @@ static void ggml_metal_encode_node(
|
|
| 3237 |
switch (dstt) {
|
| 3238 |
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F32].pipeline; break;
|
| 3239 |
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F16].pipeline; break;
|
|
|
|
| 3240 |
case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0].pipeline; break;
|
| 3241 |
case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0].pipeline; break;
|
| 3242 |
case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1].pipeline; break;
|
|
@@ -3254,6 +3326,14 @@ static void ggml_metal_encode_node(
|
|
| 3254 |
default: GGML_ABORT("not implemented");
|
| 3255 |
};
|
| 3256 |
} break;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3257 |
default: GGML_ABORT("not implemented");
|
| 3258 |
}
|
| 3259 |
|
|
|
|
| 36 |
id<MTLDevice> mtl_device;
|
| 37 |
int mtl_device_ref_count;
|
| 38 |
|
| 39 |
+
bool has_simdgroup_reduction;
|
| 40 |
+
bool has_simdgroup_mm;
|
| 41 |
+
bool has_bfloat;
|
| 42 |
|
| 43 |
char name[128];
|
| 44 |
} g_ggml_ctx_dev_main = {
|
| 45 |
+
/*.mtl_device =*/ nil,
|
| 46 |
+
/*.mtl_device_ref_count =*/ 0,
|
| 47 |
+
/*.has_simdgroup_reduction =*/ false,
|
| 48 |
+
/*.has_simdgroup_mm =*/ false,
|
| 49 |
+
/*.has_bfloat =*/ false,
|
| 50 |
+
/*.name =*/ "",
|
| 51 |
};
|
| 52 |
|
| 53 |
// acquire
|
|
|
|
| 57 |
if (ctx->mtl_device == nil) {
|
| 58 |
ctx->mtl_device = MTLCreateSystemDefaultDevice();
|
| 59 |
|
| 60 |
+
ctx->has_simdgroup_reduction = [ctx->mtl_device supportsFamily:MTLGPUFamilyApple7];
|
| 61 |
+
ctx->has_simdgroup_reduction |= [ctx->mtl_device supportsFamily:MTLGPUFamilyMetal3_GGML];
|
| 62 |
|
| 63 |
+
ctx->has_simdgroup_mm = [ctx->mtl_device supportsFamily:MTLGPUFamilyApple7];
|
| 64 |
+
|
| 65 |
+
ctx->has_bfloat = [ctx->mtl_device supportsFamily:MTLGPUFamilyMetal3_GGML];
|
| 66 |
+
ctx->has_bfloat |= [ctx->mtl_device supportsFamily:MTLGPUFamilyApple6];
|
| 67 |
|
| 68 |
strncpy(ctx->name, [[ctx->mtl_device name] UTF8String], sizeof(ctx->name) - 1);
|
| 69 |
}
|
|
|
|
| 125 |
GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF_8,
|
| 126 |
GGML_METAL_KERNEL_TYPE_GET_ROWS_F32,
|
| 127 |
GGML_METAL_KERNEL_TYPE_GET_ROWS_F16,
|
| 128 |
+
GGML_METAL_KERNEL_TYPE_GET_ROWS_BF16,
|
| 129 |
GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_0,
|
| 130 |
GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_1,
|
| 131 |
GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_0,
|
|
|
|
| 152 |
GGML_METAL_KERNEL_TYPE_SSM_CONV_F32,
|
| 153 |
GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32,
|
| 154 |
GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32,
|
|
|
|
| 155 |
GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32,
|
| 156 |
GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW,
|
| 157 |
GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4,
|
| 158 |
+
GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16,
|
| 159 |
+
GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32,
|
| 160 |
+
GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_1ROW,
|
| 161 |
+
GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_L4,
|
| 162 |
+
GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_BF16,
|
| 163 |
GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_0_F32,
|
| 164 |
GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_1_F32,
|
| 165 |
GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32,
|
|
|
|
| 180 |
GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32,
|
| 181 |
GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_XS_F32,
|
| 182 |
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32,
|
|
|
|
| 183 |
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32,
|
| 184 |
//GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_1ROW,
|
| 185 |
//GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_L4,
|
| 186 |
+
//GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F16,
|
| 187 |
+
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_BF16_F32,
|
| 188 |
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_0_F32,
|
| 189 |
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_1_F32,
|
| 190 |
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_0_F32,
|
|
|
|
| 206 |
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32,
|
| 207 |
GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32,
|
| 208 |
GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32,
|
| 209 |
+
GGML_METAL_KERNEL_TYPE_MUL_MM_BF16_F32,
|
| 210 |
GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32,
|
| 211 |
GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F32,
|
| 212 |
GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32,
|
|
|
|
| 228 |
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32,
|
| 229 |
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32,
|
| 230 |
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32,
|
| 231 |
+
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_BF16_F32,
|
| 232 |
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32,
|
| 233 |
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F32,
|
| 234 |
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F32,
|
|
|
|
| 313 |
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H256,
|
| 314 |
GGML_METAL_KERNEL_TYPE_CPY_F32_F32,
|
| 315 |
GGML_METAL_KERNEL_TYPE_CPY_F32_F16,
|
| 316 |
+
GGML_METAL_KERNEL_TYPE_CPY_F32_BF16,
|
| 317 |
GGML_METAL_KERNEL_TYPE_CPY_F16_F16,
|
| 318 |
GGML_METAL_KERNEL_TYPE_CPY_F16_F32,
|
| 319 |
+
GGML_METAL_KERNEL_TYPE_CPY_BF16_F32,
|
| 320 |
+
GGML_METAL_KERNEL_TYPE_CPY_BF16_BF16,
|
| 321 |
GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0,
|
| 322 |
GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0,
|
| 323 |
GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1,
|
|
|
|
| 496 |
// dictionary of preprocessor macros
|
| 497 |
NSMutableDictionary * prep = [NSMutableDictionary dictionary];
|
| 498 |
|
| 499 |
+
MTLCompileOptions * options = [MTLCompileOptions new];
|
| 500 |
options.preprocessorMacros = prep;
|
| 501 |
|
| 502 |
//[options setFastMathEnabled:false];
|
|
|
|
| 546 |
}
|
| 547 |
}
|
| 548 |
|
| 549 |
+
GGML_LOG_INFO("%s: simdgroup reduction = %s\n", __func__, ctx_dev->has_simdgroup_reduction ? "true" : "false");
|
| 550 |
+
GGML_LOG_INFO("%s: simdgroup matrix mul. = %s\n", __func__, ctx_dev->has_simdgroup_mm ? "true" : "false");
|
| 551 |
+
GGML_LOG_INFO("%s: bfloat = %s\n", __func__, ctx_dev->has_bfloat ? "true" : "false");
|
| 552 |
+
GGML_LOG_INFO("%s: hasUnifiedMemory = %s\n", __func__, ctx_dev->mtl_device.hasUnifiedMemory ? "true" : "false");
|
| 553 |
|
| 554 |
ctx->capture_next_compute = false;
|
| 555 |
ctx->capture_started = false;
|
|
|
|
| 595 |
GGML_LOG_WARN("%s: skipping %-40s (not supported)\n", __func__, "kernel_"#name); \
|
| 596 |
}
|
| 597 |
|
| 598 |
+
const bool has_simdgroup_mm = ctx_dev->has_simdgroup_mm;
|
| 599 |
+
const bool has_simdgroup_reduction = ctx_dev->has_simdgroup_reduction;
|
| 600 |
+
const bool has_bfloat = ctx_dev->has_bfloat;
|
| 601 |
|
| 602 |
// simd_sum and simd_max requires MTLGPUFamilyApple7
|
| 603 |
|
|
|
|
| 625 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_QUICK_4, gelu_quick_4, true);
|
| 626 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU, silu, true);
|
| 627 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU_4, silu_4, true);
|
| 628 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16, soft_max_f16, has_simdgroup_reduction);
|
| 629 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16_4, soft_max_f16_4, has_simdgroup_reduction);
|
| 630 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32, soft_max_f32, has_simdgroup_reduction);
|
| 631 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32_4, soft_max_f32_4, has_simdgroup_reduction);
|
| 632 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF, diag_mask_inf, true);
|
| 633 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF_8, diag_mask_inf_8, true);
|
| 634 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_F32, get_rows_f32, true);
|
| 635 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_F16, get_rows_f16, true);
|
| 636 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_BF16, get_rows_bf16, has_bfloat);
|
| 637 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_0, get_rows_q4_0, true);
|
| 638 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_1, get_rows_q4_1, true);
|
| 639 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_0, get_rows_q5_0, true);
|
|
|
|
| 654 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL, get_rows_iq4_nl, true);
|
| 655 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS, get_rows_iq4_xs, true);
|
| 656 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_I32, get_rows_i32, true);
|
| 657 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RMS_NORM, rms_norm, has_simdgroup_reduction);
|
| 658 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GROUP_NORM, group_norm, has_simdgroup_reduction);
|
| 659 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NORM, norm, true);
|
| 660 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_CONV_F32, ssm_conv_f32, true);
|
| 661 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32, ssm_scan_f32, true);
|
| 662 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32, mul_mv_f32_f32, has_simdgroup_reduction);
|
| 663 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32, mul_mv_bf16_f32, has_simdgroup_reduction && has_bfloat);
|
| 664 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_1ROW, mul_mv_bf16_f32_1row, has_simdgroup_reduction && has_bfloat);
|
| 665 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_L4, mul_mv_bf16_f32_l4, has_simdgroup_reduction && has_bfloat);
|
| 666 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_BF16, mul_mv_bf16_bf16, has_simdgroup_reduction && has_bfloat);
|
| 667 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32, mul_mv_f16_f32, has_simdgroup_reduction);
|
| 668 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW, mul_mv_f16_f32_1row, has_simdgroup_reduction);
|
| 669 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4, mul_mv_f16_f32_l4, has_simdgroup_reduction);
|
| 670 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16, mul_mv_f16_f16, has_simdgroup_reduction);
|
| 671 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_0_F32, mul_mv_q4_0_f32, has_simdgroup_reduction);
|
| 672 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_1_F32, mul_mv_q4_1_f32, has_simdgroup_reduction);
|
| 673 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32, mul_mv_q5_0_f32, has_simdgroup_reduction);
|
| 674 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32, mul_mv_q5_1_f32, has_simdgroup_reduction);
|
| 675 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32, mul_mv_q8_0_f32, has_simdgroup_reduction);
|
| 676 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q2_K_F32, mul_mv_q2_K_f32, has_simdgroup_reduction);
|
| 677 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q3_K_F32, mul_mv_q3_K_f32, has_simdgroup_reduction);
|
| 678 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_K_F32, mul_mv_q4_K_f32, has_simdgroup_reduction);
|
| 679 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_K_F32, mul_mv_q5_K_f32, has_simdgroup_reduction);
|
| 680 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q6_K_F32, mul_mv_q6_K_f32, has_simdgroup_reduction);
|
| 681 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XXS_F32, mul_mv_iq2_xxs_f32, has_simdgroup_reduction);
|
| 682 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32, mul_mv_iq2_xs_f32, has_simdgroup_reduction);
|
| 683 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_XXS_F32, mul_mv_iq3_xxs_f32, has_simdgroup_reduction);
|
| 684 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_S_F32, mul_mv_iq3_s_f32, has_simdgroup_reduction);
|
| 685 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_S_F32, mul_mv_iq2_s_f32, has_simdgroup_reduction);
|
| 686 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_S_F32, mul_mv_iq1_s_f32, has_simdgroup_reduction);
|
| 687 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_M_F32, mul_mv_iq1_m_f32, has_simdgroup_reduction);
|
| 688 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32, mul_mv_iq4_nl_f32, has_simdgroup_reduction);
|
| 689 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_XS_F32, mul_mv_iq4_xs_f32, has_simdgroup_reduction);
|
| 690 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32, mul_mv_id_f32_f32, has_simdgroup_reduction);
|
| 691 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32, mul_mv_id_f16_f32, has_simdgroup_reduction);
|
| 692 |
+
//GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_1ROW, mul_mv_id_f16_f32_1row, has_simdgroup_reduction);
|
| 693 |
+
//GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_L4, mul_mv_id_f16_f32_l4, has_simdgroup_reduction);
|
| 694 |
+
//GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F16, mul_mv_id_f16_f16, has_simdgroup_reduction);
|
| 695 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_BF16_F32, mul_mv_id_bf16_f32, has_simdgroup_reduction && has_bfloat);
|
| 696 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_0_F32, mul_mv_id_q4_0_f32, has_simdgroup_reduction);
|
| 697 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_1_F32, mul_mv_id_q4_1_f32, has_simdgroup_reduction);
|
| 698 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_0_F32, mul_mv_id_q5_0_f32, has_simdgroup_reduction);
|
| 699 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_1_F32, mul_mv_id_q5_1_f32, has_simdgroup_reduction);
|
| 700 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q8_0_F32, mul_mv_id_q8_0_f32, has_simdgroup_reduction);
|
| 701 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q2_K_F32, mul_mv_id_q2_K_f32, has_simdgroup_reduction);
|
| 702 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q3_K_F32, mul_mv_id_q3_K_f32, has_simdgroup_reduction);
|
| 703 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_K_F32, mul_mv_id_q4_K_f32, has_simdgroup_reduction);
|
| 704 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_K_F32, mul_mv_id_q5_K_f32, has_simdgroup_reduction);
|
| 705 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q6_K_F32, mul_mv_id_q6_K_f32, has_simdgroup_reduction);
|
| 706 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XXS_F32, mul_mv_id_iq2_xxs_f32, has_simdgroup_reduction);
|
| 707 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32, mul_mv_id_iq2_xs_f32, has_simdgroup_reduction);
|
| 708 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_XXS_F32, mul_mv_id_iq3_xxs_f32, has_simdgroup_reduction);
|
| 709 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_S_F32, mul_mv_id_iq3_s_f32, has_simdgroup_reduction);
|
| 710 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_S_F32, mul_mv_id_iq2_s_f32, has_simdgroup_reduction);
|
| 711 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_S_F32, mul_mv_id_iq1_s_f32, has_simdgroup_reduction);
|
| 712 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_M_F32, mul_mv_id_iq1_m_f32, has_simdgroup_reduction);
|
| 713 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32, mul_mv_id_iq4_nl_f32, has_simdgroup_reduction);
|
| 714 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32, mul_mv_id_iq4_xs_f32, has_simdgroup_reduction);
|
| 715 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32, mul_mm_f32_f32, has_simdgroup_mm);
|
| 716 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32, mul_mm_f16_f32, has_simdgroup_mm);
|
| 717 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_BF16_F32, mul_mm_bf16_f32, has_simdgroup_mm && has_bfloat);
|
| 718 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32, mul_mm_q4_0_f32, has_simdgroup_mm);
|
| 719 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F32, mul_mm_q4_1_f32, has_simdgroup_mm);
|
| 720 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32, mul_mm_q5_0_f32, has_simdgroup_mm);
|
| 721 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_1_F32, mul_mm_q5_1_f32, has_simdgroup_mm);
|
| 722 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q8_0_F32, mul_mm_q8_0_f32, has_simdgroup_mm);
|
| 723 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q2_K_F32, mul_mm_q2_K_f32, has_simdgroup_mm);
|
| 724 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q3_K_F32, mul_mm_q3_K_f32, has_simdgroup_mm);
|
| 725 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_K_F32, mul_mm_q4_K_f32, has_simdgroup_mm);
|
| 726 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_K_F32, mul_mm_q5_K_f32, has_simdgroup_mm);
|
| 727 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q6_K_F32, mul_mm_q6_K_f32, has_simdgroup_mm);
|
| 728 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32, mul_mm_iq2_xxs_f32, has_simdgroup_mm);
|
| 729 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32, mul_mm_iq2_xs_f32, has_simdgroup_mm);
|
| 730 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F32, mul_mm_iq3_xxs_f32, has_simdgroup_mm);
|
| 731 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_S_F32, mul_mm_iq3_s_f32, has_simdgroup_mm);
|
| 732 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_S_F32, mul_mm_iq2_s_f32, has_simdgroup_mm);
|
| 733 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32, mul_mm_iq1_s_f32, has_simdgroup_mm);
|
| 734 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F32, mul_mm_iq1_m_f32, has_simdgroup_mm);
|
| 735 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32, mul_mm_iq4_nl_f32, has_simdgroup_mm);
|
| 736 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32, mul_mm_iq4_xs_f32, has_simdgroup_mm);
|
| 737 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32, mul_mm_id_f32_f32, has_simdgroup_mm);
|
| 738 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32, mul_mm_id_f16_f32, has_simdgroup_mm);
|
| 739 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_BF16_F32, mul_mm_id_bf16_f32, has_simdgroup_mm && has_bfloat);
|
| 740 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32, mul_mm_id_q4_0_f32, has_simdgroup_mm);
|
| 741 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F32, mul_mm_id_q4_1_f32, has_simdgroup_mm);
|
| 742 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F32, mul_mm_id_q5_0_f32, has_simdgroup_mm);
|
| 743 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F32, mul_mm_id_q5_1_f32, has_simdgroup_mm);
|
| 744 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F32, mul_mm_id_q8_0_f32, has_simdgroup_mm);
|
| 745 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F32, mul_mm_id_q2_K_f32, has_simdgroup_mm);
|
| 746 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F32, mul_mm_id_q3_K_f32, has_simdgroup_mm);
|
| 747 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F32, mul_mm_id_q4_K_f32, has_simdgroup_mm);
|
| 748 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_K_F32, mul_mm_id_q5_K_f32, has_simdgroup_mm);
|
| 749 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F32, mul_mm_id_q6_K_f32, has_simdgroup_mm);
|
| 750 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32, mul_mm_id_iq2_xxs_f32, has_simdgroup_mm);
|
| 751 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32, mul_mm_id_iq2_xs_f32, has_simdgroup_mm);
|
| 752 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32, mul_mm_id_iq3_xxs_f32, has_simdgroup_mm);
|
| 753 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_S_F32, mul_mm_id_iq3_s_f32, has_simdgroup_mm);
|
| 754 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F32, mul_mm_id_iq2_s_f32, has_simdgroup_mm);
|
| 755 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32, mul_mm_id_iq1_s_f32, has_simdgroup_mm);
|
| 756 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F32, mul_mm_id_iq1_m_f32, has_simdgroup_mm);
|
| 757 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32, mul_mm_id_iq4_nl_f32, has_simdgroup_mm);
|
| 758 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32, mul_mm_id_iq4_xs_f32, has_simdgroup_mm);
|
| 759 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32, rope_norm_f32, true);
|
| 760 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16, rope_norm_f16, true);
|
| 761 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32, rope_neox_f32, true);
|
|
|
|
| 771 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC, argsort_f32_i32_asc, true);
|
| 772 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC, argsort_f32_i32_desc, true);
|
| 773 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32, leaky_relu_f32, true);
|
| 774 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64, flash_attn_ext_f16_h64, has_simdgroup_mm);
|
| 775 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80, flash_attn_ext_f16_h80, has_simdgroup_mm);
|
| 776 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96, flash_attn_ext_f16_h96, has_simdgroup_mm);
|
| 777 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112, flash_attn_ext_f16_h112, has_simdgroup_mm);
|
| 778 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128, flash_attn_ext_f16_h128, has_simdgroup_mm);
|
| 779 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, flash_attn_ext_f16_h256, has_simdgroup_mm);
|
| 780 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H64, flash_attn_ext_q4_0_h64, has_simdgroup_mm);
|
| 781 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H80, flash_attn_ext_q4_0_h80, has_simdgroup_mm);
|
| 782 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H96, flash_attn_ext_q4_0_h96, has_simdgroup_mm);
|
| 783 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H112, flash_attn_ext_q4_0_h112, has_simdgroup_mm);
|
| 784 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H128, flash_attn_ext_q4_0_h128, has_simdgroup_mm);
|
| 785 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H256, flash_attn_ext_q4_0_h256, has_simdgroup_mm);
|
| 786 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H64, flash_attn_ext_q4_1_h64, has_simdgroup_mm);
|
| 787 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H80, flash_attn_ext_q4_1_h80, has_simdgroup_mm);
|
| 788 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H96, flash_attn_ext_q4_1_h96, has_simdgroup_mm);
|
| 789 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H112, flash_attn_ext_q4_1_h112, has_simdgroup_mm);
|
| 790 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H128, flash_attn_ext_q4_1_h128, has_simdgroup_mm);
|
| 791 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H256, flash_attn_ext_q4_1_h256, has_simdgroup_mm);
|
| 792 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H64, flash_attn_ext_q5_0_h64, has_simdgroup_mm);
|
| 793 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H80, flash_attn_ext_q5_0_h80, has_simdgroup_mm);
|
| 794 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H96, flash_attn_ext_q5_0_h96, has_simdgroup_mm);
|
| 795 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H112, flash_attn_ext_q5_0_h112, has_simdgroup_mm);
|
| 796 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H128, flash_attn_ext_q5_0_h128, has_simdgroup_mm);
|
| 797 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H256, flash_attn_ext_q5_0_h256, has_simdgroup_mm);
|
| 798 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H64, flash_attn_ext_q5_1_h64, has_simdgroup_mm);
|
| 799 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H80, flash_attn_ext_q5_1_h80, has_simdgroup_mm);
|
| 800 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H96, flash_attn_ext_q5_1_h96, has_simdgroup_mm);
|
| 801 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H112, flash_attn_ext_q5_1_h112, has_simdgroup_mm);
|
| 802 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H128, flash_attn_ext_q5_1_h128, has_simdgroup_mm);
|
| 803 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H256, flash_attn_ext_q5_1_h256, has_simdgroup_mm);
|
| 804 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H64, flash_attn_ext_q8_0_h64, has_simdgroup_mm);
|
| 805 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H80, flash_attn_ext_q8_0_h80, has_simdgroup_mm);
|
| 806 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H96, flash_attn_ext_q8_0_h96, has_simdgroup_mm);
|
| 807 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H112, flash_attn_ext_q8_0_h112, has_simdgroup_mm);
|
| 808 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H128, flash_attn_ext_q8_0_h128, has_simdgroup_mm);
|
| 809 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H256, flash_attn_ext_q8_0_h256, has_simdgroup_mm);
|
| 810 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128, flash_attn_ext_vec_f16_h128, has_simdgroup_reduction);
|
| 811 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H128, flash_attn_ext_vec_q4_0_h128, has_simdgroup_reduction);
|
| 812 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H128, flash_attn_ext_vec_q4_1_h128, has_simdgroup_reduction);
|
| 813 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H128, flash_attn_ext_vec_q5_0_h128, has_simdgroup_reduction);
|
| 814 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H128, flash_attn_ext_vec_q5_1_h128, has_simdgroup_reduction);
|
| 815 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H128, flash_attn_ext_vec_q8_0_h128, has_simdgroup_reduction);
|
| 816 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, flash_attn_ext_vec_f16_h256, has_simdgroup_reduction);
|
| 817 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H256, flash_attn_ext_vec_q4_0_h256, has_simdgroup_reduction);
|
| 818 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H256, flash_attn_ext_vec_q4_1_h256, has_simdgroup_reduction);
|
| 819 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H256, flash_attn_ext_vec_q5_0_h256, has_simdgroup_reduction);
|
| 820 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H256, flash_attn_ext_vec_q5_1_h256, has_simdgroup_reduction);
|
| 821 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H256, flash_attn_ext_vec_q8_0_h256, has_simdgroup_reduction);
|
|
|
|
| 822 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F32, cpy_f32_f32, true);
|
| 823 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F16, cpy_f32_f16, true);
|
| 824 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_BF16, cpy_f32_bf16, has_bfloat);
|
| 825 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F32, cpy_f16_f32, true);
|
| 826 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F16, cpy_f16_f16, true);
|
| 827 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_BF16_F32, cpy_bf16_f32, has_bfloat);
|
| 828 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_BF16_BF16, cpy_bf16_bf16, has_bfloat);
|
| 829 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0, cpy_f32_q8_0, true);
|
| 830 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0, cpy_f32_q4_0, true);
|
| 831 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1, cpy_f32_q4_1, true);
|
|
|
|
| 915 |
}
|
| 916 |
|
| 917 |
static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_context * ctx_dev, const struct ggml_tensor * op) {
|
| 918 |
+
const bool has_simdgroup_mm = ctx_dev->has_simdgroup_mm;
|
| 919 |
+
const bool has_simdgroup_reduction = ctx_dev->has_simdgroup_reduction;
|
| 920 |
+
const bool has_bfloat = ctx_dev->has_bfloat;
|
| 921 |
+
|
| 922 |
+
if (!has_bfloat) {
|
| 923 |
+
for (size_t i = 0, n = 3; i < n; ++i) {
|
| 924 |
+
if (op->src[i] != NULL && op->src[i]->type == GGML_TYPE_BF16) {
|
| 925 |
+
return false;
|
| 926 |
+
}
|
| 927 |
}
|
| 928 |
}
|
| 929 |
|
|
|
|
|
|
|
|
|
|
| 930 |
switch (op->op) {
|
| 931 |
case GGML_OP_UNARY:
|
| 932 |
switch (ggml_get_unary_op(op)) {
|
|
|
|
| 964 |
case GGML_OP_SOFT_MAX:
|
| 965 |
case GGML_OP_RMS_NORM:
|
| 966 |
case GGML_OP_GROUP_NORM:
|
| 967 |
+
return has_simdgroup_reduction;
|
| 968 |
case GGML_OP_NORM:
|
| 969 |
case GGML_OP_ROPE:
|
| 970 |
return true;
|
|
|
|
| 984 |
if (op->src[1]->type != op->src[2]->type) {
|
| 985 |
return false;
|
| 986 |
}
|
| 987 |
+
return has_simdgroup_mm; // TODO: over-restricted for vec-kernels
|
| 988 |
case GGML_OP_SSM_CONV:
|
| 989 |
case GGML_OP_SSM_SCAN:
|
| 990 |
return true;
|
| 991 |
case GGML_OP_MUL_MAT:
|
| 992 |
case GGML_OP_MUL_MAT_ID:
|
| 993 |
+
return has_simdgroup_reduction &&
|
| 994 |
(op->src[0]->type != GGML_TYPE_F32 || op->src[1]->type == GGML_TYPE_F32);
|
| 995 |
case GGML_OP_CPY:
|
| 996 |
case GGML_OP_DUP:
|
|
|
|
| 1001 |
switch (op->type) {
|
| 1002 |
case GGML_TYPE_F32:
|
| 1003 |
case GGML_TYPE_F16:
|
| 1004 |
+
case GGML_TYPE_BF16:
|
| 1005 |
case GGML_TYPE_Q8_0:
|
| 1006 |
case GGML_TYPE_Q4_0:
|
| 1007 |
case GGML_TYPE_Q4_1:
|
|
|
|
| 1014 |
}
|
| 1015 |
case GGML_TYPE_F16:
|
| 1016 |
switch (op->type) {
|
| 1017 |
+
case GGML_TYPE_F32:
|
| 1018 |
+
case GGML_TYPE_F16:
|
| 1019 |
return true;
|
| 1020 |
+
default:
|
| 1021 |
+
return false;
|
| 1022 |
+
}
|
| 1023 |
+
case GGML_TYPE_BF16:
|
| 1024 |
+
switch (op->type) {
|
| 1025 |
+
case GGML_TYPE_F32:
|
| 1026 |
+
case GGML_TYPE_BF16:
|
| 1027 |
+
return true;
|
| 1028 |
+
default:
|
| 1029 |
return false;
|
| 1030 |
}
|
| 1031 |
default:
|
|
|
|
| 1896 |
switch (src0->type) {
|
| 1897 |
case GGML_TYPE_F32: GGML_ASSERT(nb01 % 16 == 0); break;
|
| 1898 |
case GGML_TYPE_F16: GGML_ASSERT(nb01 % 8 == 0); break;
|
| 1899 |
+
case GGML_TYPE_BF16: GGML_ASSERT(nb01 % 8 == 0); break;
|
| 1900 |
default: break;
|
| 1901 |
}
|
| 1902 |
|
|
|
|
| 1905 |
switch (src0->type) {
|
| 1906 |
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32 ].pipeline; break;
|
| 1907 |
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32 ].pipeline; break;
|
| 1908 |
+
case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_BF16_F32 ].pipeline; break;
|
| 1909 |
case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32 ].pipeline; break;
|
| 1910 |
case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F32 ].pipeline; break;
|
| 1911 |
case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32 ].pipeline; break;
|
|
|
|
| 1983 |
nrows = 4;
|
| 1984 |
}
|
| 1985 |
} break;
|
| 1986 |
+
case GGML_TYPE_BF16:
|
| 1987 |
+
{
|
| 1988 |
+
nth0 = 32;
|
| 1989 |
+
nth1 = 1;
|
| 1990 |
+
if (src1t == GGML_TYPE_F32) {
|
| 1991 |
+
if (ne11 * ne12 < 4) {
|
| 1992 |
+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_1ROW].pipeline;
|
| 1993 |
+
} else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) {
|
| 1994 |
+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_L4].pipeline;
|
| 1995 |
+
nrows = ne11;
|
| 1996 |
+
} else {
|
| 1997 |
+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32].pipeline;
|
| 1998 |
+
nrows = 4;
|
| 1999 |
+
}
|
| 2000 |
+
} else {
|
| 2001 |
+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_BF16].pipeline;
|
| 2002 |
+
nrows = 4;
|
| 2003 |
+
}
|
| 2004 |
+
} break;
|
| 2005 |
case GGML_TYPE_Q4_0:
|
| 2006 |
{
|
| 2007 |
nth0 = 8;
|
|
|
|
| 2220 |
if ([device supportsFamily:MTLGPUFamilyApple7] &&
|
| 2221 |
ne00 % 32 == 0 && ne00 >= 64 &&
|
| 2222 |
dst_rows > dst_rows_min) {
|
|
|
|
| 2223 |
// some Metal matrix data types require aligned pointers
|
| 2224 |
// ref: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf (Table 2.5)
|
| 2225 |
switch (src0->type) {
|
| 2226 |
+
case GGML_TYPE_F32: GGML_ASSERT(nb01 % 16 == 0); break;
|
| 2227 |
+
case GGML_TYPE_F16: GGML_ASSERT(nb01 % 8 == 0); break;
|
| 2228 |
+
case GGML_TYPE_BF16: GGML_ASSERT(nb01 % 8 == 0); break;
|
| 2229 |
default: break;
|
| 2230 |
}
|
| 2231 |
|
|
|
|
| 2234 |
switch (src0->type) {
|
| 2235 |
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32 ].pipeline; break;
|
| 2236 |
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32 ].pipeline; break;
|
| 2237 |
+
case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_BF16_F32 ].pipeline; break;
|
| 2238 |
case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32 ].pipeline; break;
|
| 2239 |
case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F32 ].pipeline; break;
|
| 2240 |
case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F32 ].pipeline; break;
|
|
|
|
| 2304 |
nth1 = 1;
|
| 2305 |
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32].pipeline;
|
| 2306 |
} break;
|
| 2307 |
+
case GGML_TYPE_BF16:
|
| 2308 |
+
{
|
| 2309 |
+
GGML_ASSERT(src1t == GGML_TYPE_F32);
|
| 2310 |
+
nth0 = 32;
|
| 2311 |
+
nth1 = 1;
|
| 2312 |
+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_BF16_F32].pipeline;
|
| 2313 |
+
} break;
|
| 2314 |
case GGML_TYPE_Q4_0:
|
| 2315 |
{
|
| 2316 |
nth0 = 8;
|
|
|
|
| 2508 |
switch (src0->type) {
|
| 2509 |
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_F32 ].pipeline; break;
|
| 2510 |
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_F16 ].pipeline; break;
|
| 2511 |
+
case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_BF16 ].pipeline; break;
|
| 2512 |
case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_0 ].pipeline; break;
|
| 2513 |
case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_1 ].pipeline; break;
|
| 2514 |
case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_0 ].pipeline; break;
|
|
|
|
| 3308 |
switch (dstt) {
|
| 3309 |
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F32].pipeline; break;
|
| 3310 |
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F16].pipeline; break;
|
| 3311 |
+
case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_BF16].pipeline; break;
|
| 3312 |
case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0].pipeline; break;
|
| 3313 |
case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0].pipeline; break;
|
| 3314 |
case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1].pipeline; break;
|
|
|
|
| 3326 |
default: GGML_ABORT("not implemented");
|
| 3327 |
};
|
| 3328 |
} break;
|
| 3329 |
+
case GGML_TYPE_BF16:
|
| 3330 |
+
{
|
| 3331 |
+
switch (dstt) {
|
| 3332 |
+
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_BF16_F32].pipeline; break;
|
| 3333 |
+
case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_BF16_BF16].pipeline; break;
|
| 3334 |
+
default: GGML_ASSERT(false && "not implemented");
|
| 3335 |
+
};
|
| 3336 |
+
} break;
|
| 3337 |
default: GGML_ABORT("not implemented");
|
| 3338 |
}
|
| 3339 |
|
ggml/src/ggml-metal.metal
CHANGED
|
@@ -12,6 +12,20 @@ using namespace metal;
|
|
| 12 |
|
| 13 |
#define N_SIMDWIDTH 32 // assuming SIMD group size is 32
|
| 14 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 15 |
constexpr constant static float kvalues_iq4nl_f[16] = {
|
| 16 |
-127.f, -104.f, -83.f, -65.f, -49.f, -35.f, -22.f, -10.f, 1.f, 13.f, 25.f, 38.f, 53.f, 69.f, 89.f, 113.f
|
| 17 |
};
|
|
@@ -27,6 +41,13 @@ void dequantize_f16(device const half4x4 * src, short il, thread type4x4 & reg)
|
|
| 27 |
reg = (type4x4)(*src);
|
| 28 |
}
|
| 29 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 30 |
template <typename type4x4>
|
| 31 |
void dequantize_q4_0(device const block_q4_0 *xb, short il, thread type4x4 & reg) {
|
| 32 |
device const uint16_t * qs = ((device const uint16_t *)xb + 1);
|
|
@@ -2041,6 +2062,10 @@ typedef decltype(kernel_mul_mv<half, half4, half, half4>) mul_mv_t;
|
|
| 2041 |
template [[host_name("kernel_mul_mv_f32_f32")]] kernel mul_mv_t kernel_mul_mv<float, float4, float, float4>;
|
| 2042 |
template [[host_name("kernel_mul_mv_f16_f32")]] kernel mul_mv_t kernel_mul_mv<half, half4, float, float4>;
|
| 2043 |
template [[host_name("kernel_mul_mv_f16_f16")]] kernel mul_mv_t kernel_mul_mv<half, half4, half, half4>;
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2044 |
|
| 2045 |
template<typename T, typename T4>
|
| 2046 |
kernel void kernel_mul_mv_1row(
|
|
@@ -2110,6 +2135,9 @@ kernel void kernel_mul_mv_1row(
|
|
| 2110 |
typedef decltype(kernel_mul_mv_1row<half, half4>) mul_mv_1row_t;
|
| 2111 |
|
| 2112 |
template [[host_name("kernel_mul_mv_f16_f32_1row")]] kernel mul_mv_1row_t kernel_mul_mv_1row<half, half4>;
|
|
|
|
|
|
|
|
|
|
| 2113 |
|
| 2114 |
// Assumes row size (ne00) is a multiple of 4
|
| 2115 |
template<typename T, typename T4>
|
|
@@ -2169,6 +2197,9 @@ kernel void kernel_mul_mv_l4(
|
|
| 2169 |
typedef decltype(kernel_mul_mv_l4<half, half4>) mul_mv_l4_t;
|
| 2170 |
|
| 2171 |
template [[host_name("kernel_mul_mv_f16_f32_l4")]] kernel mul_mv_l4_t kernel_mul_mv_l4<half, half4>;
|
|
|
|
|
|
|
|
|
|
| 2172 |
|
| 2173 |
static float rope_yarn_ramp(const float low, const float high, const int i0) {
|
| 2174 |
const float y = (i0 / 2 - low) / max(0.001f, high - low);
|
|
@@ -3565,10 +3596,17 @@ kernel void kernel_cpy(
|
|
| 3565 |
|
| 3566 |
typedef decltype(kernel_cpy<float, float>) kernel_cpy_t;
|
| 3567 |
|
| 3568 |
-
template [[host_name("kernel_cpy_f32_f32")]]
|
| 3569 |
-
template [[host_name("kernel_cpy_f32_f16")]]
|
| 3570 |
-
|
| 3571 |
-
template [[host_name("
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3572 |
|
| 3573 |
kernel void kernel_cpy_f32_q8_0(
|
| 3574 |
device const float * src0,
|
|
@@ -6473,6 +6511,9 @@ typedef decltype(kernel_get_rows_f<float>) get_rows_f_t;
|
|
| 6473 |
|
| 6474 |
template [[host_name("kernel_get_rows_f32")]] kernel get_rows_f_t kernel_get_rows_f<float>;
|
| 6475 |
template [[host_name("kernel_get_rows_f16")]] kernel get_rows_f_t kernel_get_rows_f<half>;
|
|
|
|
|
|
|
|
|
|
| 6476 |
|
| 6477 |
typedef decltype(kernel_get_rows_q<block_q4_0, 2, dequantize_q4_0>) get_rows_q_t;
|
| 6478 |
|
|
@@ -6504,6 +6545,9 @@ typedef decltype(kernel_mul_mm<half, half4x4, simdgroup_half8x8, float4x4, 1, de
|
|
| 6504 |
|
| 6505 |
template [[host_name("kernel_mul_mm_f32_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, float4x4, 1, dequantize_f32>;
|
| 6506 |
template [[host_name("kernel_mul_mm_f16_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half4x4, 1, dequantize_f16>;
|
|
|
|
|
|
|
|
|
|
| 6507 |
template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q4_0, 2, dequantize_q4_0>;
|
| 6508 |
template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q4_1, 2, dequantize_q4_1>;
|
| 6509 |
template [[host_name("kernel_mul_mm_q5_0_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q5_0, 2, dequantize_q5_0>;
|
|
@@ -6532,6 +6576,9 @@ typedef decltype(kernel_mul_mm_id<float4x4, 1, dequantize_f32>) mat_mm_id_t;
|
|
| 6532 |
|
| 6533 |
template [[host_name("kernel_mul_mm_id_f32_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<float4x4, 1, dequantize_f32>;
|
| 6534 |
template [[host_name("kernel_mul_mm_id_f16_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<half4x4, 1, dequantize_f16>;
|
|
|
|
|
|
|
|
|
|
| 6535 |
template [[host_name("kernel_mul_mm_id_q4_0_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q4_0, 2, dequantize_q4_0>;
|
| 6536 |
template [[host_name("kernel_mul_mm_id_q4_1_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q4_1, 2, dequantize_q4_1>;
|
| 6537 |
template [[host_name("kernel_mul_mm_id_q5_0_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q5_0, 2, dequantize_q5_0>;
|
|
@@ -6755,6 +6802,9 @@ typedef decltype(kernel_mul_mv_id<mmv_fn<kernel_mul_mv_impl<float, float4, float
|
|
| 6755 |
|
| 6756 |
template [[host_name("kernel_mul_mv_id_f32_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_impl<float, float4, float, float4>>>;
|
| 6757 |
template [[host_name("kernel_mul_mv_id_f16_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_impl<half, half4, float, float4>>>;
|
|
|
|
|
|
|
|
|
|
| 6758 |
template [[host_name("kernel_mul_mv_id_q8_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q8_0_f32_impl>>;
|
| 6759 |
template [[host_name("kernel_mul_mv_id_q4_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>>>;
|
| 6760 |
template [[host_name("kernel_mul_mv_id_q4_1_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>>>;
|
|
|
|
| 12 |
|
| 13 |
#define N_SIMDWIDTH 32 // assuming SIMD group size is 32
|
| 14 |
|
| 15 |
+
// ref: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf
|
| 16 |
+
//
|
| 17 |
+
// cmd:
|
| 18 |
+
// .../usr/bin/metal -dM -E -c ggml/src/ggml-metal.metal
|
| 19 |
+
// .../usr/bin/metal -dM -E -c -target air64-apple-ios14.0 ggml/src/ggml-metal.metal
|
| 20 |
+
//
|
| 21 |
+
#if __METAL_VERSION__ < 310
|
| 22 |
+
#define GGML_METAL_NO_BFLOAT
|
| 23 |
+
#endif
|
| 24 |
+
|
| 25 |
+
#if !defined(GGML_METAL_NO_BFLOAT)
|
| 26 |
+
typedef matrix<bfloat, 4, 4> bfloat4x4;
|
| 27 |
+
#endif
|
| 28 |
+
|
| 29 |
constexpr constant static float kvalues_iq4nl_f[16] = {
|
| 30 |
-127.f, -104.f, -83.f, -65.f, -49.f, -35.f, -22.f, -10.f, 1.f, 13.f, 25.f, 38.f, 53.f, 69.f, 89.f, 113.f
|
| 31 |
};
|
|
|
|
| 41 |
reg = (type4x4)(*src);
|
| 42 |
}
|
| 43 |
|
| 44 |
+
#if !defined(GGML_METAL_NO_BFLOAT)
|
| 45 |
+
template <typename type4x4>
|
| 46 |
+
void dequantize_bf16(device const bfloat4x4 * src, short il, thread type4x4 & reg) {
|
| 47 |
+
reg = (type4x4)(*src);
|
| 48 |
+
}
|
| 49 |
+
#endif
|
| 50 |
+
|
| 51 |
template <typename type4x4>
|
| 52 |
void dequantize_q4_0(device const block_q4_0 *xb, short il, thread type4x4 & reg) {
|
| 53 |
device const uint16_t * qs = ((device const uint16_t *)xb + 1);
|
|
|
|
| 2062 |
template [[host_name("kernel_mul_mv_f32_f32")]] kernel mul_mv_t kernel_mul_mv<float, float4, float, float4>;
|
| 2063 |
template [[host_name("kernel_mul_mv_f16_f32")]] kernel mul_mv_t kernel_mul_mv<half, half4, float, float4>;
|
| 2064 |
template [[host_name("kernel_mul_mv_f16_f16")]] kernel mul_mv_t kernel_mul_mv<half, half4, half, half4>;
|
| 2065 |
+
#if !defined(GGML_METAL_NO_BFLOAT)
|
| 2066 |
+
template [[host_name("kernel_mul_mv_bf16_f32")]] kernel mul_mv_t kernel_mul_mv<bfloat, bfloat4, float, float4>;
|
| 2067 |
+
template [[host_name("kernel_mul_mv_bf16_bf16")]] kernel mul_mv_t kernel_mul_mv<bfloat, bfloat4, bfloat, bfloat4>;
|
| 2068 |
+
#endif
|
| 2069 |
|
| 2070 |
template<typename T, typename T4>
|
| 2071 |
kernel void kernel_mul_mv_1row(
|
|
|
|
| 2135 |
typedef decltype(kernel_mul_mv_1row<half, half4>) mul_mv_1row_t;
|
| 2136 |
|
| 2137 |
template [[host_name("kernel_mul_mv_f16_f32_1row")]] kernel mul_mv_1row_t kernel_mul_mv_1row<half, half4>;
|
| 2138 |
+
#if !defined(GGML_METAL_NO_BFLOAT)
|
| 2139 |
+
template [[host_name("kernel_mul_mv_bf16_f32_1row")]] kernel mul_mv_1row_t kernel_mul_mv_1row<bfloat, bfloat4>;
|
| 2140 |
+
#endif
|
| 2141 |
|
| 2142 |
// Assumes row size (ne00) is a multiple of 4
|
| 2143 |
template<typename T, typename T4>
|
|
|
|
| 2197 |
typedef decltype(kernel_mul_mv_l4<half, half4>) mul_mv_l4_t;
|
| 2198 |
|
| 2199 |
template [[host_name("kernel_mul_mv_f16_f32_l4")]] kernel mul_mv_l4_t kernel_mul_mv_l4<half, half4>;
|
| 2200 |
+
#if !defined(GGML_METAL_NO_BFLOAT)
|
| 2201 |
+
template [[host_name("kernel_mul_mv_bf16_f32_l4")]] kernel mul_mv_l4_t kernel_mul_mv_l4<bfloat, bfloat4>;
|
| 2202 |
+
#endif
|
| 2203 |
|
| 2204 |
static float rope_yarn_ramp(const float low, const float high, const int i0) {
|
| 2205 |
const float y = (i0 / 2 - low) / max(0.001f, high - low);
|
|
|
|
| 3596 |
|
| 3597 |
typedef decltype(kernel_cpy<float, float>) kernel_cpy_t;
|
| 3598 |
|
| 3599 |
+
template [[host_name("kernel_cpy_f32_f32")]] kernel kernel_cpy_t kernel_cpy<float, float>;
|
| 3600 |
+
template [[host_name("kernel_cpy_f32_f16")]] kernel kernel_cpy_t kernel_cpy<float, half>;
|
| 3601 |
+
#if !defined(GGML_METAL_NO_BFLOAT)
|
| 3602 |
+
template [[host_name("kernel_cpy_f32_bf16")]] kernel kernel_cpy_t kernel_cpy<float, bfloat>;
|
| 3603 |
+
#endif
|
| 3604 |
+
template [[host_name("kernel_cpy_f16_f32")]] kernel kernel_cpy_t kernel_cpy<half, float>;
|
| 3605 |
+
template [[host_name("kernel_cpy_f16_f16")]] kernel kernel_cpy_t kernel_cpy<half, half>;
|
| 3606 |
+
#if !defined(GGML_METAL_NO_BFLOAT)
|
| 3607 |
+
template [[host_name("kernel_cpy_bf16_f32")]] kernel kernel_cpy_t kernel_cpy<bfloat, float>;
|
| 3608 |
+
template [[host_name("kernel_cpy_bf16_bf16")]] kernel kernel_cpy_t kernel_cpy<bfloat, bfloat>;
|
| 3609 |
+
#endif
|
| 3610 |
|
| 3611 |
kernel void kernel_cpy_f32_q8_0(
|
| 3612 |
device const float * src0,
|
|
|
|
| 6511 |
|
| 6512 |
template [[host_name("kernel_get_rows_f32")]] kernel get_rows_f_t kernel_get_rows_f<float>;
|
| 6513 |
template [[host_name("kernel_get_rows_f16")]] kernel get_rows_f_t kernel_get_rows_f<half>;
|
| 6514 |
+
#if !defined(GGML_METAL_NO_BFLOAT)
|
| 6515 |
+
template [[host_name("kernel_get_rows_bf16")]] kernel get_rows_f_t kernel_get_rows_f<bfloat>;
|
| 6516 |
+
#endif
|
| 6517 |
|
| 6518 |
typedef decltype(kernel_get_rows_q<block_q4_0, 2, dequantize_q4_0>) get_rows_q_t;
|
| 6519 |
|
|
|
|
| 6545 |
|
| 6546 |
template [[host_name("kernel_mul_mm_f32_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, float4x4, 1, dequantize_f32>;
|
| 6547 |
template [[host_name("kernel_mul_mm_f16_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half4x4, 1, dequantize_f16>;
|
| 6548 |
+
#if !defined(GGML_METAL_NO_BFLOAT)
|
| 6549 |
+
template [[host_name("kernel_mul_mm_bf16_f32")]] kernel mat_mm_t kernel_mul_mm<bfloat, bfloat4x4, simdgroup_bfloat8x8, bfloat4x4, 1, dequantize_bf16>;
|
| 6550 |
+
#endif
|
| 6551 |
template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q4_0, 2, dequantize_q4_0>;
|
| 6552 |
template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q4_1, 2, dequantize_q4_1>;
|
| 6553 |
template [[host_name("kernel_mul_mm_q5_0_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q5_0, 2, dequantize_q5_0>;
|
|
|
|
| 6576 |
|
| 6577 |
template [[host_name("kernel_mul_mm_id_f32_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<float4x4, 1, dequantize_f32>;
|
| 6578 |
template [[host_name("kernel_mul_mm_id_f16_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<half4x4, 1, dequantize_f16>;
|
| 6579 |
+
#if !defined(GGML_METAL_NO_BFLOAT)
|
| 6580 |
+
template [[host_name("kernel_mul_mm_id_bf16_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<bfloat4x4, 1, dequantize_bf16>;
|
| 6581 |
+
#endif
|
| 6582 |
template [[host_name("kernel_mul_mm_id_q4_0_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q4_0, 2, dequantize_q4_0>;
|
| 6583 |
template [[host_name("kernel_mul_mm_id_q4_1_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q4_1, 2, dequantize_q4_1>;
|
| 6584 |
template [[host_name("kernel_mul_mm_id_q5_0_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q5_0, 2, dequantize_q5_0>;
|
|
|
|
| 6802 |
|
| 6803 |
template [[host_name("kernel_mul_mv_id_f32_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_impl<float, float4, float, float4>>>;
|
| 6804 |
template [[host_name("kernel_mul_mv_id_f16_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_impl<half, half4, float, float4>>>;
|
| 6805 |
+
#if !defined(GGML_METAL_NO_BFLOAT)
|
| 6806 |
+
template [[host_name("kernel_mul_mv_id_bf16_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_impl<bfloat, bfloat4, float, float4>>>;
|
| 6807 |
+
#endif
|
| 6808 |
template [[host_name("kernel_mul_mv_id_q8_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q8_0_f32_impl>>;
|
| 6809 |
template [[host_name("kernel_mul_mv_id_q4_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>>>;
|
| 6810 |
template [[host_name("kernel_mul_mv_id_q4_1_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>>>;
|