ggerganov commited on
Commit
a13f78c
·
1 Parent(s): 9b4460a

ggml : fix MUL_MAT_ID repack with Q8_K (llama/12544)

Browse files

* ggml : fix MUL_MAT_ID repack with Q8_K

ggml-ci

* ggml : improve repack templates

ggml-ci

ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp CHANGED
@@ -250,7 +250,7 @@ static inline __m256i mul_sum_i8_pairs_int32x8(const __m256i x, const __m256i y)
250
 
251
  static const int8_t kvalues_iq4nl[16] = {-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113};
252
 
253
- static void quantize_q8_0_4x4(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) {
254
  assert(QK8_0 == 32);
255
  assert(k % QK8_0 == 0);
256
  const int nb = k / QK8_0;
@@ -344,7 +344,7 @@ static void quantize_q8_0_4x4(const float * GGML_RESTRICT x, void * GGML_RESTRIC
344
  #endif
345
  }
346
 
347
- static void quantize_q8_0_4x8(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) {
348
  assert(QK8_0 == 32);
349
  assert(k % QK8_0 == 0);
350
  const int nb = k / QK8_0;
@@ -559,7 +559,7 @@ static void quantize_q8_0_4x8(const float * GGML_RESTRICT x, void * GGML_RESTRIC
559
  #endif
560
  }
561
 
562
- static void quantize_q8_K_4x8(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) {
563
  assert(QK_K == 256);
564
  assert(k % QK_K == 0);
565
  const int nb = k / QK_K;
@@ -811,7 +811,7 @@ static void quantize_q8_K_4x8(const float * GGML_RESTRICT x, void * GGML_RESTRIC
811
  // i.e first four bsums from the first super block, followed by first four bsums from second super block and so on
812
  for (int j = 0; j < QK_K * 4; j++) {
813
  int src_offset = (j / (4 * blck_size_interleave)) * blck_size_interleave;
814
- int src_id = (j % (4 * blck_size_interleave)) / blck_size_interleave;
815
  src_offset += (j % blck_size_interleave);
816
  int index = (((j & 31) >> 3) << 2) + ((j >> 8) << 4) + ((j >> 6) & 3);
817
 
@@ -823,26 +823,25 @@ static void quantize_q8_K_4x8(const float * GGML_RESTRICT x, void * GGML_RESTRIC
823
  #endif
824
  }
825
 
826
- static void quantize_mat_q8_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t nrow, int64_t n_per_row, int64_t blck_size_interleave) {
 
 
 
827
  assert(nrow == 4);
828
  UNUSED(nrow);
829
- if (blck_size_interleave == 4) {
830
- quantize_q8_0_4x4(x, vy, n_per_row);
831
- } else if (blck_size_interleave == 8) {
832
- quantize_q8_0_4x8(x, vy, n_per_row);
833
- } else {
834
- assert(false);
835
- }
836
  }
837
 
838
- static void quantize_mat_q8_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t nrow, int64_t n_per_row, int64_t blck_size_interleave) {
839
  assert(nrow == 4);
840
  UNUSED(nrow);
841
- if (blck_size_interleave == 8) {
842
- quantize_q8_K_4x8(x, vy, n_per_row);
843
- } else {
844
- assert(false);
845
- }
 
 
846
  }
847
 
848
  static void ggml_gemv_q4_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
@@ -5276,52 +5275,50 @@ template <> int repack<block_iq4_nl, 4, 4>(struct ggml_tensor * t, const void *
5276
  //}
5277
 
5278
  // gemv
5279
- template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS>
5280
  void gemv(int, float *, size_t, const void *, const void *, int, int);
5281
 
5282
- template <> void gemv<block_q4_0, 4, 4>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
5283
  ggml_gemv_q4_0_4x4_q8_0(n, s, bs, vx, vy, nr, nc);
5284
  }
5285
 
5286
- template <> void gemv<block_q4_0, 8, 4>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
5287
  ggml_gemv_q4_0_4x8_q8_0(n, s, bs, vx, vy, nr, nc);
5288
  }
5289
 
5290
- template <> void gemv<block_q4_0, 8, 8>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
5291
  ggml_gemv_q4_0_8x8_q8_0(n, s, bs, vx, vy, nr, nc);
5292
  }
5293
 
5294
- template <> void gemv<block_q4_K, 8, 8>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
5295
  ggml_gemv_q4_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc);
5296
  }
5297
 
5298
- template <>
5299
- void gemv<block_iq4_nl, 4, 4>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
5300
  ggml_gemv_iq4_nl_4x4_q8_0(n, s, bs, vx, vy, nr, nc);
5301
  }
5302
 
5303
  // gemm
5304
- template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS>
5305
  void gemm(int, float *, size_t, const void *, const void *, int, int);
5306
 
5307
- template <> void gemm<block_q4_0, 4, 4>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
5308
  ggml_gemm_q4_0_4x4_q8_0(n, s, bs, vx, vy, nr, nc);
5309
  }
5310
 
5311
- template <> void gemm<block_q4_0, 8, 4>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
5312
  ggml_gemm_q4_0_4x8_q8_0(n, s, bs, vx, vy, nr, nc);
5313
  }
5314
 
5315
- template <> void gemm<block_q4_0, 8, 8>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
5316
  ggml_gemm_q4_0_8x8_q8_0(n, s, bs, vx, vy, nr, nc);
5317
  }
5318
 
5319
- template <> void gemm<block_q4_K, 8, 8>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
5320
  ggml_gemm_q4_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc);
5321
  }
5322
 
5323
- template <>
5324
- void gemm<block_iq4_nl, 4, 4>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
5325
  ggml_gemm_iq4_nl_4x4_q8_0(n, s, bs, vx, vy, nr, nc);
5326
  }
5327
 
@@ -5335,32 +5332,32 @@ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS, ggml_type PAR
5335
  bool work_size(int /* n_threads */, const struct ggml_tensor * op, size_t & size) override {
5336
  // not realy a GGML_TYPE_Q8_0 but same size.
5337
  switch (op->op) {
5338
- case GGML_OP_MUL_MAT:
5339
- size = ggml_row_size(PARAM_TYPE, ggml_nelements(op->src[1]));
5340
- return true;
5341
- case GGML_OP_MUL_MAT_ID:
5342
- size = ggml_row_size(PARAM_TYPE, ggml_nelements(op->src[1]));
5343
- size = GGML_PAD(size, sizeof(int64_t)); // + padding for next bloc.
5344
- size += sizeof(int64_t) * (1+op->src[0]->ne[2]) * op->src[1]->ne[2];
5345
- return true;
5346
- default:
5347
- // GGML_ABORT("fatal error");
5348
- break;
5349
  }
5350
  return false;
5351
  }
5352
 
5353
  bool compute_forward(struct ggml_compute_params * params, struct ggml_tensor * op) override {
5354
  switch (op->op) {
5355
- case GGML_OP_MUL_MAT:
5356
- forward_mul_mat(params, op);
5357
- return true;
5358
- case GGML_OP_MUL_MAT_ID:
5359
- forward_mul_mat_id(params, op);
5360
- return true;
5361
- default:
5362
- // GGML_ABORT("fatal error");
5363
- break;
5364
  }
5365
  return false;
5366
  }
@@ -5399,17 +5396,10 @@ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS, ggml_type PAR
5399
  const ggml_from_float_t from_float = ggml_get_type_traits_cpu(PARAM_TYPE)->from_float;
5400
 
5401
  int64_t i11_processed = 0;
5402
- if(PARAM_TYPE == GGML_TYPE_Q8_K) {
5403
- for (int64_t i11 = ith * 4; i11 < ne11 - ne11 % 4; i11 += nth * 4) {
5404
- quantize_mat_q8_K((float *) ((char *) src1->data + i11 * nb11), (void *) (wdata + i11 * nbw1), 4, ne10,
5405
- INTER_SIZE);
5406
- }
5407
- } else {
5408
- for (int64_t i11 = ith * 4; i11 < ne11 - ne11 % 4; i11 += nth * 4) {
5409
- quantize_mat_q8_0((float *) ((char *) src1->data + i11 * nb11), (void *) (wdata + i11 * nbw1), 4, ne10,
5410
- INTER_SIZE);
5411
- }
5412
  }
 
5413
  i11_processed = ne11 - ne11 % 4;
5414
  for (int64_t i11 = i11_processed + ith; i11 < ne11; i11 += nth) {
5415
  from_float((float *) ((char *) src1->data + i11 * nb11), (void *) (wdata + i11 * nbw1), ne10);
@@ -5422,22 +5412,24 @@ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS, ggml_type PAR
5422
  int64_t src0_start = (ith * ne01) / nth;
5423
  int64_t src0_end = ((ith + 1) * ne01) / nth;
5424
  src0_start = (src0_start % NB_COLS) ? src0_start + NB_COLS - (src0_start % NB_COLS) : src0_start;
5425
- src0_end = (src0_end % NB_COLS) ? src0_end + NB_COLS - (src0_end % NB_COLS) : src0_end;
5426
  if (src0_start >= src0_end) {
5427
  return;
5428
  }
5429
 
5430
  // If there are more than three rows in src1, use gemm; otherwise, use gemv.
5431
  if (ne11 > 3) {
5432
- gemm<BLOC_TYPE, INTER_SIZE, NB_COLS>(ne00, (float *) ((char *) dst->data) + src0_start, ne01,
5433
- (const char *) src0->data + src0_start * nb01,
5434
- (const char *) src1_wdata, ne11 - ne11 % 4, src0_end - src0_start);
 
5435
  }
5436
  for (int iter = ne11 - ne11 % 4; iter < ne11; iter++) {
5437
- gemv<BLOC_TYPE, INTER_SIZE, NB_COLS>(ne00, (float *) ((char *) dst->data + (iter * nb1)) + src0_start, ne01,
5438
- (const char *) src0->data + src0_start * nb01,
5439
- (const char *) src1_wdata + (src1_col_stride * iter), 1,
5440
- src0_end - src0_start);
 
5441
  }
5442
  }
5443
 
@@ -5452,7 +5444,7 @@ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS, ggml_type PAR
5452
  const int ith = params->ith;
5453
  const int nth = params->nth;
5454
 
5455
- const ggml_from_float_t from_float = ggml_get_type_traits_cpu(GGML_TYPE_Q8_0)->from_float;
5456
 
5457
  // we don't support permuted src0 or src1
5458
  GGML_ASSERT(nb00 == ggml_type_size(src0->type));
@@ -5474,7 +5466,7 @@ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS, ggml_type PAR
5474
  const int n_ids = ids->ne[0]; // n_expert_used
5475
  const int n_as = ne02; // n_expert
5476
 
5477
- const size_t nbw1 = ggml_row_size(GGML_TYPE_Q8_0, ne10);
5478
  const size_t nbw2 = nbw1*ne11;
5479
  const size_t nbw3 = nbw2*ne12;
5480
 
@@ -5486,12 +5478,13 @@ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS, ggml_type PAR
5486
  GGML_ASSERT(params->wsize >= (GGML_PAD(nbw3, sizeof(int64_t)) + n_as * sizeof(int64_t) +
5487
  n_as * ne12 * sizeof(mmid_row_mapping)));
5488
 
5489
- auto wdata = (char *) params->wdata;
5490
- auto wdata_src1_end = (char *) wdata + GGML_PAD(nbw3, sizeof(int64_t));
5491
- int64_t * matrix_row_counts = (int64_t *) (wdata_src1_end); // [n_as]
 
5492
  struct mmid_row_mapping * matrix_rows = (struct mmid_row_mapping *) (matrix_row_counts + n_as); // [n_as][ne12]
5493
 
5494
- // src1: float32 => block_q8_0
5495
  for (int64_t i12 = 0; i12 < ne12; ++i12) {
5496
  for (int64_t i11 = ith; i11 < ne11; i11 += nth) {
5497
  from_float((float *)((char *) src1->data + i12 * nb12 + i11 * nb11),
@@ -5530,34 +5523,37 @@ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS, ggml_type PAR
5530
  continue;
5531
  }
5532
 
5533
- auto src0_cur = (const char *) src0->data + cur_a*nb02;
5534
 
5535
  //const int64_t nr0 = ne01; // src0 rows
5536
  const int64_t nr1 = cne1; // src1 rows
5537
 
5538
  int64_t src0_cur_start = (ith * ne01) / nth;
5539
  int64_t src0_cur_end = ((ith + 1) * ne01) / nth;
5540
- src0_cur_start =
5541
- (src0_cur_start % NB_COLS) ? src0_cur_start + NB_COLS - (src0_cur_start % NB_COLS) : src0_cur_start;
5542
- src0_cur_end = (src0_cur_end % NB_COLS) ? src0_cur_end + NB_COLS - (src0_cur_end % NB_COLS) : src0_cur_end;
5543
 
5544
- if (src0_cur_start >= src0_cur_end) return;
 
 
 
 
 
5545
 
5546
  for (int ir1 = 0; ir1 < nr1; ir1++) {
5547
  struct mmid_row_mapping row_mapping = MMID_MATRIX_ROW(cur_a, ir1);
5548
- const int id = row_mapping.i1; // selected expert index
5549
 
5550
- const int64_t i11 = id % ne11;
5551
- const int64_t i12 = row_mapping.i2; // row index in src1
 
 
5552
 
5553
- const int64_t i1 = id; // selected expert index
5554
- const int64_t i2 = i12; // row
5555
 
5556
- auto src1_col = (const char *) wdata + (i11 * nbw1 + i12 * nbw2);
5557
 
5558
- gemv<BLOC_TYPE, INTER_SIZE, NB_COLS>(
5559
- ne00, (float *)((char *) dst->data + (i1 * nb1 + i2 * nb2)) + src0_cur_start,
5560
- ne01, src0_cur + src0_cur_start * nb01,
5561
  src1_col, 1, src0_cur_end - src0_cur_start);
5562
  }
5563
  }
@@ -5578,7 +5574,7 @@ static const tensor_traits<block_q4_0, 8, 8, GGML_TYPE_Q8_0> q4_0_8x8_q8_0;
5578
  static const tensor_traits<block_q4_K, 8, 8, GGML_TYPE_Q8_K> q4_K_8x8_q8_K;
5579
 
5580
  // instance for IQ4
5581
- static const tensor_traits<block_iq4_nl, 4, 4, GGML_TYPE_IQ4_NL> iq4_nl_4x4_q8_0;
5582
 
5583
  } // namespace ggml::cpu::aarch64
5584
 
 
250
 
251
  static const int8_t kvalues_iq4nl[16] = {-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113};
252
 
253
+ static void ggml_quantize_mat_q8_0_4x4(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) {
254
  assert(QK8_0 == 32);
255
  assert(k % QK8_0 == 0);
256
  const int nb = k / QK8_0;
 
344
  #endif
345
  }
346
 
347
+ static void ggml_quantize_mat_q8_0_4x8(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) {
348
  assert(QK8_0 == 32);
349
  assert(k % QK8_0 == 0);
350
  const int nb = k / QK8_0;
 
559
  #endif
560
  }
561
 
562
+ static void ggml_quantize_mat_q8_K_4x8(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) {
563
  assert(QK_K == 256);
564
  assert(k % QK_K == 0);
565
  const int nb = k / QK_K;
 
811
  // i.e first four bsums from the first super block, followed by first four bsums from second super block and so on
812
  for (int j = 0; j < QK_K * 4; j++) {
813
  int src_offset = (j / (4 * blck_size_interleave)) * blck_size_interleave;
814
+ int src_id = (j % (4 * blck_size_interleave)) / blck_size_interleave;
815
  src_offset += (j % blck_size_interleave);
816
  int index = (((j & 31) >> 3) << 2) + ((j >> 8) << 4) + ((j >> 6) & 3);
817
 
 
823
  #endif
824
  }
825
 
826
+ template <int64_t INTER_SIZE, ggml_type PARAM_TYPE>
827
+ void ggml_quantize_mat_t(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t nrow, int64_t n_per_row);
828
+
829
+ template <> void ggml_quantize_mat_t<4, GGML_TYPE_Q8_0>(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t nrow, int64_t n_per_row) {
830
  assert(nrow == 4);
831
  UNUSED(nrow);
832
+ ggml_quantize_mat_q8_0_4x4(x, vy, n_per_row);
 
 
 
 
 
 
833
  }
834
 
835
+ template <> void ggml_quantize_mat_t<8, GGML_TYPE_Q8_0>(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t nrow, int64_t n_per_row) {
836
  assert(nrow == 4);
837
  UNUSED(nrow);
838
+ ggml_quantize_mat_q8_0_4x8(x, vy, n_per_row);
839
+ }
840
+
841
+ template <> void ggml_quantize_mat_t<8, GGML_TYPE_Q8_K>(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t nrow, int64_t n_per_row) {
842
+ assert(nrow == 4);
843
+ UNUSED(nrow);
844
+ ggml_quantize_mat_q8_K_4x8(x, vy, n_per_row);
845
  }
846
 
847
  static void ggml_gemv_q4_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
 
5275
  //}
5276
 
5277
  // gemv
5278
+ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS, ggml_type PARAM_TYPE>
5279
  void gemv(int, float *, size_t, const void *, const void *, int, int);
5280
 
5281
+ template <> void gemv<block_q4_0, 4, 4, GGML_TYPE_Q8_0>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
5282
  ggml_gemv_q4_0_4x4_q8_0(n, s, bs, vx, vy, nr, nc);
5283
  }
5284
 
5285
+ template <> void gemv<block_q4_0, 8, 4, GGML_TYPE_Q8_0>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
5286
  ggml_gemv_q4_0_4x8_q8_0(n, s, bs, vx, vy, nr, nc);
5287
  }
5288
 
5289
+ template <> void gemv<block_q4_0, 8, 8, GGML_TYPE_Q8_0>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
5290
  ggml_gemv_q4_0_8x8_q8_0(n, s, bs, vx, vy, nr, nc);
5291
  }
5292
 
5293
+ template <> void gemv<block_q4_K, 8, 8, GGML_TYPE_Q8_K>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
5294
  ggml_gemv_q4_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc);
5295
  }
5296
 
5297
+ template <> void gemv<block_iq4_nl, 4, 4, GGML_TYPE_Q8_0>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
 
5298
  ggml_gemv_iq4_nl_4x4_q8_0(n, s, bs, vx, vy, nr, nc);
5299
  }
5300
 
5301
  // gemm
5302
+ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS, ggml_type PARAM_TYPE>
5303
  void gemm(int, float *, size_t, const void *, const void *, int, int);
5304
 
5305
+ template <> void gemm<block_q4_0, 4, 4, GGML_TYPE_Q8_0>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
5306
  ggml_gemm_q4_0_4x4_q8_0(n, s, bs, vx, vy, nr, nc);
5307
  }
5308
 
5309
+ template <> void gemm<block_q4_0, 8, 4, GGML_TYPE_Q8_0>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
5310
  ggml_gemm_q4_0_4x8_q8_0(n, s, bs, vx, vy, nr, nc);
5311
  }
5312
 
5313
+ template <> void gemm<block_q4_0, 8, 8, GGML_TYPE_Q8_0>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
5314
  ggml_gemm_q4_0_8x8_q8_0(n, s, bs, vx, vy, nr, nc);
5315
  }
5316
 
5317
+ template <> void gemm<block_q4_K, 8, 8, GGML_TYPE_Q8_K>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
5318
  ggml_gemm_q4_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc);
5319
  }
5320
 
5321
+ template <> void gemm<block_iq4_nl, 4, 4, GGML_TYPE_Q8_0>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
 
5322
  ggml_gemm_iq4_nl_4x4_q8_0(n, s, bs, vx, vy, nr, nc);
5323
  }
5324
 
 
5332
  bool work_size(int /* n_threads */, const struct ggml_tensor * op, size_t & size) override {
5333
  // not realy a GGML_TYPE_Q8_0 but same size.
5334
  switch (op->op) {
5335
+ case GGML_OP_MUL_MAT:
5336
+ size = ggml_row_size(PARAM_TYPE, ggml_nelements(op->src[1]));
5337
+ return true;
5338
+ case GGML_OP_MUL_MAT_ID:
5339
+ size = ggml_row_size(PARAM_TYPE, ggml_nelements(op->src[1]));
5340
+ size = GGML_PAD(size, sizeof(int64_t)); // + padding for next bloc.
5341
+ size += sizeof(int64_t) * (1+op->src[0]->ne[2]) * op->src[1]->ne[2];
5342
+ return true;
5343
+ default:
5344
+ // GGML_ABORT("fatal error");
5345
+ break;
5346
  }
5347
  return false;
5348
  }
5349
 
5350
  bool compute_forward(struct ggml_compute_params * params, struct ggml_tensor * op) override {
5351
  switch (op->op) {
5352
+ case GGML_OP_MUL_MAT:
5353
+ forward_mul_mat(params, op);
5354
+ return true;
5355
+ case GGML_OP_MUL_MAT_ID:
5356
+ forward_mul_mat_id(params, op);
5357
+ return true;
5358
+ default:
5359
+ // GGML_ABORT("fatal error");
5360
+ break;
5361
  }
5362
  return false;
5363
  }
 
5396
  const ggml_from_float_t from_float = ggml_get_type_traits_cpu(PARAM_TYPE)->from_float;
5397
 
5398
  int64_t i11_processed = 0;
5399
+ for (int64_t i11 = ith * 4; i11 < ne11 - ne11 % 4; i11 += nth * 4) {
5400
+ ggml_quantize_mat_t<INTER_SIZE, PARAM_TYPE>((float *) ((char *) src1->data + i11 * nb11), (void *) (wdata + i11 * nbw1), 4, ne10);
 
 
 
 
 
 
 
 
5401
  }
5402
+
5403
  i11_processed = ne11 - ne11 % 4;
5404
  for (int64_t i11 = i11_processed + ith; i11 < ne11; i11 += nth) {
5405
  from_float((float *) ((char *) src1->data + i11 * nb11), (void *) (wdata + i11 * nbw1), ne10);
 
5412
  int64_t src0_start = (ith * ne01) / nth;
5413
  int64_t src0_end = ((ith + 1) * ne01) / nth;
5414
  src0_start = (src0_start % NB_COLS) ? src0_start + NB_COLS - (src0_start % NB_COLS) : src0_start;
5415
+ src0_end = (src0_end % NB_COLS) ? src0_end + NB_COLS - (src0_end % NB_COLS) : src0_end;
5416
  if (src0_start >= src0_end) {
5417
  return;
5418
  }
5419
 
5420
  // If there are more than three rows in src1, use gemm; otherwise, use gemv.
5421
  if (ne11 > 3) {
5422
+ gemm<BLOC_TYPE, INTER_SIZE, NB_COLS, PARAM_TYPE>(ne00,
5423
+ (float *) ((char *) dst->data) + src0_start, ne01,
5424
+ (const char *) src0->data + src0_start * nb01,
5425
+ (const char *) src1_wdata, ne11 - ne11 % 4, src0_end - src0_start);
5426
  }
5427
  for (int iter = ne11 - ne11 % 4; iter < ne11; iter++) {
5428
+ gemv<BLOC_TYPE, INTER_SIZE, NB_COLS, PARAM_TYPE>(ne00,
5429
+ (float *) ((char *) dst->data + (iter * nb1)) + src0_start, ne01,
5430
+ (const char *) src0->data + src0_start * nb01,
5431
+ (const char *) src1_wdata + (src1_col_stride * iter), 1,
5432
+ src0_end - src0_start);
5433
  }
5434
  }
5435
 
 
5444
  const int ith = params->ith;
5445
  const int nth = params->nth;
5446
 
5447
+ const ggml_from_float_t from_float = ggml_get_type_traits_cpu(PARAM_TYPE)->from_float;
5448
 
5449
  // we don't support permuted src0 or src1
5450
  GGML_ASSERT(nb00 == ggml_type_size(src0->type));
 
5466
  const int n_ids = ids->ne[0]; // n_expert_used
5467
  const int n_as = ne02; // n_expert
5468
 
5469
+ const size_t nbw1 = ggml_row_size(PARAM_TYPE, ne10);
5470
  const size_t nbw2 = nbw1*ne11;
5471
  const size_t nbw3 = nbw2*ne12;
5472
 
 
5478
  GGML_ASSERT(params->wsize >= (GGML_PAD(nbw3, sizeof(int64_t)) + n_as * sizeof(int64_t) +
5479
  n_as * ne12 * sizeof(mmid_row_mapping)));
5480
 
5481
+ auto * wdata = (char *) params->wdata;
5482
+ auto * wdata_src1_end = (char *) wdata + GGML_PAD(nbw3, sizeof(int64_t));
5483
+ auto * matrix_row_counts = (int64_t *) (wdata_src1_end); // [n_as]
5484
+
5485
  struct mmid_row_mapping * matrix_rows = (struct mmid_row_mapping *) (matrix_row_counts + n_as); // [n_as][ne12]
5486
 
5487
+ // src1: float32 => param type
5488
  for (int64_t i12 = 0; i12 < ne12; ++i12) {
5489
  for (int64_t i11 = ith; i11 < ne11; i11 += nth) {
5490
  from_float((float *)((char *) src1->data + i12 * nb12 + i11 * nb11),
 
5523
  continue;
5524
  }
5525
 
5526
+ const auto * src0_cur = (const char *) src0->data + cur_a*nb02;
5527
 
5528
  //const int64_t nr0 = ne01; // src0 rows
5529
  const int64_t nr1 = cne1; // src1 rows
5530
 
5531
  int64_t src0_cur_start = (ith * ne01) / nth;
5532
  int64_t src0_cur_end = ((ith + 1) * ne01) / nth;
 
 
 
5533
 
5534
+ src0_cur_start = (src0_cur_start % NB_COLS) ? src0_cur_start + NB_COLS - (src0_cur_start % NB_COLS) : src0_cur_start;
5535
+ src0_cur_end = (src0_cur_end % NB_COLS) ? src0_cur_end + NB_COLS - (src0_cur_end % NB_COLS) : src0_cur_end;
5536
+
5537
+ if (src0_cur_start >= src0_cur_end) {
5538
+ return;
5539
+ }
5540
 
5541
  for (int ir1 = 0; ir1 < nr1; ir1++) {
5542
  struct mmid_row_mapping row_mapping = MMID_MATRIX_ROW(cur_a, ir1);
 
5543
 
5544
+ const int id = row_mapping.i1; // selected expert index
5545
+
5546
+ const int64_t i11 = id % ne11;
5547
+ const int64_t i12 = row_mapping.i2; // row index in src1
5548
 
5549
+ const int64_t i1 = id; // selected expert index
5550
+ const int64_t i2 = i12; // row
5551
 
5552
+ const auto * src1_col = (const char *) wdata + (i11 * nbw1 + i12 * nbw2);
5553
 
5554
+ gemv<BLOC_TYPE, INTER_SIZE, NB_COLS, PARAM_TYPE>(ne00,
5555
+ (float *)((char *) dst->data + (i1 * nb1 + i2 * nb2)) + src0_cur_start, ne01,
5556
+ src0_cur + src0_cur_start * nb01,
5557
  src1_col, 1, src0_cur_end - src0_cur_start);
5558
  }
5559
  }
 
5574
  static const tensor_traits<block_q4_K, 8, 8, GGML_TYPE_Q8_K> q4_K_8x8_q8_K;
5575
 
5576
  // instance for IQ4
5577
+ static const tensor_traits<block_iq4_nl, 4, 4, GGML_TYPE_Q8_0> iq4_nl_4x4_q8_0;
5578
 
5579
  } // namespace ggml::cpu::aarch64
5580