Spaces:
Sleeping
Sleeping
Dan Johansson
commited on
Commit
·
9b4460a
1
Parent(s):
33f8316
ggml-cpu : update KleidiAI to v1.5.0 (llama/12568)
Browse filesggml-cpu : bug fix related to KleidiAI LHS packing
Signed-off-by: Dan Johansson <[email protected]>
ggml/src/ggml-cpu/CMakeLists.txt
CHANGED
|
@@ -357,9 +357,9 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
|
|
| 357 |
|
| 358 |
# Fetch KleidiAI sources:
|
| 359 |
include(FetchContent)
|
| 360 |
-
set(KLEIDIAI_COMMIT_TAG "v1.
|
| 361 |
set(KLEIDIAI_DOWNLOAD_URL "https://github.com/ARM-software/kleidiai/archive/refs/tags/${KLEIDIAI_COMMIT_TAG}.tar.gz")
|
| 362 |
-
set(KLEIDIAI_ARCHIVE_MD5 "
|
| 363 |
|
| 364 |
if (POLICY CMP0135)
|
| 365 |
cmake_policy(SET CMP0135 NEW)
|
|
|
|
| 357 |
|
| 358 |
# Fetch KleidiAI sources:
|
| 359 |
include(FetchContent)
|
| 360 |
+
set(KLEIDIAI_COMMIT_TAG "v1.5.0")
|
| 361 |
set(KLEIDIAI_DOWNLOAD_URL "https://github.com/ARM-software/kleidiai/archive/refs/tags/${KLEIDIAI_COMMIT_TAG}.tar.gz")
|
| 362 |
+
set(KLEIDIAI_ARCHIVE_MD5 "ea22e1aefb800e9bc8c74d91633cc58e")
|
| 363 |
|
| 364 |
if (POLICY CMP0135)
|
| 365 |
cmake_policy(SET CMP0135 NEW)
|
ggml/src/ggml-cpu/kleidiai/kernels.cpp
CHANGED
|
@@ -51,11 +51,10 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
|
|
| 51 |
/* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot,
|
| 52 |
},
|
| 53 |
/* .lhs_info = */ {
|
| 54 |
-
/* .get_offset = */
|
| 55 |
-
/* .get_packed_offset = */
|
| 56 |
/* .packed_size = */ kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32_neon,
|
| 57 |
/* .pack_func = */ kai_run_lhs_quant_pack_qsi8d32p_f32_neon,
|
| 58 |
-
/* .require_aligned_m_idx = */ true,
|
| 59 |
},
|
| 60 |
/* .rhs_info = */ {
|
| 61 |
/* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon,
|
|
@@ -100,7 +99,6 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
|
|
| 100 |
/* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32,
|
| 101 |
/* .packed_size = */ kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32,
|
| 102 |
/* .pack_func = */ kai_run_lhs_quant_pack_qsi8d32p_f32,
|
| 103 |
-
/* .require_aligned_m_idx = */ false,
|
| 104 |
},
|
| 105 |
/* .rhs_info = */ {
|
| 106 |
/* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
|
|
@@ -144,7 +142,6 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
|
|
| 144 |
/* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32,
|
| 145 |
/* .packed_size = */ kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32,
|
| 146 |
/* .pack_func = */ kai_run_lhs_quant_pack_qsi8d32p_f32,
|
| 147 |
-
/* .require_aligned_m_idx = */ false,
|
| 148 |
},
|
| 149 |
/* .rhs_info = */ {
|
| 150 |
/* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
|
|
@@ -189,7 +186,6 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
|
|
| 189 |
/* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32,
|
| 190 |
/* .packed_size = */ kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32,
|
| 191 |
/* .pack_func = */ kai_run_lhs_quant_pack_qsi8d32p_f32,
|
| 192 |
-
/* .require_aligned_m_idx = */ false,
|
| 193 |
},
|
| 194 |
/* .rhs_info = */ {
|
| 195 |
/* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
|
|
@@ -233,7 +229,6 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
|
|
| 233 |
/* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32,
|
| 234 |
/* .packed_size = */ kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32,
|
| 235 |
/* .pack_func = */ kai_run_lhs_quant_pack_qsi8d32p_f32,
|
| 236 |
-
/* .require_aligned_m_idx = */ false,
|
| 237 |
},
|
| 238 |
/* .rhs_info = */ {
|
| 239 |
/* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
|
|
|
|
| 51 |
/* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot,
|
| 52 |
},
|
| 53 |
/* .lhs_info = */ {
|
| 54 |
+
/* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32_neon,
|
| 55 |
+
/* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32_neon,
|
| 56 |
/* .packed_size = */ kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32_neon,
|
| 57 |
/* .pack_func = */ kai_run_lhs_quant_pack_qsi8d32p_f32_neon,
|
|
|
|
| 58 |
},
|
| 59 |
/* .rhs_info = */ {
|
| 60 |
/* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon,
|
|
|
|
| 99 |
/* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32,
|
| 100 |
/* .packed_size = */ kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32,
|
| 101 |
/* .pack_func = */ kai_run_lhs_quant_pack_qsi8d32p_f32,
|
|
|
|
| 102 |
},
|
| 103 |
/* .rhs_info = */ {
|
| 104 |
/* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
|
|
|
|
| 142 |
/* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32,
|
| 143 |
/* .packed_size = */ kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32,
|
| 144 |
/* .pack_func = */ kai_run_lhs_quant_pack_qsi8d32p_f32,
|
|
|
|
| 145 |
},
|
| 146 |
/* .rhs_info = */ {
|
| 147 |
/* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
|
|
|
|
| 186 |
/* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32,
|
| 187 |
/* .packed_size = */ kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32,
|
| 188 |
/* .pack_func = */ kai_run_lhs_quant_pack_qsi8d32p_f32,
|
|
|
|
| 189 |
},
|
| 190 |
/* .rhs_info = */ {
|
| 191 |
/* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
|
|
|
|
| 229 |
/* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32,
|
| 230 |
/* .packed_size = */ kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32,
|
| 231 |
/* .pack_func = */ kai_run_lhs_quant_pack_qsi8d32p_f32,
|
|
|
|
| 232 |
},
|
| 233 |
/* .rhs_info = */ {
|
| 234 |
/* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
|
ggml/src/ggml-cpu/kleidiai/kernels.h
CHANGED
|
@@ -40,7 +40,6 @@ struct lhs_packing_info {
|
|
| 40 |
size_t (*packed_size)(size_t m, size_t k, size_t bl, size_t mr, size_t kr, size_t sr);
|
| 41 |
void (*pack_func)(size_t m, size_t k, size_t bl, size_t mr, size_t kr, size_t sr, size_t m_idx_start, const float* lhs,
|
| 42 |
size_t lhs_stride, void* lhs_packed);
|
| 43 |
-
bool require_aligned_m_idx;
|
| 44 |
};
|
| 45 |
|
| 46 |
struct rhs_packing_info {
|
|
|
|
| 40 |
size_t (*packed_size)(size_t m, size_t k, size_t bl, size_t mr, size_t kr, size_t sr);
|
| 41 |
void (*pack_func)(size_t m, size_t k, size_t bl, size_t mr, size_t kr, size_t sr, size_t m_idx_start, const float* lhs,
|
| 42 |
size_t lhs_stride, void* lhs_packed);
|
|
|
|
| 43 |
};
|
| 44 |
|
| 45 |
struct rhs_packing_info {
|
ggml/src/ggml-cpu/kleidiai/kleidiai.cpp
CHANGED
|
@@ -124,8 +124,7 @@ class tensor_traits : public ggml::cpu::tensor_traits {
|
|
| 124 |
size_t sr = kernel->get_sr();
|
| 125 |
|
| 126 |
// Calculate number of columns to be processed per thread
|
| 127 |
-
const
|
| 128 |
-
const size_t num_m_per_thread = use_multithread ? kai_roundup(m, nth) / nth : m;
|
| 129 |
const size_t m_start = ith * num_m_per_thread;
|
| 130 |
size_t m_to_process = num_m_per_thread;
|
| 131 |
if ((m_start + m_to_process) > m) {
|
|
@@ -135,11 +134,11 @@ class tensor_traits : public ggml::cpu::tensor_traits {
|
|
| 135 |
if(m_start < m) {
|
| 136 |
// Transform LHS
|
| 137 |
const size_t src_stride = src1->nb[1];
|
| 138 |
-
const float * src_ptr = reinterpret_cast<const float *>(lhs + lhs_info->get_offset(
|
| 139 |
const size_t lhs_packed_offset = lhs_info->get_packed_offset(m_start, k, QK4_0, mr, kr, sr);
|
| 140 |
void * lhs_packed_ptr = static_cast<void *>(lhs_packed + lhs_packed_offset);
|
| 141 |
|
| 142 |
-
lhs_info->pack_func(m_to_process, k, QK4_0, mr, kr, sr,
|
| 143 |
}
|
| 144 |
|
| 145 |
ggml_barrier(params->threadpool);
|
|
|
|
| 124 |
size_t sr = kernel->get_sr();
|
| 125 |
|
| 126 |
// Calculate number of columns to be processed per thread
|
| 127 |
+
const size_t num_m_per_thread = kai_roundup(m, mr * nth) / nth;
|
|
|
|
| 128 |
const size_t m_start = ith * num_m_per_thread;
|
| 129 |
size_t m_to_process = num_m_per_thread;
|
| 130 |
if ((m_start + m_to_process) > m) {
|
|
|
|
| 134 |
if(m_start < m) {
|
| 135 |
// Transform LHS
|
| 136 |
const size_t src_stride = src1->nb[1];
|
| 137 |
+
const float * src_ptr = reinterpret_cast<const float *>(lhs + lhs_info->get_offset(m_start, dst->src[1]->nb[1]));
|
| 138 |
const size_t lhs_packed_offset = lhs_info->get_packed_offset(m_start, k, QK4_0, mr, kr, sr);
|
| 139 |
void * lhs_packed_ptr = static_cast<void *>(lhs_packed + lhs_packed_offset);
|
| 140 |
|
| 141 |
+
lhs_info->pack_func(m_to_process, k, QK4_0, mr, kr, sr, 0, src_ptr, src_stride, lhs_packed_ptr);
|
| 142 |
}
|
| 143 |
|
| 144 |
ggml_barrier(params->threadpool);
|