Spaces:
Running
Running
metal : add mean kernel (llama/14267)
Browse files* metal : add mean kernel
ggml-ci
* cont : dedup implementation
ggml-ci
ggml/src/ggml-metal/ggml-metal.m
CHANGED
|
@@ -498,6 +498,7 @@ enum ggml_metal_kernel_type {
|
|
| 498 |
GGML_METAL_KERNEL_TYPE_COS,
|
| 499 |
GGML_METAL_KERNEL_TYPE_NEG,
|
| 500 |
GGML_METAL_KERNEL_TYPE_SUM_ROWS,
|
|
|
|
| 501 |
GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32,
|
| 502 |
GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32,
|
| 503 |
GGML_METAL_KERNEL_TYPE_ARGMAX,
|
|
@@ -1454,6 +1455,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
|
|
| 1454 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_COS, cos, true);
|
| 1455 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NEG, neg, true);
|
| 1456 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true);
|
|
|
|
| 1457 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGMAX, argmax, true);
|
| 1458 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32, pool_2d_avg_f32, true);
|
| 1459 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32, pool_2d_max_f32, true);
|
|
@@ -1653,6 +1655,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
|
|
| 1653 |
case GGML_OP_LOG:
|
| 1654 |
return false; // TODO: implement
|
| 1655 |
case GGML_OP_SUM_ROWS:
|
|
|
|
| 1656 |
case GGML_OP_SOFT_MAX:
|
| 1657 |
case GGML_OP_GROUP_NORM:
|
| 1658 |
return has_simdgroup_reduction && ggml_is_contiguous(op->src[0]);
|
|
@@ -2400,11 +2403,30 @@ static bool ggml_metal_encode_node(
|
|
| 2400 |
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
| 2401 |
} break;
|
| 2402 |
case GGML_OP_SUM_ROWS:
|
|
|
|
| 2403 |
{
|
| 2404 |
GGML_ASSERT(src0->nb[0] == ggml_type_size(src0->type));
|
| 2405 |
|
| 2406 |
-
id<MTLComputePipelineState> pipeline =
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2407 |
|
|
|
|
| 2408 |
|
| 2409 |
ggml_metal_kargs_sum_rows args = {
|
| 2410 |
/*.ne00 =*/ ne00,
|
|
@@ -2434,11 +2456,12 @@ static bool ggml_metal_encode_node(
|
|
| 2434 |
};
|
| 2435 |
|
| 2436 |
[encoder setComputePipelineState:pipeline];
|
| 2437 |
-
[encoder
|
| 2438 |
-
[encoder setBuffer:
|
| 2439 |
-
[encoder
|
|
|
|
| 2440 |
|
| 2441 |
-
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(
|
| 2442 |
} break;
|
| 2443 |
case GGML_OP_SOFT_MAX:
|
| 2444 |
{
|
|
|
|
| 498 |
GGML_METAL_KERNEL_TYPE_COS,
|
| 499 |
GGML_METAL_KERNEL_TYPE_NEG,
|
| 500 |
GGML_METAL_KERNEL_TYPE_SUM_ROWS,
|
| 501 |
+
GGML_METAL_KERNEL_TYPE_MEAN,
|
| 502 |
GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32,
|
| 503 |
GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32,
|
| 504 |
GGML_METAL_KERNEL_TYPE_ARGMAX,
|
|
|
|
| 1455 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_COS, cos, true);
|
| 1456 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NEG, neg, true);
|
| 1457 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true);
|
| 1458 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MEAN, mean, true);
|
| 1459 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGMAX, argmax, true);
|
| 1460 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32, pool_2d_avg_f32, true);
|
| 1461 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32, pool_2d_max_f32, true);
|
|
|
|
| 1655 |
case GGML_OP_LOG:
|
| 1656 |
return false; // TODO: implement
|
| 1657 |
case GGML_OP_SUM_ROWS:
|
| 1658 |
+
case GGML_OP_MEAN:
|
| 1659 |
case GGML_OP_SOFT_MAX:
|
| 1660 |
case GGML_OP_GROUP_NORM:
|
| 1661 |
return has_simdgroup_reduction && ggml_is_contiguous(op->src[0]);
|
|
|
|
| 2403 |
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
| 2404 |
} break;
|
| 2405 |
case GGML_OP_SUM_ROWS:
|
| 2406 |
+
case GGML_OP_MEAN:
|
| 2407 |
{
|
| 2408 |
GGML_ASSERT(src0->nb[0] == ggml_type_size(src0->type));
|
| 2409 |
|
| 2410 |
+
id<MTLComputePipelineState> pipeline = nil;
|
| 2411 |
+
|
| 2412 |
+
switch (dst->op) {
|
| 2413 |
+
case GGML_OP_SUM_ROWS:
|
| 2414 |
+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SUM_ROWS].pipeline;
|
| 2415 |
+
break;
|
| 2416 |
+
case GGML_OP_MEAN:
|
| 2417 |
+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MEAN].pipeline;
|
| 2418 |
+
break;
|
| 2419 |
+
default:
|
| 2420 |
+
GGML_ABORT("fatal error");
|
| 2421 |
+
}
|
| 2422 |
+
|
| 2423 |
+
int nth = 32; // SIMD width
|
| 2424 |
+
|
| 2425 |
+
while (nth < ne00 && nth < (int) pipeline.maxTotalThreadsPerThreadgroup) {
|
| 2426 |
+
nth *= 2;
|
| 2427 |
+
}
|
| 2428 |
|
| 2429 |
+
nth = MIN(nth, ne00);
|
| 2430 |
|
| 2431 |
ggml_metal_kargs_sum_rows args = {
|
| 2432 |
/*.ne00 =*/ ne00,
|
|
|
|
| 2456 |
};
|
| 2457 |
|
| 2458 |
[encoder setComputePipelineState:pipeline];
|
| 2459 |
+
[encoder setBytes:&args length:sizeof(args) atIndex:0];
|
| 2460 |
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
|
| 2461 |
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
| 2462 |
+
[encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
|
| 2463 |
|
| 2464 |
+
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
| 2465 |
} break;
|
| 2466 |
case GGML_OP_SOFT_MAX:
|
| 2467 |
{
|
ggml/src/ggml-metal/ggml-metal.metal
CHANGED
|
@@ -993,31 +993,61 @@ kernel void kernel_neg(
|
|
| 993 |
dst[tpig] = -src0[tpig];
|
| 994 |
}
|
| 995 |
|
|
|
|
| 996 |
kernel void kernel_sum_rows(
|
|
|
|
| 997 |
device const float * src0,
|
| 998 |
device float * dst,
|
| 999 |
-
|
| 1000 |
-
uint3
|
| 1001 |
-
|
| 1002 |
-
|
| 1003 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1004 |
|
| 1005 |
if (i3 >= args.ne03 || i2 >= args.ne02 || i1 >= args.ne01) {
|
| 1006 |
return;
|
| 1007 |
}
|
| 1008 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1009 |
device const float * src_row = (device const float *) ((device const char *) src0 + i1*args.nb01 + i2*args.nb02 + i3*args.nb03);
|
| 1010 |
device float * dst_row = (device float *) ((device char *) dst + i1*args.nb1 + i2*args.nb2 + i3*args.nb3);
|
| 1011 |
|
| 1012 |
-
float
|
| 1013 |
|
| 1014 |
-
for (int64_t i0 =
|
| 1015 |
-
|
| 1016 |
}
|
| 1017 |
|
| 1018 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1019 |
}
|
| 1020 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1021 |
template<typename T>
|
| 1022 |
kernel void kernel_soft_max(
|
| 1023 |
device const char * src0,
|
|
|
|
| 993 |
dst[tpig] = -src0[tpig];
|
| 994 |
}
|
| 995 |
|
| 996 |
+
template <bool norm>
|
| 997 |
kernel void kernel_sum_rows(
|
| 998 |
+
constant ggml_metal_kargs_sum_rows & args,
|
| 999 |
device const float * src0,
|
| 1000 |
device float * dst,
|
| 1001 |
+
threadgroup float * shmem_f32 [[threadgroup(0)]],
|
| 1002 |
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 1003 |
+
ushort3 tpitg[[thread_position_in_threadgroup]],
|
| 1004 |
+
ushort sgitg[[simdgroup_index_in_threadgroup]],
|
| 1005 |
+
ushort tiisg[[thread_index_in_simdgroup]],
|
| 1006 |
+
ushort3 ntg[[threads_per_threadgroup]]) {
|
| 1007 |
+
int64_t i3 = tgpig.z;
|
| 1008 |
+
int64_t i2 = tgpig.y;
|
| 1009 |
+
int64_t i1 = tgpig.x;
|
| 1010 |
|
| 1011 |
if (i3 >= args.ne03 || i2 >= args.ne02 || i1 >= args.ne01) {
|
| 1012 |
return;
|
| 1013 |
}
|
| 1014 |
|
| 1015 |
+
if (sgitg == 0) {
|
| 1016 |
+
shmem_f32[tiisg] = 0.0f;
|
| 1017 |
+
}
|
| 1018 |
+
|
| 1019 |
device const float * src_row = (device const float *) ((device const char *) src0 + i1*args.nb01 + i2*args.nb02 + i3*args.nb03);
|
| 1020 |
device float * dst_row = (device float *) ((device char *) dst + i1*args.nb1 + i2*args.nb2 + i3*args.nb3);
|
| 1021 |
|
| 1022 |
+
float sumf = 0;
|
| 1023 |
|
| 1024 |
+
for (int64_t i0 = tpitg.x; i0 < args.ne00; i0 += ntg.x) {
|
| 1025 |
+
sumf += src_row[i0];
|
| 1026 |
}
|
| 1027 |
|
| 1028 |
+
sumf = simd_sum(sumf);
|
| 1029 |
+
|
| 1030 |
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
| 1031 |
+
|
| 1032 |
+
if (tiisg == 0) {
|
| 1033 |
+
shmem_f32[sgitg] = sumf;
|
| 1034 |
+
}
|
| 1035 |
+
|
| 1036 |
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
| 1037 |
+
|
| 1038 |
+
sumf = shmem_f32[tiisg];
|
| 1039 |
+
sumf = simd_sum(sumf);
|
| 1040 |
+
|
| 1041 |
+
if (tpitg.x == 0) {
|
| 1042 |
+
dst_row[0] = norm ? sumf / args.ne00 : sumf;
|
| 1043 |
+
}
|
| 1044 |
}
|
| 1045 |
|
| 1046 |
+
typedef decltype(kernel_sum_rows<false>) kernel_sum_rows_t;
|
| 1047 |
+
|
| 1048 |
+
template [[host_name("kernel_sum_rows")]] kernel kernel_sum_rows_t kernel_sum_rows<false>;
|
| 1049 |
+
template [[host_name("kernel_mean")]] kernel kernel_sum_rows_t kernel_sum_rows<true>;
|
| 1050 |
+
|
| 1051 |
template<typename T>
|
| 1052 |
kernel void kernel_soft_max(
|
| 1053 |
device const char * src0,
|