Spaces:
Running
metal: SSM_SCAN performance (llama/14743)
Browse files* feat: Add s_off as a parameter in the args struct
This may not be necessary, but it more closely mirrors the CUDA kernel
Branch: GraniteFourPerf
Signed-off-by: Gabe Goodhart <[email protected]>
* perf: Parallelize mamba2 SSM_SCAN metal kernel over d_state
This is a first attempt at optimizing the metal kernel. The changes here
are:
- Launch the kernel with a thread group of size d_state
- Use simd groups and shared memory to do the summation for the y
computation
When tested with G4 tiny preview, this shows roughly a 3x speedup on
prefill and 15% speedup on decode.
Signed-off-by: Gabe Goodhart <[email protected]>
* fix: Update logic to correctly do the multi-layer parallel sum
Signed-off-by: Gabe Goodhart <[email protected]>
* fix: Correctly size the shared memory bufer and assert expected size relationships
Branch: GraniteFourPerf
Signed-off-by: Gabe Goodhart <[email protected]>
* refactor: Compute block offsets once rather than once per token
Branch: GraniteFourPerf
Signed-off-by: Gabe Goodhart <[email protected]>
* feat: Use local variable for state recursion
Branch: GraniteFourPerf
Signed-off-by: Gabe Goodhart <[email protected]>
* feat: Use a secondary simd_sum instead of a for loop
Branch: GraniteFourPerf
Signed-off-by: Gabe Goodhart <[email protected]>
* feat: Add assertion and comment about relationship between simd size and num simd groups
Branch: GraniteFourPerf
Signed-off-by: Gabe Goodhart <[email protected]>
* feat: Parallelize of d_state for mamba-1
Branch: GraniteFourPerf
Signed-off-by: Gabe Goodhart <[email protected]>
* feat: Parallel sum in SSM_CONV
Branch: GraniteFourPerf
Signed-off-by: Gabe Goodhart <[email protected]>
* Revert "feat: Parallel sum in SSM_CONV"
After discussion with
@compilade
, the size of the parallelism here is
not worth the cost in complexity or overhead of the parallel for.
https://github.com/ggml-org/llama.cpp/pull/14743#discussion_r2223395357
This reverts commit 16bc059660c1c59e566628201c0ca2c20c9f4bc3.
Signed-off-by: Gabe Goodhart <[email protected]>
* refactor: Simplify shared memory sizing
Branch: GraniteFourPerf
Signed-off-by: Gabe Goodhart <[email protected]>
Co-Authored-By: Georgi Gerganov <[email protected]>
---------
Signed-off-by: Gabe Goodhart <[email protected]>
Co-authored-by: Georgi Gerganov <[email protected]>
|
@@ -528,6 +528,7 @@ typedef struct {
|
|
| 528 |
int64_t n_group;
|
| 529 |
int64_t n_seq_tokens;
|
| 530 |
int64_t n_seqs;
|
|
|
|
| 531 |
uint64_t nb01;
|
| 532 |
uint64_t nb02;
|
| 533 |
uint64_t nb03;
|
|
|
|
| 528 |
int64_t n_group;
|
| 529 |
int64_t n_seq_tokens;
|
| 530 |
int64_t n_seqs;
|
| 531 |
+
int64_t s_off;
|
| 532 |
uint64_t nb01;
|
| 533 |
uint64_t nb02;
|
| 534 |
uint64_t nb03;
|
|
@@ -3141,6 +3141,7 @@ static int ggml_metal_encode_node(
|
|
| 3141 |
/*.n_group =*/ n_group,
|
| 3142 |
/*.n_seq_tokens =*/ n_seq_tokens,
|
| 3143 |
/*.n_seqs =*/ n_seqs,
|
|
|
|
| 3144 |
/*.nb01 =*/ nb01,
|
| 3145 |
/*.nb02 =*/ nb02,
|
| 3146 |
/*.nb03 =*/ nb03,
|
|
@@ -3169,12 +3170,22 @@ static int ggml_metal_encode_node(
|
|
| 3169 |
[encoder setBuffer:id_dst offset:offs_dst atIndex:7];
|
| 3170 |
[encoder setBytes:&args length:sizeof(args) atIndex:8];
|
| 3171 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3172 |
if (ne30 == 1) {
|
| 3173 |
// Mamba-2
|
| 3174 |
-
[encoder dispatchThreadgroups:MTLSizeMake(d_inner, n_head, n_seqs) threadsPerThreadgroup:MTLSizeMake(
|
| 3175 |
} else {
|
| 3176 |
GGML_ASSERT(d_inner == 1);
|
| 3177 |
-
[encoder dispatchThreadgroups:MTLSizeMake(n_head, n_seqs, 1) threadsPerThreadgroup:MTLSizeMake(
|
| 3178 |
}
|
| 3179 |
} break;
|
| 3180 |
case GGML_OP_RWKV_WKV6:
|
|
|
|
| 3141 |
/*.n_group =*/ n_group,
|
| 3142 |
/*.n_seq_tokens =*/ n_seq_tokens,
|
| 3143 |
/*.n_seqs =*/ n_seqs,
|
| 3144 |
+
/*.s_off =*/ ggml_nelements(src1) * sizeof(float),
|
| 3145 |
/*.nb01 =*/ nb01,
|
| 3146 |
/*.nb02 =*/ nb02,
|
| 3147 |
/*.nb03 =*/ nb03,
|
|
|
|
| 3170 |
[encoder setBuffer:id_dst offset:offs_dst atIndex:7];
|
| 3171 |
[encoder setBytes:&args length:sizeof(args) atIndex:8];
|
| 3172 |
|
| 3173 |
+
// One shared memory bucket for each simd group in the threadgroup
|
| 3174 |
+
// NOTE: Metal kernels require the buffer size to be multiple of 16 bytes
|
| 3175 |
+
// https://developer.apple.com/documentation/metal/mtlcomputecommandencoder/1443142-setthreadgroupmemorylength
|
| 3176 |
+
if (d_state >= 32) {
|
| 3177 |
+
GGML_ASSERT((int64_t)(d_state / 32) <= 32);
|
| 3178 |
+
const int64_t shmem_size = 32;
|
| 3179 |
+
GGML_ASSERT(d_state <= (int64_t)pipeline.maxTotalThreadsPerThreadgroup);
|
| 3180 |
+
[encoder setThreadgroupMemoryLength:(shmem_size)*sizeof(float) atIndex:0];
|
| 3181 |
+
}
|
| 3182 |
+
|
| 3183 |
if (ne30 == 1) {
|
| 3184 |
// Mamba-2
|
| 3185 |
+
[encoder dispatchThreadgroups:MTLSizeMake(d_inner, n_head, n_seqs) threadsPerThreadgroup:MTLSizeMake(d_state, 1, 1)];
|
| 3186 |
} else {
|
| 3187 |
GGML_ASSERT(d_inner == 1);
|
| 3188 |
+
[encoder dispatchThreadgroups:MTLSizeMake(n_head, n_seqs, 1) threadsPerThreadgroup:MTLSizeMake(d_state, 1, 1)];
|
| 3189 |
}
|
| 3190 |
} break;
|
| 3191 |
case GGML_OP_RWKV_WKV6:
|
|
@@ -1823,10 +1823,16 @@ kernel void kernel_ssm_scan_f32(
|
|
| 1823 |
device const void * src5,
|
| 1824 |
device const void * src6,
|
| 1825 |
device float * dst,
|
|
|
|
| 1826 |
constant ggml_metal_kargs_ssm_scan & args,
|
| 1827 |
-
uint3
|
| 1828 |
-
uint3
|
| 1829 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1830 |
const int64_t i1 = 0;
|
| 1831 |
const int64_t ir = tgpig.x; // current head
|
| 1832 |
const int64_t i3 = tgpig.y; // current seq
|
|
@@ -1841,41 +1847,88 @@ kernel void kernel_ssm_scan_f32(
|
|
| 1841 |
const int64_t ng = args.n_group;
|
| 1842 |
const int64_t n_t = args.n_seq_tokens;
|
| 1843 |
|
| 1844 |
-
const int64_t s_off =
|
| 1845 |
|
| 1846 |
device const int32_t * ids = (device const int32_t *) src6;
|
| 1847 |
|
| 1848 |
-
device const float *
|
| 1849 |
-
device float *
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1850 |
|
| 1851 |
for (int64_t i2 = 0; i2 < n_t; ++i2) {
|
| 1852 |
-
device const float * x = (device const float *) ((device const char *)
|
| 1853 |
-
device const float * dt = (device const float *) ((device const char *)
|
| 1854 |
-
device const float *
|
| 1855 |
-
device const float *
|
| 1856 |
-
device
|
| 1857 |
-
device float * y = (device float *) ((device char *) dst + (i1 + ir*(nr) + i2*(nh*nr) + i3*(n_t*nh*nr))*nb00); // {dim, nh, nt, ns}
|
| 1858 |
|
| 1859 |
const float dt_soft_plus = dt[0] <= 20.0f ? log(1.0f + exp(dt[0])) : dt[0];
|
| 1860 |
const float x_dt = x[0] * dt_soft_plus;
|
| 1861 |
-
float sumf = 0.0f;
|
| 1862 |
|
| 1863 |
-
|
| 1864 |
-
|
| 1865 |
-
|
| 1866 |
-
|
| 1867 |
-
|
| 1868 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1869 |
|
| 1870 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1871 |
|
| 1872 |
// recurse
|
| 1873 |
s0 = s;
|
| 1874 |
}
|
|
|
|
|
|
|
|
|
|
| 1875 |
}
|
| 1876 |
|
| 1877 |
// ref: ggml.c:ggml_compute_forward_ssm_scan_f32, Mamba-2 part
|
| 1878 |
-
// TODO: optimize (e.g. by parallelizing over d_state)
|
| 1879 |
kernel void kernel_ssm_scan_f32_group(
|
| 1880 |
device const void * src0,
|
| 1881 |
device const void * src1,
|
|
@@ -1885,10 +1938,16 @@ kernel void kernel_ssm_scan_f32_group(
|
|
| 1885 |
device const void * src5,
|
| 1886 |
device const void * src6,
|
| 1887 |
device float * dst,
|
|
|
|
| 1888 |
constant ggml_metal_kargs_ssm_scan & args,
|
| 1889 |
-
uint3
|
| 1890 |
-
uint3
|
| 1891 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1892 |
const int64_t i1 = tgpig.x;
|
| 1893 |
const int64_t ir = tgpig.y; // current head
|
| 1894 |
const int64_t i3 = tgpig.z; // current seq
|
|
@@ -1903,38 +1962,81 @@ kernel void kernel_ssm_scan_f32_group(
|
|
| 1903 |
const int64_t ng = args.n_group;
|
| 1904 |
const int64_t n_t = args.n_seq_tokens;
|
| 1905 |
|
| 1906 |
-
const int64_t s_off =
|
| 1907 |
|
| 1908 |
device const int32_t * ids = (device const int32_t *) src6;
|
| 1909 |
|
| 1910 |
-
device const float *
|
| 1911 |
-
device float *
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1912 |
|
| 1913 |
for (int64_t i2 = 0; i2 < n_t; ++i2) {
|
| 1914 |
-
device const float * x = (device const float *) ((device const char *)
|
| 1915 |
-
device const float * dt = (device const float *) ((device const char *)
|
| 1916 |
-
device const float *
|
| 1917 |
-
device const float *
|
| 1918 |
-
device
|
| 1919 |
-
device float * y = (device float *) ((device char *) dst + (i1 + ir*(nr) + i2*(nh*nr) + i3*(n_t*nh*nr))*nb00); // {dim, nh, nt, ns}
|
| 1920 |
|
| 1921 |
const float dt_soft_plus = dt[0] <= 20.0f ? log(1.0f + exp(dt[0])) : dt[0];
|
| 1922 |
const float x_dt = x[0] * dt_soft_plus;
|
| 1923 |
const float dA = exp(dt_soft_plus * A[0]);
|
| 1924 |
-
float sumf = 0.0f;
|
| 1925 |
|
| 1926 |
-
|
| 1927 |
-
|
| 1928 |
-
|
| 1929 |
-
|
| 1930 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1931 |
}
|
| 1932 |
|
| 1933 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1934 |
|
| 1935 |
// recurse
|
| 1936 |
s0 = s;
|
| 1937 |
}
|
|
|
|
|
|
|
|
|
|
| 1938 |
}
|
| 1939 |
|
| 1940 |
kernel void kernel_rwkv_wkv6_f32(
|
|
|
|
| 1823 |
device const void * src5,
|
| 1824 |
device const void * src6,
|
| 1825 |
device float * dst,
|
| 1826 |
+
threadgroup float * shared [[threadgroup(0)]],
|
| 1827 |
constant ggml_metal_kargs_ssm_scan & args,
|
| 1828 |
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 1829 |
+
uint3 tpitg[[thread_position_in_threadgroup]],
|
| 1830 |
+
ushort sgitg[[simdgroup_index_in_threadgroup]],
|
| 1831 |
+
ushort tiisg[[thread_index_in_simdgroup]],
|
| 1832 |
+
ushort sgptg[[simdgroups_per_threadgroup]],
|
| 1833 |
+
uint3 tgpg[[threadgroups_per_grid]]) {
|
| 1834 |
+
|
| 1835 |
+
const int64_t i0 = tpitg.x;
|
| 1836 |
const int64_t i1 = 0;
|
| 1837 |
const int64_t ir = tgpig.x; // current head
|
| 1838 |
const int64_t i3 = tgpig.y; // current seq
|
|
|
|
| 1847 |
const int64_t ng = args.n_group;
|
| 1848 |
const int64_t n_t = args.n_seq_tokens;
|
| 1849 |
|
| 1850 |
+
const int64_t s_off = args.s_off;
|
| 1851 |
|
| 1852 |
device const int32_t * ids = (device const int32_t *) src6;
|
| 1853 |
|
| 1854 |
+
device const float * s0_buff = (device const float *) ((device const char *) src0 + ir*args.nb02 + ids[i3]*args.nb03);
|
| 1855 |
+
device float * s_buff = (device float *) ((device char *) dst + ir*args.nb02 + i3*args.nb03 + s_off);
|
| 1856 |
+
const int64_t i = i0 + i1*nc;
|
| 1857 |
+
float s0 = s0_buff[i];
|
| 1858 |
+
float s = s_buff[i];
|
| 1859 |
+
|
| 1860 |
+
device const float * A = (device const float *) ((device const char *) src3 + ir*args.nb31);
|
| 1861 |
+
device const float * x_block = (device const float *) ((device const char *) src1 + i1*nb10 + ir*args.nb11 + i3*args.nb13);
|
| 1862 |
+
device const float * dt_block = (device const float *) ((device const char *) src2 + ir*nb20 + i3*args.nb22);
|
| 1863 |
+
device const float * B_block = (device const float *) ((device const char *) src4 + (ir & (ng - 1))*args.nb41 + i3*args.nb43);
|
| 1864 |
+
device const float * C_block = (device const float *) ((device const char *) src5 + (ir & (ng - 1))*args.nb51 + i3*args.nb53);
|
| 1865 |
+
device float * y_block = (device float *) ((device char *) dst + (i1 + ir*(nr) + i3*(n_t*nh*nr))*nb00);
|
| 1866 |
|
| 1867 |
for (int64_t i2 = 0; i2 < n_t; ++i2) {
|
| 1868 |
+
device const float * x = (device const float *) ((device const char *) x_block + i2*args.nb12); // {dim, nh, nt, ns}
|
| 1869 |
+
device const float * dt = (device const float *) ((device const char *) dt_block + i2*args.nb21); // {nh, nt, ns}
|
| 1870 |
+
device const float * B = (device const float *) ((device const char *) B_block + i2*args.nb42); // {d_state, ng, nt, ns}
|
| 1871 |
+
device const float * C = (device const float *) ((device const char *) C_block + i2*args.nb52); // {d_state, ng, nt, ns}
|
| 1872 |
+
device float * y = (device float *) ((device char *) y_block + i2*(nh*nr*nb00)); // {dim, nh, nt, ns}
|
|
|
|
| 1873 |
|
| 1874 |
const float dt_soft_plus = dt[0] <= 20.0f ? log(1.0f + exp(dt[0])) : dt[0];
|
| 1875 |
const float x_dt = x[0] * dt_soft_plus;
|
|
|
|
| 1876 |
|
| 1877 |
+
const float state = (s0 * exp(dt_soft_plus * A[i0])) + (B[i0] * x_dt);
|
| 1878 |
+
s = state;
|
| 1879 |
+
|
| 1880 |
+
// Parallel sum: This relies on the fact that this kernel will be
|
| 1881 |
+
// dispatched with each threadgroup having (d_state, 1, 1) threads which
|
| 1882 |
+
// are subdivided into SIMD groups of size `sgptg`. The goal is to
|
| 1883 |
+
// compute y = sum({state * C[i] for i in range(d_state)}).
|
| 1884 |
+
// To parallelize this effectively, we first use simd_sum over each SIMD
|
| 1885 |
+
// group to compute the sum of each SIMD group, then place the result in
|
| 1886 |
+
// the SIMD group's indexed bucket in the shared memory. We then sum
|
| 1887 |
+
// over the individual group sums to compute the final sum.
|
| 1888 |
+
|
| 1889 |
+
// Computed for each thread
|
| 1890 |
+
float sumf = state * C[i0];
|
| 1891 |
|
| 1892 |
+
// Sum the threads in the simd group => simd sum
|
| 1893 |
+
sumf = simd_sum(sumf);
|
| 1894 |
+
|
| 1895 |
+
if (sgptg > 1) {
|
| 1896 |
+
|
| 1897 |
+
// Once per simd group, place the group sum into the shared buffer
|
| 1898 |
+
if (tiisg == 0) {
|
| 1899 |
+
shared[sgitg] = sumf;
|
| 1900 |
+
}
|
| 1901 |
+
|
| 1902 |
+
// Wait for all threads in the threadgroup to reach this point. This
|
| 1903 |
+
// ensures that all elements of the shared buffer are populated with the
|
| 1904 |
+
// sum of the individual simd groups.
|
| 1905 |
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
| 1906 |
+
|
| 1907 |
+
// For simd group 0 at indices < num simd groups, extract the shared
|
| 1908 |
+
// simd sum
|
| 1909 |
+
sumf = 0.0f;
|
| 1910 |
+
if (sgitg == 0) {
|
| 1911 |
+
if (tiisg < sgptg) {
|
| 1912 |
+
sumf = shared[tiisg];
|
| 1913 |
+
}
|
| 1914 |
+
sumf = simd_sum(sumf);
|
| 1915 |
+
if (tiisg == 0) {
|
| 1916 |
+
y[0] = sumf;
|
| 1917 |
+
}
|
| 1918 |
+
}
|
| 1919 |
+
} else if (tiisg == 0) {
|
| 1920 |
+
y[0] = sumf;
|
| 1921 |
+
}
|
| 1922 |
|
| 1923 |
// recurse
|
| 1924 |
s0 = s;
|
| 1925 |
}
|
| 1926 |
+
|
| 1927 |
+
// Assign the final state to the output buffer
|
| 1928 |
+
s_buff[i] = s;
|
| 1929 |
}
|
| 1930 |
|
| 1931 |
// ref: ggml.c:ggml_compute_forward_ssm_scan_f32, Mamba-2 part
|
|
|
|
| 1932 |
kernel void kernel_ssm_scan_f32_group(
|
| 1933 |
device const void * src0,
|
| 1934 |
device const void * src1,
|
|
|
|
| 1938 |
device const void * src5,
|
| 1939 |
device const void * src6,
|
| 1940 |
device float * dst,
|
| 1941 |
+
threadgroup float * shared [[threadgroup(0)]],
|
| 1942 |
constant ggml_metal_kargs_ssm_scan & args,
|
| 1943 |
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 1944 |
+
uint3 tpitg[[thread_position_in_threadgroup]],
|
| 1945 |
+
ushort sgitg[[simdgroup_index_in_threadgroup]],
|
| 1946 |
+
ushort tiisg[[thread_index_in_simdgroup]],
|
| 1947 |
+
ushort sgptg[[simdgroups_per_threadgroup]],
|
| 1948 |
+
uint3 tgpg[[threadgroups_per_grid]]) {
|
| 1949 |
+
|
| 1950 |
+
const int64_t i0 = tpitg.x;
|
| 1951 |
const int64_t i1 = tgpig.x;
|
| 1952 |
const int64_t ir = tgpig.y; // current head
|
| 1953 |
const int64_t i3 = tgpig.z; // current seq
|
|
|
|
| 1962 |
const int64_t ng = args.n_group;
|
| 1963 |
const int64_t n_t = args.n_seq_tokens;
|
| 1964 |
|
| 1965 |
+
const int64_t s_off = args.s_off;
|
| 1966 |
|
| 1967 |
device const int32_t * ids = (device const int32_t *) src6;
|
| 1968 |
|
| 1969 |
+
device const float * s0_buff = (device const float *) ((device const char *) src0 + ir*args.nb02 + ids[i3]*args.nb03);
|
| 1970 |
+
device float * s_buff = (device float *) ((device char *) dst + ir*args.nb02 + i3*args.nb03 + s_off);
|
| 1971 |
+
const int64_t i = i0 + i1*nc;
|
| 1972 |
+
float s0 = s0_buff[i];
|
| 1973 |
+
float s = s_buff[i];
|
| 1974 |
+
|
| 1975 |
+
device const float * A = (device const float *) ((device const char *) src3 + ir*args.nb31); // {1, nh}
|
| 1976 |
+
device const float * x_block = (device const float *) ((device const char *) src1 + i1*nb10 + ir*args.nb11 + i3*args.nb13);
|
| 1977 |
+
device const float * dt_block = (device const float *) ((device const char *) src2 + ir*nb20 + i3*args.nb22);
|
| 1978 |
+
device const float * B_block = (device const float *) ((device const char *) src4 + (ir & (ng - 1))*args.nb41 + i3*args.nb43);
|
| 1979 |
+
device const float * C_block = (device const float *) ((device const char *) src5 + (ir & (ng - 1))*args.nb51 + i3*args.nb53);
|
| 1980 |
+
device float * y_block = (device float *) ((device char *) dst + (i1 + ir*(nr) + i3*(n_t*nh*nr))*nb00);
|
| 1981 |
|
| 1982 |
for (int64_t i2 = 0; i2 < n_t; ++i2) {
|
| 1983 |
+
device const float * x = (device const float *) ((device const char *) x_block + i2*args.nb12); // {dim, nh, nt, ns}
|
| 1984 |
+
device const float * dt = (device const float *) ((device const char *) dt_block + i2*args.nb21); // {nh, nt, ns}
|
| 1985 |
+
device const float * B = (device const float *) ((device const char *) B_block + i2*args.nb42); // {d_state, ng, nt, ns}
|
| 1986 |
+
device const float * C = (device const float *) ((device const char *) C_block + i2*args.nb52); // {d_state, ng, nt, ns}
|
| 1987 |
+
device float * y = (device float *) ((device char *) y_block + i2*(nh*nr*nb00)); // {dim, nh, nt, ns}
|
|
|
|
| 1988 |
|
| 1989 |
const float dt_soft_plus = dt[0] <= 20.0f ? log(1.0f + exp(dt[0])) : dt[0];
|
| 1990 |
const float x_dt = x[0] * dt_soft_plus;
|
| 1991 |
const float dA = exp(dt_soft_plus * A[0]);
|
|
|
|
| 1992 |
|
| 1993 |
+
const float state = (s0 * dA) + (B[i0] * x_dt);
|
| 1994 |
+
s = state;
|
| 1995 |
+
|
| 1996 |
+
// Parallel sum: This relies on the fact that this kernel will be
|
| 1997 |
+
// dispatched with each threadgroup having (d_state, 1, 1) threads which
|
| 1998 |
+
// are subdivided into SIMD groups of size `sgptg`. The goal is to
|
| 1999 |
+
// compute y = sum({state * C[i] for i in range(d_state)}).
|
| 2000 |
+
// To parallelize this effectively, we first use simd_sum over each SIMD
|
| 2001 |
+
// group to compute the sum of each SIMD group, then place the result in
|
| 2002 |
+
// the SIMD group's indexed bucket in the shared memory. We then sum
|
| 2003 |
+
// over the individual group sums to compute the final sum.
|
| 2004 |
+
|
| 2005 |
+
// Computed for each thread
|
| 2006 |
+
float sumf = state * C[i0];
|
| 2007 |
+
|
| 2008 |
+
// Sum the threads in the simd group => simd sum
|
| 2009 |
+
sumf = simd_sum(sumf);
|
| 2010 |
+
|
| 2011 |
+
// Once per simd group, place the group sum into the shared buffer
|
| 2012 |
+
if (tiisg == 0) {
|
| 2013 |
+
shared[sgitg] = sumf;
|
| 2014 |
}
|
| 2015 |
|
| 2016 |
+
// Wait for all threads in the threadgroup to reach this point. This
|
| 2017 |
+
// ensures that all elements of the shared buffer are populated with the
|
| 2018 |
+
// sum of the individual simd groups.
|
| 2019 |
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
| 2020 |
+
|
| 2021 |
+
// For simd group 0 at indices < num simd groups, extract the shared
|
| 2022 |
+
// simd sum
|
| 2023 |
+
sumf = 0.0f;
|
| 2024 |
+
if (sgitg == 0) {
|
| 2025 |
+
if (tiisg < sgptg) {
|
| 2026 |
+
sumf = shared[tiisg];
|
| 2027 |
+
}
|
| 2028 |
+
sumf = simd_sum(sumf);
|
| 2029 |
+
if (tiisg == 0) {
|
| 2030 |
+
y[0] = sumf;
|
| 2031 |
+
}
|
| 2032 |
+
}
|
| 2033 |
|
| 2034 |
// recurse
|
| 2035 |
s0 = s;
|
| 2036 |
}
|
| 2037 |
+
|
| 2038 |
+
// Assign the final state to the output buffer
|
| 2039 |
+
s_buff[i] = s;
|
| 2040 |
}
|
| 2041 |
|
| 2042 |
kernel void kernel_rwkv_wkv6_f32(
|