PABannier Diego Devesa commited on
Commit
c7e59ef
·
1 Parent(s): 9c845f4

feat: add `GGML_UNARY_OP_ARGMAX` Metal kernel (ggml/1019)

Browse files

* implemented argmax kernel

* tpig -> tgpig

* change to strides

* contiguous assertions

* kernel working and tested

* argmax simd parallel implementation

* added 2 new tests for argmax in test-backend-ops

* cosmit

* added 3 tests cases for perf eval

* add test_argmax in make_test_cases_perf

* Update test-backend-ops.cpp

Co-authored-by: Diego Devesa <[email protected]>

---------

Co-authored-by: Diego Devesa <[email protected]>

ggml/src/ggml-metal/ggml-metal.m CHANGED
@@ -352,6 +352,7 @@ enum ggml_metal_kernel_type {
352
  GGML_METAL_KERNEL_TYPE_SUM_ROWS,
353
  GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32,
354
  GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32,
 
355
 
356
  GGML_METAL_KERNEL_TYPE_COUNT
357
  };
@@ -876,6 +877,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
876
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SIN, sin, true);
877
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_COS, cos, true);
878
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true);
 
879
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32, pool_2d_avg_f32, true);
880
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32, pool_2d_max_f32, true);
881
  }
@@ -1005,6 +1007,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
1005
  case GGML_OP_RMS_NORM:
1006
  case GGML_OP_GROUP_NORM:
1007
  return has_simdgroup_reduction;
 
1008
  case GGML_OP_NORM:
1009
  case GGML_OP_ROPE:
1010
  return true;
@@ -3615,6 +3618,31 @@ static void ggml_metal_encode_node(
3615
 
3616
  [encoder dispatchThreadgroups:MTLSizeMake(n_tg, 1, 1) threadsPerThreadgroup:MTLSizeMake(n_threads, 1, 1)];
3617
  } break;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3618
  default:
3619
  {
3620
  GGML_LOG_ERROR("%s: error: node %3d, op = %8s not implemented\n", __func__, idx, ggml_op_name(dst->op));
 
352
  GGML_METAL_KERNEL_TYPE_SUM_ROWS,
353
  GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32,
354
  GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32,
355
+ GGML_METAL_KERNEL_TYPE_ARGMAX,
356
 
357
  GGML_METAL_KERNEL_TYPE_COUNT
358
  };
 
877
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SIN, sin, true);
878
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_COS, cos, true);
879
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true);
880
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGMAX, argmax, true);
881
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32, pool_2d_avg_f32, true);
882
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32, pool_2d_max_f32, true);
883
  }
 
1007
  case GGML_OP_RMS_NORM:
1008
  case GGML_OP_GROUP_NORM:
1009
  return has_simdgroup_reduction;
1010
+ case GGML_OP_ARGMAX:
1011
  case GGML_OP_NORM:
1012
  case GGML_OP_ROPE:
1013
  return true;
 
3618
 
3619
  [encoder dispatchThreadgroups:MTLSizeMake(n_tg, 1, 1) threadsPerThreadgroup:MTLSizeMake(n_threads, 1, 1)];
3620
  } break;
3621
+ case GGML_OP_ARGMAX:
3622
+ {
3623
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
3624
+ GGML_ASSERT(ggml_is_contiguous_1(src0));
3625
+ GGML_ASSERT(nb00 == ggml_type_size(src0->type));
3626
+
3627
+ const int64_t nrows = ggml_nrows(src0);
3628
+
3629
+ int nth = 32; // SIMD width
3630
+ while (nth < ne00 && nth*ne01*ne02*ne03 < 256) {
3631
+ nth *= 2;
3632
+ }
3633
+
3634
+ id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ARGMAX].pipeline;
3635
+
3636
+ [encoder setComputePipelineState:pipeline];
3637
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
3638
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
3639
+ [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
3640
+ [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3];
3641
+ [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
3642
+ [encoder setThreadgroupMemoryLength:32*sizeof(int32_t) atIndex:1];
3643
+
3644
+ [encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
3645
+ } break;
3646
  default:
3647
  {
3648
  GGML_LOG_ERROR("%s: error: node %3d, op = %8s not implemented\n", __func__, idx, ggml_op_name(dst->op));
ggml/src/ggml-metal/ggml-metal.metal CHANGED
@@ -1248,6 +1248,63 @@ kernel void kernel_ssm_scan_f32(
1248
  }
1249
  }
1250
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1251
  kernel void kernel_norm(
1252
  constant ggml_metal_kargs_norm & args,
1253
  device const char * src0,
 
1248
  }
1249
  }
1250
 
1251
+ kernel void kernel_argmax(
1252
+ device const void * x,
1253
+ device int32_t * dst,
1254
+ constant int64_t & ncols,
1255
+ constant uint64_t & nb01,
1256
+ threadgroup float * shared_maxval [[threadgroup(0)]],
1257
+ threadgroup int32_t * shared_argmax [[threadgroup(1)]],
1258
+ uint tgpig[[threadgroup_position_in_grid]],
1259
+ uint tpitg[[thread_position_in_threadgroup]],
1260
+ uint sgitg[[simdgroup_index_in_threadgroup]],
1261
+ uint tiisg[[thread_index_in_simdgroup]],
1262
+ uint ntg[[threads_per_threadgroup]]) {
1263
+ device const float * x_row = (device const float *) ((device const char *) x + tgpig * nb01);
1264
+
1265
+ float lmax = -INFINITY;
1266
+ int32_t larg = -1;
1267
+
1268
+ for (int i00 = tpitg; i00 < ncols; i00 += ntg) {
1269
+ if (x_row[i00] > lmax) {
1270
+ lmax = x_row[i00];
1271
+ larg = i00;
1272
+ }
1273
+ }
1274
+
1275
+ // find the argmax value in the block
1276
+ float max_val = simd_max(lmax);
1277
+ int32_t arg_val = simd_max(select(-1, larg, lmax == max_val));
1278
+
1279
+ if (ntg > N_SIMDWIDTH) {
1280
+ if (sgitg == 0) {
1281
+ shared_maxval[tiisg] = -INFINITY;
1282
+ shared_argmax[tiisg] = -1;
1283
+ }
1284
+
1285
+ threadgroup_barrier(mem_flags::mem_threadgroup);
1286
+
1287
+ if (tiisg == 0) {
1288
+ shared_maxval[sgitg] = max_val;
1289
+ shared_argmax[sgitg] = arg_val;
1290
+ }
1291
+
1292
+ threadgroup_barrier(mem_flags::mem_threadgroup);
1293
+
1294
+ max_val = shared_maxval[tiisg];
1295
+ arg_val = shared_argmax[tiisg];
1296
+
1297
+ float max_val_reduced = simd_max(max_val);
1298
+ int32_t arg_val_reduced = simd_max(select(-1, arg_val, max_val == max_val_reduced));
1299
+
1300
+ dst[tgpig] = arg_val_reduced;
1301
+
1302
+ return;
1303
+ }
1304
+
1305
+ dst[tgpig] = arg_val;
1306
+ }
1307
+
1308
  kernel void kernel_norm(
1309
  constant ggml_metal_kargs_norm & args,
1310
  device const char * src0,