ggerganov commited on
Commit
459dd87
·
1 Parent(s): 8bc6274

metal : add kernel_get_rows_i32

Browse files
Files changed (2) hide show
  1. ggml-metal.m +4 -0
  2. ggml-metal.metal +29 -0
ggml-metal.m CHANGED
@@ -87,6 +87,7 @@ struct ggml_metal_context {
87
  GGML_METAL_DECL_KERNEL(get_rows_q4_K);
88
  GGML_METAL_DECL_KERNEL(get_rows_q5_K);
89
  GGML_METAL_DECL_KERNEL(get_rows_q6_K);
 
90
  GGML_METAL_DECL_KERNEL(rms_norm);
91
  GGML_METAL_DECL_KERNEL(group_norm);
92
  GGML_METAL_DECL_KERNEL(norm);
@@ -377,6 +378,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
377
  GGML_METAL_ADD_KERNEL(get_rows_q4_K);
378
  GGML_METAL_ADD_KERNEL(get_rows_q5_K);
379
  GGML_METAL_ADD_KERNEL(get_rows_q6_K);
 
380
  GGML_METAL_ADD_KERNEL(rms_norm);
381
  GGML_METAL_ADD_KERNEL(group_norm);
382
  GGML_METAL_ADD_KERNEL(norm);
@@ -499,6 +501,7 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
499
  GGML_METAL_DEL_KERNEL(get_rows_q4_K);
500
  GGML_METAL_DEL_KERNEL(get_rows_q5_K);
501
  GGML_METAL_DEL_KERNEL(get_rows_q6_K);
 
502
  GGML_METAL_DEL_KERNEL(rms_norm);
503
  GGML_METAL_DEL_KERNEL(group_norm);
504
  GGML_METAL_DEL_KERNEL(norm);
@@ -1978,6 +1981,7 @@ void ggml_metal_graph_compute(
1978
  case GGML_TYPE_Q4_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_K]; break;
1979
  case GGML_TYPE_Q5_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q5_K]; break;
1980
  case GGML_TYPE_Q6_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q6_K]; break;
 
1981
  default: GGML_ASSERT(false && "not implemented");
1982
  }
1983
 
 
87
  GGML_METAL_DECL_KERNEL(get_rows_q4_K);
88
  GGML_METAL_DECL_KERNEL(get_rows_q5_K);
89
  GGML_METAL_DECL_KERNEL(get_rows_q6_K);
90
+ GGML_METAL_DECL_KERNEL(get_rows_i32);
91
  GGML_METAL_DECL_KERNEL(rms_norm);
92
  GGML_METAL_DECL_KERNEL(group_norm);
93
  GGML_METAL_DECL_KERNEL(norm);
 
378
  GGML_METAL_ADD_KERNEL(get_rows_q4_K);
379
  GGML_METAL_ADD_KERNEL(get_rows_q5_K);
380
  GGML_METAL_ADD_KERNEL(get_rows_q6_K);
381
+ GGML_METAL_ADD_KERNEL(get_rows_i32);
382
  GGML_METAL_ADD_KERNEL(rms_norm);
383
  GGML_METAL_ADD_KERNEL(group_norm);
384
  GGML_METAL_ADD_KERNEL(norm);
 
501
  GGML_METAL_DEL_KERNEL(get_rows_q4_K);
502
  GGML_METAL_DEL_KERNEL(get_rows_q5_K);
503
  GGML_METAL_DEL_KERNEL(get_rows_q6_K);
504
+ GGML_METAL_DEL_KERNEL(get_rows_i32);
505
  GGML_METAL_DEL_KERNEL(rms_norm);
506
  GGML_METAL_DEL_KERNEL(group_norm);
507
  GGML_METAL_DEL_KERNEL(norm);
 
1981
  case GGML_TYPE_Q4_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_K]; break;
1982
  case GGML_TYPE_Q5_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q5_K]; break;
1983
  case GGML_TYPE_Q6_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q6_K]; break;
1984
+ case GGML_TYPE_I32: [encoder setComputePipelineState:ctx->pipeline_get_rows_i32]; break;
1985
  default: GGML_ASSERT(false && "not implemented");
1986
  }
1987
 
ggml-metal.metal CHANGED
@@ -3829,6 +3829,35 @@ kernel void kernel_get_rows_f16(
3829
  }
3830
  }
3831
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3832
  #define BLOCK_SIZE_M 64 // 8 simdgroup matrices from matrix A
3833
  #define BLOCK_SIZE_N 32 // 4 simdgroup matrices from matrix B
3834
  #define BLOCK_SIZE_K 32
 
3829
  }
3830
  }
3831
 
3832
+ kernel void kernel_get_rows_i32(
3833
+ device const void * src0,
3834
+ device const char * src1,
3835
+ device int32_t * dst,
3836
+ constant int64_t & ne00,
3837
+ constant uint64_t & nb01,
3838
+ constant uint64_t & nb02,
3839
+ constant int64_t & ne10,
3840
+ constant uint64_t & nb10,
3841
+ constant uint64_t & nb11,
3842
+ constant uint64_t & nb1,
3843
+ constant uint64_t & nb2,
3844
+ uint3 tgpig[[threadgroup_position_in_grid]],
3845
+ uint tiitg[[thread_index_in_threadgroup]],
3846
+ uint3 tptg [[threads_per_threadgroup]]) {
3847
+ const int64_t i10 = tgpig.x;
3848
+ const int64_t i11 = tgpig.y;
3849
+
3850
+ const int64_t r = ((device int32_t *) ((device char *) src1 + i11*nb11 + i10*nb10))[0];
3851
+
3852
+ const int64_t i02 = i11;
3853
+
3854
+ for (int ind = tiitg; ind < ne00; ind += tptg.x) {
3855
+ ((device int32_t *) ((device char *) dst + i11*nb2 + i10*nb1))[ind] =
3856
+ ((device int32_t *) ((device char *) src0 + r*nb01 + i02*nb02))[ind];
3857
+ }
3858
+ }
3859
+
3860
+
3861
  #define BLOCK_SIZE_M 64 // 8 simdgroup matrices from matrix A
3862
  #define BLOCK_SIZE_N 32 // 4 simdgroup matrices from matrix B
3863
  #define BLOCK_SIZE_K 32