ggerganov commited on
Commit
a726ecc
·
1 Parent(s): 17c0dfa

metal : add mean kernel (llama/14267)

Browse files

* metal : add mean kernel

ggml-ci

* cont : dedup implementation

ggml-ci

ggml/src/ggml-metal/ggml-metal.m CHANGED
@@ -498,6 +498,7 @@ enum ggml_metal_kernel_type {
498
  GGML_METAL_KERNEL_TYPE_COS,
499
  GGML_METAL_KERNEL_TYPE_NEG,
500
  GGML_METAL_KERNEL_TYPE_SUM_ROWS,
 
501
  GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32,
502
  GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32,
503
  GGML_METAL_KERNEL_TYPE_ARGMAX,
@@ -1454,6 +1455,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
1454
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_COS, cos, true);
1455
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NEG, neg, true);
1456
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true);
 
1457
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGMAX, argmax, true);
1458
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32, pool_2d_avg_f32, true);
1459
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32, pool_2d_max_f32, true);
@@ -1653,6 +1655,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
1653
  case GGML_OP_LOG:
1654
  return false; // TODO: implement
1655
  case GGML_OP_SUM_ROWS:
 
1656
  case GGML_OP_SOFT_MAX:
1657
  case GGML_OP_GROUP_NORM:
1658
  return has_simdgroup_reduction && ggml_is_contiguous(op->src[0]);
@@ -2400,11 +2403,30 @@ static bool ggml_metal_encode_node(
2400
  [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
2401
  } break;
2402
  case GGML_OP_SUM_ROWS:
 
2403
  {
2404
  GGML_ASSERT(src0->nb[0] == ggml_type_size(src0->type));
2405
 
2406
- id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SUM_ROWS].pipeline;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2407
 
 
2408
 
2409
  ggml_metal_kargs_sum_rows args = {
2410
  /*.ne00 =*/ ne00,
@@ -2434,11 +2456,12 @@ static bool ggml_metal_encode_node(
2434
  };
2435
 
2436
  [encoder setComputePipelineState:pipeline];
2437
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2438
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
2439
- [encoder setBytes:&args length:sizeof(args) atIndex:2];
 
2440
 
2441
- [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
2442
  } break;
2443
  case GGML_OP_SOFT_MAX:
2444
  {
 
498
  GGML_METAL_KERNEL_TYPE_COS,
499
  GGML_METAL_KERNEL_TYPE_NEG,
500
  GGML_METAL_KERNEL_TYPE_SUM_ROWS,
501
+ GGML_METAL_KERNEL_TYPE_MEAN,
502
  GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32,
503
  GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32,
504
  GGML_METAL_KERNEL_TYPE_ARGMAX,
 
1455
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_COS, cos, true);
1456
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NEG, neg, true);
1457
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true);
1458
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MEAN, mean, true);
1459
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGMAX, argmax, true);
1460
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32, pool_2d_avg_f32, true);
1461
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32, pool_2d_max_f32, true);
 
1655
  case GGML_OP_LOG:
1656
  return false; // TODO: implement
1657
  case GGML_OP_SUM_ROWS:
1658
+ case GGML_OP_MEAN:
1659
  case GGML_OP_SOFT_MAX:
1660
  case GGML_OP_GROUP_NORM:
1661
  return has_simdgroup_reduction && ggml_is_contiguous(op->src[0]);
 
2403
  [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
2404
  } break;
2405
  case GGML_OP_SUM_ROWS:
2406
+ case GGML_OP_MEAN:
2407
  {
2408
  GGML_ASSERT(src0->nb[0] == ggml_type_size(src0->type));
2409
 
2410
+ id<MTLComputePipelineState> pipeline = nil;
2411
+
2412
+ switch (dst->op) {
2413
+ case GGML_OP_SUM_ROWS:
2414
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SUM_ROWS].pipeline;
2415
+ break;
2416
+ case GGML_OP_MEAN:
2417
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MEAN].pipeline;
2418
+ break;
2419
+ default:
2420
+ GGML_ABORT("fatal error");
2421
+ }
2422
+
2423
+ int nth = 32; // SIMD width
2424
+
2425
+ while (nth < ne00 && nth < (int) pipeline.maxTotalThreadsPerThreadgroup) {
2426
+ nth *= 2;
2427
+ }
2428
 
2429
+ nth = MIN(nth, ne00);
2430
 
2431
  ggml_metal_kargs_sum_rows args = {
2432
  /*.ne00 =*/ ne00,
 
2456
  };
2457
 
2458
  [encoder setComputePipelineState:pipeline];
2459
+ [encoder setBytes:&args length:sizeof(args) atIndex:0];
2460
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
2461
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
2462
+ [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
2463
 
2464
+ [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
2465
  } break;
2466
  case GGML_OP_SOFT_MAX:
2467
  {
ggml/src/ggml-metal/ggml-metal.metal CHANGED
@@ -993,31 +993,61 @@ kernel void kernel_neg(
993
  dst[tpig] = -src0[tpig];
994
  }
995
 
 
996
  kernel void kernel_sum_rows(
 
997
  device const float * src0,
998
  device float * dst,
999
- constant ggml_metal_kargs_sum_rows & args,
1000
- uint3 tpig[[thread_position_in_grid]]) {
1001
- int64_t i3 = tpig.z;
1002
- int64_t i2 = tpig.y;
1003
- int64_t i1 = tpig.x;
 
 
 
 
1004
 
1005
  if (i3 >= args.ne03 || i2 >= args.ne02 || i1 >= args.ne01) {
1006
  return;
1007
  }
1008
 
 
 
 
 
1009
  device const float * src_row = (device const float *) ((device const char *) src0 + i1*args.nb01 + i2*args.nb02 + i3*args.nb03);
1010
  device float * dst_row = (device float *) ((device char *) dst + i1*args.nb1 + i2*args.nb2 + i3*args.nb3);
1011
 
1012
- float row_sum = 0;
1013
 
1014
- for (int64_t i0 = 0; i0 < args.ne00; i0++) {
1015
- row_sum += src_row[i0];
1016
  }
1017
 
1018
- dst_row[0] = row_sum;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1019
  }
1020
 
 
 
 
 
 
1021
  template<typename T>
1022
  kernel void kernel_soft_max(
1023
  device const char * src0,
 
993
  dst[tpig] = -src0[tpig];
994
  }
995
 
996
+ template <bool norm>
997
  kernel void kernel_sum_rows(
998
+ constant ggml_metal_kargs_sum_rows & args,
999
  device const float * src0,
1000
  device float * dst,
1001
+ threadgroup float * shmem_f32 [[threadgroup(0)]],
1002
+ uint3 tgpig[[threadgroup_position_in_grid]],
1003
+ ushort3 tpitg[[thread_position_in_threadgroup]],
1004
+ ushort sgitg[[simdgroup_index_in_threadgroup]],
1005
+ ushort tiisg[[thread_index_in_simdgroup]],
1006
+ ushort3 ntg[[threads_per_threadgroup]]) {
1007
+ int64_t i3 = tgpig.z;
1008
+ int64_t i2 = tgpig.y;
1009
+ int64_t i1 = tgpig.x;
1010
 
1011
  if (i3 >= args.ne03 || i2 >= args.ne02 || i1 >= args.ne01) {
1012
  return;
1013
  }
1014
 
1015
+ if (sgitg == 0) {
1016
+ shmem_f32[tiisg] = 0.0f;
1017
+ }
1018
+
1019
  device const float * src_row = (device const float *) ((device const char *) src0 + i1*args.nb01 + i2*args.nb02 + i3*args.nb03);
1020
  device float * dst_row = (device float *) ((device char *) dst + i1*args.nb1 + i2*args.nb2 + i3*args.nb3);
1021
 
1022
+ float sumf = 0;
1023
 
1024
+ for (int64_t i0 = tpitg.x; i0 < args.ne00; i0 += ntg.x) {
1025
+ sumf += src_row[i0];
1026
  }
1027
 
1028
+ sumf = simd_sum(sumf);
1029
+
1030
+ threadgroup_barrier(mem_flags::mem_threadgroup);
1031
+
1032
+ if (tiisg == 0) {
1033
+ shmem_f32[sgitg] = sumf;
1034
+ }
1035
+
1036
+ threadgroup_barrier(mem_flags::mem_threadgroup);
1037
+
1038
+ sumf = shmem_f32[tiisg];
1039
+ sumf = simd_sum(sumf);
1040
+
1041
+ if (tpitg.x == 0) {
1042
+ dst_row[0] = norm ? sumf / args.ne00 : sumf;
1043
+ }
1044
  }
1045
 
1046
+ typedef decltype(kernel_sum_rows<false>) kernel_sum_rows_t;
1047
+
1048
+ template [[host_name("kernel_sum_rows")]] kernel kernel_sum_rows_t kernel_sum_rows<false>;
1049
+ template [[host_name("kernel_mean")]] kernel kernel_sum_rows_t kernel_sum_rows<true>;
1050
+
1051
  template<typename T>
1052
  kernel void kernel_soft_max(
1053
  device const char * src0,