Spaces:
Sleeping
Sleeping
metal : refactor mat-vec code (llama/12569)
Browse files* metal : refactor mat-vec code
ggml-ci
* metal : rename all_sum -> sum_all
ggml-ci
* metal : fix comments [no ci]
* metal : fix nr constant [no ci]
* metal : mv q6_K support nr0 > 1
ggml-ci
* metal : reduce register pressure
ggml-ci
* metal : fix typo [no ci]
* metal : reduce register pressure
ggml-ci
- ggml/src/ggml-metal/ggml-metal-impl.h +64 -0
- ggml/src/ggml-metal/ggml-metal.m +127 -171
- ggml/src/ggml-metal/ggml-metal.metal +330 -327
ggml/src/ggml-metal/ggml-metal-impl.h
CHANGED
|
@@ -1,6 +1,70 @@
|
|
| 1 |
#ifndef GGML_METAL_IMPL
|
| 2 |
#define GGML_METAL_IMPL
|
| 3 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4 |
// kernel argument structs
|
| 5 |
//
|
| 6 |
// - element counters (e.g. ne00) typically use int32_t to reduce register usage
|
|
|
|
| 1 |
#ifndef GGML_METAL_IMPL
|
| 2 |
#define GGML_METAL_IMPL
|
| 3 |
|
| 4 |
+
// kernel parameters for mat-vec threadgroups
|
| 5 |
+
//
|
| 6 |
+
// N_R0: number of src0 rows to process per simdgroup
|
| 7 |
+
// N_SG: number of simdgroups per threadgroup
|
| 8 |
+
//
|
| 9 |
+
// TODO: for optimal performance, become function of the device and work size
|
| 10 |
+
|
| 11 |
+
#define N_R0_Q4_0 4
|
| 12 |
+
#define N_SG_Q4_0 2
|
| 13 |
+
|
| 14 |
+
#define N_R0_Q4_1 4
|
| 15 |
+
#define N_SG_Q4_1 2
|
| 16 |
+
|
| 17 |
+
#define N_R0_Q5_0 4
|
| 18 |
+
#define N_SG_Q5_0 2
|
| 19 |
+
|
| 20 |
+
#define N_R0_Q5_1 4
|
| 21 |
+
#define N_SG_Q5_1 2
|
| 22 |
+
|
| 23 |
+
#define N_R0_Q8_0 4
|
| 24 |
+
#define N_SG_Q8_0 2
|
| 25 |
+
|
| 26 |
+
#define N_R0_Q2_K 4
|
| 27 |
+
#define N_SG_Q2_K 2
|
| 28 |
+
|
| 29 |
+
#define N_R0_Q3_K 2
|
| 30 |
+
#define N_SG_Q3_K 2
|
| 31 |
+
|
| 32 |
+
#define N_R0_Q4_K 4
|
| 33 |
+
#define N_SG_Q4_K 2
|
| 34 |
+
|
| 35 |
+
#define N_R0_Q5_K 2
|
| 36 |
+
#define N_SG_Q5_K 2
|
| 37 |
+
|
| 38 |
+
#define N_R0_Q6_K 1
|
| 39 |
+
#define N_SG_Q6_K 2
|
| 40 |
+
|
| 41 |
+
#define N_R0_IQ1_S 4
|
| 42 |
+
#define N_SG_IQ1_S 2
|
| 43 |
+
|
| 44 |
+
#define N_R0_IQ1_M 4
|
| 45 |
+
#define N_SG_IQ1_M 2
|
| 46 |
+
|
| 47 |
+
#define N_R0_IQ2_XXS 4
|
| 48 |
+
#define N_SG_IQ2_XXS 2
|
| 49 |
+
|
| 50 |
+
#define N_R0_IQ2_XS 4
|
| 51 |
+
#define N_SG_IQ2_XS 2
|
| 52 |
+
|
| 53 |
+
#define N_R0_IQ2_S 4
|
| 54 |
+
#define N_SG_IQ2_S 2
|
| 55 |
+
|
| 56 |
+
#define N_R0_IQ3_XXS 4
|
| 57 |
+
#define N_SG_IQ3_XXS 2
|
| 58 |
+
|
| 59 |
+
#define N_R0_IQ3_S 4
|
| 60 |
+
#define N_SG_IQ3_S 2
|
| 61 |
+
|
| 62 |
+
#define N_R0_IQ4_NL 2
|
| 63 |
+
#define N_SG_IQ4_NL 2
|
| 64 |
+
|
| 65 |
+
#define N_R0_IQ4_XS 2
|
| 66 |
+
#define N_SG_IQ4_XS 2
|
| 67 |
+
|
| 68 |
// kernel argument structs
|
| 69 |
//
|
| 70 |
// - element counters (e.g. ne00) typically use int32_t to reduce register usage
|
ggml/src/ggml-metal/ggml-metal.m
CHANGED
|
@@ -2561,171 +2561,180 @@ static void ggml_metal_encode_node(
|
|
| 2561 |
[encoder setThreadgroupMemoryLength:8192 atIndex:0];
|
| 2562 |
[encoder dispatchThreadgroups:MTLSizeMake( (ne11 + 31)/32, (ne01 + 63)/64, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
|
| 2563 |
} else {
|
| 2564 |
-
int nth0 = 32;
|
| 2565 |
-
int nth1 = 1;
|
| 2566 |
-
int nrows = 1;
|
| 2567 |
-
//printf("vector: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
|
| 2568 |
-
|
| 2569 |
id<MTLComputePipelineState> pipeline = nil;
|
| 2570 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2571 |
// use custom matrix x vector kernel
|
| 2572 |
switch (src0t) {
|
| 2573 |
case GGML_TYPE_F32:
|
| 2574 |
{
|
| 2575 |
GGML_ASSERT(src1t == GGML_TYPE_F32);
|
|
|
|
|
|
|
|
|
|
| 2576 |
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32].pipeline;
|
| 2577 |
-
nrows = 4;
|
| 2578 |
} break;
|
| 2579 |
case GGML_TYPE_F16:
|
| 2580 |
{
|
| 2581 |
-
|
| 2582 |
-
|
| 2583 |
if (src1t == GGML_TYPE_F32) {
|
| 2584 |
if (ne11 * ne12 < 4) {
|
| 2585 |
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW].pipeline;
|
| 2586 |
} else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) {
|
| 2587 |
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4].pipeline;
|
| 2588 |
-
|
| 2589 |
} else {
|
| 2590 |
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32].pipeline;
|
| 2591 |
-
|
| 2592 |
}
|
| 2593 |
} else {
|
| 2594 |
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16].pipeline;
|
| 2595 |
-
|
| 2596 |
}
|
| 2597 |
} break;
|
| 2598 |
case GGML_TYPE_BF16:
|
| 2599 |
{
|
| 2600 |
-
|
| 2601 |
-
|
| 2602 |
if (src1t == GGML_TYPE_F32) {
|
| 2603 |
if (ne11 * ne12 < 4) {
|
| 2604 |
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_1ROW].pipeline;
|
| 2605 |
} else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) {
|
| 2606 |
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_L4].pipeline;
|
| 2607 |
-
|
| 2608 |
} else {
|
| 2609 |
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32].pipeline;
|
| 2610 |
-
|
| 2611 |
}
|
| 2612 |
} else {
|
| 2613 |
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_BF16].pipeline;
|
| 2614 |
-
|
| 2615 |
}
|
| 2616 |
} break;
|
| 2617 |
case GGML_TYPE_Q4_0:
|
| 2618 |
{
|
| 2619 |
-
|
| 2620 |
-
|
| 2621 |
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_0_F32].pipeline;
|
| 2622 |
} break;
|
| 2623 |
case GGML_TYPE_Q4_1:
|
| 2624 |
{
|
| 2625 |
-
|
| 2626 |
-
|
| 2627 |
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_1_F32].pipeline;
|
| 2628 |
} break;
|
| 2629 |
case GGML_TYPE_Q5_0:
|
| 2630 |
{
|
| 2631 |
-
|
| 2632 |
-
|
| 2633 |
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32].pipeline;
|
| 2634 |
} break;
|
| 2635 |
case GGML_TYPE_Q5_1:
|
| 2636 |
{
|
| 2637 |
-
|
| 2638 |
-
|
| 2639 |
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32].pipeline;
|
| 2640 |
} break;
|
| 2641 |
case GGML_TYPE_Q8_0:
|
| 2642 |
{
|
| 2643 |
-
|
| 2644 |
-
|
| 2645 |
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32].pipeline;
|
| 2646 |
} break;
|
| 2647 |
case GGML_TYPE_Q2_K:
|
| 2648 |
{
|
| 2649 |
-
|
| 2650 |
-
|
| 2651 |
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q2_K_F32].pipeline;
|
| 2652 |
} break;
|
| 2653 |
case GGML_TYPE_Q3_K:
|
| 2654 |
{
|
| 2655 |
-
|
| 2656 |
-
|
| 2657 |
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q3_K_F32].pipeline;
|
| 2658 |
} break;
|
| 2659 |
case GGML_TYPE_Q4_K:
|
| 2660 |
{
|
| 2661 |
-
|
| 2662 |
-
|
| 2663 |
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_K_F32].pipeline;
|
| 2664 |
} break;
|
| 2665 |
case GGML_TYPE_Q5_K:
|
| 2666 |
{
|
| 2667 |
-
|
| 2668 |
-
|
| 2669 |
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_K_F32].pipeline;
|
| 2670 |
} break;
|
| 2671 |
case GGML_TYPE_Q6_K:
|
| 2672 |
{
|
| 2673 |
-
|
| 2674 |
-
|
| 2675 |
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q6_K_F32].pipeline;
|
| 2676 |
} break;
|
| 2677 |
case GGML_TYPE_IQ2_XXS:
|
| 2678 |
{
|
| 2679 |
-
|
| 2680 |
-
|
|
|
|
| 2681 |
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XXS_F32].pipeline;
|
| 2682 |
} break;
|
| 2683 |
case GGML_TYPE_IQ2_XS:
|
| 2684 |
{
|
| 2685 |
-
|
| 2686 |
-
|
|
|
|
| 2687 |
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32].pipeline;
|
| 2688 |
} break;
|
| 2689 |
case GGML_TYPE_IQ3_XXS:
|
| 2690 |
{
|
| 2691 |
-
|
| 2692 |
-
|
|
|
|
| 2693 |
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_XXS_F32].pipeline;
|
| 2694 |
} break;
|
| 2695 |
case GGML_TYPE_IQ3_S:
|
| 2696 |
{
|
| 2697 |
-
|
| 2698 |
-
|
|
|
|
| 2699 |
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_S_F32].pipeline;
|
| 2700 |
} break;
|
| 2701 |
case GGML_TYPE_IQ2_S:
|
| 2702 |
{
|
| 2703 |
-
|
| 2704 |
-
|
| 2705 |
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_S_F32].pipeline;
|
| 2706 |
} break;
|
| 2707 |
case GGML_TYPE_IQ1_S:
|
| 2708 |
{
|
| 2709 |
-
|
| 2710 |
-
|
| 2711 |
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_S_F32].pipeline;
|
| 2712 |
} break;
|
| 2713 |
case GGML_TYPE_IQ1_M:
|
| 2714 |
{
|
| 2715 |
-
|
| 2716 |
-
|
| 2717 |
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_M_F32].pipeline;
|
| 2718 |
} break;
|
| 2719 |
case GGML_TYPE_IQ4_NL:
|
| 2720 |
{
|
| 2721 |
-
|
| 2722 |
-
|
|
|
|
| 2723 |
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32].pipeline;
|
| 2724 |
} break;
|
| 2725 |
case GGML_TYPE_IQ4_XS:
|
| 2726 |
{
|
| 2727 |
-
|
| 2728 |
-
|
|
|
|
| 2729 |
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_XS_F32].pipeline;
|
| 2730 |
} break;
|
| 2731 |
default:
|
|
@@ -2762,41 +2771,10 @@ static void ggml_metal_encode_node(
|
|
| 2762 |
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:2];
|
| 2763 |
[encoder setBuffer:id_dst offset:offs_dst atIndex:3];
|
| 2764 |
|
| 2765 |
-
if (
|
| 2766 |
-
|
| 2767 |
-
src0t == GGML_TYPE_IQ1_S || src0t == GGML_TYPE_IQ1_M || src0t == GGML_TYPE_IQ2_S) {
|
| 2768 |
-
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
| 2769 |
-
}
|
| 2770 |
-
else if (src0t == GGML_TYPE_IQ2_XXS || src0t == GGML_TYPE_IQ2_XS) {
|
| 2771 |
-
const int mem_size = src0t == GGML_TYPE_IQ2_XXS ? 256*8+128 : 512*8+128;
|
| 2772 |
-
[encoder setThreadgroupMemoryLength:mem_size atIndex:0];
|
| 2773 |
-
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
| 2774 |
-
}
|
| 2775 |
-
else if (src0t == GGML_TYPE_IQ3_XXS || src0t == GGML_TYPE_IQ3_S) {
|
| 2776 |
-
const int mem_size = src0t == GGML_TYPE_IQ3_XXS ? 256*4+128 : 512*4;
|
| 2777 |
-
[encoder setThreadgroupMemoryLength:mem_size atIndex:0];
|
| 2778 |
-
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
| 2779 |
-
}
|
| 2780 |
-
else if (src0t == GGML_TYPE_IQ4_NL || src0t == GGML_TYPE_IQ4_XS) {
|
| 2781 |
-
const int mem_size = 32*sizeof(float);
|
| 2782 |
-
[encoder setThreadgroupMemoryLength:mem_size atIndex:0];
|
| 2783 |
-
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
| 2784 |
-
}
|
| 2785 |
-
else if (src0t == GGML_TYPE_Q4_K) {
|
| 2786 |
-
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
| 2787 |
-
}
|
| 2788 |
-
else if (src0t == GGML_TYPE_Q3_K) {
|
| 2789 |
-
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
| 2790 |
-
}
|
| 2791 |
-
else if (src0t == GGML_TYPE_Q5_K) {
|
| 2792 |
-
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
| 2793 |
-
}
|
| 2794 |
-
else if (src0t == GGML_TYPE_Q6_K) {
|
| 2795 |
-
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
| 2796 |
-
} else {
|
| 2797 |
-
const int64_t ny = (ne11 + nrows - 1)/nrows;
|
| 2798 |
-
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
| 2799 |
}
|
|
|
|
| 2800 |
}
|
| 2801 |
} break;
|
| 2802 |
case GGML_OP_MUL_MAT_ID:
|
|
@@ -2902,146 +2880,155 @@ static void ggml_metal_encode_node(
|
|
| 2902 |
|
| 2903 |
[encoder dispatchThreadgroups:MTLSizeMake((ne21 + 31)/32, (ne01 + 63)/64, n_as) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
|
| 2904 |
} else {
|
| 2905 |
-
int nth0 = 32;
|
| 2906 |
-
int nth1 = 1;
|
| 2907 |
-
int nrows = 1;
|
| 2908 |
-
//printf("vector: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
|
| 2909 |
-
|
| 2910 |
id<MTLComputePipelineState> pipeline = nil;
|
| 2911 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2912 |
// use custom matrix x vector kernel
|
| 2913 |
switch (src0t) {
|
| 2914 |
case GGML_TYPE_F32:
|
| 2915 |
{
|
| 2916 |
GGML_ASSERT(src1t == GGML_TYPE_F32);
|
|
|
|
|
|
|
| 2917 |
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32].pipeline;
|
| 2918 |
} break;
|
| 2919 |
case GGML_TYPE_F16:
|
| 2920 |
{
|
| 2921 |
GGML_ASSERT(src1t == GGML_TYPE_F32);
|
| 2922 |
-
|
| 2923 |
-
|
| 2924 |
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32].pipeline;
|
| 2925 |
} break;
|
| 2926 |
case GGML_TYPE_BF16:
|
| 2927 |
{
|
| 2928 |
GGML_ASSERT(src1t == GGML_TYPE_F32);
|
| 2929 |
-
|
| 2930 |
-
|
| 2931 |
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_BF16_F32].pipeline;
|
| 2932 |
} break;
|
| 2933 |
case GGML_TYPE_Q4_0:
|
| 2934 |
{
|
| 2935 |
-
|
| 2936 |
-
|
| 2937 |
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_0_F32].pipeline;
|
| 2938 |
} break;
|
| 2939 |
case GGML_TYPE_Q4_1:
|
| 2940 |
{
|
| 2941 |
-
|
| 2942 |
-
|
| 2943 |
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_1_F32].pipeline;
|
| 2944 |
} break;
|
| 2945 |
case GGML_TYPE_Q5_0:
|
| 2946 |
{
|
| 2947 |
-
|
| 2948 |
-
|
| 2949 |
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_0_F32].pipeline;
|
| 2950 |
} break;
|
| 2951 |
case GGML_TYPE_Q5_1:
|
| 2952 |
{
|
| 2953 |
-
|
| 2954 |
-
|
| 2955 |
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_1_F32].pipeline;
|
| 2956 |
} break;
|
| 2957 |
case GGML_TYPE_Q8_0:
|
| 2958 |
{
|
| 2959 |
-
|
| 2960 |
-
|
| 2961 |
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q8_0_F32].pipeline;
|
| 2962 |
} break;
|
| 2963 |
case GGML_TYPE_Q2_K:
|
| 2964 |
{
|
| 2965 |
-
|
| 2966 |
-
|
| 2967 |
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q2_K_F32].pipeline;
|
| 2968 |
} break;
|
| 2969 |
case GGML_TYPE_Q3_K:
|
| 2970 |
{
|
| 2971 |
-
|
| 2972 |
-
|
| 2973 |
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q3_K_F32].pipeline;
|
| 2974 |
} break;
|
| 2975 |
case GGML_TYPE_Q4_K:
|
| 2976 |
{
|
| 2977 |
-
|
| 2978 |
-
|
| 2979 |
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_K_F32].pipeline;
|
| 2980 |
} break;
|
| 2981 |
case GGML_TYPE_Q5_K:
|
| 2982 |
{
|
| 2983 |
-
|
| 2984 |
-
|
| 2985 |
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_K_F32].pipeline;
|
| 2986 |
} break;
|
| 2987 |
case GGML_TYPE_Q6_K:
|
| 2988 |
{
|
| 2989 |
-
|
| 2990 |
-
|
| 2991 |
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q6_K_F32].pipeline;
|
| 2992 |
} break;
|
| 2993 |
case GGML_TYPE_IQ2_XXS:
|
| 2994 |
{
|
| 2995 |
-
|
| 2996 |
-
|
|
|
|
| 2997 |
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XXS_F32].pipeline;
|
| 2998 |
} break;
|
| 2999 |
case GGML_TYPE_IQ2_XS:
|
| 3000 |
{
|
| 3001 |
-
|
| 3002 |
-
|
|
|
|
| 3003 |
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32].pipeline;
|
| 3004 |
} break;
|
| 3005 |
case GGML_TYPE_IQ3_XXS:
|
| 3006 |
{
|
| 3007 |
-
|
| 3008 |
-
|
|
|
|
| 3009 |
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_XXS_F32].pipeline;
|
| 3010 |
} break;
|
| 3011 |
case GGML_TYPE_IQ3_S:
|
| 3012 |
{
|
| 3013 |
-
|
| 3014 |
-
|
|
|
|
| 3015 |
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_S_F32].pipeline;
|
| 3016 |
} break;
|
| 3017 |
case GGML_TYPE_IQ2_S:
|
| 3018 |
{
|
| 3019 |
-
|
| 3020 |
-
|
| 3021 |
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_S_F32].pipeline;
|
| 3022 |
} break;
|
| 3023 |
case GGML_TYPE_IQ1_S:
|
| 3024 |
{
|
| 3025 |
-
|
| 3026 |
-
|
| 3027 |
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_S_F32].pipeline;
|
| 3028 |
} break;
|
| 3029 |
case GGML_TYPE_IQ1_M:
|
| 3030 |
{
|
| 3031 |
-
|
| 3032 |
-
|
| 3033 |
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_M_F32].pipeline;
|
| 3034 |
} break;
|
| 3035 |
case GGML_TYPE_IQ4_NL:
|
| 3036 |
{
|
| 3037 |
-
|
| 3038 |
-
|
|
|
|
| 3039 |
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32].pipeline;
|
| 3040 |
} break;
|
| 3041 |
case GGML_TYPE_IQ4_XS:
|
| 3042 |
{
|
| 3043 |
-
|
| 3044 |
-
|
|
|
|
| 3045 |
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32].pipeline;
|
| 3046 |
} break;
|
| 3047 |
default:
|
|
@@ -3052,7 +3039,7 @@ static void ggml_metal_encode_node(
|
|
| 3052 |
};
|
| 3053 |
|
| 3054 |
if (ggml_is_quantized(src0t)) {
|
| 3055 |
-
GGML_ASSERT(ne00 >=
|
| 3056 |
}
|
| 3057 |
|
| 3058 |
ggml_metal_kargs_mul_mv_id args = {
|
|
@@ -3085,43 +3072,12 @@ static void ggml_metal_encode_node(
|
|
| 3085 |
[encoder setBuffer:id_src2 offset:offs_src2 atIndex:4];
|
| 3086 |
|
| 3087 |
const int64_t _ne1 = 1;
|
| 3088 |
-
const
|
| 3089 |
|
| 3090 |
-
if (
|
| 3091 |
-
|
| 3092 |
-
src0t == GGML_TYPE_IQ1_S || src0t == GGML_TYPE_IQ1_M || src0t == GGML_TYPE_IQ2_S) {
|
| 3093 |
-
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
| 3094 |
-
}
|
| 3095 |
-
else if (src0t == GGML_TYPE_IQ2_XXS || src0t == GGML_TYPE_IQ2_XS) {
|
| 3096 |
-
const int mem_size = src0t == GGML_TYPE_IQ2_XXS ? 256*8+128 : 512*8+128;
|
| 3097 |
-
[encoder setThreadgroupMemoryLength:mem_size atIndex:0];
|
| 3098 |
-
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
| 3099 |
-
}
|
| 3100 |
-
else if (src0t == GGML_TYPE_IQ3_XXS || src0t == GGML_TYPE_IQ3_S) {
|
| 3101 |
-
const int mem_size = src0t == GGML_TYPE_IQ3_XXS ? 256*4+128 : 512*4;
|
| 3102 |
-
[encoder setThreadgroupMemoryLength:mem_size atIndex:0];
|
| 3103 |
-
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
| 3104 |
-
}
|
| 3105 |
-
else if (src0t == GGML_TYPE_IQ4_NL || src0t == GGML_TYPE_IQ4_XS) {
|
| 3106 |
-
const int mem_size = 32*sizeof(float);
|
| 3107 |
-
[encoder setThreadgroupMemoryLength:mem_size atIndex:0];
|
| 3108 |
-
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
| 3109 |
-
}
|
| 3110 |
-
else if (src0t == GGML_TYPE_Q4_K) {
|
| 3111 |
-
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
| 3112 |
-
}
|
| 3113 |
-
else if (src0t == GGML_TYPE_Q3_K) {
|
| 3114 |
-
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
| 3115 |
-
}
|
| 3116 |
-
else if (src0t == GGML_TYPE_Q5_K) {
|
| 3117 |
-
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
| 3118 |
-
}
|
| 3119 |
-
else if (src0t == GGML_TYPE_Q6_K) {
|
| 3120 |
-
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
| 3121 |
-
} else {
|
| 3122 |
-
const int64_t ny = (_ne1 + nrows - 1)/nrows; // = _ne1
|
| 3123 |
-
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
| 3124 |
}
|
|
|
|
| 3125 |
}
|
| 3126 |
} break;
|
| 3127 |
case GGML_OP_GET_ROWS:
|
|
|
|
| 2561 |
[encoder setThreadgroupMemoryLength:8192 atIndex:0];
|
| 2562 |
[encoder dispatchThreadgroups:MTLSizeMake( (ne11 + 31)/32, (ne01 + 63)/64, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
|
| 2563 |
} else {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2564 |
id<MTLComputePipelineState> pipeline = nil;
|
| 2565 |
|
| 2566 |
+
int nsg = 0; // number of simdgroups
|
| 2567 |
+
int nr0 = 0; // number of src0 rows per simdgroup
|
| 2568 |
+
int nr1 = 1; // number of src1 rows per threadgroup
|
| 2569 |
+
|
| 2570 |
+
size_t smem = 0; // shared memory
|
| 2571 |
+
|
| 2572 |
// use custom matrix x vector kernel
|
| 2573 |
switch (src0t) {
|
| 2574 |
case GGML_TYPE_F32:
|
| 2575 |
{
|
| 2576 |
GGML_ASSERT(src1t == GGML_TYPE_F32);
|
| 2577 |
+
nsg = 1;
|
| 2578 |
+
nr0 = 1;
|
| 2579 |
+
nr1 = 4;
|
| 2580 |
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32].pipeline;
|
|
|
|
| 2581 |
} break;
|
| 2582 |
case GGML_TYPE_F16:
|
| 2583 |
{
|
| 2584 |
+
nsg = 1;
|
| 2585 |
+
nr0 = 1;
|
| 2586 |
if (src1t == GGML_TYPE_F32) {
|
| 2587 |
if (ne11 * ne12 < 4) {
|
| 2588 |
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW].pipeline;
|
| 2589 |
} else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) {
|
| 2590 |
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4].pipeline;
|
| 2591 |
+
nr1 = ne11;
|
| 2592 |
} else {
|
| 2593 |
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32].pipeline;
|
| 2594 |
+
nr1 = 4;
|
| 2595 |
}
|
| 2596 |
} else {
|
| 2597 |
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16].pipeline;
|
| 2598 |
+
nr1 = 4;
|
| 2599 |
}
|
| 2600 |
} break;
|
| 2601 |
case GGML_TYPE_BF16:
|
| 2602 |
{
|
| 2603 |
+
nsg = 1;
|
| 2604 |
+
nr0 = 1;
|
| 2605 |
if (src1t == GGML_TYPE_F32) {
|
| 2606 |
if (ne11 * ne12 < 4) {
|
| 2607 |
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_1ROW].pipeline;
|
| 2608 |
} else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) {
|
| 2609 |
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_L4].pipeline;
|
| 2610 |
+
nr1 = ne11;
|
| 2611 |
} else {
|
| 2612 |
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32].pipeline;
|
| 2613 |
+
nr1 = 4;
|
| 2614 |
}
|
| 2615 |
} else {
|
| 2616 |
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_BF16].pipeline;
|
| 2617 |
+
nr1 = 4;
|
| 2618 |
}
|
| 2619 |
} break;
|
| 2620 |
case GGML_TYPE_Q4_0:
|
| 2621 |
{
|
| 2622 |
+
nsg = N_SG_Q4_0;
|
| 2623 |
+
nr0 = N_R0_Q4_0;
|
| 2624 |
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_0_F32].pipeline;
|
| 2625 |
} break;
|
| 2626 |
case GGML_TYPE_Q4_1:
|
| 2627 |
{
|
| 2628 |
+
nsg = N_SG_Q4_1;
|
| 2629 |
+
nr0 = N_R0_Q4_1;
|
| 2630 |
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_1_F32].pipeline;
|
| 2631 |
} break;
|
| 2632 |
case GGML_TYPE_Q5_0:
|
| 2633 |
{
|
| 2634 |
+
nsg = N_SG_Q5_0;
|
| 2635 |
+
nr0 = N_R0_Q5_0;
|
| 2636 |
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32].pipeline;
|
| 2637 |
} break;
|
| 2638 |
case GGML_TYPE_Q5_1:
|
| 2639 |
{
|
| 2640 |
+
nsg = N_SG_Q5_1;
|
| 2641 |
+
nr0 = N_R0_Q5_1;
|
| 2642 |
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32].pipeline;
|
| 2643 |
} break;
|
| 2644 |
case GGML_TYPE_Q8_0:
|
| 2645 |
{
|
| 2646 |
+
nsg = N_SG_Q8_0;
|
| 2647 |
+
nr0 = N_R0_Q8_0;
|
| 2648 |
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32].pipeline;
|
| 2649 |
} break;
|
| 2650 |
case GGML_TYPE_Q2_K:
|
| 2651 |
{
|
| 2652 |
+
nsg = N_SG_Q2_K;
|
| 2653 |
+
nr0 = N_R0_Q2_K;
|
| 2654 |
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q2_K_F32].pipeline;
|
| 2655 |
} break;
|
| 2656 |
case GGML_TYPE_Q3_K:
|
| 2657 |
{
|
| 2658 |
+
nsg = N_SG_Q3_K;
|
| 2659 |
+
nr0 = N_R0_Q3_K;
|
| 2660 |
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q3_K_F32].pipeline;
|
| 2661 |
} break;
|
| 2662 |
case GGML_TYPE_Q4_K:
|
| 2663 |
{
|
| 2664 |
+
nsg = N_SG_Q4_K;
|
| 2665 |
+
nr0 = N_R0_Q4_K;
|
| 2666 |
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_K_F32].pipeline;
|
| 2667 |
} break;
|
| 2668 |
case GGML_TYPE_Q5_K:
|
| 2669 |
{
|
| 2670 |
+
nsg = N_SG_Q5_K;
|
| 2671 |
+
nr0 = N_R0_Q5_K;
|
| 2672 |
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_K_F32].pipeline;
|
| 2673 |
} break;
|
| 2674 |
case GGML_TYPE_Q6_K:
|
| 2675 |
{
|
| 2676 |
+
nsg = N_SG_Q6_K;
|
| 2677 |
+
nr0 = N_R0_Q6_K;
|
| 2678 |
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q6_K_F32].pipeline;
|
| 2679 |
} break;
|
| 2680 |
case GGML_TYPE_IQ2_XXS:
|
| 2681 |
{
|
| 2682 |
+
nsg = N_SG_IQ2_XXS;
|
| 2683 |
+
nr0 = N_R0_IQ2_XXS;
|
| 2684 |
+
smem = 256*8+128;
|
| 2685 |
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XXS_F32].pipeline;
|
| 2686 |
} break;
|
| 2687 |
case GGML_TYPE_IQ2_XS:
|
| 2688 |
{
|
| 2689 |
+
nsg = N_SG_IQ2_XS;
|
| 2690 |
+
nr0 = N_R0_IQ2_XS;
|
| 2691 |
+
smem = 512*8+128;
|
| 2692 |
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32].pipeline;
|
| 2693 |
} break;
|
| 2694 |
case GGML_TYPE_IQ3_XXS:
|
| 2695 |
{
|
| 2696 |
+
nsg = N_SG_IQ3_XXS;
|
| 2697 |
+
nr0 = N_R0_IQ3_XXS;
|
| 2698 |
+
smem = 256*4+128;
|
| 2699 |
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_XXS_F32].pipeline;
|
| 2700 |
} break;
|
| 2701 |
case GGML_TYPE_IQ3_S:
|
| 2702 |
{
|
| 2703 |
+
nsg = N_SG_IQ3_S;
|
| 2704 |
+
nr0 = N_R0_IQ3_S;
|
| 2705 |
+
smem = 512*4;
|
| 2706 |
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_S_F32].pipeline;
|
| 2707 |
} break;
|
| 2708 |
case GGML_TYPE_IQ2_S:
|
| 2709 |
{
|
| 2710 |
+
nsg = N_SG_IQ2_S;
|
| 2711 |
+
nr0 = N_R0_IQ2_S;
|
| 2712 |
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_S_F32].pipeline;
|
| 2713 |
} break;
|
| 2714 |
case GGML_TYPE_IQ1_S:
|
| 2715 |
{
|
| 2716 |
+
nsg = N_SG_IQ1_S;
|
| 2717 |
+
nr0 = N_R0_IQ1_S;
|
| 2718 |
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_S_F32].pipeline;
|
| 2719 |
} break;
|
| 2720 |
case GGML_TYPE_IQ1_M:
|
| 2721 |
{
|
| 2722 |
+
nsg = N_SG_IQ1_M;
|
| 2723 |
+
nr0 = N_R0_IQ1_M;
|
| 2724 |
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_M_F32].pipeline;
|
| 2725 |
} break;
|
| 2726 |
case GGML_TYPE_IQ4_NL:
|
| 2727 |
{
|
| 2728 |
+
nsg = N_SG_IQ4_NL;
|
| 2729 |
+
nr0 = N_R0_IQ4_NL;
|
| 2730 |
+
smem = 32*sizeof(float);
|
| 2731 |
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32].pipeline;
|
| 2732 |
} break;
|
| 2733 |
case GGML_TYPE_IQ4_XS:
|
| 2734 |
{
|
| 2735 |
+
nsg = N_SG_IQ4_XS;
|
| 2736 |
+
nr0 = N_R0_IQ4_XS;
|
| 2737 |
+
smem = 32*sizeof(float);
|
| 2738 |
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_XS_F32].pipeline;
|
| 2739 |
} break;
|
| 2740 |
default:
|
|
|
|
| 2771 |
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:2];
|
| 2772 |
[encoder setBuffer:id_dst offset:offs_dst atIndex:3];
|
| 2773 |
|
| 2774 |
+
if (smem > 0) {
|
| 2775 |
+
[encoder setThreadgroupMemoryLength:smem atIndex:0];
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2776 |
}
|
| 2777 |
+
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + nr0*nsg - 1)/(nr0*nsg), (ne11 + nr1 - 1)/nr1, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)];
|
| 2778 |
}
|
| 2779 |
} break;
|
| 2780 |
case GGML_OP_MUL_MAT_ID:
|
|
|
|
| 2880 |
|
| 2881 |
[encoder dispatchThreadgroups:MTLSizeMake((ne21 + 31)/32, (ne01 + 63)/64, n_as) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
|
| 2882 |
} else {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2883 |
id<MTLComputePipelineState> pipeline = nil;
|
| 2884 |
|
| 2885 |
+
int nsg = 0; // number of simdgroups
|
| 2886 |
+
int nr0 = 0; // number of src0 rows per simdgroup
|
| 2887 |
+
int nr1 = 1; // number of src1 rows per threadgroup
|
| 2888 |
+
|
| 2889 |
+
size_t smem = 0; // shared memory
|
| 2890 |
+
|
| 2891 |
// use custom matrix x vector kernel
|
| 2892 |
switch (src0t) {
|
| 2893 |
case GGML_TYPE_F32:
|
| 2894 |
{
|
| 2895 |
GGML_ASSERT(src1t == GGML_TYPE_F32);
|
| 2896 |
+
nsg = 1;
|
| 2897 |
+
nr0 = 1;
|
| 2898 |
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32].pipeline;
|
| 2899 |
} break;
|
| 2900 |
case GGML_TYPE_F16:
|
| 2901 |
{
|
| 2902 |
GGML_ASSERT(src1t == GGML_TYPE_F32);
|
| 2903 |
+
nsg = 1;
|
| 2904 |
+
nr0 = 1;
|
| 2905 |
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32].pipeline;
|
| 2906 |
} break;
|
| 2907 |
case GGML_TYPE_BF16:
|
| 2908 |
{
|
| 2909 |
GGML_ASSERT(src1t == GGML_TYPE_F32);
|
| 2910 |
+
nsg = 1;
|
| 2911 |
+
nr0 = 1;
|
| 2912 |
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_BF16_F32].pipeline;
|
| 2913 |
} break;
|
| 2914 |
case GGML_TYPE_Q4_0:
|
| 2915 |
{
|
| 2916 |
+
nsg = N_SG_Q4_0;
|
| 2917 |
+
nr0 = N_R0_Q4_0;
|
| 2918 |
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_0_F32].pipeline;
|
| 2919 |
} break;
|
| 2920 |
case GGML_TYPE_Q4_1:
|
| 2921 |
{
|
| 2922 |
+
nsg = N_SG_Q4_1;
|
| 2923 |
+
nr0 = N_R0_Q4_1;
|
| 2924 |
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_1_F32].pipeline;
|
| 2925 |
} break;
|
| 2926 |
case GGML_TYPE_Q5_0:
|
| 2927 |
{
|
| 2928 |
+
nsg = N_SG_Q5_0;
|
| 2929 |
+
nr0 = N_R0_Q5_0;
|
| 2930 |
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_0_F32].pipeline;
|
| 2931 |
} break;
|
| 2932 |
case GGML_TYPE_Q5_1:
|
| 2933 |
{
|
| 2934 |
+
nsg = N_SG_Q5_1;
|
| 2935 |
+
nr0 = N_R0_Q5_1;
|
| 2936 |
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_1_F32].pipeline;
|
| 2937 |
} break;
|
| 2938 |
case GGML_TYPE_Q8_0:
|
| 2939 |
{
|
| 2940 |
+
nsg = N_SG_Q8_0;
|
| 2941 |
+
nr0 = N_R0_Q8_0;
|
| 2942 |
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q8_0_F32].pipeline;
|
| 2943 |
} break;
|
| 2944 |
case GGML_TYPE_Q2_K:
|
| 2945 |
{
|
| 2946 |
+
nsg = N_SG_Q2_K;
|
| 2947 |
+
nr0 = N_R0_Q2_K;
|
| 2948 |
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q2_K_F32].pipeline;
|
| 2949 |
} break;
|
| 2950 |
case GGML_TYPE_Q3_K:
|
| 2951 |
{
|
| 2952 |
+
nsg = N_SG_Q3_K;
|
| 2953 |
+
nr0 = N_R0_Q3_K;
|
| 2954 |
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q3_K_F32].pipeline;
|
| 2955 |
} break;
|
| 2956 |
case GGML_TYPE_Q4_K:
|
| 2957 |
{
|
| 2958 |
+
nsg = N_SG_Q4_K;
|
| 2959 |
+
nr0 = N_R0_Q4_K;
|
| 2960 |
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_K_F32].pipeline;
|
| 2961 |
} break;
|
| 2962 |
case GGML_TYPE_Q5_K:
|
| 2963 |
{
|
| 2964 |
+
nsg = N_SG_Q5_K;
|
| 2965 |
+
nr0 = N_R0_Q5_K;
|
| 2966 |
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_K_F32].pipeline;
|
| 2967 |
} break;
|
| 2968 |
case GGML_TYPE_Q6_K:
|
| 2969 |
{
|
| 2970 |
+
nsg = N_SG_Q6_K;
|
| 2971 |
+
nr0 = N_R0_Q6_K;
|
| 2972 |
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q6_K_F32].pipeline;
|
| 2973 |
} break;
|
| 2974 |
case GGML_TYPE_IQ2_XXS:
|
| 2975 |
{
|
| 2976 |
+
nsg = N_SG_IQ2_XXS;
|
| 2977 |
+
nr0 = N_R0_IQ2_XXS;
|
| 2978 |
+
smem = 256*8+128;
|
| 2979 |
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XXS_F32].pipeline;
|
| 2980 |
} break;
|
| 2981 |
case GGML_TYPE_IQ2_XS:
|
| 2982 |
{
|
| 2983 |
+
nsg = N_SG_IQ2_XS;
|
| 2984 |
+
nr0 = N_R0_IQ2_XS;
|
| 2985 |
+
smem = 512*8+128;
|
| 2986 |
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32].pipeline;
|
| 2987 |
} break;
|
| 2988 |
case GGML_TYPE_IQ3_XXS:
|
| 2989 |
{
|
| 2990 |
+
nsg = N_SG_IQ3_XXS;
|
| 2991 |
+
nr0 = N_R0_IQ3_XXS;
|
| 2992 |
+
smem = 256*4+128;
|
| 2993 |
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_XXS_F32].pipeline;
|
| 2994 |
} break;
|
| 2995 |
case GGML_TYPE_IQ3_S:
|
| 2996 |
{
|
| 2997 |
+
nsg = N_SG_IQ3_S;
|
| 2998 |
+
nr0 = N_R0_IQ3_S;
|
| 2999 |
+
smem = 512*4;
|
| 3000 |
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_S_F32].pipeline;
|
| 3001 |
} break;
|
| 3002 |
case GGML_TYPE_IQ2_S:
|
| 3003 |
{
|
| 3004 |
+
nsg = N_SG_IQ2_S;
|
| 3005 |
+
nr0 = N_R0_IQ2_S;
|
| 3006 |
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_S_F32].pipeline;
|
| 3007 |
} break;
|
| 3008 |
case GGML_TYPE_IQ1_S:
|
| 3009 |
{
|
| 3010 |
+
nsg = N_SG_IQ1_S;
|
| 3011 |
+
nr0 = N_R0_IQ1_S;
|
| 3012 |
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_S_F32].pipeline;
|
| 3013 |
} break;
|
| 3014 |
case GGML_TYPE_IQ1_M:
|
| 3015 |
{
|
| 3016 |
+
nsg = N_SG_IQ1_M;
|
| 3017 |
+
nr0 = N_R0_IQ1_M;
|
| 3018 |
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_M_F32].pipeline;
|
| 3019 |
} break;
|
| 3020 |
case GGML_TYPE_IQ4_NL:
|
| 3021 |
{
|
| 3022 |
+
nsg = N_SG_IQ4_NL;
|
| 3023 |
+
nr0 = N_R0_IQ4_NL;
|
| 3024 |
+
smem = 32*sizeof(float);
|
| 3025 |
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32].pipeline;
|
| 3026 |
} break;
|
| 3027 |
case GGML_TYPE_IQ4_XS:
|
| 3028 |
{
|
| 3029 |
+
nsg = N_SG_IQ4_XS;
|
| 3030 |
+
nr0 = N_R0_IQ4_XS;
|
| 3031 |
+
smem = 32*sizeof(float);
|
| 3032 |
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32].pipeline;
|
| 3033 |
} break;
|
| 3034 |
default:
|
|
|
|
| 3039 |
};
|
| 3040 |
|
| 3041 |
if (ggml_is_quantized(src0t)) {
|
| 3042 |
+
GGML_ASSERT(ne00 >= nsg*nr0);
|
| 3043 |
}
|
| 3044 |
|
| 3045 |
ggml_metal_kargs_mul_mv_id args = {
|
|
|
|
| 3072 |
[encoder setBuffer:id_src2 offset:offs_src2 atIndex:4];
|
| 3073 |
|
| 3074 |
const int64_t _ne1 = 1;
|
| 3075 |
+
const int64_t ne123 = dst_rows;
|
| 3076 |
|
| 3077 |
+
if (smem > 0) {
|
| 3078 |
+
[encoder setThreadgroupMemoryLength:smem atIndex:0];
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3079 |
}
|
| 3080 |
+
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + nr0*nsg - 1)/(nr0*nsg), (_ne1 + nr1 - 1)/nr1, ne123) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)];
|
| 3081 |
}
|
| 3082 |
} break;
|
| 3083 |
case GGML_OP_GET_ROWS:
|
ggml/src/ggml-metal/ggml-metal.metal
CHANGED
|
@@ -1439,7 +1439,7 @@ kernel void kernel_rwkv_wkv7_f32(
|
|
| 1439 |
|
| 1440 |
float4 sa_vec(0.0);
|
| 1441 |
|
| 1442 |
-
for (
|
| 1443 |
float4 a_vec = float4(_a[j], _a[j+1], _a[j+2], _a[j+3]);
|
| 1444 |
float4 s_vec = float4(state[j], state[j+1], state[j+2], state[j+3]);
|
| 1445 |
sa_vec += a_vec * s_vec;
|
|
@@ -1853,14 +1853,7 @@ inline float block_q_n_dot_y(device const block_q5_1 * qb_curr, float sumy, thre
|
|
| 1853 |
return d * (acc[0] + acc[1] + acc[2] + acc[3]) + sumy * m;
|
| 1854 |
}
|
| 1855 |
|
| 1856 |
-
|
| 1857 |
-
#define N_DST 4 // each SIMD group works on 4 rows
|
| 1858 |
-
#define N_SIMDGROUP 2 // number of SIMD groups in a thread group
|
| 1859 |
-
//Note: This is a template, but strictly speaking it only applies to
|
| 1860 |
-
// quantizations where the block size is 32. It also does not
|
| 1861 |
-
// guard against the number of rows not being divisible by
|
| 1862 |
-
// N_DST, so this is another explicit assumption of the implementation.
|
| 1863 |
-
template<typename block_q_type, int nr, int nsg, int nw, typename args_t>
|
| 1864 |
void mul_vec_q_n_f32_impl(
|
| 1865 |
args_t args,
|
| 1866 |
device const char * src0,
|
|
@@ -1876,7 +1869,7 @@ void mul_vec_q_n_f32_impl(
|
|
| 1876 |
const int r1 = tgpig.y;
|
| 1877 |
const int im = tgpig.z;
|
| 1878 |
|
| 1879 |
-
const int first_row = (r0 * nsg + sgitg) *
|
| 1880 |
|
| 1881 |
const uint i12 = im%args.ne12;
|
| 1882 |
const uint i13 = im/args.ne12;
|
|
@@ -1888,15 +1881,15 @@ void mul_vec_q_n_f32_impl(
|
|
| 1888 |
device const float * y = (device const float *) (src1 + offset1);
|
| 1889 |
|
| 1890 |
// pointers to src0 rows
|
| 1891 |
-
device const block_q_type * ax[
|
| 1892 |
-
for (int row = 0; row <
|
| 1893 |
const uint64_t offset0 = (first_row + row)*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
|
| 1894 |
|
| 1895 |
ax[row] = (device const block_q_type *) ((device char *) src0 + offset0);
|
| 1896 |
}
|
| 1897 |
|
| 1898 |
float yl[16]; // src1 vector cache
|
| 1899 |
-
float sumf[
|
| 1900 |
|
| 1901 |
const short ix = (tiisg/2);
|
| 1902 |
const short il = (tiisg%2)*8;
|
|
@@ -1908,7 +1901,7 @@ void mul_vec_q_n_f32_impl(
|
|
| 1908 |
float sumy[2] = { 0.f, 0.f };
|
| 1909 |
|
| 1910 |
#pragma unroll
|
| 1911 |
-
for (
|
| 1912 |
sumy[0] += yb[i + 0] + yb[i + 1];
|
| 1913 |
yl[i + 0] = yb[i + 0];
|
| 1914 |
yl[i + 1] = yb[i + 1]/256.f;
|
|
@@ -1919,7 +1912,7 @@ void mul_vec_q_n_f32_impl(
|
|
| 1919 |
}
|
| 1920 |
|
| 1921 |
#pragma unroll
|
| 1922 |
-
for (
|
| 1923 |
sumf[row] += block_q_n_dot_y(ax[row] + ib, sumy[0] + sumy[1], yl, il);
|
| 1924 |
}
|
| 1925 |
|
|
@@ -1928,7 +1921,7 @@ void mul_vec_q_n_f32_impl(
|
|
| 1928 |
|
| 1929 |
device float * dst_f32 = (device float *) dst + im*args.ne0*args.ne1 + r1*args.ne0;
|
| 1930 |
|
| 1931 |
-
for (int row = 0; row <
|
| 1932 |
const float tot = simd_sum(sumf[row]);
|
| 1933 |
|
| 1934 |
if (tiisg == 0 && first_row + row < args.ne01) {
|
|
@@ -1945,7 +1938,7 @@ kernel void kernel_mul_mv_q4_0_f32(
|
|
| 1945 |
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 1946 |
ushort tiisg[[thread_index_in_simdgroup]],
|
| 1947 |
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
|
| 1948 |
-
mul_vec_q_n_f32_impl<block_q4_0,
|
| 1949 |
}
|
| 1950 |
|
| 1951 |
kernel void kernel_mul_mv_q4_1_f32(
|
|
@@ -1956,7 +1949,7 @@ kernel void kernel_mul_mv_q4_1_f32(
|
|
| 1956 |
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 1957 |
ushort tiisg[[thread_index_in_simdgroup]],
|
| 1958 |
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
|
| 1959 |
-
mul_vec_q_n_f32_impl<block_q4_1,
|
| 1960 |
}
|
| 1961 |
|
| 1962 |
kernel void kernel_mul_mv_q5_0_f32(
|
|
@@ -1967,7 +1960,7 @@ kernel void kernel_mul_mv_q5_0_f32(
|
|
| 1967 |
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 1968 |
ushort tiisg[[thread_index_in_simdgroup]],
|
| 1969 |
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
|
| 1970 |
-
mul_vec_q_n_f32_impl<block_q5_0,
|
| 1971 |
}
|
| 1972 |
|
| 1973 |
kernel void kernel_mul_mv_q5_1_f32(
|
|
@@ -1978,12 +1971,12 @@ kernel void kernel_mul_mv_q5_1_f32(
|
|
| 1978 |
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 1979 |
ushort tiisg[[thread_index_in_simdgroup]],
|
| 1980 |
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
|
| 1981 |
-
mul_vec_q_n_f32_impl<block_q5_1,
|
| 1982 |
}
|
| 1983 |
|
| 1984 |
#define NB_Q8_0 8
|
| 1985 |
|
| 1986 |
-
template<typename args_t>
|
| 1987 |
void kernel_mul_mv_q8_0_f32_impl(
|
| 1988 |
args_t args,
|
| 1989 |
device const char * src0,
|
|
@@ -1993,16 +1986,13 @@ void kernel_mul_mv_q8_0_f32_impl(
|
|
| 1993 |
uint3 tgpig,
|
| 1994 |
ushort tiisg,
|
| 1995 |
ushort sgitg) {
|
| 1996 |
-
const int nr = N_DST;
|
| 1997 |
-
const int nsg = N_SIMDGROUP;
|
| 1998 |
-
const int nw = N_SIMDWIDTH;
|
| 1999 |
-
|
| 2000 |
const int nb = args.ne00/QK8_0;
|
|
|
|
| 2001 |
const int r0 = tgpig.x;
|
| 2002 |
const int r1 = tgpig.y;
|
| 2003 |
const int im = tgpig.z;
|
| 2004 |
|
| 2005 |
-
const int first_row = (r0*nsg + sgitg)*
|
| 2006 |
|
| 2007 |
const uint i12 = im%args.ne12;
|
| 2008 |
const uint i13 = im/args.ne12;
|
|
@@ -2014,15 +2004,15 @@ void kernel_mul_mv_q8_0_f32_impl(
|
|
| 2014 |
device const float * y = (device const float *) (src1 + offset1);
|
| 2015 |
|
| 2016 |
// pointers to src0 rows
|
| 2017 |
-
device const block_q8_0 * ax[
|
| 2018 |
-
for (int row = 0; row <
|
| 2019 |
const uint64_t offset0 = (first_row + row)*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
|
| 2020 |
|
| 2021 |
ax[row] = (device const block_q8_0 *) ((device char *) src0 + offset0);
|
| 2022 |
}
|
| 2023 |
|
| 2024 |
float yl[NB_Q8_0];
|
| 2025 |
-
float sumf[
|
| 2026 |
|
| 2027 |
const short ix = tiisg/4;
|
| 2028 |
const short il = tiisg%4;
|
|
@@ -2035,7 +2025,7 @@ void kernel_mul_mv_q8_0_f32_impl(
|
|
| 2035 |
yl[i] = yb[i];
|
| 2036 |
}
|
| 2037 |
|
| 2038 |
-
for (
|
| 2039 |
device const int8_t * qs = ax[row][ib].qs + il*NB_Q8_0;
|
| 2040 |
float sumq = 0.f;
|
| 2041 |
for (short iq = 0; iq < NB_Q8_0; ++iq) {
|
|
@@ -2049,7 +2039,7 @@ void kernel_mul_mv_q8_0_f32_impl(
|
|
| 2049 |
|
| 2050 |
device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
|
| 2051 |
|
| 2052 |
-
for (int row = 0; row <
|
| 2053 |
const float tot = simd_sum(sumf[row]);
|
| 2054 |
|
| 2055 |
if (tiisg == 0 && first_row + row < args.ne01) {
|
|
@@ -2067,7 +2057,7 @@ kernel void kernel_mul_mv_q8_0_f32(
|
|
| 2067 |
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 2068 |
ushort tiisg[[thread_index_in_simdgroup]],
|
| 2069 |
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
|
| 2070 |
-
kernel_mul_mv_q8_0_f32_impl<constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
|
| 2071 |
}
|
| 2072 |
|
| 2073 |
// mat-vec kernel processing in chunks of float4
|
|
@@ -2404,9 +2394,9 @@ void kernel_mul_mv_impl(
|
|
| 2404 |
sumf += (T0) x[i] * (T1) y[i];
|
| 2405 |
}
|
| 2406 |
|
| 2407 |
-
float
|
| 2408 |
if (tiisg == 0) {
|
| 2409 |
-
dst_f32[(uint64_t)r1*args.ne0 + r0] =
|
| 2410 |
}
|
| 2411 |
}
|
| 2412 |
} else {
|
|
@@ -2427,10 +2417,10 @@ void kernel_mul_mv_impl(
|
|
| 2427 |
sumf += dot((float4) x4[i], (float4) y4[i]);
|
| 2428 |
}
|
| 2429 |
|
| 2430 |
-
float
|
| 2431 |
if (tiisg == 0) {
|
| 2432 |
-
for (int i = 4*(args.ne00/4); i < args.ne00; ++i)
|
| 2433 |
-
dst_f32[(uint64_t)r1*args.ne0 + r0] =
|
| 2434 |
}
|
| 2435 |
}
|
| 2436 |
}
|
|
@@ -2492,9 +2482,9 @@ kernel void kernel_mul_mv_1row(
|
|
| 2492 |
for (int i = tiisg; i < args.ne00; i += 32) {
|
| 2493 |
sumf += (float) x[i] * (float) y[i];
|
| 2494 |
}
|
| 2495 |
-
float
|
| 2496 |
if (tiisg == 0) {
|
| 2497 |
-
dst_f32[r0] =
|
| 2498 |
}
|
| 2499 |
} else {
|
| 2500 |
device const T4 * x4 = (device const T4 *) x;
|
|
@@ -2504,11 +2494,11 @@ kernel void kernel_mul_mv_1row(
|
|
| 2504 |
sumf += dot((float4) x4[i], y4[i]);
|
| 2505 |
}
|
| 2506 |
|
| 2507 |
-
float
|
| 2508 |
|
| 2509 |
if (tiisg == 0) {
|
| 2510 |
-
for (int i = 4*(args.ne00/4); i < args.ne00; ++i)
|
| 2511 |
-
dst_f32[r0] =
|
| 2512 |
}
|
| 2513 |
}
|
| 2514 |
}
|
|
@@ -2553,9 +2543,9 @@ kernel void kernel_mul_mv_l4(
|
|
| 2553 |
sumf += dot((float4) x4[i], y4[i]);
|
| 2554 |
}
|
| 2555 |
|
| 2556 |
-
float
|
| 2557 |
if (tiisg == 0) {
|
| 2558 |
-
dst_f32[(uint64_t)r1*args.ne0 + r0] =
|
| 2559 |
}
|
| 2560 |
}
|
| 2561 |
}
|
|
@@ -4321,7 +4311,7 @@ kernel void kernel_cpy_f32_iq4_nl(
|
|
| 4321 |
float amax = 0.0f; // absolute max
|
| 4322 |
float max = 0.0f;
|
| 4323 |
|
| 4324 |
-
for (int j = 0; j <
|
| 4325 |
const float v = src[j];
|
| 4326 |
if (amax < fabs(v)) {
|
| 4327 |
amax = fabs(v);
|
|
@@ -4429,7 +4419,7 @@ kernel void kernel_concat(
|
|
| 4429 |
}
|
| 4430 |
}
|
| 4431 |
|
| 4432 |
-
template<typename args_t>
|
| 4433 |
void kernel_mul_mv_q2_K_f32_impl(
|
| 4434 |
args_t args,
|
| 4435 |
device const char * src0,
|
|
@@ -4445,7 +4435,7 @@ void kernel_mul_mv_q2_K_f32_impl(
|
|
| 4445 |
const int r1 = tgpig.y;
|
| 4446 |
const int im = tgpig.z;
|
| 4447 |
|
| 4448 |
-
const int first_row = (r0 *
|
| 4449 |
|
| 4450 |
const uint i12 = im%args.ne12;
|
| 4451 |
const uint i13 = im/args.ne12;
|
|
@@ -4457,20 +4447,19 @@ void kernel_mul_mv_q2_K_f32_impl(
|
|
| 4457 |
device const float * y = (device const float *) (src1 + offset1);
|
| 4458 |
|
| 4459 |
float yl[32];
|
| 4460 |
-
float sumf[
|
| 4461 |
|
| 4462 |
-
const
|
| 4463 |
-
const
|
| 4464 |
-
const
|
| 4465 |
-
const
|
| 4466 |
-
const
|
| 4467 |
|
| 4468 |
device const float * y4 = y + ix * QK_K + 128 * iq + 8 * ir;
|
| 4469 |
|
| 4470 |
for (int ib = ix; ib < nb; ib += 4) {
|
| 4471 |
-
|
| 4472 |
float4 sumy = {0.f, 0.f, 0.f, 0.f};
|
| 4473 |
-
for (
|
| 4474 |
yl[i+ 0] = y4[i+ 0]; sumy[0] += yl[i+ 0];
|
| 4475 |
yl[i+ 8] = y4[i+32]; sumy[1] += yl[i+ 8];
|
| 4476 |
yl[i+16] = y4[i+64]; sumy[2] += yl[i+16];
|
|
@@ -4481,7 +4470,7 @@ void kernel_mul_mv_q2_K_f32_impl(
|
|
| 4481 |
device const uint16_t * qs = (device const uint16_t *)x[ib].qs + 16 * iq + 4 * ir;
|
| 4482 |
device const half * dh = &x[ib].d;
|
| 4483 |
|
| 4484 |
-
for (
|
| 4485 |
float4 acc1 = {0.f, 0.f, 0.f, 0.f};
|
| 4486 |
float4 acc2 = {0.f, 0.f, 0.f, 0.f};
|
| 4487 |
for (int i = 0; i < 8; i += 2) {
|
|
@@ -4512,10 +4501,10 @@ void kernel_mul_mv_q2_K_f32_impl(
|
|
| 4512 |
|
| 4513 |
device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
|
| 4514 |
|
| 4515 |
-
for (int row = 0; row <
|
| 4516 |
-
|
| 4517 |
if (tiisg == 0) {
|
| 4518 |
-
dst_f32[first_row + row] =
|
| 4519 |
}
|
| 4520 |
}
|
| 4521 |
}
|
|
@@ -4530,10 +4519,10 @@ kernel void kernel_mul_mv_q2_K_f32(
|
|
| 4530 |
ushort tiisg[[thread_index_in_simdgroup]],
|
| 4531 |
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
|
| 4532 |
|
| 4533 |
-
kernel_mul_mv_q2_K_f32_impl<constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
|
| 4534 |
}
|
| 4535 |
|
| 4536 |
-
template<typename args_t>
|
| 4537 |
void kernel_mul_mv_q3_K_f32_impl(
|
| 4538 |
args_t args,
|
| 4539 |
device const char * src0,
|
|
@@ -4550,7 +4539,7 @@ void kernel_mul_mv_q3_K_f32_impl(
|
|
| 4550 |
const int r1 = tgpig.y;
|
| 4551 |
const int im = tgpig.z;
|
| 4552 |
|
| 4553 |
-
const int first_row = (r0 *
|
| 4554 |
|
| 4555 |
const uint i12 = im%args.ne12;
|
| 4556 |
const uint i13 = im/args.ne12;
|
|
@@ -4566,13 +4555,12 @@ void kernel_mul_mv_q3_K_f32_impl(
|
|
| 4566 |
//const uint16_t kmask1 = 0x3030;
|
| 4567 |
//const uint16_t kmask2 = 0x0f0f;
|
| 4568 |
|
| 4569 |
-
const
|
| 4570 |
-
const
|
| 4571 |
-
const
|
| 4572 |
-
const
|
| 4573 |
-
const
|
| 4574 |
-
const
|
| 4575 |
-
const int l0 = n*ir;
|
| 4576 |
|
| 4577 |
// One would think that the Metal compiler would figure out that ip and il can only have
|
| 4578 |
// 4 possible states, and optimize accordingly. Well, no. It needs help, and we do it
|
|
@@ -4597,8 +4585,8 @@ void kernel_mul_mv_q3_K_f32_impl(
|
|
| 4597 |
const uint16_t s_shift1 = 4*ip;
|
| 4598 |
const uint16_t s_shift2 = s_shift1 + il;
|
| 4599 |
|
| 4600 |
-
const
|
| 4601 |
-
const
|
| 4602 |
|
| 4603 |
device const float * y1 = yy + ix*QK_K + y_offset;
|
| 4604 |
|
|
@@ -4606,10 +4594,11 @@ void kernel_mul_mv_q3_K_f32_impl(
|
|
| 4606 |
thread uint16_t * scales16 = (thread uint16_t *)&scales32;
|
| 4607 |
thread const int8_t * scales = (thread const int8_t *)&scales32;
|
| 4608 |
|
| 4609 |
-
float sumf1[
|
| 4610 |
-
float sumf2[
|
|
|
|
| 4611 |
for (int i = ix; i < nb; i += 4) {
|
| 4612 |
-
for (
|
| 4613 |
yl[l+ 0] = y1[l+ 0];
|
| 4614 |
yl[l+ 8] = y1[l+16];
|
| 4615 |
yl[l+16] = y1[l+32];
|
|
@@ -4621,7 +4610,7 @@ void kernel_mul_mv_q3_K_f32_impl(
|
|
| 4621 |
device const uint16_t * a = (device const uint16_t *)(x[i].scales);
|
| 4622 |
device const half * dh = &x[i].d;
|
| 4623 |
|
| 4624 |
-
for (
|
| 4625 |
const float d_all = (float)dh[0];
|
| 4626 |
|
| 4627 |
scales16[0] = a[4];
|
|
@@ -4632,7 +4621,7 @@ void kernel_mul_mv_q3_K_f32_impl(
|
|
| 4632 |
scales32 = ((scales32 >> s_shift1) & 0x0f0f0f0f) | aux32;
|
| 4633 |
|
| 4634 |
float s1 = 0, s2 = 0, s3 = 0, s4 = 0, s5 = 0, s6 = 0;
|
| 4635 |
-
for (
|
| 4636 |
const int32_t qs = q[l/2];
|
| 4637 |
s1 += yl[l+0] * (qs & qm[il/2][0]);
|
| 4638 |
s2 += yl[l+1] * (qs & qm[il/2][1]);
|
|
@@ -4647,7 +4636,7 @@ void kernel_mul_mv_q3_K_f32_impl(
|
|
| 4647 |
sumf2[row] += d2 * (scales[2] - 32);
|
| 4648 |
|
| 4649 |
s1 = s2 = s3 = s4 = s5 = s6 = 0;
|
| 4650 |
-
for (
|
| 4651 |
const int32_t qs = q[l/2+8];
|
| 4652 |
s1 += yl[l+8] * (qs & qm[il/2][0]);
|
| 4653 |
s2 += yl[l+9] * (qs & qm[il/2][1]);
|
|
@@ -4670,7 +4659,7 @@ void kernel_mul_mv_q3_K_f32_impl(
|
|
| 4670 |
y1 += 4 * QK_K;
|
| 4671 |
}
|
| 4672 |
|
| 4673 |
-
for (int row = 0; row <
|
| 4674 |
const float sumf = (sumf1[row] + 0.25f * sumf2[row]) / (1 << shift);
|
| 4675 |
sumf1[row] = simd_sum(sumf);
|
| 4676 |
}
|
|
@@ -4678,7 +4667,7 @@ void kernel_mul_mv_q3_K_f32_impl(
|
|
| 4678 |
device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
|
| 4679 |
|
| 4680 |
if (tiisg == 0) {
|
| 4681 |
-
for (int row = 0; row <
|
| 4682 |
dst_f32[first_row + row] = sumf1[row];
|
| 4683 |
}
|
| 4684 |
}
|
|
@@ -4694,10 +4683,10 @@ kernel void kernel_mul_mv_q3_K_f32(
|
|
| 4694 |
ushort tiisg[[thread_index_in_simdgroup]],
|
| 4695 |
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
|
| 4696 |
|
| 4697 |
-
kernel_mul_mv_q3_K_f32_impl<constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
|
| 4698 |
}
|
| 4699 |
|
| 4700 |
-
template<typename args_t>
|
| 4701 |
void kernel_mul_mv_q4_K_f32_impl(
|
| 4702 |
args_t args,
|
| 4703 |
device const char * src0,
|
|
@@ -4707,22 +4696,22 @@ void kernel_mul_mv_q4_K_f32_impl(
|
|
| 4707 |
uint3 tgpig,
|
| 4708 |
ushort tiisg,
|
| 4709 |
ushort sgitg) {
|
| 4710 |
-
|
| 4711 |
const uint16_t kmask1 = 0x3f3f;
|
| 4712 |
const uint16_t kmask2 = 0x0f0f;
|
| 4713 |
const uint16_t kmask3 = 0xc0c0;
|
| 4714 |
|
| 4715 |
-
const
|
| 4716 |
-
const
|
| 4717 |
-
const
|
| 4718 |
-
const
|
| 4719 |
|
| 4720 |
const int nb = args.ne00/QK_K;
|
|
|
|
| 4721 |
const int r0 = tgpig.x;
|
| 4722 |
const int r1 = tgpig.y;
|
| 4723 |
const int im = tgpig.z;
|
| 4724 |
-
|
| 4725 |
-
const int first_row = r0 *
|
| 4726 |
|
| 4727 |
const uint i12 = im%args.ne12;
|
| 4728 |
const uint i13 = im/args.ne12;
|
|
@@ -4735,7 +4724,8 @@ void kernel_mul_mv_q4_K_f32_impl(
|
|
| 4735 |
|
| 4736 |
float yl[16];
|
| 4737 |
float yh[16];
|
| 4738 |
-
|
|
|
|
| 4739 |
|
| 4740 |
device const float * y4 = y + ix * QK_K + 64 * iq + 8 * ir;
|
| 4741 |
|
|
@@ -4744,7 +4734,8 @@ void kernel_mul_mv_q4_K_f32_impl(
|
|
| 4744 |
|
| 4745 |
for (int ib = ix; ib < nb; ib += 4) {
|
| 4746 |
float4 sumy = {0.f, 0.f, 0.f, 0.f};
|
| 4747 |
-
|
|
|
|
| 4748 |
yl[i+0] = y4[i+ 0]; sumy[0] += yl[i+0];
|
| 4749 |
yl[i+8] = y4[i+ 32]; sumy[1] += yl[i+8];
|
| 4750 |
yh[i+0] = y4[i+128]; sumy[2] += yh[i+0];
|
|
@@ -4755,7 +4746,7 @@ void kernel_mul_mv_q4_K_f32_impl(
|
|
| 4755 |
device const uint16_t * q1 = (device const uint16_t *)x[ib].qs + 16 * iq + 4 * ir;
|
| 4756 |
device const half * dh = &x[ib].d;
|
| 4757 |
|
| 4758 |
-
for (
|
| 4759 |
sc16[0] = sc[0] & kmask1;
|
| 4760 |
sc16[1] = sc[2] & kmask1;
|
| 4761 |
sc16[2] = ((sc[4] >> 0) & kmask2) | ((sc[0] & kmask3) >> 2);
|
|
@@ -4765,19 +4756,21 @@ void kernel_mul_mv_q4_K_f32_impl(
|
|
| 4765 |
|
| 4766 |
float4 acc1 = {0.f, 0.f, 0.f, 0.f};
|
| 4767 |
float4 acc2 = {0.f, 0.f, 0.f, 0.f};
|
| 4768 |
-
|
| 4769 |
-
|
| 4770 |
-
acc1[
|
| 4771 |
-
acc1[
|
| 4772 |
-
acc1[
|
| 4773 |
-
|
| 4774 |
-
acc2[
|
| 4775 |
-
acc2[
|
| 4776 |
-
acc2[
|
|
|
|
| 4777 |
}
|
| 4778 |
|
| 4779 |
float dall = dh[0];
|
| 4780 |
float dmin = dh[1];
|
|
|
|
| 4781 |
sumf[row] += dall * ((acc1[0] + 1.f/256.f * acc1[1]) * sc8[0] +
|
| 4782 |
(acc1[2] + 1.f/256.f * acc1[3]) * sc8[1] * 1.f/16.f +
|
| 4783 |
(acc2[0] + 1.f/256.f * acc2[1]) * sc8[4] +
|
|
@@ -4794,10 +4787,10 @@ void kernel_mul_mv_q4_K_f32_impl(
|
|
| 4794 |
|
| 4795 |
device float * dst_f32 = (device float *) dst + (int64_t)im*args.ne0*args.ne1 + (int64_t)r1*args.ne0;
|
| 4796 |
|
| 4797 |
-
for (int row = 0; row <
|
| 4798 |
-
|
| 4799 |
if (tiisg == 0) {
|
| 4800 |
-
dst_f32[first_row + row] =
|
| 4801 |
}
|
| 4802 |
}
|
| 4803 |
}
|
|
@@ -4812,10 +4805,10 @@ kernel void kernel_mul_mv_q4_K_f32(
|
|
| 4812 |
ushort tiisg[[thread_index_in_simdgroup]],
|
| 4813 |
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
|
| 4814 |
|
| 4815 |
-
kernel_mul_mv_q4_K_f32_impl<constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
|
| 4816 |
}
|
| 4817 |
|
| 4818 |
-
template<typename args_t>
|
| 4819 |
void kernel_mul_mv_q5_K_f32_impl(
|
| 4820 |
args_t args,
|
| 4821 |
device const char * src0,
|
|
@@ -4832,7 +4825,7 @@ void kernel_mul_mv_q5_K_f32_impl(
|
|
| 4832 |
const int r1 = tgpig.y;
|
| 4833 |
const int im = tgpig.z;
|
| 4834 |
|
| 4835 |
-
const int first_row = (r0 *
|
| 4836 |
|
| 4837 |
const uint i12 = im%args.ne12;
|
| 4838 |
const uint i13 = im/args.ne12;
|
|
@@ -4843,7 +4836,7 @@ void kernel_mul_mv_q5_K_f32_impl(
|
|
| 4843 |
device const block_q5_K * x = (device const block_q5_K *) (src0 + offset0);
|
| 4844 |
device const float * yy = (device const float *) (src1 + offset1);
|
| 4845 |
|
| 4846 |
-
float sumf[
|
| 4847 |
|
| 4848 |
float yl[16], yh[16];
|
| 4849 |
|
|
@@ -4851,15 +4844,14 @@ void kernel_mul_mv_q5_K_f32_impl(
|
|
| 4851 |
const uint16_t kmask2 = 0x0f0f;
|
| 4852 |
const uint16_t kmask3 = 0xc0c0;
|
| 4853 |
|
| 4854 |
-
const
|
| 4855 |
-
const
|
| 4856 |
-
const
|
| 4857 |
-
const
|
| 4858 |
-
const int n = 8;
|
| 4859 |
|
| 4860 |
-
const
|
| 4861 |
-
const
|
| 4862 |
-
const
|
| 4863 |
|
| 4864 |
const uint8_t hm1 = 1u << (2*iq);
|
| 4865 |
const uint8_t hm2 = hm1 << 1;
|
|
@@ -4879,14 +4871,14 @@ void kernel_mul_mv_q5_K_f32_impl(
|
|
| 4879 |
|
| 4880 |
device const float * y2 = y1 + 128;
|
| 4881 |
float4 sumy = {0.f, 0.f, 0.f, 0.f};
|
| 4882 |
-
for (
|
| 4883 |
yl[l+0] = y1[l+ 0]; sumy[0] += yl[l+0];
|
| 4884 |
yl[l+8] = y1[l+32]; sumy[1] += yl[l+8];
|
| 4885 |
yh[l+0] = y2[l+ 0]; sumy[2] += yh[l+0];
|
| 4886 |
yh[l+8] = y2[l+32]; sumy[3] += yh[l+8];
|
| 4887 |
}
|
| 4888 |
|
| 4889 |
-
for (
|
| 4890 |
device const uint8_t * q2 = q1 + 64;
|
| 4891 |
|
| 4892 |
sc16[0] = a[0] & kmask1;
|
|
@@ -4896,7 +4888,7 @@ void kernel_mul_mv_q5_K_f32_impl(
|
|
| 4896 |
|
| 4897 |
float4 acc1 = {0.f};
|
| 4898 |
float4 acc2 = {0.f};
|
| 4899 |
-
for (
|
| 4900 |
uint8_t h = qh[l];
|
| 4901 |
acc1[0] += yl[l+0] * (q1[l] & 0x0F);
|
| 4902 |
acc1[1] += yl[l+8] * (q1[l] & 0xF0);
|
|
@@ -4926,7 +4918,7 @@ void kernel_mul_mv_q5_K_f32_impl(
|
|
| 4926 |
|
| 4927 |
device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
|
| 4928 |
|
| 4929 |
-
for (int row = 0; row <
|
| 4930 |
const float tot = simd_sum(sumf[row]);
|
| 4931 |
if (tiisg == 0) {
|
| 4932 |
dst_f32[first_row + row] = tot;
|
|
@@ -4944,10 +4936,10 @@ kernel void kernel_mul_mv_q5_K_f32(
|
|
| 4944 |
ushort tiisg[[thread_index_in_simdgroup]],
|
| 4945 |
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
|
| 4946 |
|
| 4947 |
-
kernel_mul_mv_q5_K_f32_impl<constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
|
| 4948 |
}
|
| 4949 |
|
| 4950 |
-
template
|
| 4951 |
void kernel_mul_mv_q6_K_f32_impl(
|
| 4952 |
args_t args,
|
| 4953 |
device const char * src0,
|
|
@@ -4969,62 +4961,77 @@ void kernel_mul_mv_q6_K_f32_impl(
|
|
| 4969 |
const int r1 = tgpig.y;
|
| 4970 |
const int im = tgpig.z;
|
| 4971 |
|
| 4972 |
-
const int
|
| 4973 |
-
|
| 4974 |
-
if (row >= args.ne0) {
|
| 4975 |
-
return;
|
| 4976 |
-
}
|
| 4977 |
|
| 4978 |
const uint i12 = im%args.ne12;
|
| 4979 |
const uint i13 = im/args.ne12;
|
| 4980 |
|
| 4981 |
-
const uint64_t offset0 =
|
| 4982 |
-
const uint64_t offset1 =
|
| 4983 |
|
| 4984 |
device const block_q6_K * x = (device const block_q6_K *) (src0 + offset0);
|
| 4985 |
device const float * yy = (device const float *) (src1 + offset1);
|
| 4986 |
|
| 4987 |
-
float sumf = 0;
|
|
|
|
|
|
|
| 4988 |
|
| 4989 |
-
const
|
| 4990 |
-
const
|
| 4991 |
-
const
|
| 4992 |
-
const
|
| 4993 |
-
const
|
| 4994 |
-
const
|
| 4995 |
-
const int is = 8*ip + l0/16;
|
| 4996 |
|
| 4997 |
-
const
|
| 4998 |
-
const
|
| 4999 |
-
const
|
| 5000 |
|
| 5001 |
for (int i = ix; i < nb; i += 2) {
|
| 5002 |
device const uint8_t * q1 = x[i].ql + q_offset_l;
|
| 5003 |
device const uint8_t * q2 = q1 + 32;
|
| 5004 |
device const uint8_t * qh = x[i].qh + q_offset_h;
|
| 5005 |
device const int8_t * sc = x[i].scales + is;
|
|
|
|
| 5006 |
|
| 5007 |
device const float * y = yy + i * QK_K + y_offset;
|
| 5008 |
|
| 5009 |
-
|
| 5010 |
-
|
| 5011 |
-
|
| 5012 |
-
|
| 5013 |
-
|
| 5014 |
-
sums[1] += y[l+32] * ((int8_t)((q2[l] & 0xF) | ((qh[l] & kmask2) << 2)) - 32);
|
| 5015 |
-
sums[2] += y[l+64] * ((int8_t)((q1[l] >> 4) | ((qh[l] & kmask3) << 0)) - 32);
|
| 5016 |
-
sums[3] += y[l+96] * ((int8_t)((q2[l] >> 4) | ((qh[l] & kmask4) >> 2)) - 32);
|
| 5017 |
}
|
| 5018 |
|
| 5019 |
-
|
|
|
|
|
|
|
|
|
|
| 5020 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5021 |
}
|
| 5022 |
|
| 5023 |
device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
|
| 5024 |
|
| 5025 |
-
|
| 5026 |
-
|
| 5027 |
-
|
|
|
|
|
|
|
| 5028 |
}
|
| 5029 |
}
|
| 5030 |
|
|
@@ -5038,12 +5045,12 @@ kernel void kernel_mul_mv_q6_K_f32(
|
|
| 5038 |
ushort tiisg[[thread_index_in_simdgroup]],
|
| 5039 |
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
|
| 5040 |
|
| 5041 |
-
kernel_mul_mv_q6_K_f32_impl<constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
|
| 5042 |
}
|
| 5043 |
|
| 5044 |
// ======================= "True" 2-bit
|
| 5045 |
|
| 5046 |
-
template<typename args_t>
|
| 5047 |
void kernel_mul_mv_iq2_xxs_f32_impl(
|
| 5048 |
args_t args,
|
| 5049 |
device const char * src0,
|
|
@@ -5059,7 +5066,7 @@ void kernel_mul_mv_iq2_xxs_f32_impl(
|
|
| 5059 |
const int r1 = tgpig.y;
|
| 5060 |
const int im = tgpig.z;
|
| 5061 |
|
| 5062 |
-
const int first_row = (r0 *
|
| 5063 |
|
| 5064 |
const uint i12 = im%args.ne12;
|
| 5065 |
const uint i13 = im/args.ne12;
|
|
@@ -5071,7 +5078,7 @@ void kernel_mul_mv_iq2_xxs_f32_impl(
|
|
| 5071 |
device const float * y = (device const float *) (src1 + offset1);
|
| 5072 |
|
| 5073 |
float yl[32];
|
| 5074 |
-
float sumf[
|
| 5075 |
|
| 5076 |
const int nb32 = nb * (QK_K / 32);
|
| 5077 |
|
|
@@ -5092,8 +5099,7 @@ void kernel_mul_mv_iq2_xxs_f32_impl(
|
|
| 5092 |
device const float * y4 = y + 32 * ix;
|
| 5093 |
|
| 5094 |
for (int ib32 = ix; ib32 < nb32; ib32 += 32) {
|
| 5095 |
-
|
| 5096 |
-
for (int i = 0; i < 32; ++i) {
|
| 5097 |
yl[i] = y4[i];
|
| 5098 |
}
|
| 5099 |
|
|
@@ -5104,18 +5110,17 @@ void kernel_mul_mv_iq2_xxs_f32_impl(
|
|
| 5104 |
device const uint16_t * q2 = xr->qs + 4 * ib;
|
| 5105 |
device const half * dh = &xr->d;
|
| 5106 |
|
| 5107 |
-
for (
|
| 5108 |
-
|
| 5109 |
const float db = dh[0];
|
| 5110 |
device const uint8_t * aux8 = (device const uint8_t *)q2;
|
| 5111 |
const uint32_t aux32 = q2[2] | (q2[3] << 16);
|
| 5112 |
const float d = db * (0.5f + (aux32 >> 28));
|
| 5113 |
|
| 5114 |
float sum = 0;
|
| 5115 |
-
for (
|
| 5116 |
const threadgroup uint8_t * grid = (const threadgroup uint8_t *)(svalues + aux8[l]);
|
| 5117 |
const uint8_t signs = ssigns[(aux32 >> 7*l) & 127];
|
| 5118 |
-
for (
|
| 5119 |
sum += yl[8*l + j] * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
|
| 5120 |
}
|
| 5121 |
}
|
|
@@ -5130,10 +5135,10 @@ void kernel_mul_mv_iq2_xxs_f32_impl(
|
|
| 5130 |
|
| 5131 |
device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
|
| 5132 |
|
| 5133 |
-
for (int row = 0; row <
|
| 5134 |
-
|
| 5135 |
if (tiisg == 0) {
|
| 5136 |
-
dst_f32[first_row + row] =
|
| 5137 |
}
|
| 5138 |
}
|
| 5139 |
}
|
|
@@ -5148,10 +5153,10 @@ kernel void kernel_mul_mv_iq2_xxs_f32(
|
|
| 5148 |
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 5149 |
ushort tiisg[[thread_index_in_simdgroup]],
|
| 5150 |
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
|
| 5151 |
-
kernel_mul_mv_iq2_xxs_f32_impl<constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
|
| 5152 |
}
|
| 5153 |
|
| 5154 |
-
template<typename args_t>
|
| 5155 |
void kernel_mul_mv_iq2_xs_f32_impl(
|
| 5156 |
args_t args,
|
| 5157 |
device const char * src0,
|
|
@@ -5167,7 +5172,7 @@ void kernel_mul_mv_iq2_xs_f32_impl(
|
|
| 5167 |
const int r1 = tgpig.y;
|
| 5168 |
const int im = tgpig.z;
|
| 5169 |
|
| 5170 |
-
const int first_row = (r0 *
|
| 5171 |
|
| 5172 |
const uint i12 = im%args.ne12;
|
| 5173 |
const uint i13 = im/args.ne12;
|
|
@@ -5179,7 +5184,7 @@ void kernel_mul_mv_iq2_xs_f32_impl(
|
|
| 5179 |
device const float * y = (device const float *) (src1 + offset1);
|
| 5180 |
|
| 5181 |
float yl[32];
|
| 5182 |
-
float sumf[
|
| 5183 |
|
| 5184 |
const int nb32 = nb * (QK_K / 32);
|
| 5185 |
|
|
@@ -5200,8 +5205,7 @@ void kernel_mul_mv_iq2_xs_f32_impl(
|
|
| 5200 |
device const float * y4 = y + 32 * ix;
|
| 5201 |
|
| 5202 |
for (int ib32 = ix; ib32 < nb32; ib32 += 32) {
|
| 5203 |
-
|
| 5204 |
-
for (int i = 0; i < 32; ++i) {
|
| 5205 |
yl[i] = y4[i];
|
| 5206 |
}
|
| 5207 |
|
|
@@ -5213,8 +5217,7 @@ void kernel_mul_mv_iq2_xs_f32_impl(
|
|
| 5213 |
device const uint8_t * sc = xr->scales + ib;
|
| 5214 |
device const half * dh = &xr->d;
|
| 5215 |
|
| 5216 |
-
for (
|
| 5217 |
-
|
| 5218 |
const float db = dh[0];
|
| 5219 |
const uint8_t ls1 = sc[0] & 0xf;
|
| 5220 |
const uint8_t ls2 = sc[0] >> 4;
|
|
@@ -5222,17 +5225,17 @@ void kernel_mul_mv_iq2_xs_f32_impl(
|
|
| 5222 |
const float d2 = db * (0.5f + ls2);
|
| 5223 |
|
| 5224 |
float sum1 = 0, sum2 = 0;
|
| 5225 |
-
for (
|
| 5226 |
const threadgroup uint8_t * grid = (const threadgroup uint8_t *)(svalues + (q2[l] & 511));
|
| 5227 |
const uint8_t signs = ssigns[(q2[l] >> 9)];
|
| 5228 |
-
for (
|
| 5229 |
sum1 += yl[8*l + j] * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
|
| 5230 |
}
|
| 5231 |
}
|
| 5232 |
-
for (
|
| 5233 |
const threadgroup uint8_t * grid = (const threadgroup uint8_t *)(svalues + (q2[l] & 511));
|
| 5234 |
const uint8_t signs = ssigns[(q2[l] >> 9)];
|
| 5235 |
-
for (
|
| 5236 |
sum2 += yl[8*l + j] * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
|
| 5237 |
}
|
| 5238 |
}
|
|
@@ -5248,10 +5251,10 @@ void kernel_mul_mv_iq2_xs_f32_impl(
|
|
| 5248 |
|
| 5249 |
device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
|
| 5250 |
|
| 5251 |
-
for (int row = 0; row <
|
| 5252 |
-
|
| 5253 |
if (tiisg == 0) {
|
| 5254 |
-
dst_f32[first_row + row] =
|
| 5255 |
}
|
| 5256 |
}
|
| 5257 |
}
|
|
@@ -5267,10 +5270,10 @@ kernel void kernel_mul_mv_iq2_xs_f32(
|
|
| 5267 |
ushort tiisg[[thread_index_in_simdgroup]],
|
| 5268 |
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
|
| 5269 |
|
| 5270 |
-
kernel_mul_mv_iq2_xs_f32_impl<constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
|
| 5271 |
}
|
| 5272 |
|
| 5273 |
-
template
|
| 5274 |
void kernel_mul_mv_iq3_xxs_f32_impl(
|
| 5275 |
args_t args,
|
| 5276 |
device const char * src0,
|
|
@@ -5286,7 +5289,7 @@ void kernel_mul_mv_iq3_xxs_f32_impl(
|
|
| 5286 |
const int r1 = tgpig.y;
|
| 5287 |
const int im = tgpig.z;
|
| 5288 |
|
| 5289 |
-
const int first_row = (r0 *
|
| 5290 |
|
| 5291 |
const uint i12 = im%args.ne12;
|
| 5292 |
const uint i13 = im/args.ne12;
|
|
@@ -5298,7 +5301,7 @@ void kernel_mul_mv_iq3_xxs_f32_impl(
|
|
| 5298 |
device const float * y = (device const float *) (src1 + offset1);
|
| 5299 |
|
| 5300 |
float yl[32];
|
| 5301 |
-
float sumf[
|
| 5302 |
|
| 5303 |
const int nb32 = nb * (QK_K / 32);
|
| 5304 |
|
|
@@ -5319,7 +5322,7 @@ void kernel_mul_mv_iq3_xxs_f32_impl(
|
|
| 5319 |
device const float * y4 = y + 32 * ix;
|
| 5320 |
|
| 5321 |
for (int ib32 = ix; ib32 < nb32; ib32 += 32) {
|
| 5322 |
-
for (
|
| 5323 |
yl[i] = y4[i];
|
| 5324 |
}
|
| 5325 |
|
|
@@ -5331,17 +5334,17 @@ void kernel_mul_mv_iq3_xxs_f32_impl(
|
|
| 5331 |
device const uint16_t * gas = (device const uint16_t *)(xr->qs + QK_K/4) + 2 * ib;
|
| 5332 |
device const half * dh = &xr->d;
|
| 5333 |
|
| 5334 |
-
for (
|
| 5335 |
const float db = dh[0];
|
| 5336 |
const uint32_t aux32 = gas[0] | (gas[1] << 16);
|
| 5337 |
const float d = db * (0.5f + (aux32 >> 28));
|
| 5338 |
|
| 5339 |
float2 sum = {0};
|
| 5340 |
-
for (
|
| 5341 |
const threadgroup uint8_t * grid1 = (const threadgroup uint8_t *)(svalues + q3[2*l+0]);
|
| 5342 |
const threadgroup uint8_t * grid2 = (const threadgroup uint8_t *)(svalues + q3[2*l+1]);
|
| 5343 |
const uint8_t signs = ssigns[(aux32 >> 7*l) & 127];
|
| 5344 |
-
for (
|
| 5345 |
sum[0] += yl[8*l + j + 0] * grid1[j] * (signs & kmask_iq2xs[j+0] ? -1.f : 1.f);
|
| 5346 |
sum[1] += yl[8*l + j + 4] * grid2[j] * (signs & kmask_iq2xs[j+4] ? -1.f : 1.f);
|
| 5347 |
}
|
|
@@ -5358,10 +5361,10 @@ void kernel_mul_mv_iq3_xxs_f32_impl(
|
|
| 5358 |
|
| 5359 |
device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
|
| 5360 |
|
| 5361 |
-
for (int row = 0; row <
|
| 5362 |
-
|
| 5363 |
if (tiisg == 0) {
|
| 5364 |
-
dst_f32[first_row + row] =
|
| 5365 |
}
|
| 5366 |
}
|
| 5367 |
}
|
|
@@ -5377,10 +5380,10 @@ kernel void kernel_mul_mv_iq3_xxs_f32(
|
|
| 5377 |
ushort tiisg[[thread_index_in_simdgroup]],
|
| 5378 |
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
|
| 5379 |
|
| 5380 |
-
kernel_mul_mv_iq3_xxs_f32_impl<constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
|
| 5381 |
}
|
| 5382 |
|
| 5383 |
-
template<typename args_t>
|
| 5384 |
void kernel_mul_mv_iq3_s_f32_impl(
|
| 5385 |
args_t args,
|
| 5386 |
device const char * src0,
|
|
@@ -5396,7 +5399,7 @@ void kernel_mul_mv_iq3_s_f32_impl(
|
|
| 5396 |
const int r1 = tgpig.y;
|
| 5397 |
const int im = tgpig.z;
|
| 5398 |
|
| 5399 |
-
const int first_row = (r0 *
|
| 5400 |
|
| 5401 |
const uint i12 = im%args.ne12;
|
| 5402 |
const uint i13 = im/args.ne12;
|
|
@@ -5408,7 +5411,7 @@ void kernel_mul_mv_iq3_s_f32_impl(
|
|
| 5408 |
device const float * y = (device const float *) (src1 + offset1);
|
| 5409 |
|
| 5410 |
float yl[32];
|
| 5411 |
-
float sumf[
|
| 5412 |
|
| 5413 |
const int nb32 = nb * (QK_K / 32);
|
| 5414 |
|
|
@@ -5425,8 +5428,7 @@ void kernel_mul_mv_iq3_s_f32_impl(
|
|
| 5425 |
device const float * y4 = y + 32 * ix;
|
| 5426 |
|
| 5427 |
for (int ib32 = ix; ib32 < nb32; ib32 += 32) {
|
| 5428 |
-
|
| 5429 |
-
for (int i = 0; i < 32; ++i) {
|
| 5430 |
yl[i] = y4[i];
|
| 5431 |
}
|
| 5432 |
|
|
@@ -5440,18 +5442,17 @@ void kernel_mul_mv_iq3_s_f32_impl(
|
|
| 5440 |
device const uint8_t * signs = xr->signs + 4 * ib;
|
| 5441 |
device const half * dh = &xr->d;
|
| 5442 |
|
| 5443 |
-
for (
|
| 5444 |
-
|
| 5445 |
const float db = dh[0];
|
| 5446 |
const float d = db * (1 + 2*((sc[0] >> 4*(ib%2)) & 0xf));
|
| 5447 |
|
| 5448 |
float2 sum = {0};
|
| 5449 |
-
for (
|
| 5450 |
const threadgroup uint32_t * table1 = qh[0] & kmask_iq2xs[2*l+0] ? svalues + 256 : svalues;
|
| 5451 |
const threadgroup uint32_t * table2 = qh[0] & kmask_iq2xs[2*l+1] ? svalues + 256 : svalues;
|
| 5452 |
const threadgroup uint8_t * grid1 = (const threadgroup uint8_t *)(table1 + qs[2*l+0]);
|
| 5453 |
const threadgroup uint8_t * grid2 = (const threadgroup uint8_t *)(table2 + qs[2*l+1]);
|
| 5454 |
-
for (
|
| 5455 |
sum[0] += yl[8*l + j + 0] * grid1[j] * select(1, -1, signs[l] & kmask_iq2xs[j+0]);
|
| 5456 |
sum[1] += yl[8*l + j + 4] * grid2[j] * select(1, -1, signs[l] & kmask_iq2xs[j+4]);
|
| 5457 |
}
|
|
@@ -5470,10 +5471,10 @@ void kernel_mul_mv_iq3_s_f32_impl(
|
|
| 5470 |
|
| 5471 |
device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
|
| 5472 |
|
| 5473 |
-
for (int row = 0; row <
|
| 5474 |
-
|
| 5475 |
if (tiisg == 0) {
|
| 5476 |
-
dst_f32[first_row + row] =
|
| 5477 |
}
|
| 5478 |
}
|
| 5479 |
}
|
|
@@ -5489,10 +5490,10 @@ kernel void kernel_mul_mv_iq3_s_f32(
|
|
| 5489 |
ushort tiisg[[thread_index_in_simdgroup]],
|
| 5490 |
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
|
| 5491 |
|
| 5492 |
-
kernel_mul_mv_iq3_s_f32_impl<constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
|
| 5493 |
}
|
| 5494 |
|
| 5495 |
-
template
|
| 5496 |
void kernel_mul_mv_iq2_s_f32_impl(
|
| 5497 |
args_t args,
|
| 5498 |
device const char * src0,
|
|
@@ -5508,7 +5509,7 @@ void kernel_mul_mv_iq2_s_f32_impl(
|
|
| 5508 |
const int r1 = tgpig.y;
|
| 5509 |
const int im = tgpig.z;
|
| 5510 |
|
| 5511 |
-
const int first_row = (r0 *
|
| 5512 |
|
| 5513 |
const uint i12 = im%args.ne12;
|
| 5514 |
const uint i13 = im/args.ne12;
|
|
@@ -5520,7 +5521,7 @@ void kernel_mul_mv_iq2_s_f32_impl(
|
|
| 5520 |
device const float * y = (device const float *) (src1 + offset1);
|
| 5521 |
|
| 5522 |
float yl[32];
|
| 5523 |
-
float sumf[
|
| 5524 |
|
| 5525 |
const int nb32 = nb * (QK_K / 32);
|
| 5526 |
|
|
@@ -5532,13 +5533,12 @@ void kernel_mul_mv_iq2_s_f32_impl(
|
|
| 5532 |
// threadgroup_barrier(mem_flags::mem_threadgroup);
|
| 5533 |
//}
|
| 5534 |
|
| 5535 |
-
const
|
| 5536 |
|
| 5537 |
device const float * y4 = y + 32 * ix;
|
| 5538 |
|
| 5539 |
for (int ib32 = ix; ib32 < nb32; ib32 += 32) {
|
| 5540 |
-
|
| 5541 |
-
for (int i = 0; i < 32; ++i) {
|
| 5542 |
yl[i] = y4[i];
|
| 5543 |
}
|
| 5544 |
|
|
@@ -5552,19 +5552,18 @@ void kernel_mul_mv_iq2_s_f32_impl(
|
|
| 5552 |
device const uint8_t * signs = qs + QK_K/8;
|
| 5553 |
device const half * dh = &xr->d;
|
| 5554 |
|
| 5555 |
-
for (
|
| 5556 |
-
|
| 5557 |
const float db = dh[0];
|
| 5558 |
const float d1 = db * (0.5f + (sc[0] & 0xf));
|
| 5559 |
const float d2 = db * (0.5f + (sc[0] >> 4));
|
| 5560 |
|
| 5561 |
float2 sum = {0};
|
| 5562 |
-
for (
|
| 5563 |
//const threadgroup uint8_t * grid1 = (const threadgroup uint8_t *)(svalues + (qs[l+0] | ((qh[0] << (8-2*l)) & 0x300)));
|
| 5564 |
//const threadgroup uint8_t * grid2 = (const threadgroup uint8_t *)(svalues + (qs[l+2] | ((qh[0] << (4-2*l)) & 0x300)));
|
| 5565 |
constant uint8_t * grid1 = (constant uint8_t *)(iq2s_grid + (qs[l+0] | ((qh[0] << (8-2*l)) & 0x300)));
|
| 5566 |
constant uint8_t * grid2 = (constant uint8_t *)(iq2s_grid + (qs[l+2] | ((qh[0] << (4-2*l)) & 0x300)));
|
| 5567 |
-
for (
|
| 5568 |
sum[0] += yl[8*l + j + 0] * grid1[j] * select(1, -1, signs[l+0] & kmask_iq2xs[j]);
|
| 5569 |
sum[1] += yl[8*l + j + 16] * grid2[j] * select(1, -1, signs[l+2] & kmask_iq2xs[j]);
|
| 5570 |
}
|
|
@@ -5583,10 +5582,10 @@ void kernel_mul_mv_iq2_s_f32_impl(
|
|
| 5583 |
|
| 5584 |
device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
|
| 5585 |
|
| 5586 |
-
for (int row = 0; row <
|
| 5587 |
-
|
| 5588 |
if (tiisg == 0) {
|
| 5589 |
-
dst_f32[first_row + row] =
|
| 5590 |
}
|
| 5591 |
}
|
| 5592 |
}
|
|
@@ -5602,10 +5601,10 @@ kernel void kernel_mul_mv_iq2_s_f32(
|
|
| 5602 |
ushort tiisg[[thread_index_in_simdgroup]],
|
| 5603 |
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
|
| 5604 |
|
| 5605 |
-
kernel_mul_mv_iq2_s_f32_impl<constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
|
| 5606 |
}
|
| 5607 |
|
| 5608 |
-
template<typename args_t>
|
| 5609 |
void kernel_mul_mv_iq1_s_f32_impl(
|
| 5610 |
args_t args,
|
| 5611 |
device const char * src0,
|
|
@@ -5621,7 +5620,7 @@ void kernel_mul_mv_iq1_s_f32_impl(
|
|
| 5621 |
const int r1 = tgpig.y;
|
| 5622 |
const int im = tgpig.z;
|
| 5623 |
|
| 5624 |
-
const int first_row = (r0 *
|
| 5625 |
|
| 5626 |
const uint i12 = im%args.ne12;
|
| 5627 |
const uint i13 = im/args.ne12;
|
|
@@ -5633,18 +5632,17 @@ void kernel_mul_mv_iq1_s_f32_impl(
|
|
| 5633 |
device const float * y = (device const float *) (src1 + offset1);
|
| 5634 |
|
| 5635 |
float yl[32];
|
| 5636 |
-
float sumf[
|
| 5637 |
|
| 5638 |
const int nb32 = nb * (QK_K / 32);
|
| 5639 |
|
| 5640 |
-
const
|
| 5641 |
|
| 5642 |
device const float * y4 = y + 32 * ix;
|
| 5643 |
|
| 5644 |
for (int ib32 = ix; ib32 < nb32; ib32 += 32) {
|
| 5645 |
-
|
| 5646 |
float sumy = 0;
|
| 5647 |
-
for (
|
| 5648 |
yl[i] = y4[i];
|
| 5649 |
sumy += yl[i];
|
| 5650 |
}
|
|
@@ -5657,15 +5655,14 @@ void kernel_mul_mv_iq1_s_f32_impl(
|
|
| 5657 |
device const uint16_t * qh = xr->qh + ib;
|
| 5658 |
device const half * dh = &xr->d;
|
| 5659 |
|
| 5660 |
-
for (
|
| 5661 |
-
|
| 5662 |
constant uint8_t * grid1 = (constant uint8_t *)(iq1s_grid_gpu + (qs[0] | ((qh[0] << 8) & 0x700)));
|
| 5663 |
constant uint8_t * grid2 = (constant uint8_t *)(iq1s_grid_gpu + (qs[1] | ((qh[0] << 5) & 0x700)));
|
| 5664 |
constant uint8_t * grid3 = (constant uint8_t *)(iq1s_grid_gpu + (qs[2] | ((qh[0] << 2) & 0x700)));
|
| 5665 |
constant uint8_t * grid4 = (constant uint8_t *)(iq1s_grid_gpu + (qs[3] | ((qh[0] >> 1) & 0x700)));
|
| 5666 |
|
| 5667 |
float sum = 0;
|
| 5668 |
-
for (
|
| 5669 |
sum += yl[j+ 0] * (grid1[j] & 0xf) + yl[j+ 4] * (grid1[j] >> 4)
|
| 5670 |
+ yl[j+ 8] * (grid2[j] & 0xf) + yl[j+12] * (grid2[j] >> 4)
|
| 5671 |
+ yl[j+16] * (grid3[j] & 0xf) + yl[j+20] * (grid3[j] >> 4)
|
|
@@ -5683,15 +5680,28 @@ void kernel_mul_mv_iq1_s_f32_impl(
|
|
| 5683 |
|
| 5684 |
device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
|
| 5685 |
|
| 5686 |
-
for (int row = 0; row <
|
| 5687 |
-
|
| 5688 |
if (tiisg == 0) {
|
| 5689 |
-
dst_f32[first_row + row] =
|
| 5690 |
}
|
| 5691 |
}
|
| 5692 |
}
|
| 5693 |
|
| 5694 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5695 |
void kernel_mul_mv_iq1_m_f32_impl(
|
| 5696 |
args_t args,
|
| 5697 |
device const char * src0,
|
|
@@ -5703,11 +5713,12 @@ void kernel_mul_mv_iq1_m_f32_impl(
|
|
| 5703 |
ushort sgitg) {
|
| 5704 |
|
| 5705 |
const int nb = args.ne00/QK_K;
|
|
|
|
| 5706 |
const int r0 = tgpig.x;
|
| 5707 |
const int r1 = tgpig.y;
|
| 5708 |
const int im = tgpig.z;
|
| 5709 |
|
| 5710 |
-
const int first_row = (r0 *
|
| 5711 |
|
| 5712 |
const uint i12 = im%args.ne12;
|
| 5713 |
const uint i13 = im/args.ne12;
|
|
@@ -5719,20 +5730,19 @@ void kernel_mul_mv_iq1_m_f32_impl(
|
|
| 5719 |
device const float * y = (device const float *) (src1 + offset1);
|
| 5720 |
|
| 5721 |
float yl[32];
|
| 5722 |
-
float sumf[
|
| 5723 |
|
| 5724 |
const int nb32 = nb * (QK_K / 32);
|
| 5725 |
|
| 5726 |
-
const
|
| 5727 |
|
| 5728 |
device const float * y4 = y + 32 * ix;
|
| 5729 |
|
| 5730 |
iq1m_scale_t scale;
|
| 5731 |
|
| 5732 |
for (int ib32 = ix; ib32 < nb32; ib32 += 32) {
|
| 5733 |
-
|
| 5734 |
float4 sumy = {0.f};
|
| 5735 |
-
for (
|
| 5736 |
yl[i+ 0] = y4[i+ 0]; sumy[0] += yl[i+ 0];
|
| 5737 |
yl[i+ 8] = y4[i+ 8]; sumy[1] += yl[i+ 8];
|
| 5738 |
yl[i+16] = y4[i+16]; sumy[2] += yl[i+16];
|
|
@@ -5747,7 +5757,7 @@ void kernel_mul_mv_iq1_m_f32_impl(
|
|
| 5747 |
device const uint8_t * qh = xr->qh + 2 * ib;
|
| 5748 |
device const uint16_t * sc = (device const uint16_t *)xr->scales;
|
| 5749 |
|
| 5750 |
-
for (
|
| 5751 |
scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000);
|
| 5752 |
|
| 5753 |
constant uint8_t * grid1 = (constant uint8_t *)(iq1s_grid_gpu + (qs[0] | ((qh[0] << 8) & 0x700)));
|
|
@@ -5756,7 +5766,7 @@ void kernel_mul_mv_iq1_m_f32_impl(
|
|
| 5756 |
constant uint8_t * grid4 = (constant uint8_t *)(iq1s_grid_gpu + (qs[3] | ((qh[1] << 4) & 0x700)));
|
| 5757 |
|
| 5758 |
float2 sum = {0.f};
|
| 5759 |
-
for (
|
| 5760 |
sum[0] += yl[j+ 0] * (grid1[j] & 0xf) + yl[j+ 4] * (grid1[j] >> 4)
|
| 5761 |
+ yl[j+ 8] * (grid2[j] & 0xf) + yl[j+12] * (grid2[j] >> 4);
|
| 5762 |
sum[1] += yl[j+16] * (grid3[j] & 0xf) + yl[j+20] * (grid3[j] >> 4)
|
|
@@ -5778,15 +5788,28 @@ void kernel_mul_mv_iq1_m_f32_impl(
|
|
| 5778 |
|
| 5779 |
device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
|
| 5780 |
|
| 5781 |
-
for (int row = 0; row <
|
| 5782 |
-
|
| 5783 |
if (tiisg == 0) {
|
| 5784 |
-
dst_f32[first_row + row] =
|
| 5785 |
}
|
| 5786 |
}
|
| 5787 |
}
|
| 5788 |
|
| 5789 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5790 |
void kernel_mul_mv_iq4_nl_f32_impl(
|
| 5791 |
args_t args,
|
| 5792 |
device const char * src0,
|
|
@@ -5799,10 +5822,12 @@ void kernel_mul_mv_iq4_nl_f32_impl(
|
|
| 5799 |
|
| 5800 |
threadgroup float * shmem_f32 = (threadgroup float *) shmem;
|
| 5801 |
const int nb = args.ne00/QK4_NL;
|
|
|
|
| 5802 |
const int r0 = tgpig.x;
|
| 5803 |
const int r1 = tgpig.y;
|
| 5804 |
const int im = tgpig.z;
|
| 5805 |
-
|
|
|
|
| 5806 |
|
| 5807 |
const uint i12 = im%args.ne12;
|
| 5808 |
const uint i13 = im/args.ne12;
|
|
@@ -5813,14 +5838,14 @@ void kernel_mul_mv_iq4_nl_f32_impl(
|
|
| 5813 |
device const block_iq4_nl * x = (device const block_iq4_nl *) (src0 + offset0);
|
| 5814 |
device const float * y = (device const float *) (src1 + offset1);
|
| 5815 |
|
| 5816 |
-
const
|
| 5817 |
-
const
|
| 5818 |
|
| 5819 |
shmem_f32[tiisg] = kvalues_iq4nl_f[tiisg%16];
|
| 5820 |
threadgroup_barrier(mem_flags::mem_threadgroup);
|
| 5821 |
|
| 5822 |
float4 yl[4];
|
| 5823 |
-
float sumf[
|
| 5824 |
|
| 5825 |
device const float * yb = y + ix * QK4_NL + it * 8;
|
| 5826 |
|
|
@@ -5830,12 +5855,13 @@ void kernel_mul_mv_iq4_nl_f32_impl(
|
|
| 5830 |
float4 qf1, qf2;
|
| 5831 |
|
| 5832 |
for (int ib = ix; ib < nb; ib += 16) {
|
| 5833 |
-
|
| 5834 |
device const float4 * y4 = (device const float4 *)yb;
|
| 5835 |
-
yl[0] = y4[0];
|
| 5836 |
-
|
| 5837 |
-
|
|
|
|
| 5838 |
|
|
|
|
| 5839 |
device const block_iq4_nl & xb = x[row*nb + ib];
|
| 5840 |
device const uint16_t * q4 = (device const uint16_t *)(xb.qs + 8*it);
|
| 5841 |
|
|
@@ -5860,7 +5886,6 @@ void kernel_mul_mv_iq4_nl_f32_impl(
|
|
| 5860 |
acc1 += acc2;
|
| 5861 |
|
| 5862 |
sumf[row] += (float)xb.d * (acc1[0] + acc1[1] + acc1[2] + acc1[3]);
|
| 5863 |
-
|
| 5864 |
}
|
| 5865 |
|
| 5866 |
yb += 16 * QK4_NL;
|
|
@@ -5868,15 +5893,29 @@ void kernel_mul_mv_iq4_nl_f32_impl(
|
|
| 5868 |
|
| 5869 |
device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
|
| 5870 |
|
| 5871 |
-
for (int row = 0; row <
|
| 5872 |
-
|
| 5873 |
if (tiisg == 0) {
|
| 5874 |
-
dst_f32[first_row + row] =
|
| 5875 |
}
|
| 5876 |
}
|
| 5877 |
}
|
| 5878 |
|
| 5879 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5880 |
void kernel_mul_mv_iq4_xs_f32_impl(
|
| 5881 |
args_t args,
|
| 5882 |
device const char * src0,
|
|
@@ -5892,7 +5931,7 @@ void kernel_mul_mv_iq4_xs_f32_impl(
|
|
| 5892 |
const int r0 = tgpig.x;
|
| 5893 |
const int r1 = tgpig.y;
|
| 5894 |
const int im = tgpig.z;
|
| 5895 |
-
const int first_row = (r0 *
|
| 5896 |
|
| 5897 |
const uint i12 = im%args.ne12;
|
| 5898 |
const uint i13 = im/args.ne12;
|
|
@@ -5903,16 +5942,16 @@ void kernel_mul_mv_iq4_xs_f32_impl(
|
|
| 5903 |
device const block_iq4_xs * x = (device const block_iq4_xs *) (src0 + offset0);
|
| 5904 |
device const float * y = (device const float *) (src1 + offset1);
|
| 5905 |
|
| 5906 |
-
const
|
| 5907 |
-
const
|
| 5908 |
-
const
|
| 5909 |
-
const
|
| 5910 |
|
| 5911 |
shmem_f32[tiisg] = kvalues_iq4nl_f[tiisg%16];
|
| 5912 |
threadgroup_barrier(mem_flags::mem_threadgroup);
|
| 5913 |
|
| 5914 |
float4 yl[4];
|
| 5915 |
-
float sumf[
|
| 5916 |
|
| 5917 |
device const float * yb = y + ix * QK_K + ib * 32 + il * 8;
|
| 5918 |
|
|
@@ -5923,9 +5962,12 @@ void kernel_mul_mv_iq4_xs_f32_impl(
|
|
| 5923 |
|
| 5924 |
for (int ibl = ix; ibl < nb; ibl += 2) {
|
| 5925 |
device const float4 * y4 = (device const float4 *)yb;
|
| 5926 |
-
yl[0] = y4[0];
|
|
|
|
|
|
|
|
|
|
| 5927 |
|
| 5928 |
-
for (
|
| 5929 |
device const block_iq4_xs & xb = x[row*nb + ibl];
|
| 5930 |
device const uint32_t * q4 = (device const uint32_t *)(xb.qs + 16*ib + 8*il);
|
| 5931 |
|
|
@@ -5949,7 +5991,6 @@ void kernel_mul_mv_iq4_xs_f32_impl(
|
|
| 5949 |
|
| 5950 |
const int ls = (((xb.scales_l[ib/2] >> 4*(ib%2)) & 0xf) | (((xb.scales_h >> 2*ib) & 3) << 4)) - 32;
|
| 5951 |
sumf[row] += (float)xb.d * ls * (acc1[0] + acc1[1] + acc1[2] + acc1[3]);
|
| 5952 |
-
|
| 5953 |
}
|
| 5954 |
|
| 5955 |
yb += 2 * QK_K;
|
|
@@ -5957,54 +5998,14 @@ void kernel_mul_mv_iq4_xs_f32_impl(
|
|
| 5957 |
|
| 5958 |
device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
|
| 5959 |
|
| 5960 |
-
for (int row = 0; row <
|
| 5961 |
-
|
| 5962 |
if (tiisg == 0) {
|
| 5963 |
-
dst_f32[first_row + row] =
|
| 5964 |
}
|
| 5965 |
}
|
| 5966 |
}
|
| 5967 |
|
| 5968 |
-
[[host_name("kernel_mul_mv_iq1_s_f32")]]
|
| 5969 |
-
kernel void kernel_mul_mv_iq1_s_f32(
|
| 5970 |
-
constant ggml_metal_kargs_mul_mv & args,
|
| 5971 |
-
device const char * src0,
|
| 5972 |
-
device const char * src1,
|
| 5973 |
-
device char * dst,
|
| 5974 |
-
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 5975 |
-
ushort tiisg[[thread_index_in_simdgroup]],
|
| 5976 |
-
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
|
| 5977 |
-
|
| 5978 |
-
kernel_mul_mv_iq1_s_f32_impl<constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
|
| 5979 |
-
}
|
| 5980 |
-
|
| 5981 |
-
[[host_name("kernel_mul_mv_iq1_m_f32")]]
|
| 5982 |
-
kernel void kernel_mul_mv_iq1_m_f32(
|
| 5983 |
-
constant ggml_metal_kargs_mul_mv & args,
|
| 5984 |
-
device const char * src0,
|
| 5985 |
-
device const char * src1,
|
| 5986 |
-
device char * dst,
|
| 5987 |
-
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 5988 |
-
ushort tiisg[[thread_index_in_simdgroup]],
|
| 5989 |
-
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
|
| 5990 |
-
|
| 5991 |
-
kernel_mul_mv_iq1_m_f32_impl<constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
|
| 5992 |
-
}
|
| 5993 |
-
|
| 5994 |
-
[[host_name("kernel_mul_mv_iq4_nl_f32")]]
|
| 5995 |
-
kernel void kernel_mul_mv_iq4_nl_f32(
|
| 5996 |
-
constant ggml_metal_kargs_mul_mv & args,
|
| 5997 |
-
device const char * src0,
|
| 5998 |
-
device const char * src1,
|
| 5999 |
-
device char * dst,
|
| 6000 |
-
threadgroup char * shmem [[threadgroup(0)]],
|
| 6001 |
-
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 6002 |
-
ushort tiisg[[thread_index_in_simdgroup]],
|
| 6003 |
-
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
|
| 6004 |
-
|
| 6005 |
-
kernel_mul_mv_iq4_nl_f32_impl<constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
|
| 6006 |
-
}
|
| 6007 |
-
|
| 6008 |
[[host_name("kernel_mul_mv_iq4_xs_f32")]]
|
| 6009 |
kernel void kernel_mul_mv_iq4_xs_f32(
|
| 6010 |
constant ggml_metal_kargs_mul_mv & args,
|
|
@@ -6016,7 +6017,7 @@ kernel void kernel_mul_mv_iq4_xs_f32(
|
|
| 6016 |
ushort tiisg[[thread_index_in_simdgroup]],
|
| 6017 |
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
|
| 6018 |
|
| 6019 |
-
kernel_mul_mv_iq4_xs_f32_impl<constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
|
| 6020 |
}
|
| 6021 |
|
| 6022 |
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread float4x4 &)>
|
|
@@ -6660,25 +6661,27 @@ template [[host_name("kernel_mul_mv_id_f16_f32")]] kernel kernel_mul_mv_id_t
|
|
| 6660 |
#if defined(GGML_METAL_USE_BF16)
|
| 6661 |
template [[host_name("kernel_mul_mv_id_bf16_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_impl<bfloat, bfloat4, float, float4>>>;
|
| 6662 |
#endif
|
| 6663 |
-
template [[host_name("kernel_mul_mv_id_q8_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q8_0_f32_impl
|
| 6664 |
-
|
| 6665 |
-
template [[host_name("
|
| 6666 |
-
template [[host_name("
|
| 6667 |
-
template [[host_name("
|
| 6668 |
-
template [[host_name("
|
| 6669 |
-
|
| 6670 |
-
template [[host_name("
|
| 6671 |
-
template [[host_name("
|
| 6672 |
-
template [[host_name("
|
| 6673 |
-
template [[host_name("
|
| 6674 |
-
template [[host_name("
|
| 6675 |
-
template [[host_name("
|
| 6676 |
-
template [[host_name("
|
| 6677 |
-
template [[host_name("
|
| 6678 |
-
template [[host_name("
|
| 6679 |
-
template [[host_name("
|
| 6680 |
-
template [[host_name("
|
| 6681 |
-
template [[host_name("
|
|
|
|
|
|
|
| 6682 |
|
| 6683 |
kernel void kernel_pool_2d_max_f32(
|
| 6684 |
device const float * src0,
|
|
|
|
| 1439 |
|
| 1440 |
float4 sa_vec(0.0);
|
| 1441 |
|
| 1442 |
+
for (uint j = 0; j < head_size; j += 4) {
|
| 1443 |
float4 a_vec = float4(_a[j], _a[j+1], _a[j+2], _a[j+3]);
|
| 1444 |
float4 s_vec = float4(state[j], state[j+1], state[j+2], state[j+3]);
|
| 1445 |
sa_vec += a_vec * s_vec;
|
|
|
|
| 1853 |
return d * (acc[0] + acc[1] + acc[2] + acc[3]) + sumy * m;
|
| 1854 |
}
|
| 1855 |
|
| 1856 |
+
template<typename block_q_type, int nr0, int nsg, int nw, typename args_t>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1857 |
void mul_vec_q_n_f32_impl(
|
| 1858 |
args_t args,
|
| 1859 |
device const char * src0,
|
|
|
|
| 1869 |
const int r1 = tgpig.y;
|
| 1870 |
const int im = tgpig.z;
|
| 1871 |
|
| 1872 |
+
const int first_row = (r0 * nsg + sgitg) * nr0;
|
| 1873 |
|
| 1874 |
const uint i12 = im%args.ne12;
|
| 1875 |
const uint i13 = im/args.ne12;
|
|
|
|
| 1881 |
device const float * y = (device const float *) (src1 + offset1);
|
| 1882 |
|
| 1883 |
// pointers to src0 rows
|
| 1884 |
+
device const block_q_type * ax[nr0];
|
| 1885 |
+
for (int row = 0; row < nr0; ++row) {
|
| 1886 |
const uint64_t offset0 = (first_row + row)*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
|
| 1887 |
|
| 1888 |
ax[row] = (device const block_q_type *) ((device char *) src0 + offset0);
|
| 1889 |
}
|
| 1890 |
|
| 1891 |
float yl[16]; // src1 vector cache
|
| 1892 |
+
float sumf[nr0] = {0.f};
|
| 1893 |
|
| 1894 |
const short ix = (tiisg/2);
|
| 1895 |
const short il = (tiisg%2)*8;
|
|
|
|
| 1901 |
float sumy[2] = { 0.f, 0.f };
|
| 1902 |
|
| 1903 |
#pragma unroll
|
| 1904 |
+
for (short i = 0; i < 8; i += 2) {
|
| 1905 |
sumy[0] += yb[i + 0] + yb[i + 1];
|
| 1906 |
yl[i + 0] = yb[i + 0];
|
| 1907 |
yl[i + 1] = yb[i + 1]/256.f;
|
|
|
|
| 1912 |
}
|
| 1913 |
|
| 1914 |
#pragma unroll
|
| 1915 |
+
for (short row = 0; row < nr0; row++) {
|
| 1916 |
sumf[row] += block_q_n_dot_y(ax[row] + ib, sumy[0] + sumy[1], yl, il);
|
| 1917 |
}
|
| 1918 |
|
|
|
|
| 1921 |
|
| 1922 |
device float * dst_f32 = (device float *) dst + im*args.ne0*args.ne1 + r1*args.ne0;
|
| 1923 |
|
| 1924 |
+
for (int row = 0; row < nr0; ++row) {
|
| 1925 |
const float tot = simd_sum(sumf[row]);
|
| 1926 |
|
| 1927 |
if (tiisg == 0 && first_row + row < args.ne01) {
|
|
|
|
| 1938 |
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 1939 |
ushort tiisg[[thread_index_in_simdgroup]],
|
| 1940 |
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
|
| 1941 |
+
mul_vec_q_n_f32_impl<block_q4_0, N_R0_Q4_0, N_SG_Q4_0, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
|
| 1942 |
}
|
| 1943 |
|
| 1944 |
kernel void kernel_mul_mv_q4_1_f32(
|
|
|
|
| 1949 |
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 1950 |
ushort tiisg[[thread_index_in_simdgroup]],
|
| 1951 |
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
|
| 1952 |
+
mul_vec_q_n_f32_impl<block_q4_1, N_R0_Q4_1, N_SG_Q4_1, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
|
| 1953 |
}
|
| 1954 |
|
| 1955 |
kernel void kernel_mul_mv_q5_0_f32(
|
|
|
|
| 1960 |
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 1961 |
ushort tiisg[[thread_index_in_simdgroup]],
|
| 1962 |
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
|
| 1963 |
+
mul_vec_q_n_f32_impl<block_q5_0, N_R0_Q5_0, N_SG_Q5_0, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
|
| 1964 |
}
|
| 1965 |
|
| 1966 |
kernel void kernel_mul_mv_q5_1_f32(
|
|
|
|
| 1971 |
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 1972 |
ushort tiisg[[thread_index_in_simdgroup]],
|
| 1973 |
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
|
| 1974 |
+
mul_vec_q_n_f32_impl<block_q5_1, N_R0_Q5_1, N_SG_Q5_1, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
|
| 1975 |
}
|
| 1976 |
|
| 1977 |
#define NB_Q8_0 8
|
| 1978 |
|
| 1979 |
+
template<int nr0, int nsg, int nw, typename args_t>
|
| 1980 |
void kernel_mul_mv_q8_0_f32_impl(
|
| 1981 |
args_t args,
|
| 1982 |
device const char * src0,
|
|
|
|
| 1986 |
uint3 tgpig,
|
| 1987 |
ushort tiisg,
|
| 1988 |
ushort sgitg) {
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1989 |
const int nb = args.ne00/QK8_0;
|
| 1990 |
+
|
| 1991 |
const int r0 = tgpig.x;
|
| 1992 |
const int r1 = tgpig.y;
|
| 1993 |
const int im = tgpig.z;
|
| 1994 |
|
| 1995 |
+
const int first_row = (r0 * nsg + sgitg) * nr0;
|
| 1996 |
|
| 1997 |
const uint i12 = im%args.ne12;
|
| 1998 |
const uint i13 = im/args.ne12;
|
|
|
|
| 2004 |
device const float * y = (device const float *) (src1 + offset1);
|
| 2005 |
|
| 2006 |
// pointers to src0 rows
|
| 2007 |
+
device const block_q8_0 * ax[nr0];
|
| 2008 |
+
for (int row = 0; row < nr0; ++row) {
|
| 2009 |
const uint64_t offset0 = (first_row + row)*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
|
| 2010 |
|
| 2011 |
ax[row] = (device const block_q8_0 *) ((device char *) src0 + offset0);
|
| 2012 |
}
|
| 2013 |
|
| 2014 |
float yl[NB_Q8_0];
|
| 2015 |
+
float sumf[nr0] = { 0.f };
|
| 2016 |
|
| 2017 |
const short ix = tiisg/4;
|
| 2018 |
const short il = tiisg%4;
|
|
|
|
| 2025 |
yl[i] = yb[i];
|
| 2026 |
}
|
| 2027 |
|
| 2028 |
+
for (short row = 0; row < nr0; row++) {
|
| 2029 |
device const int8_t * qs = ax[row][ib].qs + il*NB_Q8_0;
|
| 2030 |
float sumq = 0.f;
|
| 2031 |
for (short iq = 0; iq < NB_Q8_0; ++iq) {
|
|
|
|
| 2039 |
|
| 2040 |
device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
|
| 2041 |
|
| 2042 |
+
for (int row = 0; row < nr0; ++row) {
|
| 2043 |
const float tot = simd_sum(sumf[row]);
|
| 2044 |
|
| 2045 |
if (tiisg == 0 && first_row + row < args.ne01) {
|
|
|
|
| 2057 |
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 2058 |
ushort tiisg[[thread_index_in_simdgroup]],
|
| 2059 |
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
|
| 2060 |
+
kernel_mul_mv_q8_0_f32_impl<N_R0_Q8_0, N_SG_Q8_0, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
|
| 2061 |
}
|
| 2062 |
|
| 2063 |
// mat-vec kernel processing in chunks of float4
|
|
|
|
| 2394 |
sumf += (T0) x[i] * (T1) y[i];
|
| 2395 |
}
|
| 2396 |
|
| 2397 |
+
float sum_all = simd_sum(sumf);
|
| 2398 |
if (tiisg == 0) {
|
| 2399 |
+
dst_f32[(uint64_t)r1*args.ne0 + r0] = sum_all;
|
| 2400 |
}
|
| 2401 |
}
|
| 2402 |
} else {
|
|
|
|
| 2417 |
sumf += dot((float4) x4[i], (float4) y4[i]);
|
| 2418 |
}
|
| 2419 |
|
| 2420 |
+
float sum_all = simd_sum(sumf);
|
| 2421 |
if (tiisg == 0) {
|
| 2422 |
+
for (int i = 4*(args.ne00/4); i < args.ne00; ++i) sum_all += (float) (x[i] * y[i]);
|
| 2423 |
+
dst_f32[(uint64_t)r1*args.ne0 + r0] = sum_all;
|
| 2424 |
}
|
| 2425 |
}
|
| 2426 |
}
|
|
|
|
| 2482 |
for (int i = tiisg; i < args.ne00; i += 32) {
|
| 2483 |
sumf += (float) x[i] * (float) y[i];
|
| 2484 |
}
|
| 2485 |
+
float sum_all = simd_sum(sumf);
|
| 2486 |
if (tiisg == 0) {
|
| 2487 |
+
dst_f32[r0] = sum_all;
|
| 2488 |
}
|
| 2489 |
} else {
|
| 2490 |
device const T4 * x4 = (device const T4 *) x;
|
|
|
|
| 2494 |
sumf += dot((float4) x4[i], y4[i]);
|
| 2495 |
}
|
| 2496 |
|
| 2497 |
+
float sum_all = simd_sum(sumf);
|
| 2498 |
|
| 2499 |
if (tiisg == 0) {
|
| 2500 |
+
for (int i = 4*(args.ne00/4); i < args.ne00; ++i) sum_all += (float) (x[i] * y[i]);
|
| 2501 |
+
dst_f32[r0] = sum_all;
|
| 2502 |
}
|
| 2503 |
}
|
| 2504 |
}
|
|
|
|
| 2543 |
sumf += dot((float4) x4[i], y4[i]);
|
| 2544 |
}
|
| 2545 |
|
| 2546 |
+
float sum_all = simd_sum(sumf);
|
| 2547 |
if (tiisg == 0) {
|
| 2548 |
+
dst_f32[(uint64_t)r1*args.ne0 + r0] = sum_all;
|
| 2549 |
}
|
| 2550 |
}
|
| 2551 |
}
|
|
|
|
| 4311 |
float amax = 0.0f; // absolute max
|
| 4312 |
float max = 0.0f;
|
| 4313 |
|
| 4314 |
+
for (int j = 0; j < QK4_NL; j++) {
|
| 4315 |
const float v = src[j];
|
| 4316 |
if (amax < fabs(v)) {
|
| 4317 |
amax = fabs(v);
|
|
|
|
| 4419 |
}
|
| 4420 |
}
|
| 4421 |
|
| 4422 |
+
template<int nr0, int nsg, int nw, typename args_t>
|
| 4423 |
void kernel_mul_mv_q2_K_f32_impl(
|
| 4424 |
args_t args,
|
| 4425 |
device const char * src0,
|
|
|
|
| 4435 |
const int r1 = tgpig.y;
|
| 4436 |
const int im = tgpig.z;
|
| 4437 |
|
| 4438 |
+
const int first_row = (r0 * nsg + sgitg) * nr0;
|
| 4439 |
|
| 4440 |
const uint i12 = im%args.ne12;
|
| 4441 |
const uint i13 = im/args.ne12;
|
|
|
|
| 4447 |
device const float * y = (device const float *) (src1 + offset1);
|
| 4448 |
|
| 4449 |
float yl[32];
|
| 4450 |
+
float sumf[nr0]={0.f};
|
| 4451 |
|
| 4452 |
+
const short ix = tiisg/8; // 0...3
|
| 4453 |
+
const short it = tiisg%8; // 0...7
|
| 4454 |
+
const short iq = it/4; // 0 or 1
|
| 4455 |
+
const short ir = it%4; // 0...3
|
| 4456 |
+
const short is = (8*ir)/16;// 0 or 1
|
| 4457 |
|
| 4458 |
device const float * y4 = y + ix * QK_K + 128 * iq + 8 * ir;
|
| 4459 |
|
| 4460 |
for (int ib = ix; ib < nb; ib += 4) {
|
|
|
|
| 4461 |
float4 sumy = {0.f, 0.f, 0.f, 0.f};
|
| 4462 |
+
for (short i = 0; i < 8; ++i) {
|
| 4463 |
yl[i+ 0] = y4[i+ 0]; sumy[0] += yl[i+ 0];
|
| 4464 |
yl[i+ 8] = y4[i+32]; sumy[1] += yl[i+ 8];
|
| 4465 |
yl[i+16] = y4[i+64]; sumy[2] += yl[i+16];
|
|
|
|
| 4470 |
device const uint16_t * qs = (device const uint16_t *)x[ib].qs + 16 * iq + 4 * ir;
|
| 4471 |
device const half * dh = &x[ib].d;
|
| 4472 |
|
| 4473 |
+
for (short row = 0; row < nr0; row++) {
|
| 4474 |
float4 acc1 = {0.f, 0.f, 0.f, 0.f};
|
| 4475 |
float4 acc2 = {0.f, 0.f, 0.f, 0.f};
|
| 4476 |
for (int i = 0; i < 8; i += 2) {
|
|
|
|
| 4501 |
|
| 4502 |
device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
|
| 4503 |
|
| 4504 |
+
for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) {
|
| 4505 |
+
float sum_all = simd_sum(sumf[row]);
|
| 4506 |
if (tiisg == 0) {
|
| 4507 |
+
dst_f32[first_row + row] = sum_all;
|
| 4508 |
}
|
| 4509 |
}
|
| 4510 |
}
|
|
|
|
| 4519 |
ushort tiisg[[thread_index_in_simdgroup]],
|
| 4520 |
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
|
| 4521 |
|
| 4522 |
+
kernel_mul_mv_q2_K_f32_impl<N_R0_Q2_K, N_SG_Q2_K, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
|
| 4523 |
}
|
| 4524 |
|
| 4525 |
+
template<int nr0, int nsg, int nw, typename args_t>
|
| 4526 |
void kernel_mul_mv_q3_K_f32_impl(
|
| 4527 |
args_t args,
|
| 4528 |
device const char * src0,
|
|
|
|
| 4539 |
const int r1 = tgpig.y;
|
| 4540 |
const int im = tgpig.z;
|
| 4541 |
|
| 4542 |
+
const int first_row = (r0 * nsg + sgitg) * nr0;
|
| 4543 |
|
| 4544 |
const uint i12 = im%args.ne12;
|
| 4545 |
const uint i13 = im/args.ne12;
|
|
|
|
| 4555 |
//const uint16_t kmask1 = 0x3030;
|
| 4556 |
//const uint16_t kmask2 = 0x0f0f;
|
| 4557 |
|
| 4558 |
+
const short tid = tiisg/4;
|
| 4559 |
+
const short ix = tiisg%4;
|
| 4560 |
+
const short ip = tid/4; // 0 or 1
|
| 4561 |
+
const short il = 2*((tid%4)/2); // 0 or 2
|
| 4562 |
+
const short ir = tid%2;
|
| 4563 |
+
const short l0 = 8*ir;
|
|
|
|
| 4564 |
|
| 4565 |
// One would think that the Metal compiler would figure out that ip and il can only have
|
| 4566 |
// 4 possible states, and optimize accordingly. Well, no. It needs help, and we do it
|
|
|
|
| 4585 |
const uint16_t s_shift1 = 4*ip;
|
| 4586 |
const uint16_t s_shift2 = s_shift1 + il;
|
| 4587 |
|
| 4588 |
+
const short q_offset = 32*ip + l0;
|
| 4589 |
+
const short y_offset = 128*ip + 32*il + l0;
|
| 4590 |
|
| 4591 |
device const float * y1 = yy + ix*QK_K + y_offset;
|
| 4592 |
|
|
|
|
| 4594 |
thread uint16_t * scales16 = (thread uint16_t *)&scales32;
|
| 4595 |
thread const int8_t * scales = (thread const int8_t *)&scales32;
|
| 4596 |
|
| 4597 |
+
float sumf1[nr0] = {0.f};
|
| 4598 |
+
float sumf2[nr0] = {0.f};
|
| 4599 |
+
|
| 4600 |
for (int i = ix; i < nb; i += 4) {
|
| 4601 |
+
for (short l = 0; l < 8; ++l) {
|
| 4602 |
yl[l+ 0] = y1[l+ 0];
|
| 4603 |
yl[l+ 8] = y1[l+16];
|
| 4604 |
yl[l+16] = y1[l+32];
|
|
|
|
| 4610 |
device const uint16_t * a = (device const uint16_t *)(x[i].scales);
|
| 4611 |
device const half * dh = &x[i].d;
|
| 4612 |
|
| 4613 |
+
for (short row = 0; row < nr0; ++row) {
|
| 4614 |
const float d_all = (float)dh[0];
|
| 4615 |
|
| 4616 |
scales16[0] = a[4];
|
|
|
|
| 4621 |
scales32 = ((scales32 >> s_shift1) & 0x0f0f0f0f) | aux32;
|
| 4622 |
|
| 4623 |
float s1 = 0, s2 = 0, s3 = 0, s4 = 0, s5 = 0, s6 = 0;
|
| 4624 |
+
for (short l = 0; l < 8; l += 2) {
|
| 4625 |
const int32_t qs = q[l/2];
|
| 4626 |
s1 += yl[l+0] * (qs & qm[il/2][0]);
|
| 4627 |
s2 += yl[l+1] * (qs & qm[il/2][1]);
|
|
|
|
| 4636 |
sumf2[row] += d2 * (scales[2] - 32);
|
| 4637 |
|
| 4638 |
s1 = s2 = s3 = s4 = s5 = s6 = 0;
|
| 4639 |
+
for (short l = 0; l < 8; l += 2) {
|
| 4640 |
const int32_t qs = q[l/2+8];
|
| 4641 |
s1 += yl[l+8] * (qs & qm[il/2][0]);
|
| 4642 |
s2 += yl[l+9] * (qs & qm[il/2][1]);
|
|
|
|
| 4659 |
y1 += 4 * QK_K;
|
| 4660 |
}
|
| 4661 |
|
| 4662 |
+
for (int row = 0; row < nr0; ++row) {
|
| 4663 |
const float sumf = (sumf1[row] + 0.25f * sumf2[row]) / (1 << shift);
|
| 4664 |
sumf1[row] = simd_sum(sumf);
|
| 4665 |
}
|
|
|
|
| 4667 |
device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
|
| 4668 |
|
| 4669 |
if (tiisg == 0) {
|
| 4670 |
+
for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) {
|
| 4671 |
dst_f32[first_row + row] = sumf1[row];
|
| 4672 |
}
|
| 4673 |
}
|
|
|
|
| 4683 |
ushort tiisg[[thread_index_in_simdgroup]],
|
| 4684 |
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
|
| 4685 |
|
| 4686 |
+
kernel_mul_mv_q3_K_f32_impl<N_R0_Q3_K, N_SG_Q3_K, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
|
| 4687 |
}
|
| 4688 |
|
| 4689 |
+
template<int nr0, int nsg, int nw, typename args_t>
|
| 4690 |
void kernel_mul_mv_q4_K_f32_impl(
|
| 4691 |
args_t args,
|
| 4692 |
device const char * src0,
|
|
|
|
| 4696 |
uint3 tgpig,
|
| 4697 |
ushort tiisg,
|
| 4698 |
ushort sgitg) {
|
|
|
|
| 4699 |
const uint16_t kmask1 = 0x3f3f;
|
| 4700 |
const uint16_t kmask2 = 0x0f0f;
|
| 4701 |
const uint16_t kmask3 = 0xc0c0;
|
| 4702 |
|
| 4703 |
+
const short ix = tiisg/8; // 0...3
|
| 4704 |
+
const short it = tiisg%8; // 0...7
|
| 4705 |
+
const short iq = it/4; // 0 or 1
|
| 4706 |
+
const short ir = it%4; // 0...3
|
| 4707 |
|
| 4708 |
const int nb = args.ne00/QK_K;
|
| 4709 |
+
|
| 4710 |
const int r0 = tgpig.x;
|
| 4711 |
const int r1 = tgpig.y;
|
| 4712 |
const int im = tgpig.z;
|
| 4713 |
+
|
| 4714 |
+
const int first_row = (r0 * nsg + sgitg) * nr0;
|
| 4715 |
|
| 4716 |
const uint i12 = im%args.ne12;
|
| 4717 |
const uint i13 = im/args.ne12;
|
|
|
|
| 4724 |
|
| 4725 |
float yl[16];
|
| 4726 |
float yh[16];
|
| 4727 |
+
|
| 4728 |
+
float sumf[nr0]={0.f};
|
| 4729 |
|
| 4730 |
device const float * y4 = y + ix * QK_K + 64 * iq + 8 * ir;
|
| 4731 |
|
|
|
|
| 4734 |
|
| 4735 |
for (int ib = ix; ib < nb; ib += 4) {
|
| 4736 |
float4 sumy = {0.f, 0.f, 0.f, 0.f};
|
| 4737 |
+
|
| 4738 |
+
for (short i = 0; i < 8; ++i) {
|
| 4739 |
yl[i+0] = y4[i+ 0]; sumy[0] += yl[i+0];
|
| 4740 |
yl[i+8] = y4[i+ 32]; sumy[1] += yl[i+8];
|
| 4741 |
yh[i+0] = y4[i+128]; sumy[2] += yh[i+0];
|
|
|
|
| 4746 |
device const uint16_t * q1 = (device const uint16_t *)x[ib].qs + 16 * iq + 4 * ir;
|
| 4747 |
device const half * dh = &x[ib].d;
|
| 4748 |
|
| 4749 |
+
for (short row = 0; row < nr0; row++) {
|
| 4750 |
sc16[0] = sc[0] & kmask1;
|
| 4751 |
sc16[1] = sc[2] & kmask1;
|
| 4752 |
sc16[2] = ((sc[4] >> 0) & kmask2) | ((sc[0] & kmask3) >> 2);
|
|
|
|
| 4756 |
|
| 4757 |
float4 acc1 = {0.f, 0.f, 0.f, 0.f};
|
| 4758 |
float4 acc2 = {0.f, 0.f, 0.f, 0.f};
|
| 4759 |
+
|
| 4760 |
+
for (short i = 0; i < 4; ++i) {
|
| 4761 |
+
acc1[0] += yl[2*i + 0] * (q1[i] & 0x000F);
|
| 4762 |
+
acc1[1] += yl[2*i + 1] * (q1[i] & 0x0F00);
|
| 4763 |
+
acc1[2] += yl[2*i + 8] * (q1[i] & 0x00F0);
|
| 4764 |
+
acc1[3] += yl[2*i + 9] * (q1[i] & 0xF000);
|
| 4765 |
+
acc2[0] += yh[2*i + 0] * (q2[i] & 0x000F);
|
| 4766 |
+
acc2[1] += yh[2*i + 1] * (q2[i] & 0x0F00);
|
| 4767 |
+
acc2[2] += yh[2*i + 8] * (q2[i] & 0x00F0);
|
| 4768 |
+
acc2[3] += yh[2*i + 9] * (q2[i] & 0xF000);
|
| 4769 |
}
|
| 4770 |
|
| 4771 |
float dall = dh[0];
|
| 4772 |
float dmin = dh[1];
|
| 4773 |
+
|
| 4774 |
sumf[row] += dall * ((acc1[0] + 1.f/256.f * acc1[1]) * sc8[0] +
|
| 4775 |
(acc1[2] + 1.f/256.f * acc1[3]) * sc8[1] * 1.f/16.f +
|
| 4776 |
(acc2[0] + 1.f/256.f * acc2[1]) * sc8[4] +
|
|
|
|
| 4787 |
|
| 4788 |
device float * dst_f32 = (device float *) dst + (int64_t)im*args.ne0*args.ne1 + (int64_t)r1*args.ne0;
|
| 4789 |
|
| 4790 |
+
for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) {
|
| 4791 |
+
float sum_all = simd_sum(sumf[row]);
|
| 4792 |
if (tiisg == 0) {
|
| 4793 |
+
dst_f32[first_row + row] = sum_all;
|
| 4794 |
}
|
| 4795 |
}
|
| 4796 |
}
|
|
|
|
| 4805 |
ushort tiisg[[thread_index_in_simdgroup]],
|
| 4806 |
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
|
| 4807 |
|
| 4808 |
+
kernel_mul_mv_q4_K_f32_impl<N_R0_Q4_K, N_SG_Q4_K, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
|
| 4809 |
}
|
| 4810 |
|
| 4811 |
+
template<int nr0, int nsg, int nw, typename args_t>
|
| 4812 |
void kernel_mul_mv_q5_K_f32_impl(
|
| 4813 |
args_t args,
|
| 4814 |
device const char * src0,
|
|
|
|
| 4825 |
const int r1 = tgpig.y;
|
| 4826 |
const int im = tgpig.z;
|
| 4827 |
|
| 4828 |
+
const int first_row = (r0 * nsg + sgitg) * nr0;
|
| 4829 |
|
| 4830 |
const uint i12 = im%args.ne12;
|
| 4831 |
const uint i13 = im/args.ne12;
|
|
|
|
| 4836 |
device const block_q5_K * x = (device const block_q5_K *) (src0 + offset0);
|
| 4837 |
device const float * yy = (device const float *) (src1 + offset1);
|
| 4838 |
|
| 4839 |
+
float sumf[nr0]={0.f};
|
| 4840 |
|
| 4841 |
float yl[16], yh[16];
|
| 4842 |
|
|
|
|
| 4844 |
const uint16_t kmask2 = 0x0f0f;
|
| 4845 |
const uint16_t kmask3 = 0xc0c0;
|
| 4846 |
|
| 4847 |
+
const short tid = tiisg/4;
|
| 4848 |
+
const short ix = tiisg%4;
|
| 4849 |
+
const short iq = tid/4;
|
| 4850 |
+
const short ir = tid%4;
|
|
|
|
| 4851 |
|
| 4852 |
+
const short l0 = 8*ir;
|
| 4853 |
+
const short q_offset = 32*iq + l0;
|
| 4854 |
+
const short y_offset = 64*iq + l0;
|
| 4855 |
|
| 4856 |
const uint8_t hm1 = 1u << (2*iq);
|
| 4857 |
const uint8_t hm2 = hm1 << 1;
|
|
|
|
| 4871 |
|
| 4872 |
device const float * y2 = y1 + 128;
|
| 4873 |
float4 sumy = {0.f, 0.f, 0.f, 0.f};
|
| 4874 |
+
for (short l = 0; l < 8; ++l) {
|
| 4875 |
yl[l+0] = y1[l+ 0]; sumy[0] += yl[l+0];
|
| 4876 |
yl[l+8] = y1[l+32]; sumy[1] += yl[l+8];
|
| 4877 |
yh[l+0] = y2[l+ 0]; sumy[2] += yh[l+0];
|
| 4878 |
yh[l+8] = y2[l+32]; sumy[3] += yh[l+8];
|
| 4879 |
}
|
| 4880 |
|
| 4881 |
+
for (short row = 0; row < nr0; ++row) {
|
| 4882 |
device const uint8_t * q2 = q1 + 64;
|
| 4883 |
|
| 4884 |
sc16[0] = a[0] & kmask1;
|
|
|
|
| 4888 |
|
| 4889 |
float4 acc1 = {0.f};
|
| 4890 |
float4 acc2 = {0.f};
|
| 4891 |
+
for (short l = 0; l < 8; ++l) {
|
| 4892 |
uint8_t h = qh[l];
|
| 4893 |
acc1[0] += yl[l+0] * (q1[l] & 0x0F);
|
| 4894 |
acc1[1] += yl[l+8] * (q1[l] & 0xF0);
|
|
|
|
| 4918 |
|
| 4919 |
device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
|
| 4920 |
|
| 4921 |
+
for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) {
|
| 4922 |
const float tot = simd_sum(sumf[row]);
|
| 4923 |
if (tiisg == 0) {
|
| 4924 |
dst_f32[first_row + row] = tot;
|
|
|
|
| 4936 |
ushort tiisg[[thread_index_in_simdgroup]],
|
| 4937 |
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
|
| 4938 |
|
| 4939 |
+
kernel_mul_mv_q5_K_f32_impl<N_R0_Q5_K, N_SG_Q5_K, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
|
| 4940 |
}
|
| 4941 |
|
| 4942 |
+
template<int nr0, int nsg, int nw, typename args_t>
|
| 4943 |
void kernel_mul_mv_q6_K_f32_impl(
|
| 4944 |
args_t args,
|
| 4945 |
device const char * src0,
|
|
|
|
| 4961 |
const int r1 = tgpig.y;
|
| 4962 |
const int im = tgpig.z;
|
| 4963 |
|
| 4964 |
+
const int first_row = (r0 * nsg + sgitg) * nr0;
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4965 |
|
| 4966 |
const uint i12 = im%args.ne12;
|
| 4967 |
const uint i13 = im/args.ne12;
|
| 4968 |
|
| 4969 |
+
const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
|
| 4970 |
+
const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
|
| 4971 |
|
| 4972 |
device const block_q6_K * x = (device const block_q6_K *) (src0 + offset0);
|
| 4973 |
device const float * yy = (device const float *) (src1 + offset1);
|
| 4974 |
|
| 4975 |
+
float sumf[nr0] = { 0.f };
|
| 4976 |
+
|
| 4977 |
+
float yl[16];
|
| 4978 |
|
| 4979 |
+
const short tid = tiisg/2;
|
| 4980 |
+
const short ix = tiisg%2;
|
| 4981 |
+
const short ip = tid/8; // 0 or 1
|
| 4982 |
+
const short il = tid%8;
|
| 4983 |
+
const short l0 = 4*il;
|
| 4984 |
+
const short is = 8*ip + l0/16;
|
|
|
|
| 4985 |
|
| 4986 |
+
const short y_offset = 128*ip + l0;
|
| 4987 |
+
const short q_offset_l = 64*ip + l0;
|
| 4988 |
+
const short q_offset_h = 32*ip + l0;
|
| 4989 |
|
| 4990 |
for (int i = ix; i < nb; i += 2) {
|
| 4991 |
device const uint8_t * q1 = x[i].ql + q_offset_l;
|
| 4992 |
device const uint8_t * q2 = q1 + 32;
|
| 4993 |
device const uint8_t * qh = x[i].qh + q_offset_h;
|
| 4994 |
device const int8_t * sc = x[i].scales + is;
|
| 4995 |
+
device const half * dh = &x[i].d;
|
| 4996 |
|
| 4997 |
device const float * y = yy + i * QK_K + y_offset;
|
| 4998 |
|
| 4999 |
+
for (short l = 0; l < 4; ++l) {
|
| 5000 |
+
yl[4*l + 0] = y[l + 0];
|
| 5001 |
+
yl[4*l + 1] = y[l + 32];
|
| 5002 |
+
yl[4*l + 2] = y[l + 64];
|
| 5003 |
+
yl[4*l + 3] = y[l + 96];
|
|
|
|
|
|
|
|
|
|
| 5004 |
}
|
| 5005 |
|
| 5006 |
+
for (short row = 0; row < nr0; ++row) {
|
| 5007 |
+
const float dall = dh[0];
|
| 5008 |
+
|
| 5009 |
+
float4 sums = {0.f, 0.f, 0.f, 0.f};
|
| 5010 |
|
| 5011 |
+
for (short l = 0; l < 4; ++l) {
|
| 5012 |
+
sums[0] += yl[4*l + 0] * ((int8_t)((q1[l] & 0xF) | ((qh[l] & kmask1) << 4)) - 32);
|
| 5013 |
+
sums[1] += yl[4*l + 1] * ((int8_t)((q2[l] & 0xF) | ((qh[l] & kmask2) << 2)) - 32);
|
| 5014 |
+
sums[2] += yl[4*l + 2] * ((int8_t)((q1[l] >> 4) | ((qh[l] & kmask3) << 0)) - 32);
|
| 5015 |
+
sums[3] += yl[4*l + 3] * ((int8_t)((q2[l] >> 4) | ((qh[l] & kmask4) >> 2)) - 32);
|
| 5016 |
+
}
|
| 5017 |
+
|
| 5018 |
+
sumf[row] += dall * (sums[0] * sc[0] + sums[1] * sc[2] + sums[2] * sc[4] + sums[3] * sc[6]);
|
| 5019 |
+
|
| 5020 |
+
q1 += args.nb01;
|
| 5021 |
+
q2 += args.nb01;
|
| 5022 |
+
qh += args.nb01;
|
| 5023 |
+
sc += args.nb01;
|
| 5024 |
+
dh += args.nb01/2;
|
| 5025 |
+
}
|
| 5026 |
}
|
| 5027 |
|
| 5028 |
device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
|
| 5029 |
|
| 5030 |
+
for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) {
|
| 5031 |
+
float sum_all = simd_sum(sumf[row]);
|
| 5032 |
+
if (tiisg == 0) {
|
| 5033 |
+
dst_f32[first_row + row] = sum_all;
|
| 5034 |
+
}
|
| 5035 |
}
|
| 5036 |
}
|
| 5037 |
|
|
|
|
| 5045 |
ushort tiisg[[thread_index_in_simdgroup]],
|
| 5046 |
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
|
| 5047 |
|
| 5048 |
+
kernel_mul_mv_q6_K_f32_impl<N_R0_Q6_K, N_SG_Q6_K, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
|
| 5049 |
}
|
| 5050 |
|
| 5051 |
// ======================= "True" 2-bit
|
| 5052 |
|
| 5053 |
+
template<int nr0, int nsg, int nw, typename args_t>
|
| 5054 |
void kernel_mul_mv_iq2_xxs_f32_impl(
|
| 5055 |
args_t args,
|
| 5056 |
device const char * src0,
|
|
|
|
| 5066 |
const int r1 = tgpig.y;
|
| 5067 |
const int im = tgpig.z;
|
| 5068 |
|
| 5069 |
+
const int first_row = (r0 * nsg + sgitg) * nr0;
|
| 5070 |
|
| 5071 |
const uint i12 = im%args.ne12;
|
| 5072 |
const uint i13 = im/args.ne12;
|
|
|
|
| 5078 |
device const float * y = (device const float *) (src1 + offset1);
|
| 5079 |
|
| 5080 |
float yl[32];
|
| 5081 |
+
float sumf[nr0]={0.f};
|
| 5082 |
|
| 5083 |
const int nb32 = nb * (QK_K / 32);
|
| 5084 |
|
|
|
|
| 5099 |
device const float * y4 = y + 32 * ix;
|
| 5100 |
|
| 5101 |
for (int ib32 = ix; ib32 < nb32; ib32 += 32) {
|
| 5102 |
+
for (short i = 0; i < 32; ++i) {
|
|
|
|
| 5103 |
yl[i] = y4[i];
|
| 5104 |
}
|
| 5105 |
|
|
|
|
| 5110 |
device const uint16_t * q2 = xr->qs + 4 * ib;
|
| 5111 |
device const half * dh = &xr->d;
|
| 5112 |
|
| 5113 |
+
for (short row = 0; row < nr0; row++) {
|
|
|
|
| 5114 |
const float db = dh[0];
|
| 5115 |
device const uint8_t * aux8 = (device const uint8_t *)q2;
|
| 5116 |
const uint32_t aux32 = q2[2] | (q2[3] << 16);
|
| 5117 |
const float d = db * (0.5f + (aux32 >> 28));
|
| 5118 |
|
| 5119 |
float sum = 0;
|
| 5120 |
+
for (short l = 0; l < 4; ++l) {
|
| 5121 |
const threadgroup uint8_t * grid = (const threadgroup uint8_t *)(svalues + aux8[l]);
|
| 5122 |
const uint8_t signs = ssigns[(aux32 >> 7*l) & 127];
|
| 5123 |
+
for (short j = 0; j < 8; ++j) {
|
| 5124 |
sum += yl[8*l + j] * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
|
| 5125 |
}
|
| 5126 |
}
|
|
|
|
| 5135 |
|
| 5136 |
device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
|
| 5137 |
|
| 5138 |
+
for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) {
|
| 5139 |
+
float sum_all = simd_sum(sumf[row]);
|
| 5140 |
if (tiisg == 0) {
|
| 5141 |
+
dst_f32[first_row + row] = sum_all * 0.25f;
|
| 5142 |
}
|
| 5143 |
}
|
| 5144 |
}
|
|
|
|
| 5153 |
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 5154 |
ushort tiisg[[thread_index_in_simdgroup]],
|
| 5155 |
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
|
| 5156 |
+
kernel_mul_mv_iq2_xxs_f32_impl<N_R0_IQ2_XXS, N_SG_IQ2_XXS, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
|
| 5157 |
}
|
| 5158 |
|
| 5159 |
+
template<int nr0, int nsg, int nw, typename args_t>
|
| 5160 |
void kernel_mul_mv_iq2_xs_f32_impl(
|
| 5161 |
args_t args,
|
| 5162 |
device const char * src0,
|
|
|
|
| 5172 |
const int r1 = tgpig.y;
|
| 5173 |
const int im = tgpig.z;
|
| 5174 |
|
| 5175 |
+
const int first_row = (r0 * nsg + sgitg) * nr0;
|
| 5176 |
|
| 5177 |
const uint i12 = im%args.ne12;
|
| 5178 |
const uint i13 = im/args.ne12;
|
|
|
|
| 5184 |
device const float * y = (device const float *) (src1 + offset1);
|
| 5185 |
|
| 5186 |
float yl[32];
|
| 5187 |
+
float sumf[nr0]={0.f};
|
| 5188 |
|
| 5189 |
const int nb32 = nb * (QK_K / 32);
|
| 5190 |
|
|
|
|
| 5205 |
device const float * y4 = y + 32 * ix;
|
| 5206 |
|
| 5207 |
for (int ib32 = ix; ib32 < nb32; ib32 += 32) {
|
| 5208 |
+
for (short i = 0; i < 32; ++i) {
|
|
|
|
| 5209 |
yl[i] = y4[i];
|
| 5210 |
}
|
| 5211 |
|
|
|
|
| 5217 |
device const uint8_t * sc = xr->scales + ib;
|
| 5218 |
device const half * dh = &xr->d;
|
| 5219 |
|
| 5220 |
+
for (short row = 0; row < nr0; row++) {
|
|
|
|
| 5221 |
const float db = dh[0];
|
| 5222 |
const uint8_t ls1 = sc[0] & 0xf;
|
| 5223 |
const uint8_t ls2 = sc[0] >> 4;
|
|
|
|
| 5225 |
const float d2 = db * (0.5f + ls2);
|
| 5226 |
|
| 5227 |
float sum1 = 0, sum2 = 0;
|
| 5228 |
+
for (short l = 0; l < 2; ++l) {
|
| 5229 |
const threadgroup uint8_t * grid = (const threadgroup uint8_t *)(svalues + (q2[l] & 511));
|
| 5230 |
const uint8_t signs = ssigns[(q2[l] >> 9)];
|
| 5231 |
+
for (short j = 0; j < 8; ++j) {
|
| 5232 |
sum1 += yl[8*l + j] * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
|
| 5233 |
}
|
| 5234 |
}
|
| 5235 |
+
for (short l = 2; l < 4; ++l) {
|
| 5236 |
const threadgroup uint8_t * grid = (const threadgroup uint8_t *)(svalues + (q2[l] & 511));
|
| 5237 |
const uint8_t signs = ssigns[(q2[l] >> 9)];
|
| 5238 |
+
for (short j = 0; j < 8; ++j) {
|
| 5239 |
sum2 += yl[8*l + j] * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
|
| 5240 |
}
|
| 5241 |
}
|
|
|
|
| 5251 |
|
| 5252 |
device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
|
| 5253 |
|
| 5254 |
+
for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) {
|
| 5255 |
+
float sum_all = simd_sum(sumf[row]);
|
| 5256 |
if (tiisg == 0) {
|
| 5257 |
+
dst_f32[first_row + row] = sum_all * 0.25f;
|
| 5258 |
}
|
| 5259 |
}
|
| 5260 |
}
|
|
|
|
| 5270 |
ushort tiisg[[thread_index_in_simdgroup]],
|
| 5271 |
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
|
| 5272 |
|
| 5273 |
+
kernel_mul_mv_iq2_xs_f32_impl<N_R0_IQ2_XS, N_SG_IQ2_XS, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
|
| 5274 |
}
|
| 5275 |
|
| 5276 |
+
template<int nr0, int nsg, int nw, typename args_t>
|
| 5277 |
void kernel_mul_mv_iq3_xxs_f32_impl(
|
| 5278 |
args_t args,
|
| 5279 |
device const char * src0,
|
|
|
|
| 5289 |
const int r1 = tgpig.y;
|
| 5290 |
const int im = tgpig.z;
|
| 5291 |
|
| 5292 |
+
const int first_row = (r0 * nsg + sgitg) * nr0;
|
| 5293 |
|
| 5294 |
const uint i12 = im%args.ne12;
|
| 5295 |
const uint i13 = im/args.ne12;
|
|
|
|
| 5301 |
device const float * y = (device const float *) (src1 + offset1);
|
| 5302 |
|
| 5303 |
float yl[32];
|
| 5304 |
+
float sumf[nr0]={0.f};
|
| 5305 |
|
| 5306 |
const int nb32 = nb * (QK_K / 32);
|
| 5307 |
|
|
|
|
| 5322 |
device const float * y4 = y + 32 * ix;
|
| 5323 |
|
| 5324 |
for (int ib32 = ix; ib32 < nb32; ib32 += 32) {
|
| 5325 |
+
for (short i = 0; i < 32; ++i) {
|
| 5326 |
yl[i] = y4[i];
|
| 5327 |
}
|
| 5328 |
|
|
|
|
| 5334 |
device const uint16_t * gas = (device const uint16_t *)(xr->qs + QK_K/4) + 2 * ib;
|
| 5335 |
device const half * dh = &xr->d;
|
| 5336 |
|
| 5337 |
+
for (short row = 0; row < nr0; row++) {
|
| 5338 |
const float db = dh[0];
|
| 5339 |
const uint32_t aux32 = gas[0] | (gas[1] << 16);
|
| 5340 |
const float d = db * (0.5f + (aux32 >> 28));
|
| 5341 |
|
| 5342 |
float2 sum = {0};
|
| 5343 |
+
for (short l = 0; l < 4; ++l) {
|
| 5344 |
const threadgroup uint8_t * grid1 = (const threadgroup uint8_t *)(svalues + q3[2*l+0]);
|
| 5345 |
const threadgroup uint8_t * grid2 = (const threadgroup uint8_t *)(svalues + q3[2*l+1]);
|
| 5346 |
const uint8_t signs = ssigns[(aux32 >> 7*l) & 127];
|
| 5347 |
+
for (short j = 0; j < 4; ++j) {
|
| 5348 |
sum[0] += yl[8*l + j + 0] * grid1[j] * (signs & kmask_iq2xs[j+0] ? -1.f : 1.f);
|
| 5349 |
sum[1] += yl[8*l + j + 4] * grid2[j] * (signs & kmask_iq2xs[j+4] ? -1.f : 1.f);
|
| 5350 |
}
|
|
|
|
| 5361 |
|
| 5362 |
device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
|
| 5363 |
|
| 5364 |
+
for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) {
|
| 5365 |
+
float sum_all = simd_sum(sumf[row]);
|
| 5366 |
if (tiisg == 0) {
|
| 5367 |
+
dst_f32[first_row + row] = sum_all * 0.5f;
|
| 5368 |
}
|
| 5369 |
}
|
| 5370 |
}
|
|
|
|
| 5380 |
ushort tiisg[[thread_index_in_simdgroup]],
|
| 5381 |
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
|
| 5382 |
|
| 5383 |
+
kernel_mul_mv_iq3_xxs_f32_impl<N_R0_IQ3_XXS, N_SG_IQ3_XXS, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
|
| 5384 |
}
|
| 5385 |
|
| 5386 |
+
template<int nr0, int nsg, int nw, typename args_t>
|
| 5387 |
void kernel_mul_mv_iq3_s_f32_impl(
|
| 5388 |
args_t args,
|
| 5389 |
device const char * src0,
|
|
|
|
| 5399 |
const int r1 = tgpig.y;
|
| 5400 |
const int im = tgpig.z;
|
| 5401 |
|
| 5402 |
+
const int first_row = (r0 * nsg + sgitg) * nr0;
|
| 5403 |
|
| 5404 |
const uint i12 = im%args.ne12;
|
| 5405 |
const uint i13 = im/args.ne12;
|
|
|
|
| 5411 |
device const float * y = (device const float *) (src1 + offset1);
|
| 5412 |
|
| 5413 |
float yl[32];
|
| 5414 |
+
float sumf[nr0]={0.f};
|
| 5415 |
|
| 5416 |
const int nb32 = nb * (QK_K / 32);
|
| 5417 |
|
|
|
|
| 5428 |
device const float * y4 = y + 32 * ix;
|
| 5429 |
|
| 5430 |
for (int ib32 = ix; ib32 < nb32; ib32 += 32) {
|
| 5431 |
+
for (short i = 0; i < 32; ++i) {
|
|
|
|
| 5432 |
yl[i] = y4[i];
|
| 5433 |
}
|
| 5434 |
|
|
|
|
| 5442 |
device const uint8_t * signs = xr->signs + 4 * ib;
|
| 5443 |
device const half * dh = &xr->d;
|
| 5444 |
|
| 5445 |
+
for (short row = 0; row < nr0; row++) {
|
|
|
|
| 5446 |
const float db = dh[0];
|
| 5447 |
const float d = db * (1 + 2*((sc[0] >> 4*(ib%2)) & 0xf));
|
| 5448 |
|
| 5449 |
float2 sum = {0};
|
| 5450 |
+
for (short l = 0; l < 4; ++l) {
|
| 5451 |
const threadgroup uint32_t * table1 = qh[0] & kmask_iq2xs[2*l+0] ? svalues + 256 : svalues;
|
| 5452 |
const threadgroup uint32_t * table2 = qh[0] & kmask_iq2xs[2*l+1] ? svalues + 256 : svalues;
|
| 5453 |
const threadgroup uint8_t * grid1 = (const threadgroup uint8_t *)(table1 + qs[2*l+0]);
|
| 5454 |
const threadgroup uint8_t * grid2 = (const threadgroup uint8_t *)(table2 + qs[2*l+1]);
|
| 5455 |
+
for (short j = 0; j < 4; ++j) {
|
| 5456 |
sum[0] += yl[8*l + j + 0] * grid1[j] * select(1, -1, signs[l] & kmask_iq2xs[j+0]);
|
| 5457 |
sum[1] += yl[8*l + j + 4] * grid2[j] * select(1, -1, signs[l] & kmask_iq2xs[j+4]);
|
| 5458 |
}
|
|
|
|
| 5471 |
|
| 5472 |
device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
|
| 5473 |
|
| 5474 |
+
for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) {
|
| 5475 |
+
float sum_all = simd_sum(sumf[row]);
|
| 5476 |
if (tiisg == 0) {
|
| 5477 |
+
dst_f32[first_row + row] = sum_all;
|
| 5478 |
}
|
| 5479 |
}
|
| 5480 |
}
|
|
|
|
| 5490 |
ushort tiisg[[thread_index_in_simdgroup]],
|
| 5491 |
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
|
| 5492 |
|
| 5493 |
+
kernel_mul_mv_iq3_s_f32_impl<N_R0_IQ3_S, N_SG_IQ3_S, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
|
| 5494 |
}
|
| 5495 |
|
| 5496 |
+
template<int nr0, int nsg, int nw, typename args_t>
|
| 5497 |
void kernel_mul_mv_iq2_s_f32_impl(
|
| 5498 |
args_t args,
|
| 5499 |
device const char * src0,
|
|
|
|
| 5509 |
const int r1 = tgpig.y;
|
| 5510 |
const int im = tgpig.z;
|
| 5511 |
|
| 5512 |
+
const int first_row = (r0 * nsg + sgitg) * nr0;
|
| 5513 |
|
| 5514 |
const uint i12 = im%args.ne12;
|
| 5515 |
const uint i13 = im/args.ne12;
|
|
|
|
| 5521 |
device const float * y = (device const float *) (src1 + offset1);
|
| 5522 |
|
| 5523 |
float yl[32];
|
| 5524 |
+
float sumf[nr0]={0.f};
|
| 5525 |
|
| 5526 |
const int nb32 = nb * (QK_K / 32);
|
| 5527 |
|
|
|
|
| 5533 |
// threadgroup_barrier(mem_flags::mem_threadgroup);
|
| 5534 |
//}
|
| 5535 |
|
| 5536 |
+
const short ix = tiisg;
|
| 5537 |
|
| 5538 |
device const float * y4 = y + 32 * ix;
|
| 5539 |
|
| 5540 |
for (int ib32 = ix; ib32 < nb32; ib32 += 32) {
|
| 5541 |
+
for (short i = 0; i < 32; ++i) {
|
|
|
|
| 5542 |
yl[i] = y4[i];
|
| 5543 |
}
|
| 5544 |
|
|
|
|
| 5552 |
device const uint8_t * signs = qs + QK_K/8;
|
| 5553 |
device const half * dh = &xr->d;
|
| 5554 |
|
| 5555 |
+
for (short row = 0; row < nr0; row++) {
|
|
|
|
| 5556 |
const float db = dh[0];
|
| 5557 |
const float d1 = db * (0.5f + (sc[0] & 0xf));
|
| 5558 |
const float d2 = db * (0.5f + (sc[0] >> 4));
|
| 5559 |
|
| 5560 |
float2 sum = {0};
|
| 5561 |
+
for (short l = 0; l < 2; ++l) {
|
| 5562 |
//const threadgroup uint8_t * grid1 = (const threadgroup uint8_t *)(svalues + (qs[l+0] | ((qh[0] << (8-2*l)) & 0x300)));
|
| 5563 |
//const threadgroup uint8_t * grid2 = (const threadgroup uint8_t *)(svalues + (qs[l+2] | ((qh[0] << (4-2*l)) & 0x300)));
|
| 5564 |
constant uint8_t * grid1 = (constant uint8_t *)(iq2s_grid + (qs[l+0] | ((qh[0] << (8-2*l)) & 0x300)));
|
| 5565 |
constant uint8_t * grid2 = (constant uint8_t *)(iq2s_grid + (qs[l+2] | ((qh[0] << (4-2*l)) & 0x300)));
|
| 5566 |
+
for (short j = 0; j < 8; ++j) {
|
| 5567 |
sum[0] += yl[8*l + j + 0] * grid1[j] * select(1, -1, signs[l+0] & kmask_iq2xs[j]);
|
| 5568 |
sum[1] += yl[8*l + j + 16] * grid2[j] * select(1, -1, signs[l+2] & kmask_iq2xs[j]);
|
| 5569 |
}
|
|
|
|
| 5582 |
|
| 5583 |
device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
|
| 5584 |
|
| 5585 |
+
for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) {
|
| 5586 |
+
float sum_all = simd_sum(sumf[row]);
|
| 5587 |
if (tiisg == 0) {
|
| 5588 |
+
dst_f32[first_row + row] = sum_all * 0.25f;
|
| 5589 |
}
|
| 5590 |
}
|
| 5591 |
}
|
|
|
|
| 5601 |
ushort tiisg[[thread_index_in_simdgroup]],
|
| 5602 |
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
|
| 5603 |
|
| 5604 |
+
kernel_mul_mv_iq2_s_f32_impl<N_R0_IQ2_S, N_SG_IQ2_S, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
|
| 5605 |
}
|
| 5606 |
|
| 5607 |
+
template<int nr0, int nsg, int nw, typename args_t>
|
| 5608 |
void kernel_mul_mv_iq1_s_f32_impl(
|
| 5609 |
args_t args,
|
| 5610 |
device const char * src0,
|
|
|
|
| 5620 |
const int r1 = tgpig.y;
|
| 5621 |
const int im = tgpig.z;
|
| 5622 |
|
| 5623 |
+
const int first_row = (r0 * nsg + sgitg) * nr0;
|
| 5624 |
|
| 5625 |
const uint i12 = im%args.ne12;
|
| 5626 |
const uint i13 = im/args.ne12;
|
|
|
|
| 5632 |
device const float * y = (device const float *) (src1 + offset1);
|
| 5633 |
|
| 5634 |
float yl[32];
|
| 5635 |
+
float sumf[nr0]={0.f};
|
| 5636 |
|
| 5637 |
const int nb32 = nb * (QK_K / 32);
|
| 5638 |
|
| 5639 |
+
const short ix = tiisg;
|
| 5640 |
|
| 5641 |
device const float * y4 = y + 32 * ix;
|
| 5642 |
|
| 5643 |
for (int ib32 = ix; ib32 < nb32; ib32 += 32) {
|
|
|
|
| 5644 |
float sumy = 0;
|
| 5645 |
+
for (short i = 0; i < 32; ++i) {
|
| 5646 |
yl[i] = y4[i];
|
| 5647 |
sumy += yl[i];
|
| 5648 |
}
|
|
|
|
| 5655 |
device const uint16_t * qh = xr->qh + ib;
|
| 5656 |
device const half * dh = &xr->d;
|
| 5657 |
|
| 5658 |
+
for (short row = 0; row < nr0; row++) {
|
|
|
|
| 5659 |
constant uint8_t * grid1 = (constant uint8_t *)(iq1s_grid_gpu + (qs[0] | ((qh[0] << 8) & 0x700)));
|
| 5660 |
constant uint8_t * grid2 = (constant uint8_t *)(iq1s_grid_gpu + (qs[1] | ((qh[0] << 5) & 0x700)));
|
| 5661 |
constant uint8_t * grid3 = (constant uint8_t *)(iq1s_grid_gpu + (qs[2] | ((qh[0] << 2) & 0x700)));
|
| 5662 |
constant uint8_t * grid4 = (constant uint8_t *)(iq1s_grid_gpu + (qs[3] | ((qh[0] >> 1) & 0x700)));
|
| 5663 |
|
| 5664 |
float sum = 0;
|
| 5665 |
+
for (short j = 0; j < 4; ++j) {
|
| 5666 |
sum += yl[j+ 0] * (grid1[j] & 0xf) + yl[j+ 4] * (grid1[j] >> 4)
|
| 5667 |
+ yl[j+ 8] * (grid2[j] & 0xf) + yl[j+12] * (grid2[j] >> 4)
|
| 5668 |
+ yl[j+16] * (grid3[j] & 0xf) + yl[j+20] * (grid3[j] >> 4)
|
|
|
|
| 5680 |
|
| 5681 |
device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
|
| 5682 |
|
| 5683 |
+
for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) {
|
| 5684 |
+
float sum_all = simd_sum(sumf[row]);
|
| 5685 |
if (tiisg == 0) {
|
| 5686 |
+
dst_f32[first_row + row] = sum_all;
|
| 5687 |
}
|
| 5688 |
}
|
| 5689 |
}
|
| 5690 |
|
| 5691 |
+
[[host_name("kernel_mul_mv_iq1_s_f32")]]
|
| 5692 |
+
kernel void kernel_mul_mv_iq1_s_f32(
|
| 5693 |
+
constant ggml_metal_kargs_mul_mv & args,
|
| 5694 |
+
device const char * src0,
|
| 5695 |
+
device const char * src1,
|
| 5696 |
+
device char * dst,
|
| 5697 |
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 5698 |
+
ushort tiisg[[thread_index_in_simdgroup]],
|
| 5699 |
+
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
|
| 5700 |
+
|
| 5701 |
+
kernel_mul_mv_iq1_s_f32_impl<N_R0_IQ1_S, N_SG_IQ1_S, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
|
| 5702 |
+
}
|
| 5703 |
+
|
| 5704 |
+
template<int nr0, int nsg, int nw, typename args_t>
|
| 5705 |
void kernel_mul_mv_iq1_m_f32_impl(
|
| 5706 |
args_t args,
|
| 5707 |
device const char * src0,
|
|
|
|
| 5713 |
ushort sgitg) {
|
| 5714 |
|
| 5715 |
const int nb = args.ne00/QK_K;
|
| 5716 |
+
|
| 5717 |
const int r0 = tgpig.x;
|
| 5718 |
const int r1 = tgpig.y;
|
| 5719 |
const int im = tgpig.z;
|
| 5720 |
|
| 5721 |
+
const int first_row = (r0 * nsg + sgitg) * nr0;
|
| 5722 |
|
| 5723 |
const uint i12 = im%args.ne12;
|
| 5724 |
const uint i13 = im/args.ne12;
|
|
|
|
| 5730 |
device const float * y = (device const float *) (src1 + offset1);
|
| 5731 |
|
| 5732 |
float yl[32];
|
| 5733 |
+
float sumf[nr0]={0.f};
|
| 5734 |
|
| 5735 |
const int nb32 = nb * (QK_K / 32);
|
| 5736 |
|
| 5737 |
+
const short ix = tiisg;
|
| 5738 |
|
| 5739 |
device const float * y4 = y + 32 * ix;
|
| 5740 |
|
| 5741 |
iq1m_scale_t scale;
|
| 5742 |
|
| 5743 |
for (int ib32 = ix; ib32 < nb32; ib32 += 32) {
|
|
|
|
| 5744 |
float4 sumy = {0.f};
|
| 5745 |
+
for (short i = 0; i < 8; ++i) {
|
| 5746 |
yl[i+ 0] = y4[i+ 0]; sumy[0] += yl[i+ 0];
|
| 5747 |
yl[i+ 8] = y4[i+ 8]; sumy[1] += yl[i+ 8];
|
| 5748 |
yl[i+16] = y4[i+16]; sumy[2] += yl[i+16];
|
|
|
|
| 5757 |
device const uint8_t * qh = xr->qh + 2 * ib;
|
| 5758 |
device const uint16_t * sc = (device const uint16_t *)xr->scales;
|
| 5759 |
|
| 5760 |
+
for (short row = 0; row < nr0; row++) {
|
| 5761 |
scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000);
|
| 5762 |
|
| 5763 |
constant uint8_t * grid1 = (constant uint8_t *)(iq1s_grid_gpu + (qs[0] | ((qh[0] << 8) & 0x700)));
|
|
|
|
| 5766 |
constant uint8_t * grid4 = (constant uint8_t *)(iq1s_grid_gpu + (qs[3] | ((qh[1] << 4) & 0x700)));
|
| 5767 |
|
| 5768 |
float2 sum = {0.f};
|
| 5769 |
+
for (short j = 0; j < 4; ++j) {
|
| 5770 |
sum[0] += yl[j+ 0] * (grid1[j] & 0xf) + yl[j+ 4] * (grid1[j] >> 4)
|
| 5771 |
+ yl[j+ 8] * (grid2[j] & 0xf) + yl[j+12] * (grid2[j] >> 4);
|
| 5772 |
sum[1] += yl[j+16] * (grid3[j] & 0xf) + yl[j+20] * (grid3[j] >> 4)
|
|
|
|
| 5788 |
|
| 5789 |
device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
|
| 5790 |
|
| 5791 |
+
for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) {
|
| 5792 |
+
float sum_all = simd_sum(sumf[row]);
|
| 5793 |
if (tiisg == 0) {
|
| 5794 |
+
dst_f32[first_row + row] = sum_all;
|
| 5795 |
}
|
| 5796 |
}
|
| 5797 |
}
|
| 5798 |
|
| 5799 |
+
[[host_name("kernel_mul_mv_iq1_m_f32")]]
|
| 5800 |
+
kernel void kernel_mul_mv_iq1_m_f32(
|
| 5801 |
+
constant ggml_metal_kargs_mul_mv & args,
|
| 5802 |
+
device const char * src0,
|
| 5803 |
+
device const char * src1,
|
| 5804 |
+
device char * dst,
|
| 5805 |
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 5806 |
+
ushort tiisg[[thread_index_in_simdgroup]],
|
| 5807 |
+
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
|
| 5808 |
+
|
| 5809 |
+
kernel_mul_mv_iq1_m_f32_impl<N_R0_IQ1_M, N_SG_IQ1_M, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
|
| 5810 |
+
}
|
| 5811 |
+
|
| 5812 |
+
template<int nr0, int nsg, int nw, typename args_t>
|
| 5813 |
void kernel_mul_mv_iq4_nl_f32_impl(
|
| 5814 |
args_t args,
|
| 5815 |
device const char * src0,
|
|
|
|
| 5822 |
|
| 5823 |
threadgroup float * shmem_f32 = (threadgroup float *) shmem;
|
| 5824 |
const int nb = args.ne00/QK4_NL;
|
| 5825 |
+
|
| 5826 |
const int r0 = tgpig.x;
|
| 5827 |
const int r1 = tgpig.y;
|
| 5828 |
const int im = tgpig.z;
|
| 5829 |
+
|
| 5830 |
+
const int first_row = (r0 * nsg + sgitg) * nr0;
|
| 5831 |
|
| 5832 |
const uint i12 = im%args.ne12;
|
| 5833 |
const uint i13 = im/args.ne12;
|
|
|
|
| 5838 |
device const block_iq4_nl * x = (device const block_iq4_nl *) (src0 + offset0);
|
| 5839 |
device const float * y = (device const float *) (src1 + offset1);
|
| 5840 |
|
| 5841 |
+
const short ix = tiisg/2; // 0...15
|
| 5842 |
+
const short it = tiisg%2; // 0 or 1
|
| 5843 |
|
| 5844 |
shmem_f32[tiisg] = kvalues_iq4nl_f[tiisg%16];
|
| 5845 |
threadgroup_barrier(mem_flags::mem_threadgroup);
|
| 5846 |
|
| 5847 |
float4 yl[4];
|
| 5848 |
+
float sumf[nr0]={0.f};
|
| 5849 |
|
| 5850 |
device const float * yb = y + ix * QK4_NL + it * 8;
|
| 5851 |
|
|
|
|
| 5855 |
float4 qf1, qf2;
|
| 5856 |
|
| 5857 |
for (int ib = ix; ib < nb; ib += 16) {
|
|
|
|
| 5858 |
device const float4 * y4 = (device const float4 *)yb;
|
| 5859 |
+
yl[0] = y4[0];
|
| 5860 |
+
yl[1] = y4[4];
|
| 5861 |
+
yl[2] = y4[1];
|
| 5862 |
+
yl[3] = y4[5];
|
| 5863 |
|
| 5864 |
+
for (short row = 0; row < nr0; row++) {
|
| 5865 |
device const block_iq4_nl & xb = x[row*nb + ib];
|
| 5866 |
device const uint16_t * q4 = (device const uint16_t *)(xb.qs + 8*it);
|
| 5867 |
|
|
|
|
| 5886 |
acc1 += acc2;
|
| 5887 |
|
| 5888 |
sumf[row] += (float)xb.d * (acc1[0] + acc1[1] + acc1[2] + acc1[3]);
|
|
|
|
| 5889 |
}
|
| 5890 |
|
| 5891 |
yb += 16 * QK4_NL;
|
|
|
|
| 5893 |
|
| 5894 |
device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
|
| 5895 |
|
| 5896 |
+
for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) {
|
| 5897 |
+
float sum_all = simd_sum(sumf[row]);
|
| 5898 |
if (tiisg == 0) {
|
| 5899 |
+
dst_f32[first_row + row] = sum_all;
|
| 5900 |
}
|
| 5901 |
}
|
| 5902 |
}
|
| 5903 |
|
| 5904 |
+
[[host_name("kernel_mul_mv_iq4_nl_f32")]]
|
| 5905 |
+
kernel void kernel_mul_mv_iq4_nl_f32(
|
| 5906 |
+
constant ggml_metal_kargs_mul_mv & args,
|
| 5907 |
+
device const char * src0,
|
| 5908 |
+
device const char * src1,
|
| 5909 |
+
device char * dst,
|
| 5910 |
+
threadgroup char * shmem [[threadgroup(0)]],
|
| 5911 |
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 5912 |
+
ushort tiisg[[thread_index_in_simdgroup]],
|
| 5913 |
+
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
|
| 5914 |
+
|
| 5915 |
+
kernel_mul_mv_iq4_nl_f32_impl<N_R0_IQ4_NL, N_SG_IQ4_NL, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
|
| 5916 |
+
}
|
| 5917 |
+
|
| 5918 |
+
template<int nr0, int nsg, int nw, typename args_t>
|
| 5919 |
void kernel_mul_mv_iq4_xs_f32_impl(
|
| 5920 |
args_t args,
|
| 5921 |
device const char * src0,
|
|
|
|
| 5931 |
const int r0 = tgpig.x;
|
| 5932 |
const int r1 = tgpig.y;
|
| 5933 |
const int im = tgpig.z;
|
| 5934 |
+
const int first_row = (r0 * nsg + sgitg) * nr0;
|
| 5935 |
|
| 5936 |
const uint i12 = im%args.ne12;
|
| 5937 |
const uint i13 = im/args.ne12;
|
|
|
|
| 5942 |
device const block_iq4_xs * x = (device const block_iq4_xs *) (src0 + offset0);
|
| 5943 |
device const float * y = (device const float *) (src1 + offset1);
|
| 5944 |
|
| 5945 |
+
const short ix = tiisg/16; // 0 or 1
|
| 5946 |
+
const short it = tiisg%16; // 0...15
|
| 5947 |
+
const short ib = it/2;
|
| 5948 |
+
const short il = it%2;
|
| 5949 |
|
| 5950 |
shmem_f32[tiisg] = kvalues_iq4nl_f[tiisg%16];
|
| 5951 |
threadgroup_barrier(mem_flags::mem_threadgroup);
|
| 5952 |
|
| 5953 |
float4 yl[4];
|
| 5954 |
+
float sumf[nr0]={0.f};
|
| 5955 |
|
| 5956 |
device const float * yb = y + ix * QK_K + ib * 32 + il * 8;
|
| 5957 |
|
|
|
|
| 5962 |
|
| 5963 |
for (int ibl = ix; ibl < nb; ibl += 2) {
|
| 5964 |
device const float4 * y4 = (device const float4 *)yb;
|
| 5965 |
+
yl[0] = y4[0];
|
| 5966 |
+
yl[1] = y4[4];
|
| 5967 |
+
yl[2] = y4[1];
|
| 5968 |
+
yl[3] = y4[5];
|
| 5969 |
|
| 5970 |
+
for (short row = 0; row < nr0; ++row) {
|
| 5971 |
device const block_iq4_xs & xb = x[row*nb + ibl];
|
| 5972 |
device const uint32_t * q4 = (device const uint32_t *)(xb.qs + 16*ib + 8*il);
|
| 5973 |
|
|
|
|
| 5991 |
|
| 5992 |
const int ls = (((xb.scales_l[ib/2] >> 4*(ib%2)) & 0xf) | (((xb.scales_h >> 2*ib) & 3) << 4)) - 32;
|
| 5993 |
sumf[row] += (float)xb.d * ls * (acc1[0] + acc1[1] + acc1[2] + acc1[3]);
|
|
|
|
| 5994 |
}
|
| 5995 |
|
| 5996 |
yb += 2 * QK_K;
|
|
|
|
| 5998 |
|
| 5999 |
device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
|
| 6000 |
|
| 6001 |
+
for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) {
|
| 6002 |
+
float sum_all = simd_sum(sumf[row]);
|
| 6003 |
if (tiisg == 0) {
|
| 6004 |
+
dst_f32[first_row + row] = sum_all;
|
| 6005 |
}
|
| 6006 |
}
|
| 6007 |
}
|
| 6008 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6009 |
[[host_name("kernel_mul_mv_iq4_xs_f32")]]
|
| 6010 |
kernel void kernel_mul_mv_iq4_xs_f32(
|
| 6011 |
constant ggml_metal_kargs_mul_mv & args,
|
|
|
|
| 6017 |
ushort tiisg[[thread_index_in_simdgroup]],
|
| 6018 |
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
|
| 6019 |
|
| 6020 |
+
kernel_mul_mv_iq4_xs_f32_impl<N_R0_IQ4_XS, N_SG_IQ4_XS, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
|
| 6021 |
}
|
| 6022 |
|
| 6023 |
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread float4x4 &)>
|
|
|
|
| 6661 |
#if defined(GGML_METAL_USE_BF16)
|
| 6662 |
template [[host_name("kernel_mul_mv_id_bf16_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_impl<bfloat, bfloat4, float, float4>>>;
|
| 6663 |
#endif
|
| 6664 |
+
template [[host_name("kernel_mul_mv_id_q8_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q8_0_f32_impl<N_R0_Q8_0, N_SG_Q8_0, N_SIMDWIDTH>>>;
|
| 6665 |
+
|
| 6666 |
+
template [[host_name("kernel_mul_mv_id_q4_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q4_0, N_R0_Q4_0, N_SG_Q4_0, N_SIMDWIDTH>>>;
|
| 6667 |
+
template [[host_name("kernel_mul_mv_id_q4_1_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q4_1, N_R0_Q4_1, N_SG_Q4_1, N_SIMDWIDTH>>>;
|
| 6668 |
+
template [[host_name("kernel_mul_mv_id_q5_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q5_0, N_R0_Q5_0, N_SG_Q5_0, N_SIMDWIDTH>>>;
|
| 6669 |
+
template [[host_name("kernel_mul_mv_id_q5_1_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q5_1, N_R0_Q5_1, N_SG_Q5_1, N_SIMDWIDTH>>>;
|
| 6670 |
+
|
| 6671 |
+
template [[host_name("kernel_mul_mv_id_q2_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q2_K_f32_impl <N_R0_Q2_K, N_SG_Q2_K, N_SIMDWIDTH>>>;
|
| 6672 |
+
template [[host_name("kernel_mul_mv_id_q3_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q3_K_f32_impl <N_R0_Q3_K, N_SG_Q3_K, N_SIMDWIDTH>>>;
|
| 6673 |
+
template [[host_name("kernel_mul_mv_id_q4_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q4_K_f32_impl <N_R0_Q4_K, N_SG_Q4_K, N_SIMDWIDTH>>>;
|
| 6674 |
+
template [[host_name("kernel_mul_mv_id_q5_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q5_K_f32_impl <N_R0_Q5_K, N_SG_Q5_K, N_SIMDWIDTH>>>;
|
| 6675 |
+
template [[host_name("kernel_mul_mv_id_q6_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q6_K_f32_impl <N_R0_Q6_K, N_SG_Q6_K, N_SIMDWIDTH>>>;
|
| 6676 |
+
template [[host_name("kernel_mul_mv_id_iq1_s_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq1_s_f32_impl <N_R0_IQ1_S, N_SG_IQ1_S, N_SIMDWIDTH>>>;
|
| 6677 |
+
template [[host_name("kernel_mul_mv_id_iq1_m_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq1_m_f32_impl <N_R0_IQ1_M, N_SG_IQ1_M, N_SIMDWIDTH>>>;
|
| 6678 |
+
template [[host_name("kernel_mul_mv_id_iq2_xxs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq2_xxs_f32_impl<N_R0_IQ2_XXS, N_SG_IQ2_XXS, N_SIMDWIDTH>>>;
|
| 6679 |
+
template [[host_name("kernel_mul_mv_id_iq2_xs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq2_xs_f32_impl <N_R0_IQ2_XS, N_SG_IQ2_XS, N_SIMDWIDTH>>>;
|
| 6680 |
+
template [[host_name("kernel_mul_mv_id_iq3_xxs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq3_xxs_f32_impl<N_R0_IQ3_XXS, N_SG_IQ3_XXS, N_SIMDWIDTH>>>;
|
| 6681 |
+
template [[host_name("kernel_mul_mv_id_iq3_s_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq3_s_f32_impl <N_R0_IQ3_S, N_SG_IQ3_S, N_SIMDWIDTH>>>;
|
| 6682 |
+
template [[host_name("kernel_mul_mv_id_iq2_s_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq2_s_f32_impl <N_R0_IQ2_S, N_SG_IQ2_S, N_SIMDWIDTH>>>;
|
| 6683 |
+
template [[host_name("kernel_mul_mv_id_iq4_nl_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq4_nl_f32_impl <N_R0_IQ4_NL, N_SG_IQ4_NL, N_SIMDWIDTH>>>;
|
| 6684 |
+
template [[host_name("kernel_mul_mv_id_iq4_xs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq4_xs_f32_impl <N_R0_IQ4_XS, N_SG_IQ4_XS, N_SIMDWIDTH>>>;
|
| 6685 |
|
| 6686 |
kernel void kernel_pool_2d_max_f32(
|
| 6687 |
device const float * src0,
|