ggerganov commited on
Commit
8bc6274
·
1 Parent(s): 7dd37dc

metal : optimize ggml_mul_mat_id (faster Mixtral PP) (llama/4725)

Browse files

* ggml : disable fast-math for Metal (cmake build only)

ggml-ci

* metal : fix Metal API debug warnings

* cmake : add -fno-inline for Metal build (llama/4545)

* metal : fix API debug warnings

* metal : fix compile warnings

* metal : use uint64_t for strides

* cmake : rename option to LLAMA_METAL_SHADER_DEBUG

* metal : fix mat-vec Q8_0 kernel for BS > 1

* metal : normalize mat-vec kernel signatures

* cmake : respect LLAMA_QKK_64 option

* metal : fix mat-vec Q4_K kernel for QK_K == 64

* metal : optimizing ggml_mul_mat_id (wip)

* metal : minor fix

* metal : opt mul_mm_id

Files changed (2) hide show
  1. ggml-metal.m +20 -11
  2. ggml-metal.metal +170 -35
ggml-metal.m CHANGED
@@ -1657,6 +1657,10 @@ void ggml_metal_graph_compute(
1657
  }
1658
  };
1659
 
 
 
 
 
1660
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1661
  [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
1662
  [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
@@ -1715,6 +1719,9 @@ void ggml_metal_graph_compute(
1715
  // TODO: make this more general
1716
  GGML_ASSERT(n_as <= 8);
1717
 
 
 
 
1718
  struct ggml_tensor * src2 = gf->nodes[i]->src[2];
1719
 
1720
  const int64_t ne20 = src2 ? src2->ne[0] : 0;
@@ -1732,9 +1739,6 @@ void ggml_metal_graph_compute(
1732
  GGML_ASSERT(!ggml_is_transposed(src2));
1733
  GGML_ASSERT(!ggml_is_transposed(src1));
1734
 
1735
- GGML_ASSERT(ne20 % 32 == 0);
1736
- // !!!!!!!!! TODO: this assert is probably required but not sure!
1737
- //GGML_ASSERT(ne20 >= 64);
1738
  GGML_ASSERT(src1t == GGML_TYPE_F32);
1739
 
1740
  const uint r2 = ne12/ne22;
@@ -1742,22 +1746,22 @@ void ggml_metal_graph_compute(
1742
 
1743
  // find the break-even point where the matrix-matrix kernel becomes more efficient compared
1744
  // to the matrix-vector kernel
1745
- int ne11_mm_min = 1;
1746
 
1747
  const int idx = ((int32_t *) dst->op_params)[0];
1748
 
1749
  // batch size
1750
  GGML_ASSERT(ne01 == ne11);
1751
 
1752
- const int64_t _ne1 = 1; // kernel_mul_mm_impl needs a reference in constant memory
1753
-
1754
  // for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
1755
  // AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
1756
  // !!!
1757
  // TODO: for now, always use mat-vec kernels until we figure out how to improve the
1758
  // indirect matrix multiplication
1759
  // !!!
1760
- if ([ctx->device supportsFamily:MTLGPUFamilyApple7] && _ne1 > ne11_mm_min) {
 
 
1761
  switch (src2->type) {
1762
  case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_f32_f32]; break;
1763
  case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_f16_f32]; break;
@@ -1787,7 +1791,7 @@ void ggml_metal_graph_compute(
1787
  [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:11];
1788
  [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:12];
1789
  [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:13];
1790
- [encoder setBytes:&_ne1 length:sizeof(_ne1) atIndex:14];
1791
  [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15];
1792
  [encoder setBytes:&r2 length:sizeof(r2) atIndex:16];
1793
  [encoder setBytes:&r3 length:sizeof(r3) atIndex:17];
@@ -1805,8 +1809,7 @@ void ggml_metal_graph_compute(
1805
 
1806
  [encoder setThreadgroupMemoryLength:8192 atIndex:0];
1807
 
1808
- // TODO: processing one row at a time (ne11 -> 1) is not efficient
1809
- [encoder dispatchThreadgroups:MTLSizeMake( (_ne1 + 31)/32, (ne21 + 63)/64, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
1810
  } else {
1811
  int nth0 = 32;
1812
  int nth1 = 1;
@@ -1889,11 +1892,17 @@ void ggml_metal_graph_compute(
1889
  } break;
1890
  default:
1891
  {
1892
- GGML_METAL_LOG_ERROR("Asserting on type %d\n", (int)src0t);
1893
  GGML_ASSERT(false && "not implemented");
1894
  }
1895
  };
1896
 
 
 
 
 
 
 
1897
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1898
  [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
1899
  [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
 
1657
  }
1658
  };
1659
 
1660
+ if (ggml_is_quantized(src0t)) {
1661
+ GGML_ASSERT(ne00 >= nth0*nth1);
1662
+ }
1663
+
1664
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1665
  [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
1666
  [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
 
1719
  // TODO: make this more general
1720
  GGML_ASSERT(n_as <= 8);
1721
 
1722
+ // max size of the src1ids array in the kernel stack
1723
+ GGML_ASSERT(ne11 <= 512);
1724
+
1725
  struct ggml_tensor * src2 = gf->nodes[i]->src[2];
1726
 
1727
  const int64_t ne20 = src2 ? src2->ne[0] : 0;
 
1739
  GGML_ASSERT(!ggml_is_transposed(src2));
1740
  GGML_ASSERT(!ggml_is_transposed(src1));
1741
 
 
 
 
1742
  GGML_ASSERT(src1t == GGML_TYPE_F32);
1743
 
1744
  const uint r2 = ne12/ne22;
 
1746
 
1747
  // find the break-even point where the matrix-matrix kernel becomes more efficient compared
1748
  // to the matrix-vector kernel
1749
+ int ne11_mm_min = n_as;
1750
 
1751
  const int idx = ((int32_t *) dst->op_params)[0];
1752
 
1753
  // batch size
1754
  GGML_ASSERT(ne01 == ne11);
1755
 
 
 
1756
  // for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
1757
  // AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
1758
  // !!!
1759
  // TODO: for now, always use mat-vec kernels until we figure out how to improve the
1760
  // indirect matrix multiplication
1761
  // !!!
1762
+ if ([ctx->device supportsFamily:MTLGPUFamilyApple7] &&
1763
+ ne20 % 32 == 0 && ne20 >= 64 &&
1764
+ ne11 > ne11_mm_min) {
1765
  switch (src2->type) {
1766
  case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_f32_f32]; break;
1767
  case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_f16_f32]; break;
 
1791
  [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:11];
1792
  [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:12];
1793
  [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:13];
1794
+ [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:14];
1795
  [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15];
1796
  [encoder setBytes:&r2 length:sizeof(r2) atIndex:16];
1797
  [encoder setBytes:&r3 length:sizeof(r3) atIndex:17];
 
1809
 
1810
  [encoder setThreadgroupMemoryLength:8192 atIndex:0];
1811
 
1812
+ [encoder dispatchThreadgroups:MTLSizeMake((ne11 + 31)/32, (ne21 + 63)/64, n_as*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
 
1813
  } else {
1814
  int nth0 = 32;
1815
  int nth1 = 1;
 
1892
  } break;
1893
  default:
1894
  {
1895
+ GGML_METAL_LOG_ERROR("Asserting on type %d\n", (int)src2t);
1896
  GGML_ASSERT(false && "not implemented");
1897
  }
1898
  };
1899
 
1900
+ if (ggml_is_quantized(src2t)) {
1901
+ GGML_ASSERT(ne20 >= nth0*nth1);
1902
+ }
1903
+
1904
+ const int64_t _ne1 = 1; // kernels needs a reference in constant memory
1905
+
1906
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1907
  [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
1908
  [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
ggml-metal.metal CHANGED
@@ -846,7 +846,7 @@ inline float block_q_n_dot_y(device const block_q5_1 * qb_curr, float sumy, thre
846
  #define N_SIMDGROUP 2 // number of SIMD groups in a thread group
847
  //Note: This is a template, but strictly speaking it only applies to
848
  // quantizations where the block size is 32. It also does not
849
- // giard against the number of rows not being divisible by
850
  // N_DST, so this is another explicit assumption of the implementation.
851
  template<typename block_q_type, int nr, int nsg, int nw>
852
  void mul_vec_q_n_f32_impl(
@@ -3973,6 +3973,131 @@ void kernel_mul_mm_impl(device const uchar * src0,
3973
  }
3974
  }
3975
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3976
  template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
3977
  kernel void kernel_mul_mm(device const uchar * src0,
3978
  device const uchar * src1,
@@ -4019,7 +4144,7 @@ template<typename block_q, short nl, void (*dequantize_func)(device const block_
4019
  kernel void kernel_mul_mm_id(
4020
  device const uchar * ids,
4021
  device const uchar * src1,
4022
- device uchar * dst,
4023
  constant uint64_t & nbi1,
4024
  constant int64_t & ne00,
4025
  constant int64_t & ne02,
@@ -4048,18 +4173,28 @@ kernel void kernel_mul_mm_id(
4048
  uint3 tgpig[[threadgroup_position_in_grid]],
4049
  uint tiitg[[thread_index_in_threadgroup]],
4050
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
4051
- device const uchar * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
4052
 
4053
- const int64_t bid = tgpig.z/(ne12*ne13);
 
4054
 
4055
  tgpig.z = tgpig.z%(ne12*ne13);
4056
 
4057
- const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
 
 
4058
 
4059
- kernel_mul_mm_impl<block_q, nl, dequantize_func>(
4060
- src0[id],
4061
- src1 + bid*nb11,
4062
- (device float *) (dst + bid*nb1),
 
 
 
 
 
 
 
4063
  ne00,
4064
  ne02,
4065
  nb01,
@@ -4069,7 +4204,7 @@ kernel void kernel_mul_mm_id(
4069
  nb11,
4070
  nb12,
4071
  ne0,
4072
- ne1,
4073
  r2,
4074
  r3,
4075
  shared_memory,
@@ -4158,7 +4293,7 @@ template [[host_name("kernel_mul_mm_q6_K_f32")]] kernel mat_mm_t kernel_mul_mm<b
4158
  typedef void (mat_mm_id_t)(
4159
  device const uchar * ids,
4160
  device const uchar * src1,
4161
- device uchar * dst,
4162
  constant uint64_t & nbi1,
4163
  constant int64_t & ne00,
4164
  constant int64_t & ne02,
@@ -4207,7 +4342,7 @@ template [[host_name("kernel_mul_mm_id_q6_K_f32")]] kernel mat_mm_id_t kernel_mu
4207
  kernel void kernel_mul_mv_id_f32_f32(
4208
  device const char * ids,
4209
  device const char * src1,
4210
- device uchar * dst,
4211
  constant uint64_t & nbi1,
4212
  constant int64_t & ne00,
4213
  constant int64_t & ne01,
@@ -4251,7 +4386,7 @@ kernel void kernel_mul_mv_id_f32_f32(
4251
  kernel_mul_mv_f32_f32_impl(
4252
  src0[id],
4253
  src1 + bid*nb11,
4254
- (device float *) (dst + bid*nb1),
4255
  ne00,
4256
  ne01,
4257
  ne02,
@@ -4276,7 +4411,7 @@ kernel void kernel_mul_mv_id_f32_f32(
4276
  kernel void kernel_mul_mv_id_f16_f32(
4277
  device const char * ids,
4278
  device const char * src1,
4279
- device uchar * dst,
4280
  constant uint64_t & nbi1,
4281
  constant int64_t & ne00,
4282
  constant int64_t & ne01,
@@ -4320,7 +4455,7 @@ kernel void kernel_mul_mv_id_f16_f32(
4320
  kernel_mul_mv_f16_f32_impl(
4321
  src0[id],
4322
  src1 + bid*nb11,
4323
- (device float *) (dst + bid*nb1),
4324
  ne00,
4325
  ne01,
4326
  ne02,
@@ -4345,7 +4480,7 @@ kernel void kernel_mul_mv_id_f16_f32(
4345
  kernel void kernel_mul_mv_id_q8_0_f32(
4346
  device const char * ids,
4347
  device const char * src1,
4348
- device uchar * dst,
4349
  constant uint64_t & nbi1,
4350
  constant int64_t & ne00,
4351
  constant int64_t & ne01,
@@ -4389,7 +4524,7 @@ kernel void kernel_mul_mv_id_q8_0_f32(
4389
  kernel_mul_mv_q8_0_f32_impl(
4390
  src0[id],
4391
  (device const float *) (src1 + bid*nb11),
4392
- (device float *) ( dst + bid*nb1),
4393
  ne00,
4394
  ne01,
4395
  ne02,
@@ -4408,7 +4543,7 @@ kernel void kernel_mul_mv_id_q8_0_f32(
4408
  kernel void kernel_mul_mv_id_q4_0_f32(
4409
  device const char * ids,
4410
  device const char * src1,
4411
- device uchar * dst,
4412
  constant uint64_t & nbi1,
4413
  constant int64_t & ne00,
4414
  constant int64_t & ne01,
@@ -4452,7 +4587,7 @@ kernel void kernel_mul_mv_id_q4_0_f32(
4452
  mul_vec_q_n_f32_impl<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
4453
  src0[id],
4454
  (device const float *) (src1 + bid*nb11),
4455
- (device float *) ( dst + bid*nb1),
4456
  ne00,
4457
  ne01,
4458
  ne02,
@@ -4471,7 +4606,7 @@ kernel void kernel_mul_mv_id_q4_0_f32(
4471
  kernel void kernel_mul_mv_id_q4_1_f32(
4472
  device const char * ids,
4473
  device const char * src1,
4474
- device uchar * dst,
4475
  constant uint64_t & nbi1,
4476
  constant int64_t & ne00,
4477
  constant int64_t & ne01,
@@ -4515,7 +4650,7 @@ kernel void kernel_mul_mv_id_q4_1_f32(
4515
  mul_vec_q_n_f32_impl<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
4516
  src0[id],
4517
  (device const float *) (src1 + bid*nb11),
4518
- (device float *) ( dst + bid*nb1),
4519
  ne00,
4520
  ne01,
4521
  ne02,
@@ -4534,7 +4669,7 @@ kernel void kernel_mul_mv_id_q4_1_f32(
4534
  kernel void kernel_mul_mv_id_q5_0_f32(
4535
  device const char * ids,
4536
  device const char * src1,
4537
- device uchar * dst,
4538
  constant uint64_t & nbi1,
4539
  constant int64_t & ne00,
4540
  constant int64_t & ne01,
@@ -4578,7 +4713,7 @@ kernel void kernel_mul_mv_id_q5_0_f32(
4578
  mul_vec_q_n_f32_impl<block_q5_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
4579
  src0[id],
4580
  (device const float *) (src1 + bid*nb11),
4581
- (device float *) ( dst + bid*nb1),
4582
  ne00,
4583
  ne01,
4584
  ne02,
@@ -4597,7 +4732,7 @@ kernel void kernel_mul_mv_id_q5_0_f32(
4597
  kernel void kernel_mul_mv_id_q5_1_f32(
4598
  device const char * ids,
4599
  device const char * src1,
4600
- device uchar * dst,
4601
  constant uint64_t & nbi1,
4602
  constant int64_t & ne00,
4603
  constant int64_t & ne01,
@@ -4641,7 +4776,7 @@ kernel void kernel_mul_mv_id_q5_1_f32(
4641
  mul_vec_q_n_f32_impl<block_q5_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
4642
  src0[id],
4643
  (device const float *) (src1 + bid*nb11),
4644
- (device float *) ( dst + bid*nb1),
4645
  ne00,
4646
  ne01,
4647
  ne02,
@@ -4660,7 +4795,7 @@ kernel void kernel_mul_mv_id_q5_1_f32(
4660
  kernel void kernel_mul_mv_id_q2_K_f32(
4661
  device const char * ids,
4662
  device const char * src1,
4663
- device uchar * dst,
4664
  constant uint64_t & nbi1,
4665
  constant int64_t & ne00,
4666
  constant int64_t & ne01,
@@ -4704,7 +4839,7 @@ kernel void kernel_mul_mv_id_q2_K_f32(
4704
  kernel_mul_mv_q2_K_f32_impl(
4705
  src0[id],
4706
  (device const float *) (src1 + bid*nb11),
4707
- (device float *) ( dst + bid*nb1),
4708
  ne00,
4709
  ne01,
4710
  ne02,
@@ -4723,7 +4858,7 @@ kernel void kernel_mul_mv_id_q2_K_f32(
4723
  kernel void kernel_mul_mv_id_q3_K_f32(
4724
  device const char * ids,
4725
  device const char * src1,
4726
- device uchar * dst,
4727
  constant uint64_t & nbi1,
4728
  constant int64_t & ne00,
4729
  constant int64_t & ne01,
@@ -4767,7 +4902,7 @@ kernel void kernel_mul_mv_id_q3_K_f32(
4767
  kernel_mul_mv_q3_K_f32_impl(
4768
  src0[id],
4769
  (device const float *) (src1 + bid*nb11),
4770
- (device float *) ( dst + bid*nb1),
4771
  ne00,
4772
  ne01,
4773
  ne02,
@@ -4786,7 +4921,7 @@ kernel void kernel_mul_mv_id_q3_K_f32(
4786
  kernel void kernel_mul_mv_id_q4_K_f32(
4787
  device const char * ids,
4788
  device const char * src1,
4789
- device uchar * dst,
4790
  constant uint64_t & nbi1,
4791
  constant int64_t & ne00,
4792
  constant int64_t & ne01,
@@ -4830,7 +4965,7 @@ kernel void kernel_mul_mv_id_q4_K_f32(
4830
  kernel_mul_mv_q4_K_f32_impl(
4831
  src0[id],
4832
  (device const float *) (src1 + bid*nb11),
4833
- (device float *) ( dst + bid*nb1),
4834
  ne00,
4835
  ne01,
4836
  ne02,
@@ -4849,7 +4984,7 @@ kernel void kernel_mul_mv_id_q4_K_f32(
4849
  kernel void kernel_mul_mv_id_q5_K_f32(
4850
  device const char * ids,
4851
  device const char * src1,
4852
- device uchar * dst,
4853
  constant uint64_t & nbi1,
4854
  constant int64_t & ne00,
4855
  constant int64_t & ne01,
@@ -4893,7 +5028,7 @@ kernel void kernel_mul_mv_id_q5_K_f32(
4893
  kernel_mul_mv_q5_K_f32_impl(
4894
  src0[id],
4895
  (device const float *) (src1 + bid*nb11),
4896
- (device float *) ( dst + bid*nb1),
4897
  ne00,
4898
  ne01,
4899
  ne02,
@@ -4912,7 +5047,7 @@ kernel void kernel_mul_mv_id_q5_K_f32(
4912
  kernel void kernel_mul_mv_id_q6_K_f32(
4913
  device const char * ids,
4914
  device const char * src1,
4915
- device uchar * dst,
4916
  constant uint64_t & nbi1,
4917
  constant int64_t & ne00,
4918
  constant int64_t & ne01,
@@ -4956,7 +5091,7 @@ kernel void kernel_mul_mv_id_q6_K_f32(
4956
  kernel_mul_mv_q6_K_f32_impl(
4957
  src0[id],
4958
  (device const float *) (src1 + bid*nb11),
4959
- (device float *) ( dst + bid*nb1),
4960
  ne00,
4961
  ne01,
4962
  ne02,
 
846
  #define N_SIMDGROUP 2 // number of SIMD groups in a thread group
847
  //Note: This is a template, but strictly speaking it only applies to
848
  // quantizations where the block size is 32. It also does not
849
+ // guard against the number of rows not being divisible by
850
  // N_DST, so this is another explicit assumption of the implementation.
851
  template<typename block_q_type, int nr, int nsg, int nw>
852
  void mul_vec_q_n_f32_impl(
 
3973
  }
3974
  }
3975
 
3976
+ // same as kernel_mul_mm_impl, but src1 and dst are accessed via indices stored in src1ids
3977
+ template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
3978
+ void kernel_mul_mm_id_impl(
3979
+ device const uchar * src0,
3980
+ device const uchar * src1,
3981
+ thread short * src1ids,
3982
+ device float * dst,
3983
+ constant int64_t & ne00,
3984
+ constant int64_t & ne02,
3985
+ constant uint64_t & nb01,
3986
+ constant uint64_t & nb02,
3987
+ constant int64_t & ne12,
3988
+ constant uint64_t & nb10,
3989
+ constant uint64_t & nb11,
3990
+ constant uint64_t & nb12,
3991
+ constant int64_t & ne0,
3992
+ int64_t ne1,
3993
+ constant uint & r2,
3994
+ constant uint & r3,
3995
+ threadgroup uchar * shared_memory,
3996
+ uint3 tgpig[[threadgroup_position_in_grid]],
3997
+ uint tiitg[[thread_index_in_threadgroup]],
3998
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
3999
+
4000
+ threadgroup half * sa = (threadgroup half *)(shared_memory);
4001
+ threadgroup float * sb = (threadgroup float *)(shared_memory + 4096);
4002
+
4003
+ const uint r0 = tgpig.y;
4004
+ const uint r1 = tgpig.x;
4005
+ const uint im = tgpig.z;
4006
+
4007
+ if (r1 * BLOCK_SIZE_N >= ne1) return;
4008
+
4009
+ // if this block is of 64x32 shape or smaller
4010
+ short n_rows = (ne0 - r0 * BLOCK_SIZE_M < BLOCK_SIZE_M) ? (ne0 - r0 * BLOCK_SIZE_M) : BLOCK_SIZE_M;
4011
+ short n_cols = (ne1 - r1 * BLOCK_SIZE_N < BLOCK_SIZE_N) ? (ne1 - r1 * BLOCK_SIZE_N) : BLOCK_SIZE_N;
4012
+
4013
+ // a thread shouldn't load data outside of the matrix
4014
+ short thread_row = ((short)tiitg/THREAD_PER_ROW) < n_rows ? ((short)tiitg/THREAD_PER_ROW) : n_rows - 1;
4015
+ short thread_col = ((short)tiitg/THREAD_PER_COL) < n_cols ? ((short)tiitg/THREAD_PER_COL) : n_cols - 1;
4016
+
4017
+ simdgroup_half8x8 ma[4];
4018
+ simdgroup_float8x8 mb[2];
4019
+ simdgroup_float8x8 c_res[8];
4020
+ for (int i = 0; i < 8; i++){
4021
+ c_res[i] = make_filled_simdgroup_matrix<float, 8>(0.f);
4022
+ }
4023
+
4024
+ short il = (tiitg % THREAD_PER_ROW);
4025
+
4026
+ const uint i12 = im%ne12;
4027
+ const uint i13 = im/ne12;
4028
+
4029
+ uint offset0 = (i12/r2)*nb02 + (i13/r3)*(nb02*ne02);
4030
+ ushort offset1 = il/nl;
4031
+
4032
+ device const block_q * x = (device const block_q *)(src0 + (r0 * BLOCK_SIZE_M + thread_row) * nb01 + offset0) + offset1;
4033
+ device const float * y = (device const float *)(src1
4034
+ + nb12 * im
4035
+ + nb11 * src1ids[r1 * BLOCK_SIZE_N + thread_col]
4036
+ + nb10 * (BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL)));
4037
+
4038
+ for (int loop_k = 0; loop_k < ne00; loop_k += BLOCK_SIZE_K) {
4039
+ // load data and store to threadgroup memory
4040
+ half4x4 temp_a;
4041
+ dequantize_func(x, il, temp_a);
4042
+ threadgroup_barrier(mem_flags::mem_threadgroup);
4043
+
4044
+ for (int i = 0; i < 16; i++) {
4045
+ *(sa + SG_MAT_SIZE * ((tiitg / THREAD_PER_ROW / 8) \
4046
+ + (tiitg % THREAD_PER_ROW) * 16 + (i / 8) * 8) \
4047
+ + (tiitg / THREAD_PER_ROW) % 8 + (i & 7) * 8) = temp_a[i/4][i%4];
4048
+ }
4049
+
4050
+ *(threadgroup float2x4 *)(sb + (tiitg % THREAD_PER_COL) * 8 * 32 + 8 * (tiitg / THREAD_PER_COL)) = *((device float2x4 *)y);
4051
+
4052
+ il = (il + 2 < nl) ? il + 2 : il % 2;
4053
+ x = (il < 2) ? x + (2+nl-1)/nl : x;
4054
+ y += BLOCK_SIZE_K;
4055
+
4056
+ threadgroup_barrier(mem_flags::mem_threadgroup);
4057
+
4058
+ // load matrices from threadgroup memory and conduct outer products
4059
+ threadgroup half * lsma = (sa + THREAD_MAT_M * SG_MAT_SIZE * (sgitg % 2));
4060
+ threadgroup float * lsmb = (sb + THREAD_MAT_N * SG_MAT_SIZE * (sgitg / 2));
4061
+
4062
+ for (int ik = 0; ik < BLOCK_SIZE_K / 8; ik++) {
4063
+ for (int i = 0; i < 4; i++) {
4064
+ simdgroup_load(ma[i],lsma + SG_MAT_SIZE * i);
4065
+ }
4066
+ simdgroup_barrier(mem_flags::mem_none);
4067
+ for (int i = 0; i < 2; i++) {
4068
+ simdgroup_load(mb[i],lsmb + SG_MAT_SIZE * i);
4069
+ }
4070
+
4071
+ lsma += BLOCK_SIZE_M / SG_MAT_ROW * SG_MAT_SIZE;
4072
+ lsmb += BLOCK_SIZE_N / SG_MAT_ROW * SG_MAT_SIZE;
4073
+
4074
+ for (int i = 0; i < 8; i++){
4075
+ simdgroup_multiply_accumulate(c_res[i], mb[i/4], ma[i%4], c_res[i]);
4076
+ }
4077
+ }
4078
+ }
4079
+
4080
+ {
4081
+ threadgroup_barrier(mem_flags::mem_threadgroup);
4082
+ threadgroup float * temp_str = ((threadgroup float *)shared_memory) \
4083
+ + 32 * (sgitg&1) + (16 * (sgitg>>1)) * BLOCK_SIZE_M;
4084
+ for (int i = 0; i < 8; i++) {
4085
+ simdgroup_store(c_res[i], temp_str + 8 * (i%4) + 8 * BLOCK_SIZE_M * (i/4), BLOCK_SIZE_M);
4086
+ }
4087
+
4088
+ threadgroup_barrier(mem_flags::mem_threadgroup);
4089
+
4090
+ device float * C = dst + (BLOCK_SIZE_M * r0) + im*ne1*ne0;
4091
+ if (sgitg == 0) {
4092
+ for (int i = 0; i < n_rows; i++) {
4093
+ for (int j = tiitg; j < n_cols; j += BLOCK_SIZE_N) {
4094
+ *(C + i + src1ids[j + r1*BLOCK_SIZE_N] * ne0) = *(temp_str + i + j * BLOCK_SIZE_M);
4095
+ }
4096
+ }
4097
+ }
4098
+ }
4099
+ }
4100
+
4101
  template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
4102
  kernel void kernel_mul_mm(device const uchar * src0,
4103
  device const uchar * src1,
 
4144
  kernel void kernel_mul_mm_id(
4145
  device const uchar * ids,
4146
  device const uchar * src1,
4147
+ device float * dst,
4148
  constant uint64_t & nbi1,
4149
  constant int64_t & ne00,
4150
  constant int64_t & ne02,
 
4173
  uint3 tgpig[[threadgroup_position_in_grid]],
4174
  uint tiitg[[thread_index_in_threadgroup]],
4175
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
4176
+ device const uchar * src0s[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
4177
 
4178
+ // expert id
4179
+ const int32_t id = tgpig.z/(ne12*ne13);
4180
 
4181
  tgpig.z = tgpig.z%(ne12*ne13);
4182
 
4183
+ // row indices of src1 for expert id
4184
+ int64_t _ne1 = 0;
4185
+ short src1ids[512];
4186
 
4187
+ for (int64_t i1 = 0; i1 < ne1; i1++) {
4188
+ if (((device int32_t *) (ids + i1*nbi1))[idx] == id) {
4189
+ src1ids[_ne1++] = i1;
4190
+ }
4191
+ }
4192
+
4193
+ kernel_mul_mm_id_impl<block_q, nl, dequantize_func>(
4194
+ src0s[id],
4195
+ src1,
4196
+ src1ids,
4197
+ dst,
4198
  ne00,
4199
  ne02,
4200
  nb01,
 
4204
  nb11,
4205
  nb12,
4206
  ne0,
4207
+ _ne1,
4208
  r2,
4209
  r3,
4210
  shared_memory,
 
4293
  typedef void (mat_mm_id_t)(
4294
  device const uchar * ids,
4295
  device const uchar * src1,
4296
+ device float * dst,
4297
  constant uint64_t & nbi1,
4298
  constant int64_t & ne00,
4299
  constant int64_t & ne02,
 
4342
  kernel void kernel_mul_mv_id_f32_f32(
4343
  device const char * ids,
4344
  device const char * src1,
4345
+ device float * dst,
4346
  constant uint64_t & nbi1,
4347
  constant int64_t & ne00,
4348
  constant int64_t & ne01,
 
4386
  kernel_mul_mv_f32_f32_impl(
4387
  src0[id],
4388
  src1 + bid*nb11,
4389
+ dst + bid*ne0,
4390
  ne00,
4391
  ne01,
4392
  ne02,
 
4411
  kernel void kernel_mul_mv_id_f16_f32(
4412
  device const char * ids,
4413
  device const char * src1,
4414
+ device float * dst,
4415
  constant uint64_t & nbi1,
4416
  constant int64_t & ne00,
4417
  constant int64_t & ne01,
 
4455
  kernel_mul_mv_f16_f32_impl(
4456
  src0[id],
4457
  src1 + bid*nb11,
4458
+ dst + bid*ne0,
4459
  ne00,
4460
  ne01,
4461
  ne02,
 
4480
  kernel void kernel_mul_mv_id_q8_0_f32(
4481
  device const char * ids,
4482
  device const char * src1,
4483
+ device float * dst,
4484
  constant uint64_t & nbi1,
4485
  constant int64_t & ne00,
4486
  constant int64_t & ne01,
 
4524
  kernel_mul_mv_q8_0_f32_impl(
4525
  src0[id],
4526
  (device const float *) (src1 + bid*nb11),
4527
+ dst + bid*ne0,
4528
  ne00,
4529
  ne01,
4530
  ne02,
 
4543
  kernel void kernel_mul_mv_id_q4_0_f32(
4544
  device const char * ids,
4545
  device const char * src1,
4546
+ device float * dst,
4547
  constant uint64_t & nbi1,
4548
  constant int64_t & ne00,
4549
  constant int64_t & ne01,
 
4587
  mul_vec_q_n_f32_impl<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
4588
  src0[id],
4589
  (device const float *) (src1 + bid*nb11),
4590
+ dst + bid*ne0,
4591
  ne00,
4592
  ne01,
4593
  ne02,
 
4606
  kernel void kernel_mul_mv_id_q4_1_f32(
4607
  device const char * ids,
4608
  device const char * src1,
4609
+ device float * dst,
4610
  constant uint64_t & nbi1,
4611
  constant int64_t & ne00,
4612
  constant int64_t & ne01,
 
4650
  mul_vec_q_n_f32_impl<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
4651
  src0[id],
4652
  (device const float *) (src1 + bid*nb11),
4653
+ dst + bid*ne0,
4654
  ne00,
4655
  ne01,
4656
  ne02,
 
4669
  kernel void kernel_mul_mv_id_q5_0_f32(
4670
  device const char * ids,
4671
  device const char * src1,
4672
+ device float * dst,
4673
  constant uint64_t & nbi1,
4674
  constant int64_t & ne00,
4675
  constant int64_t & ne01,
 
4713
  mul_vec_q_n_f32_impl<block_q5_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
4714
  src0[id],
4715
  (device const float *) (src1 + bid*nb11),
4716
+ dst + bid*ne0,
4717
  ne00,
4718
  ne01,
4719
  ne02,
 
4732
  kernel void kernel_mul_mv_id_q5_1_f32(
4733
  device const char * ids,
4734
  device const char * src1,
4735
+ device float * dst,
4736
  constant uint64_t & nbi1,
4737
  constant int64_t & ne00,
4738
  constant int64_t & ne01,
 
4776
  mul_vec_q_n_f32_impl<block_q5_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
4777
  src0[id],
4778
  (device const float *) (src1 + bid*nb11),
4779
+ dst + bid*ne0,
4780
  ne00,
4781
  ne01,
4782
  ne02,
 
4795
  kernel void kernel_mul_mv_id_q2_K_f32(
4796
  device const char * ids,
4797
  device const char * src1,
4798
+ device float * dst,
4799
  constant uint64_t & nbi1,
4800
  constant int64_t & ne00,
4801
  constant int64_t & ne01,
 
4839
  kernel_mul_mv_q2_K_f32_impl(
4840
  src0[id],
4841
  (device const float *) (src1 + bid*nb11),
4842
+ dst + bid*ne0,
4843
  ne00,
4844
  ne01,
4845
  ne02,
 
4858
  kernel void kernel_mul_mv_id_q3_K_f32(
4859
  device const char * ids,
4860
  device const char * src1,
4861
+ device float * dst,
4862
  constant uint64_t & nbi1,
4863
  constant int64_t & ne00,
4864
  constant int64_t & ne01,
 
4902
  kernel_mul_mv_q3_K_f32_impl(
4903
  src0[id],
4904
  (device const float *) (src1 + bid*nb11),
4905
+ dst + bid*ne0,
4906
  ne00,
4907
  ne01,
4908
  ne02,
 
4921
  kernel void kernel_mul_mv_id_q4_K_f32(
4922
  device const char * ids,
4923
  device const char * src1,
4924
+ device float * dst,
4925
  constant uint64_t & nbi1,
4926
  constant int64_t & ne00,
4927
  constant int64_t & ne01,
 
4965
  kernel_mul_mv_q4_K_f32_impl(
4966
  src0[id],
4967
  (device const float *) (src1 + bid*nb11),
4968
+ dst + bid*ne0,
4969
  ne00,
4970
  ne01,
4971
  ne02,
 
4984
  kernel void kernel_mul_mv_id_q5_K_f32(
4985
  device const char * ids,
4986
  device const char * src1,
4987
+ device float * dst,
4988
  constant uint64_t & nbi1,
4989
  constant int64_t & ne00,
4990
  constant int64_t & ne01,
 
5028
  kernel_mul_mv_q5_K_f32_impl(
5029
  src0[id],
5030
  (device const float *) (src1 + bid*nb11),
5031
+ dst + bid*ne0,
5032
  ne00,
5033
  ne01,
5034
  ne02,
 
5047
  kernel void kernel_mul_mv_id_q6_K_f32(
5048
  device const char * ids,
5049
  device const char * src1,
5050
+ device float * dst,
5051
  constant uint64_t & nbi1,
5052
  constant int64_t & ne00,
5053
  constant int64_t & ne01,
 
5091
  kernel_mul_mv_q6_K_f32_impl(
5092
  src0[id],
5093
  (device const float *) (src1 + bid*nb11),
5094
+ dst + bid*ne0,
5095
  ne00,
5096
  ne01,
5097
  ne02,