Spaces:
Running
Running
metal : add special-case mat-vec mul for ne00 == 4 (llama/14385)
Browse files
ggml/src/ggml-metal/ggml-metal.m
CHANGED
|
@@ -211,11 +211,14 @@ enum ggml_metal_kernel_type {
|
|
| 211 |
GGML_METAL_KERNEL_TYPE_RWKV_WKV6_F32,
|
| 212 |
GGML_METAL_KERNEL_TYPE_RWKV_WKV7_F32,
|
| 213 |
GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32,
|
|
|
|
| 214 |
GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32,
|
|
|
|
| 215 |
GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW,
|
| 216 |
GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4,
|
| 217 |
GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16,
|
| 218 |
GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32,
|
|
|
|
| 219 |
GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_1ROW,
|
| 220 |
GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_L4,
|
| 221 |
GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_BF16,
|
|
@@ -1175,11 +1178,14 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
|
|
| 1175 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RWKV_WKV6_F32, rwkv_wkv6_f32, true);
|
| 1176 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RWKV_WKV7_F32, rwkv_wkv7_f32, true);
|
| 1177 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32, mul_mv_f32_f32, has_simdgroup_reduction);
|
|
|
|
| 1178 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32, mul_mv_bf16_f32, has_simdgroup_reduction && use_bfloat);
|
|
|
|
| 1179 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_1ROW, mul_mv_bf16_f32_1row, has_simdgroup_reduction && use_bfloat);
|
| 1180 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_L4, mul_mv_bf16_f32_l4, has_simdgroup_reduction && use_bfloat);
|
| 1181 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_BF16, mul_mv_bf16_bf16, has_simdgroup_reduction && use_bfloat);
|
| 1182 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32, mul_mv_f16_f32, has_simdgroup_reduction);
|
|
|
|
| 1183 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW, mul_mv_f16_f32_1row, has_simdgroup_reduction);
|
| 1184 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4, mul_mv_f16_f32_l4, has_simdgroup_reduction);
|
| 1185 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16, mul_mv_f16_f16, has_simdgroup_reduction);
|
|
@@ -3111,14 +3117,23 @@ static bool ggml_metal_encode_node(
|
|
| 3111 |
nsg = 1;
|
| 3112 |
nr0 = 1;
|
| 3113 |
nr1 = 4;
|
| 3114 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3115 |
} break;
|
| 3116 |
case GGML_TYPE_F16:
|
| 3117 |
{
|
| 3118 |
nsg = 1;
|
| 3119 |
nr0 = 1;
|
| 3120 |
if (src1t == GGML_TYPE_F32) {
|
| 3121 |
-
if (
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3122 |
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW].pipeline;
|
| 3123 |
} else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) {
|
| 3124 |
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4].pipeline;
|
|
@@ -3137,7 +3152,11 @@ static bool ggml_metal_encode_node(
|
|
| 3137 |
nsg = 1;
|
| 3138 |
nr0 = 1;
|
| 3139 |
if (src1t == GGML_TYPE_F32) {
|
| 3140 |
-
if (
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3141 |
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_1ROW].pipeline;
|
| 3142 |
} else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) {
|
| 3143 |
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_L4].pipeline;
|
|
|
|
| 211 |
GGML_METAL_KERNEL_TYPE_RWKV_WKV6_F32,
|
| 212 |
GGML_METAL_KERNEL_TYPE_RWKV_WKV7_F32,
|
| 213 |
GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32,
|
| 214 |
+
GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32_C4,
|
| 215 |
GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32,
|
| 216 |
+
GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_C4,
|
| 217 |
GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW,
|
| 218 |
GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4,
|
| 219 |
GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16,
|
| 220 |
GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32,
|
| 221 |
+
GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_C4,
|
| 222 |
GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_1ROW,
|
| 223 |
GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_L4,
|
| 224 |
GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_BF16,
|
|
|
|
| 1178 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RWKV_WKV6_F32, rwkv_wkv6_f32, true);
|
| 1179 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RWKV_WKV7_F32, rwkv_wkv7_f32, true);
|
| 1180 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32, mul_mv_f32_f32, has_simdgroup_reduction);
|
| 1181 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32_C4, mul_mv_f32_f32_c4, true);
|
| 1182 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32, mul_mv_bf16_f32, has_simdgroup_reduction && use_bfloat);
|
| 1183 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_C4, mul_mv_bf16_f32_c4, use_bfloat);
|
| 1184 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_1ROW, mul_mv_bf16_f32_1row, has_simdgroup_reduction && use_bfloat);
|
| 1185 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_L4, mul_mv_bf16_f32_l4, has_simdgroup_reduction && use_bfloat);
|
| 1186 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_BF16, mul_mv_bf16_bf16, has_simdgroup_reduction && use_bfloat);
|
| 1187 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32, mul_mv_f16_f32, has_simdgroup_reduction);
|
| 1188 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_C4, mul_mv_f16_f32_c4, true);
|
| 1189 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW, mul_mv_f16_f32_1row, has_simdgroup_reduction);
|
| 1190 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4, mul_mv_f16_f32_l4, has_simdgroup_reduction);
|
| 1191 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16, mul_mv_f16_f16, has_simdgroup_reduction);
|
|
|
|
| 3117 |
nsg = 1;
|
| 3118 |
nr0 = 1;
|
| 3119 |
nr1 = 4;
|
| 3120 |
+
if (ne00 == 4) {
|
| 3121 |
+
nr0 = 32;
|
| 3122 |
+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32_C4].pipeline;
|
| 3123 |
+
} else {
|
| 3124 |
+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32].pipeline;
|
| 3125 |
+
}
|
| 3126 |
} break;
|
| 3127 |
case GGML_TYPE_F16:
|
| 3128 |
{
|
| 3129 |
nsg = 1;
|
| 3130 |
nr0 = 1;
|
| 3131 |
if (src1t == GGML_TYPE_F32) {
|
| 3132 |
+
if (ne00 == 4) {
|
| 3133 |
+
nr0 = 32;
|
| 3134 |
+
nr1 = 4;
|
| 3135 |
+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_C4].pipeline;
|
| 3136 |
+
} else if (ne11 * ne12 < 4) {
|
| 3137 |
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW].pipeline;
|
| 3138 |
} else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) {
|
| 3139 |
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4].pipeline;
|
|
|
|
| 3152 |
nsg = 1;
|
| 3153 |
nr0 = 1;
|
| 3154 |
if (src1t == GGML_TYPE_F32) {
|
| 3155 |
+
if (ne00 == 4) {
|
| 3156 |
+
nr0 = 32;
|
| 3157 |
+
nr1 = 4;
|
| 3158 |
+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_C4].pipeline;
|
| 3159 |
+
} else if (ne11 * ne12 < 4) {
|
| 3160 |
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_1ROW].pipeline;
|
| 3161 |
} else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) {
|
| 3162 |
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_L4].pipeline;
|
ggml/src/ggml-metal/ggml-metal.metal
CHANGED
|
@@ -2532,6 +2532,70 @@ template [[host_name("kernel_mul_mv_bf16_f32")]] kernel mul_mv_t kernel_mul_mv<
|
|
| 2532 |
template [[host_name("kernel_mul_mv_bf16_bf16")]] kernel mul_mv_t kernel_mul_mv<bfloat, bfloat4, bfloat, bfloat4>;
|
| 2533 |
#endif
|
| 2534 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2535 |
template<typename T, typename T4>
|
| 2536 |
kernel void kernel_mul_mv_1row(
|
| 2537 |
constant ggml_metal_kargs_mul_mv & args,
|
|
|
|
| 2532 |
template [[host_name("kernel_mul_mv_bf16_bf16")]] kernel mul_mv_t kernel_mul_mv<bfloat, bfloat4, bfloat, bfloat4>;
|
| 2533 |
#endif
|
| 2534 |
|
| 2535 |
+
template<typename T04, typename T14, typename args_t>
|
| 2536 |
+
void kernel_mul_mv_c4_impl(
|
| 2537 |
+
args_t args,
|
| 2538 |
+
device const char * src0,
|
| 2539 |
+
device const char * src1,
|
| 2540 |
+
device char * dst,
|
| 2541 |
+
uint3 tgpig,
|
| 2542 |
+
ushort tiisg) {
|
| 2543 |
+
const int r0 = tgpig.x*32 + tiisg;
|
| 2544 |
+
const int rb = tgpig.y*N_MV_T_T;
|
| 2545 |
+
const int im = tgpig.z;
|
| 2546 |
+
|
| 2547 |
+
if (r0 >= args.ne01) {
|
| 2548 |
+
return;
|
| 2549 |
+
}
|
| 2550 |
+
|
| 2551 |
+
const uint i12 = im%args.ne12;
|
| 2552 |
+
const uint i13 = im/args.ne12;
|
| 2553 |
+
|
| 2554 |
+
const uint64_t offset0 = r0*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
|
| 2555 |
+
|
| 2556 |
+
device const T04 * x = (device const T04 *) (src0 + offset0);
|
| 2557 |
+
|
| 2558 |
+
device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1;
|
| 2559 |
+
|
| 2560 |
+
for (int row = 0; row < N_MV_T_T; ++row) {
|
| 2561 |
+
int r1 = rb + row;
|
| 2562 |
+
if (r1 >= args.ne11) {
|
| 2563 |
+
break;
|
| 2564 |
+
}
|
| 2565 |
+
|
| 2566 |
+
const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
|
| 2567 |
+
|
| 2568 |
+
device const T14 * y = (device const T14 *) (src1 + offset1);
|
| 2569 |
+
|
| 2570 |
+
dst_f32[(uint64_t)r1*args.ne0 + r0] = dot((float4) x[0], (float4) y[0]);
|
| 2571 |
+
}
|
| 2572 |
+
}
|
| 2573 |
+
|
| 2574 |
+
template<typename T04, typename T14>
|
| 2575 |
+
kernel void kernel_mul_mv_c4(
|
| 2576 |
+
constant ggml_metal_kargs_mul_mv & args,
|
| 2577 |
+
device const char * src0,
|
| 2578 |
+
device const char * src1,
|
| 2579 |
+
device char * dst,
|
| 2580 |
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 2581 |
+
ushort tiisg[[thread_index_in_simdgroup]]) {
|
| 2582 |
+
kernel_mul_mv_c4_impl<T04, T14, constant ggml_metal_kargs_mul_mv &>(
|
| 2583 |
+
args,
|
| 2584 |
+
src0,
|
| 2585 |
+
src1,
|
| 2586 |
+
dst,
|
| 2587 |
+
tgpig,
|
| 2588 |
+
tiisg);
|
| 2589 |
+
}
|
| 2590 |
+
|
| 2591 |
+
typedef decltype(kernel_mul_mv_c4<half4, half4>) mul_mv_c4_t;
|
| 2592 |
+
|
| 2593 |
+
template [[host_name("kernel_mul_mv_f32_f32_c4")]] kernel mul_mv_c4_t kernel_mul_mv_c4<float4, float4>;
|
| 2594 |
+
template [[host_name("kernel_mul_mv_f16_f32_c4")]] kernel mul_mv_c4_t kernel_mul_mv_c4<half4, float4>;
|
| 2595 |
+
#if defined(GGML_METAL_USE_BF16)
|
| 2596 |
+
template [[host_name("kernel_mul_mv_bf16_f32_c4")]] kernel mul_mv_c4_t kernel_mul_mv_c4<bfloat4, float4>;
|
| 2597 |
+
#endif
|
| 2598 |
+
|
| 2599 |
template<typename T, typename T4>
|
| 2600 |
kernel void kernel_mul_mv_1row(
|
| 2601 |
constant ggml_metal_kargs_mul_mv & args,
|