Spaces:
Running
Running
metal : add im2col F32 dst support (llama/5132)
Browse files- ggml-metal.m +10 -3
- ggml-metal.metal +29 -4
ggml-metal.m
CHANGED
|
@@ -135,6 +135,7 @@ enum ggml_metal_kernel_type {
|
|
| 135 |
GGML_METAL_KERNEL_TYPE_ROPE_F16,
|
| 136 |
GGML_METAL_KERNEL_TYPE_ALIBI_F32,
|
| 137 |
GGML_METAL_KERNEL_TYPE_IM2COL_F16,
|
|
|
|
| 138 |
GGML_METAL_KERNEL_TYPE_UPSCALE_F32,
|
| 139 |
GGML_METAL_KERNEL_TYPE_PAD_F32,
|
| 140 |
GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC,
|
|
@@ -506,6 +507,7 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
| 506 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F16, rope_f16, true);
|
| 507 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ALIBI_F32, alibi_f32, true);
|
| 508 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F16, im2col_f16, true);
|
|
|
|
| 509 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_UPSCALE_F32, upscale_f32, true);
|
| 510 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_PAD_F32, pad_f32, true);
|
| 511 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC, argsort_f32_i32_asc, true);
|
|
@@ -630,6 +632,10 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const
|
|
| 630 |
case GGML_OP_ALIBI:
|
| 631 |
case GGML_OP_ROPE:
|
| 632 |
case GGML_OP_IM2COL:
|
|
|
|
|
|
|
|
|
|
|
|
|
| 633 |
case GGML_OP_UPSCALE:
|
| 634 |
case GGML_OP_PAD:
|
| 635 |
case GGML_OP_ARGSORT:
|
|
@@ -2015,7 +2021,7 @@ static bool ggml_metal_graph_compute(
|
|
| 2015 |
{
|
| 2016 |
GGML_ASSERT(src0->type == GGML_TYPE_F16);
|
| 2017 |
GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
| 2018 |
-
GGML_ASSERT( dst->type == GGML_TYPE_F16);
|
| 2019 |
|
| 2020 |
const int32_t s0 = ((const int32_t *)(dst->op_params))[0];
|
| 2021 |
const int32_t s1 = ((const int32_t *)(dst->op_params))[1];
|
|
@@ -2023,6 +2029,7 @@ static bool ggml_metal_graph_compute(
|
|
| 2023 |
const int32_t p1 = ((const int32_t *)(dst->op_params))[3];
|
| 2024 |
const int32_t d0 = ((const int32_t *)(dst->op_params))[4];
|
| 2025 |
const int32_t d1 = ((const int32_t *)(dst->op_params))[5];
|
|
|
|
| 2026 |
const bool is_2D = ((const int32_t *)(dst->op_params))[6] == 1;
|
| 2027 |
|
| 2028 |
const int32_t N = src1->ne[is_2D ? 3 : 2];
|
|
@@ -2043,8 +2050,8 @@ static bool ggml_metal_graph_compute(
|
|
| 2043 |
|
| 2044 |
id<MTLComputePipelineState> pipeline = nil;
|
| 2045 |
|
| 2046 |
-
switch (
|
| 2047 |
-
case GGML_TYPE_F32:
|
| 2048 |
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_IM2COL_F16].pipeline; break;
|
| 2049 |
default: GGML_ASSERT(false);
|
| 2050 |
};
|
|
|
|
| 135 |
GGML_METAL_KERNEL_TYPE_ROPE_F16,
|
| 136 |
GGML_METAL_KERNEL_TYPE_ALIBI_F32,
|
| 137 |
GGML_METAL_KERNEL_TYPE_IM2COL_F16,
|
| 138 |
+
GGML_METAL_KERNEL_TYPE_IM2COL_F32,
|
| 139 |
GGML_METAL_KERNEL_TYPE_UPSCALE_F32,
|
| 140 |
GGML_METAL_KERNEL_TYPE_PAD_F32,
|
| 141 |
GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC,
|
|
|
|
| 507 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F16, rope_f16, true);
|
| 508 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ALIBI_F32, alibi_f32, true);
|
| 509 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F16, im2col_f16, true);
|
| 510 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F32, im2col_f32, true);
|
| 511 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_UPSCALE_F32, upscale_f32, true);
|
| 512 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_PAD_F32, pad_f32, true);
|
| 513 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC, argsort_f32_i32_asc, true);
|
|
|
|
| 632 |
case GGML_OP_ALIBI:
|
| 633 |
case GGML_OP_ROPE:
|
| 634 |
case GGML_OP_IM2COL:
|
| 635 |
+
return true;
|
| 636 |
+
case GGML_OP_POOL_1D:
|
| 637 |
+
case GGML_OP_POOL_2D:
|
| 638 |
+
return false;
|
| 639 |
case GGML_OP_UPSCALE:
|
| 640 |
case GGML_OP_PAD:
|
| 641 |
case GGML_OP_ARGSORT:
|
|
|
|
| 2021 |
{
|
| 2022 |
GGML_ASSERT(src0->type == GGML_TYPE_F16);
|
| 2023 |
GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
| 2024 |
+
GGML_ASSERT( dst->type == GGML_TYPE_F16 || dst->type == GGML_TYPE_F32);
|
| 2025 |
|
| 2026 |
const int32_t s0 = ((const int32_t *)(dst->op_params))[0];
|
| 2027 |
const int32_t s1 = ((const int32_t *)(dst->op_params))[1];
|
|
|
|
| 2029 |
const int32_t p1 = ((const int32_t *)(dst->op_params))[3];
|
| 2030 |
const int32_t d0 = ((const int32_t *)(dst->op_params))[4];
|
| 2031 |
const int32_t d1 = ((const int32_t *)(dst->op_params))[5];
|
| 2032 |
+
|
| 2033 |
const bool is_2D = ((const int32_t *)(dst->op_params))[6] == 1;
|
| 2034 |
|
| 2035 |
const int32_t N = src1->ne[is_2D ? 3 : 2];
|
|
|
|
| 2050 |
|
| 2051 |
id<MTLComputePipelineState> pipeline = nil;
|
| 2052 |
|
| 2053 |
+
switch (dst->type) {
|
| 2054 |
+
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_IM2COL_F32].pipeline; break;
|
| 2055 |
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_IM2COL_F16].pipeline; break;
|
| 2056 |
default: GGML_ASSERT(false);
|
| 2057 |
};
|
ggml-metal.metal
CHANGED
|
@@ -1775,9 +1775,29 @@ kernel void kernel_rope(
|
|
| 1775 |
template [[host_name("kernel_rope_f32")]] kernel rope_t kernel_rope<float>;
|
| 1776 |
template [[host_name("kernel_rope_f16")]] kernel rope_t kernel_rope<half>;
|
| 1777 |
|
| 1778 |
-
|
| 1779 |
device const float * x,
|
| 1780 |
-
device
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1781 |
constant int32_t & ofs0,
|
| 1782 |
constant int32_t & ofs1,
|
| 1783 |
constant int32_t & IW,
|
|
@@ -1800,14 +1820,19 @@ kernel void kernel_im2col_f16(
|
|
| 1800 |
(tpitg[0] * tgpg[1] * tgpg[2] + tgpig[1] * tgpg[2] + tgpig[2]) * CHW +
|
| 1801 |
(tgpig[0] * (ntg[1] * ntg[2]) + tpitg[1] * ntg[2] + tpitg[2]);
|
| 1802 |
|
|
|
|
|
|
|
| 1803 |
if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
|
| 1804 |
-
|
| 1805 |
} else {
|
| 1806 |
const int32_t offset_src = tpitg[0] * ofs0 + tgpig[0] * ofs1;
|
| 1807 |
-
|
| 1808 |
}
|
| 1809 |
}
|
| 1810 |
|
|
|
|
|
|
|
|
|
|
| 1811 |
kernel void kernel_upscale_f32(
|
| 1812 |
device const char * src0,
|
| 1813 |
device char * dst,
|
|
|
|
| 1775 |
template [[host_name("kernel_rope_f32")]] kernel rope_t kernel_rope<float>;
|
| 1776 |
template [[host_name("kernel_rope_f16")]] kernel rope_t kernel_rope<half>;
|
| 1777 |
|
| 1778 |
+
typedef void (im2col_t)(
|
| 1779 |
device const float * x,
|
| 1780 |
+
device char * dst,
|
| 1781 |
+
constant int32_t & ofs0,
|
| 1782 |
+
constant int32_t & ofs1,
|
| 1783 |
+
constant int32_t & IW,
|
| 1784 |
+
constant int32_t & IH,
|
| 1785 |
+
constant int32_t & CHW,
|
| 1786 |
+
constant int32_t & s0,
|
| 1787 |
+
constant int32_t & s1,
|
| 1788 |
+
constant int32_t & p0,
|
| 1789 |
+
constant int32_t & p1,
|
| 1790 |
+
constant int32_t & d0,
|
| 1791 |
+
constant int32_t & d1,
|
| 1792 |
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 1793 |
+
uint3 tgpg[[threadgroups_per_grid]],
|
| 1794 |
+
uint3 tpitg[[thread_position_in_threadgroup]],
|
| 1795 |
+
uint3 ntg[[threads_per_threadgroup]]);
|
| 1796 |
+
|
| 1797 |
+
template <typename T>
|
| 1798 |
+
kernel void kernel_im2col(
|
| 1799 |
+
device const float * x,
|
| 1800 |
+
device char * dst,
|
| 1801 |
constant int32_t & ofs0,
|
| 1802 |
constant int32_t & ofs1,
|
| 1803 |
constant int32_t & IW,
|
|
|
|
| 1820 |
(tpitg[0] * tgpg[1] * tgpg[2] + tgpig[1] * tgpg[2] + tgpig[2]) * CHW +
|
| 1821 |
(tgpig[0] * (ntg[1] * ntg[2]) + tpitg[1] * ntg[2] + tpitg[2]);
|
| 1822 |
|
| 1823 |
+
device T * pdst = (device T *) (dst);
|
| 1824 |
+
|
| 1825 |
if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
|
| 1826 |
+
pdst[offset_dst] = 0.0f;
|
| 1827 |
} else {
|
| 1828 |
const int32_t offset_src = tpitg[0] * ofs0 + tgpig[0] * ofs1;
|
| 1829 |
+
pdst[offset_dst] = x[offset_src + iih * IW + iiw];
|
| 1830 |
}
|
| 1831 |
}
|
| 1832 |
|
| 1833 |
+
template [[host_name("kernel_im2col_f32")]] kernel im2col_t kernel_im2col<float>;
|
| 1834 |
+
template [[host_name("kernel_im2col_f16")]] kernel im2col_t kernel_im2col<half>;
|
| 1835 |
+
|
| 1836 |
kernel void kernel_upscale_f32(
|
| 1837 |
device const char * src0,
|
| 1838 |
device char * dst,
|