ggerganov commited on
Commit
71d72f9
·
1 Parent(s): a13f78c

metal : refactor mat-vec code (llama/12569)

Browse files

* metal : refactor mat-vec code

ggml-ci

* metal : rename all_sum -> sum_all

ggml-ci

* metal : fix comments [no ci]

* metal : fix nr constant [no ci]

* metal : mv q6_K support nr0 > 1

ggml-ci

* metal : reduce register pressure

ggml-ci

* metal : fix typo [no ci]

* metal : reduce register pressure

ggml-ci

ggml/src/ggml-metal/ggml-metal-impl.h CHANGED
@@ -1,6 +1,70 @@
1
  #ifndef GGML_METAL_IMPL
2
  #define GGML_METAL_IMPL
3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  // kernel argument structs
5
  //
6
  // - element counters (e.g. ne00) typically use int32_t to reduce register usage
 
1
  #ifndef GGML_METAL_IMPL
2
  #define GGML_METAL_IMPL
3
 
4
+ // kernel parameters for mat-vec threadgroups
5
+ //
6
+ // N_R0: number of src0 rows to process per simdgroup
7
+ // N_SG: number of simdgroups per threadgroup
8
+ //
9
+ // TODO: for optimal performance, become function of the device and work size
10
+
11
+ #define N_R0_Q4_0 4
12
+ #define N_SG_Q4_0 2
13
+
14
+ #define N_R0_Q4_1 4
15
+ #define N_SG_Q4_1 2
16
+
17
+ #define N_R0_Q5_0 4
18
+ #define N_SG_Q5_0 2
19
+
20
+ #define N_R0_Q5_1 4
21
+ #define N_SG_Q5_1 2
22
+
23
+ #define N_R0_Q8_0 4
24
+ #define N_SG_Q8_0 2
25
+
26
+ #define N_R0_Q2_K 4
27
+ #define N_SG_Q2_K 2
28
+
29
+ #define N_R0_Q3_K 2
30
+ #define N_SG_Q3_K 2
31
+
32
+ #define N_R0_Q4_K 4
33
+ #define N_SG_Q4_K 2
34
+
35
+ #define N_R0_Q5_K 2
36
+ #define N_SG_Q5_K 2
37
+
38
+ #define N_R0_Q6_K 1
39
+ #define N_SG_Q6_K 2
40
+
41
+ #define N_R0_IQ1_S 4
42
+ #define N_SG_IQ1_S 2
43
+
44
+ #define N_R0_IQ1_M 4
45
+ #define N_SG_IQ1_M 2
46
+
47
+ #define N_R0_IQ2_XXS 4
48
+ #define N_SG_IQ2_XXS 2
49
+
50
+ #define N_R0_IQ2_XS 4
51
+ #define N_SG_IQ2_XS 2
52
+
53
+ #define N_R0_IQ2_S 4
54
+ #define N_SG_IQ2_S 2
55
+
56
+ #define N_R0_IQ3_XXS 4
57
+ #define N_SG_IQ3_XXS 2
58
+
59
+ #define N_R0_IQ3_S 4
60
+ #define N_SG_IQ3_S 2
61
+
62
+ #define N_R0_IQ4_NL 2
63
+ #define N_SG_IQ4_NL 2
64
+
65
+ #define N_R0_IQ4_XS 2
66
+ #define N_SG_IQ4_XS 2
67
+
68
  // kernel argument structs
69
  //
70
  // - element counters (e.g. ne00) typically use int32_t to reduce register usage
ggml/src/ggml-metal/ggml-metal.m CHANGED
@@ -2561,171 +2561,180 @@ static void ggml_metal_encode_node(
2561
  [encoder setThreadgroupMemoryLength:8192 atIndex:0];
2562
  [encoder dispatchThreadgroups:MTLSizeMake( (ne11 + 31)/32, (ne01 + 63)/64, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
2563
  } else {
2564
- int nth0 = 32;
2565
- int nth1 = 1;
2566
- int nrows = 1;
2567
- //printf("vector: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
2568
-
2569
  id<MTLComputePipelineState> pipeline = nil;
2570
 
 
 
 
 
 
 
2571
  // use custom matrix x vector kernel
2572
  switch (src0t) {
2573
  case GGML_TYPE_F32:
2574
  {
2575
  GGML_ASSERT(src1t == GGML_TYPE_F32);
 
 
 
2576
  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32].pipeline;
2577
- nrows = 4;
2578
  } break;
2579
  case GGML_TYPE_F16:
2580
  {
2581
- nth0 = 32;
2582
- nth1 = 1;
2583
  if (src1t == GGML_TYPE_F32) {
2584
  if (ne11 * ne12 < 4) {
2585
  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW].pipeline;
2586
  } else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) {
2587
  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4].pipeline;
2588
- nrows = ne11;
2589
  } else {
2590
  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32].pipeline;
2591
- nrows = 4;
2592
  }
2593
  } else {
2594
  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16].pipeline;
2595
- nrows = 4;
2596
  }
2597
  } break;
2598
  case GGML_TYPE_BF16:
2599
  {
2600
- nth0 = 32;
2601
- nth1 = 1;
2602
  if (src1t == GGML_TYPE_F32) {
2603
  if (ne11 * ne12 < 4) {
2604
  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_1ROW].pipeline;
2605
  } else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) {
2606
  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_L4].pipeline;
2607
- nrows = ne11;
2608
  } else {
2609
  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32].pipeline;
2610
- nrows = 4;
2611
  }
2612
  } else {
2613
  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_BF16].pipeline;
2614
- nrows = 4;
2615
  }
2616
  } break;
2617
  case GGML_TYPE_Q4_0:
2618
  {
2619
- nth0 = 8;
2620
- nth1 = 8;
2621
  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_0_F32].pipeline;
2622
  } break;
2623
  case GGML_TYPE_Q4_1:
2624
  {
2625
- nth0 = 8;
2626
- nth1 = 8;
2627
  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_1_F32].pipeline;
2628
  } break;
2629
  case GGML_TYPE_Q5_0:
2630
  {
2631
- nth0 = 8;
2632
- nth1 = 8;
2633
  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32].pipeline;
2634
  } break;
2635
  case GGML_TYPE_Q5_1:
2636
  {
2637
- nth0 = 8;
2638
- nth1 = 8;
2639
  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32].pipeline;
2640
  } break;
2641
  case GGML_TYPE_Q8_0:
2642
  {
2643
- nth0 = 8;
2644
- nth1 = 8;
2645
  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32].pipeline;
2646
  } break;
2647
  case GGML_TYPE_Q2_K:
2648
  {
2649
- nth0 = 2;
2650
- nth1 = 32;
2651
  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q2_K_F32].pipeline;
2652
  } break;
2653
  case GGML_TYPE_Q3_K:
2654
  {
2655
- nth0 = 2;
2656
- nth1 = 32;
2657
  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q3_K_F32].pipeline;
2658
  } break;
2659
  case GGML_TYPE_Q4_K:
2660
  {
2661
- nth0 = 4; //1;
2662
- nth1 = 8; //32;
2663
  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_K_F32].pipeline;
2664
  } break;
2665
  case GGML_TYPE_Q5_K:
2666
  {
2667
- nth0 = 2;
2668
- nth1 = 32;
2669
  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_K_F32].pipeline;
2670
  } break;
2671
  case GGML_TYPE_Q6_K:
2672
  {
2673
- nth0 = 2;
2674
- nth1 = 32;
2675
  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q6_K_F32].pipeline;
2676
  } break;
2677
  case GGML_TYPE_IQ2_XXS:
2678
  {
2679
- nth0 = 4;
2680
- nth1 = 16;
 
2681
  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XXS_F32].pipeline;
2682
  } break;
2683
  case GGML_TYPE_IQ2_XS:
2684
  {
2685
- nth0 = 4;
2686
- nth1 = 16;
 
2687
  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32].pipeline;
2688
  } break;
2689
  case GGML_TYPE_IQ3_XXS:
2690
  {
2691
- nth0 = 4;
2692
- nth1 = 16;
 
2693
  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_XXS_F32].pipeline;
2694
  } break;
2695
  case GGML_TYPE_IQ3_S:
2696
  {
2697
- nth0 = 4;
2698
- nth1 = 16;
 
2699
  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_S_F32].pipeline;
2700
  } break;
2701
  case GGML_TYPE_IQ2_S:
2702
  {
2703
- nth0 = 4;
2704
- nth1 = 16;
2705
  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_S_F32].pipeline;
2706
  } break;
2707
  case GGML_TYPE_IQ1_S:
2708
  {
2709
- nth0 = 4;
2710
- nth1 = 16;
2711
  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_S_F32].pipeline;
2712
  } break;
2713
  case GGML_TYPE_IQ1_M:
2714
  {
2715
- nth0 = 4;
2716
- nth1 = 16;
2717
  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_M_F32].pipeline;
2718
  } break;
2719
  case GGML_TYPE_IQ4_NL:
2720
  {
2721
- nth0 = 4;
2722
- nth1 = 16;
 
2723
  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32].pipeline;
2724
  } break;
2725
  case GGML_TYPE_IQ4_XS:
2726
  {
2727
- nth0 = 4;
2728
- nth1 = 16;
 
2729
  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_XS_F32].pipeline;
2730
  } break;
2731
  default:
@@ -2762,41 +2771,10 @@ static void ggml_metal_encode_node(
2762
  [encoder setBuffer:id_src1 offset:offs_src1 atIndex:2];
2763
  [encoder setBuffer:id_dst offset:offs_dst atIndex:3];
2764
 
2765
- if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 || src0t == GGML_TYPE_Q5_0 ||
2766
- src0t == GGML_TYPE_Q5_1 || src0t == GGML_TYPE_Q8_0 || src0t == GGML_TYPE_Q2_K ||
2767
- src0t == GGML_TYPE_IQ1_S || src0t == GGML_TYPE_IQ1_M || src0t == GGML_TYPE_IQ2_S) {
2768
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
2769
- }
2770
- else if (src0t == GGML_TYPE_IQ2_XXS || src0t == GGML_TYPE_IQ2_XS) {
2771
- const int mem_size = src0t == GGML_TYPE_IQ2_XXS ? 256*8+128 : 512*8+128;
2772
- [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
2773
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
2774
- }
2775
- else if (src0t == GGML_TYPE_IQ3_XXS || src0t == GGML_TYPE_IQ3_S) {
2776
- const int mem_size = src0t == GGML_TYPE_IQ3_XXS ? 256*4+128 : 512*4;
2777
- [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
2778
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
2779
- }
2780
- else if (src0t == GGML_TYPE_IQ4_NL || src0t == GGML_TYPE_IQ4_XS) {
2781
- const int mem_size = 32*sizeof(float);
2782
- [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
2783
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
2784
- }
2785
- else if (src0t == GGML_TYPE_Q4_K) {
2786
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
2787
- }
2788
- else if (src0t == GGML_TYPE_Q3_K) {
2789
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
2790
- }
2791
- else if (src0t == GGML_TYPE_Q5_K) {
2792
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
2793
- }
2794
- else if (src0t == GGML_TYPE_Q6_K) {
2795
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
2796
- } else {
2797
- const int64_t ny = (ne11 + nrows - 1)/nrows;
2798
- [encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
2799
  }
 
2800
  }
2801
  } break;
2802
  case GGML_OP_MUL_MAT_ID:
@@ -2902,146 +2880,155 @@ static void ggml_metal_encode_node(
2902
 
2903
  [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 31)/32, (ne01 + 63)/64, n_as) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
2904
  } else {
2905
- int nth0 = 32;
2906
- int nth1 = 1;
2907
- int nrows = 1;
2908
- //printf("vector: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
2909
-
2910
  id<MTLComputePipelineState> pipeline = nil;
2911
 
 
 
 
 
 
 
2912
  // use custom matrix x vector kernel
2913
  switch (src0t) {
2914
  case GGML_TYPE_F32:
2915
  {
2916
  GGML_ASSERT(src1t == GGML_TYPE_F32);
 
 
2917
  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32].pipeline;
2918
  } break;
2919
  case GGML_TYPE_F16:
2920
  {
2921
  GGML_ASSERT(src1t == GGML_TYPE_F32);
2922
- nth0 = 32;
2923
- nth1 = 1;
2924
  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32].pipeline;
2925
  } break;
2926
  case GGML_TYPE_BF16:
2927
  {
2928
  GGML_ASSERT(src1t == GGML_TYPE_F32);
2929
- nth0 = 32;
2930
- nth1 = 1;
2931
  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_BF16_F32].pipeline;
2932
  } break;
2933
  case GGML_TYPE_Q4_0:
2934
  {
2935
- nth0 = 8;
2936
- nth1 = 8;
2937
  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_0_F32].pipeline;
2938
  } break;
2939
  case GGML_TYPE_Q4_1:
2940
  {
2941
- nth0 = 8;
2942
- nth1 = 8;
2943
  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_1_F32].pipeline;
2944
  } break;
2945
  case GGML_TYPE_Q5_0:
2946
  {
2947
- nth0 = 8;
2948
- nth1 = 8;
2949
  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_0_F32].pipeline;
2950
  } break;
2951
  case GGML_TYPE_Q5_1:
2952
  {
2953
- nth0 = 8;
2954
- nth1 = 8;
2955
  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_1_F32].pipeline;
2956
  } break;
2957
  case GGML_TYPE_Q8_0:
2958
  {
2959
- nth0 = 8;
2960
- nth1 = 8;
2961
  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q8_0_F32].pipeline;
2962
  } break;
2963
  case GGML_TYPE_Q2_K:
2964
  {
2965
- nth0 = 2;
2966
- nth1 = 32;
2967
  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q2_K_F32].pipeline;
2968
  } break;
2969
  case GGML_TYPE_Q3_K:
2970
  {
2971
- nth0 = 2;
2972
- nth1 = 32;
2973
  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q3_K_F32].pipeline;
2974
  } break;
2975
  case GGML_TYPE_Q4_K:
2976
  {
2977
- nth0 = 4; //1;
2978
- nth1 = 8; //32;
2979
  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_K_F32].pipeline;
2980
  } break;
2981
  case GGML_TYPE_Q5_K:
2982
  {
2983
- nth0 = 2;
2984
- nth1 = 32;
2985
  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_K_F32].pipeline;
2986
  } break;
2987
  case GGML_TYPE_Q6_K:
2988
  {
2989
- nth0 = 2;
2990
- nth1 = 32;
2991
  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q6_K_F32].pipeline;
2992
  } break;
2993
  case GGML_TYPE_IQ2_XXS:
2994
  {
2995
- nth0 = 4;
2996
- nth1 = 16;
 
2997
  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XXS_F32].pipeline;
2998
  } break;
2999
  case GGML_TYPE_IQ2_XS:
3000
  {
3001
- nth0 = 4;
3002
- nth1 = 16;
 
3003
  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32].pipeline;
3004
  } break;
3005
  case GGML_TYPE_IQ3_XXS:
3006
  {
3007
- nth0 = 4;
3008
- nth1 = 16;
 
3009
  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_XXS_F32].pipeline;
3010
  } break;
3011
  case GGML_TYPE_IQ3_S:
3012
  {
3013
- nth0 = 4;
3014
- nth1 = 16;
 
3015
  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_S_F32].pipeline;
3016
  } break;
3017
  case GGML_TYPE_IQ2_S:
3018
  {
3019
- nth0 = 4;
3020
- nth1 = 16;
3021
  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_S_F32].pipeline;
3022
  } break;
3023
  case GGML_TYPE_IQ1_S:
3024
  {
3025
- nth0 = 4;
3026
- nth1 = 16;
3027
  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_S_F32].pipeline;
3028
  } break;
3029
  case GGML_TYPE_IQ1_M:
3030
  {
3031
- nth0 = 4;
3032
- nth1 = 16;
3033
  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_M_F32].pipeline;
3034
  } break;
3035
  case GGML_TYPE_IQ4_NL:
3036
  {
3037
- nth0 = 4;
3038
- nth1 = 16;
 
3039
  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32].pipeline;
3040
  } break;
3041
  case GGML_TYPE_IQ4_XS:
3042
  {
3043
- nth0 = 4;
3044
- nth1 = 16;
 
3045
  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32].pipeline;
3046
  } break;
3047
  default:
@@ -3052,7 +3039,7 @@ static void ggml_metal_encode_node(
3052
  };
3053
 
3054
  if (ggml_is_quantized(src0t)) {
3055
- GGML_ASSERT(ne00 >= nth0*nth1);
3056
  }
3057
 
3058
  ggml_metal_kargs_mul_mv_id args = {
@@ -3085,43 +3072,12 @@ static void ggml_metal_encode_node(
3085
  [encoder setBuffer:id_src2 offset:offs_src2 atIndex:4];
3086
 
3087
  const int64_t _ne1 = 1;
3088
- const int tgz = dst_rows;
3089
 
3090
- if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 || src0t == GGML_TYPE_Q5_0 ||
3091
- src0t == GGML_TYPE_Q5_1 || src0t == GGML_TYPE_Q8_0 || src0t == GGML_TYPE_Q2_K ||
3092
- src0t == GGML_TYPE_IQ1_S || src0t == GGML_TYPE_IQ1_M || src0t == GGML_TYPE_IQ2_S) {
3093
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
3094
- }
3095
- else if (src0t == GGML_TYPE_IQ2_XXS || src0t == GGML_TYPE_IQ2_XS) {
3096
- const int mem_size = src0t == GGML_TYPE_IQ2_XXS ? 256*8+128 : 512*8+128;
3097
- [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
3098
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
3099
- }
3100
- else if (src0t == GGML_TYPE_IQ3_XXS || src0t == GGML_TYPE_IQ3_S) {
3101
- const int mem_size = src0t == GGML_TYPE_IQ3_XXS ? 256*4+128 : 512*4;
3102
- [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
3103
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
3104
- }
3105
- else if (src0t == GGML_TYPE_IQ4_NL || src0t == GGML_TYPE_IQ4_XS) {
3106
- const int mem_size = 32*sizeof(float);
3107
- [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
3108
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
3109
- }
3110
- else if (src0t == GGML_TYPE_Q4_K) {
3111
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
3112
- }
3113
- else if (src0t == GGML_TYPE_Q3_K) {
3114
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
3115
- }
3116
- else if (src0t == GGML_TYPE_Q5_K) {
3117
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
3118
- }
3119
- else if (src0t == GGML_TYPE_Q6_K) {
3120
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
3121
- } else {
3122
- const int64_t ny = (_ne1 + nrows - 1)/nrows; // = _ne1
3123
- [encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
3124
  }
 
3125
  }
3126
  } break;
3127
  case GGML_OP_GET_ROWS:
 
2561
  [encoder setThreadgroupMemoryLength:8192 atIndex:0];
2562
  [encoder dispatchThreadgroups:MTLSizeMake( (ne11 + 31)/32, (ne01 + 63)/64, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
2563
  } else {
 
 
 
 
 
2564
  id<MTLComputePipelineState> pipeline = nil;
2565
 
2566
+ int nsg = 0; // number of simdgroups
2567
+ int nr0 = 0; // number of src0 rows per simdgroup
2568
+ int nr1 = 1; // number of src1 rows per threadgroup
2569
+
2570
+ size_t smem = 0; // shared memory
2571
+
2572
  // use custom matrix x vector kernel
2573
  switch (src0t) {
2574
  case GGML_TYPE_F32:
2575
  {
2576
  GGML_ASSERT(src1t == GGML_TYPE_F32);
2577
+ nsg = 1;
2578
+ nr0 = 1;
2579
+ nr1 = 4;
2580
  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32].pipeline;
 
2581
  } break;
2582
  case GGML_TYPE_F16:
2583
  {
2584
+ nsg = 1;
2585
+ nr0 = 1;
2586
  if (src1t == GGML_TYPE_F32) {
2587
  if (ne11 * ne12 < 4) {
2588
  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW].pipeline;
2589
  } else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) {
2590
  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4].pipeline;
2591
+ nr1 = ne11;
2592
  } else {
2593
  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32].pipeline;
2594
+ nr1 = 4;
2595
  }
2596
  } else {
2597
  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16].pipeline;
2598
+ nr1 = 4;
2599
  }
2600
  } break;
2601
  case GGML_TYPE_BF16:
2602
  {
2603
+ nsg = 1;
2604
+ nr0 = 1;
2605
  if (src1t == GGML_TYPE_F32) {
2606
  if (ne11 * ne12 < 4) {
2607
  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_1ROW].pipeline;
2608
  } else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) {
2609
  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_L4].pipeline;
2610
+ nr1 = ne11;
2611
  } else {
2612
  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32].pipeline;
2613
+ nr1 = 4;
2614
  }
2615
  } else {
2616
  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_BF16].pipeline;
2617
+ nr1 = 4;
2618
  }
2619
  } break;
2620
  case GGML_TYPE_Q4_0:
2621
  {
2622
+ nsg = N_SG_Q4_0;
2623
+ nr0 = N_R0_Q4_0;
2624
  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_0_F32].pipeline;
2625
  } break;
2626
  case GGML_TYPE_Q4_1:
2627
  {
2628
+ nsg = N_SG_Q4_1;
2629
+ nr0 = N_R0_Q4_1;
2630
  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_1_F32].pipeline;
2631
  } break;
2632
  case GGML_TYPE_Q5_0:
2633
  {
2634
+ nsg = N_SG_Q5_0;
2635
+ nr0 = N_R0_Q5_0;
2636
  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32].pipeline;
2637
  } break;
2638
  case GGML_TYPE_Q5_1:
2639
  {
2640
+ nsg = N_SG_Q5_1;
2641
+ nr0 = N_R0_Q5_1;
2642
  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32].pipeline;
2643
  } break;
2644
  case GGML_TYPE_Q8_0:
2645
  {
2646
+ nsg = N_SG_Q8_0;
2647
+ nr0 = N_R0_Q8_0;
2648
  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32].pipeline;
2649
  } break;
2650
  case GGML_TYPE_Q2_K:
2651
  {
2652
+ nsg = N_SG_Q2_K;
2653
+ nr0 = N_R0_Q2_K;
2654
  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q2_K_F32].pipeline;
2655
  } break;
2656
  case GGML_TYPE_Q3_K:
2657
  {
2658
+ nsg = N_SG_Q3_K;
2659
+ nr0 = N_R0_Q3_K;
2660
  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q3_K_F32].pipeline;
2661
  } break;
2662
  case GGML_TYPE_Q4_K:
2663
  {
2664
+ nsg = N_SG_Q4_K;
2665
+ nr0 = N_R0_Q4_K;
2666
  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_K_F32].pipeline;
2667
  } break;
2668
  case GGML_TYPE_Q5_K:
2669
  {
2670
+ nsg = N_SG_Q5_K;
2671
+ nr0 = N_R0_Q5_K;
2672
  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_K_F32].pipeline;
2673
  } break;
2674
  case GGML_TYPE_Q6_K:
2675
  {
2676
+ nsg = N_SG_Q6_K;
2677
+ nr0 = N_R0_Q6_K;
2678
  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q6_K_F32].pipeline;
2679
  } break;
2680
  case GGML_TYPE_IQ2_XXS:
2681
  {
2682
+ nsg = N_SG_IQ2_XXS;
2683
+ nr0 = N_R0_IQ2_XXS;
2684
+ smem = 256*8+128;
2685
  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XXS_F32].pipeline;
2686
  } break;
2687
  case GGML_TYPE_IQ2_XS:
2688
  {
2689
+ nsg = N_SG_IQ2_XS;
2690
+ nr0 = N_R0_IQ2_XS;
2691
+ smem = 512*8+128;
2692
  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32].pipeline;
2693
  } break;
2694
  case GGML_TYPE_IQ3_XXS:
2695
  {
2696
+ nsg = N_SG_IQ3_XXS;
2697
+ nr0 = N_R0_IQ3_XXS;
2698
+ smem = 256*4+128;
2699
  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_XXS_F32].pipeline;
2700
  } break;
2701
  case GGML_TYPE_IQ3_S:
2702
  {
2703
+ nsg = N_SG_IQ3_S;
2704
+ nr0 = N_R0_IQ3_S;
2705
+ smem = 512*4;
2706
  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_S_F32].pipeline;
2707
  } break;
2708
  case GGML_TYPE_IQ2_S:
2709
  {
2710
+ nsg = N_SG_IQ2_S;
2711
+ nr0 = N_R0_IQ2_S;
2712
  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_S_F32].pipeline;
2713
  } break;
2714
  case GGML_TYPE_IQ1_S:
2715
  {
2716
+ nsg = N_SG_IQ1_S;
2717
+ nr0 = N_R0_IQ1_S;
2718
  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_S_F32].pipeline;
2719
  } break;
2720
  case GGML_TYPE_IQ1_M:
2721
  {
2722
+ nsg = N_SG_IQ1_M;
2723
+ nr0 = N_R0_IQ1_M;
2724
  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_M_F32].pipeline;
2725
  } break;
2726
  case GGML_TYPE_IQ4_NL:
2727
  {
2728
+ nsg = N_SG_IQ4_NL;
2729
+ nr0 = N_R0_IQ4_NL;
2730
+ smem = 32*sizeof(float);
2731
  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32].pipeline;
2732
  } break;
2733
  case GGML_TYPE_IQ4_XS:
2734
  {
2735
+ nsg = N_SG_IQ4_XS;
2736
+ nr0 = N_R0_IQ4_XS;
2737
+ smem = 32*sizeof(float);
2738
  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_XS_F32].pipeline;
2739
  } break;
2740
  default:
 
2771
  [encoder setBuffer:id_src1 offset:offs_src1 atIndex:2];
2772
  [encoder setBuffer:id_dst offset:offs_dst atIndex:3];
2773
 
2774
+ if (smem > 0) {
2775
+ [encoder setThreadgroupMemoryLength:smem atIndex:0];
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2776
  }
2777
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + nr0*nsg - 1)/(nr0*nsg), (ne11 + nr1 - 1)/nr1, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)];
2778
  }
2779
  } break;
2780
  case GGML_OP_MUL_MAT_ID:
 
2880
 
2881
  [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 31)/32, (ne01 + 63)/64, n_as) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
2882
  } else {
 
 
 
 
 
2883
  id<MTLComputePipelineState> pipeline = nil;
2884
 
2885
+ int nsg = 0; // number of simdgroups
2886
+ int nr0 = 0; // number of src0 rows per simdgroup
2887
+ int nr1 = 1; // number of src1 rows per threadgroup
2888
+
2889
+ size_t smem = 0; // shared memory
2890
+
2891
  // use custom matrix x vector kernel
2892
  switch (src0t) {
2893
  case GGML_TYPE_F32:
2894
  {
2895
  GGML_ASSERT(src1t == GGML_TYPE_F32);
2896
+ nsg = 1;
2897
+ nr0 = 1;
2898
  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32].pipeline;
2899
  } break;
2900
  case GGML_TYPE_F16:
2901
  {
2902
  GGML_ASSERT(src1t == GGML_TYPE_F32);
2903
+ nsg = 1;
2904
+ nr0 = 1;
2905
  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32].pipeline;
2906
  } break;
2907
  case GGML_TYPE_BF16:
2908
  {
2909
  GGML_ASSERT(src1t == GGML_TYPE_F32);
2910
+ nsg = 1;
2911
+ nr0 = 1;
2912
  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_BF16_F32].pipeline;
2913
  } break;
2914
  case GGML_TYPE_Q4_0:
2915
  {
2916
+ nsg = N_SG_Q4_0;
2917
+ nr0 = N_R0_Q4_0;
2918
  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_0_F32].pipeline;
2919
  } break;
2920
  case GGML_TYPE_Q4_1:
2921
  {
2922
+ nsg = N_SG_Q4_1;
2923
+ nr0 = N_R0_Q4_1;
2924
  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_1_F32].pipeline;
2925
  } break;
2926
  case GGML_TYPE_Q5_0:
2927
  {
2928
+ nsg = N_SG_Q5_0;
2929
+ nr0 = N_R0_Q5_0;
2930
  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_0_F32].pipeline;
2931
  } break;
2932
  case GGML_TYPE_Q5_1:
2933
  {
2934
+ nsg = N_SG_Q5_1;
2935
+ nr0 = N_R0_Q5_1;
2936
  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_1_F32].pipeline;
2937
  } break;
2938
  case GGML_TYPE_Q8_0:
2939
  {
2940
+ nsg = N_SG_Q8_0;
2941
+ nr0 = N_R0_Q8_0;
2942
  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q8_0_F32].pipeline;
2943
  } break;
2944
  case GGML_TYPE_Q2_K:
2945
  {
2946
+ nsg = N_SG_Q2_K;
2947
+ nr0 = N_R0_Q2_K;
2948
  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q2_K_F32].pipeline;
2949
  } break;
2950
  case GGML_TYPE_Q3_K:
2951
  {
2952
+ nsg = N_SG_Q3_K;
2953
+ nr0 = N_R0_Q3_K;
2954
  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q3_K_F32].pipeline;
2955
  } break;
2956
  case GGML_TYPE_Q4_K:
2957
  {
2958
+ nsg = N_SG_Q4_K;
2959
+ nr0 = N_R0_Q4_K;
2960
  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_K_F32].pipeline;
2961
  } break;
2962
  case GGML_TYPE_Q5_K:
2963
  {
2964
+ nsg = N_SG_Q5_K;
2965
+ nr0 = N_R0_Q5_K;
2966
  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_K_F32].pipeline;
2967
  } break;
2968
  case GGML_TYPE_Q6_K:
2969
  {
2970
+ nsg = N_SG_Q6_K;
2971
+ nr0 = N_R0_Q6_K;
2972
  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q6_K_F32].pipeline;
2973
  } break;
2974
  case GGML_TYPE_IQ2_XXS:
2975
  {
2976
+ nsg = N_SG_IQ2_XXS;
2977
+ nr0 = N_R0_IQ2_XXS;
2978
+ smem = 256*8+128;
2979
  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XXS_F32].pipeline;
2980
  } break;
2981
  case GGML_TYPE_IQ2_XS:
2982
  {
2983
+ nsg = N_SG_IQ2_XS;
2984
+ nr0 = N_R0_IQ2_XS;
2985
+ smem = 512*8+128;
2986
  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32].pipeline;
2987
  } break;
2988
  case GGML_TYPE_IQ3_XXS:
2989
  {
2990
+ nsg = N_SG_IQ3_XXS;
2991
+ nr0 = N_R0_IQ3_XXS;
2992
+ smem = 256*4+128;
2993
  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_XXS_F32].pipeline;
2994
  } break;
2995
  case GGML_TYPE_IQ3_S:
2996
  {
2997
+ nsg = N_SG_IQ3_S;
2998
+ nr0 = N_R0_IQ3_S;
2999
+ smem = 512*4;
3000
  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_S_F32].pipeline;
3001
  } break;
3002
  case GGML_TYPE_IQ2_S:
3003
  {
3004
+ nsg = N_SG_IQ2_S;
3005
+ nr0 = N_R0_IQ2_S;
3006
  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_S_F32].pipeline;
3007
  } break;
3008
  case GGML_TYPE_IQ1_S:
3009
  {
3010
+ nsg = N_SG_IQ1_S;
3011
+ nr0 = N_R0_IQ1_S;
3012
  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_S_F32].pipeline;
3013
  } break;
3014
  case GGML_TYPE_IQ1_M:
3015
  {
3016
+ nsg = N_SG_IQ1_M;
3017
+ nr0 = N_R0_IQ1_M;
3018
  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_M_F32].pipeline;
3019
  } break;
3020
  case GGML_TYPE_IQ4_NL:
3021
  {
3022
+ nsg = N_SG_IQ4_NL;
3023
+ nr0 = N_R0_IQ4_NL;
3024
+ smem = 32*sizeof(float);
3025
  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32].pipeline;
3026
  } break;
3027
  case GGML_TYPE_IQ4_XS:
3028
  {
3029
+ nsg = N_SG_IQ4_XS;
3030
+ nr0 = N_R0_IQ4_XS;
3031
+ smem = 32*sizeof(float);
3032
  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32].pipeline;
3033
  } break;
3034
  default:
 
3039
  };
3040
 
3041
  if (ggml_is_quantized(src0t)) {
3042
+ GGML_ASSERT(ne00 >= nsg*nr0);
3043
  }
3044
 
3045
  ggml_metal_kargs_mul_mv_id args = {
 
3072
  [encoder setBuffer:id_src2 offset:offs_src2 atIndex:4];
3073
 
3074
  const int64_t _ne1 = 1;
3075
+ const int64_t ne123 = dst_rows;
3076
 
3077
+ if (smem > 0) {
3078
+ [encoder setThreadgroupMemoryLength:smem atIndex:0];
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3079
  }
3080
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + nr0*nsg - 1)/(nr0*nsg), (_ne1 + nr1 - 1)/nr1, ne123) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)];
3081
  }
3082
  } break;
3083
  case GGML_OP_GET_ROWS:
ggml/src/ggml-metal/ggml-metal.metal CHANGED
@@ -1439,7 +1439,7 @@ kernel void kernel_rwkv_wkv7_f32(
1439
 
1440
  float4 sa_vec(0.0);
1441
 
1442
- for (int j = 0; j < head_size; j += 4) {
1443
  float4 a_vec = float4(_a[j], _a[j+1], _a[j+2], _a[j+3]);
1444
  float4 s_vec = float4(state[j], state[j+1], state[j+2], state[j+3]);
1445
  sa_vec += a_vec * s_vec;
@@ -1853,14 +1853,7 @@ inline float block_q_n_dot_y(device const block_q5_1 * qb_curr, float sumy, thre
1853
  return d * (acc[0] + acc[1] + acc[2] + acc[3]) + sumy * m;
1854
  }
1855
 
1856
- // putting them in the kernel cause a significant performance penalty
1857
- #define N_DST 4 // each SIMD group works on 4 rows
1858
- #define N_SIMDGROUP 2 // number of SIMD groups in a thread group
1859
- //Note: This is a template, but strictly speaking it only applies to
1860
- // quantizations where the block size is 32. It also does not
1861
- // guard against the number of rows not being divisible by
1862
- // N_DST, so this is another explicit assumption of the implementation.
1863
- template<typename block_q_type, int nr, int nsg, int nw, typename args_t>
1864
  void mul_vec_q_n_f32_impl(
1865
  args_t args,
1866
  device const char * src0,
@@ -1876,7 +1869,7 @@ void mul_vec_q_n_f32_impl(
1876
  const int r1 = tgpig.y;
1877
  const int im = tgpig.z;
1878
 
1879
- const int first_row = (r0 * nsg + sgitg) * nr;
1880
 
1881
  const uint i12 = im%args.ne12;
1882
  const uint i13 = im/args.ne12;
@@ -1888,15 +1881,15 @@ void mul_vec_q_n_f32_impl(
1888
  device const float * y = (device const float *) (src1 + offset1);
1889
 
1890
  // pointers to src0 rows
1891
- device const block_q_type * ax[nr];
1892
- for (int row = 0; row < nr; ++row) {
1893
  const uint64_t offset0 = (first_row + row)*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
1894
 
1895
  ax[row] = (device const block_q_type *) ((device char *) src0 + offset0);
1896
  }
1897
 
1898
  float yl[16]; // src1 vector cache
1899
- float sumf[nr] = {0.f};
1900
 
1901
  const short ix = (tiisg/2);
1902
  const short il = (tiisg%2)*8;
@@ -1908,7 +1901,7 @@ void mul_vec_q_n_f32_impl(
1908
  float sumy[2] = { 0.f, 0.f };
1909
 
1910
  #pragma unroll
1911
- for (int i = 0; i < 8; i += 2) {
1912
  sumy[0] += yb[i + 0] + yb[i + 1];
1913
  yl[i + 0] = yb[i + 0];
1914
  yl[i + 1] = yb[i + 1]/256.f;
@@ -1919,7 +1912,7 @@ void mul_vec_q_n_f32_impl(
1919
  }
1920
 
1921
  #pragma unroll
1922
- for (int row = 0; row < nr; row++) {
1923
  sumf[row] += block_q_n_dot_y(ax[row] + ib, sumy[0] + sumy[1], yl, il);
1924
  }
1925
 
@@ -1928,7 +1921,7 @@ void mul_vec_q_n_f32_impl(
1928
 
1929
  device float * dst_f32 = (device float *) dst + im*args.ne0*args.ne1 + r1*args.ne0;
1930
 
1931
- for (int row = 0; row < nr; ++row) {
1932
  const float tot = simd_sum(sumf[row]);
1933
 
1934
  if (tiisg == 0 && first_row + row < args.ne01) {
@@ -1945,7 +1938,7 @@ kernel void kernel_mul_mv_q4_0_f32(
1945
  uint3 tgpig[[threadgroup_position_in_grid]],
1946
  ushort tiisg[[thread_index_in_simdgroup]],
1947
  ushort sgitg[[simdgroup_index_in_threadgroup]]) {
1948
- mul_vec_q_n_f32_impl<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
1949
  }
1950
 
1951
  kernel void kernel_mul_mv_q4_1_f32(
@@ -1956,7 +1949,7 @@ kernel void kernel_mul_mv_q4_1_f32(
1956
  uint3 tgpig[[threadgroup_position_in_grid]],
1957
  ushort tiisg[[thread_index_in_simdgroup]],
1958
  ushort sgitg[[simdgroup_index_in_threadgroup]]) {
1959
- mul_vec_q_n_f32_impl<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
1960
  }
1961
 
1962
  kernel void kernel_mul_mv_q5_0_f32(
@@ -1967,7 +1960,7 @@ kernel void kernel_mul_mv_q5_0_f32(
1967
  uint3 tgpig[[threadgroup_position_in_grid]],
1968
  ushort tiisg[[thread_index_in_simdgroup]],
1969
  ushort sgitg[[simdgroup_index_in_threadgroup]]) {
1970
- mul_vec_q_n_f32_impl<block_q5_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
1971
  }
1972
 
1973
  kernel void kernel_mul_mv_q5_1_f32(
@@ -1978,12 +1971,12 @@ kernel void kernel_mul_mv_q5_1_f32(
1978
  uint3 tgpig[[threadgroup_position_in_grid]],
1979
  ushort tiisg[[thread_index_in_simdgroup]],
1980
  ushort sgitg[[simdgroup_index_in_threadgroup]]) {
1981
- mul_vec_q_n_f32_impl<block_q5_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
1982
  }
1983
 
1984
  #define NB_Q8_0 8
1985
 
1986
- template<typename args_t>
1987
  void kernel_mul_mv_q8_0_f32_impl(
1988
  args_t args,
1989
  device const char * src0,
@@ -1993,16 +1986,13 @@ void kernel_mul_mv_q8_0_f32_impl(
1993
  uint3 tgpig,
1994
  ushort tiisg,
1995
  ushort sgitg) {
1996
- const int nr = N_DST;
1997
- const int nsg = N_SIMDGROUP;
1998
- const int nw = N_SIMDWIDTH;
1999
-
2000
  const int nb = args.ne00/QK8_0;
 
2001
  const int r0 = tgpig.x;
2002
  const int r1 = tgpig.y;
2003
  const int im = tgpig.z;
2004
 
2005
- const int first_row = (r0*nsg + sgitg)*nr;
2006
 
2007
  const uint i12 = im%args.ne12;
2008
  const uint i13 = im/args.ne12;
@@ -2014,15 +2004,15 @@ void kernel_mul_mv_q8_0_f32_impl(
2014
  device const float * y = (device const float *) (src1 + offset1);
2015
 
2016
  // pointers to src0 rows
2017
- device const block_q8_0 * ax[nr];
2018
- for (int row = 0; row < nr; ++row) {
2019
  const uint64_t offset0 = (first_row + row)*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
2020
 
2021
  ax[row] = (device const block_q8_0 *) ((device char *) src0 + offset0);
2022
  }
2023
 
2024
  float yl[NB_Q8_0];
2025
- float sumf[nr] = { 0.f };
2026
 
2027
  const short ix = tiisg/4;
2028
  const short il = tiisg%4;
@@ -2035,7 +2025,7 @@ void kernel_mul_mv_q8_0_f32_impl(
2035
  yl[i] = yb[i];
2036
  }
2037
 
2038
- for (int row = 0; row < nr; row++) {
2039
  device const int8_t * qs = ax[row][ib].qs + il*NB_Q8_0;
2040
  float sumq = 0.f;
2041
  for (short iq = 0; iq < NB_Q8_0; ++iq) {
@@ -2049,7 +2039,7 @@ void kernel_mul_mv_q8_0_f32_impl(
2049
 
2050
  device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
2051
 
2052
- for (int row = 0; row < nr; ++row) {
2053
  const float tot = simd_sum(sumf[row]);
2054
 
2055
  if (tiisg == 0 && first_row + row < args.ne01) {
@@ -2067,7 +2057,7 @@ kernel void kernel_mul_mv_q8_0_f32(
2067
  uint3 tgpig[[threadgroup_position_in_grid]],
2068
  ushort tiisg[[thread_index_in_simdgroup]],
2069
  ushort sgitg[[simdgroup_index_in_threadgroup]]) {
2070
- kernel_mul_mv_q8_0_f32_impl<constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
2071
  }
2072
 
2073
  // mat-vec kernel processing in chunks of float4
@@ -2404,9 +2394,9 @@ void kernel_mul_mv_impl(
2404
  sumf += (T0) x[i] * (T1) y[i];
2405
  }
2406
 
2407
- float all_sum = simd_sum(sumf);
2408
  if (tiisg == 0) {
2409
- dst_f32[(uint64_t)r1*args.ne0 + r0] = all_sum;
2410
  }
2411
  }
2412
  } else {
@@ -2427,10 +2417,10 @@ void kernel_mul_mv_impl(
2427
  sumf += dot((float4) x4[i], (float4) y4[i]);
2428
  }
2429
 
2430
- float all_sum = simd_sum(sumf);
2431
  if (tiisg == 0) {
2432
- for (int i = 4*(args.ne00/4); i < args.ne00; ++i) all_sum += (float) (x[i] * y[i]);
2433
- dst_f32[(uint64_t)r1*args.ne0 + r0] = all_sum;
2434
  }
2435
  }
2436
  }
@@ -2492,9 +2482,9 @@ kernel void kernel_mul_mv_1row(
2492
  for (int i = tiisg; i < args.ne00; i += 32) {
2493
  sumf += (float) x[i] * (float) y[i];
2494
  }
2495
- float all_sum = simd_sum(sumf);
2496
  if (tiisg == 0) {
2497
- dst_f32[r0] = all_sum;
2498
  }
2499
  } else {
2500
  device const T4 * x4 = (device const T4 *) x;
@@ -2504,11 +2494,11 @@ kernel void kernel_mul_mv_1row(
2504
  sumf += dot((float4) x4[i], y4[i]);
2505
  }
2506
 
2507
- float all_sum = simd_sum(sumf);
2508
 
2509
  if (tiisg == 0) {
2510
- for (int i = 4*(args.ne00/4); i < args.ne00; ++i) all_sum += (float) (x[i] * y[i]);
2511
- dst_f32[r0] = all_sum;
2512
  }
2513
  }
2514
  }
@@ -2553,9 +2543,9 @@ kernel void kernel_mul_mv_l4(
2553
  sumf += dot((float4) x4[i], y4[i]);
2554
  }
2555
 
2556
- float all_sum = simd_sum(sumf);
2557
  if (tiisg == 0) {
2558
- dst_f32[(uint64_t)r1*args.ne0 + r0] = all_sum;
2559
  }
2560
  }
2561
  }
@@ -4321,7 +4311,7 @@ kernel void kernel_cpy_f32_iq4_nl(
4321
  float amax = 0.0f; // absolute max
4322
  float max = 0.0f;
4323
 
4324
- for (int j = 0; j < QK4_0; j++) {
4325
  const float v = src[j];
4326
  if (amax < fabs(v)) {
4327
  amax = fabs(v);
@@ -4429,7 +4419,7 @@ kernel void kernel_concat(
4429
  }
4430
  }
4431
 
4432
- template<typename args_t>
4433
  void kernel_mul_mv_q2_K_f32_impl(
4434
  args_t args,
4435
  device const char * src0,
@@ -4445,7 +4435,7 @@ void kernel_mul_mv_q2_K_f32_impl(
4445
  const int r1 = tgpig.y;
4446
  const int im = tgpig.z;
4447
 
4448
- const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
4449
 
4450
  const uint i12 = im%args.ne12;
4451
  const uint i13 = im/args.ne12;
@@ -4457,20 +4447,19 @@ void kernel_mul_mv_q2_K_f32_impl(
4457
  device const float * y = (device const float *) (src1 + offset1);
4458
 
4459
  float yl[32];
4460
- float sumf[N_DST]={0.f}, all_sum;
4461
 
4462
- const int ix = tiisg/8; // 0...3
4463
- const int it = tiisg%8; // 0...7
4464
- const int iq = it/4; // 0 or 1
4465
- const int ir = it%4; // 0...3
4466
- const int is = (8*ir)/16;// 0 or 1
4467
 
4468
  device const float * y4 = y + ix * QK_K + 128 * iq + 8 * ir;
4469
 
4470
  for (int ib = ix; ib < nb; ib += 4) {
4471
-
4472
  float4 sumy = {0.f, 0.f, 0.f, 0.f};
4473
- for (int i = 0; i < 8; ++i) {
4474
  yl[i+ 0] = y4[i+ 0]; sumy[0] += yl[i+ 0];
4475
  yl[i+ 8] = y4[i+32]; sumy[1] += yl[i+ 8];
4476
  yl[i+16] = y4[i+64]; sumy[2] += yl[i+16];
@@ -4481,7 +4470,7 @@ void kernel_mul_mv_q2_K_f32_impl(
4481
  device const uint16_t * qs = (device const uint16_t *)x[ib].qs + 16 * iq + 4 * ir;
4482
  device const half * dh = &x[ib].d;
4483
 
4484
- for (int row = 0; row < N_DST; row++) {
4485
  float4 acc1 = {0.f, 0.f, 0.f, 0.f};
4486
  float4 acc2 = {0.f, 0.f, 0.f, 0.f};
4487
  for (int i = 0; i < 8; i += 2) {
@@ -4512,10 +4501,10 @@ void kernel_mul_mv_q2_K_f32_impl(
4512
 
4513
  device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
4514
 
4515
- for (int row = 0; row < N_DST && first_row + row < args.ne0; ++row) {
4516
- all_sum = simd_sum(sumf[row]);
4517
  if (tiisg == 0) {
4518
- dst_f32[first_row + row] = all_sum;
4519
  }
4520
  }
4521
  }
@@ -4530,10 +4519,10 @@ kernel void kernel_mul_mv_q2_K_f32(
4530
  ushort tiisg[[thread_index_in_simdgroup]],
4531
  ushort sgitg[[simdgroup_index_in_threadgroup]]) {
4532
 
4533
- kernel_mul_mv_q2_K_f32_impl<constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
4534
  }
4535
 
4536
- template<typename args_t>
4537
  void kernel_mul_mv_q3_K_f32_impl(
4538
  args_t args,
4539
  device const char * src0,
@@ -4550,7 +4539,7 @@ void kernel_mul_mv_q3_K_f32_impl(
4550
  const int r1 = tgpig.y;
4551
  const int im = tgpig.z;
4552
 
4553
- const int first_row = (r0 * N_SIMDGROUP + sgitg) * 2;
4554
 
4555
  const uint i12 = im%args.ne12;
4556
  const uint i13 = im/args.ne12;
@@ -4566,13 +4555,12 @@ void kernel_mul_mv_q3_K_f32_impl(
4566
  //const uint16_t kmask1 = 0x3030;
4567
  //const uint16_t kmask2 = 0x0f0f;
4568
 
4569
- const int tid = tiisg/4;
4570
- const int ix = tiisg%4;
4571
- const int ip = tid/4; // 0 or 1
4572
- const int il = 2*((tid%4)/2); // 0 or 2
4573
- const int ir = tid%2;
4574
- const int n = 8;
4575
- const int l0 = n*ir;
4576
 
4577
  // One would think that the Metal compiler would figure out that ip and il can only have
4578
  // 4 possible states, and optimize accordingly. Well, no. It needs help, and we do it
@@ -4597,8 +4585,8 @@ void kernel_mul_mv_q3_K_f32_impl(
4597
  const uint16_t s_shift1 = 4*ip;
4598
  const uint16_t s_shift2 = s_shift1 + il;
4599
 
4600
- const int q_offset = 32*ip + l0;
4601
- const int y_offset = 128*ip + 32*il + l0;
4602
 
4603
  device const float * y1 = yy + ix*QK_K + y_offset;
4604
 
@@ -4606,10 +4594,11 @@ void kernel_mul_mv_q3_K_f32_impl(
4606
  thread uint16_t * scales16 = (thread uint16_t *)&scales32;
4607
  thread const int8_t * scales = (thread const int8_t *)&scales32;
4608
 
4609
- float sumf1[2] = {0.f};
4610
- float sumf2[2] = {0.f};
 
4611
  for (int i = ix; i < nb; i += 4) {
4612
- for (int l = 0; l < 8; ++l) {
4613
  yl[l+ 0] = y1[l+ 0];
4614
  yl[l+ 8] = y1[l+16];
4615
  yl[l+16] = y1[l+32];
@@ -4621,7 +4610,7 @@ void kernel_mul_mv_q3_K_f32_impl(
4621
  device const uint16_t * a = (device const uint16_t *)(x[i].scales);
4622
  device const half * dh = &x[i].d;
4623
 
4624
- for (int row = 0; row < 2; ++row) {
4625
  const float d_all = (float)dh[0];
4626
 
4627
  scales16[0] = a[4];
@@ -4632,7 +4621,7 @@ void kernel_mul_mv_q3_K_f32_impl(
4632
  scales32 = ((scales32 >> s_shift1) & 0x0f0f0f0f) | aux32;
4633
 
4634
  float s1 = 0, s2 = 0, s3 = 0, s4 = 0, s5 = 0, s6 = 0;
4635
- for (int l = 0; l < n; l += 2) {
4636
  const int32_t qs = q[l/2];
4637
  s1 += yl[l+0] * (qs & qm[il/2][0]);
4638
  s2 += yl[l+1] * (qs & qm[il/2][1]);
@@ -4647,7 +4636,7 @@ void kernel_mul_mv_q3_K_f32_impl(
4647
  sumf2[row] += d2 * (scales[2] - 32);
4648
 
4649
  s1 = s2 = s3 = s4 = s5 = s6 = 0;
4650
- for (int l = 0; l < n; l += 2) {
4651
  const int32_t qs = q[l/2+8];
4652
  s1 += yl[l+8] * (qs & qm[il/2][0]);
4653
  s2 += yl[l+9] * (qs & qm[il/2][1]);
@@ -4670,7 +4659,7 @@ void kernel_mul_mv_q3_K_f32_impl(
4670
  y1 += 4 * QK_K;
4671
  }
4672
 
4673
- for (int row = 0; row < 2; ++row) {
4674
  const float sumf = (sumf1[row] + 0.25f * sumf2[row]) / (1 << shift);
4675
  sumf1[row] = simd_sum(sumf);
4676
  }
@@ -4678,7 +4667,7 @@ void kernel_mul_mv_q3_K_f32_impl(
4678
  device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
4679
 
4680
  if (tiisg == 0) {
4681
- for (int row = 0; row < 2 && first_row + row < args.ne0; ++row) {
4682
  dst_f32[first_row + row] = sumf1[row];
4683
  }
4684
  }
@@ -4694,10 +4683,10 @@ kernel void kernel_mul_mv_q3_K_f32(
4694
  ushort tiisg[[thread_index_in_simdgroup]],
4695
  ushort sgitg[[simdgroup_index_in_threadgroup]]) {
4696
 
4697
- kernel_mul_mv_q3_K_f32_impl<constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
4698
  }
4699
 
4700
- template<typename args_t>
4701
  void kernel_mul_mv_q4_K_f32_impl(
4702
  args_t args,
4703
  device const char * src0,
@@ -4707,22 +4696,22 @@ void kernel_mul_mv_q4_K_f32_impl(
4707
  uint3 tgpig,
4708
  ushort tiisg,
4709
  ushort sgitg) {
4710
-
4711
  const uint16_t kmask1 = 0x3f3f;
4712
  const uint16_t kmask2 = 0x0f0f;
4713
  const uint16_t kmask3 = 0xc0c0;
4714
 
4715
- const int ix = tiisg/8; // 0...3
4716
- const int it = tiisg%8; // 0...7
4717
- const int iq = it/4; // 0 or 1
4718
- const int ir = it%4; // 0...3
4719
 
4720
  const int nb = args.ne00/QK_K;
 
4721
  const int r0 = tgpig.x;
4722
  const int r1 = tgpig.y;
4723
  const int im = tgpig.z;
4724
- //const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
4725
- const int first_row = r0 * N_DST;
4726
 
4727
  const uint i12 = im%args.ne12;
4728
  const uint i13 = im/args.ne12;
@@ -4735,7 +4724,8 @@ void kernel_mul_mv_q4_K_f32_impl(
4735
 
4736
  float yl[16];
4737
  float yh[16];
4738
- float sumf[N_DST]={0.f}, all_sum;
 
4739
 
4740
  device const float * y4 = y + ix * QK_K + 64 * iq + 8 * ir;
4741
 
@@ -4744,7 +4734,8 @@ void kernel_mul_mv_q4_K_f32_impl(
4744
 
4745
  for (int ib = ix; ib < nb; ib += 4) {
4746
  float4 sumy = {0.f, 0.f, 0.f, 0.f};
4747
- for (int i = 0; i < 8; ++i) {
 
4748
  yl[i+0] = y4[i+ 0]; sumy[0] += yl[i+0];
4749
  yl[i+8] = y4[i+ 32]; sumy[1] += yl[i+8];
4750
  yh[i+0] = y4[i+128]; sumy[2] += yh[i+0];
@@ -4755,7 +4746,7 @@ void kernel_mul_mv_q4_K_f32_impl(
4755
  device const uint16_t * q1 = (device const uint16_t *)x[ib].qs + 16 * iq + 4 * ir;
4756
  device const half * dh = &x[ib].d;
4757
 
4758
- for (int row = 0; row < N_DST; row++) {
4759
  sc16[0] = sc[0] & kmask1;
4760
  sc16[1] = sc[2] & kmask1;
4761
  sc16[2] = ((sc[4] >> 0) & kmask2) | ((sc[0] & kmask3) >> 2);
@@ -4765,19 +4756,21 @@ void kernel_mul_mv_q4_K_f32_impl(
4765
 
4766
  float4 acc1 = {0.f, 0.f, 0.f, 0.f};
4767
  float4 acc2 = {0.f, 0.f, 0.f, 0.f};
4768
- for (int i = 0; i < 8; i += 2) {
4769
- acc1[0] += yl[i+0] * (q1[i/2] & 0x000F);
4770
- acc1[1] += yl[i+1] * (q1[i/2] & 0x0F00);
4771
- acc1[2] += yl[i+8] * (q1[i/2] & 0x00F0);
4772
- acc1[3] += yl[i+9] * (q1[i/2] & 0xF000);
4773
- acc2[0] += yh[i+0] * (q2[i/2] & 0x000F);
4774
- acc2[1] += yh[i+1] * (q2[i/2] & 0x0F00);
4775
- acc2[2] += yh[i+8] * (q2[i/2] & 0x00F0);
4776
- acc2[3] += yh[i+9] * (q2[i/2] & 0xF000);
 
4777
  }
4778
 
4779
  float dall = dh[0];
4780
  float dmin = dh[1];
 
4781
  sumf[row] += dall * ((acc1[0] + 1.f/256.f * acc1[1]) * sc8[0] +
4782
  (acc1[2] + 1.f/256.f * acc1[3]) * sc8[1] * 1.f/16.f +
4783
  (acc2[0] + 1.f/256.f * acc2[1]) * sc8[4] +
@@ -4794,10 +4787,10 @@ void kernel_mul_mv_q4_K_f32_impl(
4794
 
4795
  device float * dst_f32 = (device float *) dst + (int64_t)im*args.ne0*args.ne1 + (int64_t)r1*args.ne0;
4796
 
4797
- for (int row = 0; row < N_DST && first_row + row < args.ne0; ++row) {
4798
- all_sum = simd_sum(sumf[row]);
4799
  if (tiisg == 0) {
4800
- dst_f32[first_row + row] = all_sum;
4801
  }
4802
  }
4803
  }
@@ -4812,10 +4805,10 @@ kernel void kernel_mul_mv_q4_K_f32(
4812
  ushort tiisg[[thread_index_in_simdgroup]],
4813
  ushort sgitg[[simdgroup_index_in_threadgroup]]) {
4814
 
4815
- kernel_mul_mv_q4_K_f32_impl<constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
4816
  }
4817
 
4818
- template<typename args_t>
4819
  void kernel_mul_mv_q5_K_f32_impl(
4820
  args_t args,
4821
  device const char * src0,
@@ -4832,7 +4825,7 @@ void kernel_mul_mv_q5_K_f32_impl(
4832
  const int r1 = tgpig.y;
4833
  const int im = tgpig.z;
4834
 
4835
- const int first_row = (r0 * N_SIMDGROUP + sgitg) * 2;
4836
 
4837
  const uint i12 = im%args.ne12;
4838
  const uint i13 = im/args.ne12;
@@ -4843,7 +4836,7 @@ void kernel_mul_mv_q5_K_f32_impl(
4843
  device const block_q5_K * x = (device const block_q5_K *) (src0 + offset0);
4844
  device const float * yy = (device const float *) (src1 + offset1);
4845
 
4846
- float sumf[2]={0.f};
4847
 
4848
  float yl[16], yh[16];
4849
 
@@ -4851,15 +4844,14 @@ void kernel_mul_mv_q5_K_f32_impl(
4851
  const uint16_t kmask2 = 0x0f0f;
4852
  const uint16_t kmask3 = 0xc0c0;
4853
 
4854
- const int tid = tiisg/4;
4855
- const int ix = tiisg%4;
4856
- const int iq = tid/4;
4857
- const int ir = tid%4;
4858
- const int n = 8;
4859
 
4860
- const int l0 = n*ir;
4861
- const int q_offset = 32*iq + l0;
4862
- const int y_offset = 64*iq + l0;
4863
 
4864
  const uint8_t hm1 = 1u << (2*iq);
4865
  const uint8_t hm2 = hm1 << 1;
@@ -4879,14 +4871,14 @@ void kernel_mul_mv_q5_K_f32_impl(
4879
 
4880
  device const float * y2 = y1 + 128;
4881
  float4 sumy = {0.f, 0.f, 0.f, 0.f};
4882
- for (int l = 0; l < 8; ++l) {
4883
  yl[l+0] = y1[l+ 0]; sumy[0] += yl[l+0];
4884
  yl[l+8] = y1[l+32]; sumy[1] += yl[l+8];
4885
  yh[l+0] = y2[l+ 0]; sumy[2] += yh[l+0];
4886
  yh[l+8] = y2[l+32]; sumy[3] += yh[l+8];
4887
  }
4888
 
4889
- for (int row = 0; row < 2; ++row) {
4890
  device const uint8_t * q2 = q1 + 64;
4891
 
4892
  sc16[0] = a[0] & kmask1;
@@ -4896,7 +4888,7 @@ void kernel_mul_mv_q5_K_f32_impl(
4896
 
4897
  float4 acc1 = {0.f};
4898
  float4 acc2 = {0.f};
4899
- for (int l = 0; l < n; ++l) {
4900
  uint8_t h = qh[l];
4901
  acc1[0] += yl[l+0] * (q1[l] & 0x0F);
4902
  acc1[1] += yl[l+8] * (q1[l] & 0xF0);
@@ -4926,7 +4918,7 @@ void kernel_mul_mv_q5_K_f32_impl(
4926
 
4927
  device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
4928
 
4929
- for (int row = 0; row < 2 && first_row + row < args.ne0; ++row) {
4930
  const float tot = simd_sum(sumf[row]);
4931
  if (tiisg == 0) {
4932
  dst_f32[first_row + row] = tot;
@@ -4944,10 +4936,10 @@ kernel void kernel_mul_mv_q5_K_f32(
4944
  ushort tiisg[[thread_index_in_simdgroup]],
4945
  ushort sgitg[[simdgroup_index_in_threadgroup]]) {
4946
 
4947
- kernel_mul_mv_q5_K_f32_impl<constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
4948
  }
4949
 
4950
- template <typename args_t>
4951
  void kernel_mul_mv_q6_K_f32_impl(
4952
  args_t args,
4953
  device const char * src0,
@@ -4969,62 +4961,77 @@ void kernel_mul_mv_q6_K_f32_impl(
4969
  const int r1 = tgpig.y;
4970
  const int im = tgpig.z;
4971
 
4972
- const int row = 2*r0 + sgitg;
4973
-
4974
- if (row >= args.ne0) {
4975
- return;
4976
- }
4977
 
4978
  const uint i12 = im%args.ne12;
4979
  const uint i13 = im/args.ne12;
4980
 
4981
- const uint64_t offset0 = row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
4982
- const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
4983
 
4984
  device const block_q6_K * x = (device const block_q6_K *) (src0 + offset0);
4985
  device const float * yy = (device const float *) (src1 + offset1);
4986
 
4987
- float sumf = 0;
 
 
4988
 
4989
- const int tid = tiisg/2;
4990
- const int ix = tiisg%2;
4991
- const int ip = tid/8; // 0 or 1
4992
- const int il = tid%8;
4993
- const int n = 4;
4994
- const int l0 = n*il;
4995
- const int is = 8*ip + l0/16;
4996
 
4997
- const int y_offset = 128*ip + l0;
4998
- const int q_offset_l = 64*ip + l0;
4999
- const int q_offset_h = 32*ip + l0;
5000
 
5001
  for (int i = ix; i < nb; i += 2) {
5002
  device const uint8_t * q1 = x[i].ql + q_offset_l;
5003
  device const uint8_t * q2 = q1 + 32;
5004
  device const uint8_t * qh = x[i].qh + q_offset_h;
5005
  device const int8_t * sc = x[i].scales + is;
 
5006
 
5007
  device const float * y = yy + i * QK_K + y_offset;
5008
 
5009
- const float dall = x[i].d;
5010
-
5011
- float4 sums = {0.f, 0.f, 0.f, 0.f};
5012
- for (int l = 0; l < n; ++l) {
5013
- sums[0] += y[l+ 0] * ((int8_t)((q1[l] & 0xF) | ((qh[l] & kmask1) << 4)) - 32);
5014
- sums[1] += y[l+32] * ((int8_t)((q2[l] & 0xF) | ((qh[l] & kmask2) << 2)) - 32);
5015
- sums[2] += y[l+64] * ((int8_t)((q1[l] >> 4) | ((qh[l] & kmask3) << 0)) - 32);
5016
- sums[3] += y[l+96] * ((int8_t)((q2[l] >> 4) | ((qh[l] & kmask4) >> 2)) - 32);
5017
  }
5018
 
5019
- sumf += dall * (sums[0] * sc[0] + sums[1] * sc[2] + sums[2] * sc[4] + sums[3] * sc[6]);
 
 
 
5020
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5021
  }
5022
 
5023
  device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
5024
 
5025
- const float tot = simd_sum(sumf);
5026
- if (tiisg == 0) {
5027
- dst_f32[row] = tot;
 
 
5028
  }
5029
  }
5030
 
@@ -5038,12 +5045,12 @@ kernel void kernel_mul_mv_q6_K_f32(
5038
  ushort tiisg[[thread_index_in_simdgroup]],
5039
  ushort sgitg[[simdgroup_index_in_threadgroup]]) {
5040
 
5041
- kernel_mul_mv_q6_K_f32_impl<constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
5042
  }
5043
 
5044
  // ======================= "True" 2-bit
5045
 
5046
- template<typename args_t>
5047
  void kernel_mul_mv_iq2_xxs_f32_impl(
5048
  args_t args,
5049
  device const char * src0,
@@ -5059,7 +5066,7 @@ void kernel_mul_mv_iq2_xxs_f32_impl(
5059
  const int r1 = tgpig.y;
5060
  const int im = tgpig.z;
5061
 
5062
- const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
5063
 
5064
  const uint i12 = im%args.ne12;
5065
  const uint i13 = im/args.ne12;
@@ -5071,7 +5078,7 @@ void kernel_mul_mv_iq2_xxs_f32_impl(
5071
  device const float * y = (device const float *) (src1 + offset1);
5072
 
5073
  float yl[32];
5074
- float sumf[N_DST]={0.f}, all_sum;
5075
 
5076
  const int nb32 = nb * (QK_K / 32);
5077
 
@@ -5092,8 +5099,7 @@ void kernel_mul_mv_iq2_xxs_f32_impl(
5092
  device const float * y4 = y + 32 * ix;
5093
 
5094
  for (int ib32 = ix; ib32 < nb32; ib32 += 32) {
5095
-
5096
- for (int i = 0; i < 32; ++i) {
5097
  yl[i] = y4[i];
5098
  }
5099
 
@@ -5104,18 +5110,17 @@ void kernel_mul_mv_iq2_xxs_f32_impl(
5104
  device const uint16_t * q2 = xr->qs + 4 * ib;
5105
  device const half * dh = &xr->d;
5106
 
5107
- for (int row = 0; row < N_DST; row++) {
5108
-
5109
  const float db = dh[0];
5110
  device const uint8_t * aux8 = (device const uint8_t *)q2;
5111
  const uint32_t aux32 = q2[2] | (q2[3] << 16);
5112
  const float d = db * (0.5f + (aux32 >> 28));
5113
 
5114
  float sum = 0;
5115
- for (int l = 0; l < 4; ++l) {
5116
  const threadgroup uint8_t * grid = (const threadgroup uint8_t *)(svalues + aux8[l]);
5117
  const uint8_t signs = ssigns[(aux32 >> 7*l) & 127];
5118
- for (int j = 0; j < 8; ++j) {
5119
  sum += yl[8*l + j] * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
5120
  }
5121
  }
@@ -5130,10 +5135,10 @@ void kernel_mul_mv_iq2_xxs_f32_impl(
5130
 
5131
  device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
5132
 
5133
- for (int row = 0; row < N_DST && first_row + row < args.ne0; ++row) {
5134
- all_sum = simd_sum(sumf[row]);
5135
  if (tiisg == 0) {
5136
- dst_f32[first_row + row] = all_sum * 0.25f;
5137
  }
5138
  }
5139
  }
@@ -5148,10 +5153,10 @@ kernel void kernel_mul_mv_iq2_xxs_f32(
5148
  uint3 tgpig[[threadgroup_position_in_grid]],
5149
  ushort tiisg[[thread_index_in_simdgroup]],
5150
  ushort sgitg[[simdgroup_index_in_threadgroup]]) {
5151
- kernel_mul_mv_iq2_xxs_f32_impl<constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
5152
  }
5153
 
5154
- template<typename args_t>
5155
  void kernel_mul_mv_iq2_xs_f32_impl(
5156
  args_t args,
5157
  device const char * src0,
@@ -5167,7 +5172,7 @@ void kernel_mul_mv_iq2_xs_f32_impl(
5167
  const int r1 = tgpig.y;
5168
  const int im = tgpig.z;
5169
 
5170
- const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
5171
 
5172
  const uint i12 = im%args.ne12;
5173
  const uint i13 = im/args.ne12;
@@ -5179,7 +5184,7 @@ void kernel_mul_mv_iq2_xs_f32_impl(
5179
  device const float * y = (device const float *) (src1 + offset1);
5180
 
5181
  float yl[32];
5182
- float sumf[N_DST]={0.f}, all_sum;
5183
 
5184
  const int nb32 = nb * (QK_K / 32);
5185
 
@@ -5200,8 +5205,7 @@ void kernel_mul_mv_iq2_xs_f32_impl(
5200
  device const float * y4 = y + 32 * ix;
5201
 
5202
  for (int ib32 = ix; ib32 < nb32; ib32 += 32) {
5203
-
5204
- for (int i = 0; i < 32; ++i) {
5205
  yl[i] = y4[i];
5206
  }
5207
 
@@ -5213,8 +5217,7 @@ void kernel_mul_mv_iq2_xs_f32_impl(
5213
  device const uint8_t * sc = xr->scales + ib;
5214
  device const half * dh = &xr->d;
5215
 
5216
- for (int row = 0; row < N_DST; row++) {
5217
-
5218
  const float db = dh[0];
5219
  const uint8_t ls1 = sc[0] & 0xf;
5220
  const uint8_t ls2 = sc[0] >> 4;
@@ -5222,17 +5225,17 @@ void kernel_mul_mv_iq2_xs_f32_impl(
5222
  const float d2 = db * (0.5f + ls2);
5223
 
5224
  float sum1 = 0, sum2 = 0;
5225
- for (int l = 0; l < 2; ++l) {
5226
  const threadgroup uint8_t * grid = (const threadgroup uint8_t *)(svalues + (q2[l] & 511));
5227
  const uint8_t signs = ssigns[(q2[l] >> 9)];
5228
- for (int j = 0; j < 8; ++j) {
5229
  sum1 += yl[8*l + j] * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
5230
  }
5231
  }
5232
- for (int l = 2; l < 4; ++l) {
5233
  const threadgroup uint8_t * grid = (const threadgroup uint8_t *)(svalues + (q2[l] & 511));
5234
  const uint8_t signs = ssigns[(q2[l] >> 9)];
5235
- for (int j = 0; j < 8; ++j) {
5236
  sum2 += yl[8*l + j] * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
5237
  }
5238
  }
@@ -5248,10 +5251,10 @@ void kernel_mul_mv_iq2_xs_f32_impl(
5248
 
5249
  device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
5250
 
5251
- for (int row = 0; row < N_DST && first_row + row < args.ne0; ++row) {
5252
- all_sum = simd_sum(sumf[row]);
5253
  if (tiisg == 0) {
5254
- dst_f32[first_row + row] = all_sum * 0.25f;
5255
  }
5256
  }
5257
  }
@@ -5267,10 +5270,10 @@ kernel void kernel_mul_mv_iq2_xs_f32(
5267
  ushort tiisg[[thread_index_in_simdgroup]],
5268
  ushort sgitg[[simdgroup_index_in_threadgroup]]) {
5269
 
5270
- kernel_mul_mv_iq2_xs_f32_impl<constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
5271
  }
5272
 
5273
- template <typename args_t>
5274
  void kernel_mul_mv_iq3_xxs_f32_impl(
5275
  args_t args,
5276
  device const char * src0,
@@ -5286,7 +5289,7 @@ void kernel_mul_mv_iq3_xxs_f32_impl(
5286
  const int r1 = tgpig.y;
5287
  const int im = tgpig.z;
5288
 
5289
- const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
5290
 
5291
  const uint i12 = im%args.ne12;
5292
  const uint i13 = im/args.ne12;
@@ -5298,7 +5301,7 @@ void kernel_mul_mv_iq3_xxs_f32_impl(
5298
  device const float * y = (device const float *) (src1 + offset1);
5299
 
5300
  float yl[32];
5301
- float sumf[N_DST]={0.f}, all_sum;
5302
 
5303
  const int nb32 = nb * (QK_K / 32);
5304
 
@@ -5319,7 +5322,7 @@ void kernel_mul_mv_iq3_xxs_f32_impl(
5319
  device const float * y4 = y + 32 * ix;
5320
 
5321
  for (int ib32 = ix; ib32 < nb32; ib32 += 32) {
5322
- for (int i = 0; i < 32; ++i) {
5323
  yl[i] = y4[i];
5324
  }
5325
 
@@ -5331,17 +5334,17 @@ void kernel_mul_mv_iq3_xxs_f32_impl(
5331
  device const uint16_t * gas = (device const uint16_t *)(xr->qs + QK_K/4) + 2 * ib;
5332
  device const half * dh = &xr->d;
5333
 
5334
- for (int row = 0; row < N_DST; row++) {
5335
  const float db = dh[0];
5336
  const uint32_t aux32 = gas[0] | (gas[1] << 16);
5337
  const float d = db * (0.5f + (aux32 >> 28));
5338
 
5339
  float2 sum = {0};
5340
- for (int l = 0; l < 4; ++l) {
5341
  const threadgroup uint8_t * grid1 = (const threadgroup uint8_t *)(svalues + q3[2*l+0]);
5342
  const threadgroup uint8_t * grid2 = (const threadgroup uint8_t *)(svalues + q3[2*l+1]);
5343
  const uint8_t signs = ssigns[(aux32 >> 7*l) & 127];
5344
- for (int j = 0; j < 4; ++j) {
5345
  sum[0] += yl[8*l + j + 0] * grid1[j] * (signs & kmask_iq2xs[j+0] ? -1.f : 1.f);
5346
  sum[1] += yl[8*l + j + 4] * grid2[j] * (signs & kmask_iq2xs[j+4] ? -1.f : 1.f);
5347
  }
@@ -5358,10 +5361,10 @@ void kernel_mul_mv_iq3_xxs_f32_impl(
5358
 
5359
  device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
5360
 
5361
- for (int row = 0; row < N_DST && first_row + row < args.ne0; ++row) {
5362
- all_sum = simd_sum(sumf[row]);
5363
  if (tiisg == 0) {
5364
- dst_f32[first_row + row] = all_sum * 0.5f;
5365
  }
5366
  }
5367
  }
@@ -5377,10 +5380,10 @@ kernel void kernel_mul_mv_iq3_xxs_f32(
5377
  ushort tiisg[[thread_index_in_simdgroup]],
5378
  ushort sgitg[[simdgroup_index_in_threadgroup]]) {
5379
 
5380
- kernel_mul_mv_iq3_xxs_f32_impl<constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
5381
  }
5382
 
5383
- template<typename args_t>
5384
  void kernel_mul_mv_iq3_s_f32_impl(
5385
  args_t args,
5386
  device const char * src0,
@@ -5396,7 +5399,7 @@ void kernel_mul_mv_iq3_s_f32_impl(
5396
  const int r1 = tgpig.y;
5397
  const int im = tgpig.z;
5398
 
5399
- const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
5400
 
5401
  const uint i12 = im%args.ne12;
5402
  const uint i13 = im/args.ne12;
@@ -5408,7 +5411,7 @@ void kernel_mul_mv_iq3_s_f32_impl(
5408
  device const float * y = (device const float *) (src1 + offset1);
5409
 
5410
  float yl[32];
5411
- float sumf[N_DST]={0.f}, all_sum;
5412
 
5413
  const int nb32 = nb * (QK_K / 32);
5414
 
@@ -5425,8 +5428,7 @@ void kernel_mul_mv_iq3_s_f32_impl(
5425
  device const float * y4 = y + 32 * ix;
5426
 
5427
  for (int ib32 = ix; ib32 < nb32; ib32 += 32) {
5428
-
5429
- for (int i = 0; i < 32; ++i) {
5430
  yl[i] = y4[i];
5431
  }
5432
 
@@ -5440,18 +5442,17 @@ void kernel_mul_mv_iq3_s_f32_impl(
5440
  device const uint8_t * signs = xr->signs + 4 * ib;
5441
  device const half * dh = &xr->d;
5442
 
5443
- for (int row = 0; row < N_DST; row++) {
5444
-
5445
  const float db = dh[0];
5446
  const float d = db * (1 + 2*((sc[0] >> 4*(ib%2)) & 0xf));
5447
 
5448
  float2 sum = {0};
5449
- for (int l = 0; l < 4; ++l) {
5450
  const threadgroup uint32_t * table1 = qh[0] & kmask_iq2xs[2*l+0] ? svalues + 256 : svalues;
5451
  const threadgroup uint32_t * table2 = qh[0] & kmask_iq2xs[2*l+1] ? svalues + 256 : svalues;
5452
  const threadgroup uint8_t * grid1 = (const threadgroup uint8_t *)(table1 + qs[2*l+0]);
5453
  const threadgroup uint8_t * grid2 = (const threadgroup uint8_t *)(table2 + qs[2*l+1]);
5454
- for (int j = 0; j < 4; ++j) {
5455
  sum[0] += yl[8*l + j + 0] * grid1[j] * select(1, -1, signs[l] & kmask_iq2xs[j+0]);
5456
  sum[1] += yl[8*l + j + 4] * grid2[j] * select(1, -1, signs[l] & kmask_iq2xs[j+4]);
5457
  }
@@ -5470,10 +5471,10 @@ void kernel_mul_mv_iq3_s_f32_impl(
5470
 
5471
  device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
5472
 
5473
- for (int row = 0; row < N_DST && first_row + row < args.ne0; ++row) {
5474
- all_sum = simd_sum(sumf[row]);
5475
  if (tiisg == 0) {
5476
- dst_f32[first_row + row] = all_sum;
5477
  }
5478
  }
5479
  }
@@ -5489,10 +5490,10 @@ kernel void kernel_mul_mv_iq3_s_f32(
5489
  ushort tiisg[[thread_index_in_simdgroup]],
5490
  ushort sgitg[[simdgroup_index_in_threadgroup]]) {
5491
 
5492
- kernel_mul_mv_iq3_s_f32_impl<constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
5493
  }
5494
 
5495
- template <typename args_t>
5496
  void kernel_mul_mv_iq2_s_f32_impl(
5497
  args_t args,
5498
  device const char * src0,
@@ -5508,7 +5509,7 @@ void kernel_mul_mv_iq2_s_f32_impl(
5508
  const int r1 = tgpig.y;
5509
  const int im = tgpig.z;
5510
 
5511
- const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
5512
 
5513
  const uint i12 = im%args.ne12;
5514
  const uint i13 = im/args.ne12;
@@ -5520,7 +5521,7 @@ void kernel_mul_mv_iq2_s_f32_impl(
5520
  device const float * y = (device const float *) (src1 + offset1);
5521
 
5522
  float yl[32];
5523
- float sumf[N_DST]={0.f}, all_sum;
5524
 
5525
  const int nb32 = nb * (QK_K / 32);
5526
 
@@ -5532,13 +5533,12 @@ void kernel_mul_mv_iq2_s_f32_impl(
5532
  // threadgroup_barrier(mem_flags::mem_threadgroup);
5533
  //}
5534
 
5535
- const int ix = tiisg;
5536
 
5537
  device const float * y4 = y + 32 * ix;
5538
 
5539
  for (int ib32 = ix; ib32 < nb32; ib32 += 32) {
5540
-
5541
- for (int i = 0; i < 32; ++i) {
5542
  yl[i] = y4[i];
5543
  }
5544
 
@@ -5552,19 +5552,18 @@ void kernel_mul_mv_iq2_s_f32_impl(
5552
  device const uint8_t * signs = qs + QK_K/8;
5553
  device const half * dh = &xr->d;
5554
 
5555
- for (int row = 0; row < N_DST; row++) {
5556
-
5557
  const float db = dh[0];
5558
  const float d1 = db * (0.5f + (sc[0] & 0xf));
5559
  const float d2 = db * (0.5f + (sc[0] >> 4));
5560
 
5561
  float2 sum = {0};
5562
- for (int l = 0; l < 2; ++l) {
5563
  //const threadgroup uint8_t * grid1 = (const threadgroup uint8_t *)(svalues + (qs[l+0] | ((qh[0] << (8-2*l)) & 0x300)));
5564
  //const threadgroup uint8_t * grid2 = (const threadgroup uint8_t *)(svalues + (qs[l+2] | ((qh[0] << (4-2*l)) & 0x300)));
5565
  constant uint8_t * grid1 = (constant uint8_t *)(iq2s_grid + (qs[l+0] | ((qh[0] << (8-2*l)) & 0x300)));
5566
  constant uint8_t * grid2 = (constant uint8_t *)(iq2s_grid + (qs[l+2] | ((qh[0] << (4-2*l)) & 0x300)));
5567
- for (int j = 0; j < 8; ++j) {
5568
  sum[0] += yl[8*l + j + 0] * grid1[j] * select(1, -1, signs[l+0] & kmask_iq2xs[j]);
5569
  sum[1] += yl[8*l + j + 16] * grid2[j] * select(1, -1, signs[l+2] & kmask_iq2xs[j]);
5570
  }
@@ -5583,10 +5582,10 @@ void kernel_mul_mv_iq2_s_f32_impl(
5583
 
5584
  device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
5585
 
5586
- for (int row = 0; row < N_DST && first_row + row < args.ne0; ++row) {
5587
- all_sum = simd_sum(sumf[row]);
5588
  if (tiisg == 0) {
5589
- dst_f32[first_row + row] = all_sum * 0.25f;
5590
  }
5591
  }
5592
  }
@@ -5602,10 +5601,10 @@ kernel void kernel_mul_mv_iq2_s_f32(
5602
  ushort tiisg[[thread_index_in_simdgroup]],
5603
  ushort sgitg[[simdgroup_index_in_threadgroup]]) {
5604
 
5605
- kernel_mul_mv_iq2_s_f32_impl<constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
5606
  }
5607
 
5608
- template<typename args_t>
5609
  void kernel_mul_mv_iq1_s_f32_impl(
5610
  args_t args,
5611
  device const char * src0,
@@ -5621,7 +5620,7 @@ void kernel_mul_mv_iq1_s_f32_impl(
5621
  const int r1 = tgpig.y;
5622
  const int im = tgpig.z;
5623
 
5624
- const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
5625
 
5626
  const uint i12 = im%args.ne12;
5627
  const uint i13 = im/args.ne12;
@@ -5633,18 +5632,17 @@ void kernel_mul_mv_iq1_s_f32_impl(
5633
  device const float * y = (device const float *) (src1 + offset1);
5634
 
5635
  float yl[32];
5636
- float sumf[N_DST]={0.f}, all_sum;
5637
 
5638
  const int nb32 = nb * (QK_K / 32);
5639
 
5640
- const int ix = tiisg;
5641
 
5642
  device const float * y4 = y + 32 * ix;
5643
 
5644
  for (int ib32 = ix; ib32 < nb32; ib32 += 32) {
5645
-
5646
  float sumy = 0;
5647
- for (int i = 0; i < 32; ++i) {
5648
  yl[i] = y4[i];
5649
  sumy += yl[i];
5650
  }
@@ -5657,15 +5655,14 @@ void kernel_mul_mv_iq1_s_f32_impl(
5657
  device const uint16_t * qh = xr->qh + ib;
5658
  device const half * dh = &xr->d;
5659
 
5660
- for (int row = 0; row < N_DST; row++) {
5661
-
5662
  constant uint8_t * grid1 = (constant uint8_t *)(iq1s_grid_gpu + (qs[0] | ((qh[0] << 8) & 0x700)));
5663
  constant uint8_t * grid2 = (constant uint8_t *)(iq1s_grid_gpu + (qs[1] | ((qh[0] << 5) & 0x700)));
5664
  constant uint8_t * grid3 = (constant uint8_t *)(iq1s_grid_gpu + (qs[2] | ((qh[0] << 2) & 0x700)));
5665
  constant uint8_t * grid4 = (constant uint8_t *)(iq1s_grid_gpu + (qs[3] | ((qh[0] >> 1) & 0x700)));
5666
 
5667
  float sum = 0;
5668
- for (int j = 0; j < 4; ++j) {
5669
  sum += yl[j+ 0] * (grid1[j] & 0xf) + yl[j+ 4] * (grid1[j] >> 4)
5670
  + yl[j+ 8] * (grid2[j] & 0xf) + yl[j+12] * (grid2[j] >> 4)
5671
  + yl[j+16] * (grid3[j] & 0xf) + yl[j+20] * (grid3[j] >> 4)
@@ -5683,15 +5680,28 @@ void kernel_mul_mv_iq1_s_f32_impl(
5683
 
5684
  device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
5685
 
5686
- for (int row = 0; row < N_DST && first_row + row < args.ne0; ++row) {
5687
- all_sum = simd_sum(sumf[row]);
5688
  if (tiisg == 0) {
5689
- dst_f32[first_row + row] = all_sum;
5690
  }
5691
  }
5692
  }
5693
 
5694
- template <typename args_t>
 
 
 
 
 
 
 
 
 
 
 
 
 
5695
  void kernel_mul_mv_iq1_m_f32_impl(
5696
  args_t args,
5697
  device const char * src0,
@@ -5703,11 +5713,12 @@ void kernel_mul_mv_iq1_m_f32_impl(
5703
  ushort sgitg) {
5704
 
5705
  const int nb = args.ne00/QK_K;
 
5706
  const int r0 = tgpig.x;
5707
  const int r1 = tgpig.y;
5708
  const int im = tgpig.z;
5709
 
5710
- const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
5711
 
5712
  const uint i12 = im%args.ne12;
5713
  const uint i13 = im/args.ne12;
@@ -5719,20 +5730,19 @@ void kernel_mul_mv_iq1_m_f32_impl(
5719
  device const float * y = (device const float *) (src1 + offset1);
5720
 
5721
  float yl[32];
5722
- float sumf[N_DST]={0.f}, all_sum;
5723
 
5724
  const int nb32 = nb * (QK_K / 32);
5725
 
5726
- const int ix = tiisg;
5727
 
5728
  device const float * y4 = y + 32 * ix;
5729
 
5730
  iq1m_scale_t scale;
5731
 
5732
  for (int ib32 = ix; ib32 < nb32; ib32 += 32) {
5733
-
5734
  float4 sumy = {0.f};
5735
- for (int i = 0; i < 8; ++i) {
5736
  yl[i+ 0] = y4[i+ 0]; sumy[0] += yl[i+ 0];
5737
  yl[i+ 8] = y4[i+ 8]; sumy[1] += yl[i+ 8];
5738
  yl[i+16] = y4[i+16]; sumy[2] += yl[i+16];
@@ -5747,7 +5757,7 @@ void kernel_mul_mv_iq1_m_f32_impl(
5747
  device const uint8_t * qh = xr->qh + 2 * ib;
5748
  device const uint16_t * sc = (device const uint16_t *)xr->scales;
5749
 
5750
- for (int row = 0; row < N_DST; row++) {
5751
  scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000);
5752
 
5753
  constant uint8_t * grid1 = (constant uint8_t *)(iq1s_grid_gpu + (qs[0] | ((qh[0] << 8) & 0x700)));
@@ -5756,7 +5766,7 @@ void kernel_mul_mv_iq1_m_f32_impl(
5756
  constant uint8_t * grid4 = (constant uint8_t *)(iq1s_grid_gpu + (qs[3] | ((qh[1] << 4) & 0x700)));
5757
 
5758
  float2 sum = {0.f};
5759
- for (int j = 0; j < 4; ++j) {
5760
  sum[0] += yl[j+ 0] * (grid1[j] & 0xf) + yl[j+ 4] * (grid1[j] >> 4)
5761
  + yl[j+ 8] * (grid2[j] & 0xf) + yl[j+12] * (grid2[j] >> 4);
5762
  sum[1] += yl[j+16] * (grid3[j] & 0xf) + yl[j+20] * (grid3[j] >> 4)
@@ -5778,15 +5788,28 @@ void kernel_mul_mv_iq1_m_f32_impl(
5778
 
5779
  device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
5780
 
5781
- for (int row = 0; row < N_DST && first_row + row < args.ne0; ++row) {
5782
- all_sum = simd_sum(sumf[row]);
5783
  if (tiisg == 0) {
5784
- dst_f32[first_row + row] = all_sum;
5785
  }
5786
  }
5787
  }
5788
 
5789
- template<typename args_t>
 
 
 
 
 
 
 
 
 
 
 
 
 
5790
  void kernel_mul_mv_iq4_nl_f32_impl(
5791
  args_t args,
5792
  device const char * src0,
@@ -5799,10 +5822,12 @@ void kernel_mul_mv_iq4_nl_f32_impl(
5799
 
5800
  threadgroup float * shmem_f32 = (threadgroup float *) shmem;
5801
  const int nb = args.ne00/QK4_NL;
 
5802
  const int r0 = tgpig.x;
5803
  const int r1 = tgpig.y;
5804
  const int im = tgpig.z;
5805
- const int first_row = (r0 * 2 + sgitg) * 2;
 
5806
 
5807
  const uint i12 = im%args.ne12;
5808
  const uint i13 = im/args.ne12;
@@ -5813,14 +5838,14 @@ void kernel_mul_mv_iq4_nl_f32_impl(
5813
  device const block_iq4_nl * x = (device const block_iq4_nl *) (src0 + offset0);
5814
  device const float * y = (device const float *) (src1 + offset1);
5815
 
5816
- const int ix = tiisg/2; // 0...15
5817
- const int it = tiisg%2; // 0 or 1
5818
 
5819
  shmem_f32[tiisg] = kvalues_iq4nl_f[tiisg%16];
5820
  threadgroup_barrier(mem_flags::mem_threadgroup);
5821
 
5822
  float4 yl[4];
5823
- float sumf[2]={0.f}, all_sum;
5824
 
5825
  device const float * yb = y + ix * QK4_NL + it * 8;
5826
 
@@ -5830,12 +5855,13 @@ void kernel_mul_mv_iq4_nl_f32_impl(
5830
  float4 qf1, qf2;
5831
 
5832
  for (int ib = ix; ib < nb; ib += 16) {
5833
-
5834
  device const float4 * y4 = (device const float4 *)yb;
5835
- yl[0] = y4[0]; yl[1] = y4[4]; yl[2] = y4[1]; yl[3] = y4[5];
5836
-
5837
- for (int row = 0; row < 2 && first_row + row < args.ne01; ++row) {
 
5838
 
 
5839
  device const block_iq4_nl & xb = x[row*nb + ib];
5840
  device const uint16_t * q4 = (device const uint16_t *)(xb.qs + 8*it);
5841
 
@@ -5860,7 +5886,6 @@ void kernel_mul_mv_iq4_nl_f32_impl(
5860
  acc1 += acc2;
5861
 
5862
  sumf[row] += (float)xb.d * (acc1[0] + acc1[1] + acc1[2] + acc1[3]);
5863
-
5864
  }
5865
 
5866
  yb += 16 * QK4_NL;
@@ -5868,15 +5893,29 @@ void kernel_mul_mv_iq4_nl_f32_impl(
5868
 
5869
  device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
5870
 
5871
- for (int row = 0; row < 2 && first_row + row < args.ne0; ++row) {
5872
- all_sum = simd_sum(sumf[row]);
5873
  if (tiisg == 0) {
5874
- dst_f32[first_row + row] = all_sum;
5875
  }
5876
  }
5877
  }
5878
 
5879
- template<typename args_t>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5880
  void kernel_mul_mv_iq4_xs_f32_impl(
5881
  args_t args,
5882
  device const char * src0,
@@ -5892,7 +5931,7 @@ void kernel_mul_mv_iq4_xs_f32_impl(
5892
  const int r0 = tgpig.x;
5893
  const int r1 = tgpig.y;
5894
  const int im = tgpig.z;
5895
- const int first_row = (r0 * 2 + sgitg) * 2;
5896
 
5897
  const uint i12 = im%args.ne12;
5898
  const uint i13 = im/args.ne12;
@@ -5903,16 +5942,16 @@ void kernel_mul_mv_iq4_xs_f32_impl(
5903
  device const block_iq4_xs * x = (device const block_iq4_xs *) (src0 + offset0);
5904
  device const float * y = (device const float *) (src1 + offset1);
5905
 
5906
- const int ix = tiisg/16; // 0 or 1
5907
- const int it = tiisg%16; // 0...15
5908
- const int ib = it/2;
5909
- const int il = it%2;
5910
 
5911
  shmem_f32[tiisg] = kvalues_iq4nl_f[tiisg%16];
5912
  threadgroup_barrier(mem_flags::mem_threadgroup);
5913
 
5914
  float4 yl[4];
5915
- float sumf[2]={0.f}, all_sum;
5916
 
5917
  device const float * yb = y + ix * QK_K + ib * 32 + il * 8;
5918
 
@@ -5923,9 +5962,12 @@ void kernel_mul_mv_iq4_xs_f32_impl(
5923
 
5924
  for (int ibl = ix; ibl < nb; ibl += 2) {
5925
  device const float4 * y4 = (device const float4 *)yb;
5926
- yl[0] = y4[0]; yl[1] = y4[4]; yl[2] = y4[1]; yl[3] = y4[5];
 
 
 
5927
 
5928
- for (int row = 0; row < 2; ++row) {
5929
  device const block_iq4_xs & xb = x[row*nb + ibl];
5930
  device const uint32_t * q4 = (device const uint32_t *)(xb.qs + 16*ib + 8*il);
5931
 
@@ -5949,7 +5991,6 @@ void kernel_mul_mv_iq4_xs_f32_impl(
5949
 
5950
  const int ls = (((xb.scales_l[ib/2] >> 4*(ib%2)) & 0xf) | (((xb.scales_h >> 2*ib) & 3) << 4)) - 32;
5951
  sumf[row] += (float)xb.d * ls * (acc1[0] + acc1[1] + acc1[2] + acc1[3]);
5952
-
5953
  }
5954
 
5955
  yb += 2 * QK_K;
@@ -5957,54 +5998,14 @@ void kernel_mul_mv_iq4_xs_f32_impl(
5957
 
5958
  device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
5959
 
5960
- for (int row = 0; row < 2 && first_row + row < args.ne0; ++row) {
5961
- all_sum = simd_sum(sumf[row]);
5962
  if (tiisg == 0) {
5963
- dst_f32[first_row + row] = all_sum;
5964
  }
5965
  }
5966
  }
5967
 
5968
- [[host_name("kernel_mul_mv_iq1_s_f32")]]
5969
- kernel void kernel_mul_mv_iq1_s_f32(
5970
- constant ggml_metal_kargs_mul_mv & args,
5971
- device const char * src0,
5972
- device const char * src1,
5973
- device char * dst,
5974
- uint3 tgpig[[threadgroup_position_in_grid]],
5975
- ushort tiisg[[thread_index_in_simdgroup]],
5976
- ushort sgitg[[simdgroup_index_in_threadgroup]]) {
5977
-
5978
- kernel_mul_mv_iq1_s_f32_impl<constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
5979
- }
5980
-
5981
- [[host_name("kernel_mul_mv_iq1_m_f32")]]
5982
- kernel void kernel_mul_mv_iq1_m_f32(
5983
- constant ggml_metal_kargs_mul_mv & args,
5984
- device const char * src0,
5985
- device const char * src1,
5986
- device char * dst,
5987
- uint3 tgpig[[threadgroup_position_in_grid]],
5988
- ushort tiisg[[thread_index_in_simdgroup]],
5989
- ushort sgitg[[simdgroup_index_in_threadgroup]]) {
5990
-
5991
- kernel_mul_mv_iq1_m_f32_impl<constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
5992
- }
5993
-
5994
- [[host_name("kernel_mul_mv_iq4_nl_f32")]]
5995
- kernel void kernel_mul_mv_iq4_nl_f32(
5996
- constant ggml_metal_kargs_mul_mv & args,
5997
- device const char * src0,
5998
- device const char * src1,
5999
- device char * dst,
6000
- threadgroup char * shmem [[threadgroup(0)]],
6001
- uint3 tgpig[[threadgroup_position_in_grid]],
6002
- ushort tiisg[[thread_index_in_simdgroup]],
6003
- ushort sgitg[[simdgroup_index_in_threadgroup]]) {
6004
-
6005
- kernel_mul_mv_iq4_nl_f32_impl<constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
6006
- }
6007
-
6008
  [[host_name("kernel_mul_mv_iq4_xs_f32")]]
6009
  kernel void kernel_mul_mv_iq4_xs_f32(
6010
  constant ggml_metal_kargs_mul_mv & args,
@@ -6016,7 +6017,7 @@ kernel void kernel_mul_mv_iq4_xs_f32(
6016
  ushort tiisg[[thread_index_in_simdgroup]],
6017
  ushort sgitg[[simdgroup_index_in_threadgroup]]) {
6018
 
6019
- kernel_mul_mv_iq4_xs_f32_impl<constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
6020
  }
6021
 
6022
  template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread float4x4 &)>
@@ -6660,25 +6661,27 @@ template [[host_name("kernel_mul_mv_id_f16_f32")]] kernel kernel_mul_mv_id_t
6660
  #if defined(GGML_METAL_USE_BF16)
6661
  template [[host_name("kernel_mul_mv_id_bf16_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_impl<bfloat, bfloat4, float, float4>>>;
6662
  #endif
6663
- template [[host_name("kernel_mul_mv_id_q8_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q8_0_f32_impl>>;
6664
- template [[host_name("kernel_mul_mv_id_q4_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>>>;
6665
- template [[host_name("kernel_mul_mv_id_q4_1_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>>>;
6666
- template [[host_name("kernel_mul_mv_id_q5_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q5_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>>>;
6667
- template [[host_name("kernel_mul_mv_id_q5_1_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q5_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>>>;
6668
- template [[host_name("kernel_mul_mv_id_q2_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q2_K_f32_impl>>;
6669
- template [[host_name("kernel_mul_mv_id_q3_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q3_K_f32_impl>>;
6670
- template [[host_name("kernel_mul_mv_id_q4_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q4_K_f32_impl>>;
6671
- template [[host_name("kernel_mul_mv_id_q5_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q5_K_f32_impl>>;
6672
- template [[host_name("kernel_mul_mv_id_q6_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q6_K_f32_impl>>;
6673
- template [[host_name("kernel_mul_mv_id_iq1_s_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq1_s_f32_impl>>;
6674
- template [[host_name("kernel_mul_mv_id_iq1_m_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq1_m_f32_impl>>;
6675
- template [[host_name("kernel_mul_mv_id_iq2_xxs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq2_xxs_f32_impl>>;
6676
- template [[host_name("kernel_mul_mv_id_iq2_xs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq2_xs_f32_impl>>;
6677
- template [[host_name("kernel_mul_mv_id_iq3_xxs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq3_xxs_f32_impl>>;
6678
- template [[host_name("kernel_mul_mv_id_iq3_s_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq3_s_f32_impl>>;
6679
- template [[host_name("kernel_mul_mv_id_iq2_s_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq2_s_f32_impl>>;
6680
- template [[host_name("kernel_mul_mv_id_iq4_nl_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq4_nl_f32_impl>>;
6681
- template [[host_name("kernel_mul_mv_id_iq4_xs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq4_xs_f32_impl>>;
 
 
6682
 
6683
  kernel void kernel_pool_2d_max_f32(
6684
  device const float * src0,
 
1439
 
1440
  float4 sa_vec(0.0);
1441
 
1442
+ for (uint j = 0; j < head_size; j += 4) {
1443
  float4 a_vec = float4(_a[j], _a[j+1], _a[j+2], _a[j+3]);
1444
  float4 s_vec = float4(state[j], state[j+1], state[j+2], state[j+3]);
1445
  sa_vec += a_vec * s_vec;
 
1853
  return d * (acc[0] + acc[1] + acc[2] + acc[3]) + sumy * m;
1854
  }
1855
 
1856
+ template<typename block_q_type, int nr0, int nsg, int nw, typename args_t>
 
 
 
 
 
 
 
1857
  void mul_vec_q_n_f32_impl(
1858
  args_t args,
1859
  device const char * src0,
 
1869
  const int r1 = tgpig.y;
1870
  const int im = tgpig.z;
1871
 
1872
+ const int first_row = (r0 * nsg + sgitg) * nr0;
1873
 
1874
  const uint i12 = im%args.ne12;
1875
  const uint i13 = im/args.ne12;
 
1881
  device const float * y = (device const float *) (src1 + offset1);
1882
 
1883
  // pointers to src0 rows
1884
+ device const block_q_type * ax[nr0];
1885
+ for (int row = 0; row < nr0; ++row) {
1886
  const uint64_t offset0 = (first_row + row)*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
1887
 
1888
  ax[row] = (device const block_q_type *) ((device char *) src0 + offset0);
1889
  }
1890
 
1891
  float yl[16]; // src1 vector cache
1892
+ float sumf[nr0] = {0.f};
1893
 
1894
  const short ix = (tiisg/2);
1895
  const short il = (tiisg%2)*8;
 
1901
  float sumy[2] = { 0.f, 0.f };
1902
 
1903
  #pragma unroll
1904
+ for (short i = 0; i < 8; i += 2) {
1905
  sumy[0] += yb[i + 0] + yb[i + 1];
1906
  yl[i + 0] = yb[i + 0];
1907
  yl[i + 1] = yb[i + 1]/256.f;
 
1912
  }
1913
 
1914
  #pragma unroll
1915
+ for (short row = 0; row < nr0; row++) {
1916
  sumf[row] += block_q_n_dot_y(ax[row] + ib, sumy[0] + sumy[1], yl, il);
1917
  }
1918
 
 
1921
 
1922
  device float * dst_f32 = (device float *) dst + im*args.ne0*args.ne1 + r1*args.ne0;
1923
 
1924
+ for (int row = 0; row < nr0; ++row) {
1925
  const float tot = simd_sum(sumf[row]);
1926
 
1927
  if (tiisg == 0 && first_row + row < args.ne01) {
 
1938
  uint3 tgpig[[threadgroup_position_in_grid]],
1939
  ushort tiisg[[thread_index_in_simdgroup]],
1940
  ushort sgitg[[simdgroup_index_in_threadgroup]]) {
1941
+ mul_vec_q_n_f32_impl<block_q4_0, N_R0_Q4_0, N_SG_Q4_0, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
1942
  }
1943
 
1944
  kernel void kernel_mul_mv_q4_1_f32(
 
1949
  uint3 tgpig[[threadgroup_position_in_grid]],
1950
  ushort tiisg[[thread_index_in_simdgroup]],
1951
  ushort sgitg[[simdgroup_index_in_threadgroup]]) {
1952
+ mul_vec_q_n_f32_impl<block_q4_1, N_R0_Q4_1, N_SG_Q4_1, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
1953
  }
1954
 
1955
  kernel void kernel_mul_mv_q5_0_f32(
 
1960
  uint3 tgpig[[threadgroup_position_in_grid]],
1961
  ushort tiisg[[thread_index_in_simdgroup]],
1962
  ushort sgitg[[simdgroup_index_in_threadgroup]]) {
1963
+ mul_vec_q_n_f32_impl<block_q5_0, N_R0_Q5_0, N_SG_Q5_0, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
1964
  }
1965
 
1966
  kernel void kernel_mul_mv_q5_1_f32(
 
1971
  uint3 tgpig[[threadgroup_position_in_grid]],
1972
  ushort tiisg[[thread_index_in_simdgroup]],
1973
  ushort sgitg[[simdgroup_index_in_threadgroup]]) {
1974
+ mul_vec_q_n_f32_impl<block_q5_1, N_R0_Q5_1, N_SG_Q5_1, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
1975
  }
1976
 
1977
  #define NB_Q8_0 8
1978
 
1979
+ template<int nr0, int nsg, int nw, typename args_t>
1980
  void kernel_mul_mv_q8_0_f32_impl(
1981
  args_t args,
1982
  device const char * src0,
 
1986
  uint3 tgpig,
1987
  ushort tiisg,
1988
  ushort sgitg) {
 
 
 
 
1989
  const int nb = args.ne00/QK8_0;
1990
+
1991
  const int r0 = tgpig.x;
1992
  const int r1 = tgpig.y;
1993
  const int im = tgpig.z;
1994
 
1995
+ const int first_row = (r0 * nsg + sgitg) * nr0;
1996
 
1997
  const uint i12 = im%args.ne12;
1998
  const uint i13 = im/args.ne12;
 
2004
  device const float * y = (device const float *) (src1 + offset1);
2005
 
2006
  // pointers to src0 rows
2007
+ device const block_q8_0 * ax[nr0];
2008
+ for (int row = 0; row < nr0; ++row) {
2009
  const uint64_t offset0 = (first_row + row)*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
2010
 
2011
  ax[row] = (device const block_q8_0 *) ((device char *) src0 + offset0);
2012
  }
2013
 
2014
  float yl[NB_Q8_0];
2015
+ float sumf[nr0] = { 0.f };
2016
 
2017
  const short ix = tiisg/4;
2018
  const short il = tiisg%4;
 
2025
  yl[i] = yb[i];
2026
  }
2027
 
2028
+ for (short row = 0; row < nr0; row++) {
2029
  device const int8_t * qs = ax[row][ib].qs + il*NB_Q8_0;
2030
  float sumq = 0.f;
2031
  for (short iq = 0; iq < NB_Q8_0; ++iq) {
 
2039
 
2040
  device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
2041
 
2042
+ for (int row = 0; row < nr0; ++row) {
2043
  const float tot = simd_sum(sumf[row]);
2044
 
2045
  if (tiisg == 0 && first_row + row < args.ne01) {
 
2057
  uint3 tgpig[[threadgroup_position_in_grid]],
2058
  ushort tiisg[[thread_index_in_simdgroup]],
2059
  ushort sgitg[[simdgroup_index_in_threadgroup]]) {
2060
+ kernel_mul_mv_q8_0_f32_impl<N_R0_Q8_0, N_SG_Q8_0, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
2061
  }
2062
 
2063
  // mat-vec kernel processing in chunks of float4
 
2394
  sumf += (T0) x[i] * (T1) y[i];
2395
  }
2396
 
2397
+ float sum_all = simd_sum(sumf);
2398
  if (tiisg == 0) {
2399
+ dst_f32[(uint64_t)r1*args.ne0 + r0] = sum_all;
2400
  }
2401
  }
2402
  } else {
 
2417
  sumf += dot((float4) x4[i], (float4) y4[i]);
2418
  }
2419
 
2420
+ float sum_all = simd_sum(sumf);
2421
  if (tiisg == 0) {
2422
+ for (int i = 4*(args.ne00/4); i < args.ne00; ++i) sum_all += (float) (x[i] * y[i]);
2423
+ dst_f32[(uint64_t)r1*args.ne0 + r0] = sum_all;
2424
  }
2425
  }
2426
  }
 
2482
  for (int i = tiisg; i < args.ne00; i += 32) {
2483
  sumf += (float) x[i] * (float) y[i];
2484
  }
2485
+ float sum_all = simd_sum(sumf);
2486
  if (tiisg == 0) {
2487
+ dst_f32[r0] = sum_all;
2488
  }
2489
  } else {
2490
  device const T4 * x4 = (device const T4 *) x;
 
2494
  sumf += dot((float4) x4[i], y4[i]);
2495
  }
2496
 
2497
+ float sum_all = simd_sum(sumf);
2498
 
2499
  if (tiisg == 0) {
2500
+ for (int i = 4*(args.ne00/4); i < args.ne00; ++i) sum_all += (float) (x[i] * y[i]);
2501
+ dst_f32[r0] = sum_all;
2502
  }
2503
  }
2504
  }
 
2543
  sumf += dot((float4) x4[i], y4[i]);
2544
  }
2545
 
2546
+ float sum_all = simd_sum(sumf);
2547
  if (tiisg == 0) {
2548
+ dst_f32[(uint64_t)r1*args.ne0 + r0] = sum_all;
2549
  }
2550
  }
2551
  }
 
4311
  float amax = 0.0f; // absolute max
4312
  float max = 0.0f;
4313
 
4314
+ for (int j = 0; j < QK4_NL; j++) {
4315
  const float v = src[j];
4316
  if (amax < fabs(v)) {
4317
  amax = fabs(v);
 
4419
  }
4420
  }
4421
 
4422
+ template<int nr0, int nsg, int nw, typename args_t>
4423
  void kernel_mul_mv_q2_K_f32_impl(
4424
  args_t args,
4425
  device const char * src0,
 
4435
  const int r1 = tgpig.y;
4436
  const int im = tgpig.z;
4437
 
4438
+ const int first_row = (r0 * nsg + sgitg) * nr0;
4439
 
4440
  const uint i12 = im%args.ne12;
4441
  const uint i13 = im/args.ne12;
 
4447
  device const float * y = (device const float *) (src1 + offset1);
4448
 
4449
  float yl[32];
4450
+ float sumf[nr0]={0.f};
4451
 
4452
+ const short ix = tiisg/8; // 0...3
4453
+ const short it = tiisg%8; // 0...7
4454
+ const short iq = it/4; // 0 or 1
4455
+ const short ir = it%4; // 0...3
4456
+ const short is = (8*ir)/16;// 0 or 1
4457
 
4458
  device const float * y4 = y + ix * QK_K + 128 * iq + 8 * ir;
4459
 
4460
  for (int ib = ix; ib < nb; ib += 4) {
 
4461
  float4 sumy = {0.f, 0.f, 0.f, 0.f};
4462
+ for (short i = 0; i < 8; ++i) {
4463
  yl[i+ 0] = y4[i+ 0]; sumy[0] += yl[i+ 0];
4464
  yl[i+ 8] = y4[i+32]; sumy[1] += yl[i+ 8];
4465
  yl[i+16] = y4[i+64]; sumy[2] += yl[i+16];
 
4470
  device const uint16_t * qs = (device const uint16_t *)x[ib].qs + 16 * iq + 4 * ir;
4471
  device const half * dh = &x[ib].d;
4472
 
4473
+ for (short row = 0; row < nr0; row++) {
4474
  float4 acc1 = {0.f, 0.f, 0.f, 0.f};
4475
  float4 acc2 = {0.f, 0.f, 0.f, 0.f};
4476
  for (int i = 0; i < 8; i += 2) {
 
4501
 
4502
  device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
4503
 
4504
+ for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) {
4505
+ float sum_all = simd_sum(sumf[row]);
4506
  if (tiisg == 0) {
4507
+ dst_f32[first_row + row] = sum_all;
4508
  }
4509
  }
4510
  }
 
4519
  ushort tiisg[[thread_index_in_simdgroup]],
4520
  ushort sgitg[[simdgroup_index_in_threadgroup]]) {
4521
 
4522
+ kernel_mul_mv_q2_K_f32_impl<N_R0_Q2_K, N_SG_Q2_K, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
4523
  }
4524
 
4525
+ template<int nr0, int nsg, int nw, typename args_t>
4526
  void kernel_mul_mv_q3_K_f32_impl(
4527
  args_t args,
4528
  device const char * src0,
 
4539
  const int r1 = tgpig.y;
4540
  const int im = tgpig.z;
4541
 
4542
+ const int first_row = (r0 * nsg + sgitg) * nr0;
4543
 
4544
  const uint i12 = im%args.ne12;
4545
  const uint i13 = im/args.ne12;
 
4555
  //const uint16_t kmask1 = 0x3030;
4556
  //const uint16_t kmask2 = 0x0f0f;
4557
 
4558
+ const short tid = tiisg/4;
4559
+ const short ix = tiisg%4;
4560
+ const short ip = tid/4; // 0 or 1
4561
+ const short il = 2*((tid%4)/2); // 0 or 2
4562
+ const short ir = tid%2;
4563
+ const short l0 = 8*ir;
 
4564
 
4565
  // One would think that the Metal compiler would figure out that ip and il can only have
4566
  // 4 possible states, and optimize accordingly. Well, no. It needs help, and we do it
 
4585
  const uint16_t s_shift1 = 4*ip;
4586
  const uint16_t s_shift2 = s_shift1 + il;
4587
 
4588
+ const short q_offset = 32*ip + l0;
4589
+ const short y_offset = 128*ip + 32*il + l0;
4590
 
4591
  device const float * y1 = yy + ix*QK_K + y_offset;
4592
 
 
4594
  thread uint16_t * scales16 = (thread uint16_t *)&scales32;
4595
  thread const int8_t * scales = (thread const int8_t *)&scales32;
4596
 
4597
+ float sumf1[nr0] = {0.f};
4598
+ float sumf2[nr0] = {0.f};
4599
+
4600
  for (int i = ix; i < nb; i += 4) {
4601
+ for (short l = 0; l < 8; ++l) {
4602
  yl[l+ 0] = y1[l+ 0];
4603
  yl[l+ 8] = y1[l+16];
4604
  yl[l+16] = y1[l+32];
 
4610
  device const uint16_t * a = (device const uint16_t *)(x[i].scales);
4611
  device const half * dh = &x[i].d;
4612
 
4613
+ for (short row = 0; row < nr0; ++row) {
4614
  const float d_all = (float)dh[0];
4615
 
4616
  scales16[0] = a[4];
 
4621
  scales32 = ((scales32 >> s_shift1) & 0x0f0f0f0f) | aux32;
4622
 
4623
  float s1 = 0, s2 = 0, s3 = 0, s4 = 0, s5 = 0, s6 = 0;
4624
+ for (short l = 0; l < 8; l += 2) {
4625
  const int32_t qs = q[l/2];
4626
  s1 += yl[l+0] * (qs & qm[il/2][0]);
4627
  s2 += yl[l+1] * (qs & qm[il/2][1]);
 
4636
  sumf2[row] += d2 * (scales[2] - 32);
4637
 
4638
  s1 = s2 = s3 = s4 = s5 = s6 = 0;
4639
+ for (short l = 0; l < 8; l += 2) {
4640
  const int32_t qs = q[l/2+8];
4641
  s1 += yl[l+8] * (qs & qm[il/2][0]);
4642
  s2 += yl[l+9] * (qs & qm[il/2][1]);
 
4659
  y1 += 4 * QK_K;
4660
  }
4661
 
4662
+ for (int row = 0; row < nr0; ++row) {
4663
  const float sumf = (sumf1[row] + 0.25f * sumf2[row]) / (1 << shift);
4664
  sumf1[row] = simd_sum(sumf);
4665
  }
 
4667
  device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
4668
 
4669
  if (tiisg == 0) {
4670
+ for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) {
4671
  dst_f32[first_row + row] = sumf1[row];
4672
  }
4673
  }
 
4683
  ushort tiisg[[thread_index_in_simdgroup]],
4684
  ushort sgitg[[simdgroup_index_in_threadgroup]]) {
4685
 
4686
+ kernel_mul_mv_q3_K_f32_impl<N_R0_Q3_K, N_SG_Q3_K, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
4687
  }
4688
 
4689
+ template<int nr0, int nsg, int nw, typename args_t>
4690
  void kernel_mul_mv_q4_K_f32_impl(
4691
  args_t args,
4692
  device const char * src0,
 
4696
  uint3 tgpig,
4697
  ushort tiisg,
4698
  ushort sgitg) {
 
4699
  const uint16_t kmask1 = 0x3f3f;
4700
  const uint16_t kmask2 = 0x0f0f;
4701
  const uint16_t kmask3 = 0xc0c0;
4702
 
4703
+ const short ix = tiisg/8; // 0...3
4704
+ const short it = tiisg%8; // 0...7
4705
+ const short iq = it/4; // 0 or 1
4706
+ const short ir = it%4; // 0...3
4707
 
4708
  const int nb = args.ne00/QK_K;
4709
+
4710
  const int r0 = tgpig.x;
4711
  const int r1 = tgpig.y;
4712
  const int im = tgpig.z;
4713
+
4714
+ const int first_row = (r0 * nsg + sgitg) * nr0;
4715
 
4716
  const uint i12 = im%args.ne12;
4717
  const uint i13 = im/args.ne12;
 
4724
 
4725
  float yl[16];
4726
  float yh[16];
4727
+
4728
+ float sumf[nr0]={0.f};
4729
 
4730
  device const float * y4 = y + ix * QK_K + 64 * iq + 8 * ir;
4731
 
 
4734
 
4735
  for (int ib = ix; ib < nb; ib += 4) {
4736
  float4 sumy = {0.f, 0.f, 0.f, 0.f};
4737
+
4738
+ for (short i = 0; i < 8; ++i) {
4739
  yl[i+0] = y4[i+ 0]; sumy[0] += yl[i+0];
4740
  yl[i+8] = y4[i+ 32]; sumy[1] += yl[i+8];
4741
  yh[i+0] = y4[i+128]; sumy[2] += yh[i+0];
 
4746
  device const uint16_t * q1 = (device const uint16_t *)x[ib].qs + 16 * iq + 4 * ir;
4747
  device const half * dh = &x[ib].d;
4748
 
4749
+ for (short row = 0; row < nr0; row++) {
4750
  sc16[0] = sc[0] & kmask1;
4751
  sc16[1] = sc[2] & kmask1;
4752
  sc16[2] = ((sc[4] >> 0) & kmask2) | ((sc[0] & kmask3) >> 2);
 
4756
 
4757
  float4 acc1 = {0.f, 0.f, 0.f, 0.f};
4758
  float4 acc2 = {0.f, 0.f, 0.f, 0.f};
4759
+
4760
+ for (short i = 0; i < 4; ++i) {
4761
+ acc1[0] += yl[2*i + 0] * (q1[i] & 0x000F);
4762
+ acc1[1] += yl[2*i + 1] * (q1[i] & 0x0F00);
4763
+ acc1[2] += yl[2*i + 8] * (q1[i] & 0x00F0);
4764
+ acc1[3] += yl[2*i + 9] * (q1[i] & 0xF000);
4765
+ acc2[0] += yh[2*i + 0] * (q2[i] & 0x000F);
4766
+ acc2[1] += yh[2*i + 1] * (q2[i] & 0x0F00);
4767
+ acc2[2] += yh[2*i + 8] * (q2[i] & 0x00F0);
4768
+ acc2[3] += yh[2*i + 9] * (q2[i] & 0xF000);
4769
  }
4770
 
4771
  float dall = dh[0];
4772
  float dmin = dh[1];
4773
+
4774
  sumf[row] += dall * ((acc1[0] + 1.f/256.f * acc1[1]) * sc8[0] +
4775
  (acc1[2] + 1.f/256.f * acc1[3]) * sc8[1] * 1.f/16.f +
4776
  (acc2[0] + 1.f/256.f * acc2[1]) * sc8[4] +
 
4787
 
4788
  device float * dst_f32 = (device float *) dst + (int64_t)im*args.ne0*args.ne1 + (int64_t)r1*args.ne0;
4789
 
4790
+ for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) {
4791
+ float sum_all = simd_sum(sumf[row]);
4792
  if (tiisg == 0) {
4793
+ dst_f32[first_row + row] = sum_all;
4794
  }
4795
  }
4796
  }
 
4805
  ushort tiisg[[thread_index_in_simdgroup]],
4806
  ushort sgitg[[simdgroup_index_in_threadgroup]]) {
4807
 
4808
+ kernel_mul_mv_q4_K_f32_impl<N_R0_Q4_K, N_SG_Q4_K, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
4809
  }
4810
 
4811
+ template<int nr0, int nsg, int nw, typename args_t>
4812
  void kernel_mul_mv_q5_K_f32_impl(
4813
  args_t args,
4814
  device const char * src0,
 
4825
  const int r1 = tgpig.y;
4826
  const int im = tgpig.z;
4827
 
4828
+ const int first_row = (r0 * nsg + sgitg) * nr0;
4829
 
4830
  const uint i12 = im%args.ne12;
4831
  const uint i13 = im/args.ne12;
 
4836
  device const block_q5_K * x = (device const block_q5_K *) (src0 + offset0);
4837
  device const float * yy = (device const float *) (src1 + offset1);
4838
 
4839
+ float sumf[nr0]={0.f};
4840
 
4841
  float yl[16], yh[16];
4842
 
 
4844
  const uint16_t kmask2 = 0x0f0f;
4845
  const uint16_t kmask3 = 0xc0c0;
4846
 
4847
+ const short tid = tiisg/4;
4848
+ const short ix = tiisg%4;
4849
+ const short iq = tid/4;
4850
+ const short ir = tid%4;
 
4851
 
4852
+ const short l0 = 8*ir;
4853
+ const short q_offset = 32*iq + l0;
4854
+ const short y_offset = 64*iq + l0;
4855
 
4856
  const uint8_t hm1 = 1u << (2*iq);
4857
  const uint8_t hm2 = hm1 << 1;
 
4871
 
4872
  device const float * y2 = y1 + 128;
4873
  float4 sumy = {0.f, 0.f, 0.f, 0.f};
4874
+ for (short l = 0; l < 8; ++l) {
4875
  yl[l+0] = y1[l+ 0]; sumy[0] += yl[l+0];
4876
  yl[l+8] = y1[l+32]; sumy[1] += yl[l+8];
4877
  yh[l+0] = y2[l+ 0]; sumy[2] += yh[l+0];
4878
  yh[l+8] = y2[l+32]; sumy[3] += yh[l+8];
4879
  }
4880
 
4881
+ for (short row = 0; row < nr0; ++row) {
4882
  device const uint8_t * q2 = q1 + 64;
4883
 
4884
  sc16[0] = a[0] & kmask1;
 
4888
 
4889
  float4 acc1 = {0.f};
4890
  float4 acc2 = {0.f};
4891
+ for (short l = 0; l < 8; ++l) {
4892
  uint8_t h = qh[l];
4893
  acc1[0] += yl[l+0] * (q1[l] & 0x0F);
4894
  acc1[1] += yl[l+8] * (q1[l] & 0xF0);
 
4918
 
4919
  device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
4920
 
4921
+ for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) {
4922
  const float tot = simd_sum(sumf[row]);
4923
  if (tiisg == 0) {
4924
  dst_f32[first_row + row] = tot;
 
4936
  ushort tiisg[[thread_index_in_simdgroup]],
4937
  ushort sgitg[[simdgroup_index_in_threadgroup]]) {
4938
 
4939
+ kernel_mul_mv_q5_K_f32_impl<N_R0_Q5_K, N_SG_Q5_K, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
4940
  }
4941
 
4942
+ template<int nr0, int nsg, int nw, typename args_t>
4943
  void kernel_mul_mv_q6_K_f32_impl(
4944
  args_t args,
4945
  device const char * src0,
 
4961
  const int r1 = tgpig.y;
4962
  const int im = tgpig.z;
4963
 
4964
+ const int first_row = (r0 * nsg + sgitg) * nr0;
 
 
 
 
4965
 
4966
  const uint i12 = im%args.ne12;
4967
  const uint i13 = im/args.ne12;
4968
 
4969
+ const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
4970
+ const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
4971
 
4972
  device const block_q6_K * x = (device const block_q6_K *) (src0 + offset0);
4973
  device const float * yy = (device const float *) (src1 + offset1);
4974
 
4975
+ float sumf[nr0] = { 0.f };
4976
+
4977
+ float yl[16];
4978
 
4979
+ const short tid = tiisg/2;
4980
+ const short ix = tiisg%2;
4981
+ const short ip = tid/8; // 0 or 1
4982
+ const short il = tid%8;
4983
+ const short l0 = 4*il;
4984
+ const short is = 8*ip + l0/16;
 
4985
 
4986
+ const short y_offset = 128*ip + l0;
4987
+ const short q_offset_l = 64*ip + l0;
4988
+ const short q_offset_h = 32*ip + l0;
4989
 
4990
  for (int i = ix; i < nb; i += 2) {
4991
  device const uint8_t * q1 = x[i].ql + q_offset_l;
4992
  device const uint8_t * q2 = q1 + 32;
4993
  device const uint8_t * qh = x[i].qh + q_offset_h;
4994
  device const int8_t * sc = x[i].scales + is;
4995
+ device const half * dh = &x[i].d;
4996
 
4997
  device const float * y = yy + i * QK_K + y_offset;
4998
 
4999
+ for (short l = 0; l < 4; ++l) {
5000
+ yl[4*l + 0] = y[l + 0];
5001
+ yl[4*l + 1] = y[l + 32];
5002
+ yl[4*l + 2] = y[l + 64];
5003
+ yl[4*l + 3] = y[l + 96];
 
 
 
5004
  }
5005
 
5006
+ for (short row = 0; row < nr0; ++row) {
5007
+ const float dall = dh[0];
5008
+
5009
+ float4 sums = {0.f, 0.f, 0.f, 0.f};
5010
 
5011
+ for (short l = 0; l < 4; ++l) {
5012
+ sums[0] += yl[4*l + 0] * ((int8_t)((q1[l] & 0xF) | ((qh[l] & kmask1) << 4)) - 32);
5013
+ sums[1] += yl[4*l + 1] * ((int8_t)((q2[l] & 0xF) | ((qh[l] & kmask2) << 2)) - 32);
5014
+ sums[2] += yl[4*l + 2] * ((int8_t)((q1[l] >> 4) | ((qh[l] & kmask3) << 0)) - 32);
5015
+ sums[3] += yl[4*l + 3] * ((int8_t)((q2[l] >> 4) | ((qh[l] & kmask4) >> 2)) - 32);
5016
+ }
5017
+
5018
+ sumf[row] += dall * (sums[0] * sc[0] + sums[1] * sc[2] + sums[2] * sc[4] + sums[3] * sc[6]);
5019
+
5020
+ q1 += args.nb01;
5021
+ q2 += args.nb01;
5022
+ qh += args.nb01;
5023
+ sc += args.nb01;
5024
+ dh += args.nb01/2;
5025
+ }
5026
  }
5027
 
5028
  device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
5029
 
5030
+ for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) {
5031
+ float sum_all = simd_sum(sumf[row]);
5032
+ if (tiisg == 0) {
5033
+ dst_f32[first_row + row] = sum_all;
5034
+ }
5035
  }
5036
  }
5037
 
 
5045
  ushort tiisg[[thread_index_in_simdgroup]],
5046
  ushort sgitg[[simdgroup_index_in_threadgroup]]) {
5047
 
5048
+ kernel_mul_mv_q6_K_f32_impl<N_R0_Q6_K, N_SG_Q6_K, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
5049
  }
5050
 
5051
  // ======================= "True" 2-bit
5052
 
5053
+ template<int nr0, int nsg, int nw, typename args_t>
5054
  void kernel_mul_mv_iq2_xxs_f32_impl(
5055
  args_t args,
5056
  device const char * src0,
 
5066
  const int r1 = tgpig.y;
5067
  const int im = tgpig.z;
5068
 
5069
+ const int first_row = (r0 * nsg + sgitg) * nr0;
5070
 
5071
  const uint i12 = im%args.ne12;
5072
  const uint i13 = im/args.ne12;
 
5078
  device const float * y = (device const float *) (src1 + offset1);
5079
 
5080
  float yl[32];
5081
+ float sumf[nr0]={0.f};
5082
 
5083
  const int nb32 = nb * (QK_K / 32);
5084
 
 
5099
  device const float * y4 = y + 32 * ix;
5100
 
5101
  for (int ib32 = ix; ib32 < nb32; ib32 += 32) {
5102
+ for (short i = 0; i < 32; ++i) {
 
5103
  yl[i] = y4[i];
5104
  }
5105
 
 
5110
  device const uint16_t * q2 = xr->qs + 4 * ib;
5111
  device const half * dh = &xr->d;
5112
 
5113
+ for (short row = 0; row < nr0; row++) {
 
5114
  const float db = dh[0];
5115
  device const uint8_t * aux8 = (device const uint8_t *)q2;
5116
  const uint32_t aux32 = q2[2] | (q2[3] << 16);
5117
  const float d = db * (0.5f + (aux32 >> 28));
5118
 
5119
  float sum = 0;
5120
+ for (short l = 0; l < 4; ++l) {
5121
  const threadgroup uint8_t * grid = (const threadgroup uint8_t *)(svalues + aux8[l]);
5122
  const uint8_t signs = ssigns[(aux32 >> 7*l) & 127];
5123
+ for (short j = 0; j < 8; ++j) {
5124
  sum += yl[8*l + j] * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
5125
  }
5126
  }
 
5135
 
5136
  device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
5137
 
5138
+ for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) {
5139
+ float sum_all = simd_sum(sumf[row]);
5140
  if (tiisg == 0) {
5141
+ dst_f32[first_row + row] = sum_all * 0.25f;
5142
  }
5143
  }
5144
  }
 
5153
  uint3 tgpig[[threadgroup_position_in_grid]],
5154
  ushort tiisg[[thread_index_in_simdgroup]],
5155
  ushort sgitg[[simdgroup_index_in_threadgroup]]) {
5156
+ kernel_mul_mv_iq2_xxs_f32_impl<N_R0_IQ2_XXS, N_SG_IQ2_XXS, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
5157
  }
5158
 
5159
+ template<int nr0, int nsg, int nw, typename args_t>
5160
  void kernel_mul_mv_iq2_xs_f32_impl(
5161
  args_t args,
5162
  device const char * src0,
 
5172
  const int r1 = tgpig.y;
5173
  const int im = tgpig.z;
5174
 
5175
+ const int first_row = (r0 * nsg + sgitg) * nr0;
5176
 
5177
  const uint i12 = im%args.ne12;
5178
  const uint i13 = im/args.ne12;
 
5184
  device const float * y = (device const float *) (src1 + offset1);
5185
 
5186
  float yl[32];
5187
+ float sumf[nr0]={0.f};
5188
 
5189
  const int nb32 = nb * (QK_K / 32);
5190
 
 
5205
  device const float * y4 = y + 32 * ix;
5206
 
5207
  for (int ib32 = ix; ib32 < nb32; ib32 += 32) {
5208
+ for (short i = 0; i < 32; ++i) {
 
5209
  yl[i] = y4[i];
5210
  }
5211
 
 
5217
  device const uint8_t * sc = xr->scales + ib;
5218
  device const half * dh = &xr->d;
5219
 
5220
+ for (short row = 0; row < nr0; row++) {
 
5221
  const float db = dh[0];
5222
  const uint8_t ls1 = sc[0] & 0xf;
5223
  const uint8_t ls2 = sc[0] >> 4;
 
5225
  const float d2 = db * (0.5f + ls2);
5226
 
5227
  float sum1 = 0, sum2 = 0;
5228
+ for (short l = 0; l < 2; ++l) {
5229
  const threadgroup uint8_t * grid = (const threadgroup uint8_t *)(svalues + (q2[l] & 511));
5230
  const uint8_t signs = ssigns[(q2[l] >> 9)];
5231
+ for (short j = 0; j < 8; ++j) {
5232
  sum1 += yl[8*l + j] * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
5233
  }
5234
  }
5235
+ for (short l = 2; l < 4; ++l) {
5236
  const threadgroup uint8_t * grid = (const threadgroup uint8_t *)(svalues + (q2[l] & 511));
5237
  const uint8_t signs = ssigns[(q2[l] >> 9)];
5238
+ for (short j = 0; j < 8; ++j) {
5239
  sum2 += yl[8*l + j] * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
5240
  }
5241
  }
 
5251
 
5252
  device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
5253
 
5254
+ for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) {
5255
+ float sum_all = simd_sum(sumf[row]);
5256
  if (tiisg == 0) {
5257
+ dst_f32[first_row + row] = sum_all * 0.25f;
5258
  }
5259
  }
5260
  }
 
5270
  ushort tiisg[[thread_index_in_simdgroup]],
5271
  ushort sgitg[[simdgroup_index_in_threadgroup]]) {
5272
 
5273
+ kernel_mul_mv_iq2_xs_f32_impl<N_R0_IQ2_XS, N_SG_IQ2_XS, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
5274
  }
5275
 
5276
+ template<int nr0, int nsg, int nw, typename args_t>
5277
  void kernel_mul_mv_iq3_xxs_f32_impl(
5278
  args_t args,
5279
  device const char * src0,
 
5289
  const int r1 = tgpig.y;
5290
  const int im = tgpig.z;
5291
 
5292
+ const int first_row = (r0 * nsg + sgitg) * nr0;
5293
 
5294
  const uint i12 = im%args.ne12;
5295
  const uint i13 = im/args.ne12;
 
5301
  device const float * y = (device const float *) (src1 + offset1);
5302
 
5303
  float yl[32];
5304
+ float sumf[nr0]={0.f};
5305
 
5306
  const int nb32 = nb * (QK_K / 32);
5307
 
 
5322
  device const float * y4 = y + 32 * ix;
5323
 
5324
  for (int ib32 = ix; ib32 < nb32; ib32 += 32) {
5325
+ for (short i = 0; i < 32; ++i) {
5326
  yl[i] = y4[i];
5327
  }
5328
 
 
5334
  device const uint16_t * gas = (device const uint16_t *)(xr->qs + QK_K/4) + 2 * ib;
5335
  device const half * dh = &xr->d;
5336
 
5337
+ for (short row = 0; row < nr0; row++) {
5338
  const float db = dh[0];
5339
  const uint32_t aux32 = gas[0] | (gas[1] << 16);
5340
  const float d = db * (0.5f + (aux32 >> 28));
5341
 
5342
  float2 sum = {0};
5343
+ for (short l = 0; l < 4; ++l) {
5344
  const threadgroup uint8_t * grid1 = (const threadgroup uint8_t *)(svalues + q3[2*l+0]);
5345
  const threadgroup uint8_t * grid2 = (const threadgroup uint8_t *)(svalues + q3[2*l+1]);
5346
  const uint8_t signs = ssigns[(aux32 >> 7*l) & 127];
5347
+ for (short j = 0; j < 4; ++j) {
5348
  sum[0] += yl[8*l + j + 0] * grid1[j] * (signs & kmask_iq2xs[j+0] ? -1.f : 1.f);
5349
  sum[1] += yl[8*l + j + 4] * grid2[j] * (signs & kmask_iq2xs[j+4] ? -1.f : 1.f);
5350
  }
 
5361
 
5362
  device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
5363
 
5364
+ for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) {
5365
+ float sum_all = simd_sum(sumf[row]);
5366
  if (tiisg == 0) {
5367
+ dst_f32[first_row + row] = sum_all * 0.5f;
5368
  }
5369
  }
5370
  }
 
5380
  ushort tiisg[[thread_index_in_simdgroup]],
5381
  ushort sgitg[[simdgroup_index_in_threadgroup]]) {
5382
 
5383
+ kernel_mul_mv_iq3_xxs_f32_impl<N_R0_IQ3_XXS, N_SG_IQ3_XXS, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
5384
  }
5385
 
5386
+ template<int nr0, int nsg, int nw, typename args_t>
5387
  void kernel_mul_mv_iq3_s_f32_impl(
5388
  args_t args,
5389
  device const char * src0,
 
5399
  const int r1 = tgpig.y;
5400
  const int im = tgpig.z;
5401
 
5402
+ const int first_row = (r0 * nsg + sgitg) * nr0;
5403
 
5404
  const uint i12 = im%args.ne12;
5405
  const uint i13 = im/args.ne12;
 
5411
  device const float * y = (device const float *) (src1 + offset1);
5412
 
5413
  float yl[32];
5414
+ float sumf[nr0]={0.f};
5415
 
5416
  const int nb32 = nb * (QK_K / 32);
5417
 
 
5428
  device const float * y4 = y + 32 * ix;
5429
 
5430
  for (int ib32 = ix; ib32 < nb32; ib32 += 32) {
5431
+ for (short i = 0; i < 32; ++i) {
 
5432
  yl[i] = y4[i];
5433
  }
5434
 
 
5442
  device const uint8_t * signs = xr->signs + 4 * ib;
5443
  device const half * dh = &xr->d;
5444
 
5445
+ for (short row = 0; row < nr0; row++) {
 
5446
  const float db = dh[0];
5447
  const float d = db * (1 + 2*((sc[0] >> 4*(ib%2)) & 0xf));
5448
 
5449
  float2 sum = {0};
5450
+ for (short l = 0; l < 4; ++l) {
5451
  const threadgroup uint32_t * table1 = qh[0] & kmask_iq2xs[2*l+0] ? svalues + 256 : svalues;
5452
  const threadgroup uint32_t * table2 = qh[0] & kmask_iq2xs[2*l+1] ? svalues + 256 : svalues;
5453
  const threadgroup uint8_t * grid1 = (const threadgroup uint8_t *)(table1 + qs[2*l+0]);
5454
  const threadgroup uint8_t * grid2 = (const threadgroup uint8_t *)(table2 + qs[2*l+1]);
5455
+ for (short j = 0; j < 4; ++j) {
5456
  sum[0] += yl[8*l + j + 0] * grid1[j] * select(1, -1, signs[l] & kmask_iq2xs[j+0]);
5457
  sum[1] += yl[8*l + j + 4] * grid2[j] * select(1, -1, signs[l] & kmask_iq2xs[j+4]);
5458
  }
 
5471
 
5472
  device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
5473
 
5474
+ for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) {
5475
+ float sum_all = simd_sum(sumf[row]);
5476
  if (tiisg == 0) {
5477
+ dst_f32[first_row + row] = sum_all;
5478
  }
5479
  }
5480
  }
 
5490
  ushort tiisg[[thread_index_in_simdgroup]],
5491
  ushort sgitg[[simdgroup_index_in_threadgroup]]) {
5492
 
5493
+ kernel_mul_mv_iq3_s_f32_impl<N_R0_IQ3_S, N_SG_IQ3_S, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
5494
  }
5495
 
5496
+ template<int nr0, int nsg, int nw, typename args_t>
5497
  void kernel_mul_mv_iq2_s_f32_impl(
5498
  args_t args,
5499
  device const char * src0,
 
5509
  const int r1 = tgpig.y;
5510
  const int im = tgpig.z;
5511
 
5512
+ const int first_row = (r0 * nsg + sgitg) * nr0;
5513
 
5514
  const uint i12 = im%args.ne12;
5515
  const uint i13 = im/args.ne12;
 
5521
  device const float * y = (device const float *) (src1 + offset1);
5522
 
5523
  float yl[32];
5524
+ float sumf[nr0]={0.f};
5525
 
5526
  const int nb32 = nb * (QK_K / 32);
5527
 
 
5533
  // threadgroup_barrier(mem_flags::mem_threadgroup);
5534
  //}
5535
 
5536
+ const short ix = tiisg;
5537
 
5538
  device const float * y4 = y + 32 * ix;
5539
 
5540
  for (int ib32 = ix; ib32 < nb32; ib32 += 32) {
5541
+ for (short i = 0; i < 32; ++i) {
 
5542
  yl[i] = y4[i];
5543
  }
5544
 
 
5552
  device const uint8_t * signs = qs + QK_K/8;
5553
  device const half * dh = &xr->d;
5554
 
5555
+ for (short row = 0; row < nr0; row++) {
 
5556
  const float db = dh[0];
5557
  const float d1 = db * (0.5f + (sc[0] & 0xf));
5558
  const float d2 = db * (0.5f + (sc[0] >> 4));
5559
 
5560
  float2 sum = {0};
5561
+ for (short l = 0; l < 2; ++l) {
5562
  //const threadgroup uint8_t * grid1 = (const threadgroup uint8_t *)(svalues + (qs[l+0] | ((qh[0] << (8-2*l)) & 0x300)));
5563
  //const threadgroup uint8_t * grid2 = (const threadgroup uint8_t *)(svalues + (qs[l+2] | ((qh[0] << (4-2*l)) & 0x300)));
5564
  constant uint8_t * grid1 = (constant uint8_t *)(iq2s_grid + (qs[l+0] | ((qh[0] << (8-2*l)) & 0x300)));
5565
  constant uint8_t * grid2 = (constant uint8_t *)(iq2s_grid + (qs[l+2] | ((qh[0] << (4-2*l)) & 0x300)));
5566
+ for (short j = 0; j < 8; ++j) {
5567
  sum[0] += yl[8*l + j + 0] * grid1[j] * select(1, -1, signs[l+0] & kmask_iq2xs[j]);
5568
  sum[1] += yl[8*l + j + 16] * grid2[j] * select(1, -1, signs[l+2] & kmask_iq2xs[j]);
5569
  }
 
5582
 
5583
  device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
5584
 
5585
+ for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) {
5586
+ float sum_all = simd_sum(sumf[row]);
5587
  if (tiisg == 0) {
5588
+ dst_f32[first_row + row] = sum_all * 0.25f;
5589
  }
5590
  }
5591
  }
 
5601
  ushort tiisg[[thread_index_in_simdgroup]],
5602
  ushort sgitg[[simdgroup_index_in_threadgroup]]) {
5603
 
5604
+ kernel_mul_mv_iq2_s_f32_impl<N_R0_IQ2_S, N_SG_IQ2_S, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
5605
  }
5606
 
5607
+ template<int nr0, int nsg, int nw, typename args_t>
5608
  void kernel_mul_mv_iq1_s_f32_impl(
5609
  args_t args,
5610
  device const char * src0,
 
5620
  const int r1 = tgpig.y;
5621
  const int im = tgpig.z;
5622
 
5623
+ const int first_row = (r0 * nsg + sgitg) * nr0;
5624
 
5625
  const uint i12 = im%args.ne12;
5626
  const uint i13 = im/args.ne12;
 
5632
  device const float * y = (device const float *) (src1 + offset1);
5633
 
5634
  float yl[32];
5635
+ float sumf[nr0]={0.f};
5636
 
5637
  const int nb32 = nb * (QK_K / 32);
5638
 
5639
+ const short ix = tiisg;
5640
 
5641
  device const float * y4 = y + 32 * ix;
5642
 
5643
  for (int ib32 = ix; ib32 < nb32; ib32 += 32) {
 
5644
  float sumy = 0;
5645
+ for (short i = 0; i < 32; ++i) {
5646
  yl[i] = y4[i];
5647
  sumy += yl[i];
5648
  }
 
5655
  device const uint16_t * qh = xr->qh + ib;
5656
  device const half * dh = &xr->d;
5657
 
5658
+ for (short row = 0; row < nr0; row++) {
 
5659
  constant uint8_t * grid1 = (constant uint8_t *)(iq1s_grid_gpu + (qs[0] | ((qh[0] << 8) & 0x700)));
5660
  constant uint8_t * grid2 = (constant uint8_t *)(iq1s_grid_gpu + (qs[1] | ((qh[0] << 5) & 0x700)));
5661
  constant uint8_t * grid3 = (constant uint8_t *)(iq1s_grid_gpu + (qs[2] | ((qh[0] << 2) & 0x700)));
5662
  constant uint8_t * grid4 = (constant uint8_t *)(iq1s_grid_gpu + (qs[3] | ((qh[0] >> 1) & 0x700)));
5663
 
5664
  float sum = 0;
5665
+ for (short j = 0; j < 4; ++j) {
5666
  sum += yl[j+ 0] * (grid1[j] & 0xf) + yl[j+ 4] * (grid1[j] >> 4)
5667
  + yl[j+ 8] * (grid2[j] & 0xf) + yl[j+12] * (grid2[j] >> 4)
5668
  + yl[j+16] * (grid3[j] & 0xf) + yl[j+20] * (grid3[j] >> 4)
 
5680
 
5681
  device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
5682
 
5683
+ for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) {
5684
+ float sum_all = simd_sum(sumf[row]);
5685
  if (tiisg == 0) {
5686
+ dst_f32[first_row + row] = sum_all;
5687
  }
5688
  }
5689
  }
5690
 
5691
+ [[host_name("kernel_mul_mv_iq1_s_f32")]]
5692
+ kernel void kernel_mul_mv_iq1_s_f32(
5693
+ constant ggml_metal_kargs_mul_mv & args,
5694
+ device const char * src0,
5695
+ device const char * src1,
5696
+ device char * dst,
5697
+ uint3 tgpig[[threadgroup_position_in_grid]],
5698
+ ushort tiisg[[thread_index_in_simdgroup]],
5699
+ ushort sgitg[[simdgroup_index_in_threadgroup]]) {
5700
+
5701
+ kernel_mul_mv_iq1_s_f32_impl<N_R0_IQ1_S, N_SG_IQ1_S, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
5702
+ }
5703
+
5704
+ template<int nr0, int nsg, int nw, typename args_t>
5705
  void kernel_mul_mv_iq1_m_f32_impl(
5706
  args_t args,
5707
  device const char * src0,
 
5713
  ushort sgitg) {
5714
 
5715
  const int nb = args.ne00/QK_K;
5716
+
5717
  const int r0 = tgpig.x;
5718
  const int r1 = tgpig.y;
5719
  const int im = tgpig.z;
5720
 
5721
+ const int first_row = (r0 * nsg + sgitg) * nr0;
5722
 
5723
  const uint i12 = im%args.ne12;
5724
  const uint i13 = im/args.ne12;
 
5730
  device const float * y = (device const float *) (src1 + offset1);
5731
 
5732
  float yl[32];
5733
+ float sumf[nr0]={0.f};
5734
 
5735
  const int nb32 = nb * (QK_K / 32);
5736
 
5737
+ const short ix = tiisg;
5738
 
5739
  device const float * y4 = y + 32 * ix;
5740
 
5741
  iq1m_scale_t scale;
5742
 
5743
  for (int ib32 = ix; ib32 < nb32; ib32 += 32) {
 
5744
  float4 sumy = {0.f};
5745
+ for (short i = 0; i < 8; ++i) {
5746
  yl[i+ 0] = y4[i+ 0]; sumy[0] += yl[i+ 0];
5747
  yl[i+ 8] = y4[i+ 8]; sumy[1] += yl[i+ 8];
5748
  yl[i+16] = y4[i+16]; sumy[2] += yl[i+16];
 
5757
  device const uint8_t * qh = xr->qh + 2 * ib;
5758
  device const uint16_t * sc = (device const uint16_t *)xr->scales;
5759
 
5760
+ for (short row = 0; row < nr0; row++) {
5761
  scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000);
5762
 
5763
  constant uint8_t * grid1 = (constant uint8_t *)(iq1s_grid_gpu + (qs[0] | ((qh[0] << 8) & 0x700)));
 
5766
  constant uint8_t * grid4 = (constant uint8_t *)(iq1s_grid_gpu + (qs[3] | ((qh[1] << 4) & 0x700)));
5767
 
5768
  float2 sum = {0.f};
5769
+ for (short j = 0; j < 4; ++j) {
5770
  sum[0] += yl[j+ 0] * (grid1[j] & 0xf) + yl[j+ 4] * (grid1[j] >> 4)
5771
  + yl[j+ 8] * (grid2[j] & 0xf) + yl[j+12] * (grid2[j] >> 4);
5772
  sum[1] += yl[j+16] * (grid3[j] & 0xf) + yl[j+20] * (grid3[j] >> 4)
 
5788
 
5789
  device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
5790
 
5791
+ for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) {
5792
+ float sum_all = simd_sum(sumf[row]);
5793
  if (tiisg == 0) {
5794
+ dst_f32[first_row + row] = sum_all;
5795
  }
5796
  }
5797
  }
5798
 
5799
+ [[host_name("kernel_mul_mv_iq1_m_f32")]]
5800
+ kernel void kernel_mul_mv_iq1_m_f32(
5801
+ constant ggml_metal_kargs_mul_mv & args,
5802
+ device const char * src0,
5803
+ device const char * src1,
5804
+ device char * dst,
5805
+ uint3 tgpig[[threadgroup_position_in_grid]],
5806
+ ushort tiisg[[thread_index_in_simdgroup]],
5807
+ ushort sgitg[[simdgroup_index_in_threadgroup]]) {
5808
+
5809
+ kernel_mul_mv_iq1_m_f32_impl<N_R0_IQ1_M, N_SG_IQ1_M, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
5810
+ }
5811
+
5812
+ template<int nr0, int nsg, int nw, typename args_t>
5813
  void kernel_mul_mv_iq4_nl_f32_impl(
5814
  args_t args,
5815
  device const char * src0,
 
5822
 
5823
  threadgroup float * shmem_f32 = (threadgroup float *) shmem;
5824
  const int nb = args.ne00/QK4_NL;
5825
+
5826
  const int r0 = tgpig.x;
5827
  const int r1 = tgpig.y;
5828
  const int im = tgpig.z;
5829
+
5830
+ const int first_row = (r0 * nsg + sgitg) * nr0;
5831
 
5832
  const uint i12 = im%args.ne12;
5833
  const uint i13 = im/args.ne12;
 
5838
  device const block_iq4_nl * x = (device const block_iq4_nl *) (src0 + offset0);
5839
  device const float * y = (device const float *) (src1 + offset1);
5840
 
5841
+ const short ix = tiisg/2; // 0...15
5842
+ const short it = tiisg%2; // 0 or 1
5843
 
5844
  shmem_f32[tiisg] = kvalues_iq4nl_f[tiisg%16];
5845
  threadgroup_barrier(mem_flags::mem_threadgroup);
5846
 
5847
  float4 yl[4];
5848
+ float sumf[nr0]={0.f};
5849
 
5850
  device const float * yb = y + ix * QK4_NL + it * 8;
5851
 
 
5855
  float4 qf1, qf2;
5856
 
5857
  for (int ib = ix; ib < nb; ib += 16) {
 
5858
  device const float4 * y4 = (device const float4 *)yb;
5859
+ yl[0] = y4[0];
5860
+ yl[1] = y4[4];
5861
+ yl[2] = y4[1];
5862
+ yl[3] = y4[5];
5863
 
5864
+ for (short row = 0; row < nr0; row++) {
5865
  device const block_iq4_nl & xb = x[row*nb + ib];
5866
  device const uint16_t * q4 = (device const uint16_t *)(xb.qs + 8*it);
5867
 
 
5886
  acc1 += acc2;
5887
 
5888
  sumf[row] += (float)xb.d * (acc1[0] + acc1[1] + acc1[2] + acc1[3]);
 
5889
  }
5890
 
5891
  yb += 16 * QK4_NL;
 
5893
 
5894
  device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
5895
 
5896
+ for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) {
5897
+ float sum_all = simd_sum(sumf[row]);
5898
  if (tiisg == 0) {
5899
+ dst_f32[first_row + row] = sum_all;
5900
  }
5901
  }
5902
  }
5903
 
5904
+ [[host_name("kernel_mul_mv_iq4_nl_f32")]]
5905
+ kernel void kernel_mul_mv_iq4_nl_f32(
5906
+ constant ggml_metal_kargs_mul_mv & args,
5907
+ device const char * src0,
5908
+ device const char * src1,
5909
+ device char * dst,
5910
+ threadgroup char * shmem [[threadgroup(0)]],
5911
+ uint3 tgpig[[threadgroup_position_in_grid]],
5912
+ ushort tiisg[[thread_index_in_simdgroup]],
5913
+ ushort sgitg[[simdgroup_index_in_threadgroup]]) {
5914
+
5915
+ kernel_mul_mv_iq4_nl_f32_impl<N_R0_IQ4_NL, N_SG_IQ4_NL, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
5916
+ }
5917
+
5918
+ template<int nr0, int nsg, int nw, typename args_t>
5919
  void kernel_mul_mv_iq4_xs_f32_impl(
5920
  args_t args,
5921
  device const char * src0,
 
5931
  const int r0 = tgpig.x;
5932
  const int r1 = tgpig.y;
5933
  const int im = tgpig.z;
5934
+ const int first_row = (r0 * nsg + sgitg) * nr0;
5935
 
5936
  const uint i12 = im%args.ne12;
5937
  const uint i13 = im/args.ne12;
 
5942
  device const block_iq4_xs * x = (device const block_iq4_xs *) (src0 + offset0);
5943
  device const float * y = (device const float *) (src1 + offset1);
5944
 
5945
+ const short ix = tiisg/16; // 0 or 1
5946
+ const short it = tiisg%16; // 0...15
5947
+ const short ib = it/2;
5948
+ const short il = it%2;
5949
 
5950
  shmem_f32[tiisg] = kvalues_iq4nl_f[tiisg%16];
5951
  threadgroup_barrier(mem_flags::mem_threadgroup);
5952
 
5953
  float4 yl[4];
5954
+ float sumf[nr0]={0.f};
5955
 
5956
  device const float * yb = y + ix * QK_K + ib * 32 + il * 8;
5957
 
 
5962
 
5963
  for (int ibl = ix; ibl < nb; ibl += 2) {
5964
  device const float4 * y4 = (device const float4 *)yb;
5965
+ yl[0] = y4[0];
5966
+ yl[1] = y4[4];
5967
+ yl[2] = y4[1];
5968
+ yl[3] = y4[5];
5969
 
5970
+ for (short row = 0; row < nr0; ++row) {
5971
  device const block_iq4_xs & xb = x[row*nb + ibl];
5972
  device const uint32_t * q4 = (device const uint32_t *)(xb.qs + 16*ib + 8*il);
5973
 
 
5991
 
5992
  const int ls = (((xb.scales_l[ib/2] >> 4*(ib%2)) & 0xf) | (((xb.scales_h >> 2*ib) & 3) << 4)) - 32;
5993
  sumf[row] += (float)xb.d * ls * (acc1[0] + acc1[1] + acc1[2] + acc1[3]);
 
5994
  }
5995
 
5996
  yb += 2 * QK_K;
 
5998
 
5999
  device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
6000
 
6001
+ for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) {
6002
+ float sum_all = simd_sum(sumf[row]);
6003
  if (tiisg == 0) {
6004
+ dst_f32[first_row + row] = sum_all;
6005
  }
6006
  }
6007
  }
6008
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6009
  [[host_name("kernel_mul_mv_iq4_xs_f32")]]
6010
  kernel void kernel_mul_mv_iq4_xs_f32(
6011
  constant ggml_metal_kargs_mul_mv & args,
 
6017
  ushort tiisg[[thread_index_in_simdgroup]],
6018
  ushort sgitg[[simdgroup_index_in_threadgroup]]) {
6019
 
6020
+ kernel_mul_mv_iq4_xs_f32_impl<N_R0_IQ4_XS, N_SG_IQ4_XS, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
6021
  }
6022
 
6023
  template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread float4x4 &)>
 
6661
  #if defined(GGML_METAL_USE_BF16)
6662
  template [[host_name("kernel_mul_mv_id_bf16_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_impl<bfloat, bfloat4, float, float4>>>;
6663
  #endif
6664
+ template [[host_name("kernel_mul_mv_id_q8_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q8_0_f32_impl<N_R0_Q8_0, N_SG_Q8_0, N_SIMDWIDTH>>>;
6665
+
6666
+ template [[host_name("kernel_mul_mv_id_q4_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q4_0, N_R0_Q4_0, N_SG_Q4_0, N_SIMDWIDTH>>>;
6667
+ template [[host_name("kernel_mul_mv_id_q4_1_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q4_1, N_R0_Q4_1, N_SG_Q4_1, N_SIMDWIDTH>>>;
6668
+ template [[host_name("kernel_mul_mv_id_q5_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q5_0, N_R0_Q5_0, N_SG_Q5_0, N_SIMDWIDTH>>>;
6669
+ template [[host_name("kernel_mul_mv_id_q5_1_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q5_1, N_R0_Q5_1, N_SG_Q5_1, N_SIMDWIDTH>>>;
6670
+
6671
+ template [[host_name("kernel_mul_mv_id_q2_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q2_K_f32_impl <N_R0_Q2_K, N_SG_Q2_K, N_SIMDWIDTH>>>;
6672
+ template [[host_name("kernel_mul_mv_id_q3_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q3_K_f32_impl <N_R0_Q3_K, N_SG_Q3_K, N_SIMDWIDTH>>>;
6673
+ template [[host_name("kernel_mul_mv_id_q4_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q4_K_f32_impl <N_R0_Q4_K, N_SG_Q4_K, N_SIMDWIDTH>>>;
6674
+ template [[host_name("kernel_mul_mv_id_q5_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q5_K_f32_impl <N_R0_Q5_K, N_SG_Q5_K, N_SIMDWIDTH>>>;
6675
+ template [[host_name("kernel_mul_mv_id_q6_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q6_K_f32_impl <N_R0_Q6_K, N_SG_Q6_K, N_SIMDWIDTH>>>;
6676
+ template [[host_name("kernel_mul_mv_id_iq1_s_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq1_s_f32_impl <N_R0_IQ1_S, N_SG_IQ1_S, N_SIMDWIDTH>>>;
6677
+ template [[host_name("kernel_mul_mv_id_iq1_m_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq1_m_f32_impl <N_R0_IQ1_M, N_SG_IQ1_M, N_SIMDWIDTH>>>;
6678
+ template [[host_name("kernel_mul_mv_id_iq2_xxs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq2_xxs_f32_impl<N_R0_IQ2_XXS, N_SG_IQ2_XXS, N_SIMDWIDTH>>>;
6679
+ template [[host_name("kernel_mul_mv_id_iq2_xs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq2_xs_f32_impl <N_R0_IQ2_XS, N_SG_IQ2_XS, N_SIMDWIDTH>>>;
6680
+ template [[host_name("kernel_mul_mv_id_iq3_xxs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq3_xxs_f32_impl<N_R0_IQ3_XXS, N_SG_IQ3_XXS, N_SIMDWIDTH>>>;
6681
+ template [[host_name("kernel_mul_mv_id_iq3_s_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq3_s_f32_impl <N_R0_IQ3_S, N_SG_IQ3_S, N_SIMDWIDTH>>>;
6682
+ template [[host_name("kernel_mul_mv_id_iq2_s_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq2_s_f32_impl <N_R0_IQ2_S, N_SG_IQ2_S, N_SIMDWIDTH>>>;
6683
+ template [[host_name("kernel_mul_mv_id_iq4_nl_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq4_nl_f32_impl <N_R0_IQ4_NL, N_SG_IQ4_NL, N_SIMDWIDTH>>>;
6684
+ template [[host_name("kernel_mul_mv_id_iq4_xs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq4_xs_f32_impl <N_R0_IQ4_XS, N_SG_IQ4_XS, N_SIMDWIDTH>>>;
6685
 
6686
  kernel void kernel_pool_2d_max_f32(
6687
  device const float * src0,