Dan Johansson commited on
Commit
9b4460a
·
1 Parent(s): 33f8316

ggml-cpu : update KleidiAI to v1.5.0 (llama/12568)

Browse files

ggml-cpu : bug fix related to KleidiAI LHS packing

Signed-off-by: Dan Johansson <[email protected]>

ggml/src/ggml-cpu/CMakeLists.txt CHANGED
@@ -357,9 +357,9 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
357
 
358
  # Fetch KleidiAI sources:
359
  include(FetchContent)
360
- set(KLEIDIAI_COMMIT_TAG "v1.3.0")
361
  set(KLEIDIAI_DOWNLOAD_URL "https://github.com/ARM-software/kleidiai/archive/refs/tags/${KLEIDIAI_COMMIT_TAG}.tar.gz")
362
- set(KLEIDIAI_ARCHIVE_MD5 "060bd2dc64642b091f461cc8dd7426d9")
363
 
364
  if (POLICY CMP0135)
365
  cmake_policy(SET CMP0135 NEW)
 
357
 
358
  # Fetch KleidiAI sources:
359
  include(FetchContent)
360
+ set(KLEIDIAI_COMMIT_TAG "v1.5.0")
361
  set(KLEIDIAI_DOWNLOAD_URL "https://github.com/ARM-software/kleidiai/archive/refs/tags/${KLEIDIAI_COMMIT_TAG}.tar.gz")
362
+ set(KLEIDIAI_ARCHIVE_MD5 "ea22e1aefb800e9bc8c74d91633cc58e")
363
 
364
  if (POLICY CMP0135)
365
  cmake_policy(SET CMP0135 NEW)
ggml/src/ggml-cpu/kleidiai/kernels.cpp CHANGED
@@ -51,11 +51,10 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
51
  /* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot,
52
  },
53
  /* .lhs_info = */ {
54
- /* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32,
55
- /* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32,
56
  /* .packed_size = */ kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32_neon,
57
  /* .pack_func = */ kai_run_lhs_quant_pack_qsi8d32p_f32_neon,
58
- /* .require_aligned_m_idx = */ true,
59
  },
60
  /* .rhs_info = */ {
61
  /* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon,
@@ -100,7 +99,6 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
100
  /* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32,
101
  /* .packed_size = */ kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32,
102
  /* .pack_func = */ kai_run_lhs_quant_pack_qsi8d32p_f32,
103
- /* .require_aligned_m_idx = */ false,
104
  },
105
  /* .rhs_info = */ {
106
  /* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
@@ -144,7 +142,6 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
144
  /* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32,
145
  /* .packed_size = */ kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32,
146
  /* .pack_func = */ kai_run_lhs_quant_pack_qsi8d32p_f32,
147
- /* .require_aligned_m_idx = */ false,
148
  },
149
  /* .rhs_info = */ {
150
  /* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
@@ -189,7 +186,6 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
189
  /* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32,
190
  /* .packed_size = */ kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32,
191
  /* .pack_func = */ kai_run_lhs_quant_pack_qsi8d32p_f32,
192
- /* .require_aligned_m_idx = */ false,
193
  },
194
  /* .rhs_info = */ {
195
  /* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
@@ -233,7 +229,6 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
233
  /* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32,
234
  /* .packed_size = */ kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32,
235
  /* .pack_func = */ kai_run_lhs_quant_pack_qsi8d32p_f32,
236
- /* .require_aligned_m_idx = */ false,
237
  },
238
  /* .rhs_info = */ {
239
  /* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
 
51
  /* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot,
52
  },
53
  /* .lhs_info = */ {
54
+ /* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32_neon,
55
+ /* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32_neon,
56
  /* .packed_size = */ kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32_neon,
57
  /* .pack_func = */ kai_run_lhs_quant_pack_qsi8d32p_f32_neon,
 
58
  },
59
  /* .rhs_info = */ {
60
  /* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon,
 
99
  /* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32,
100
  /* .packed_size = */ kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32,
101
  /* .pack_func = */ kai_run_lhs_quant_pack_qsi8d32p_f32,
 
102
  },
103
  /* .rhs_info = */ {
104
  /* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
 
142
  /* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32,
143
  /* .packed_size = */ kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32,
144
  /* .pack_func = */ kai_run_lhs_quant_pack_qsi8d32p_f32,
 
145
  },
146
  /* .rhs_info = */ {
147
  /* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
 
186
  /* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32,
187
  /* .packed_size = */ kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32,
188
  /* .pack_func = */ kai_run_lhs_quant_pack_qsi8d32p_f32,
 
189
  },
190
  /* .rhs_info = */ {
191
  /* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
 
229
  /* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32,
230
  /* .packed_size = */ kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32,
231
  /* .pack_func = */ kai_run_lhs_quant_pack_qsi8d32p_f32,
 
232
  },
233
  /* .rhs_info = */ {
234
  /* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
ggml/src/ggml-cpu/kleidiai/kernels.h CHANGED
@@ -40,7 +40,6 @@ struct lhs_packing_info {
40
  size_t (*packed_size)(size_t m, size_t k, size_t bl, size_t mr, size_t kr, size_t sr);
41
  void (*pack_func)(size_t m, size_t k, size_t bl, size_t mr, size_t kr, size_t sr, size_t m_idx_start, const float* lhs,
42
  size_t lhs_stride, void* lhs_packed);
43
- bool require_aligned_m_idx;
44
  };
45
 
46
  struct rhs_packing_info {
 
40
  size_t (*packed_size)(size_t m, size_t k, size_t bl, size_t mr, size_t kr, size_t sr);
41
  void (*pack_func)(size_t m, size_t k, size_t bl, size_t mr, size_t kr, size_t sr, size_t m_idx_start, const float* lhs,
42
  size_t lhs_stride, void* lhs_packed);
 
43
  };
44
 
45
  struct rhs_packing_info {
ggml/src/ggml-cpu/kleidiai/kleidiai.cpp CHANGED
@@ -124,8 +124,7 @@ class tensor_traits : public ggml::cpu::tensor_traits {
124
  size_t sr = kernel->get_sr();
125
 
126
  // Calculate number of columns to be processed per thread
127
- const bool use_multithread = lhs_info->require_aligned_m_idx && m <= mr ? false : true;
128
- const size_t num_m_per_thread = use_multithread ? kai_roundup(m, nth) / nth : m;
129
  const size_t m_start = ith * num_m_per_thread;
130
  size_t m_to_process = num_m_per_thread;
131
  if ((m_start + m_to_process) > m) {
@@ -135,11 +134,11 @@ class tensor_traits : public ggml::cpu::tensor_traits {
135
  if(m_start < m) {
136
  // Transform LHS
137
  const size_t src_stride = src1->nb[1];
138
- const float * src_ptr = reinterpret_cast<const float *>(lhs + lhs_info->get_offset(0, dst->src[1]->nb[1]));
139
  const size_t lhs_packed_offset = lhs_info->get_packed_offset(m_start, k, QK4_0, mr, kr, sr);
140
  void * lhs_packed_ptr = static_cast<void *>(lhs_packed + lhs_packed_offset);
141
 
142
- lhs_info->pack_func(m_to_process, k, QK4_0, mr, kr, sr, m_start, src_ptr, src_stride, lhs_packed_ptr);
143
  }
144
 
145
  ggml_barrier(params->threadpool);
 
124
  size_t sr = kernel->get_sr();
125
 
126
  // Calculate number of columns to be processed per thread
127
+ const size_t num_m_per_thread = kai_roundup(m, mr * nth) / nth;
 
128
  const size_t m_start = ith * num_m_per_thread;
129
  size_t m_to_process = num_m_per_thread;
130
  if ((m_start + m_to_process) > m) {
 
134
  if(m_start < m) {
135
  // Transform LHS
136
  const size_t src_stride = src1->nb[1];
137
+ const float * src_ptr = reinterpret_cast<const float *>(lhs + lhs_info->get_offset(m_start, dst->src[1]->nb[1]));
138
  const size_t lhs_packed_offset = lhs_info->get_packed_offset(m_start, k, QK4_0, mr, kr, sr);
139
  void * lhs_packed_ptr = static_cast<void *>(lhs_packed + lhs_packed_offset);
140
 
141
+ lhs_info->pack_func(m_to_process, k, QK4_0, mr, kr, sr, 0, src_ptr, src_stride, lhs_packed_ptr);
142
  }
143
 
144
  ggml_barrier(params->threadpool);