Chenguang Li commited on
Commit
6a9f9dc
·
1 Parent(s): bdae2b3

CANN: Support MUL_MAT_ID for q8_0 and q4_0 (llama/13705)

Browse files

* [CANN]Support MUL_MAT_ID Q8 && Q4

Signed-off-by: noemotiovon <[email protected]>

* codestyle adjustment

Signed-off-by: noemotiovon <[email protected]>

---------

Signed-off-by: noemotiovon <[email protected]>

ggml/src/ggml-cann/aclnn_ops.cpp CHANGED
@@ -2697,14 +2697,10 @@ static void ggml_cann_mul_mat_id_fp(ggml_backend_cann_context& ctx, ggml_tensor*
2697
  }
2698
  }
2699
 
2700
- // GroupedMatmulV2 required tensor_list.size < 128
2701
  size_t GROUP_SIZE = 128;
2702
- std::vector<std::vector<aclTensor*>> src0_tensor_vec_vec;
2703
- std::vector<std::vector<aclTensor*>> src1_tensor_vec_vec;
2704
- std::vector<std::vector<aclTensor*>> dst_tensor_vec_vec;
2705
-
2706
- // split and call GroupedMatmulV2
2707
  for (size_t i = 0; i < src0_tensor_vec.size(); i += GROUP_SIZE) {
 
2708
  size_t end = std::min(i + GROUP_SIZE, src0_tensor_vec.size());
2709
  std::vector<aclTensor*> src0_tensor_vec_split(src0_tensor_vec.begin() + i, src0_tensor_vec.begin() + end);
2710
  std::vector<aclTensor*> src1_tensor_vec_split(src1_tensor_vec.begin() + i, src1_tensor_vec.begin() + end);
@@ -2722,6 +2718,133 @@ static void ggml_cann_mul_mat_id_fp(ggml_backend_cann_context& ctx, ggml_tensor*
2722
  return;
2723
  }
2724
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2725
  void ggml_cann_mul_mat_id(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
2726
  const enum ggml_type type = dst->src[0]->type;
2727
  switch (type) {
@@ -2729,6 +2852,10 @@ void ggml_cann_mul_mat_id(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
2729
  case GGML_TYPE_F16:
2730
  ggml_cann_mul_mat_id_fp(ctx, dst);
2731
  break;
 
 
 
 
2732
  default:
2733
  GGML_ABORT("Unsupported type for mul_mat_id");
2734
  break;
 
2697
  }
2698
  }
2699
 
 
2700
  size_t GROUP_SIZE = 128;
2701
+ // GroupedMatmulV2 required tensor_list.size < 128
 
 
 
 
2702
  for (size_t i = 0; i < src0_tensor_vec.size(); i += GROUP_SIZE) {
2703
+ // split and call GroupedMatmulV2
2704
  size_t end = std::min(i + GROUP_SIZE, src0_tensor_vec.size());
2705
  std::vector<aclTensor*> src0_tensor_vec_split(src0_tensor_vec.begin() + i, src0_tensor_vec.begin() + end);
2706
  std::vector<aclTensor*> src1_tensor_vec_split(src1_tensor_vec.begin() + i, src1_tensor_vec.begin() + end);
 
2718
  return;
2719
  }
2720
 
2721
+ /**
2722
+ * @brief Performs expert-specific matrix multiplication (MoE) with
2723
+ * quantized precision using the CANN backend.
2724
+ *
2725
+ * This function executes a matrix multiplication operation tailored for
2726
+ * Mixture of Experts (MoE) models, where the input tensor is multiplied
2727
+ * with expert-specific quantized weight matrices. It leverages the CANN
2728
+ * backend to perform efficient low-precision computations and stores the
2729
+ * quantized result in the destination tensor `dst`.
2730
+ *
2731
+ * Quantization techniques reduce memory footprint and improve performance
2732
+ * by using lower-bit representations (e.g., int8) instead of floating-point.
2733
+ * This function is designed to work with such formats and may incorporate
2734
+ * optimizations like identity-based fast paths or routing masks for sparse
2735
+ * expert selection.
2736
+ *
2737
+ * @param ctx The context for executing CANN backend operations.
2738
+ * @param dst The destination tensor where the quantized MoE multiplication result
2739
+ * will be stored.
2740
+ *
2741
+ * @note This function assumes quantized data types and is designed for
2742
+ * MoE architectures with potential sparse expert routing.
2743
+ */
2744
+ static void ggml_cann_mul_mat_id_quant(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
2745
+ // TODO: Use aclnnGroupedMatMul
2746
+ //dst [M, K, N, 1]
2747
+ ggml_tensor * src0 = dst->src[0]; //src0 [D, M, A, 1]
2748
+ ggml_tensor * src1 = dst->src[1]; //src1 [D, B, N, 1], B = K or B = 1
2749
+ ggml_tensor * ids = dst->src[2]; //ids [K, N]
2750
+
2751
+ GGML_TENSOR_BINARY_OP_LOCALS
2752
+
2753
+ // copy index from npu to cpu
2754
+ int64_t n_as = ne02; // A
2755
+ int64_t n_ids = ids->ne[0]; // K
2756
+
2757
+ std::vector<char> ids_host(ggml_nbytes(ids));
2758
+ ggml_cann_async_memcpy(ctx, ids_host.data(), ids->data, ggml_nbytes(ids),
2759
+ ACL_MEMCPY_DEVICE_TO_HOST);
2760
+ ACL_CHECK(aclrtSynchronizeStream(ctx.stream()));
2761
+
2762
+ char * src0_original = (char *) src0->data;
2763
+ char * src1_original = (char *) src1->data;
2764
+ char * dst_original = (char *) dst->data;
2765
+
2766
+ ggml_tensor src0_row = *src0;
2767
+ ggml_tensor src1_row = *src1;
2768
+ ggml_tensor dst_row = *dst;
2769
+
2770
+ const enum ggml_type type = dst->src[0]->type;
2771
+ float weight_elem_size;
2772
+ if (type == GGML_TYPE_Q4_0) {
2773
+ weight_elem_size = float(sizeof(uint8_t)) / 2;
2774
+ } else if (type == GGML_TYPE_Q8_0) {
2775
+ weight_elem_size = float(sizeof(uint8_t));
2776
+ } else {
2777
+ GGML_ABORT("MUL_MAT_ID only support quant type Q4_0 and Q8_0 ");
2778
+ }
2779
+
2780
+ // src0_row [D, M, 1, 1] weight without permute
2781
+ src0_row.ne[2] = 1;
2782
+ src0_row.ne[3] = 1;
2783
+ src0_row.nb[0] = weight_elem_size;
2784
+ src0_row.nb[1] = weight_elem_size * ne00;
2785
+ src0_row.nb[2] = weight_elem_size * ne00;
2786
+ src0_row.nb[3] = weight_elem_size * ne00;
2787
+ size_t weight_stride = ne00 * ne01 * weight_elem_size;
2788
+ size_t weight_size = weight_stride * ne02 * ne03;
2789
+
2790
+ // scale [D, M, 1, 1] -> scale && permute
2791
+ size_t scale_elem_size = sizeof(uint16_t);
2792
+ size_t scale_stride = src0->ne[1] * src0->ne[0] / QK8_0 * scale_elem_size;
2793
+
2794
+ // src1_row [D, 1, 1, 1] -> input
2795
+ src1_row.ne[1] = 1;
2796
+ src1_row.ne[2] = 1;
2797
+ src1_row.ne[3] = 1;
2798
+ src1_row.nb[2] = nb11;
2799
+ src1_row.nb[3] = nb11;
2800
+
2801
+ // dst_row [M, 1, 1, 1] -> out
2802
+ dst_row.ne[1] = 1;
2803
+ dst_row.ne[2] = 1;
2804
+ dst_row.ne[3] = 1;
2805
+ dst_row.nb[2] = nb1;
2806
+ dst_row.nb[3] = nb1;
2807
+
2808
+ //create weight for one row
2809
+ ggml_cann_pool_alloc weight_allocator(ctx.pool());
2810
+ void* weight_buffer = weight_allocator.alloc(nb02);
2811
+ for (int64_t iid1 = 0; iid1 < ids->ne[1]; iid1++) {
2812
+ for (int64_t id = 0; id < n_ids; id++) {
2813
+ // expert index
2814
+ int32_t i02 = *(int32_t *) (ids_host.data() + iid1*ids->nb[1] + id*ids->nb[0]);
2815
+ GGML_ASSERT(i02 >= 0 && i02 < n_as);
2816
+
2817
+ // If B = 1 (broadcast), always use 0; otherwise, use id.
2818
+ int64_t i11 = (ne11 == 1 ? 0 : id);
2819
+ int64_t i12 = iid1;
2820
+
2821
+ int64_t i1 = id;
2822
+ int64_t i2 = i12;
2823
+
2824
+ void* src0_tmp_ptr = src0_original + i02*weight_stride;
2825
+ void* scale_tmp_ptr = src0_original + weight_size + i02*scale_stride;
2826
+ void* src1_tmp_ptr = src1_original + i11*nb11 + i12*nb12;
2827
+ void* dst_tmp_ptr = dst_original + i1*nb1 + i2*nb2;
2828
+
2829
+ // mem cpy
2830
+ ggml_cann_async_memcpy(ctx, weight_buffer, src0_tmp_ptr, weight_stride,
2831
+ ACL_MEMCPY_DEVICE_TO_DEVICE);
2832
+ void* scale_buffer = (char*)weight_buffer + weight_stride;
2833
+ ggml_cann_async_memcpy(ctx, scale_buffer, scale_tmp_ptr, scale_stride,
2834
+ ACL_MEMCPY_DEVICE_TO_DEVICE);
2835
+
2836
+ src0_row.data = weight_buffer;
2837
+ src1_row.data = src1_tmp_ptr;
2838
+ dst_row.data = dst_tmp_ptr;
2839
+ dst_row.src[0] = &src0_row;
2840
+ dst_row.src[1] = &src1_row;
2841
+
2842
+ ggml_cann_mul_mat(ctx, &dst_row);
2843
+ }
2844
+ }
2845
+ return;
2846
+ }
2847
+
2848
  void ggml_cann_mul_mat_id(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
2849
  const enum ggml_type type = dst->src[0]->type;
2850
  switch (type) {
 
2852
  case GGML_TYPE_F16:
2853
  ggml_cann_mul_mat_id_fp(ctx, dst);
2854
  break;
2855
+ case GGML_TYPE_Q4_0:
2856
+ case GGML_TYPE_Q8_0:
2857
+ ggml_cann_mul_mat_id_quant(ctx, dst);
2858
+ break;
2859
  default:
2860
  GGML_ABORT("Unsupported type for mul_mat_id");
2861
  break;
ggml/src/ggml-cann/ggml-cann.cpp CHANGED
@@ -2035,6 +2035,15 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
2035
  case GGML_TYPE_F16:
2036
  case GGML_TYPE_F32:
2037
  return true;
 
 
 
 
 
 
 
 
 
2038
  default:
2039
  return false;
2040
  }
 
2035
  case GGML_TYPE_F16:
2036
  case GGML_TYPE_F32:
2037
  return true;
2038
+ case GGML_TYPE_Q8_0:
2039
+ case GGML_TYPE_Q4_0:
2040
+ #ifdef ASCEND_310P
2041
+ // Q4 && Q8 per group is not suppor on 310p device
2042
+ return false;
2043
+ #endif
2044
+ // only support contiguous for quantized types.
2045
+ return ggml_is_contiguous(op->src[0]) &&
2046
+ ggml_is_contiguous(op->src[1]);
2047
  default:
2048
  return false;
2049
  }