ggerganov commited on
Commit
b6e7294
·
1 Parent(s): e62fd15

ggml : add SSM Metal kernels (llama/8546)

Browse files

* ggml : add ggml_ssm_conv metal impl

* ggml : add ssm_scan metal impl

ggml-ci

Files changed (3) hide show
  1. ggml/src/ggml-metal.m +122 -0
  2. ggml/src/ggml-metal.metal +121 -0
  3. ggml/src/ggml.c +2 -2
ggml/src/ggml-metal.m CHANGED
@@ -84,6 +84,8 @@ enum ggml_metal_kernel_type {
84
  GGML_METAL_KERNEL_TYPE_RMS_NORM,
85
  GGML_METAL_KERNEL_TYPE_GROUP_NORM,
86
  GGML_METAL_KERNEL_TYPE_NORM,
 
 
87
  GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32,
88
  GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16,
89
  GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32,
@@ -549,6 +551,8 @@ static struct ggml_backend_metal_context * ggml_metal_init(int n_cb) {
549
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RMS_NORM, rms_norm, ctx->support_simdgroup_reduction);
550
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GROUP_NORM, group_norm, ctx->support_simdgroup_reduction);
551
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NORM, norm, true);
 
 
552
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32, mul_mv_f32_f32, ctx->support_simdgroup_reduction);
553
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16, mul_mv_f16_f16, ctx->support_simdgroup_reduction);
554
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32, mul_mv_f16_f32, ctx->support_simdgroup_reduction);
@@ -818,6 +822,9 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_context * ctx
818
  return false;
819
  }
820
  return ctx->support_simdgroup_mm; // TODO: over-restricted for vec-kernels
 
 
 
821
  case GGML_OP_MUL_MAT:
822
  case GGML_OP_MUL_MAT_ID:
823
  return ctx->support_simdgroup_reduction &&
@@ -1598,6 +1605,121 @@ static enum ggml_status ggml_metal_graph_compute(
1598
  [encoder dispatchThreadgroups:MTLSizeMake(ne00, ne01, ne02) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1599
  }
1600
  } break;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1601
  case GGML_OP_MUL_MAT:
1602
  {
1603
  GGML_ASSERT(ne00 == ne10);
 
84
  GGML_METAL_KERNEL_TYPE_RMS_NORM,
85
  GGML_METAL_KERNEL_TYPE_GROUP_NORM,
86
  GGML_METAL_KERNEL_TYPE_NORM,
87
+ GGML_METAL_KERNEL_TYPE_SSM_CONV_F32,
88
+ GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32,
89
  GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32,
90
  GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16,
91
  GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32,
 
551
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RMS_NORM, rms_norm, ctx->support_simdgroup_reduction);
552
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GROUP_NORM, group_norm, ctx->support_simdgroup_reduction);
553
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NORM, norm, true);
554
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_CONV_F32, ssm_conv_f32, true);
555
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32, ssm_scan_f32, true);
556
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32, mul_mv_f32_f32, ctx->support_simdgroup_reduction);
557
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16, mul_mv_f16_f16, ctx->support_simdgroup_reduction);
558
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32, mul_mv_f16_f32, ctx->support_simdgroup_reduction);
 
822
  return false;
823
  }
824
  return ctx->support_simdgroup_mm; // TODO: over-restricted for vec-kernels
825
+ case GGML_OP_SSM_CONV:
826
+ case GGML_OP_SSM_SCAN:
827
+ return true;
828
  case GGML_OP_MUL_MAT:
829
  case GGML_OP_MUL_MAT_ID:
830
  return ctx->support_simdgroup_reduction &&
 
1605
  [encoder dispatchThreadgroups:MTLSizeMake(ne00, ne01, ne02) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1606
  }
1607
  } break;
1608
+ case GGML_OP_SSM_CONV:
1609
+ {
1610
+ GGML_ASSERT(src0t == GGML_TYPE_F32);
1611
+ GGML_ASSERT(src1t == GGML_TYPE_F32);
1612
+
1613
+ GGML_ASSERT(ggml_is_contiguous(src0));
1614
+ GGML_ASSERT(ggml_is_contiguous(src1));
1615
+
1616
+ id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SSM_CONV_F32].pipeline;
1617
+
1618
+ [encoder setComputePipelineState:pipeline];
1619
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1620
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
1621
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
1622
+ [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
1623
+ [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
1624
+ [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
1625
+ [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
1626
+ [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
1627
+ [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
1628
+ [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:9];
1629
+ [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:10];
1630
+ [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:11];
1631
+ [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:12];
1632
+ [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:13];
1633
+ [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:14];
1634
+ [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:15];
1635
+ [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:16];
1636
+ [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:17];
1637
+ [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:18];
1638
+
1639
+ [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne1, ne02) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1640
+ } break;
1641
+ case GGML_OP_SSM_SCAN:
1642
+ {
1643
+ struct ggml_tensor * src3 = gf->nodes[i]->src[3];
1644
+ struct ggml_tensor * src4 = gf->nodes[i]->src[4];
1645
+ struct ggml_tensor * src5 = gf->nodes[i]->src[5];
1646
+
1647
+ GGML_ASSERT(src3);
1648
+ GGML_ASSERT(src4);
1649
+ GGML_ASSERT(src5);
1650
+
1651
+ size_t offs_src3 = 0;
1652
+ size_t offs_src4 = 0;
1653
+ size_t offs_src5 = 0;
1654
+
1655
+ id<MTLBuffer> id_src3 = src3 ? ggml_metal_get_buffer(src3, &offs_src3) : nil;
1656
+ id<MTLBuffer> id_src4 = src4 ? ggml_metal_get_buffer(src4, &offs_src4) : nil;
1657
+ id<MTLBuffer> id_src5 = src5 ? ggml_metal_get_buffer(src5, &offs_src5) : nil;
1658
+
1659
+ const int64_t ne30 = src3->ne[0]; GGML_UNUSED(ne30);
1660
+ const int64_t ne31 = src3->ne[1]; GGML_UNUSED(ne31);
1661
+
1662
+ const uint64_t nb30 = src3->nb[0];
1663
+ const uint64_t nb31 = src3->nb[1];
1664
+
1665
+ const int64_t ne40 = src4->ne[0]; GGML_UNUSED(ne40);
1666
+ const int64_t ne41 = src4->ne[1]; GGML_UNUSED(ne41);
1667
+ const int64_t ne42 = src4->ne[2]; GGML_UNUSED(ne42);
1668
+
1669
+ const uint64_t nb40 = src4->nb[0];
1670
+ const uint64_t nb41 = src4->nb[1];
1671
+ const uint64_t nb42 = src4->nb[2];
1672
+
1673
+ const int64_t ne50 = src5->ne[0]; GGML_UNUSED(ne50);
1674
+ const int64_t ne51 = src5->ne[1]; GGML_UNUSED(ne51);
1675
+ const int64_t ne52 = src5->ne[2]; GGML_UNUSED(ne52);
1676
+
1677
+ const uint64_t nb50 = src5->nb[0];
1678
+ const uint64_t nb51 = src5->nb[1];
1679
+ const uint64_t nb52 = src5->nb[2];
1680
+
1681
+ const int64_t d_state = ne00;
1682
+ const int64_t d_inner = ne01;
1683
+ const int64_t n_seq_tokens = ne11;
1684
+ const int64_t n_seqs = ne02;
1685
+
1686
+ id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32].pipeline;
1687
+
1688
+ [encoder setComputePipelineState:pipeline];
1689
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1690
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
1691
+ [encoder setBuffer:id_src2 offset:offs_src2 atIndex:2];
1692
+ [encoder setBuffer:id_src3 offset:offs_src3 atIndex:3];
1693
+ [encoder setBuffer:id_src4 offset:offs_src4 atIndex:4];
1694
+ [encoder setBuffer:id_src5 offset:offs_src5 atIndex:5];
1695
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:6];
1696
+
1697
+ [encoder setBytes:&d_state length:sizeof(d_state) atIndex:7];
1698
+ [encoder setBytes:&d_inner length:sizeof(d_inner) atIndex:8];
1699
+ [encoder setBytes:&n_seq_tokens length:sizeof(n_seq_tokens) atIndex:9];
1700
+ [encoder setBytes:&n_seqs length:sizeof(n_seqs) atIndex:10];
1701
+
1702
+ [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:11];
1703
+ [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:12];
1704
+ [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:13];
1705
+ [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:14];
1706
+ [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:15];
1707
+ [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:16];
1708
+ [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:17];
1709
+ [encoder setBytes:&nb20 length:sizeof(nb20) atIndex:18];
1710
+ [encoder setBytes:&nb21 length:sizeof(nb21) atIndex:19];
1711
+ [encoder setBytes:&nb22 length:sizeof(nb22) atIndex:20];
1712
+ [encoder setBytes:&nb30 length:sizeof(nb30) atIndex:21];
1713
+ [encoder setBytes:&nb31 length:sizeof(nb31) atIndex:22];
1714
+ [encoder setBytes:&nb40 length:sizeof(nb40) atIndex:23];
1715
+ [encoder setBytes:&nb41 length:sizeof(nb41) atIndex:24];
1716
+ [encoder setBytes:&nb42 length:sizeof(nb42) atIndex:25];
1717
+ [encoder setBytes:&nb50 length:sizeof(nb50) atIndex:26];
1718
+ [encoder setBytes:&nb51 length:sizeof(nb51) atIndex:27];
1719
+ [encoder setBytes:&nb52 length:sizeof(nb52) atIndex:28];
1720
+
1721
+ [encoder dispatchThreadgroups:MTLSizeMake(d_inner, n_seqs, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1722
+ } break;
1723
  case GGML_OP_MUL_MAT:
1724
  {
1725
  GGML_ASSERT(ne00 == ne10);
ggml/src/ggml-metal.metal CHANGED
@@ -747,6 +747,127 @@ kernel void kernel_diag_mask_inf_8(
747
  }
748
  }
749
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
750
  kernel void kernel_norm(
751
  device const void * src0,
752
  device float * dst,
 
747
  }
748
  }
749
 
750
+ // ref: ggml.c:ggml_compute_forward_ssm_conv_f32
751
+ // TODO: optimize
752
+ kernel void kernel_ssm_conv_f32(
753
+ device const void * src0,
754
+ device const void * src1,
755
+ device float * dst,
756
+ constant int64_t & ne00,
757
+ constant int64_t & ne01,
758
+ constant int64_t & ne02,
759
+ constant uint64_t & nb00,
760
+ constant uint64_t & nb01,
761
+ constant uint64_t & nb02,
762
+ constant int64_t & ne10,
763
+ constant int64_t & ne11,
764
+ constant uint64_t & nb10,
765
+ constant uint64_t & nb11,
766
+ constant int64_t & ne0,
767
+ constant int64_t & ne1,
768
+ constant int64_t & ne2,
769
+ constant uint64_t & nb0,
770
+ constant uint64_t & nb1,
771
+ constant uint64_t & nb2,
772
+ uint3 tgpig[[threadgroup_position_in_grid]],
773
+ uint3 tpitg[[thread_position_in_threadgroup]],
774
+ uint3 ntg[[threads_per_threadgroup]]) {
775
+ const int64_t ir = tgpig.x;
776
+ const int64_t i2 = tgpig.y;
777
+ const int64_t i3 = tgpig.z;
778
+
779
+ const int64_t nc = ne10;
780
+ const int64_t ncs = ne00;
781
+ const int64_t nr = ne01;
782
+ const int64_t n_t = ne1;
783
+ const int64_t n_s = ne2;
784
+
785
+ device const float * s = (device const float *) ((device const char *) src0 + ir*nb01 + i2*nb00 + i3*nb02);
786
+ device const float * c = (device const float *) ((device const char *) src1 + ir*nb11);
787
+ device float * x = (device float *) ((device char *) dst + ir*nb0 + i2*nb1 + i3*nb2);
788
+
789
+ float sumf = 0.0f;
790
+
791
+ for (int64_t i0 = 0; i0 < nc; ++i0) {
792
+ sumf += s[i0] * c[i0];
793
+ }
794
+
795
+ x[0] = sumf;
796
+ }
797
+
798
+ // ref: ggml.c:ggml_compute_forward_ssm_scan_f32
799
+ // TODO: optimize
800
+ kernel void kernel_ssm_scan_f32(
801
+ device const void * src0,
802
+ device const void * src1,
803
+ device const void * src2,
804
+ device const void * src3,
805
+ device const void * src4,
806
+ device const void * src5,
807
+ device float * dst,
808
+ constant int64_t & d_state,
809
+ constant int64_t & d_inner,
810
+ constant int64_t & n_seq_tokens,
811
+ constant int64_t & n_seqs,
812
+ constant uint64_t & nb00,
813
+ constant uint64_t & nb01,
814
+ constant uint64_t & nb02,
815
+ constant uint64_t & nb10,
816
+ constant uint64_t & nb11,
817
+ constant uint64_t & nb12,
818
+ constant uint64_t & nb13,
819
+ constant uint64_t & nb20,
820
+ constant uint64_t & nb21,
821
+ constant uint64_t & nb22,
822
+ constant uint64_t & nb30,
823
+ constant uint64_t & nb31,
824
+ constant uint64_t & nb40,
825
+ constant uint64_t & nb41,
826
+ constant uint64_t & nb42,
827
+ constant uint64_t & nb50,
828
+ constant uint64_t & nb51,
829
+ constant uint64_t & nb52,
830
+ uint3 tgpig[[threadgroup_position_in_grid]],
831
+ uint3 tpitg[[thread_position_in_threadgroup]],
832
+ uint3 ntg[[threads_per_threadgroup]]) {
833
+ const int64_t ir = tgpig.x;
834
+ const int64_t i3 = tgpig.y;
835
+
836
+ const int64_t nc = d_state;
837
+ const int64_t nr = d_inner;
838
+ const int64_t n_t = n_seq_tokens;
839
+ const int64_t n_s = n_seqs;
840
+
841
+ for (int64_t i2 = 0; i2 < n_t; ++i2) {
842
+ device const float * s0 = (device const float *) ((device const char *) src0 + ir*nb01 + i3*nb02);
843
+ device const float * x = (device const float *) ((device const char *) src1 + ir*nb10 + i2*nb11 + i3*nb12);
844
+ device const float * dt = (device const float *) ((device const char *) src2 + ir*nb20 + i2*nb21 + i3*nb22);
845
+ device const float * A = (device const float *) ((device const char *) src3 + ir*nb31);
846
+ device const float * B = (device const float *) ((device const char *) src4 + i2*nb41 + i3*nb42);
847
+ device const float * C = (device const float *) ((device const char *) src5 + i2*nb51 + i3*nb52);
848
+ device float * y = (device float *) ((device char *) dst + ir*nb10 + i2*nb11 + i3*nb12); // TODO: do not use src1 strides
849
+ device float * s = (device float *) ((device char *) dst + ir*nb01 + i3*nb02 + nb13);
850
+
851
+ if (i2 > 0) {
852
+ s0 = s;
853
+ }
854
+
855
+ // i1 == 0
856
+ float dt_soft_plus = dt[0] <= 20.0f ? log(1.0f + exp(dt[0])) : dt[0];
857
+ float x_dt = x[0] * dt_soft_plus;
858
+ float sumf = 0.0f;
859
+
860
+ for (int64_t i0 = 0; i0 < nc; ++i0) {
861
+ int64_t i = i0;
862
+ float state = (s0[i] * exp(dt_soft_plus * A[i])) + (B[i0] * x_dt);
863
+ sumf += state * C[i0];
864
+ s[i] = state;
865
+ }
866
+
867
+ y[0] = sumf;
868
+ }
869
+ }
870
+
871
  kernel void kernel_norm(
872
  device const void * src0,
873
  device float * dst,
ggml/src/ggml.c CHANGED
@@ -16372,8 +16372,8 @@ static void ggml_compute_forward_ssm_scan_f32(
16372
  const float * A = (const float *) ((const char *) src3->data + ir0*(src3->nb[1])); // {d_state, d_inner}
16373
  const float * B = (const float *) ((const char *) src4->data + i2*(src4->nb[1]) + i3*(src4->nb[2])); // {d_state, n_t, n_s}
16374
  const float * C = (const float *) ((const char *) src5->data + i2*(src5->nb[1]) + i3*(src5->nb[2])); // {d_state, n_t, n_s}
16375
- float * y = (float *) ((char *) dst->data + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s}
16376
- float * s = (float *) ((char *) dst->data + ir0*(src0->nb[1]) + i3*(src0->nb[2]) + src1->nb[3]); // {d_state, d_inner, n_s}
16377
 
16378
  // use the output as the source for the next token-wise iterations
16379
  if (i2 > 0) { s0 = s; }
 
16372
  const float * A = (const float *) ((const char *) src3->data + ir0*(src3->nb[1])); // {d_state, d_inner}
16373
  const float * B = (const float *) ((const char *) src4->data + i2*(src4->nb[1]) + i3*(src4->nb[2])); // {d_state, n_t, n_s}
16374
  const float * C = (const float *) ((const char *) src5->data + i2*(src5->nb[1]) + i3*(src5->nb[2])); // {d_state, n_t, n_s}
16375
+ float * y = ( float *) (( char *) dst->data + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s}
16376
+ float * s = ( float *) (( char *) dst->data + ir0*(src0->nb[1]) + i3*(src0->nb[2]) + src1->nb[3]); // {d_state, d_inner, n_s}
16377
 
16378
  // use the output as the source for the next token-wise iterations
16379
  if (i2 > 0) { s0 = s; }