Charles Xu commited on
Commit
43ba97c
·
1 Parent(s): 0be0329

kleidiai: add support for get_rows (llama/14676)

Browse files

* kleidiai: add support for get_rows

* apply fixes based on code review

* apply more fixes based on code review

ggml/src/ggml-cpu/CMakeLists.txt CHANGED
@@ -496,9 +496,9 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
496
 
497
  # Fetch KleidiAI sources:
498
  include(FetchContent)
499
- set(KLEIDIAI_COMMIT_TAG "v1.9.0")
500
  set(KLEIDIAI_DOWNLOAD_URL "https://github.com/ARM-software/kleidiai/archive/refs/tags/${KLEIDIAI_COMMIT_TAG}.tar.gz")
501
- set(KLEIDIAI_ARCHIVE_MD5 "2a8e1bb55d201557553545536489a017")
502
 
503
  if (POLICY CMP0135)
504
  cmake_policy(SET CMP0135 NEW)
 
496
 
497
  # Fetch KleidiAI sources:
498
  include(FetchContent)
499
+ set(KLEIDIAI_COMMIT_TAG "v1.11.0")
500
  set(KLEIDIAI_DOWNLOAD_URL "https://github.com/ARM-software/kleidiai/archive/refs/tags/${KLEIDIAI_COMMIT_TAG}.tar.gz")
501
+ set(KLEIDIAI_ARCHIVE_MD5 "3fe9e5ab964c375c53839296eb71eaa2")
502
 
503
  if (POLICY CMP0135)
504
  cmake_policy(SET CMP0135 NEW)
ggml/src/ggml-cpu/kleidiai/kernels.cpp CHANGED
@@ -22,9 +22,94 @@
22
 
23
  #include "kai_common.h"
24
 
 
 
25
  #include "kernels.h"
26
 
27
  #define NELEMS(x) sizeof(x) / sizeof(*x)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
  static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
29
  #if defined(__ARM_FEATURE_SME)
30
  {
@@ -63,8 +148,10 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
63
  /* .pack_func = */ kai_run_lhs_quant_pack_qsi8d32p_f32_neon,
64
  },
65
  /* .rhs_info = */ {
66
- /* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon,
67
- /* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon,
 
 
68
  },
69
  /* .required_cpu = */ CPU_FEATURE_SME,
70
  /* .lhs_type = */ GGML_TYPE_F32,
@@ -107,8 +194,10 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
107
  /* .pack_func = */ kai_run_lhs_pack_bf16p2vlx2_f32_sme,
108
  },
109
  /* .rhs_info = */ {
110
- /* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme,
111
- /* .pack_func = */ kai_run_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme,
 
 
112
  },
113
  /* .required_cpu = */ CPU_FEATURE_SME,
114
  /* .lhs_type = */ GGML_TYPE_F32,
@@ -154,8 +243,10 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
154
  /* .pack_func = */ kai_run_lhs_quant_pack_qsi8d32p_f32,
155
  },
156
  /* .rhs_info = */ {
157
- /* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
158
- /* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
 
 
159
  },
160
  /* .required_cpu = */ CPU_FEATURE_DOTPROD,
161
  /* .lhs_type = */ GGML_TYPE_F32,
@@ -200,8 +291,10 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
200
  /* .pack_func = */ kai_run_lhs_quant_pack_qsi8d32p_f32,
201
  },
202
  /* .rhs_info = */ {
203
- /* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
204
- /* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
 
 
205
  },
206
  /* .required_cpu = */ CPU_FEATURE_DOTPROD | CPU_FEATURE_I8MM,
207
  /* .lhs_type = */ GGML_TYPE_F32,
@@ -247,8 +340,10 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
247
  /* .pack_func = */ kai_run_lhs_quant_pack_qsi8d32p_f32,
248
  },
249
  /* .rhs_info = */ {
250
- /* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
251
- /* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
 
 
252
  },
253
  /* .required_cpu = */ CPU_FEATURE_DOTPROD | CPU_FEATURE_I8MM,
254
  /* .lhs_type = */ GGML_TYPE_F32,
@@ -293,8 +388,10 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
293
  /* .pack_func = */ kai_run_lhs_quant_pack_qsi8d32p_f32,
294
  },
295
  /* .rhs_info = */ {
296
- /* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
297
- /* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
 
 
298
  },
299
  /* .required_cpu = */ CPU_FEATURE_DOTPROD,
300
  /* .lhs_type = */ GGML_TYPE_F32,
 
22
 
23
  #include "kai_common.h"
24
 
25
+ #include "simd-mappings.h"
26
+
27
  #include "kernels.h"
28
 
29
  #define NELEMS(x) sizeof(x) / sizeof(*x)
30
+
31
+ static const size_t INT4_PER_BYTE = 2;
32
+ static const size_t INT4_BITS = 4;
33
+ static const int Q4_0_ZERO_POINT = 8;
34
+ const size_t INT4_PER_UINT16 = 4;
35
+
36
+ static void dequantize_row_qsi4c32pscalef16(
37
+ const void *packed_data,
38
+ int32_t row_idx,
39
+ int64_t nc,
40
+ float *out,
41
+ size_t nr_pack,
42
+ size_t packed_row_stride,
43
+ size_t kr,
44
+ size_t bl,
45
+ size_t num_bytes_multiplier
46
+ ) {
47
+ size_t group_idx = row_idx / nr_pack;
48
+ size_t row_in_group = row_idx % nr_pack;
49
+ const uint8_t *packed_group = (const uint8_t *)packed_data + group_idx * packed_row_stride;
50
+ size_t num_blocks = nc / bl;
51
+ const uint8_t *block_ptr = packed_group;
52
+
53
+ for (size_t b = 0; b < num_blocks; ++b) {
54
+ uint16_t scale_f16 = *((const uint16_t *)(block_ptr + row_in_group * num_bytes_multiplier));
55
+ float scale = GGML_CPU_FP16_TO_FP32(scale_f16);
56
+
57
+ const uint8_t *segment_ptr = block_ptr + nr_pack * num_bytes_multiplier;
58
+ size_t num_segments = bl / kr;
59
+ size_t num_bytes_per_segment = kr / INT4_PER_BYTE;
60
+
61
+ for (size_t s = 0; s < num_segments; ++s) {
62
+ const uint8_t *seg_base = segment_ptr + s * nr_pack * num_bytes_per_segment;
63
+ const uint8_t *qbytes = seg_base + row_in_group * num_bytes_per_segment;
64
+ for (size_t k = 0; k < num_bytes_per_segment; ++k) {
65
+ uint8_t byte = qbytes[k] ^ 0x88;
66
+ int x0 = (byte & 0x0F) - Q4_0_ZERO_POINT;
67
+ int x1 = (byte >> INT4_BITS) - Q4_0_ZERO_POINT;
68
+ out[b * bl + s * num_bytes_per_segment + k] = x0 * scale;
69
+ out[b * bl + s * num_bytes_per_segment + k + bl/2] = x1 * scale;
70
+ }
71
+ }
72
+ block_ptr += nr_pack * num_bytes_multiplier + num_segments * nr_pack * num_bytes_per_segment;
73
+ }
74
+ }
75
+
76
+ static void dequantize_row_qsi4c32ps1s0scalef16(
77
+ const void *packed_data,
78
+ int32_t row_idx,
79
+ int64_t k,
80
+ float *out,
81
+ size_t nr,
82
+ size_t packed_row_stride,
83
+ size_t kr,
84
+ size_t bl,
85
+ size_t num_bytes_multiplier
86
+ ) {
87
+ const size_t num_blocks = k / bl;
88
+ const size_t bl4 = bl / INT4_PER_UINT16;
89
+
90
+ size_t group_idx = row_idx / nr;
91
+ size_t row_in_group = row_idx % nr;
92
+
93
+ const uint8_t *packed_group = (const uint8_t *)packed_data + group_idx * packed_row_stride;
94
+ const uint16_t *qdata = (const uint16_t *)packed_group;
95
+ const uint16_t *scales = (const uint16_t *)(packed_group + packed_row_stride - (nr * num_blocks * num_bytes_multiplier));
96
+
97
+ for (size_t block_idx = 0; block_idx < num_blocks; ++block_idx) {
98
+ uint16_t scale_f16 = scales[row_in_group + block_idx * nr];
99
+ float scale = GGML_CPU_FP16_TO_FP32(scale_f16);
100
+
101
+ for (size_t bl4_idx = 0; bl4_idx < bl4; ++bl4_idx) {
102
+ uint16_t q = qdata[(block_idx * bl4 + bl4_idx) * nr + row_in_group];
103
+
104
+ for (size_t qidx = 0; qidx < INT4_PER_UINT16; ++qidx) {
105
+ int v = ((q >> (qidx * 4)) & 0xF) - Q4_0_ZERO_POINT;
106
+ out[block_idx * bl + bl4_idx * INT4_BITS + qidx] = v * scale;
107
+ }
108
+ }
109
+ }
110
+ GGML_UNUSED(kr);
111
+ }
112
+
113
  static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
114
  #if defined(__ARM_FEATURE_SME)
115
  {
 
148
  /* .pack_func = */ kai_run_lhs_quant_pack_qsi8d32p_f32_neon,
149
  },
150
  /* .rhs_info = */ {
151
+ /* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon,
152
+ /* .packed_stride = */ kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon,
153
+ /* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon,
154
+ /* .to_float = */ dequantize_row_qsi4c32ps1s0scalef16,
155
  },
156
  /* .required_cpu = */ CPU_FEATURE_SME,
157
  /* .lhs_type = */ GGML_TYPE_F32,
 
194
  /* .pack_func = */ kai_run_lhs_pack_bf16p2vlx2_f32_sme,
195
  },
196
  /* .rhs_info = */ {
197
+ /* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme,
198
+ /* .packed_stride = */ NULL,
199
+ /* .pack_func = */ kai_run_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme,
200
+ /* .to_float = */ NULL,
201
  },
202
  /* .required_cpu = */ CPU_FEATURE_SME,
203
  /* .lhs_type = */ GGML_TYPE_F32,
 
243
  /* .pack_func = */ kai_run_lhs_quant_pack_qsi8d32p_f32,
244
  },
245
  /* .rhs_info = */ {
246
+ /* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
247
+ /* .packed_stride = */ kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
248
+ /* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
249
+ /* .to_float = */ dequantize_row_qsi4c32pscalef16,
250
  },
251
  /* .required_cpu = */ CPU_FEATURE_DOTPROD,
252
  /* .lhs_type = */ GGML_TYPE_F32,
 
291
  /* .pack_func = */ kai_run_lhs_quant_pack_qsi8d32p_f32,
292
  },
293
  /* .rhs_info = */ {
294
+ /* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
295
+ /* .packed_stride = */ kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
296
+ /* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
297
+ /* .to_float = */ dequantize_row_qsi4c32pscalef16,
298
  },
299
  /* .required_cpu = */ CPU_FEATURE_DOTPROD | CPU_FEATURE_I8MM,
300
  /* .lhs_type = */ GGML_TYPE_F32,
 
340
  /* .pack_func = */ kai_run_lhs_quant_pack_qsi8d32p_f32,
341
  },
342
  /* .rhs_info = */ {
343
+ /* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
344
+ /* .packed_stride = */ kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
345
+ /* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
346
+ /* .to_float = */ dequantize_row_qsi4c32pscalef16,
347
  },
348
  /* .required_cpu = */ CPU_FEATURE_DOTPROD | CPU_FEATURE_I8MM,
349
  /* .lhs_type = */ GGML_TYPE_F32,
 
388
  /* .pack_func = */ kai_run_lhs_quant_pack_qsi8d32p_f32,
389
  },
390
  /* .rhs_info = */ {
391
+ /* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
392
+ /* .packed_stride = */ kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
393
+ /* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
394
+ /* .to_float = */ dequantize_row_qsi4c32pscalef16,
395
  },
396
  /* .required_cpu = */ CPU_FEATURE_DOTPROD,
397
  /* .lhs_type = */ GGML_TYPE_F32,
ggml/src/ggml-cpu/kleidiai/kernels.h CHANGED
@@ -71,12 +71,15 @@ struct rhs_packing_info {
71
  std::function<size_t(size_t n, size_t k, size_t nr, size_t kr, size_t bl)>,
72
  std::function<size_t(size_t n, size_t k)>
73
  > packed_size;
 
74
  std::variant<
75
  std::function<void(size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, size_t bl, const uint8_t* rhs,
76
  const float* bias, void* rhs_packed, size_t extra_bytes, const struct kai_rhs_pack_qs4cxs1s0_param* params)>,
77
  std::function<void(size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, size_t rhs_stride, const void* rhs,
78
  const void* bias, const void* scale, void* rhs_packed, size_t extra_bytes, const void* params)>
79
  > pack_func;
 
 
80
  };
81
 
82
  struct ggml_kleidiai_kernels {
 
71
  std::function<size_t(size_t n, size_t k, size_t nr, size_t kr, size_t bl)>,
72
  std::function<size_t(size_t n, size_t k)>
73
  > packed_size;
74
+ size_t (*packed_stride)(size_t k, size_t nr, size_t kr, size_t bl);
75
  std::variant<
76
  std::function<void(size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, size_t bl, const uint8_t* rhs,
77
  const float* bias, void* rhs_packed, size_t extra_bytes, const struct kai_rhs_pack_qs4cxs1s0_param* params)>,
78
  std::function<void(size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, size_t rhs_stride, const void* rhs,
79
  const void* bias, const void* scale, void* rhs_packed, size_t extra_bytes, const void* params)>
80
  > pack_func;
81
+ void (*to_float)(const void *packed_data, int32_t row_idx, int64_t nc, float *out, size_t nr_pack, size_t packed_row_stride,
82
+ size_t kr, size_t bl, size_t num_bytes_multiplier);
83
  };
84
 
85
  struct ggml_kleidiai_kernels {
ggml/src/ggml-cpu/kleidiai/kleidiai.cpp CHANGED
@@ -40,6 +40,17 @@ struct ggml_kleidiai_context {
40
  ggml_kleidiai_kernels * kernels;
41
  } static ctx = { CPU_FEATURE_NONE, NULL };
42
 
 
 
 
 
 
 
 
 
 
 
 
43
  static void init_kleidiai_context(void) {
44
 
45
  ggml_critical_section_start();
@@ -62,6 +73,11 @@ static void init_kleidiai_context(void) {
62
  ctx.features |= ggml_cpu_has_sme() ? CPU_FEATURE_SME : CPU_FEATURE_NONE;
63
  }
64
  ctx.kernels = ggml_kleidiai_select_kernels_q4_0(ctx.features);
 
 
 
 
 
65
  }
66
  ggml_critical_section_end();
67
  }
@@ -102,6 +118,9 @@ static void transpose_f32kxn_f16nxk(size_t n, size_t k, float * dst, const uint1
102
 
103
  class tensor_traits : public ggml::cpu::tensor_traits {
104
  bool work_size(int /* n_threads */, const struct ggml_tensor * op, size_t & size) override {
 
 
 
105
  ggml_kleidiai_kernels *kernels = ggml_kleidiai_select_kernels(ctx.features, op);
106
  GGML_ASSERT(kernels);
107
  kernel_info * kernel = op->src[1]->ne[1] == 1 ? &kernels->gemv : &kernels->gemm;
@@ -135,6 +154,10 @@ class tensor_traits : public ggml::cpu::tensor_traits {
135
  } else if (dst->src[0]->type == GGML_TYPE_F16) {
136
  return compute_forward_kv_cache(params, dst);
137
  }
 
 
 
 
138
  }
139
  return false;
140
  }
@@ -270,6 +293,8 @@ class tensor_traits : public ggml::cpu::tensor_traits {
270
  }
271
 
272
  bool compute_forward_q4_0(struct ggml_compute_params * params, struct ggml_tensor * dst) {
 
 
273
  const ggml_tensor * src0 = dst->src[0];
274
  const ggml_tensor * src1 = dst->src[1];
275
 
@@ -342,8 +367,49 @@ class tensor_traits : public ggml::cpu::tensor_traits {
342
  return true;
343
  }
344
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
345
  public:
346
  int repack(struct ggml_tensor * tensor, const void * data, size_t data_size) {
 
347
  GGML_ASSERT(ctx.kernels);
348
  const size_t n = tensor->ne[1];
349
  const size_t k = tensor->ne[0];
@@ -351,17 +417,12 @@ public:
351
  size_t kr = ctx.kernels->gemm.get_kr();
352
  size_t sr = ctx.kernels->gemm.get_sr();
353
 
354
- #ifndef NDEBUG
355
- const size_t repacked_size = variant_call<size_t>(ctx.kernels->rhs_info.packed_size, n, k, nr, kr, QK4_0);
356
- GGML_ASSERT(repacked_size <= data_size && "repacked size larger than the packed size!");
357
- #endif
358
  struct kai_rhs_pack_qs4cxs1s0_param params;
359
  params.lhs_zero_point = 1;
360
  params.rhs_zero_point = 8;
361
  variant_call<void>(ctx.kernels->rhs_info.pack_func, 1, n, k, nr, kr, sr, QK4_0, (const uint8_t*)data, nullptr, tensor->data, 0, &params);
362
 
363
  return 0;
364
-
365
  GGML_UNUSED(data_size);
366
  }
367
  };
@@ -375,8 +436,8 @@ static ggml::cpu::tensor_traits * get_tensor_traits(ggml_backend_buffer_t, struc
375
  static enum ggml_status ggml_backend_cpu_kleidiai_buffer_init_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor) {
376
  tensor->extra = (void *) ggml::cpu::kleidiai::get_tensor_traits(buffer, tensor);
377
 
378
- GGML_UNUSED(buffer);
379
  return GGML_STATUS_SUCCESS;
 
380
  }
381
 
382
  static void ggml_backend_cpu_kleidiai_buffer_set_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor,
@@ -418,18 +479,35 @@ static size_t ggml_backend_cpu_kleidiai_buffer_type_get_alignment(ggml_backend_b
418
  GGML_UNUSED(buft);
419
  }
420
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
421
  namespace ggml::cpu::kleidiai {
422
  class extra_buffer_type : ggml::cpu::extra_buffer_type {
423
  bool supports_op(ggml_backend_dev_t, const struct ggml_tensor * op) override {
424
- if (op->op == GGML_OP_MUL_MAT &&
425
  op->src[0]->type == GGML_TYPE_Q4_0 &&
426
  op->src[0]->buffer &&
427
  (ggml_n_dims(op->src[0]) == 2) &&
428
  op->src[0]->buffer->buft == ggml_backend_cpu_kleidiai_buffer_type() && ctx.kernels) {
 
 
 
429
  if (op->src[1]->buffer && !ggml_backend_buft_is_host(op->src[1]->buffer->buft)) {
430
  return false;
431
  }
432
- if (op->src[1]->type == GGML_TYPE_F32 &&
433
  ggml_ne(op->src[1], 2) == 1 && ggml_ne(op->src[1], 3) == 1) {
434
  return true;
435
  }
@@ -438,7 +516,7 @@ class extra_buffer_type : ggml::cpu::extra_buffer_type {
438
  }
439
 
440
  ggml::cpu::tensor_traits * get_tensor_traits(const struct ggml_tensor * op) override {
441
- if (op->op == GGML_OP_MUL_MAT) {
442
  if (op->src[0]->buffer && op->src[0]->buffer->buft == ggml_backend_cpu_kleidiai_buffer_type()) {
443
  return (ggml::cpu::tensor_traits *) op->src[0]->extra;
444
  }
@@ -469,7 +547,7 @@ ggml_backend_buffer_type_t ggml_backend_cpu_kleidiai_buffer_type(void) {
469
  /* .alloc_buffer = */ ggml_backend_cpu_kleidiai_buffer_type_alloc_buffer,
470
  /* .get_alignment = */ ggml_backend_cpu_kleidiai_buffer_type_get_alignment,
471
  /* .get_max_size = */ nullptr, // defaults to SIZE_MAX
472
- /* .get_alloc_size = */ nullptr, // defaults to ggml_nbytes
473
  /* .is_host = */ nullptr,
474
  },
475
  /* .device = */ ggml_backend_reg_dev_get(ggml_backend_cpu_reg(), 0),
 
40
  ggml_kleidiai_kernels * kernels;
41
  } static ctx = { CPU_FEATURE_NONE, NULL };
42
 
43
+ static const char* cpu_feature_to_string(cpu_feature f) {
44
+ switch (f) {
45
+ case CPU_FEATURE_NONE: return "NONE";
46
+ case CPU_FEATURE_DOTPROD: return "DOTPROD";
47
+ case CPU_FEATURE_I8MM: return "I8MM";
48
+ case CPU_FEATURE_SVE: return "SVE";
49
+ case CPU_FEATURE_SME: return "SME";
50
+ default: return "UNKNOWN";
51
+ }
52
+ }
53
+
54
  static void init_kleidiai_context(void) {
55
 
56
  ggml_critical_section_start();
 
73
  ctx.features |= ggml_cpu_has_sme() ? CPU_FEATURE_SME : CPU_FEATURE_NONE;
74
  }
75
  ctx.kernels = ggml_kleidiai_select_kernels_q4_0(ctx.features);
76
+ #ifndef NDEBUG
77
+ if (ctx.kernels) {
78
+ GGML_LOG_DEBUG("kleidiai: using kernel with CPU feature %s\n", cpu_feature_to_string(ctx.kernels->required_cpu));
79
+ }
80
+ #endif
81
  }
82
  ggml_critical_section_end();
83
  }
 
118
 
119
  class tensor_traits : public ggml::cpu::tensor_traits {
120
  bool work_size(int /* n_threads */, const struct ggml_tensor * op, size_t & size) override {
121
+ if (op->op != GGML_OP_MUL_MAT) {
122
+ return false;
123
+ }
124
  ggml_kleidiai_kernels *kernels = ggml_kleidiai_select_kernels(ctx.features, op);
125
  GGML_ASSERT(kernels);
126
  kernel_info * kernel = op->src[1]->ne[1] == 1 ? &kernels->gemv : &kernels->gemm;
 
154
  } else if (dst->src[0]->type == GGML_TYPE_F16) {
155
  return compute_forward_kv_cache(params, dst);
156
  }
157
+ } else if (dst->op == GGML_OP_GET_ROWS) {
158
+ if (dst->src[0]->type == GGML_TYPE_Q4_0) {
159
+ return compute_forward_get_rows(params, dst);
160
+ }
161
  }
162
  return false;
163
  }
 
293
  }
294
 
295
  bool compute_forward_q4_0(struct ggml_compute_params * params, struct ggml_tensor * dst) {
296
+ GGML_ASSERT(dst->src[0]->type == GGML_TYPE_Q4_0);
297
+
298
  const ggml_tensor * src0 = dst->src[0];
299
  const ggml_tensor * src1 = dst->src[1];
300
 
 
367
  return true;
368
  }
369
 
370
+ bool compute_forward_get_rows(struct ggml_compute_params * params, struct ggml_tensor * dst) {
371
+ GGML_ASSERT(dst->src[0]->type == GGML_TYPE_Q4_0);
372
+ GGML_ASSERT(ctx.kernels);
373
+
374
+ const ggml_tensor * src0 = dst->src[0];
375
+ const ggml_tensor * src1 = dst->src[1];
376
+
377
+ GGML_TENSOR_BINARY_OP_LOCALS
378
+
379
+ rhs_packing_info * rhs_info = &ctx.kernels->rhs_info;
380
+ kernel_info * kernel = &ctx.kernels->gemm;
381
+
382
+ const int64_t nc = ne00;
383
+ const int64_t nr = ggml_nelements(src1);
384
+
385
+ const size_t block_rows = kernel->get_nr();
386
+ const size_t kr = kernel->get_kr();
387
+
388
+ const size_t num_bytes_multiplier = sizeof(uint16_t);
389
+ const size_t packed_stride = rhs_info->packed_stride(nc, block_rows, kr, QK4_0);
390
+
391
+ const int ith = params->ith;
392
+ const int nth = params->nth;
393
+
394
+ const int dr = (nr + nth - 1) / nth;
395
+ const int ir0 = dr * ith;
396
+ const int ir1 = MIN(ir0 + dr, nr);
397
+
398
+ for (int64_t i = ir0; i < ir1; ++i) {
399
+ GGML_ASSERT(src1->type == GGML_TYPE_I32);
400
+ int64_t row_idx = ((const int32_t *)src1->data)[i];
401
+ GGML_ASSERT(row_idx >= 0 && row_idx < src0->ne[1]);
402
+
403
+ float *out = (float *)((char *)dst->data + i * nb1);
404
+ rhs_info->to_float(src0->data, row_idx, nc, out, block_rows, packed_stride, kr, QK4_0, num_bytes_multiplier);
405
+ }
406
+
407
+ return true;
408
+ }
409
+
410
  public:
411
  int repack(struct ggml_tensor * tensor, const void * data, size_t data_size) {
412
+ GGML_ASSERT(tensor->type == GGML_TYPE_Q4_0);
413
  GGML_ASSERT(ctx.kernels);
414
  const size_t n = tensor->ne[1];
415
  const size_t k = tensor->ne[0];
 
417
  size_t kr = ctx.kernels->gemm.get_kr();
418
  size_t sr = ctx.kernels->gemm.get_sr();
419
 
 
 
 
 
420
  struct kai_rhs_pack_qs4cxs1s0_param params;
421
  params.lhs_zero_point = 1;
422
  params.rhs_zero_point = 8;
423
  variant_call<void>(ctx.kernels->rhs_info.pack_func, 1, n, k, nr, kr, sr, QK4_0, (const uint8_t*)data, nullptr, tensor->data, 0, &params);
424
 
425
  return 0;
 
426
  GGML_UNUSED(data_size);
427
  }
428
  };
 
436
  static enum ggml_status ggml_backend_cpu_kleidiai_buffer_init_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor) {
437
  tensor->extra = (void *) ggml::cpu::kleidiai::get_tensor_traits(buffer, tensor);
438
 
 
439
  return GGML_STATUS_SUCCESS;
440
+ GGML_UNUSED(buffer);
441
  }
442
 
443
  static void ggml_backend_cpu_kleidiai_buffer_set_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor,
 
479
  GGML_UNUSED(buft);
480
  }
481
 
482
+ static size_t ggml_backend_cpu_kleidiai_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const struct ggml_tensor * tensor) {
483
+ GGML_ASSERT(tensor->type == GGML_TYPE_Q4_0);
484
+ GGML_ASSERT(ctx.kernels);
485
+
486
+ const size_t n = tensor->ne[1];
487
+ const size_t k = tensor->ne[0];
488
+ const size_t nr = ctx.kernels->gemm.get_nr();
489
+ const size_t kr = ctx.kernels->gemm.get_kr();
490
+
491
+ return variant_call<size_t>(ctx.kernels->rhs_info.packed_size, n, k, nr, kr, QK4_0);
492
+
493
+ GGML_UNUSED(buft);
494
+ }
495
+
496
  namespace ggml::cpu::kleidiai {
497
  class extra_buffer_type : ggml::cpu::extra_buffer_type {
498
  bool supports_op(ggml_backend_dev_t, const struct ggml_tensor * op) override {
499
+ if ((op->op == GGML_OP_MUL_MAT || op->op == GGML_OP_GET_ROWS) &&
500
  op->src[0]->type == GGML_TYPE_Q4_0 &&
501
  op->src[0]->buffer &&
502
  (ggml_n_dims(op->src[0]) == 2) &&
503
  op->src[0]->buffer->buft == ggml_backend_cpu_kleidiai_buffer_type() && ctx.kernels) {
504
+ if (op->op == GGML_OP_GET_ROWS && op->src[1]->ne[0] != 8) {
505
+ return false;
506
+ }
507
  if (op->src[1]->buffer && !ggml_backend_buft_is_host(op->src[1]->buffer->buft)) {
508
  return false;
509
  }
510
+ if ((op->src[1]->type == GGML_TYPE_F32 || op->src[1]->type == GGML_TYPE_I32) &&
511
  ggml_ne(op->src[1], 2) == 1 && ggml_ne(op->src[1], 3) == 1) {
512
  return true;
513
  }
 
516
  }
517
 
518
  ggml::cpu::tensor_traits * get_tensor_traits(const struct ggml_tensor * op) override {
519
+ if (op->op == GGML_OP_MUL_MAT || op->op == GGML_OP_GET_ROWS) {
520
  if (op->src[0]->buffer && op->src[0]->buffer->buft == ggml_backend_cpu_kleidiai_buffer_type()) {
521
  return (ggml::cpu::tensor_traits *) op->src[0]->extra;
522
  }
 
547
  /* .alloc_buffer = */ ggml_backend_cpu_kleidiai_buffer_type_alloc_buffer,
548
  /* .get_alignment = */ ggml_backend_cpu_kleidiai_buffer_type_get_alignment,
549
  /* .get_max_size = */ nullptr, // defaults to SIZE_MAX
550
+ /* .get_alloc_size = */ ggml_backend_cpu_kleidiai_buffer_type_get_alloc_size,
551
  /* .is_host = */ nullptr,
552
  },
553
  /* .device = */ ggml_backend_reg_dev_get(ggml_backend_cpu_reg(), 0),