Spaces:
Running
Running
metal : add kernel_get_rows_i32
Browse files- ggml-metal.m +4 -0
- ggml-metal.metal +29 -0
ggml-metal.m
CHANGED
|
@@ -87,6 +87,7 @@ struct ggml_metal_context {
|
|
| 87 |
GGML_METAL_DECL_KERNEL(get_rows_q4_K);
|
| 88 |
GGML_METAL_DECL_KERNEL(get_rows_q5_K);
|
| 89 |
GGML_METAL_DECL_KERNEL(get_rows_q6_K);
|
|
|
|
| 90 |
GGML_METAL_DECL_KERNEL(rms_norm);
|
| 91 |
GGML_METAL_DECL_KERNEL(group_norm);
|
| 92 |
GGML_METAL_DECL_KERNEL(norm);
|
|
@@ -377,6 +378,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
| 377 |
GGML_METAL_ADD_KERNEL(get_rows_q4_K);
|
| 378 |
GGML_METAL_ADD_KERNEL(get_rows_q5_K);
|
| 379 |
GGML_METAL_ADD_KERNEL(get_rows_q6_K);
|
|
|
|
| 380 |
GGML_METAL_ADD_KERNEL(rms_norm);
|
| 381 |
GGML_METAL_ADD_KERNEL(group_norm);
|
| 382 |
GGML_METAL_ADD_KERNEL(norm);
|
|
@@ -499,6 +501,7 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
|
|
| 499 |
GGML_METAL_DEL_KERNEL(get_rows_q4_K);
|
| 500 |
GGML_METAL_DEL_KERNEL(get_rows_q5_K);
|
| 501 |
GGML_METAL_DEL_KERNEL(get_rows_q6_K);
|
|
|
|
| 502 |
GGML_METAL_DEL_KERNEL(rms_norm);
|
| 503 |
GGML_METAL_DEL_KERNEL(group_norm);
|
| 504 |
GGML_METAL_DEL_KERNEL(norm);
|
|
@@ -1978,6 +1981,7 @@ void ggml_metal_graph_compute(
|
|
| 1978 |
case GGML_TYPE_Q4_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_K]; break;
|
| 1979 |
case GGML_TYPE_Q5_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q5_K]; break;
|
| 1980 |
case GGML_TYPE_Q6_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q6_K]; break;
|
|
|
|
| 1981 |
default: GGML_ASSERT(false && "not implemented");
|
| 1982 |
}
|
| 1983 |
|
|
|
|
| 87 |
GGML_METAL_DECL_KERNEL(get_rows_q4_K);
|
| 88 |
GGML_METAL_DECL_KERNEL(get_rows_q5_K);
|
| 89 |
GGML_METAL_DECL_KERNEL(get_rows_q6_K);
|
| 90 |
+
GGML_METAL_DECL_KERNEL(get_rows_i32);
|
| 91 |
GGML_METAL_DECL_KERNEL(rms_norm);
|
| 92 |
GGML_METAL_DECL_KERNEL(group_norm);
|
| 93 |
GGML_METAL_DECL_KERNEL(norm);
|
|
|
|
| 378 |
GGML_METAL_ADD_KERNEL(get_rows_q4_K);
|
| 379 |
GGML_METAL_ADD_KERNEL(get_rows_q5_K);
|
| 380 |
GGML_METAL_ADD_KERNEL(get_rows_q6_K);
|
| 381 |
+
GGML_METAL_ADD_KERNEL(get_rows_i32);
|
| 382 |
GGML_METAL_ADD_KERNEL(rms_norm);
|
| 383 |
GGML_METAL_ADD_KERNEL(group_norm);
|
| 384 |
GGML_METAL_ADD_KERNEL(norm);
|
|
|
|
| 501 |
GGML_METAL_DEL_KERNEL(get_rows_q4_K);
|
| 502 |
GGML_METAL_DEL_KERNEL(get_rows_q5_K);
|
| 503 |
GGML_METAL_DEL_KERNEL(get_rows_q6_K);
|
| 504 |
+
GGML_METAL_DEL_KERNEL(get_rows_i32);
|
| 505 |
GGML_METAL_DEL_KERNEL(rms_norm);
|
| 506 |
GGML_METAL_DEL_KERNEL(group_norm);
|
| 507 |
GGML_METAL_DEL_KERNEL(norm);
|
|
|
|
| 1981 |
case GGML_TYPE_Q4_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_K]; break;
|
| 1982 |
case GGML_TYPE_Q5_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q5_K]; break;
|
| 1983 |
case GGML_TYPE_Q6_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q6_K]; break;
|
| 1984 |
+
case GGML_TYPE_I32: [encoder setComputePipelineState:ctx->pipeline_get_rows_i32]; break;
|
| 1985 |
default: GGML_ASSERT(false && "not implemented");
|
| 1986 |
}
|
| 1987 |
|
ggml-metal.metal
CHANGED
|
@@ -3829,6 +3829,35 @@ kernel void kernel_get_rows_f16(
|
|
| 3829 |
}
|
| 3830 |
}
|
| 3831 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3832 |
#define BLOCK_SIZE_M 64 // 8 simdgroup matrices from matrix A
|
| 3833 |
#define BLOCK_SIZE_N 32 // 4 simdgroup matrices from matrix B
|
| 3834 |
#define BLOCK_SIZE_K 32
|
|
|
|
| 3829 |
}
|
| 3830 |
}
|
| 3831 |
|
| 3832 |
+
kernel void kernel_get_rows_i32(
|
| 3833 |
+
device const void * src0,
|
| 3834 |
+
device const char * src1,
|
| 3835 |
+
device int32_t * dst,
|
| 3836 |
+
constant int64_t & ne00,
|
| 3837 |
+
constant uint64_t & nb01,
|
| 3838 |
+
constant uint64_t & nb02,
|
| 3839 |
+
constant int64_t & ne10,
|
| 3840 |
+
constant uint64_t & nb10,
|
| 3841 |
+
constant uint64_t & nb11,
|
| 3842 |
+
constant uint64_t & nb1,
|
| 3843 |
+
constant uint64_t & nb2,
|
| 3844 |
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 3845 |
+
uint tiitg[[thread_index_in_threadgroup]],
|
| 3846 |
+
uint3 tptg [[threads_per_threadgroup]]) {
|
| 3847 |
+
const int64_t i10 = tgpig.x;
|
| 3848 |
+
const int64_t i11 = tgpig.y;
|
| 3849 |
+
|
| 3850 |
+
const int64_t r = ((device int32_t *) ((device char *) src1 + i11*nb11 + i10*nb10))[0];
|
| 3851 |
+
|
| 3852 |
+
const int64_t i02 = i11;
|
| 3853 |
+
|
| 3854 |
+
for (int ind = tiitg; ind < ne00; ind += tptg.x) {
|
| 3855 |
+
((device int32_t *) ((device char *) dst + i11*nb2 + i10*nb1))[ind] =
|
| 3856 |
+
((device int32_t *) ((device char *) src0 + r*nb01 + i02*nb02))[ind];
|
| 3857 |
+
}
|
| 3858 |
+
}
|
| 3859 |
+
|
| 3860 |
+
|
| 3861 |
#define BLOCK_SIZE_M 64 // 8 simdgroup matrices from matrix A
|
| 3862 |
#define BLOCK_SIZE_N 32 // 4 simdgroup matrices from matrix B
|
| 3863 |
#define BLOCK_SIZE_K 32
|