ggerganov commited on
Commit
724622d
·
1 Parent(s): b4ff704

metal : add special-case mat-vec mul for ne00 == 4 (llama/14385)

Browse files
ggml/src/ggml-metal/ggml-metal.m CHANGED
@@ -211,11 +211,14 @@ enum ggml_metal_kernel_type {
211
  GGML_METAL_KERNEL_TYPE_RWKV_WKV6_F32,
212
  GGML_METAL_KERNEL_TYPE_RWKV_WKV7_F32,
213
  GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32,
 
214
  GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32,
 
215
  GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW,
216
  GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4,
217
  GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16,
218
  GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32,
 
219
  GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_1ROW,
220
  GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_L4,
221
  GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_BF16,
@@ -1175,11 +1178,14 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
1175
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RWKV_WKV6_F32, rwkv_wkv6_f32, true);
1176
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RWKV_WKV7_F32, rwkv_wkv7_f32, true);
1177
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32, mul_mv_f32_f32, has_simdgroup_reduction);
 
1178
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32, mul_mv_bf16_f32, has_simdgroup_reduction && use_bfloat);
 
1179
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_1ROW, mul_mv_bf16_f32_1row, has_simdgroup_reduction && use_bfloat);
1180
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_L4, mul_mv_bf16_f32_l4, has_simdgroup_reduction && use_bfloat);
1181
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_BF16, mul_mv_bf16_bf16, has_simdgroup_reduction && use_bfloat);
1182
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32, mul_mv_f16_f32, has_simdgroup_reduction);
 
1183
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW, mul_mv_f16_f32_1row, has_simdgroup_reduction);
1184
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4, mul_mv_f16_f32_l4, has_simdgroup_reduction);
1185
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16, mul_mv_f16_f16, has_simdgroup_reduction);
@@ -3111,14 +3117,23 @@ static bool ggml_metal_encode_node(
3111
  nsg = 1;
3112
  nr0 = 1;
3113
  nr1 = 4;
3114
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32].pipeline;
 
 
 
 
 
3115
  } break;
3116
  case GGML_TYPE_F16:
3117
  {
3118
  nsg = 1;
3119
  nr0 = 1;
3120
  if (src1t == GGML_TYPE_F32) {
3121
- if (ne11 * ne12 < 4) {
 
 
 
 
3122
  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW].pipeline;
3123
  } else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) {
3124
  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4].pipeline;
@@ -3137,7 +3152,11 @@ static bool ggml_metal_encode_node(
3137
  nsg = 1;
3138
  nr0 = 1;
3139
  if (src1t == GGML_TYPE_F32) {
3140
- if (ne11 * ne12 < 4) {
 
 
 
 
3141
  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_1ROW].pipeline;
3142
  } else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) {
3143
  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_L4].pipeline;
 
211
  GGML_METAL_KERNEL_TYPE_RWKV_WKV6_F32,
212
  GGML_METAL_KERNEL_TYPE_RWKV_WKV7_F32,
213
  GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32,
214
+ GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32_C4,
215
  GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32,
216
+ GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_C4,
217
  GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW,
218
  GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4,
219
  GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16,
220
  GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32,
221
+ GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_C4,
222
  GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_1ROW,
223
  GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_L4,
224
  GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_BF16,
 
1178
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RWKV_WKV6_F32, rwkv_wkv6_f32, true);
1179
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RWKV_WKV7_F32, rwkv_wkv7_f32, true);
1180
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32, mul_mv_f32_f32, has_simdgroup_reduction);
1181
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32_C4, mul_mv_f32_f32_c4, true);
1182
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32, mul_mv_bf16_f32, has_simdgroup_reduction && use_bfloat);
1183
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_C4, mul_mv_bf16_f32_c4, use_bfloat);
1184
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_1ROW, mul_mv_bf16_f32_1row, has_simdgroup_reduction && use_bfloat);
1185
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_L4, mul_mv_bf16_f32_l4, has_simdgroup_reduction && use_bfloat);
1186
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_BF16, mul_mv_bf16_bf16, has_simdgroup_reduction && use_bfloat);
1187
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32, mul_mv_f16_f32, has_simdgroup_reduction);
1188
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_C4, mul_mv_f16_f32_c4, true);
1189
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW, mul_mv_f16_f32_1row, has_simdgroup_reduction);
1190
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4, mul_mv_f16_f32_l4, has_simdgroup_reduction);
1191
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16, mul_mv_f16_f16, has_simdgroup_reduction);
 
3117
  nsg = 1;
3118
  nr0 = 1;
3119
  nr1 = 4;
3120
+ if (ne00 == 4) {
3121
+ nr0 = 32;
3122
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32_C4].pipeline;
3123
+ } else {
3124
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32].pipeline;
3125
+ }
3126
  } break;
3127
  case GGML_TYPE_F16:
3128
  {
3129
  nsg = 1;
3130
  nr0 = 1;
3131
  if (src1t == GGML_TYPE_F32) {
3132
+ if (ne00 == 4) {
3133
+ nr0 = 32;
3134
+ nr1 = 4;
3135
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_C4].pipeline;
3136
+ } else if (ne11 * ne12 < 4) {
3137
  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW].pipeline;
3138
  } else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) {
3139
  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4].pipeline;
 
3152
  nsg = 1;
3153
  nr0 = 1;
3154
  if (src1t == GGML_TYPE_F32) {
3155
+ if (ne00 == 4) {
3156
+ nr0 = 32;
3157
+ nr1 = 4;
3158
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_C4].pipeline;
3159
+ } else if (ne11 * ne12 < 4) {
3160
  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_1ROW].pipeline;
3161
  } else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) {
3162
  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_L4].pipeline;
ggml/src/ggml-metal/ggml-metal.metal CHANGED
@@ -2532,6 +2532,70 @@ template [[host_name("kernel_mul_mv_bf16_f32")]] kernel mul_mv_t kernel_mul_mv<
2532
  template [[host_name("kernel_mul_mv_bf16_bf16")]] kernel mul_mv_t kernel_mul_mv<bfloat, bfloat4, bfloat, bfloat4>;
2533
  #endif
2534
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2535
  template<typename T, typename T4>
2536
  kernel void kernel_mul_mv_1row(
2537
  constant ggml_metal_kargs_mul_mv & args,
 
2532
  template [[host_name("kernel_mul_mv_bf16_bf16")]] kernel mul_mv_t kernel_mul_mv<bfloat, bfloat4, bfloat, bfloat4>;
2533
  #endif
2534
 
2535
+ template<typename T04, typename T14, typename args_t>
2536
+ void kernel_mul_mv_c4_impl(
2537
+ args_t args,
2538
+ device const char * src0,
2539
+ device const char * src1,
2540
+ device char * dst,
2541
+ uint3 tgpig,
2542
+ ushort tiisg) {
2543
+ const int r0 = tgpig.x*32 + tiisg;
2544
+ const int rb = tgpig.y*N_MV_T_T;
2545
+ const int im = tgpig.z;
2546
+
2547
+ if (r0 >= args.ne01) {
2548
+ return;
2549
+ }
2550
+
2551
+ const uint i12 = im%args.ne12;
2552
+ const uint i13 = im/args.ne12;
2553
+
2554
+ const uint64_t offset0 = r0*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
2555
+
2556
+ device const T04 * x = (device const T04 *) (src0 + offset0);
2557
+
2558
+ device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1;
2559
+
2560
+ for (int row = 0; row < N_MV_T_T; ++row) {
2561
+ int r1 = rb + row;
2562
+ if (r1 >= args.ne11) {
2563
+ break;
2564
+ }
2565
+
2566
+ const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
2567
+
2568
+ device const T14 * y = (device const T14 *) (src1 + offset1);
2569
+
2570
+ dst_f32[(uint64_t)r1*args.ne0 + r0] = dot((float4) x[0], (float4) y[0]);
2571
+ }
2572
+ }
2573
+
2574
+ template<typename T04, typename T14>
2575
+ kernel void kernel_mul_mv_c4(
2576
+ constant ggml_metal_kargs_mul_mv & args,
2577
+ device const char * src0,
2578
+ device const char * src1,
2579
+ device char * dst,
2580
+ uint3 tgpig[[threadgroup_position_in_grid]],
2581
+ ushort tiisg[[thread_index_in_simdgroup]]) {
2582
+ kernel_mul_mv_c4_impl<T04, T14, constant ggml_metal_kargs_mul_mv &>(
2583
+ args,
2584
+ src0,
2585
+ src1,
2586
+ dst,
2587
+ tgpig,
2588
+ tiisg);
2589
+ }
2590
+
2591
+ typedef decltype(kernel_mul_mv_c4<half4, half4>) mul_mv_c4_t;
2592
+
2593
+ template [[host_name("kernel_mul_mv_f32_f32_c4")]] kernel mul_mv_c4_t kernel_mul_mv_c4<float4, float4>;
2594
+ template [[host_name("kernel_mul_mv_f16_f32_c4")]] kernel mul_mv_c4_t kernel_mul_mv_c4<half4, float4>;
2595
+ #if defined(GGML_METAL_USE_BF16)
2596
+ template [[host_name("kernel_mul_mv_bf16_f32_c4")]] kernel mul_mv_c4_t kernel_mul_mv_c4<bfloat4, float4>;
2597
+ #endif
2598
+
2599
  template<typename T, typename T4>
2600
  kernel void kernel_mul_mv_1row(
2601
  constant ggml_metal_kargs_mul_mv & args,