Spaces:
Running
Running
metal : small-batch mat-mul kernels (llama/10581)
Browse files* metal : small-batch mat-mul kernels
ggml-ci
* metal : add rest of types
ggml-ci
* metal : final adjustments
ggml-ci
* metal : add comments
ggml-ci
ggml/src/ggml-metal/ggml-metal-impl.h
CHANGED
|
@@ -192,6 +192,30 @@ typedef struct {
|
|
| 192 |
int16_t r3;
|
| 193 |
} ggml_metal_kargs_mul_mv;
|
| 194 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 195 |
typedef struct {
|
| 196 |
int32_t nei0;
|
| 197 |
int32_t nei1;
|
|
|
|
| 192 |
int16_t r3;
|
| 193 |
} ggml_metal_kargs_mul_mv;
|
| 194 |
|
| 195 |
+
typedef struct {
|
| 196 |
+
int32_t ne00;
|
| 197 |
+
int32_t ne01;
|
| 198 |
+
int32_t ne02;
|
| 199 |
+
uint64_t nb00;
|
| 200 |
+
uint64_t nb01;
|
| 201 |
+
uint64_t nb02;
|
| 202 |
+
uint64_t nb03;
|
| 203 |
+
int32_t ne10;
|
| 204 |
+
int32_t ne11;
|
| 205 |
+
int32_t ne12;
|
| 206 |
+
uint64_t nb10;
|
| 207 |
+
uint64_t nb11;
|
| 208 |
+
uint64_t nb12;
|
| 209 |
+
uint64_t nb13;
|
| 210 |
+
int32_t ne0;
|
| 211 |
+
int32_t ne1;
|
| 212 |
+
int16_t r2;
|
| 213 |
+
int16_t r3;
|
| 214 |
+
int16_t nsg;
|
| 215 |
+
int16_t nxpsg;
|
| 216 |
+
int16_t r1ptg;
|
| 217 |
+
} ggml_metal_kargs_mul_mv_ext;
|
| 218 |
+
|
| 219 |
typedef struct {
|
| 220 |
int32_t nei0;
|
| 221 |
int32_t nei1;
|
ggml/src/ggml-metal/ggml-metal.m
CHANGED
|
@@ -175,6 +175,46 @@ enum ggml_metal_kernel_type {
|
|
| 175 |
GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32,
|
| 176 |
GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32,
|
| 177 |
GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 178 |
GGML_METAL_KERNEL_TYPE_MUL_MV_Q2_K_F32,
|
| 179 |
GGML_METAL_KERNEL_TYPE_MUL_MV_Q3_K_F32,
|
| 180 |
GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_K_F32,
|
|
@@ -702,6 +742,46 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
|
|
| 702 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32, mul_mv_q5_0_f32, has_simdgroup_reduction);
|
| 703 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32, mul_mv_q5_1_f32, has_simdgroup_reduction);
|
| 704 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32, mul_mv_q8_0_f32, has_simdgroup_reduction);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 705 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q2_K_F32, mul_mv_q2_K_f32, has_simdgroup_reduction);
|
| 706 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q3_K_F32, mul_mv_q3_K_f32, has_simdgroup_reduction);
|
| 707 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_K_F32, mul_mv_q4_K_f32, has_simdgroup_reduction);
|
|
@@ -1936,30 +2016,180 @@ static void ggml_metal_encode_node(
|
|
| 1936 |
|
| 1937 |
// find the break-even point where the matrix-matrix kernel becomes more efficient compared
|
| 1938 |
// to the matrix-vector kernel
|
| 1939 |
-
int ne11_mm_min = 4;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1940 |
|
| 1941 |
-
|
| 1942 |
-
|
| 1943 |
-
|
| 1944 |
-
|
| 1945 |
-
|
| 1946 |
-
|
| 1947 |
-
case
|
| 1948 |
-
case
|
| 1949 |
-
case
|
| 1950 |
-
|
| 1951 |
-
|
| 1952 |
-
|
| 1953 |
-
|
| 1954 |
-
case
|
| 1955 |
-
case
|
| 1956 |
-
case
|
| 1957 |
-
case
|
| 1958 |
-
default:
|
| 1959 |
-
}
|
| 1960 |
-
|
| 1961 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1962 |
|
|
|
|
|
|
|
|
|
|
| 1963 |
// for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
|
| 1964 |
// AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
|
| 1965 |
if ([device supportsFamily:MTLGPUFamilyApple7] &&
|
|
|
|
| 175 |
GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32,
|
| 176 |
GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32,
|
| 177 |
GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32,
|
| 178 |
+
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_2,
|
| 179 |
+
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_3,
|
| 180 |
+
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_4,
|
| 181 |
+
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_5,
|
| 182 |
+
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_2,
|
| 183 |
+
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_3,
|
| 184 |
+
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_4,
|
| 185 |
+
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_5,
|
| 186 |
+
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_2,
|
| 187 |
+
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_3,
|
| 188 |
+
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_4,
|
| 189 |
+
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_5,
|
| 190 |
+
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_2,
|
| 191 |
+
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_3,
|
| 192 |
+
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_4,
|
| 193 |
+
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_5,
|
| 194 |
+
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_2,
|
| 195 |
+
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_3,
|
| 196 |
+
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_4,
|
| 197 |
+
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_5,
|
| 198 |
+
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_2,
|
| 199 |
+
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_3,
|
| 200 |
+
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_4,
|
| 201 |
+
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_5,
|
| 202 |
+
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_2,
|
| 203 |
+
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_3,
|
| 204 |
+
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_4,
|
| 205 |
+
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_5,
|
| 206 |
+
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_2,
|
| 207 |
+
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_3,
|
| 208 |
+
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_4,
|
| 209 |
+
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_5,
|
| 210 |
+
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_2,
|
| 211 |
+
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_3,
|
| 212 |
+
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_4,
|
| 213 |
+
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_5,
|
| 214 |
+
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_2,
|
| 215 |
+
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_3,
|
| 216 |
+
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_4,
|
| 217 |
+
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_5,
|
| 218 |
GGML_METAL_KERNEL_TYPE_MUL_MV_Q2_K_F32,
|
| 219 |
GGML_METAL_KERNEL_TYPE_MUL_MV_Q3_K_F32,
|
| 220 |
GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_K_F32,
|
|
|
|
| 742 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32, mul_mv_q5_0_f32, has_simdgroup_reduction);
|
| 743 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32, mul_mv_q5_1_f32, has_simdgroup_reduction);
|
| 744 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32, mul_mv_q8_0_f32, has_simdgroup_reduction);
|
| 745 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_2, mul_mv_ext_f16_f32_r1_2, has_simdgroup_reduction);
|
| 746 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_3, mul_mv_ext_f16_f32_r1_3, has_simdgroup_reduction);
|
| 747 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_4, mul_mv_ext_f16_f32_r1_4, has_simdgroup_reduction);
|
| 748 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_5, mul_mv_ext_f16_f32_r1_5, has_simdgroup_reduction);
|
| 749 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_2, mul_mv_ext_q4_0_f32_r1_2, has_simdgroup_reduction);
|
| 750 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_3, mul_mv_ext_q4_0_f32_r1_3, has_simdgroup_reduction);
|
| 751 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_4, mul_mv_ext_q4_0_f32_r1_4, has_simdgroup_reduction);
|
| 752 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_5, mul_mv_ext_q4_0_f32_r1_5, has_simdgroup_reduction);
|
| 753 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_2, mul_mv_ext_q4_1_f32_r1_2, has_simdgroup_reduction);
|
| 754 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_3, mul_mv_ext_q4_1_f32_r1_3, has_simdgroup_reduction);
|
| 755 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_4, mul_mv_ext_q4_1_f32_r1_4, has_simdgroup_reduction);
|
| 756 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_5, mul_mv_ext_q4_1_f32_r1_5, has_simdgroup_reduction);
|
| 757 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_2, mul_mv_ext_q5_0_f32_r1_2, has_simdgroup_reduction);
|
| 758 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_3, mul_mv_ext_q5_0_f32_r1_3, has_simdgroup_reduction);
|
| 759 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_4, mul_mv_ext_q5_0_f32_r1_4, has_simdgroup_reduction);
|
| 760 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_5, mul_mv_ext_q5_0_f32_r1_5, has_simdgroup_reduction);
|
| 761 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_2, mul_mv_ext_q5_1_f32_r1_2, has_simdgroup_reduction);
|
| 762 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_3, mul_mv_ext_q5_1_f32_r1_3, has_simdgroup_reduction);
|
| 763 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_4, mul_mv_ext_q5_1_f32_r1_4, has_simdgroup_reduction);
|
| 764 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_5, mul_mv_ext_q5_1_f32_r1_5, has_simdgroup_reduction);
|
| 765 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_2, mul_mv_ext_q8_0_f32_r1_2, has_simdgroup_reduction);
|
| 766 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_3, mul_mv_ext_q8_0_f32_r1_3, has_simdgroup_reduction);
|
| 767 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_4, mul_mv_ext_q8_0_f32_r1_4, has_simdgroup_reduction);
|
| 768 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_5, mul_mv_ext_q8_0_f32_r1_5, has_simdgroup_reduction);
|
| 769 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_2, mul_mv_ext_q4_K_f32_r1_2, has_simdgroup_reduction);
|
| 770 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_3, mul_mv_ext_q4_K_f32_r1_3, has_simdgroup_reduction);
|
| 771 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_4, mul_mv_ext_q4_K_f32_r1_4, has_simdgroup_reduction);
|
| 772 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_5, mul_mv_ext_q4_K_f32_r1_5, has_simdgroup_reduction);
|
| 773 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_2, mul_mv_ext_q5_K_f32_r1_2, has_simdgroup_reduction);
|
| 774 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_3, mul_mv_ext_q5_K_f32_r1_3, has_simdgroup_reduction);
|
| 775 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_4, mul_mv_ext_q5_K_f32_r1_4, has_simdgroup_reduction);
|
| 776 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_5, mul_mv_ext_q5_K_f32_r1_5, has_simdgroup_reduction);
|
| 777 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_2, mul_mv_ext_q6_K_f32_r1_2, has_simdgroup_reduction);
|
| 778 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_3, mul_mv_ext_q6_K_f32_r1_3, has_simdgroup_reduction);
|
| 779 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_4, mul_mv_ext_q6_K_f32_r1_4, has_simdgroup_reduction);
|
| 780 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_5, mul_mv_ext_q6_K_f32_r1_5, has_simdgroup_reduction);
|
| 781 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_2, mul_mv_ext_iq4_nl_f32_r1_2, has_simdgroup_reduction);
|
| 782 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_3, mul_mv_ext_iq4_nl_f32_r1_3, has_simdgroup_reduction);
|
| 783 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_4, mul_mv_ext_iq4_nl_f32_r1_4, has_simdgroup_reduction);
|
| 784 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_5, mul_mv_ext_iq4_nl_f32_r1_5, has_simdgroup_reduction);
|
| 785 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q2_K_F32, mul_mv_q2_K_f32, has_simdgroup_reduction);
|
| 786 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q3_K_F32, mul_mv_q3_K_f32, has_simdgroup_reduction);
|
| 787 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_K_F32, mul_mv_q4_K_f32, has_simdgroup_reduction);
|
|
|
|
| 2016 |
|
| 2017 |
// find the break-even point where the matrix-matrix kernel becomes more efficient compared
|
| 2018 |
// to the matrix-vector kernel
|
| 2019 |
+
const int ne11_mm_min = 4;
|
| 2020 |
+
|
| 2021 |
+
// first try to use small-batch mat-mv kernels
|
| 2022 |
+
// these should be efficient for BS [2, ~8]
|
| 2023 |
+
if (src1t == GGML_TYPE_F32 && (ne00%256 == 0) &&
|
| 2024 |
+
(
|
| 2025 |
+
(
|
| 2026 |
+
(
|
| 2027 |
+
src0t == GGML_TYPE_F16 || // TODO: helper function
|
| 2028 |
+
src0t == GGML_TYPE_Q4_0 ||
|
| 2029 |
+
src0t == GGML_TYPE_Q4_1 ||
|
| 2030 |
+
src0t == GGML_TYPE_Q5_0 ||
|
| 2031 |
+
src0t == GGML_TYPE_Q5_1 ||
|
| 2032 |
+
src0t == GGML_TYPE_Q8_0 ||
|
| 2033 |
+
src0t == GGML_TYPE_IQ4_NL ||
|
| 2034 |
+
false) && (ne11 >= 2 && ne11 <= 8)
|
| 2035 |
+
) ||
|
| 2036 |
+
(
|
| 2037 |
+
(
|
| 2038 |
+
src0t == GGML_TYPE_Q4_K ||
|
| 2039 |
+
src0t == GGML_TYPE_Q5_K ||
|
| 2040 |
+
src0t == GGML_TYPE_Q6_K ||
|
| 2041 |
+
false) && (ne11 >= 4 && ne11 <= 8)
|
| 2042 |
+
)
|
| 2043 |
+
)
|
| 2044 |
+
) {
|
| 2045 |
+
// TODO: determine the optimal parameters based on grid utilization
|
| 2046 |
+
// I still don't know why we should not always use the maximum available threads:
|
| 2047 |
+
//
|
| 2048 |
+
// nsg = pipeline.maxTotalThreadsPerThreadgroup / 32
|
| 2049 |
+
//
|
| 2050 |
+
// my current hypothesis is that the work grid is not evenly divisible for different nsg
|
| 2051 |
+
// values and there can be some tail effects when nsg is high. need to confirm this
|
| 2052 |
+
//
|
| 2053 |
+
const int nsg = 2; // num simdgroups per threadgroup
|
| 2054 |
+
const int nxpsg = ne11 < 3 ? 16 : 8; // num threads along row per simdgroup
|
| 2055 |
+
const int nypsg = 32/nxpsg; // num threads along col per simdgroup (i.e. a simdgroup processes that many src0 rows at a time)
|
| 2056 |
+
const int r0ptg = nypsg*nsg; // num src0 rows per threadgroup
|
| 2057 |
+
int r1ptg = 4; // num src1 rows per threadgroup
|
| 2058 |
+
|
| 2059 |
+
// note: not sure how optimal are those across all different hardware. there might be someting cleverer
|
| 2060 |
+
switch (ne11) {
|
| 2061 |
+
case 2:
|
| 2062 |
+
r1ptg = 2; break;
|
| 2063 |
+
case 3:
|
| 2064 |
+
case 6:
|
| 2065 |
+
r1ptg = 3; break;
|
| 2066 |
+
case 4:
|
| 2067 |
+
case 7:
|
| 2068 |
+
case 8:
|
| 2069 |
+
r1ptg = 4; break;
|
| 2070 |
+
case 5:
|
| 2071 |
+
r1ptg = 5; break;
|
| 2072 |
+
};
|
| 2073 |
|
| 2074 |
+
id<MTLComputePipelineState> pipeline = nil;
|
| 2075 |
+
|
| 2076 |
+
switch (src0->type) {
|
| 2077 |
+
case GGML_TYPE_F16:
|
| 2078 |
+
switch (r1ptg) {
|
| 2079 |
+
case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_2].pipeline; break;
|
| 2080 |
+
case 3: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_3].pipeline; break;
|
| 2081 |
+
case 4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_4].pipeline; break;
|
| 2082 |
+
case 5: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_5].pipeline; break;
|
| 2083 |
+
default: GGML_ABORT("not implemented");
|
| 2084 |
+
} break;
|
| 2085 |
+
case GGML_TYPE_Q4_0:
|
| 2086 |
+
switch (r1ptg) {
|
| 2087 |
+
case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_2].pipeline; break;
|
| 2088 |
+
case 3: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_3].pipeline; break;
|
| 2089 |
+
case 4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_4].pipeline; break;
|
| 2090 |
+
case 5: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_5].pipeline; break;
|
| 2091 |
+
default: GGML_ABORT("not implemented");
|
| 2092 |
+
} break;
|
| 2093 |
+
case GGML_TYPE_Q4_1:
|
| 2094 |
+
switch (r1ptg) {
|
| 2095 |
+
case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_2].pipeline; break;
|
| 2096 |
+
case 3: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_3].pipeline; break;
|
| 2097 |
+
case 4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_4].pipeline; break;
|
| 2098 |
+
case 5: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_5].pipeline; break;
|
| 2099 |
+
default: GGML_ABORT("not implemented");
|
| 2100 |
+
} break;
|
| 2101 |
+
case GGML_TYPE_Q5_0:
|
| 2102 |
+
switch (r1ptg) {
|
| 2103 |
+
case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_2].pipeline; break;
|
| 2104 |
+
case 3: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_3].pipeline; break;
|
| 2105 |
+
case 4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_4].pipeline; break;
|
| 2106 |
+
case 5: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_5].pipeline; break;
|
| 2107 |
+
default: GGML_ABORT("not implemented");
|
| 2108 |
+
} break;
|
| 2109 |
+
case GGML_TYPE_Q5_1:
|
| 2110 |
+
switch (r1ptg) {
|
| 2111 |
+
case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_2].pipeline; break;
|
| 2112 |
+
case 3: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_3].pipeline; break;
|
| 2113 |
+
case 4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_4].pipeline; break;
|
| 2114 |
+
case 5: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_5].pipeline; break;
|
| 2115 |
+
default: GGML_ABORT("not implemented");
|
| 2116 |
+
} break;
|
| 2117 |
+
case GGML_TYPE_Q8_0:
|
| 2118 |
+
switch (r1ptg) {
|
| 2119 |
+
case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_2].pipeline; break;
|
| 2120 |
+
case 3: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_3].pipeline; break;
|
| 2121 |
+
case 4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_4].pipeline; break;
|
| 2122 |
+
case 5: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_5].pipeline; break;
|
| 2123 |
+
default: GGML_ABORT("not implemented");
|
| 2124 |
+
} break;
|
| 2125 |
+
case GGML_TYPE_Q4_K:
|
| 2126 |
+
switch (r1ptg) {
|
| 2127 |
+
case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_2].pipeline; break;
|
| 2128 |
+
case 3: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_3].pipeline; break;
|
| 2129 |
+
case 4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_4].pipeline; break;
|
| 2130 |
+
case 5: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_5].pipeline; break;
|
| 2131 |
+
default: GGML_ABORT("not implemented");
|
| 2132 |
+
} break;
|
| 2133 |
+
case GGML_TYPE_Q5_K:
|
| 2134 |
+
switch (r1ptg) {
|
| 2135 |
+
case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_2].pipeline; break;
|
| 2136 |
+
case 3: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_3].pipeline; break;
|
| 2137 |
+
case 4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_4].pipeline; break;
|
| 2138 |
+
case 5: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_5].pipeline; break;
|
| 2139 |
+
default: GGML_ABORT("not implemented");
|
| 2140 |
+
} break;
|
| 2141 |
+
case GGML_TYPE_Q6_K:
|
| 2142 |
+
switch (r1ptg) {
|
| 2143 |
+
case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_2].pipeline; break;
|
| 2144 |
+
case 3: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_3].pipeline; break;
|
| 2145 |
+
case 4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_4].pipeline; break;
|
| 2146 |
+
case 5: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_5].pipeline; break;
|
| 2147 |
+
default: GGML_ABORT("not implemented");
|
| 2148 |
+
} break;
|
| 2149 |
+
case GGML_TYPE_IQ4_NL:
|
| 2150 |
+
switch (r1ptg) {
|
| 2151 |
+
case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_2].pipeline; break;
|
| 2152 |
+
case 3: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_3].pipeline; break;
|
| 2153 |
+
case 4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_4].pipeline; break;
|
| 2154 |
+
case 5: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_5].pipeline; break;
|
| 2155 |
+
default: GGML_ABORT("not implemented");
|
| 2156 |
+
} break;
|
| 2157 |
+
default: GGML_ABORT("not implemented");
|
| 2158 |
+
}
|
| 2159 |
+
|
| 2160 |
+
ggml_metal_kargs_mul_mv_ext args = {
|
| 2161 |
+
/*.ne00 =*/ ne00,
|
| 2162 |
+
/*.ne01 =*/ ne01,
|
| 2163 |
+
/*.ne02 =*/ ne02,
|
| 2164 |
+
/*.nb00 =*/ nb00,
|
| 2165 |
+
/*.nb01 =*/ nb01,
|
| 2166 |
+
/*.nb02 =*/ nb02,
|
| 2167 |
+
/*.nb03 =*/ nb03,
|
| 2168 |
+
/*.ne10 =*/ ne10,
|
| 2169 |
+
/*.ne11 =*/ ne11,
|
| 2170 |
+
/*.ne12 =*/ ne12,
|
| 2171 |
+
/*.nb10 =*/ nb10,
|
| 2172 |
+
/*.nb11 =*/ nb11,
|
| 2173 |
+
/*.nb12 =*/ nb12,
|
| 2174 |
+
/*.nb13 =*/ nb13,
|
| 2175 |
+
/*.ne0 =*/ ne0,
|
| 2176 |
+
/*.ne1 =*/ ne1,
|
| 2177 |
+
/*.r2 =*/ r2,
|
| 2178 |
+
/*.r3 =*/ r3,
|
| 2179 |
+
/*.nsg =*/ nsg,
|
| 2180 |
+
/*.nxpsg =*/ nxpsg,
|
| 2181 |
+
/*.r1ptg =*/ r1ptg,
|
| 2182 |
+
};
|
| 2183 |
+
|
| 2184 |
+
[encoder setComputePipelineState:pipeline];
|
| 2185 |
+
[encoder setBytes:&args length:sizeof(args) atIndex:0];
|
| 2186 |
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
|
| 2187 |
+
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:2];
|
| 2188 |
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:3];
|
| 2189 |
|
| 2190 |
+
//printf("ne01 = %lld nr0ptg = %d\n", ne01, nr0ptg);
|
| 2191 |
+
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + r0ptg - 1)/r0ptg, (ne11 + r1ptg - 1)/r1ptg, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)];
|
| 2192 |
+
} else
|
| 2193 |
// for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
|
| 2194 |
// AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
|
| 2195 |
if ([device supportsFamily:MTLGPUFamilyApple7] &&
|
ggml/src/ggml-metal/ggml-metal.metal
CHANGED
|
@@ -47,6 +47,11 @@ void dequantize_f16(device const half4x4 * src, short il, thread type4x4 & reg)
|
|
| 47 |
reg = (type4x4)(*src);
|
| 48 |
}
|
| 49 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 50 |
#if defined(GGML_METAL_USE_BF16)
|
| 51 |
template <typename type4x4>
|
| 52 |
void dequantize_bf16(device const bfloat4x4 * src, short il, thread type4x4 & reg) {
|
|
@@ -55,7 +60,7 @@ void dequantize_bf16(device const bfloat4x4 * src, short il, thread type4x4 & re
|
|
| 55 |
#endif
|
| 56 |
|
| 57 |
template <typename type4x4>
|
| 58 |
-
void dequantize_q4_0(device const block_q4_0 *xb, short il, thread type4x4 & reg) {
|
| 59 |
device const uint16_t * qs = ((device const uint16_t *)xb + 1);
|
| 60 |
const float d1 = il ? (xb->d / 16.h) : xb->d;
|
| 61 |
const float d2 = d1 / 256.f;
|
|
@@ -73,8 +78,23 @@ void dequantize_q4_0(device const block_q4_0 *xb, short il, thread type4x4 & reg
|
|
| 73 |
reg = (type4x4) reg_f;
|
| 74 |
}
|
| 75 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 76 |
template <typename type4x4>
|
| 77 |
-
void dequantize_q4_1(device const block_q4_1 *xb, short il, thread type4x4 & reg) {
|
| 78 |
device const uint16_t * qs = ((device const uint16_t *)xb + 2);
|
| 79 |
const float d1 = il ? (xb->d / 16.h) : xb->d;
|
| 80 |
const float d2 = d1 / 256.f;
|
|
@@ -92,8 +112,23 @@ void dequantize_q4_1(device const block_q4_1 *xb, short il, thread type4x4 & reg
|
|
| 92 |
reg = (type4x4) reg_f;
|
| 93 |
}
|
| 94 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 95 |
template <typename type4x4>
|
| 96 |
-
void dequantize_q5_0(device const block_q5_0 *xb, short il, thread type4x4 & reg) {
|
| 97 |
device const uint16_t * qs = ((device const uint16_t *)xb + 3);
|
| 98 |
const float d = xb->d;
|
| 99 |
const float md = -16.h * xb->d;
|
|
@@ -124,8 +159,38 @@ void dequantize_q5_0(device const block_q5_0 *xb, short il, thread type4x4 & reg
|
|
| 124 |
reg = (type4x4) reg_f;
|
| 125 |
}
|
| 126 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 127 |
template <typename type4x4>
|
| 128 |
-
void dequantize_q5_1(device const block_q5_1 *xb, short il, thread type4x4 & reg) {
|
| 129 |
device const uint16_t * qs = ((device const uint16_t *)xb + 4);
|
| 130 |
const float d = xb->d;
|
| 131 |
const float m = xb->m;
|
|
@@ -156,10 +221,40 @@ void dequantize_q5_1(device const block_q5_1 *xb, short il, thread type4x4 & reg
|
|
| 156 |
reg = (type4x4) reg_f;
|
| 157 |
}
|
| 158 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 159 |
template <typename type4x4>
|
| 160 |
void dequantize_q8_0(device const block_q8_0 *xb, short il, thread type4x4 & reg) {
|
| 161 |
device const int8_t * qs = ((device const int8_t *)xb->qs);
|
| 162 |
-
const
|
| 163 |
|
| 164 |
float4x4 reg_f;
|
| 165 |
|
|
@@ -170,6 +265,16 @@ void dequantize_q8_0(device const block_q8_0 *xb, short il, thread type4x4 & reg
|
|
| 170 |
reg = (type4x4) reg_f;
|
| 171 |
}
|
| 172 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 173 |
template <typename type4x4>
|
| 174 |
void dequantize_q2_K(device const block_q2_K *xb, short il, thread type4x4 & reg) {
|
| 175 |
const float d = xb->d;
|
|
@@ -224,7 +329,7 @@ static inline uchar2 get_scale_min_k4_just2(int j, int k, device const uchar * q
|
|
| 224 |
}
|
| 225 |
|
| 226 |
template <typename type4x4>
|
| 227 |
-
void dequantize_q4_K(device const block_q4_K *xb, short il, thread type4x4 & reg) {
|
| 228 |
device const uchar * q = xb->qs;
|
| 229 |
|
| 230 |
short is = (il/4) * 2;
|
|
@@ -236,7 +341,7 @@ void dequantize_q4_K(device const block_q4_K *xb, short il, thread type4x4 & reg
|
|
| 236 |
const float dl = d * sc[0];
|
| 237 |
const float ml = min * sc[1];
|
| 238 |
|
| 239 |
-
const ushort mask = il<2 ? 0x0F : 0xF0;
|
| 240 |
for (int i = 0; i < 16; ++i) {
|
| 241 |
reg[i/4][i%4] = dl * (q[i] & mask) - ml;
|
| 242 |
}
|
|
@@ -469,6 +574,19 @@ void dequantize_iq4_nl(device const block_iq4_nl * xb, short il, thread type4x4
|
|
| 469 |
}
|
| 470 |
}
|
| 471 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 472 |
template <typename type4x4>
|
| 473 |
void dequantize_iq4_xs(device const block_iq4_xs * xb, short il, thread type4x4 & reg) {
|
| 474 |
// il is 0...15 for QK_K = 256 => index of block of 32 is il/2
|
|
@@ -1809,6 +1927,301 @@ kernel void kernel_mul_mv_q8_0_f32(
|
|
| 1809 |
kernel_mul_mv_q8_0_f32_impl<constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
|
| 1810 |
}
|
| 1811 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1812 |
#define N_MV_T_T 4
|
| 1813 |
|
| 1814 |
template<typename T0, typename T04, typename T1, typename T14, typename args_t>
|
|
|
|
| 47 |
reg = (type4x4)(*src);
|
| 48 |
}
|
| 49 |
|
| 50 |
+
template <typename type4>
|
| 51 |
+
void dequantize_f16_t4(device const half4 * src, short il, thread type4 & reg) {
|
| 52 |
+
reg = (type4)(*(src + il));
|
| 53 |
+
}
|
| 54 |
+
|
| 55 |
#if defined(GGML_METAL_USE_BF16)
|
| 56 |
template <typename type4x4>
|
| 57 |
void dequantize_bf16(device const bfloat4x4 * src, short il, thread type4x4 & reg) {
|
|
|
|
| 60 |
#endif
|
| 61 |
|
| 62 |
template <typename type4x4>
|
| 63 |
+
void dequantize_q4_0(device const block_q4_0 * xb, short il, thread type4x4 & reg) {
|
| 64 |
device const uint16_t * qs = ((device const uint16_t *)xb + 1);
|
| 65 |
const float d1 = il ? (xb->d / 16.h) : xb->d;
|
| 66 |
const float d2 = d1 / 256.f;
|
|
|
|
| 78 |
reg = (type4x4) reg_f;
|
| 79 |
}
|
| 80 |
|
| 81 |
+
template <typename type4>
|
| 82 |
+
void dequantize_q4_0_t4(device const block_q4_0 * xb, short il, thread type4 & reg) {
|
| 83 |
+
device const uint16_t * qs = ((device const uint16_t *)xb + 1);
|
| 84 |
+
const float d1 = (il/4) ? (xb->d / 16.h) : xb->d;
|
| 85 |
+
const float d2 = d1 / 256.f;
|
| 86 |
+
const float md = -8.h * xb->d;
|
| 87 |
+
const ushort mask0 = (il/4) ? 0x00F0 : 0x000F;
|
| 88 |
+
const ushort mask1 = mask0 << 8;
|
| 89 |
+
|
| 90 |
+
for (int i = 0; i < 2; i++) {
|
| 91 |
+
reg[2*i + 0] = d1 * (qs[2*(il%4) + i] & mask0) + md;
|
| 92 |
+
reg[2*i + 1] = d2 * (qs[2*(il%4) + i] & mask1) + md;
|
| 93 |
+
}
|
| 94 |
+
}
|
| 95 |
+
|
| 96 |
template <typename type4x4>
|
| 97 |
+
void dequantize_q4_1(device const block_q4_1 * xb, short il, thread type4x4 & reg) {
|
| 98 |
device const uint16_t * qs = ((device const uint16_t *)xb + 2);
|
| 99 |
const float d1 = il ? (xb->d / 16.h) : xb->d;
|
| 100 |
const float d2 = d1 / 256.f;
|
|
|
|
| 112 |
reg = (type4x4) reg_f;
|
| 113 |
}
|
| 114 |
|
| 115 |
+
template <typename type4>
|
| 116 |
+
void dequantize_q4_1_t4(device const block_q4_1 * xb, short il, thread type4 & reg) {
|
| 117 |
+
device const uint16_t * qs = ((device const uint16_t *)xb + 2);
|
| 118 |
+
const float d1 = (il/4) ? (xb->d / 16.h) : xb->d;
|
| 119 |
+
const float d2 = d1 / 256.f;
|
| 120 |
+
const float m = xb->m;
|
| 121 |
+
const ushort mask0 = (il/4) ? 0x00F0 : 0x000F;
|
| 122 |
+
const ushort mask1 = mask0 << 8;
|
| 123 |
+
|
| 124 |
+
for (int i = 0; i < 2; i++) {
|
| 125 |
+
reg[2*i + 0] = d1 * (qs[2*(il%4) + i] & mask0) + m;
|
| 126 |
+
reg[2*i + 1] = d2 * (qs[2*(il%4) + i] & mask1) + m;
|
| 127 |
+
}
|
| 128 |
+
}
|
| 129 |
+
|
| 130 |
template <typename type4x4>
|
| 131 |
+
void dequantize_q5_0(device const block_q5_0 * xb, short il, thread type4x4 & reg) {
|
| 132 |
device const uint16_t * qs = ((device const uint16_t *)xb + 3);
|
| 133 |
const float d = xb->d;
|
| 134 |
const float md = -16.h * xb->d;
|
|
|
|
| 159 |
reg = (type4x4) reg_f;
|
| 160 |
}
|
| 161 |
|
| 162 |
+
template <typename type4>
|
| 163 |
+
void dequantize_q5_0_t4(device const block_q5_0 * xb, short il, thread type4 & reg) {
|
| 164 |
+
device const uint16_t * qs = ((device const uint16_t *)xb + 3);
|
| 165 |
+
const float d = xb->d;
|
| 166 |
+
const float md = -16.h * xb->d;
|
| 167 |
+
const ushort mask = (il/4) ? 0x00F0 : 0x000F;
|
| 168 |
+
|
| 169 |
+
const uint32_t qh = *((device const uint32_t *)xb->qh);
|
| 170 |
+
|
| 171 |
+
const int x_mv = (il/4) ? 4 : 0;
|
| 172 |
+
|
| 173 |
+
const int gh_mv = (il/4) ? 12 : 0;
|
| 174 |
+
const int gh_bk = (il/4) ? 0 : 4;
|
| 175 |
+
|
| 176 |
+
for (int ii = 0; ii < 2; ii++) {
|
| 177 |
+
int i = 2*(il%4) + ii;
|
| 178 |
+
|
| 179 |
+
// extract the 5-th bits for x0 and x1
|
| 180 |
+
const uint8_t xh_0 = ((qh >> (gh_mv + 2*i )) << gh_bk) & 0x10;
|
| 181 |
+
const uint8_t xh_1 = ((qh >> (gh_mv + 2*i+1)) << gh_bk) & 0x10;
|
| 182 |
+
|
| 183 |
+
// combine the 4-bits from qs with the 5th bit
|
| 184 |
+
const int32_t x0 = ((((qs[i] ) & mask) >> x_mv) | xh_0);
|
| 185 |
+
const int32_t x1 = ((((qs[i] >> 8) & mask) >> x_mv) | xh_1);
|
| 186 |
+
|
| 187 |
+
reg[2*ii + 0] = d * x0 + md;
|
| 188 |
+
reg[2*ii + 1] = d * x1 + md;
|
| 189 |
+
}
|
| 190 |
+
}
|
| 191 |
+
|
| 192 |
template <typename type4x4>
|
| 193 |
+
void dequantize_q5_1(device const block_q5_1 * xb, short il, thread type4x4 & reg) {
|
| 194 |
device const uint16_t * qs = ((device const uint16_t *)xb + 4);
|
| 195 |
const float d = xb->d;
|
| 196 |
const float m = xb->m;
|
|
|
|
| 221 |
reg = (type4x4) reg_f;
|
| 222 |
}
|
| 223 |
|
| 224 |
+
template <typename type4>
|
| 225 |
+
void dequantize_q5_1_t4(device const block_q5_1 * xb, short il, thread type4 & reg) {
|
| 226 |
+
device const uint16_t * qs = ((device const uint16_t *)xb + 4);
|
| 227 |
+
const float d = xb->d;
|
| 228 |
+
const float m = xb->m;
|
| 229 |
+
const ushort mask = (il/4) ? 0x00F0 : 0x000F;
|
| 230 |
+
|
| 231 |
+
const uint32_t qh = *((device const uint32_t *)xb->qh);
|
| 232 |
+
|
| 233 |
+
const int x_mv = (il/4) ? 4 : 0;
|
| 234 |
+
|
| 235 |
+
const int gh_mv = (il/4) ? 12 : 0;
|
| 236 |
+
const int gh_bk = (il/4) ? 0 : 4;
|
| 237 |
+
|
| 238 |
+
for (int ii = 0; ii < 2; ii++) {
|
| 239 |
+
int i = 2*(il%4) + ii;
|
| 240 |
+
|
| 241 |
+
// extract the 5-th bits for x0 and x1
|
| 242 |
+
const uint8_t xh_0 = ((qh >> (gh_mv + 2*i )) << gh_bk) & 0x10;
|
| 243 |
+
const uint8_t xh_1 = ((qh >> (gh_mv + 2*i+1)) << gh_bk) & 0x10;
|
| 244 |
+
|
| 245 |
+
// combine the 4-bits from qs with the 5th bit
|
| 246 |
+
const int32_t x0 = ((((qs[i] ) & mask) >> x_mv) | xh_0);
|
| 247 |
+
const int32_t x1 = ((((qs[i] >> 8) & mask) >> x_mv) | xh_1);
|
| 248 |
+
|
| 249 |
+
reg[2*ii + 0] = d * x0 + m;
|
| 250 |
+
reg[2*ii + 1] = d * x1 + m;
|
| 251 |
+
}
|
| 252 |
+
}
|
| 253 |
+
|
| 254 |
template <typename type4x4>
|
| 255 |
void dequantize_q8_0(device const block_q8_0 *xb, short il, thread type4x4 & reg) {
|
| 256 |
device const int8_t * qs = ((device const int8_t *)xb->qs);
|
| 257 |
+
const float d = xb->d;
|
| 258 |
|
| 259 |
float4x4 reg_f;
|
| 260 |
|
|
|
|
| 265 |
reg = (type4x4) reg_f;
|
| 266 |
}
|
| 267 |
|
| 268 |
+
template <typename type4>
|
| 269 |
+
void dequantize_q8_0_t4(device const block_q8_0 *xb, short il, thread type4 & reg) {
|
| 270 |
+
device const int8_t * qs = ((device const int8_t *)xb->qs);
|
| 271 |
+
const float d = xb->d;
|
| 272 |
+
|
| 273 |
+
for (int i = 0; i < 4; i++) {
|
| 274 |
+
reg[i] = (qs[4*(il%4) + i + 16*(il/4)] * d);
|
| 275 |
+
}
|
| 276 |
+
}
|
| 277 |
+
|
| 278 |
template <typename type4x4>
|
| 279 |
void dequantize_q2_K(device const block_q2_K *xb, short il, thread type4x4 & reg) {
|
| 280 |
const float d = xb->d;
|
|
|
|
| 329 |
}
|
| 330 |
|
| 331 |
template <typename type4x4>
|
| 332 |
+
void dequantize_q4_K(device const block_q4_K * xb, short il, thread type4x4 & reg) {
|
| 333 |
device const uchar * q = xb->qs;
|
| 334 |
|
| 335 |
short is = (il/4) * 2;
|
|
|
|
| 341 |
const float dl = d * sc[0];
|
| 342 |
const float ml = min * sc[1];
|
| 343 |
|
| 344 |
+
const ushort mask = il < 2 ? 0x0F : 0xF0;
|
| 345 |
for (int i = 0; i < 16; ++i) {
|
| 346 |
reg[i/4][i%4] = dl * (q[i] & mask) - ml;
|
| 347 |
}
|
|
|
|
| 574 |
}
|
| 575 |
}
|
| 576 |
|
| 577 |
+
template <typename type4>
|
| 578 |
+
void dequantize_iq4_nl_t4(device const block_iq4_nl * xb, short il, thread type4 & reg) {
|
| 579 |
+
device const uint16_t * q4 = (device const uint16_t *)xb->qs;
|
| 580 |
+
const float d = xb->d;
|
| 581 |
+
uint32_t aux32;
|
| 582 |
+
thread const uint8_t * q8 = (thread const uint8_t *)&aux32;
|
| 583 |
+
aux32 = ((q4[2*(il%4)] | (q4[2*(il%4)+1] << 16)) >> 4*(il/4)) & 0x0f0f0f0f;
|
| 584 |
+
reg[0] = d * kvalues_iq4nl_f[q8[0]];
|
| 585 |
+
reg[1] = d * kvalues_iq4nl_f[q8[1]];
|
| 586 |
+
reg[2] = d * kvalues_iq4nl_f[q8[2]];
|
| 587 |
+
reg[3] = d * kvalues_iq4nl_f[q8[3]];
|
| 588 |
+
}
|
| 589 |
+
|
| 590 |
template <typename type4x4>
|
| 591 |
void dequantize_iq4_xs(device const block_iq4_xs * xb, short il, thread type4x4 & reg) {
|
| 592 |
// il is 0...15 for QK_K = 256 => index of block of 32 is il/2
|
|
|
|
| 1927 |
kernel_mul_mv_q8_0_f32_impl<constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
|
| 1928 |
}
|
| 1929 |
|
| 1930 |
+
// mat-vec kernel processing in chunks of float4
|
| 1931 |
+
// chpb - chunks per quantization block
|
| 1932 |
+
template<short nxpsg, short r1ptg, typename q_t, short chpb, void (*deq_t4)(device const q_t *, short, thread float4 &) >
|
| 1933 |
+
void kernel_mul_mv_ext_q4_f32_impl(
|
| 1934 |
+
constant ggml_metal_kargs_mul_mv_ext & args,
|
| 1935 |
+
device const char * src0,
|
| 1936 |
+
device const char * src1,
|
| 1937 |
+
device char * dst,
|
| 1938 |
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 1939 |
+
ushort tiisg[[thread_index_in_simdgroup]],
|
| 1940 |
+
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
|
| 1941 |
+
const short chpt = 4; // chunks per thread
|
| 1942 |
+
|
| 1943 |
+
//const short nxpsg = (32);
|
| 1944 |
+
const short nypsg = (32/nxpsg);
|
| 1945 |
+
|
| 1946 |
+
const short tx = tiisg%nxpsg;
|
| 1947 |
+
const short ty = tiisg/nxpsg;
|
| 1948 |
+
|
| 1949 |
+
const int i01 = tgpig.x*(nypsg*args.nsg) + nypsg*sgitg + ty;
|
| 1950 |
+
const int i11 = tgpig.y*r1ptg;
|
| 1951 |
+
const int i1m = tgpig.z;
|
| 1952 |
+
|
| 1953 |
+
const int i12 = i1m%args.ne12;
|
| 1954 |
+
const int i13 = i1m/args.ne12;
|
| 1955 |
+
|
| 1956 |
+
const uint64_t offset0 = i01*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
|
| 1957 |
+
const uint64_t offset1 = i11*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
|
| 1958 |
+
|
| 1959 |
+
device const q_t * xq = (i01 < args.ne01) ? (device const q_t *) (src0 + offset0) + tx/chpb : (device const q_t *) src0;
|
| 1960 |
+
|
| 1961 |
+
device const float4 * y4[r1ptg];
|
| 1962 |
+
|
| 1963 |
+
for (int ir1 = 0; ir1 < r1ptg; ++ir1) {
|
| 1964 |
+
y4[ir1] = (i11 + ir1 < args.ne11) ? (device const float4 *) (src1 + offset1 + ir1*args.nb11) + tx : (device const float4 *) src1;
|
| 1965 |
+
}
|
| 1966 |
+
|
| 1967 |
+
float sumf[r1ptg] = { [ 0 ... r1ptg - 1 ] = 0.0f };
|
| 1968 |
+
|
| 1969 |
+
short cch = tx%chpb; // current chunk index
|
| 1970 |
+
|
| 1971 |
+
for (int ich = tx; 4*ich < args.ne00; ich += chpt*nxpsg) {
|
| 1972 |
+
float4 lx[chpt];
|
| 1973 |
+
|
| 1974 |
+
#pragma unroll(chpt)
|
| 1975 |
+
for (short ch = 0; ch < chpt; ++ch) {
|
| 1976 |
+
deq_t4(xq, cch, lx[ch]);
|
| 1977 |
+
|
| 1978 |
+
cch += nxpsg;
|
| 1979 |
+
if (cch >= chpb) {
|
| 1980 |
+
xq += cch/chpb;
|
| 1981 |
+
cch %= chpb;
|
| 1982 |
+
}
|
| 1983 |
+
}
|
| 1984 |
+
|
| 1985 |
+
#pragma unroll(chpt)
|
| 1986 |
+
for (short ch = 0; ch < chpt; ++ch) {
|
| 1987 |
+
#pragma unroll(r1ptg)
|
| 1988 |
+
for (short ir1 = 0; ir1 < r1ptg; ++ir1) {
|
| 1989 |
+
sumf[ir1] += dot(lx[ch], y4[ir1][ch*nxpsg]);
|
| 1990 |
+
|
| 1991 |
+
}
|
| 1992 |
+
}
|
| 1993 |
+
|
| 1994 |
+
#pragma unroll(r1ptg)
|
| 1995 |
+
for (short ir1 = 0; ir1 < r1ptg; ++ir1) {
|
| 1996 |
+
y4[ir1] += chpt*nxpsg;
|
| 1997 |
+
}
|
| 1998 |
+
}
|
| 1999 |
+
|
| 2000 |
+
// reduce only the threads in each row
|
| 2001 |
+
for (short ir1 = 0; ir1 < r1ptg; ++ir1) {
|
| 2002 |
+
if (nxpsg >= 32) {
|
| 2003 |
+
sumf[ir1] += simd_shuffle_down(sumf[ir1], 16);
|
| 2004 |
+
}
|
| 2005 |
+
if (nxpsg >= 16) {
|
| 2006 |
+
sumf[ir1] += simd_shuffle_down(sumf[ir1], 8);
|
| 2007 |
+
}
|
| 2008 |
+
if (nxpsg >= 8) {
|
| 2009 |
+
sumf[ir1] += simd_shuffle_down(sumf[ir1], 4);
|
| 2010 |
+
}
|
| 2011 |
+
if (nxpsg >= 4) {
|
| 2012 |
+
sumf[ir1] += simd_shuffle_down(sumf[ir1], 2);
|
| 2013 |
+
}
|
| 2014 |
+
if (nxpsg >= 2) {
|
| 2015 |
+
sumf[ir1] += simd_shuffle_down(sumf[ir1], 1);
|
| 2016 |
+
}
|
| 2017 |
+
|
| 2018 |
+
//sumf[ir1] = simd_sum(sumf[ir1]);
|
| 2019 |
+
}
|
| 2020 |
+
|
| 2021 |
+
if (tx == 0) {
|
| 2022 |
+
for (short ir1 = 0; ir1 < r1ptg && i11 + ir1 < args.ne11; ++ir1) {
|
| 2023 |
+
device float * dst_f32 = (device float *) dst + (uint64_t)i1m*args.ne0*args.ne1 + (uint64_t)(i11 + ir1)*args.ne0;
|
| 2024 |
+
|
| 2025 |
+
if (i01 < args.ne01) {
|
| 2026 |
+
dst_f32[i01] = sumf[ir1];
|
| 2027 |
+
}
|
| 2028 |
+
}
|
| 2029 |
+
}
|
| 2030 |
+
}
|
| 2031 |
+
|
| 2032 |
+
// mat-vec kernel processing in chunks of float4x4
|
| 2033 |
+
template<short nxpsg, short r1ptg, typename q_t, short chpb, void (*deq_t4x4)(device const q_t *, short, thread float4x4 &) >
|
| 2034 |
+
void kernel_mul_mv_ext_q4x4_f32_impl(
|
| 2035 |
+
constant ggml_metal_kargs_mul_mv_ext & args,
|
| 2036 |
+
device const char * src0,
|
| 2037 |
+
device const char * src1,
|
| 2038 |
+
device char * dst,
|
| 2039 |
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 2040 |
+
ushort tiisg[[thread_index_in_simdgroup]],
|
| 2041 |
+
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
|
| 2042 |
+
const short chpt = 1;
|
| 2043 |
+
|
| 2044 |
+
//const short nxpsg = (32);
|
| 2045 |
+
const short nypsg = (32/nxpsg);
|
| 2046 |
+
|
| 2047 |
+
const short tx = tiisg%nxpsg;
|
| 2048 |
+
const short ty = tiisg/nxpsg;
|
| 2049 |
+
|
| 2050 |
+
const int i01 = tgpig.x*(nypsg*args.nsg) + nypsg*sgitg + ty;
|
| 2051 |
+
const int i11 = tgpig.y*r1ptg;
|
| 2052 |
+
const int i1m = tgpig.z;
|
| 2053 |
+
|
| 2054 |
+
const int i12 = i1m%args.ne12;
|
| 2055 |
+
const int i13 = i1m/args.ne12;
|
| 2056 |
+
|
| 2057 |
+
const uint64_t offset0 = i01*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
|
| 2058 |
+
const uint64_t offset1 = i11*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
|
| 2059 |
+
|
| 2060 |
+
device const q_t * xq = (i01 < args.ne01) ? (device const q_t *) (src0 + offset0) + tx/chpb : (device const q_t *) src0;
|
| 2061 |
+
|
| 2062 |
+
device const float4x4 * y4x4[r1ptg];
|
| 2063 |
+
|
| 2064 |
+
for (int ir1 = 0; ir1 < r1ptg; ++ir1) {
|
| 2065 |
+
y4x4[ir1] = (i11 + ir1 < args.ne11) ? (device const float4x4 *) (src1 + offset1 + ir1*args.nb11) + tx : (device const float4x4 *) src1;
|
| 2066 |
+
}
|
| 2067 |
+
|
| 2068 |
+
float sumf[r1ptg] = { [ 0 ... r1ptg - 1 ] = 0.0f };
|
| 2069 |
+
|
| 2070 |
+
short cch = tx%chpb;
|
| 2071 |
+
|
| 2072 |
+
for (int ich = tx; 16*ich < args.ne00; ich += chpt*nxpsg) {
|
| 2073 |
+
float4x4 lx[chpt];
|
| 2074 |
+
|
| 2075 |
+
#pragma unroll(chpt)
|
| 2076 |
+
for (short ch = 0; ch < chpt; ++ch) {
|
| 2077 |
+
deq_t4x4(xq, cch, lx[ch]);
|
| 2078 |
+
|
| 2079 |
+
cch += nxpsg;
|
| 2080 |
+
if (cch >= chpb) {
|
| 2081 |
+
xq += cch/chpb;
|
| 2082 |
+
cch %= chpb;
|
| 2083 |
+
}
|
| 2084 |
+
}
|
| 2085 |
+
|
| 2086 |
+
#pragma unroll(chpt)
|
| 2087 |
+
for (short ch = 0; ch < chpt; ++ch) {
|
| 2088 |
+
#pragma unroll(r1ptg)
|
| 2089 |
+
for (short ir1 = 0; ir1 < r1ptg; ++ir1) {
|
| 2090 |
+
sumf[ir1] +=
|
| 2091 |
+
dot(lx[ch][0], y4x4[ir1][ch*nxpsg][0]) +
|
| 2092 |
+
dot(lx[ch][1], y4x4[ir1][ch*nxpsg][1]) +
|
| 2093 |
+
dot(lx[ch][2], y4x4[ir1][ch*nxpsg][2]) +
|
| 2094 |
+
dot(lx[ch][3], y4x4[ir1][ch*nxpsg][3]);
|
| 2095 |
+
|
| 2096 |
+
}
|
| 2097 |
+
}
|
| 2098 |
+
|
| 2099 |
+
#pragma unroll(r1ptg)
|
| 2100 |
+
for (short ir1 = 0; ir1 < r1ptg; ++ir1) {
|
| 2101 |
+
y4x4[ir1] += chpt*nxpsg;
|
| 2102 |
+
}
|
| 2103 |
+
}
|
| 2104 |
+
|
| 2105 |
+
for (short ir1 = 0; ir1 < r1ptg; ++ir1) {
|
| 2106 |
+
if (nxpsg >= 32) {
|
| 2107 |
+
sumf[ir1] += simd_shuffle_down(sumf[ir1], 16);
|
| 2108 |
+
}
|
| 2109 |
+
if (nxpsg >= 16) {
|
| 2110 |
+
sumf[ir1] += simd_shuffle_down(sumf[ir1], 8);
|
| 2111 |
+
}
|
| 2112 |
+
if (nxpsg >= 8) {
|
| 2113 |
+
sumf[ir1] += simd_shuffle_down(sumf[ir1], 4);
|
| 2114 |
+
}
|
| 2115 |
+
if (nxpsg >= 4) {
|
| 2116 |
+
sumf[ir1] += simd_shuffle_down(sumf[ir1], 2);
|
| 2117 |
+
}
|
| 2118 |
+
if (nxpsg >= 2) {
|
| 2119 |
+
sumf[ir1] += simd_shuffle_down(sumf[ir1], 1);
|
| 2120 |
+
}
|
| 2121 |
+
|
| 2122 |
+
//sumf[ir1] = simd_sum(sumf[ir1]);
|
| 2123 |
+
}
|
| 2124 |
+
|
| 2125 |
+
if (tx == 0) {
|
| 2126 |
+
for (short ir1 = 0; ir1 < r1ptg && i11 + ir1 < args.ne11; ++ir1) {
|
| 2127 |
+
device float * dst_f32 = (device float *) dst + (uint64_t)i1m*args.ne0*args.ne1 + (uint64_t)(i11 + ir1)*args.ne0;
|
| 2128 |
+
|
| 2129 |
+
if (i01 < args.ne01) {
|
| 2130 |
+
dst_f32[i01] = sumf[ir1];
|
| 2131 |
+
}
|
| 2132 |
+
}
|
| 2133 |
+
}
|
| 2134 |
+
}
|
| 2135 |
+
|
| 2136 |
+
// dispatchers needed for compile-time nxpsg
|
| 2137 |
+
// epb - elements per quantization block
|
| 2138 |
+
template<short r1ptg, typename q_t, short epb, void (*deq_t4)(device const q_t *, short, thread float4 &)>
|
| 2139 |
+
kernel void kernel_mul_mv_ext_q4_f32_disp(
|
| 2140 |
+
constant ggml_metal_kargs_mul_mv_ext & args,
|
| 2141 |
+
device const char * src0,
|
| 2142 |
+
device const char * src1,
|
| 2143 |
+
device char * dst,
|
| 2144 |
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 2145 |
+
ushort tiisg[[thread_index_in_simdgroup]],
|
| 2146 |
+
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
|
| 2147 |
+
switch (args.nxpsg) {
|
| 2148 |
+
case 4: kernel_mul_mv_ext_q4_f32_impl<4, r1ptg, q_t, epb/4, deq_t4>(args, src0, src1, dst, tgpig, tiisg, sgitg); break;
|
| 2149 |
+
case 8: kernel_mul_mv_ext_q4_f32_impl<8, r1ptg, q_t, epb/4, deq_t4>(args, src0, src1, dst, tgpig, tiisg, sgitg); break;
|
| 2150 |
+
case 16: kernel_mul_mv_ext_q4_f32_impl<16, r1ptg, q_t, epb/4, deq_t4>(args, src0, src1, dst, tgpig, tiisg, sgitg); break;
|
| 2151 |
+
case 32: kernel_mul_mv_ext_q4_f32_impl<32, r1ptg, q_t, epb/4, deq_t4>(args, src0, src1, dst, tgpig, tiisg, sgitg); break;
|
| 2152 |
+
}
|
| 2153 |
+
}
|
| 2154 |
+
|
| 2155 |
+
template<short r1ptg, typename q_t, short epb, void (*deq_t4x4)(device const q_t *, short, thread float4x4 &)>
|
| 2156 |
+
kernel void kernel_mul_mv_ext_q4x4_f32_disp(
|
| 2157 |
+
constant ggml_metal_kargs_mul_mv_ext & args,
|
| 2158 |
+
device const char * src0,
|
| 2159 |
+
device const char * src1,
|
| 2160 |
+
device char * dst,
|
| 2161 |
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 2162 |
+
ushort tiisg[[thread_index_in_simdgroup]],
|
| 2163 |
+
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
|
| 2164 |
+
switch (args.nxpsg) {
|
| 2165 |
+
case 4: kernel_mul_mv_ext_q4x4_f32_impl<4, r1ptg, q_t, epb/16, deq_t4x4>(args, src0, src1, dst, tgpig, tiisg, sgitg); break;
|
| 2166 |
+
case 8: kernel_mul_mv_ext_q4x4_f32_impl<8, r1ptg, q_t, epb/16, deq_t4x4>(args, src0, src1, dst, tgpig, tiisg, sgitg); break;
|
| 2167 |
+
case 16: kernel_mul_mv_ext_q4x4_f32_impl<16, r1ptg, q_t, epb/16, deq_t4x4>(args, src0, src1, dst, tgpig, tiisg, sgitg); break;
|
| 2168 |
+
case 32: kernel_mul_mv_ext_q4x4_f32_impl<32, r1ptg, q_t, epb/16, deq_t4x4>(args, src0, src1, dst, tgpig, tiisg, sgitg); break;
|
| 2169 |
+
}
|
| 2170 |
+
}
|
| 2171 |
+
|
| 2172 |
+
typedef decltype(kernel_mul_mv_ext_q4_f32_disp <2, block_q8_0, 32, dequantize_q8_0_t4>) mul_mv_ext_q4_f32_t;
|
| 2173 |
+
typedef decltype(kernel_mul_mv_ext_q4x4_f32_disp<2, block_q4_K, 256, dequantize_q4_K>) mul_mv_ext_q4x4_f32_t;
|
| 2174 |
+
|
| 2175 |
+
template [[host_name("kernel_mul_mv_ext_f16_f32_r1_2")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<2, half4, 4, dequantize_f16_t4>;
|
| 2176 |
+
template [[host_name("kernel_mul_mv_ext_f16_f32_r1_3")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<3, half4, 4, dequantize_f16_t4>;
|
| 2177 |
+
template [[host_name("kernel_mul_mv_ext_f16_f32_r1_4")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, half4, 4, dequantize_f16_t4>;
|
| 2178 |
+
template [[host_name("kernel_mul_mv_ext_f16_f32_r1_5")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<5, half4, 4, dequantize_f16_t4>;
|
| 2179 |
+
|
| 2180 |
+
template [[host_name("kernel_mul_mv_ext_q4_0_f32_r1_2")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<2, block_q4_0, 32, dequantize_q4_0_t4>;
|
| 2181 |
+
template [[host_name("kernel_mul_mv_ext_q4_0_f32_r1_3")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<3, block_q4_0, 32, dequantize_q4_0_t4>;
|
| 2182 |
+
template [[host_name("kernel_mul_mv_ext_q4_0_f32_r1_4")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, block_q4_0, 32, dequantize_q4_0_t4>;
|
| 2183 |
+
template [[host_name("kernel_mul_mv_ext_q4_0_f32_r1_5")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<5, block_q4_0, 32, dequantize_q4_0_t4>;
|
| 2184 |
+
|
| 2185 |
+
template [[host_name("kernel_mul_mv_ext_q4_1_f32_r1_2")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<2, block_q4_1, 32, dequantize_q4_1_t4>;
|
| 2186 |
+
template [[host_name("kernel_mul_mv_ext_q4_1_f32_r1_3")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<3, block_q4_1, 32, dequantize_q4_1_t4>;
|
| 2187 |
+
template [[host_name("kernel_mul_mv_ext_q4_1_f32_r1_4")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, block_q4_1, 32, dequantize_q4_1_t4>;
|
| 2188 |
+
template [[host_name("kernel_mul_mv_ext_q4_1_f32_r1_5")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<5, block_q4_1, 32, dequantize_q4_1_t4>;
|
| 2189 |
+
|
| 2190 |
+
template [[host_name("kernel_mul_mv_ext_q5_0_f32_r1_2")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<2, block_q5_0, 32, dequantize_q5_0_t4>;
|
| 2191 |
+
template [[host_name("kernel_mul_mv_ext_q5_0_f32_r1_3")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<3, block_q5_0, 32, dequantize_q5_0_t4>;
|
| 2192 |
+
template [[host_name("kernel_mul_mv_ext_q5_0_f32_r1_4")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, block_q5_0, 32, dequantize_q5_0_t4>;
|
| 2193 |
+
template [[host_name("kernel_mul_mv_ext_q5_0_f32_r1_5")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<5, block_q5_0, 32, dequantize_q5_0_t4>;
|
| 2194 |
+
|
| 2195 |
+
template [[host_name("kernel_mul_mv_ext_q5_1_f32_r1_2")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<2, block_q5_1, 32, dequantize_q5_1_t4>;
|
| 2196 |
+
template [[host_name("kernel_mul_mv_ext_q5_1_f32_r1_3")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<3, block_q5_1, 32, dequantize_q5_1_t4>;
|
| 2197 |
+
template [[host_name("kernel_mul_mv_ext_q5_1_f32_r1_4")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, block_q5_1, 32, dequantize_q5_1_t4>;
|
| 2198 |
+
template [[host_name("kernel_mul_mv_ext_q5_1_f32_r1_5")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<5, block_q5_1, 32, dequantize_q5_1_t4>;
|
| 2199 |
+
|
| 2200 |
+
template [[host_name("kernel_mul_mv_ext_q8_0_f32_r1_2")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<2, block_q8_0, 32, dequantize_q8_0_t4>;
|
| 2201 |
+
template [[host_name("kernel_mul_mv_ext_q8_0_f32_r1_3")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<3, block_q8_0, 32, dequantize_q8_0_t4>;
|
| 2202 |
+
template [[host_name("kernel_mul_mv_ext_q8_0_f32_r1_4")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, block_q8_0, 32, dequantize_q8_0_t4>;
|
| 2203 |
+
template [[host_name("kernel_mul_mv_ext_q8_0_f32_r1_5")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<5, block_q8_0, 32, dequantize_q8_0_t4>;
|
| 2204 |
+
|
| 2205 |
+
template [[host_name("kernel_mul_mv_ext_iq4_nl_f32_r1_2")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<2, block_iq4_nl, 32, dequantize_iq4_nl_t4>;
|
| 2206 |
+
template [[host_name("kernel_mul_mv_ext_iq4_nl_f32_r1_3")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<3, block_iq4_nl, 32, dequantize_iq4_nl_t4>;
|
| 2207 |
+
template [[host_name("kernel_mul_mv_ext_iq4_nl_f32_r1_4")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, block_iq4_nl, 32, dequantize_iq4_nl_t4>;
|
| 2208 |
+
template [[host_name("kernel_mul_mv_ext_iq4_nl_f32_r1_5")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<5, block_iq4_nl, 32, dequantize_iq4_nl_t4>;
|
| 2209 |
+
|
| 2210 |
+
template [[host_name("kernel_mul_mv_ext_q4_K_f32_r1_2")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<2, block_q4_K, 256, dequantize_q4_K>;
|
| 2211 |
+
template [[host_name("kernel_mul_mv_ext_q4_K_f32_r1_3")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<3, block_q4_K, 256, dequantize_q4_K>;
|
| 2212 |
+
template [[host_name("kernel_mul_mv_ext_q4_K_f32_r1_4")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<4, block_q4_K, 256, dequantize_q4_K>;
|
| 2213 |
+
template [[host_name("kernel_mul_mv_ext_q4_K_f32_r1_5")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<5, block_q4_K, 256, dequantize_q4_K>;
|
| 2214 |
+
|
| 2215 |
+
template [[host_name("kernel_mul_mv_ext_q5_K_f32_r1_2")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<2, block_q5_K, 256, dequantize_q5_K>;
|
| 2216 |
+
template [[host_name("kernel_mul_mv_ext_q5_K_f32_r1_3")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<3, block_q5_K, 256, dequantize_q5_K>;
|
| 2217 |
+
template [[host_name("kernel_mul_mv_ext_q5_K_f32_r1_4")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<4, block_q5_K, 256, dequantize_q5_K>;
|
| 2218 |
+
template [[host_name("kernel_mul_mv_ext_q5_K_f32_r1_5")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<5, block_q5_K, 256, dequantize_q5_K>;
|
| 2219 |
+
|
| 2220 |
+
template [[host_name("kernel_mul_mv_ext_q6_K_f32_r1_2")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<2, block_q6_K, 256, dequantize_q6_K>;
|
| 2221 |
+
template [[host_name("kernel_mul_mv_ext_q6_K_f32_r1_3")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<3, block_q6_K, 256, dequantize_q6_K>;
|
| 2222 |
+
template [[host_name("kernel_mul_mv_ext_q6_K_f32_r1_4")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<4, block_q6_K, 256, dequantize_q6_K>;
|
| 2223 |
+
template [[host_name("kernel_mul_mv_ext_q6_K_f32_r1_5")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<5, block_q6_K, 256, dequantize_q6_K>;
|
| 2224 |
+
|
| 2225 |
#define N_MV_T_T 4
|
| 2226 |
|
| 2227 |
template<typename T0, typename T04, typename T1, typename T14, typename args_t>
|