Spaces:
Sleeping
Sleeping
ggml : add SSM Metal kernels (llama/8546)
Browse files* ggml : add ggml_ssm_conv metal impl
* ggml : add ssm_scan metal impl
ggml-ci
- ggml/src/ggml-metal.m +122 -0
- ggml/src/ggml-metal.metal +121 -0
- ggml/src/ggml.c +2 -2
ggml/src/ggml-metal.m
CHANGED
|
@@ -84,6 +84,8 @@ enum ggml_metal_kernel_type {
|
|
| 84 |
GGML_METAL_KERNEL_TYPE_RMS_NORM,
|
| 85 |
GGML_METAL_KERNEL_TYPE_GROUP_NORM,
|
| 86 |
GGML_METAL_KERNEL_TYPE_NORM,
|
|
|
|
|
|
|
| 87 |
GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32,
|
| 88 |
GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16,
|
| 89 |
GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32,
|
|
@@ -549,6 +551,8 @@ static struct ggml_backend_metal_context * ggml_metal_init(int n_cb) {
|
|
| 549 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RMS_NORM, rms_norm, ctx->support_simdgroup_reduction);
|
| 550 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GROUP_NORM, group_norm, ctx->support_simdgroup_reduction);
|
| 551 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NORM, norm, true);
|
|
|
|
|
|
|
| 552 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32, mul_mv_f32_f32, ctx->support_simdgroup_reduction);
|
| 553 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16, mul_mv_f16_f16, ctx->support_simdgroup_reduction);
|
| 554 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32, mul_mv_f16_f32, ctx->support_simdgroup_reduction);
|
|
@@ -818,6 +822,9 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_context * ctx
|
|
| 818 |
return false;
|
| 819 |
}
|
| 820 |
return ctx->support_simdgroup_mm; // TODO: over-restricted for vec-kernels
|
|
|
|
|
|
|
|
|
|
| 821 |
case GGML_OP_MUL_MAT:
|
| 822 |
case GGML_OP_MUL_MAT_ID:
|
| 823 |
return ctx->support_simdgroup_reduction &&
|
|
@@ -1598,6 +1605,121 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
| 1598 |
[encoder dispatchThreadgroups:MTLSizeMake(ne00, ne01, ne02) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
| 1599 |
}
|
| 1600 |
} break;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1601 |
case GGML_OP_MUL_MAT:
|
| 1602 |
{
|
| 1603 |
GGML_ASSERT(ne00 == ne10);
|
|
|
|
| 84 |
GGML_METAL_KERNEL_TYPE_RMS_NORM,
|
| 85 |
GGML_METAL_KERNEL_TYPE_GROUP_NORM,
|
| 86 |
GGML_METAL_KERNEL_TYPE_NORM,
|
| 87 |
+
GGML_METAL_KERNEL_TYPE_SSM_CONV_F32,
|
| 88 |
+
GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32,
|
| 89 |
GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32,
|
| 90 |
GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16,
|
| 91 |
GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32,
|
|
|
|
| 551 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RMS_NORM, rms_norm, ctx->support_simdgroup_reduction);
|
| 552 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GROUP_NORM, group_norm, ctx->support_simdgroup_reduction);
|
| 553 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NORM, norm, true);
|
| 554 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_CONV_F32, ssm_conv_f32, true);
|
| 555 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32, ssm_scan_f32, true);
|
| 556 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32, mul_mv_f32_f32, ctx->support_simdgroup_reduction);
|
| 557 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16, mul_mv_f16_f16, ctx->support_simdgroup_reduction);
|
| 558 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32, mul_mv_f16_f32, ctx->support_simdgroup_reduction);
|
|
|
|
| 822 |
return false;
|
| 823 |
}
|
| 824 |
return ctx->support_simdgroup_mm; // TODO: over-restricted for vec-kernels
|
| 825 |
+
case GGML_OP_SSM_CONV:
|
| 826 |
+
case GGML_OP_SSM_SCAN:
|
| 827 |
+
return true;
|
| 828 |
case GGML_OP_MUL_MAT:
|
| 829 |
case GGML_OP_MUL_MAT_ID:
|
| 830 |
return ctx->support_simdgroup_reduction &&
|
|
|
|
| 1605 |
[encoder dispatchThreadgroups:MTLSizeMake(ne00, ne01, ne02) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
| 1606 |
}
|
| 1607 |
} break;
|
| 1608 |
+
case GGML_OP_SSM_CONV:
|
| 1609 |
+
{
|
| 1610 |
+
GGML_ASSERT(src0t == GGML_TYPE_F32);
|
| 1611 |
+
GGML_ASSERT(src1t == GGML_TYPE_F32);
|
| 1612 |
+
|
| 1613 |
+
GGML_ASSERT(ggml_is_contiguous(src0));
|
| 1614 |
+
GGML_ASSERT(ggml_is_contiguous(src1));
|
| 1615 |
+
|
| 1616 |
+
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SSM_CONV_F32].pipeline;
|
| 1617 |
+
|
| 1618 |
+
[encoder setComputePipelineState:pipeline];
|
| 1619 |
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
| 1620 |
+
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
| 1621 |
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
| 1622 |
+
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
|
| 1623 |
+
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
|
| 1624 |
+
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
|
| 1625 |
+
[encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
|
| 1626 |
+
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
|
| 1627 |
+
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
|
| 1628 |
+
[encoder setBytes:&ne10 length:sizeof(ne10) atIndex:9];
|
| 1629 |
+
[encoder setBytes:&ne11 length:sizeof(ne11) atIndex:10];
|
| 1630 |
+
[encoder setBytes:&nb10 length:sizeof(nb10) atIndex:11];
|
| 1631 |
+
[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:12];
|
| 1632 |
+
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:13];
|
| 1633 |
+
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:14];
|
| 1634 |
+
[encoder setBytes:&ne2 length:sizeof(ne2) atIndex:15];
|
| 1635 |
+
[encoder setBytes:&nb0 length:sizeof(nb0) atIndex:16];
|
| 1636 |
+
[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:17];
|
| 1637 |
+
[encoder setBytes:&nb2 length:sizeof(nb2) atIndex:18];
|
| 1638 |
+
|
| 1639 |
+
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne1, ne02) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
| 1640 |
+
} break;
|
| 1641 |
+
case GGML_OP_SSM_SCAN:
|
| 1642 |
+
{
|
| 1643 |
+
struct ggml_tensor * src3 = gf->nodes[i]->src[3];
|
| 1644 |
+
struct ggml_tensor * src4 = gf->nodes[i]->src[4];
|
| 1645 |
+
struct ggml_tensor * src5 = gf->nodes[i]->src[5];
|
| 1646 |
+
|
| 1647 |
+
GGML_ASSERT(src3);
|
| 1648 |
+
GGML_ASSERT(src4);
|
| 1649 |
+
GGML_ASSERT(src5);
|
| 1650 |
+
|
| 1651 |
+
size_t offs_src3 = 0;
|
| 1652 |
+
size_t offs_src4 = 0;
|
| 1653 |
+
size_t offs_src5 = 0;
|
| 1654 |
+
|
| 1655 |
+
id<MTLBuffer> id_src3 = src3 ? ggml_metal_get_buffer(src3, &offs_src3) : nil;
|
| 1656 |
+
id<MTLBuffer> id_src4 = src4 ? ggml_metal_get_buffer(src4, &offs_src4) : nil;
|
| 1657 |
+
id<MTLBuffer> id_src5 = src5 ? ggml_metal_get_buffer(src5, &offs_src5) : nil;
|
| 1658 |
+
|
| 1659 |
+
const int64_t ne30 = src3->ne[0]; GGML_UNUSED(ne30);
|
| 1660 |
+
const int64_t ne31 = src3->ne[1]; GGML_UNUSED(ne31);
|
| 1661 |
+
|
| 1662 |
+
const uint64_t nb30 = src3->nb[0];
|
| 1663 |
+
const uint64_t nb31 = src3->nb[1];
|
| 1664 |
+
|
| 1665 |
+
const int64_t ne40 = src4->ne[0]; GGML_UNUSED(ne40);
|
| 1666 |
+
const int64_t ne41 = src4->ne[1]; GGML_UNUSED(ne41);
|
| 1667 |
+
const int64_t ne42 = src4->ne[2]; GGML_UNUSED(ne42);
|
| 1668 |
+
|
| 1669 |
+
const uint64_t nb40 = src4->nb[0];
|
| 1670 |
+
const uint64_t nb41 = src4->nb[1];
|
| 1671 |
+
const uint64_t nb42 = src4->nb[2];
|
| 1672 |
+
|
| 1673 |
+
const int64_t ne50 = src5->ne[0]; GGML_UNUSED(ne50);
|
| 1674 |
+
const int64_t ne51 = src5->ne[1]; GGML_UNUSED(ne51);
|
| 1675 |
+
const int64_t ne52 = src5->ne[2]; GGML_UNUSED(ne52);
|
| 1676 |
+
|
| 1677 |
+
const uint64_t nb50 = src5->nb[0];
|
| 1678 |
+
const uint64_t nb51 = src5->nb[1];
|
| 1679 |
+
const uint64_t nb52 = src5->nb[2];
|
| 1680 |
+
|
| 1681 |
+
const int64_t d_state = ne00;
|
| 1682 |
+
const int64_t d_inner = ne01;
|
| 1683 |
+
const int64_t n_seq_tokens = ne11;
|
| 1684 |
+
const int64_t n_seqs = ne02;
|
| 1685 |
+
|
| 1686 |
+
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32].pipeline;
|
| 1687 |
+
|
| 1688 |
+
[encoder setComputePipelineState:pipeline];
|
| 1689 |
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
| 1690 |
+
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
| 1691 |
+
[encoder setBuffer:id_src2 offset:offs_src2 atIndex:2];
|
| 1692 |
+
[encoder setBuffer:id_src3 offset:offs_src3 atIndex:3];
|
| 1693 |
+
[encoder setBuffer:id_src4 offset:offs_src4 atIndex:4];
|
| 1694 |
+
[encoder setBuffer:id_src5 offset:offs_src5 atIndex:5];
|
| 1695 |
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:6];
|
| 1696 |
+
|
| 1697 |
+
[encoder setBytes:&d_state length:sizeof(d_state) atIndex:7];
|
| 1698 |
+
[encoder setBytes:&d_inner length:sizeof(d_inner) atIndex:8];
|
| 1699 |
+
[encoder setBytes:&n_seq_tokens length:sizeof(n_seq_tokens) atIndex:9];
|
| 1700 |
+
[encoder setBytes:&n_seqs length:sizeof(n_seqs) atIndex:10];
|
| 1701 |
+
|
| 1702 |
+
[encoder setBytes:&nb00 length:sizeof(nb00) atIndex:11];
|
| 1703 |
+
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:12];
|
| 1704 |
+
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:13];
|
| 1705 |
+
[encoder setBytes:&nb10 length:sizeof(nb10) atIndex:14];
|
| 1706 |
+
[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:15];
|
| 1707 |
+
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:16];
|
| 1708 |
+
[encoder setBytes:&nb13 length:sizeof(nb13) atIndex:17];
|
| 1709 |
+
[encoder setBytes:&nb20 length:sizeof(nb20) atIndex:18];
|
| 1710 |
+
[encoder setBytes:&nb21 length:sizeof(nb21) atIndex:19];
|
| 1711 |
+
[encoder setBytes:&nb22 length:sizeof(nb22) atIndex:20];
|
| 1712 |
+
[encoder setBytes:&nb30 length:sizeof(nb30) atIndex:21];
|
| 1713 |
+
[encoder setBytes:&nb31 length:sizeof(nb31) atIndex:22];
|
| 1714 |
+
[encoder setBytes:&nb40 length:sizeof(nb40) atIndex:23];
|
| 1715 |
+
[encoder setBytes:&nb41 length:sizeof(nb41) atIndex:24];
|
| 1716 |
+
[encoder setBytes:&nb42 length:sizeof(nb42) atIndex:25];
|
| 1717 |
+
[encoder setBytes:&nb50 length:sizeof(nb50) atIndex:26];
|
| 1718 |
+
[encoder setBytes:&nb51 length:sizeof(nb51) atIndex:27];
|
| 1719 |
+
[encoder setBytes:&nb52 length:sizeof(nb52) atIndex:28];
|
| 1720 |
+
|
| 1721 |
+
[encoder dispatchThreadgroups:MTLSizeMake(d_inner, n_seqs, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
| 1722 |
+
} break;
|
| 1723 |
case GGML_OP_MUL_MAT:
|
| 1724 |
{
|
| 1725 |
GGML_ASSERT(ne00 == ne10);
|
ggml/src/ggml-metal.metal
CHANGED
|
@@ -747,6 +747,127 @@ kernel void kernel_diag_mask_inf_8(
|
|
| 747 |
}
|
| 748 |
}
|
| 749 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 750 |
kernel void kernel_norm(
|
| 751 |
device const void * src0,
|
| 752 |
device float * dst,
|
|
|
|
| 747 |
}
|
| 748 |
}
|
| 749 |
|
| 750 |
+
// ref: ggml.c:ggml_compute_forward_ssm_conv_f32
|
| 751 |
+
// TODO: optimize
|
| 752 |
+
kernel void kernel_ssm_conv_f32(
|
| 753 |
+
device const void * src0,
|
| 754 |
+
device const void * src1,
|
| 755 |
+
device float * dst,
|
| 756 |
+
constant int64_t & ne00,
|
| 757 |
+
constant int64_t & ne01,
|
| 758 |
+
constant int64_t & ne02,
|
| 759 |
+
constant uint64_t & nb00,
|
| 760 |
+
constant uint64_t & nb01,
|
| 761 |
+
constant uint64_t & nb02,
|
| 762 |
+
constant int64_t & ne10,
|
| 763 |
+
constant int64_t & ne11,
|
| 764 |
+
constant uint64_t & nb10,
|
| 765 |
+
constant uint64_t & nb11,
|
| 766 |
+
constant int64_t & ne0,
|
| 767 |
+
constant int64_t & ne1,
|
| 768 |
+
constant int64_t & ne2,
|
| 769 |
+
constant uint64_t & nb0,
|
| 770 |
+
constant uint64_t & nb1,
|
| 771 |
+
constant uint64_t & nb2,
|
| 772 |
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 773 |
+
uint3 tpitg[[thread_position_in_threadgroup]],
|
| 774 |
+
uint3 ntg[[threads_per_threadgroup]]) {
|
| 775 |
+
const int64_t ir = tgpig.x;
|
| 776 |
+
const int64_t i2 = tgpig.y;
|
| 777 |
+
const int64_t i3 = tgpig.z;
|
| 778 |
+
|
| 779 |
+
const int64_t nc = ne10;
|
| 780 |
+
const int64_t ncs = ne00;
|
| 781 |
+
const int64_t nr = ne01;
|
| 782 |
+
const int64_t n_t = ne1;
|
| 783 |
+
const int64_t n_s = ne2;
|
| 784 |
+
|
| 785 |
+
device const float * s = (device const float *) ((device const char *) src0 + ir*nb01 + i2*nb00 + i3*nb02);
|
| 786 |
+
device const float * c = (device const float *) ((device const char *) src1 + ir*nb11);
|
| 787 |
+
device float * x = (device float *) ((device char *) dst + ir*nb0 + i2*nb1 + i3*nb2);
|
| 788 |
+
|
| 789 |
+
float sumf = 0.0f;
|
| 790 |
+
|
| 791 |
+
for (int64_t i0 = 0; i0 < nc; ++i0) {
|
| 792 |
+
sumf += s[i0] * c[i0];
|
| 793 |
+
}
|
| 794 |
+
|
| 795 |
+
x[0] = sumf;
|
| 796 |
+
}
|
| 797 |
+
|
| 798 |
+
// ref: ggml.c:ggml_compute_forward_ssm_scan_f32
|
| 799 |
+
// TODO: optimize
|
| 800 |
+
kernel void kernel_ssm_scan_f32(
|
| 801 |
+
device const void * src0,
|
| 802 |
+
device const void * src1,
|
| 803 |
+
device const void * src2,
|
| 804 |
+
device const void * src3,
|
| 805 |
+
device const void * src4,
|
| 806 |
+
device const void * src5,
|
| 807 |
+
device float * dst,
|
| 808 |
+
constant int64_t & d_state,
|
| 809 |
+
constant int64_t & d_inner,
|
| 810 |
+
constant int64_t & n_seq_tokens,
|
| 811 |
+
constant int64_t & n_seqs,
|
| 812 |
+
constant uint64_t & nb00,
|
| 813 |
+
constant uint64_t & nb01,
|
| 814 |
+
constant uint64_t & nb02,
|
| 815 |
+
constant uint64_t & nb10,
|
| 816 |
+
constant uint64_t & nb11,
|
| 817 |
+
constant uint64_t & nb12,
|
| 818 |
+
constant uint64_t & nb13,
|
| 819 |
+
constant uint64_t & nb20,
|
| 820 |
+
constant uint64_t & nb21,
|
| 821 |
+
constant uint64_t & nb22,
|
| 822 |
+
constant uint64_t & nb30,
|
| 823 |
+
constant uint64_t & nb31,
|
| 824 |
+
constant uint64_t & nb40,
|
| 825 |
+
constant uint64_t & nb41,
|
| 826 |
+
constant uint64_t & nb42,
|
| 827 |
+
constant uint64_t & nb50,
|
| 828 |
+
constant uint64_t & nb51,
|
| 829 |
+
constant uint64_t & nb52,
|
| 830 |
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 831 |
+
uint3 tpitg[[thread_position_in_threadgroup]],
|
| 832 |
+
uint3 ntg[[threads_per_threadgroup]]) {
|
| 833 |
+
const int64_t ir = tgpig.x;
|
| 834 |
+
const int64_t i3 = tgpig.y;
|
| 835 |
+
|
| 836 |
+
const int64_t nc = d_state;
|
| 837 |
+
const int64_t nr = d_inner;
|
| 838 |
+
const int64_t n_t = n_seq_tokens;
|
| 839 |
+
const int64_t n_s = n_seqs;
|
| 840 |
+
|
| 841 |
+
for (int64_t i2 = 0; i2 < n_t; ++i2) {
|
| 842 |
+
device const float * s0 = (device const float *) ((device const char *) src0 + ir*nb01 + i3*nb02);
|
| 843 |
+
device const float * x = (device const float *) ((device const char *) src1 + ir*nb10 + i2*nb11 + i3*nb12);
|
| 844 |
+
device const float * dt = (device const float *) ((device const char *) src2 + ir*nb20 + i2*nb21 + i3*nb22);
|
| 845 |
+
device const float * A = (device const float *) ((device const char *) src3 + ir*nb31);
|
| 846 |
+
device const float * B = (device const float *) ((device const char *) src4 + i2*nb41 + i3*nb42);
|
| 847 |
+
device const float * C = (device const float *) ((device const char *) src5 + i2*nb51 + i3*nb52);
|
| 848 |
+
device float * y = (device float *) ((device char *) dst + ir*nb10 + i2*nb11 + i3*nb12); // TODO: do not use src1 strides
|
| 849 |
+
device float * s = (device float *) ((device char *) dst + ir*nb01 + i3*nb02 + nb13);
|
| 850 |
+
|
| 851 |
+
if (i2 > 0) {
|
| 852 |
+
s0 = s;
|
| 853 |
+
}
|
| 854 |
+
|
| 855 |
+
// i1 == 0
|
| 856 |
+
float dt_soft_plus = dt[0] <= 20.0f ? log(1.0f + exp(dt[0])) : dt[0];
|
| 857 |
+
float x_dt = x[0] * dt_soft_plus;
|
| 858 |
+
float sumf = 0.0f;
|
| 859 |
+
|
| 860 |
+
for (int64_t i0 = 0; i0 < nc; ++i0) {
|
| 861 |
+
int64_t i = i0;
|
| 862 |
+
float state = (s0[i] * exp(dt_soft_plus * A[i])) + (B[i0] * x_dt);
|
| 863 |
+
sumf += state * C[i0];
|
| 864 |
+
s[i] = state;
|
| 865 |
+
}
|
| 866 |
+
|
| 867 |
+
y[0] = sumf;
|
| 868 |
+
}
|
| 869 |
+
}
|
| 870 |
+
|
| 871 |
kernel void kernel_norm(
|
| 872 |
device const void * src0,
|
| 873 |
device float * dst,
|
ggml/src/ggml.c
CHANGED
|
@@ -16372,8 +16372,8 @@ static void ggml_compute_forward_ssm_scan_f32(
|
|
| 16372 |
const float * A = (const float *) ((const char *) src3->data + ir0*(src3->nb[1])); // {d_state, d_inner}
|
| 16373 |
const float * B = (const float *) ((const char *) src4->data + i2*(src4->nb[1]) + i3*(src4->nb[2])); // {d_state, n_t, n_s}
|
| 16374 |
const float * C = (const float *) ((const char *) src5->data + i2*(src5->nb[1]) + i3*(src5->nb[2])); // {d_state, n_t, n_s}
|
| 16375 |
-
|
| 16376 |
-
|
| 16377 |
|
| 16378 |
// use the output as the source for the next token-wise iterations
|
| 16379 |
if (i2 > 0) { s0 = s; }
|
|
|
|
| 16372 |
const float * A = (const float *) ((const char *) src3->data + ir0*(src3->nb[1])); // {d_state, d_inner}
|
| 16373 |
const float * B = (const float *) ((const char *) src4->data + i2*(src4->nb[1]) + i3*(src4->nb[2])); // {d_state, n_t, n_s}
|
| 16374 |
const float * C = (const float *) ((const char *) src5->data + i2*(src5->nb[1]) + i3*(src5->nb[2])); // {d_state, n_t, n_s}
|
| 16375 |
+
float * y = ( float *) (( char *) dst->data + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s}
|
| 16376 |
+
float * s = ( float *) (( char *) dst->data + ir0*(src0->nb[1]) + i3*(src0->nb[2]) + src1->nb[3]); // {d_state, d_inner, n_s}
|
| 16377 |
|
| 16378 |
// use the output as the source for the next token-wise iterations
|
| 16379 |
if (i2 > 0) { s0 = s; }
|