gabegoodhart ggerganov commited on
Commit
5359e09
·
1 Parent(s): 5629961

metal: SSM_SCAN performance (llama/14743)

Browse files

* feat: Add s_off as a parameter in the args struct

This may not be necessary, but it more closely mirrors the CUDA kernel

Branch: GraniteFourPerf

Signed-off-by: Gabe Goodhart <[email protected]>

* perf: Parallelize mamba2 SSM_SCAN metal kernel over d_state

This is a first attempt at optimizing the metal kernel. The changes here
are:

- Launch the kernel with a thread group of size d_state
- Use simd groups and shared memory to do the summation for the y
computation

When tested with G4 tiny preview, this shows roughly a 3x speedup on
prefill and 15% speedup on decode.

Signed-off-by: Gabe Goodhart <[email protected]>

* fix: Update logic to correctly do the multi-layer parallel sum

Signed-off-by: Gabe Goodhart <[email protected]>

* fix: Correctly size the shared memory bufer and assert expected size relationships

Branch: GraniteFourPerf

Signed-off-by: Gabe Goodhart <[email protected]>

* refactor: Compute block offsets once rather than once per token

Branch: GraniteFourPerf

Signed-off-by: Gabe Goodhart <[email protected]>

* feat: Use local variable for state recursion

Branch: GraniteFourPerf

Signed-off-by: Gabe Goodhart <[email protected]>

* feat: Use a secondary simd_sum instead of a for loop

Branch: GraniteFourPerf

Signed-off-by: Gabe Goodhart <[email protected]>

* feat: Add assertion and comment about relationship between simd size and num simd groups

Branch: GraniteFourPerf

Signed-off-by: Gabe Goodhart <[email protected]>

* feat: Parallelize of d_state for mamba-1

Branch: GraniteFourPerf

Signed-off-by: Gabe Goodhart <[email protected]>

* feat: Parallel sum in SSM_CONV

Branch: GraniteFourPerf

Signed-off-by: Gabe Goodhart <[email protected]>

* Revert "feat: Parallel sum in SSM_CONV"

After discussion with

@compilade
, the size of the parallelism here is
not worth the cost in complexity or overhead of the parallel for.

https://github.com/ggml-org/llama.cpp/pull/14743#discussion_r2223395357

This reverts commit 16bc059660c1c59e566628201c0ca2c20c9f4bc3.

Signed-off-by: Gabe Goodhart <[email protected]>

* refactor: Simplify shared memory sizing

Branch: GraniteFourPerf

Signed-off-by: Gabe Goodhart <[email protected]>
Co-Authored-By: Georgi Gerganov <[email protected]>

---------

Signed-off-by: Gabe Goodhart <[email protected]>
Co-authored-by: Georgi Gerganov <[email protected]>

ggml/src/ggml-metal/ggml-metal-impl.h CHANGED
@@ -528,6 +528,7 @@ typedef struct {
528
  int64_t n_group;
529
  int64_t n_seq_tokens;
530
  int64_t n_seqs;
 
531
  uint64_t nb01;
532
  uint64_t nb02;
533
  uint64_t nb03;
 
528
  int64_t n_group;
529
  int64_t n_seq_tokens;
530
  int64_t n_seqs;
531
+ int64_t s_off;
532
  uint64_t nb01;
533
  uint64_t nb02;
534
  uint64_t nb03;
ggml/src/ggml-metal/ggml-metal.m CHANGED
@@ -3141,6 +3141,7 @@ static int ggml_metal_encode_node(
3141
  /*.n_group =*/ n_group,
3142
  /*.n_seq_tokens =*/ n_seq_tokens,
3143
  /*.n_seqs =*/ n_seqs,
 
3144
  /*.nb01 =*/ nb01,
3145
  /*.nb02 =*/ nb02,
3146
  /*.nb03 =*/ nb03,
@@ -3169,12 +3170,22 @@ static int ggml_metal_encode_node(
3169
  [encoder setBuffer:id_dst offset:offs_dst atIndex:7];
3170
  [encoder setBytes:&args length:sizeof(args) atIndex:8];
3171
 
 
 
 
 
 
 
 
 
 
 
3172
  if (ne30 == 1) {
3173
  // Mamba-2
3174
- [encoder dispatchThreadgroups:MTLSizeMake(d_inner, n_head, n_seqs) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
3175
  } else {
3176
  GGML_ASSERT(d_inner == 1);
3177
- [encoder dispatchThreadgroups:MTLSizeMake(n_head, n_seqs, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
3178
  }
3179
  } break;
3180
  case GGML_OP_RWKV_WKV6:
 
3141
  /*.n_group =*/ n_group,
3142
  /*.n_seq_tokens =*/ n_seq_tokens,
3143
  /*.n_seqs =*/ n_seqs,
3144
+ /*.s_off =*/ ggml_nelements(src1) * sizeof(float),
3145
  /*.nb01 =*/ nb01,
3146
  /*.nb02 =*/ nb02,
3147
  /*.nb03 =*/ nb03,
 
3170
  [encoder setBuffer:id_dst offset:offs_dst atIndex:7];
3171
  [encoder setBytes:&args length:sizeof(args) atIndex:8];
3172
 
3173
+ // One shared memory bucket for each simd group in the threadgroup
3174
+ // NOTE: Metal kernels require the buffer size to be multiple of 16 bytes
3175
+ // https://developer.apple.com/documentation/metal/mtlcomputecommandencoder/1443142-setthreadgroupmemorylength
3176
+ if (d_state >= 32) {
3177
+ GGML_ASSERT((int64_t)(d_state / 32) <= 32);
3178
+ const int64_t shmem_size = 32;
3179
+ GGML_ASSERT(d_state <= (int64_t)pipeline.maxTotalThreadsPerThreadgroup);
3180
+ [encoder setThreadgroupMemoryLength:(shmem_size)*sizeof(float) atIndex:0];
3181
+ }
3182
+
3183
  if (ne30 == 1) {
3184
  // Mamba-2
3185
+ [encoder dispatchThreadgroups:MTLSizeMake(d_inner, n_head, n_seqs) threadsPerThreadgroup:MTLSizeMake(d_state, 1, 1)];
3186
  } else {
3187
  GGML_ASSERT(d_inner == 1);
3188
+ [encoder dispatchThreadgroups:MTLSizeMake(n_head, n_seqs, 1) threadsPerThreadgroup:MTLSizeMake(d_state, 1, 1)];
3189
  }
3190
  } break;
3191
  case GGML_OP_RWKV_WKV6:
ggml/src/ggml-metal/ggml-metal.metal CHANGED
@@ -1823,10 +1823,16 @@ kernel void kernel_ssm_scan_f32(
1823
  device const void * src5,
1824
  device const void * src6,
1825
  device float * dst,
 
1826
  constant ggml_metal_kargs_ssm_scan & args,
1827
- uint3 tgpig[[threadgroup_position_in_grid]],
1828
- uint3 tpitg[[thread_position_in_threadgroup]],
1829
- uint3 ntg[[threads_per_threadgroup]]) {
 
 
 
 
 
1830
  const int64_t i1 = 0;
1831
  const int64_t ir = tgpig.x; // current head
1832
  const int64_t i3 = tgpig.y; // current seq
@@ -1841,41 +1847,88 @@ kernel void kernel_ssm_scan_f32(
1841
  const int64_t ng = args.n_group;
1842
  const int64_t n_t = args.n_seq_tokens;
1843
 
1844
- const int64_t s_off = nr * nh * n_t * args.n_seqs * sizeof(float);
1845
 
1846
  device const int32_t * ids = (device const int32_t *) src6;
1847
 
1848
- device const float * s0 = (device const float *) ((device const char *) src0 + ir*args.nb02 + ids[i3]*args.nb03);
1849
- device float * s = (device float *) ((device char *) dst + ir*args.nb02 + i3*args.nb03 + s_off);
 
 
 
 
 
 
 
 
 
 
1850
 
1851
  for (int64_t i2 = 0; i2 < n_t; ++i2) {
1852
- device const float * x = (device const float *) ((device const char *) src1 + i1*nb10 + ir*args.nb11 + i2*args.nb12 + i3*args.nb13); // {dim, nh, nt, ns}
1853
- device const float * dt = (device const float *) ((device const char *) src2 + ir*nb20 + i2*args.nb21 + i3*args.nb22); // {nh, nt, ns}
1854
- device const float * A = (device const float *) ((device const char *) src3 + ir*args.nb31); // {d_state, nh}
1855
- device const float * B = (device const float *) ((device const char *) src4 + (ir & (ng - 1))*args.nb41 + i2*args.nb42 + i3*args.nb43); // {d_state, ng, nt, ns}
1856
- device const float * C = (device const float *) ((device const char *) src5 + (ir & (ng - 1))*args.nb51 + i2*args.nb52 + i3*args.nb53); // {d_state, ng, nt, ns}
1857
- device float * y = (device float *) ((device char *) dst + (i1 + ir*(nr) + i2*(nh*nr) + i3*(n_t*nh*nr))*nb00); // {dim, nh, nt, ns}
1858
 
1859
  const float dt_soft_plus = dt[0] <= 20.0f ? log(1.0f + exp(dt[0])) : dt[0];
1860
  const float x_dt = x[0] * dt_soft_plus;
1861
- float sumf = 0.0f;
1862
 
1863
- for (int64_t i0 = 0; i0 < nc; ++i0) {
1864
- const int64_t i = i0 + i1*nc;
1865
- const float state = (s0[i] * exp(dt_soft_plus * A[i0])) + (B[i0] * x_dt);
1866
- sumf += state * C[i0];
1867
- s[i] = state;
1868
- }
 
 
 
 
 
 
 
 
1869
 
1870
- y[0] = sumf;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1871
 
1872
  // recurse
1873
  s0 = s;
1874
  }
 
 
 
1875
  }
1876
 
1877
  // ref: ggml.c:ggml_compute_forward_ssm_scan_f32, Mamba-2 part
1878
- // TODO: optimize (e.g. by parallelizing over d_state)
1879
  kernel void kernel_ssm_scan_f32_group(
1880
  device const void * src0,
1881
  device const void * src1,
@@ -1885,10 +1938,16 @@ kernel void kernel_ssm_scan_f32_group(
1885
  device const void * src5,
1886
  device const void * src6,
1887
  device float * dst,
 
1888
  constant ggml_metal_kargs_ssm_scan & args,
1889
- uint3 tgpig[[threadgroup_position_in_grid]],
1890
- uint3 tpitg[[thread_position_in_threadgroup]],
1891
- uint3 ntg[[threads_per_threadgroup]]) {
 
 
 
 
 
1892
  const int64_t i1 = tgpig.x;
1893
  const int64_t ir = tgpig.y; // current head
1894
  const int64_t i3 = tgpig.z; // current seq
@@ -1903,38 +1962,81 @@ kernel void kernel_ssm_scan_f32_group(
1903
  const int64_t ng = args.n_group;
1904
  const int64_t n_t = args.n_seq_tokens;
1905
 
1906
- const int64_t s_off = nr * nh * n_t * args.n_seqs * sizeof(float);
1907
 
1908
  device const int32_t * ids = (device const int32_t *) src6;
1909
 
1910
- device const float * s0 = (device const float *) ((device const char *) src0 + ir*args.nb02 + ids[i3]*args.nb03);
1911
- device float * s = (device float *) ((device char *) dst + ir*args.nb02 + i3*args.nb03 + s_off);
 
 
 
 
 
 
 
 
 
 
1912
 
1913
  for (int64_t i2 = 0; i2 < n_t; ++i2) {
1914
- device const float * x = (device const float *) ((device const char *) src1 + i1*nb10 + ir*args.nb11 + i2*args.nb12 + i3*args.nb13); // {dim, nh, nt, ns}
1915
- device const float * dt = (device const float *) ((device const char *) src2 + ir*nb20 + i2*args.nb21 + i3*args.nb22); // {nh, nt, ns}
1916
- device const float * A = (device const float *) ((device const char *) src3 + ir*args.nb31); // {1, nh}
1917
- device const float * B = (device const float *) ((device const char *) src4 + (ir & (ng - 1))*args.nb41 + i2*args.nb42 + i3*args.nb43); // {d_state, ng, nt, ns}
1918
- device const float * C = (device const float *) ((device const char *) src5 + (ir & (ng - 1))*args.nb51 + i2*args.nb52 + i3*args.nb53); // {d_state, ng, nt, ns}
1919
- device float * y = (device float *) ((device char *) dst + (i1 + ir*(nr) + i2*(nh*nr) + i3*(n_t*nh*nr))*nb00); // {dim, nh, nt, ns}
1920
 
1921
  const float dt_soft_plus = dt[0] <= 20.0f ? log(1.0f + exp(dt[0])) : dt[0];
1922
  const float x_dt = x[0] * dt_soft_plus;
1923
  const float dA = exp(dt_soft_plus * A[0]);
1924
- float sumf = 0.0f;
1925
 
1926
- for (int64_t i0 = 0; i0 < nc; ++i0) {
1927
- const int64_t i = i0 + i1*nc;
1928
- const float state = (s0[i] * dA) + (B[i0] * x_dt);
1929
- sumf += state * C[i0];
1930
- s[i] = state;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1931
  }
1932
 
1933
- y[0] = sumf;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1934
 
1935
  // recurse
1936
  s0 = s;
1937
  }
 
 
 
1938
  }
1939
 
1940
  kernel void kernel_rwkv_wkv6_f32(
 
1823
  device const void * src5,
1824
  device const void * src6,
1825
  device float * dst,
1826
+ threadgroup float * shared [[threadgroup(0)]],
1827
  constant ggml_metal_kargs_ssm_scan & args,
1828
+ uint3 tgpig[[threadgroup_position_in_grid]],
1829
+ uint3 tpitg[[thread_position_in_threadgroup]],
1830
+ ushort sgitg[[simdgroup_index_in_threadgroup]],
1831
+ ushort tiisg[[thread_index_in_simdgroup]],
1832
+ ushort sgptg[[simdgroups_per_threadgroup]],
1833
+ uint3 tgpg[[threadgroups_per_grid]]) {
1834
+
1835
+ const int64_t i0 = tpitg.x;
1836
  const int64_t i1 = 0;
1837
  const int64_t ir = tgpig.x; // current head
1838
  const int64_t i3 = tgpig.y; // current seq
 
1847
  const int64_t ng = args.n_group;
1848
  const int64_t n_t = args.n_seq_tokens;
1849
 
1850
+ const int64_t s_off = args.s_off;
1851
 
1852
  device const int32_t * ids = (device const int32_t *) src6;
1853
 
1854
+ device const float * s0_buff = (device const float *) ((device const char *) src0 + ir*args.nb02 + ids[i3]*args.nb03);
1855
+ device float * s_buff = (device float *) ((device char *) dst + ir*args.nb02 + i3*args.nb03 + s_off);
1856
+ const int64_t i = i0 + i1*nc;
1857
+ float s0 = s0_buff[i];
1858
+ float s = s_buff[i];
1859
+
1860
+ device const float * A = (device const float *) ((device const char *) src3 + ir*args.nb31);
1861
+ device const float * x_block = (device const float *) ((device const char *) src1 + i1*nb10 + ir*args.nb11 + i3*args.nb13);
1862
+ device const float * dt_block = (device const float *) ((device const char *) src2 + ir*nb20 + i3*args.nb22);
1863
+ device const float * B_block = (device const float *) ((device const char *) src4 + (ir & (ng - 1))*args.nb41 + i3*args.nb43);
1864
+ device const float * C_block = (device const float *) ((device const char *) src5 + (ir & (ng - 1))*args.nb51 + i3*args.nb53);
1865
+ device float * y_block = (device float *) ((device char *) dst + (i1 + ir*(nr) + i3*(n_t*nh*nr))*nb00);
1866
 
1867
  for (int64_t i2 = 0; i2 < n_t; ++i2) {
1868
+ device const float * x = (device const float *) ((device const char *) x_block + i2*args.nb12); // {dim, nh, nt, ns}
1869
+ device const float * dt = (device const float *) ((device const char *) dt_block + i2*args.nb21); // {nh, nt, ns}
1870
+ device const float * B = (device const float *) ((device const char *) B_block + i2*args.nb42); // {d_state, ng, nt, ns}
1871
+ device const float * C = (device const float *) ((device const char *) C_block + i2*args.nb52); // {d_state, ng, nt, ns}
1872
+ device float * y = (device float *) ((device char *) y_block + i2*(nh*nr*nb00)); // {dim, nh, nt, ns}
 
1873
 
1874
  const float dt_soft_plus = dt[0] <= 20.0f ? log(1.0f + exp(dt[0])) : dt[0];
1875
  const float x_dt = x[0] * dt_soft_plus;
 
1876
 
1877
+ const float state = (s0 * exp(dt_soft_plus * A[i0])) + (B[i0] * x_dt);
1878
+ s = state;
1879
+
1880
+ // Parallel sum: This relies on the fact that this kernel will be
1881
+ // dispatched with each threadgroup having (d_state, 1, 1) threads which
1882
+ // are subdivided into SIMD groups of size `sgptg`. The goal is to
1883
+ // compute y = sum({state * C[i] for i in range(d_state)}).
1884
+ // To parallelize this effectively, we first use simd_sum over each SIMD
1885
+ // group to compute the sum of each SIMD group, then place the result in
1886
+ // the SIMD group's indexed bucket in the shared memory. We then sum
1887
+ // over the individual group sums to compute the final sum.
1888
+
1889
+ // Computed for each thread
1890
+ float sumf = state * C[i0];
1891
 
1892
+ // Sum the threads in the simd group => simd sum
1893
+ sumf = simd_sum(sumf);
1894
+
1895
+ if (sgptg > 1) {
1896
+
1897
+ // Once per simd group, place the group sum into the shared buffer
1898
+ if (tiisg == 0) {
1899
+ shared[sgitg] = sumf;
1900
+ }
1901
+
1902
+ // Wait for all threads in the threadgroup to reach this point. This
1903
+ // ensures that all elements of the shared buffer are populated with the
1904
+ // sum of the individual simd groups.
1905
+ threadgroup_barrier(mem_flags::mem_threadgroup);
1906
+
1907
+ // For simd group 0 at indices < num simd groups, extract the shared
1908
+ // simd sum
1909
+ sumf = 0.0f;
1910
+ if (sgitg == 0) {
1911
+ if (tiisg < sgptg) {
1912
+ sumf = shared[tiisg];
1913
+ }
1914
+ sumf = simd_sum(sumf);
1915
+ if (tiisg == 0) {
1916
+ y[0] = sumf;
1917
+ }
1918
+ }
1919
+ } else if (tiisg == 0) {
1920
+ y[0] = sumf;
1921
+ }
1922
 
1923
  // recurse
1924
  s0 = s;
1925
  }
1926
+
1927
+ // Assign the final state to the output buffer
1928
+ s_buff[i] = s;
1929
  }
1930
 
1931
  // ref: ggml.c:ggml_compute_forward_ssm_scan_f32, Mamba-2 part
 
1932
  kernel void kernel_ssm_scan_f32_group(
1933
  device const void * src0,
1934
  device const void * src1,
 
1938
  device const void * src5,
1939
  device const void * src6,
1940
  device float * dst,
1941
+ threadgroup float * shared [[threadgroup(0)]],
1942
  constant ggml_metal_kargs_ssm_scan & args,
1943
+ uint3 tgpig[[threadgroup_position_in_grid]],
1944
+ uint3 tpitg[[thread_position_in_threadgroup]],
1945
+ ushort sgitg[[simdgroup_index_in_threadgroup]],
1946
+ ushort tiisg[[thread_index_in_simdgroup]],
1947
+ ushort sgptg[[simdgroups_per_threadgroup]],
1948
+ uint3 tgpg[[threadgroups_per_grid]]) {
1949
+
1950
+ const int64_t i0 = tpitg.x;
1951
  const int64_t i1 = tgpig.x;
1952
  const int64_t ir = tgpig.y; // current head
1953
  const int64_t i3 = tgpig.z; // current seq
 
1962
  const int64_t ng = args.n_group;
1963
  const int64_t n_t = args.n_seq_tokens;
1964
 
1965
+ const int64_t s_off = args.s_off;
1966
 
1967
  device const int32_t * ids = (device const int32_t *) src6;
1968
 
1969
+ device const float * s0_buff = (device const float *) ((device const char *) src0 + ir*args.nb02 + ids[i3]*args.nb03);
1970
+ device float * s_buff = (device float *) ((device char *) dst + ir*args.nb02 + i3*args.nb03 + s_off);
1971
+ const int64_t i = i0 + i1*nc;
1972
+ float s0 = s0_buff[i];
1973
+ float s = s_buff[i];
1974
+
1975
+ device const float * A = (device const float *) ((device const char *) src3 + ir*args.nb31); // {1, nh}
1976
+ device const float * x_block = (device const float *) ((device const char *) src1 + i1*nb10 + ir*args.nb11 + i3*args.nb13);
1977
+ device const float * dt_block = (device const float *) ((device const char *) src2 + ir*nb20 + i3*args.nb22);
1978
+ device const float * B_block = (device const float *) ((device const char *) src4 + (ir & (ng - 1))*args.nb41 + i3*args.nb43);
1979
+ device const float * C_block = (device const float *) ((device const char *) src5 + (ir & (ng - 1))*args.nb51 + i3*args.nb53);
1980
+ device float * y_block = (device float *) ((device char *) dst + (i1 + ir*(nr) + i3*(n_t*nh*nr))*nb00);
1981
 
1982
  for (int64_t i2 = 0; i2 < n_t; ++i2) {
1983
+ device const float * x = (device const float *) ((device const char *) x_block + i2*args.nb12); // {dim, nh, nt, ns}
1984
+ device const float * dt = (device const float *) ((device const char *) dt_block + i2*args.nb21); // {nh, nt, ns}
1985
+ device const float * B = (device const float *) ((device const char *) B_block + i2*args.nb42); // {d_state, ng, nt, ns}
1986
+ device const float * C = (device const float *) ((device const char *) C_block + i2*args.nb52); // {d_state, ng, nt, ns}
1987
+ device float * y = (device float *) ((device char *) y_block + i2*(nh*nr*nb00)); // {dim, nh, nt, ns}
 
1988
 
1989
  const float dt_soft_plus = dt[0] <= 20.0f ? log(1.0f + exp(dt[0])) : dt[0];
1990
  const float x_dt = x[0] * dt_soft_plus;
1991
  const float dA = exp(dt_soft_plus * A[0]);
 
1992
 
1993
+ const float state = (s0 * dA) + (B[i0] * x_dt);
1994
+ s = state;
1995
+
1996
+ // Parallel sum: This relies on the fact that this kernel will be
1997
+ // dispatched with each threadgroup having (d_state, 1, 1) threads which
1998
+ // are subdivided into SIMD groups of size `sgptg`. The goal is to
1999
+ // compute y = sum({state * C[i] for i in range(d_state)}).
2000
+ // To parallelize this effectively, we first use simd_sum over each SIMD
2001
+ // group to compute the sum of each SIMD group, then place the result in
2002
+ // the SIMD group's indexed bucket in the shared memory. We then sum
2003
+ // over the individual group sums to compute the final sum.
2004
+
2005
+ // Computed for each thread
2006
+ float sumf = state * C[i0];
2007
+
2008
+ // Sum the threads in the simd group => simd sum
2009
+ sumf = simd_sum(sumf);
2010
+
2011
+ // Once per simd group, place the group sum into the shared buffer
2012
+ if (tiisg == 0) {
2013
+ shared[sgitg] = sumf;
2014
  }
2015
 
2016
+ // Wait for all threads in the threadgroup to reach this point. This
2017
+ // ensures that all elements of the shared buffer are populated with the
2018
+ // sum of the individual simd groups.
2019
+ threadgroup_barrier(mem_flags::mem_threadgroup);
2020
+
2021
+ // For simd group 0 at indices < num simd groups, extract the shared
2022
+ // simd sum
2023
+ sumf = 0.0f;
2024
+ if (sgitg == 0) {
2025
+ if (tiisg < sgptg) {
2026
+ sumf = shared[tiisg];
2027
+ }
2028
+ sumf = simd_sum(sumf);
2029
+ if (tiisg == 0) {
2030
+ y[0] = sumf;
2031
+ }
2032
+ }
2033
 
2034
  // recurse
2035
  s0 = s;
2036
  }
2037
+
2038
+ // Assign the final state to the output buffer
2039
+ s_buff[i] = s;
2040
  }
2041
 
2042
  kernel void kernel_rwkv_wkv6_f32(