ggerganov commited on
Commit
26aec77
·
unverified ·
1 Parent(s): f17a416

metal : add im2col F32 dst support (llama/5132)

Browse files
Files changed (2) hide show
  1. ggml-metal.m +10 -3
  2. ggml-metal.metal +29 -4
ggml-metal.m CHANGED
@@ -135,6 +135,7 @@ enum ggml_metal_kernel_type {
135
  GGML_METAL_KERNEL_TYPE_ROPE_F16,
136
  GGML_METAL_KERNEL_TYPE_ALIBI_F32,
137
  GGML_METAL_KERNEL_TYPE_IM2COL_F16,
 
138
  GGML_METAL_KERNEL_TYPE_UPSCALE_F32,
139
  GGML_METAL_KERNEL_TYPE_PAD_F32,
140
  GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC,
@@ -506,6 +507,7 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
506
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F16, rope_f16, true);
507
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ALIBI_F32, alibi_f32, true);
508
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F16, im2col_f16, true);
 
509
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_UPSCALE_F32, upscale_f32, true);
510
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_PAD_F32, pad_f32, true);
511
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC, argsort_f32_i32_asc, true);
@@ -630,6 +632,10 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const
630
  case GGML_OP_ALIBI:
631
  case GGML_OP_ROPE:
632
  case GGML_OP_IM2COL:
 
 
 
 
633
  case GGML_OP_UPSCALE:
634
  case GGML_OP_PAD:
635
  case GGML_OP_ARGSORT:
@@ -2015,7 +2021,7 @@ static bool ggml_metal_graph_compute(
2015
  {
2016
  GGML_ASSERT(src0->type == GGML_TYPE_F16);
2017
  GGML_ASSERT(src1->type == GGML_TYPE_F32);
2018
- GGML_ASSERT( dst->type == GGML_TYPE_F16);
2019
 
2020
  const int32_t s0 = ((const int32_t *)(dst->op_params))[0];
2021
  const int32_t s1 = ((const int32_t *)(dst->op_params))[1];
@@ -2023,6 +2029,7 @@ static bool ggml_metal_graph_compute(
2023
  const int32_t p1 = ((const int32_t *)(dst->op_params))[3];
2024
  const int32_t d0 = ((const int32_t *)(dst->op_params))[4];
2025
  const int32_t d1 = ((const int32_t *)(dst->op_params))[5];
 
2026
  const bool is_2D = ((const int32_t *)(dst->op_params))[6] == 1;
2027
 
2028
  const int32_t N = src1->ne[is_2D ? 3 : 2];
@@ -2043,8 +2050,8 @@ static bool ggml_metal_graph_compute(
2043
 
2044
  id<MTLComputePipelineState> pipeline = nil;
2045
 
2046
- switch (src0->type) {
2047
- case GGML_TYPE_F32: GGML_ASSERT(false && "not implemented"); break;
2048
  case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_IM2COL_F16].pipeline; break;
2049
  default: GGML_ASSERT(false);
2050
  };
 
135
  GGML_METAL_KERNEL_TYPE_ROPE_F16,
136
  GGML_METAL_KERNEL_TYPE_ALIBI_F32,
137
  GGML_METAL_KERNEL_TYPE_IM2COL_F16,
138
+ GGML_METAL_KERNEL_TYPE_IM2COL_F32,
139
  GGML_METAL_KERNEL_TYPE_UPSCALE_F32,
140
  GGML_METAL_KERNEL_TYPE_PAD_F32,
141
  GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC,
 
507
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F16, rope_f16, true);
508
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ALIBI_F32, alibi_f32, true);
509
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F16, im2col_f16, true);
510
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F32, im2col_f32, true);
511
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_UPSCALE_F32, upscale_f32, true);
512
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_PAD_F32, pad_f32, true);
513
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC, argsort_f32_i32_asc, true);
 
632
  case GGML_OP_ALIBI:
633
  case GGML_OP_ROPE:
634
  case GGML_OP_IM2COL:
635
+ return true;
636
+ case GGML_OP_POOL_1D:
637
+ case GGML_OP_POOL_2D:
638
+ return false;
639
  case GGML_OP_UPSCALE:
640
  case GGML_OP_PAD:
641
  case GGML_OP_ARGSORT:
 
2021
  {
2022
  GGML_ASSERT(src0->type == GGML_TYPE_F16);
2023
  GGML_ASSERT(src1->type == GGML_TYPE_F32);
2024
+ GGML_ASSERT( dst->type == GGML_TYPE_F16 || dst->type == GGML_TYPE_F32);
2025
 
2026
  const int32_t s0 = ((const int32_t *)(dst->op_params))[0];
2027
  const int32_t s1 = ((const int32_t *)(dst->op_params))[1];
 
2029
  const int32_t p1 = ((const int32_t *)(dst->op_params))[3];
2030
  const int32_t d0 = ((const int32_t *)(dst->op_params))[4];
2031
  const int32_t d1 = ((const int32_t *)(dst->op_params))[5];
2032
+
2033
  const bool is_2D = ((const int32_t *)(dst->op_params))[6] == 1;
2034
 
2035
  const int32_t N = src1->ne[is_2D ? 3 : 2];
 
2050
 
2051
  id<MTLComputePipelineState> pipeline = nil;
2052
 
2053
+ switch (dst->type) {
2054
+ case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_IM2COL_F32].pipeline; break;
2055
  case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_IM2COL_F16].pipeline; break;
2056
  default: GGML_ASSERT(false);
2057
  };
ggml-metal.metal CHANGED
@@ -1775,9 +1775,29 @@ kernel void kernel_rope(
1775
  template [[host_name("kernel_rope_f32")]] kernel rope_t kernel_rope<float>;
1776
  template [[host_name("kernel_rope_f16")]] kernel rope_t kernel_rope<half>;
1777
 
1778
- kernel void kernel_im2col_f16(
1779
  device const float * x,
1780
- device half * dst,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1781
  constant int32_t & ofs0,
1782
  constant int32_t & ofs1,
1783
  constant int32_t & IW,
@@ -1800,14 +1820,19 @@ kernel void kernel_im2col_f16(
1800
  (tpitg[0] * tgpg[1] * tgpg[2] + tgpig[1] * tgpg[2] + tgpig[2]) * CHW +
1801
  (tgpig[0] * (ntg[1] * ntg[2]) + tpitg[1] * ntg[2] + tpitg[2]);
1802
 
 
 
1803
  if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
1804
- dst[offset_dst] = 0.0f;
1805
  } else {
1806
  const int32_t offset_src = tpitg[0] * ofs0 + tgpig[0] * ofs1;
1807
- dst[offset_dst] = x[offset_src + iih * IW + iiw];
1808
  }
1809
  }
1810
 
 
 
 
1811
  kernel void kernel_upscale_f32(
1812
  device const char * src0,
1813
  device char * dst,
 
1775
  template [[host_name("kernel_rope_f32")]] kernel rope_t kernel_rope<float>;
1776
  template [[host_name("kernel_rope_f16")]] kernel rope_t kernel_rope<half>;
1777
 
1778
+ typedef void (im2col_t)(
1779
  device const float * x,
1780
+ device char * dst,
1781
+ constant int32_t & ofs0,
1782
+ constant int32_t & ofs1,
1783
+ constant int32_t & IW,
1784
+ constant int32_t & IH,
1785
+ constant int32_t & CHW,
1786
+ constant int32_t & s0,
1787
+ constant int32_t & s1,
1788
+ constant int32_t & p0,
1789
+ constant int32_t & p1,
1790
+ constant int32_t & d0,
1791
+ constant int32_t & d1,
1792
+ uint3 tgpig[[threadgroup_position_in_grid]],
1793
+ uint3 tgpg[[threadgroups_per_grid]],
1794
+ uint3 tpitg[[thread_position_in_threadgroup]],
1795
+ uint3 ntg[[threads_per_threadgroup]]);
1796
+
1797
+ template <typename T>
1798
+ kernel void kernel_im2col(
1799
+ device const float * x,
1800
+ device char * dst,
1801
  constant int32_t & ofs0,
1802
  constant int32_t & ofs1,
1803
  constant int32_t & IW,
 
1820
  (tpitg[0] * tgpg[1] * tgpg[2] + tgpig[1] * tgpg[2] + tgpig[2]) * CHW +
1821
  (tgpig[0] * (ntg[1] * ntg[2]) + tpitg[1] * ntg[2] + tpitg[2]);
1822
 
1823
+ device T * pdst = (device T *) (dst);
1824
+
1825
  if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
1826
+ pdst[offset_dst] = 0.0f;
1827
  } else {
1828
  const int32_t offset_src = tpitg[0] * ofs0 + tgpig[0] * ofs1;
1829
+ pdst[offset_dst] = x[offset_src + iih * IW + iiw];
1830
  }
1831
  }
1832
 
1833
+ template [[host_name("kernel_im2col_f32")]] kernel im2col_t kernel_im2col<float>;
1834
+ template [[host_name("kernel_im2col_f16")]] kernel im2col_t kernel_im2col<half>;
1835
+
1836
  kernel void kernel_upscale_f32(
1837
  device const char * src0,
1838
  device char * dst,