Spaces:
Sleeping
Sleeping
Chenguang Li
commited on
Commit
·
6a9f9dc
1
Parent(s):
bdae2b3
CANN: Support MUL_MAT_ID for q8_0 and q4_0 (llama/13705)
Browse files* [CANN]Support MUL_MAT_ID Q8 && Q4
Signed-off-by: noemotiovon <[email protected]>
* codestyle adjustment
Signed-off-by: noemotiovon <[email protected]>
---------
Signed-off-by: noemotiovon <[email protected]>
ggml/src/ggml-cann/aclnn_ops.cpp
CHANGED
|
@@ -2697,14 +2697,10 @@ static void ggml_cann_mul_mat_id_fp(ggml_backend_cann_context& ctx, ggml_tensor*
|
|
| 2697 |
}
|
| 2698 |
}
|
| 2699 |
|
| 2700 |
-
// GroupedMatmulV2 required tensor_list.size < 128
|
| 2701 |
size_t GROUP_SIZE = 128;
|
| 2702 |
-
|
| 2703 |
-
std::vector<std::vector<aclTensor*>> src1_tensor_vec_vec;
|
| 2704 |
-
std::vector<std::vector<aclTensor*>> dst_tensor_vec_vec;
|
| 2705 |
-
|
| 2706 |
-
// split and call GroupedMatmulV2
|
| 2707 |
for (size_t i = 0; i < src0_tensor_vec.size(); i += GROUP_SIZE) {
|
|
|
|
| 2708 |
size_t end = std::min(i + GROUP_SIZE, src0_tensor_vec.size());
|
| 2709 |
std::vector<aclTensor*> src0_tensor_vec_split(src0_tensor_vec.begin() + i, src0_tensor_vec.begin() + end);
|
| 2710 |
std::vector<aclTensor*> src1_tensor_vec_split(src1_tensor_vec.begin() + i, src1_tensor_vec.begin() + end);
|
|
@@ -2722,6 +2718,133 @@ static void ggml_cann_mul_mat_id_fp(ggml_backend_cann_context& ctx, ggml_tensor*
|
|
| 2722 |
return;
|
| 2723 |
}
|
| 2724 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2725 |
void ggml_cann_mul_mat_id(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
|
| 2726 |
const enum ggml_type type = dst->src[0]->type;
|
| 2727 |
switch (type) {
|
|
@@ -2729,6 +2852,10 @@ void ggml_cann_mul_mat_id(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
|
|
| 2729 |
case GGML_TYPE_F16:
|
| 2730 |
ggml_cann_mul_mat_id_fp(ctx, dst);
|
| 2731 |
break;
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2732 |
default:
|
| 2733 |
GGML_ABORT("Unsupported type for mul_mat_id");
|
| 2734 |
break;
|
|
|
|
| 2697 |
}
|
| 2698 |
}
|
| 2699 |
|
|
|
|
| 2700 |
size_t GROUP_SIZE = 128;
|
| 2701 |
+
// GroupedMatmulV2 required tensor_list.size < 128
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2702 |
for (size_t i = 0; i < src0_tensor_vec.size(); i += GROUP_SIZE) {
|
| 2703 |
+
// split and call GroupedMatmulV2
|
| 2704 |
size_t end = std::min(i + GROUP_SIZE, src0_tensor_vec.size());
|
| 2705 |
std::vector<aclTensor*> src0_tensor_vec_split(src0_tensor_vec.begin() + i, src0_tensor_vec.begin() + end);
|
| 2706 |
std::vector<aclTensor*> src1_tensor_vec_split(src1_tensor_vec.begin() + i, src1_tensor_vec.begin() + end);
|
|
|
|
| 2718 |
return;
|
| 2719 |
}
|
| 2720 |
|
| 2721 |
+
/**
|
| 2722 |
+
* @brief Performs expert-specific matrix multiplication (MoE) with
|
| 2723 |
+
* quantized precision using the CANN backend.
|
| 2724 |
+
*
|
| 2725 |
+
* This function executes a matrix multiplication operation tailored for
|
| 2726 |
+
* Mixture of Experts (MoE) models, where the input tensor is multiplied
|
| 2727 |
+
* with expert-specific quantized weight matrices. It leverages the CANN
|
| 2728 |
+
* backend to perform efficient low-precision computations and stores the
|
| 2729 |
+
* quantized result in the destination tensor `dst`.
|
| 2730 |
+
*
|
| 2731 |
+
* Quantization techniques reduce memory footprint and improve performance
|
| 2732 |
+
* by using lower-bit representations (e.g., int8) instead of floating-point.
|
| 2733 |
+
* This function is designed to work with such formats and may incorporate
|
| 2734 |
+
* optimizations like identity-based fast paths or routing masks for sparse
|
| 2735 |
+
* expert selection.
|
| 2736 |
+
*
|
| 2737 |
+
* @param ctx The context for executing CANN backend operations.
|
| 2738 |
+
* @param dst The destination tensor where the quantized MoE multiplication result
|
| 2739 |
+
* will be stored.
|
| 2740 |
+
*
|
| 2741 |
+
* @note This function assumes quantized data types and is designed for
|
| 2742 |
+
* MoE architectures with potential sparse expert routing.
|
| 2743 |
+
*/
|
| 2744 |
+
static void ggml_cann_mul_mat_id_quant(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
|
| 2745 |
+
// TODO: Use aclnnGroupedMatMul
|
| 2746 |
+
//dst [M, K, N, 1]
|
| 2747 |
+
ggml_tensor * src0 = dst->src[0]; //src0 [D, M, A, 1]
|
| 2748 |
+
ggml_tensor * src1 = dst->src[1]; //src1 [D, B, N, 1], B = K or B = 1
|
| 2749 |
+
ggml_tensor * ids = dst->src[2]; //ids [K, N]
|
| 2750 |
+
|
| 2751 |
+
GGML_TENSOR_BINARY_OP_LOCALS
|
| 2752 |
+
|
| 2753 |
+
// copy index from npu to cpu
|
| 2754 |
+
int64_t n_as = ne02; // A
|
| 2755 |
+
int64_t n_ids = ids->ne[0]; // K
|
| 2756 |
+
|
| 2757 |
+
std::vector<char> ids_host(ggml_nbytes(ids));
|
| 2758 |
+
ggml_cann_async_memcpy(ctx, ids_host.data(), ids->data, ggml_nbytes(ids),
|
| 2759 |
+
ACL_MEMCPY_DEVICE_TO_HOST);
|
| 2760 |
+
ACL_CHECK(aclrtSynchronizeStream(ctx.stream()));
|
| 2761 |
+
|
| 2762 |
+
char * src0_original = (char *) src0->data;
|
| 2763 |
+
char * src1_original = (char *) src1->data;
|
| 2764 |
+
char * dst_original = (char *) dst->data;
|
| 2765 |
+
|
| 2766 |
+
ggml_tensor src0_row = *src0;
|
| 2767 |
+
ggml_tensor src1_row = *src1;
|
| 2768 |
+
ggml_tensor dst_row = *dst;
|
| 2769 |
+
|
| 2770 |
+
const enum ggml_type type = dst->src[0]->type;
|
| 2771 |
+
float weight_elem_size;
|
| 2772 |
+
if (type == GGML_TYPE_Q4_0) {
|
| 2773 |
+
weight_elem_size = float(sizeof(uint8_t)) / 2;
|
| 2774 |
+
} else if (type == GGML_TYPE_Q8_0) {
|
| 2775 |
+
weight_elem_size = float(sizeof(uint8_t));
|
| 2776 |
+
} else {
|
| 2777 |
+
GGML_ABORT("MUL_MAT_ID only support quant type Q4_0 and Q8_0 ");
|
| 2778 |
+
}
|
| 2779 |
+
|
| 2780 |
+
// src0_row [D, M, 1, 1] weight without permute
|
| 2781 |
+
src0_row.ne[2] = 1;
|
| 2782 |
+
src0_row.ne[3] = 1;
|
| 2783 |
+
src0_row.nb[0] = weight_elem_size;
|
| 2784 |
+
src0_row.nb[1] = weight_elem_size * ne00;
|
| 2785 |
+
src0_row.nb[2] = weight_elem_size * ne00;
|
| 2786 |
+
src0_row.nb[3] = weight_elem_size * ne00;
|
| 2787 |
+
size_t weight_stride = ne00 * ne01 * weight_elem_size;
|
| 2788 |
+
size_t weight_size = weight_stride * ne02 * ne03;
|
| 2789 |
+
|
| 2790 |
+
// scale [D, M, 1, 1] -> scale && permute
|
| 2791 |
+
size_t scale_elem_size = sizeof(uint16_t);
|
| 2792 |
+
size_t scale_stride = src0->ne[1] * src0->ne[0] / QK8_0 * scale_elem_size;
|
| 2793 |
+
|
| 2794 |
+
// src1_row [D, 1, 1, 1] -> input
|
| 2795 |
+
src1_row.ne[1] = 1;
|
| 2796 |
+
src1_row.ne[2] = 1;
|
| 2797 |
+
src1_row.ne[3] = 1;
|
| 2798 |
+
src1_row.nb[2] = nb11;
|
| 2799 |
+
src1_row.nb[3] = nb11;
|
| 2800 |
+
|
| 2801 |
+
// dst_row [M, 1, 1, 1] -> out
|
| 2802 |
+
dst_row.ne[1] = 1;
|
| 2803 |
+
dst_row.ne[2] = 1;
|
| 2804 |
+
dst_row.ne[3] = 1;
|
| 2805 |
+
dst_row.nb[2] = nb1;
|
| 2806 |
+
dst_row.nb[3] = nb1;
|
| 2807 |
+
|
| 2808 |
+
//create weight for one row
|
| 2809 |
+
ggml_cann_pool_alloc weight_allocator(ctx.pool());
|
| 2810 |
+
void* weight_buffer = weight_allocator.alloc(nb02);
|
| 2811 |
+
for (int64_t iid1 = 0; iid1 < ids->ne[1]; iid1++) {
|
| 2812 |
+
for (int64_t id = 0; id < n_ids; id++) {
|
| 2813 |
+
// expert index
|
| 2814 |
+
int32_t i02 = *(int32_t *) (ids_host.data() + iid1*ids->nb[1] + id*ids->nb[0]);
|
| 2815 |
+
GGML_ASSERT(i02 >= 0 && i02 < n_as);
|
| 2816 |
+
|
| 2817 |
+
// If B = 1 (broadcast), always use 0; otherwise, use id.
|
| 2818 |
+
int64_t i11 = (ne11 == 1 ? 0 : id);
|
| 2819 |
+
int64_t i12 = iid1;
|
| 2820 |
+
|
| 2821 |
+
int64_t i1 = id;
|
| 2822 |
+
int64_t i2 = i12;
|
| 2823 |
+
|
| 2824 |
+
void* src0_tmp_ptr = src0_original + i02*weight_stride;
|
| 2825 |
+
void* scale_tmp_ptr = src0_original + weight_size + i02*scale_stride;
|
| 2826 |
+
void* src1_tmp_ptr = src1_original + i11*nb11 + i12*nb12;
|
| 2827 |
+
void* dst_tmp_ptr = dst_original + i1*nb1 + i2*nb2;
|
| 2828 |
+
|
| 2829 |
+
// mem cpy
|
| 2830 |
+
ggml_cann_async_memcpy(ctx, weight_buffer, src0_tmp_ptr, weight_stride,
|
| 2831 |
+
ACL_MEMCPY_DEVICE_TO_DEVICE);
|
| 2832 |
+
void* scale_buffer = (char*)weight_buffer + weight_stride;
|
| 2833 |
+
ggml_cann_async_memcpy(ctx, scale_buffer, scale_tmp_ptr, scale_stride,
|
| 2834 |
+
ACL_MEMCPY_DEVICE_TO_DEVICE);
|
| 2835 |
+
|
| 2836 |
+
src0_row.data = weight_buffer;
|
| 2837 |
+
src1_row.data = src1_tmp_ptr;
|
| 2838 |
+
dst_row.data = dst_tmp_ptr;
|
| 2839 |
+
dst_row.src[0] = &src0_row;
|
| 2840 |
+
dst_row.src[1] = &src1_row;
|
| 2841 |
+
|
| 2842 |
+
ggml_cann_mul_mat(ctx, &dst_row);
|
| 2843 |
+
}
|
| 2844 |
+
}
|
| 2845 |
+
return;
|
| 2846 |
+
}
|
| 2847 |
+
|
| 2848 |
void ggml_cann_mul_mat_id(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
|
| 2849 |
const enum ggml_type type = dst->src[0]->type;
|
| 2850 |
switch (type) {
|
|
|
|
| 2852 |
case GGML_TYPE_F16:
|
| 2853 |
ggml_cann_mul_mat_id_fp(ctx, dst);
|
| 2854 |
break;
|
| 2855 |
+
case GGML_TYPE_Q4_0:
|
| 2856 |
+
case GGML_TYPE_Q8_0:
|
| 2857 |
+
ggml_cann_mul_mat_id_quant(ctx, dst);
|
| 2858 |
+
break;
|
| 2859 |
default:
|
| 2860 |
GGML_ABORT("Unsupported type for mul_mat_id");
|
| 2861 |
break;
|
ggml/src/ggml-cann/ggml-cann.cpp
CHANGED
|
@@ -2035,6 +2035,15 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
|
|
| 2035 |
case GGML_TYPE_F16:
|
| 2036 |
case GGML_TYPE_F32:
|
| 2037 |
return true;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2038 |
default:
|
| 2039 |
return false;
|
| 2040 |
}
|
|
|
|
| 2035 |
case GGML_TYPE_F16:
|
| 2036 |
case GGML_TYPE_F32:
|
| 2037 |
return true;
|
| 2038 |
+
case GGML_TYPE_Q8_0:
|
| 2039 |
+
case GGML_TYPE_Q4_0:
|
| 2040 |
+
#ifdef ASCEND_310P
|
| 2041 |
+
// Q4 && Q8 per group is not suppor on 310p device
|
| 2042 |
+
return false;
|
| 2043 |
+
#endif
|
| 2044 |
+
// only support contiguous for quantized types.
|
| 2045 |
+
return ggml_is_contiguous(op->src[0]) &&
|
| 2046 |
+
ggml_is_contiguous(op->src[1]);
|
| 2047 |
default:
|
| 2048 |
return false;
|
| 2049 |
}
|