Spaces:
Running
Running
Charles Xu
commited on
Commit
·
43ba97c
1
Parent(s):
0be0329
kleidiai: add support for get_rows (llama/14676)
Browse files* kleidiai: add support for get_rows
* apply fixes based on code review
* apply more fixes based on code review
ggml/src/ggml-cpu/CMakeLists.txt
CHANGED
|
@@ -496,9 +496,9 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
|
|
| 496 |
|
| 497 |
# Fetch KleidiAI sources:
|
| 498 |
include(FetchContent)
|
| 499 |
-
set(KLEIDIAI_COMMIT_TAG "v1.
|
| 500 |
set(KLEIDIAI_DOWNLOAD_URL "https://github.com/ARM-software/kleidiai/archive/refs/tags/${KLEIDIAI_COMMIT_TAG}.tar.gz")
|
| 501 |
-
set(KLEIDIAI_ARCHIVE_MD5 "
|
| 502 |
|
| 503 |
if (POLICY CMP0135)
|
| 504 |
cmake_policy(SET CMP0135 NEW)
|
|
|
|
| 496 |
|
| 497 |
# Fetch KleidiAI sources:
|
| 498 |
include(FetchContent)
|
| 499 |
+
set(KLEIDIAI_COMMIT_TAG "v1.11.0")
|
| 500 |
set(KLEIDIAI_DOWNLOAD_URL "https://github.com/ARM-software/kleidiai/archive/refs/tags/${KLEIDIAI_COMMIT_TAG}.tar.gz")
|
| 501 |
+
set(KLEIDIAI_ARCHIVE_MD5 "3fe9e5ab964c375c53839296eb71eaa2")
|
| 502 |
|
| 503 |
if (POLICY CMP0135)
|
| 504 |
cmake_policy(SET CMP0135 NEW)
|
ggml/src/ggml-cpu/kleidiai/kernels.cpp
CHANGED
|
@@ -22,9 +22,94 @@
|
|
| 22 |
|
| 23 |
#include "kai_common.h"
|
| 24 |
|
|
|
|
|
|
|
| 25 |
#include "kernels.h"
|
| 26 |
|
| 27 |
#define NELEMS(x) sizeof(x) / sizeof(*x)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 28 |
static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
|
| 29 |
#if defined(__ARM_FEATURE_SME)
|
| 30 |
{
|
|
@@ -63,8 +148,10 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
|
|
| 63 |
/* .pack_func = */ kai_run_lhs_quant_pack_qsi8d32p_f32_neon,
|
| 64 |
},
|
| 65 |
/* .rhs_info = */ {
|
| 66 |
-
/* .packed_size
|
| 67 |
-
/* .
|
|
|
|
|
|
|
| 68 |
},
|
| 69 |
/* .required_cpu = */ CPU_FEATURE_SME,
|
| 70 |
/* .lhs_type = */ GGML_TYPE_F32,
|
|
@@ -107,8 +194,10 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
|
|
| 107 |
/* .pack_func = */ kai_run_lhs_pack_bf16p2vlx2_f32_sme,
|
| 108 |
},
|
| 109 |
/* .rhs_info = */ {
|
| 110 |
-
/* .packed_size
|
| 111 |
-
/* .
|
|
|
|
|
|
|
| 112 |
},
|
| 113 |
/* .required_cpu = */ CPU_FEATURE_SME,
|
| 114 |
/* .lhs_type = */ GGML_TYPE_F32,
|
|
@@ -154,8 +243,10 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
|
|
| 154 |
/* .pack_func = */ kai_run_lhs_quant_pack_qsi8d32p_f32,
|
| 155 |
},
|
| 156 |
/* .rhs_info = */ {
|
| 157 |
-
/* .packed_size
|
| 158 |
-
/* .
|
|
|
|
|
|
|
| 159 |
},
|
| 160 |
/* .required_cpu = */ CPU_FEATURE_DOTPROD,
|
| 161 |
/* .lhs_type = */ GGML_TYPE_F32,
|
|
@@ -200,8 +291,10 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
|
|
| 200 |
/* .pack_func = */ kai_run_lhs_quant_pack_qsi8d32p_f32,
|
| 201 |
},
|
| 202 |
/* .rhs_info = */ {
|
| 203 |
-
/* .packed_size
|
| 204 |
-
/* .
|
|
|
|
|
|
|
| 205 |
},
|
| 206 |
/* .required_cpu = */ CPU_FEATURE_DOTPROD | CPU_FEATURE_I8MM,
|
| 207 |
/* .lhs_type = */ GGML_TYPE_F32,
|
|
@@ -247,8 +340,10 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
|
|
| 247 |
/* .pack_func = */ kai_run_lhs_quant_pack_qsi8d32p_f32,
|
| 248 |
},
|
| 249 |
/* .rhs_info = */ {
|
| 250 |
-
/* .packed_size
|
| 251 |
-
/* .
|
|
|
|
|
|
|
| 252 |
},
|
| 253 |
/* .required_cpu = */ CPU_FEATURE_DOTPROD | CPU_FEATURE_I8MM,
|
| 254 |
/* .lhs_type = */ GGML_TYPE_F32,
|
|
@@ -293,8 +388,10 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
|
|
| 293 |
/* .pack_func = */ kai_run_lhs_quant_pack_qsi8d32p_f32,
|
| 294 |
},
|
| 295 |
/* .rhs_info = */ {
|
| 296 |
-
/* .packed_size
|
| 297 |
-
/* .
|
|
|
|
|
|
|
| 298 |
},
|
| 299 |
/* .required_cpu = */ CPU_FEATURE_DOTPROD,
|
| 300 |
/* .lhs_type = */ GGML_TYPE_F32,
|
|
|
|
| 22 |
|
| 23 |
#include "kai_common.h"
|
| 24 |
|
| 25 |
+
#include "simd-mappings.h"
|
| 26 |
+
|
| 27 |
#include "kernels.h"
|
| 28 |
|
| 29 |
#define NELEMS(x) sizeof(x) / sizeof(*x)
|
| 30 |
+
|
| 31 |
+
static const size_t INT4_PER_BYTE = 2;
|
| 32 |
+
static const size_t INT4_BITS = 4;
|
| 33 |
+
static const int Q4_0_ZERO_POINT = 8;
|
| 34 |
+
const size_t INT4_PER_UINT16 = 4;
|
| 35 |
+
|
| 36 |
+
static void dequantize_row_qsi4c32pscalef16(
|
| 37 |
+
const void *packed_data,
|
| 38 |
+
int32_t row_idx,
|
| 39 |
+
int64_t nc,
|
| 40 |
+
float *out,
|
| 41 |
+
size_t nr_pack,
|
| 42 |
+
size_t packed_row_stride,
|
| 43 |
+
size_t kr,
|
| 44 |
+
size_t bl,
|
| 45 |
+
size_t num_bytes_multiplier
|
| 46 |
+
) {
|
| 47 |
+
size_t group_idx = row_idx / nr_pack;
|
| 48 |
+
size_t row_in_group = row_idx % nr_pack;
|
| 49 |
+
const uint8_t *packed_group = (const uint8_t *)packed_data + group_idx * packed_row_stride;
|
| 50 |
+
size_t num_blocks = nc / bl;
|
| 51 |
+
const uint8_t *block_ptr = packed_group;
|
| 52 |
+
|
| 53 |
+
for (size_t b = 0; b < num_blocks; ++b) {
|
| 54 |
+
uint16_t scale_f16 = *((const uint16_t *)(block_ptr + row_in_group * num_bytes_multiplier));
|
| 55 |
+
float scale = GGML_CPU_FP16_TO_FP32(scale_f16);
|
| 56 |
+
|
| 57 |
+
const uint8_t *segment_ptr = block_ptr + nr_pack * num_bytes_multiplier;
|
| 58 |
+
size_t num_segments = bl / kr;
|
| 59 |
+
size_t num_bytes_per_segment = kr / INT4_PER_BYTE;
|
| 60 |
+
|
| 61 |
+
for (size_t s = 0; s < num_segments; ++s) {
|
| 62 |
+
const uint8_t *seg_base = segment_ptr + s * nr_pack * num_bytes_per_segment;
|
| 63 |
+
const uint8_t *qbytes = seg_base + row_in_group * num_bytes_per_segment;
|
| 64 |
+
for (size_t k = 0; k < num_bytes_per_segment; ++k) {
|
| 65 |
+
uint8_t byte = qbytes[k] ^ 0x88;
|
| 66 |
+
int x0 = (byte & 0x0F) - Q4_0_ZERO_POINT;
|
| 67 |
+
int x1 = (byte >> INT4_BITS) - Q4_0_ZERO_POINT;
|
| 68 |
+
out[b * bl + s * num_bytes_per_segment + k] = x0 * scale;
|
| 69 |
+
out[b * bl + s * num_bytes_per_segment + k + bl/2] = x1 * scale;
|
| 70 |
+
}
|
| 71 |
+
}
|
| 72 |
+
block_ptr += nr_pack * num_bytes_multiplier + num_segments * nr_pack * num_bytes_per_segment;
|
| 73 |
+
}
|
| 74 |
+
}
|
| 75 |
+
|
| 76 |
+
static void dequantize_row_qsi4c32ps1s0scalef16(
|
| 77 |
+
const void *packed_data,
|
| 78 |
+
int32_t row_idx,
|
| 79 |
+
int64_t k,
|
| 80 |
+
float *out,
|
| 81 |
+
size_t nr,
|
| 82 |
+
size_t packed_row_stride,
|
| 83 |
+
size_t kr,
|
| 84 |
+
size_t bl,
|
| 85 |
+
size_t num_bytes_multiplier
|
| 86 |
+
) {
|
| 87 |
+
const size_t num_blocks = k / bl;
|
| 88 |
+
const size_t bl4 = bl / INT4_PER_UINT16;
|
| 89 |
+
|
| 90 |
+
size_t group_idx = row_idx / nr;
|
| 91 |
+
size_t row_in_group = row_idx % nr;
|
| 92 |
+
|
| 93 |
+
const uint8_t *packed_group = (const uint8_t *)packed_data + group_idx * packed_row_stride;
|
| 94 |
+
const uint16_t *qdata = (const uint16_t *)packed_group;
|
| 95 |
+
const uint16_t *scales = (const uint16_t *)(packed_group + packed_row_stride - (nr * num_blocks * num_bytes_multiplier));
|
| 96 |
+
|
| 97 |
+
for (size_t block_idx = 0; block_idx < num_blocks; ++block_idx) {
|
| 98 |
+
uint16_t scale_f16 = scales[row_in_group + block_idx * nr];
|
| 99 |
+
float scale = GGML_CPU_FP16_TO_FP32(scale_f16);
|
| 100 |
+
|
| 101 |
+
for (size_t bl4_idx = 0; bl4_idx < bl4; ++bl4_idx) {
|
| 102 |
+
uint16_t q = qdata[(block_idx * bl4 + bl4_idx) * nr + row_in_group];
|
| 103 |
+
|
| 104 |
+
for (size_t qidx = 0; qidx < INT4_PER_UINT16; ++qidx) {
|
| 105 |
+
int v = ((q >> (qidx * 4)) & 0xF) - Q4_0_ZERO_POINT;
|
| 106 |
+
out[block_idx * bl + bl4_idx * INT4_BITS + qidx] = v * scale;
|
| 107 |
+
}
|
| 108 |
+
}
|
| 109 |
+
}
|
| 110 |
+
GGML_UNUSED(kr);
|
| 111 |
+
}
|
| 112 |
+
|
| 113 |
static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
|
| 114 |
#if defined(__ARM_FEATURE_SME)
|
| 115 |
{
|
|
|
|
| 148 |
/* .pack_func = */ kai_run_lhs_quant_pack_qsi8d32p_f32_neon,
|
| 149 |
},
|
| 150 |
/* .rhs_info = */ {
|
| 151 |
+
/* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon,
|
| 152 |
+
/* .packed_stride = */ kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon,
|
| 153 |
+
/* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon,
|
| 154 |
+
/* .to_float = */ dequantize_row_qsi4c32ps1s0scalef16,
|
| 155 |
},
|
| 156 |
/* .required_cpu = */ CPU_FEATURE_SME,
|
| 157 |
/* .lhs_type = */ GGML_TYPE_F32,
|
|
|
|
| 194 |
/* .pack_func = */ kai_run_lhs_pack_bf16p2vlx2_f32_sme,
|
| 195 |
},
|
| 196 |
/* .rhs_info = */ {
|
| 197 |
+
/* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme,
|
| 198 |
+
/* .packed_stride = */ NULL,
|
| 199 |
+
/* .pack_func = */ kai_run_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme,
|
| 200 |
+
/* .to_float = */ NULL,
|
| 201 |
},
|
| 202 |
/* .required_cpu = */ CPU_FEATURE_SME,
|
| 203 |
/* .lhs_type = */ GGML_TYPE_F32,
|
|
|
|
| 243 |
/* .pack_func = */ kai_run_lhs_quant_pack_qsi8d32p_f32,
|
| 244 |
},
|
| 245 |
/* .rhs_info = */ {
|
| 246 |
+
/* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
|
| 247 |
+
/* .packed_stride = */ kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
|
| 248 |
+
/* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
|
| 249 |
+
/* .to_float = */ dequantize_row_qsi4c32pscalef16,
|
| 250 |
},
|
| 251 |
/* .required_cpu = */ CPU_FEATURE_DOTPROD,
|
| 252 |
/* .lhs_type = */ GGML_TYPE_F32,
|
|
|
|
| 291 |
/* .pack_func = */ kai_run_lhs_quant_pack_qsi8d32p_f32,
|
| 292 |
},
|
| 293 |
/* .rhs_info = */ {
|
| 294 |
+
/* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
|
| 295 |
+
/* .packed_stride = */ kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
|
| 296 |
+
/* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
|
| 297 |
+
/* .to_float = */ dequantize_row_qsi4c32pscalef16,
|
| 298 |
},
|
| 299 |
/* .required_cpu = */ CPU_FEATURE_DOTPROD | CPU_FEATURE_I8MM,
|
| 300 |
/* .lhs_type = */ GGML_TYPE_F32,
|
|
|
|
| 340 |
/* .pack_func = */ kai_run_lhs_quant_pack_qsi8d32p_f32,
|
| 341 |
},
|
| 342 |
/* .rhs_info = */ {
|
| 343 |
+
/* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
|
| 344 |
+
/* .packed_stride = */ kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
|
| 345 |
+
/* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
|
| 346 |
+
/* .to_float = */ dequantize_row_qsi4c32pscalef16,
|
| 347 |
},
|
| 348 |
/* .required_cpu = */ CPU_FEATURE_DOTPROD | CPU_FEATURE_I8MM,
|
| 349 |
/* .lhs_type = */ GGML_TYPE_F32,
|
|
|
|
| 388 |
/* .pack_func = */ kai_run_lhs_quant_pack_qsi8d32p_f32,
|
| 389 |
},
|
| 390 |
/* .rhs_info = */ {
|
| 391 |
+
/* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
|
| 392 |
+
/* .packed_stride = */ kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
|
| 393 |
+
/* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
|
| 394 |
+
/* .to_float = */ dequantize_row_qsi4c32pscalef16,
|
| 395 |
},
|
| 396 |
/* .required_cpu = */ CPU_FEATURE_DOTPROD,
|
| 397 |
/* .lhs_type = */ GGML_TYPE_F32,
|
ggml/src/ggml-cpu/kleidiai/kernels.h
CHANGED
|
@@ -71,12 +71,15 @@ struct rhs_packing_info {
|
|
| 71 |
std::function<size_t(size_t n, size_t k, size_t nr, size_t kr, size_t bl)>,
|
| 72 |
std::function<size_t(size_t n, size_t k)>
|
| 73 |
> packed_size;
|
|
|
|
| 74 |
std::variant<
|
| 75 |
std::function<void(size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, size_t bl, const uint8_t* rhs,
|
| 76 |
const float* bias, void* rhs_packed, size_t extra_bytes, const struct kai_rhs_pack_qs4cxs1s0_param* params)>,
|
| 77 |
std::function<void(size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, size_t rhs_stride, const void* rhs,
|
| 78 |
const void* bias, const void* scale, void* rhs_packed, size_t extra_bytes, const void* params)>
|
| 79 |
> pack_func;
|
|
|
|
|
|
|
| 80 |
};
|
| 81 |
|
| 82 |
struct ggml_kleidiai_kernels {
|
|
|
|
| 71 |
std::function<size_t(size_t n, size_t k, size_t nr, size_t kr, size_t bl)>,
|
| 72 |
std::function<size_t(size_t n, size_t k)>
|
| 73 |
> packed_size;
|
| 74 |
+
size_t (*packed_stride)(size_t k, size_t nr, size_t kr, size_t bl);
|
| 75 |
std::variant<
|
| 76 |
std::function<void(size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, size_t bl, const uint8_t* rhs,
|
| 77 |
const float* bias, void* rhs_packed, size_t extra_bytes, const struct kai_rhs_pack_qs4cxs1s0_param* params)>,
|
| 78 |
std::function<void(size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, size_t rhs_stride, const void* rhs,
|
| 79 |
const void* bias, const void* scale, void* rhs_packed, size_t extra_bytes, const void* params)>
|
| 80 |
> pack_func;
|
| 81 |
+
void (*to_float)(const void *packed_data, int32_t row_idx, int64_t nc, float *out, size_t nr_pack, size_t packed_row_stride,
|
| 82 |
+
size_t kr, size_t bl, size_t num_bytes_multiplier);
|
| 83 |
};
|
| 84 |
|
| 85 |
struct ggml_kleidiai_kernels {
|
ggml/src/ggml-cpu/kleidiai/kleidiai.cpp
CHANGED
|
@@ -40,6 +40,17 @@ struct ggml_kleidiai_context {
|
|
| 40 |
ggml_kleidiai_kernels * kernels;
|
| 41 |
} static ctx = { CPU_FEATURE_NONE, NULL };
|
| 42 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 43 |
static void init_kleidiai_context(void) {
|
| 44 |
|
| 45 |
ggml_critical_section_start();
|
|
@@ -62,6 +73,11 @@ static void init_kleidiai_context(void) {
|
|
| 62 |
ctx.features |= ggml_cpu_has_sme() ? CPU_FEATURE_SME : CPU_FEATURE_NONE;
|
| 63 |
}
|
| 64 |
ctx.kernels = ggml_kleidiai_select_kernels_q4_0(ctx.features);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 65 |
}
|
| 66 |
ggml_critical_section_end();
|
| 67 |
}
|
|
@@ -102,6 +118,9 @@ static void transpose_f32kxn_f16nxk(size_t n, size_t k, float * dst, const uint1
|
|
| 102 |
|
| 103 |
class tensor_traits : public ggml::cpu::tensor_traits {
|
| 104 |
bool work_size(int /* n_threads */, const struct ggml_tensor * op, size_t & size) override {
|
|
|
|
|
|
|
|
|
|
| 105 |
ggml_kleidiai_kernels *kernels = ggml_kleidiai_select_kernels(ctx.features, op);
|
| 106 |
GGML_ASSERT(kernels);
|
| 107 |
kernel_info * kernel = op->src[1]->ne[1] == 1 ? &kernels->gemv : &kernels->gemm;
|
|
@@ -135,6 +154,10 @@ class tensor_traits : public ggml::cpu::tensor_traits {
|
|
| 135 |
} else if (dst->src[0]->type == GGML_TYPE_F16) {
|
| 136 |
return compute_forward_kv_cache(params, dst);
|
| 137 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
| 138 |
}
|
| 139 |
return false;
|
| 140 |
}
|
|
@@ -270,6 +293,8 @@ class tensor_traits : public ggml::cpu::tensor_traits {
|
|
| 270 |
}
|
| 271 |
|
| 272 |
bool compute_forward_q4_0(struct ggml_compute_params * params, struct ggml_tensor * dst) {
|
|
|
|
|
|
|
| 273 |
const ggml_tensor * src0 = dst->src[0];
|
| 274 |
const ggml_tensor * src1 = dst->src[1];
|
| 275 |
|
|
@@ -342,8 +367,49 @@ class tensor_traits : public ggml::cpu::tensor_traits {
|
|
| 342 |
return true;
|
| 343 |
}
|
| 344 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 345 |
public:
|
| 346 |
int repack(struct ggml_tensor * tensor, const void * data, size_t data_size) {
|
|
|
|
| 347 |
GGML_ASSERT(ctx.kernels);
|
| 348 |
const size_t n = tensor->ne[1];
|
| 349 |
const size_t k = tensor->ne[0];
|
|
@@ -351,17 +417,12 @@ public:
|
|
| 351 |
size_t kr = ctx.kernels->gemm.get_kr();
|
| 352 |
size_t sr = ctx.kernels->gemm.get_sr();
|
| 353 |
|
| 354 |
-
#ifndef NDEBUG
|
| 355 |
-
const size_t repacked_size = variant_call<size_t>(ctx.kernels->rhs_info.packed_size, n, k, nr, kr, QK4_0);
|
| 356 |
-
GGML_ASSERT(repacked_size <= data_size && "repacked size larger than the packed size!");
|
| 357 |
-
#endif
|
| 358 |
struct kai_rhs_pack_qs4cxs1s0_param params;
|
| 359 |
params.lhs_zero_point = 1;
|
| 360 |
params.rhs_zero_point = 8;
|
| 361 |
variant_call<void>(ctx.kernels->rhs_info.pack_func, 1, n, k, nr, kr, sr, QK4_0, (const uint8_t*)data, nullptr, tensor->data, 0, ¶ms);
|
| 362 |
|
| 363 |
return 0;
|
| 364 |
-
|
| 365 |
GGML_UNUSED(data_size);
|
| 366 |
}
|
| 367 |
};
|
|
@@ -375,8 +436,8 @@ static ggml::cpu::tensor_traits * get_tensor_traits(ggml_backend_buffer_t, struc
|
|
| 375 |
static enum ggml_status ggml_backend_cpu_kleidiai_buffer_init_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor) {
|
| 376 |
tensor->extra = (void *) ggml::cpu::kleidiai::get_tensor_traits(buffer, tensor);
|
| 377 |
|
| 378 |
-
GGML_UNUSED(buffer);
|
| 379 |
return GGML_STATUS_SUCCESS;
|
|
|
|
| 380 |
}
|
| 381 |
|
| 382 |
static void ggml_backend_cpu_kleidiai_buffer_set_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor,
|
|
@@ -418,18 +479,35 @@ static size_t ggml_backend_cpu_kleidiai_buffer_type_get_alignment(ggml_backend_b
|
|
| 418 |
GGML_UNUSED(buft);
|
| 419 |
}
|
| 420 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 421 |
namespace ggml::cpu::kleidiai {
|
| 422 |
class extra_buffer_type : ggml::cpu::extra_buffer_type {
|
| 423 |
bool supports_op(ggml_backend_dev_t, const struct ggml_tensor * op) override {
|
| 424 |
-
if (op->op == GGML_OP_MUL_MAT &&
|
| 425 |
op->src[0]->type == GGML_TYPE_Q4_0 &&
|
| 426 |
op->src[0]->buffer &&
|
| 427 |
(ggml_n_dims(op->src[0]) == 2) &&
|
| 428 |
op->src[0]->buffer->buft == ggml_backend_cpu_kleidiai_buffer_type() && ctx.kernels) {
|
|
|
|
|
|
|
|
|
|
| 429 |
if (op->src[1]->buffer && !ggml_backend_buft_is_host(op->src[1]->buffer->buft)) {
|
| 430 |
return false;
|
| 431 |
}
|
| 432 |
-
if (op->src[1]->type == GGML_TYPE_F32 &&
|
| 433 |
ggml_ne(op->src[1], 2) == 1 && ggml_ne(op->src[1], 3) == 1) {
|
| 434 |
return true;
|
| 435 |
}
|
|
@@ -438,7 +516,7 @@ class extra_buffer_type : ggml::cpu::extra_buffer_type {
|
|
| 438 |
}
|
| 439 |
|
| 440 |
ggml::cpu::tensor_traits * get_tensor_traits(const struct ggml_tensor * op) override {
|
| 441 |
-
if (op->op == GGML_OP_MUL_MAT) {
|
| 442 |
if (op->src[0]->buffer && op->src[0]->buffer->buft == ggml_backend_cpu_kleidiai_buffer_type()) {
|
| 443 |
return (ggml::cpu::tensor_traits *) op->src[0]->extra;
|
| 444 |
}
|
|
@@ -469,7 +547,7 @@ ggml_backend_buffer_type_t ggml_backend_cpu_kleidiai_buffer_type(void) {
|
|
| 469 |
/* .alloc_buffer = */ ggml_backend_cpu_kleidiai_buffer_type_alloc_buffer,
|
| 470 |
/* .get_alignment = */ ggml_backend_cpu_kleidiai_buffer_type_get_alignment,
|
| 471 |
/* .get_max_size = */ nullptr, // defaults to SIZE_MAX
|
| 472 |
-
/* .get_alloc_size = */
|
| 473 |
/* .is_host = */ nullptr,
|
| 474 |
},
|
| 475 |
/* .device = */ ggml_backend_reg_dev_get(ggml_backend_cpu_reg(), 0),
|
|
|
|
| 40 |
ggml_kleidiai_kernels * kernels;
|
| 41 |
} static ctx = { CPU_FEATURE_NONE, NULL };
|
| 42 |
|
| 43 |
+
static const char* cpu_feature_to_string(cpu_feature f) {
|
| 44 |
+
switch (f) {
|
| 45 |
+
case CPU_FEATURE_NONE: return "NONE";
|
| 46 |
+
case CPU_FEATURE_DOTPROD: return "DOTPROD";
|
| 47 |
+
case CPU_FEATURE_I8MM: return "I8MM";
|
| 48 |
+
case CPU_FEATURE_SVE: return "SVE";
|
| 49 |
+
case CPU_FEATURE_SME: return "SME";
|
| 50 |
+
default: return "UNKNOWN";
|
| 51 |
+
}
|
| 52 |
+
}
|
| 53 |
+
|
| 54 |
static void init_kleidiai_context(void) {
|
| 55 |
|
| 56 |
ggml_critical_section_start();
|
|
|
|
| 73 |
ctx.features |= ggml_cpu_has_sme() ? CPU_FEATURE_SME : CPU_FEATURE_NONE;
|
| 74 |
}
|
| 75 |
ctx.kernels = ggml_kleidiai_select_kernels_q4_0(ctx.features);
|
| 76 |
+
#ifndef NDEBUG
|
| 77 |
+
if (ctx.kernels) {
|
| 78 |
+
GGML_LOG_DEBUG("kleidiai: using kernel with CPU feature %s\n", cpu_feature_to_string(ctx.kernels->required_cpu));
|
| 79 |
+
}
|
| 80 |
+
#endif
|
| 81 |
}
|
| 82 |
ggml_critical_section_end();
|
| 83 |
}
|
|
|
|
| 118 |
|
| 119 |
class tensor_traits : public ggml::cpu::tensor_traits {
|
| 120 |
bool work_size(int /* n_threads */, const struct ggml_tensor * op, size_t & size) override {
|
| 121 |
+
if (op->op != GGML_OP_MUL_MAT) {
|
| 122 |
+
return false;
|
| 123 |
+
}
|
| 124 |
ggml_kleidiai_kernels *kernels = ggml_kleidiai_select_kernels(ctx.features, op);
|
| 125 |
GGML_ASSERT(kernels);
|
| 126 |
kernel_info * kernel = op->src[1]->ne[1] == 1 ? &kernels->gemv : &kernels->gemm;
|
|
|
|
| 154 |
} else if (dst->src[0]->type == GGML_TYPE_F16) {
|
| 155 |
return compute_forward_kv_cache(params, dst);
|
| 156 |
}
|
| 157 |
+
} else if (dst->op == GGML_OP_GET_ROWS) {
|
| 158 |
+
if (dst->src[0]->type == GGML_TYPE_Q4_0) {
|
| 159 |
+
return compute_forward_get_rows(params, dst);
|
| 160 |
+
}
|
| 161 |
}
|
| 162 |
return false;
|
| 163 |
}
|
|
|
|
| 293 |
}
|
| 294 |
|
| 295 |
bool compute_forward_q4_0(struct ggml_compute_params * params, struct ggml_tensor * dst) {
|
| 296 |
+
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_Q4_0);
|
| 297 |
+
|
| 298 |
const ggml_tensor * src0 = dst->src[0];
|
| 299 |
const ggml_tensor * src1 = dst->src[1];
|
| 300 |
|
|
|
|
| 367 |
return true;
|
| 368 |
}
|
| 369 |
|
| 370 |
+
bool compute_forward_get_rows(struct ggml_compute_params * params, struct ggml_tensor * dst) {
|
| 371 |
+
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_Q4_0);
|
| 372 |
+
GGML_ASSERT(ctx.kernels);
|
| 373 |
+
|
| 374 |
+
const ggml_tensor * src0 = dst->src[0];
|
| 375 |
+
const ggml_tensor * src1 = dst->src[1];
|
| 376 |
+
|
| 377 |
+
GGML_TENSOR_BINARY_OP_LOCALS
|
| 378 |
+
|
| 379 |
+
rhs_packing_info * rhs_info = &ctx.kernels->rhs_info;
|
| 380 |
+
kernel_info * kernel = &ctx.kernels->gemm;
|
| 381 |
+
|
| 382 |
+
const int64_t nc = ne00;
|
| 383 |
+
const int64_t nr = ggml_nelements(src1);
|
| 384 |
+
|
| 385 |
+
const size_t block_rows = kernel->get_nr();
|
| 386 |
+
const size_t kr = kernel->get_kr();
|
| 387 |
+
|
| 388 |
+
const size_t num_bytes_multiplier = sizeof(uint16_t);
|
| 389 |
+
const size_t packed_stride = rhs_info->packed_stride(nc, block_rows, kr, QK4_0);
|
| 390 |
+
|
| 391 |
+
const int ith = params->ith;
|
| 392 |
+
const int nth = params->nth;
|
| 393 |
+
|
| 394 |
+
const int dr = (nr + nth - 1) / nth;
|
| 395 |
+
const int ir0 = dr * ith;
|
| 396 |
+
const int ir1 = MIN(ir0 + dr, nr);
|
| 397 |
+
|
| 398 |
+
for (int64_t i = ir0; i < ir1; ++i) {
|
| 399 |
+
GGML_ASSERT(src1->type == GGML_TYPE_I32);
|
| 400 |
+
int64_t row_idx = ((const int32_t *)src1->data)[i];
|
| 401 |
+
GGML_ASSERT(row_idx >= 0 && row_idx < src0->ne[1]);
|
| 402 |
+
|
| 403 |
+
float *out = (float *)((char *)dst->data + i * nb1);
|
| 404 |
+
rhs_info->to_float(src0->data, row_idx, nc, out, block_rows, packed_stride, kr, QK4_0, num_bytes_multiplier);
|
| 405 |
+
}
|
| 406 |
+
|
| 407 |
+
return true;
|
| 408 |
+
}
|
| 409 |
+
|
| 410 |
public:
|
| 411 |
int repack(struct ggml_tensor * tensor, const void * data, size_t data_size) {
|
| 412 |
+
GGML_ASSERT(tensor->type == GGML_TYPE_Q4_0);
|
| 413 |
GGML_ASSERT(ctx.kernels);
|
| 414 |
const size_t n = tensor->ne[1];
|
| 415 |
const size_t k = tensor->ne[0];
|
|
|
|
| 417 |
size_t kr = ctx.kernels->gemm.get_kr();
|
| 418 |
size_t sr = ctx.kernels->gemm.get_sr();
|
| 419 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 420 |
struct kai_rhs_pack_qs4cxs1s0_param params;
|
| 421 |
params.lhs_zero_point = 1;
|
| 422 |
params.rhs_zero_point = 8;
|
| 423 |
variant_call<void>(ctx.kernels->rhs_info.pack_func, 1, n, k, nr, kr, sr, QK4_0, (const uint8_t*)data, nullptr, tensor->data, 0, ¶ms);
|
| 424 |
|
| 425 |
return 0;
|
|
|
|
| 426 |
GGML_UNUSED(data_size);
|
| 427 |
}
|
| 428 |
};
|
|
|
|
| 436 |
static enum ggml_status ggml_backend_cpu_kleidiai_buffer_init_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor) {
|
| 437 |
tensor->extra = (void *) ggml::cpu::kleidiai::get_tensor_traits(buffer, tensor);
|
| 438 |
|
|
|
|
| 439 |
return GGML_STATUS_SUCCESS;
|
| 440 |
+
GGML_UNUSED(buffer);
|
| 441 |
}
|
| 442 |
|
| 443 |
static void ggml_backend_cpu_kleidiai_buffer_set_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor,
|
|
|
|
| 479 |
GGML_UNUSED(buft);
|
| 480 |
}
|
| 481 |
|
| 482 |
+
static size_t ggml_backend_cpu_kleidiai_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const struct ggml_tensor * tensor) {
|
| 483 |
+
GGML_ASSERT(tensor->type == GGML_TYPE_Q4_0);
|
| 484 |
+
GGML_ASSERT(ctx.kernels);
|
| 485 |
+
|
| 486 |
+
const size_t n = tensor->ne[1];
|
| 487 |
+
const size_t k = tensor->ne[0];
|
| 488 |
+
const size_t nr = ctx.kernels->gemm.get_nr();
|
| 489 |
+
const size_t kr = ctx.kernels->gemm.get_kr();
|
| 490 |
+
|
| 491 |
+
return variant_call<size_t>(ctx.kernels->rhs_info.packed_size, n, k, nr, kr, QK4_0);
|
| 492 |
+
|
| 493 |
+
GGML_UNUSED(buft);
|
| 494 |
+
}
|
| 495 |
+
|
| 496 |
namespace ggml::cpu::kleidiai {
|
| 497 |
class extra_buffer_type : ggml::cpu::extra_buffer_type {
|
| 498 |
bool supports_op(ggml_backend_dev_t, const struct ggml_tensor * op) override {
|
| 499 |
+
if ((op->op == GGML_OP_MUL_MAT || op->op == GGML_OP_GET_ROWS) &&
|
| 500 |
op->src[0]->type == GGML_TYPE_Q4_0 &&
|
| 501 |
op->src[0]->buffer &&
|
| 502 |
(ggml_n_dims(op->src[0]) == 2) &&
|
| 503 |
op->src[0]->buffer->buft == ggml_backend_cpu_kleidiai_buffer_type() && ctx.kernels) {
|
| 504 |
+
if (op->op == GGML_OP_GET_ROWS && op->src[1]->ne[0] != 8) {
|
| 505 |
+
return false;
|
| 506 |
+
}
|
| 507 |
if (op->src[1]->buffer && !ggml_backend_buft_is_host(op->src[1]->buffer->buft)) {
|
| 508 |
return false;
|
| 509 |
}
|
| 510 |
+
if ((op->src[1]->type == GGML_TYPE_F32 || op->src[1]->type == GGML_TYPE_I32) &&
|
| 511 |
ggml_ne(op->src[1], 2) == 1 && ggml_ne(op->src[1], 3) == 1) {
|
| 512 |
return true;
|
| 513 |
}
|
|
|
|
| 516 |
}
|
| 517 |
|
| 518 |
ggml::cpu::tensor_traits * get_tensor_traits(const struct ggml_tensor * op) override {
|
| 519 |
+
if (op->op == GGML_OP_MUL_MAT || op->op == GGML_OP_GET_ROWS) {
|
| 520 |
if (op->src[0]->buffer && op->src[0]->buffer->buft == ggml_backend_cpu_kleidiai_buffer_type()) {
|
| 521 |
return (ggml::cpu::tensor_traits *) op->src[0]->extra;
|
| 522 |
}
|
|
|
|
| 547 |
/* .alloc_buffer = */ ggml_backend_cpu_kleidiai_buffer_type_alloc_buffer,
|
| 548 |
/* .get_alignment = */ ggml_backend_cpu_kleidiai_buffer_type_get_alignment,
|
| 549 |
/* .get_max_size = */ nullptr, // defaults to SIZE_MAX
|
| 550 |
+
/* .get_alloc_size = */ ggml_backend_cpu_kleidiai_buffer_type_get_alloc_size,
|
| 551 |
/* .is_host = */ nullptr,
|
| 552 |
},
|
| 553 |
/* .device = */ ggml_backend_reg_dev_get(ggml_backend_cpu_reg(), 0),
|