Charles Xu commited on
Commit
9de6d81
·
1 Parent(s): 1a1acd2

ggml-cpu: Add CPU backend support for KleidiAI library (llama/11390)

Browse files

* ggml-cpu: Add CPU backend support for KleidiAI library

* Add environmental variable GGML_KLEIDIAI_SME

* Add support for multithread LHS conversion

* Switch kernel selection order to dotprod and i8mm

* updates for review comments

* More updates for review comments

* Reorganize and rename KleidiAI files

* Move ggml-cpu-traits.h to source file

* Update cmake for SME build and add alignment for SME

* Remove append GGML_USE_CPU_KLEIDIAI to the GGML_CDEF_PUBLIC list

ggml/CMakeLists.txt CHANGED
@@ -102,6 +102,7 @@ endif()
102
 
103
  option(GGML_CPU_HBM "ggml: use memkind for CPU HBM" OFF)
104
  option(GGML_CPU_AARCH64 "ggml: use runtime weight conversion of Q4_0 to Q4_X_X" ON)
 
105
  option(GGML_AVX "ggml: enable AVX" ${INS_ENB})
106
  option(GGML_AVX_VNNI "ggml: enable AVX-VNNI" OFF)
107
  option(GGML_AVX2 "ggml: enable AVX2" ${INS_ENB})
 
102
 
103
  option(GGML_CPU_HBM "ggml: use memkind for CPU HBM" OFF)
104
  option(GGML_CPU_AARCH64 "ggml: use runtime weight conversion of Q4_0 to Q4_X_X" ON)
105
+ option(GGML_CPU_KLEIDIAI "ggml: use KleidiAI optimized kernels if applicable" OFF)
106
  option(GGML_AVX "ggml: enable AVX" ${INS_ENB})
107
  option(GGML_AVX_VNNI "ggml: enable AVX-VNNI" OFF)
108
  option(GGML_AVX2 "ggml: enable AVX2" ${INS_ENB})
ggml/include/ggml-cpu.h CHANGED
@@ -95,6 +95,7 @@ extern "C" {
95
  GGML_BACKEND_API int ggml_cpu_has_matmul_int8(void);
96
  GGML_BACKEND_API int ggml_cpu_has_sve (void);
97
  GGML_BACKEND_API int ggml_cpu_get_sve_cnt (void); // sve vector length in bytes
 
98
  // other
99
  GGML_BACKEND_API int ggml_cpu_has_riscv_v (void);
100
  GGML_BACKEND_API int ggml_cpu_has_vsx (void);
 
95
  GGML_BACKEND_API int ggml_cpu_has_matmul_int8(void);
96
  GGML_BACKEND_API int ggml_cpu_has_sve (void);
97
  GGML_BACKEND_API int ggml_cpu_get_sve_cnt (void); // sve vector length in bytes
98
+ GGML_BACKEND_API int ggml_cpu_has_sme (void);
99
  // other
100
  GGML_BACKEND_API int ggml_cpu_has_riscv_v (void);
101
  GGML_BACKEND_API int ggml_cpu_has_vsx (void);
ggml/src/ggml-cpu/CMakeLists.txt CHANGED
@@ -111,14 +111,15 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
111
  function(check_arm_feature tag code)
112
  set(CMAKE_REQUIRED_FLAGS_SAVE ${CMAKE_REQUIRED_FLAGS})
113
  set(CMAKE_REQUIRED_FLAGS "${ARM_MCPU_FLAG}+${tag}")
114
- check_cxx_source_runs(
115
- "${code}"
116
- GGML_MACHINE_SUPPORTS_${tag}
117
- )
118
  if (GGML_MACHINE_SUPPORTS_${tag})
119
  set(ARM_MCPU_FLAG_FIX "${ARM_MCPU_FLAG_FIX}+${tag}" PARENT_SCOPE)
120
  else()
121
- set(ARM_MCPU_FLAG_FIX "${ARM_MCPU_FLAG_FIX}+no${tag}" PARENT_SCOPE)
 
 
 
 
122
  endif()
123
  set(CMAKE_REQUIRED_FLAGS ${CMAKE_REQUIRED_FLAGS_SAVE})
124
  endfunction()
@@ -126,6 +127,7 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
126
  check_arm_feature(dotprod "#include <arm_neon.h>\nint main() { int8x16_t _a, _b; volatile int32x4_t _s = vdotq_s32(_s, _a, _b); return 0; }")
127
  check_arm_feature(i8mm "#include <arm_neon.h>\nint main() { int8x16_t _a, _b; volatile int32x4_t _s = vmmlaq_s32(_s, _a, _b); return 0; }")
128
  check_arm_feature(sve "#include <arm_sve.h>\nint main() { svfloat32_t _a, _b; volatile svfloat32_t _c = svadd_f32_z(svptrue_b8(), _a, _b); return 0; }")
 
129
 
130
  list(APPEND ARCH_FLAGS "${ARM_MCPU_FLAG}${ARM_MCPU_FLAG_FIX}")
131
  else()
@@ -150,7 +152,7 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
150
  if (ARM_FEATURE_RESULT)
151
  message(WARNING "Failed to get ARM features")
152
  else()
153
- foreach(feature DOTPROD SVE MATMUL_INT8 FMA FP16_VECTOR_ARITHMETIC)
154
  string(FIND "${ARM_FEATURE}" "__ARM_FEATURE_${feature} 1" feature_pos)
155
  if (NOT ${feature_pos} EQUAL -1)
156
  message(STATUS "ARM feature ${feature} enabled")
@@ -312,6 +314,94 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
312
  target_compile_definitions(${GGML_CPU_NAME} PRIVATE GGML_USE_CPU_AARCH64)
313
  endif()
314
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
315
  message(STATUS "Adding CPU backend variant ${GGML_CPU_NAME}: ${ARCH_FLAGS} ${ARCH_DEFINITIONS}")
316
  target_sources(${GGML_CPU_NAME} PRIVATE ${GGML_CPU_SOURCES})
317
  target_compile_options(${GGML_CPU_NAME} PRIVATE ${ARCH_FLAGS})
 
111
  function(check_arm_feature tag code)
112
  set(CMAKE_REQUIRED_FLAGS_SAVE ${CMAKE_REQUIRED_FLAGS})
113
  set(CMAKE_REQUIRED_FLAGS "${ARM_MCPU_FLAG}+${tag}")
114
+ check_cxx_source_runs("${code}" GGML_MACHINE_SUPPORTS_${tag})
 
 
 
115
  if (GGML_MACHINE_SUPPORTS_${tag})
116
  set(ARM_MCPU_FLAG_FIX "${ARM_MCPU_FLAG_FIX}+${tag}" PARENT_SCOPE)
117
  else()
118
+ set(CMAKE_REQUIRED_FLAGS "${ARM_MCPU_FLAG}+no${tag}")
119
+ check_cxx_source_compiles("int main() { return 0; }" GGML_MACHINE_SUPPORTS_no${tag})
120
+ if (GGML_MACHINE_SUPPORTS_no${tag})
121
+ set(ARM_MCPU_FLAG_FIX "${ARM_MCPU_FLAG_FIX}+no${tag}" PARENT_SCOPE)
122
+ endif()
123
  endif()
124
  set(CMAKE_REQUIRED_FLAGS ${CMAKE_REQUIRED_FLAGS_SAVE})
125
  endfunction()
 
127
  check_arm_feature(dotprod "#include <arm_neon.h>\nint main() { int8x16_t _a, _b; volatile int32x4_t _s = vdotq_s32(_s, _a, _b); return 0; }")
128
  check_arm_feature(i8mm "#include <arm_neon.h>\nint main() { int8x16_t _a, _b; volatile int32x4_t _s = vmmlaq_s32(_s, _a, _b); return 0; }")
129
  check_arm_feature(sve "#include <arm_sve.h>\nint main() { svfloat32_t _a, _b; volatile svfloat32_t _c = svadd_f32_z(svptrue_b8(), _a, _b); return 0; }")
130
+ check_arm_feature(sme "#include <arm_sme.h>\n__arm_locally_streaming int main() { __asm__ volatile(\"smstart; smstop;\"); return 0; }")
131
 
132
  list(APPEND ARCH_FLAGS "${ARM_MCPU_FLAG}${ARM_MCPU_FLAG_FIX}")
133
  else()
 
152
  if (ARM_FEATURE_RESULT)
153
  message(WARNING "Failed to get ARM features")
154
  else()
155
+ foreach(feature DOTPROD SVE MATMUL_INT8 FMA FP16_VECTOR_ARITHMETIC SME)
156
  string(FIND "${ARM_FEATURE}" "__ARM_FEATURE_${feature} 1" feature_pos)
157
  if (NOT ${feature_pos} EQUAL -1)
158
  message(STATUS "ARM feature ${feature} enabled")
 
314
  target_compile_definitions(${GGML_CPU_NAME} PRIVATE GGML_USE_CPU_AARCH64)
315
  endif()
316
 
317
+ if (GGML_CPU_KLEIDIAI)
318
+ message(STATUS "Using KleidiAI optimized kernels if applicable")
319
+
320
+ # Disable the KleidiAI tests
321
+ set(KLEIDIAI_BUILD_TESTS OFF)
322
+
323
+ # Fetch KleidiAI sources:
324
+ include(FetchContent)
325
+ set(KLEIDIAI_COMMIT_TAG "v1.3.0")
326
+ set(KLEIDIAI_DOWNLOAD_URL "https://github.com/ARM-software/kleidiai/archive/refs/tags/${KLEIDIAI_COMMIT_TAG}.tar.gz")
327
+ set(KLEIDIAI_ARCHIVE_MD5 "060bd2dc64642b091f461cc8dd7426d9")
328
+
329
+ if (POLICY CMP0135)
330
+ cmake_policy(SET CMP0135 NEW)
331
+ endif()
332
+
333
+ FetchContent_Declare(KleidiAI_Download
334
+ URL ${KLEIDIAI_DOWNLOAD_URL}
335
+ DOWNLOAD_EXTRACT_TIMESTAMP NEW
336
+ URL_HASH MD5=${KLEIDIAI_ARCHIVE_MD5})
337
+
338
+ FetchContent_MakeAvailable(KleidiAI_Download)
339
+ FetchContent_GetProperties(KleidiAI_Download
340
+ SOURCE_DIR KLEIDIAI_SRC
341
+ POPULATED KLEIDIAI_POPULATED)
342
+
343
+ if (NOT KLEIDIAI_POPULATED)
344
+ message(FATAL_ERROR "KleidiAI source downloaded failed.")
345
+ endif()
346
+
347
+ add_compile_definitions(GGML_USE_CPU_KLEIDIAI)
348
+
349
+ # Remove kleidiai target after fetching it
350
+ if (TARGET kleidiai)
351
+ set_target_properties(kleidiai PROPERTIES EXCLUDE_FROM_ALL TRUE)
352
+ endif()
353
+
354
+ list(APPEND GGML_CPU_SOURCES
355
+ ggml-cpu/kleidiai/kleidiai.cpp
356
+ ggml-cpu/kleidiai/kernels.cpp
357
+ ggml-cpu/kleidiai/kleidiai.h
358
+ ggml-cpu/kleidiai/kernels.h
359
+ )
360
+
361
+ # KleidiAI
362
+ include_directories(
363
+ ${KLEIDIAI_SRC}/
364
+ ${KLEIDIAI_SRC}/kai/
365
+ ${KLEIDIAI_SRC}/kai/ukernels/
366
+ ${KLEIDIAI_SRC}/kai/ukernels/matmul/
367
+ ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/
368
+ ${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/)
369
+
370
+ set(ARCH_FLAGS_TEMP "${ARCH_FLAGS}")
371
+ if (NOT ARCH_FLAGS_TEMP)
372
+ string(REGEX MATCH "-march=[^ ]+" ARCH_FLAGS_TEMP "${CMAKE_C_FLAGS}")
373
+ endif()
374
+ string(FIND "${ARCH_FLAGS_TEMP}" "+dotprod" DOTPROD_ENABLED)
375
+ string(FIND "${ARCH_FLAGS_TEMP}" "+i8mm" I8MM_ENABLED)
376
+ string(FIND "${ARCH_FLAGS_TEMP}" "+sme" SME_ENABLED)
377
+
378
+ set(PRIVATE_ARCH_FLAGS ${ARCH_FLAGS})
379
+
380
+ list(APPEND GGML_KLEIDIAI_SOURCES ${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_lhs_quant_pack_qsi8d32p_f32.c)
381
+ list(APPEND GGML_KLEIDIAI_SOURCES ${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon.c)
382
+ list(APPEND GGML_KLEIDIAI_SOURCES ${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_lhs_quant_pack_qsi8d32p_f32_neon.c)
383
+ list(APPEND GGML_KLEIDIAI_SOURCES ${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0.c)
384
+
385
+ if (NOT DOTPROD_ENABLED MATCHES -1)
386
+ list(APPEND GGML_KLEIDIAI_SOURCES ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod.c)
387
+ list(APPEND GGML_KLEIDIAI_SOURCES ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod.c)
388
+ list(APPEND GGML_KLEIDIAI_SOURCES ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod.c)
389
+ endif()
390
+
391
+ if (NOT I8MM_ENABLED MATCHES -1)
392
+ list(APPEND GGML_KLEIDIAI_SOURCES ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm.c)
393
+ endif()
394
+
395
+ if (NOT SME_ENABLED MATCHES -1)
396
+ list(APPEND GGML_KLEIDIAI_SOURCES ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa.c)
397
+ list(APPEND GGML_KLEIDIAI_SOURCES ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot.c)
398
+ set(PRIVATE_ARCH_FLAGS "${PRIVATE_ARCH_FLAGS}+sve+sve2")
399
+ endif()
400
+
401
+ set_source_files_properties(${GGML_KLEIDIAI_SOURCES} PROPERTIES COMPILE_OPTIONS "${PRIVATE_ARCH_FLAGS}")
402
+ list(APPEND GGML_CPU_SOURCES ${GGML_KLEIDIAI_SOURCES})
403
+ endif()
404
+
405
  message(STATUS "Adding CPU backend variant ${GGML_CPU_NAME}: ${ARCH_FLAGS} ${ARCH_DEFINITIONS}")
406
  target_sources(${GGML_CPU_NAME} PRIVATE ${GGML_CPU_SOURCES})
407
  target_compile_options(${GGML_CPU_NAME} PRIVATE ${ARCH_FLAGS})
ggml/src/ggml-cpu/ggml-cpu.c CHANGED
@@ -112,7 +112,8 @@ struct ggml_arm_arch_features_type {
112
  int has_i8mm;
113
  int has_sve;
114
  int sve_cnt;
115
- } ggml_arm_arch_features = {-1, -1, -1, -1, 0};
 
116
  #endif
117
 
118
 
@@ -2381,15 +2382,20 @@ bool ggml_is_numa(void) {
2381
  #define HWCAP2_I8MM (1 << 13)
2382
  #endif
2383
 
 
 
 
 
2384
  static void ggml_init_arm_arch_features(void) {
2385
  #if defined(__linux__) && defined(__aarch64__)
2386
  uint32_t hwcap = getauxval(AT_HWCAP);
2387
  uint32_t hwcap2 = getauxval(AT_HWCAP2);
2388
 
2389
- ggml_arm_arch_features.has_neon = !!(hwcap & HWCAP_ASIMD);
2390
  ggml_arm_arch_features.has_dotprod = !!(hwcap & HWCAP_ASIMDDP);
2391
- ggml_arm_arch_features.has_i8mm = !!(hwcap2 & HWCAP2_I8MM);
2392
- ggml_arm_arch_features.has_sve = !!(hwcap & HWCAP_SVE);
 
2393
 
2394
  #if defined(__ARM_FEATURE_SVE)
2395
  ggml_arm_arch_features.sve_cnt = PR_SVE_VL_LEN_MASK & prctl(PR_SVE_GET_VL);
@@ -2412,6 +2418,11 @@ static void ggml_init_arm_arch_features(void) {
2412
  }
2413
  ggml_arm_arch_features.has_i8mm = oldp;
2414
 
 
 
 
 
 
2415
  ggml_arm_arch_features.has_sve = 0;
2416
  ggml_arm_arch_features.sve_cnt = 0;
2417
  #else
@@ -2435,6 +2446,12 @@ static void ggml_init_arm_arch_features(void) {
2435
  ggml_arm_arch_features.has_sve = 0;
2436
  ggml_arm_arch_features.sve_cnt = 0;
2437
  #endif
 
 
 
 
 
 
2438
  #endif
2439
  }
2440
  #endif
@@ -14442,6 +14459,14 @@ int ggml_cpu_get_sve_cnt(void) {
14442
  #endif
14443
  }
14444
 
 
 
 
 
 
 
 
 
14445
  void ggml_cpu_init(void) {
14446
  // needed to initialize f16 tables
14447
  {
 
112
  int has_i8mm;
113
  int has_sve;
114
  int sve_cnt;
115
+ int has_sme;
116
+ } ggml_arm_arch_features = {-1, -1, -1, -1, 0, -1};
117
  #endif
118
 
119
 
 
2382
  #define HWCAP2_I8MM (1 << 13)
2383
  #endif
2384
 
2385
+ #if !defined(HWCAP2_SME)
2386
+ #define HWCAP2_SME (1 << 23)
2387
+ #endif
2388
+
2389
  static void ggml_init_arm_arch_features(void) {
2390
  #if defined(__linux__) && defined(__aarch64__)
2391
  uint32_t hwcap = getauxval(AT_HWCAP);
2392
  uint32_t hwcap2 = getauxval(AT_HWCAP2);
2393
 
2394
+ ggml_arm_arch_features.has_neon = !!(hwcap & HWCAP_ASIMD);
2395
  ggml_arm_arch_features.has_dotprod = !!(hwcap & HWCAP_ASIMDDP);
2396
+ ggml_arm_arch_features.has_i8mm = !!(hwcap2 & HWCAP2_I8MM);
2397
+ ggml_arm_arch_features.has_sve = !!(hwcap & HWCAP_SVE);
2398
+ ggml_arm_arch_features.has_sme = !!(hwcap2 & HWCAP2_SME);
2399
 
2400
  #if defined(__ARM_FEATURE_SVE)
2401
  ggml_arm_arch_features.sve_cnt = PR_SVE_VL_LEN_MASK & prctl(PR_SVE_GET_VL);
 
2418
  }
2419
  ggml_arm_arch_features.has_i8mm = oldp;
2420
 
2421
+ if (sysctlbyname("hw.optional.arm.FEAT_SME", &oldp, &size, NULL, 0) != 0) {
2422
+ oldp = 0;
2423
+ }
2424
+ ggml_arm_arch_features.has_sme = oldp;
2425
+
2426
  ggml_arm_arch_features.has_sve = 0;
2427
  ggml_arm_arch_features.sve_cnt = 0;
2428
  #else
 
2446
  ggml_arm_arch_features.has_sve = 0;
2447
  ggml_arm_arch_features.sve_cnt = 0;
2448
  #endif
2449
+
2450
+ #if defined(__ARM_FEATURE_SME) || defined(__ARM_FEATURE_SME2)
2451
+ ggml_arm_arch_features.has_sme = 1;
2452
+ #else
2453
+ ggml_arm_arch_features.has_sme = 0;
2454
+ #endif
2455
  #endif
2456
  }
2457
  #endif
 
14459
  #endif
14460
  }
14461
 
14462
+ int ggml_cpu_has_sme(void) {
14463
+ #if defined(__ARM_ARCH) && defined(__ARM_FEATURE_SME)
14464
+ return ggml_arm_arch_features.has_sme;
14465
+ #else
14466
+ return 0;
14467
+ #endif
14468
+ }
14469
+
14470
  void ggml_cpu_init(void) {
14471
  // needed to initialize f16 tables
14472
  {
ggml/src/ggml-cpu/ggml-cpu.cpp CHANGED
@@ -14,6 +14,10 @@
14
  #include "ggml-cpu-hbm.h"
15
  #endif
16
 
 
 
 
 
17
  #if defined(__APPLE__)
18
  #include <sys/types.h>
19
  #include <sys/sysctl.h>
@@ -39,6 +43,12 @@ std::vector<ggml_backend_buffer_type_t>& ggml_backend_cpu_get_extra_buffers_type
39
  }
40
  #endif
41
 
 
 
 
 
 
 
42
  #ifdef GGML_USE_CPU_AARCH64
43
  if (ggml_backend_cpu_aarch64_buffer_type()) {
44
  bufts.push_back(ggml_backend_cpu_aarch64_buffer_type());
@@ -538,6 +548,9 @@ static ggml_backend_feature * ggml_backend_cpu_get_features(ggml_backend_reg_t r
538
  static std::string sve_cnt = std::to_string(ggml_cpu_get_sve_cnt());
539
  features.push_back({ "SVE_CNT", sve_cnt.c_str() });
540
  }
 
 
 
541
  if (ggml_cpu_has_riscv_v()) {
542
  features.push_back({ "RISCV_V", "1" });
543
  }
@@ -559,6 +572,9 @@ static ggml_backend_feature * ggml_backend_cpu_get_features(ggml_backend_reg_t r
559
  #ifdef GGML_USE_OPENMP
560
  features.push_back({ "OPENMP", "1" });
561
  #endif
 
 
 
562
  #ifdef GGML_USE_CPU_AARCH64
563
  features.push_back({ "AARCH64_REPACK", "1" });
564
  #endif
 
14
  #include "ggml-cpu-hbm.h"
15
  #endif
16
 
17
+ #ifdef GGML_USE_CPU_KLEIDIAI
18
+ #include "kleidiai/kleidiai.h"
19
+ #endif
20
+
21
  #if defined(__APPLE__)
22
  #include <sys/types.h>
23
  #include <sys/sysctl.h>
 
43
  }
44
  #endif
45
 
46
+ #ifdef GGML_USE_CPU_KLEIDIAI
47
+ if (ggml_backend_cpu_kleidiai_buffer_type()) {
48
+ bufts.push_back(ggml_backend_cpu_kleidiai_buffer_type());
49
+ }
50
+ #endif
51
+
52
  #ifdef GGML_USE_CPU_AARCH64
53
  if (ggml_backend_cpu_aarch64_buffer_type()) {
54
  bufts.push_back(ggml_backend_cpu_aarch64_buffer_type());
 
548
  static std::string sve_cnt = std::to_string(ggml_cpu_get_sve_cnt());
549
  features.push_back({ "SVE_CNT", sve_cnt.c_str() });
550
  }
551
+ if (ggml_cpu_has_sme()) {
552
+ features.push_back({ "SME", "1" });
553
+ }
554
  if (ggml_cpu_has_riscv_v()) {
555
  features.push_back({ "RISCV_V", "1" });
556
  }
 
572
  #ifdef GGML_USE_OPENMP
573
  features.push_back({ "OPENMP", "1" });
574
  #endif
575
+ #ifdef GGML_USE_CPU_KLEIDIAI
576
+ features.push_back({ "KLEIDIAI", "1" });
577
+ #endif
578
  #ifdef GGML_USE_CPU_AARCH64
579
  features.push_back({ "AARCH64_REPACK", "1" });
580
  #endif
ggml/src/ggml-cpu/kleidiai/kernels.cpp ADDED
@@ -0,0 +1,259 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates <[email protected]>
2
+ // SPDX-License-Identifier: MIT
3
+ //
4
+
5
+ // KleidiAI micro-kernels
6
+ #include "kai_matmul_clamp_f32_qsi8d32p_qsi4c32p_interface.h"
7
+ #include "kai_lhs_quant_pack_qsi8d32p_f32.h"
8
+ #include "kai_lhs_quant_pack_qsi8d32p_f32_neon.h"
9
+ #include "kai_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0.h"
10
+ #include "kai_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon.h"
11
+ #include "kai_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod.h"
12
+ #include "kai_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod.h"
13
+ #include "kai_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod.h"
14
+ #include "kai_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm.h"
15
+ #include "kai_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa.h"
16
+ #include "kai_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot.h"
17
+ #include "kai_common.h"
18
+
19
+ #include "kernels.h"
20
+
21
+ #define NELEMS(x) sizeof(x) / sizeof(*x)
22
+ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
23
+ #if defined(__ARM_FEATURE_SME)
24
+ {
25
+ /* SME GEMM */
26
+ /* .kern_info = */ {
27
+ /* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa,
28
+ /* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa,
29
+ /* .get_mr = */ kai_get_mr_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa,
30
+ /* .get_nr = */ kai_get_nr_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa,
31
+ /* .get_kr = */ kai_get_kr_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa,
32
+ /* .get_sr = */ kai_get_sr_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa,
33
+ /* .get_lhs_offset = */ kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa,
34
+ /* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa,
35
+ /* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa,
36
+ /* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa,
37
+ /* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa,
38
+ },
39
+ /* SME GEMV */
40
+ /* .kern_info = */ {
41
+ /* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot,
42
+ /* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot,
43
+ /* .get_mr = */ kai_get_mr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot,
44
+ /* .get_nr = */ kai_get_nr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot,
45
+ /* .get_kr = */ kai_get_kr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot,
46
+ /* .get_sr = */ kai_get_sr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot,
47
+ /* .get_lhs_offset = */ kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot,
48
+ /* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot,
49
+ /* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot,
50
+ /* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot,
51
+ /* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot,
52
+ },
53
+ /* .lhs_info = */ {
54
+ /* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32,
55
+ /* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32,
56
+ /* .packed_size = */ kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32_neon,
57
+ /* .pack_func = */ kai_run_lhs_quant_pack_qsi8d32p_f32_neon,
58
+ /* .require_aligned_m_idx = */ true,
59
+ },
60
+ /* .rhs_info = */ {
61
+ /* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon,
62
+ /* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon,
63
+ },
64
+ /* .required_cpu = */ CPU_FEATURE_SME,
65
+ },
66
+ #endif
67
+ #if defined(__APPLE__)
68
+ #if defined(__ARM_FEATURE_DOTPROD)
69
+ {
70
+ /* DOTPROD GEMM */
71
+ /* .kern_info = */ {
72
+ /* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
73
+ /* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
74
+ /* .get_mr = */ kai_get_mr_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
75
+ /* .get_nr = */ kai_get_nr_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
76
+ /* .get_kr = */ kai_get_kr_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
77
+ /* .get_sr = */ kai_get_sr_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
78
+ /* .get_lhs_offset = */ kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
79
+ /* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
80
+ /* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
81
+ /* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
82
+ /* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
83
+ },
84
+ /* DOTPROD GEMV */
85
+ /* .kern_info = */ {
86
+ /* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
87
+ /* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
88
+ /* .get_mr = */ kai_get_mr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
89
+ /* .get_nr = */ kai_get_nr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
90
+ /* .get_kr = */ kai_get_kr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
91
+ /* .get_sr = */ kai_get_sr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
92
+ /* .get_lhs_offset = */ kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
93
+ /* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
94
+ /* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
95
+ /* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
96
+ /* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
97
+ },
98
+ /* .lhs_info = */ {
99
+ /* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32,
100
+ /* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32,
101
+ /* .packed_size = */ kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32,
102
+ /* .pack_func = */ kai_run_lhs_quant_pack_qsi8d32p_f32,
103
+ /* .require_aligned_m_idx = */ false,
104
+ },
105
+ /* .rhs_info = */ {
106
+ /* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
107
+ /* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
108
+ },
109
+ /* .required_cpu = */ CPU_FEATURE_DOTPROD,
110
+ },
111
+ #endif
112
+ #if defined(__ARM_FEATURE_MATMUL_INT8)
113
+ {
114
+ /* i8mm GEMM */
115
+ /* .kern_info = */ {
116
+ /* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
117
+ /* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
118
+ /* .get_mr = */ kai_get_mr_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
119
+ /* .get_nr = */ kai_get_nr_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
120
+ /* .get_kr = */ kai_get_kr_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
121
+ /* .get_sr = */ kai_get_sr_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
122
+ /* .get_lhs_offset = */ kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
123
+ /* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
124
+ /* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
125
+ /* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
126
+ /* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
127
+ },
128
+ /* i8mm GEMV */
129
+ /* .kern_info = */ {
130
+ /* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
131
+ /* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
132
+ /* .get_mr = */ kai_get_mr_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
133
+ /* .get_nr = */ kai_get_nr_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
134
+ /* .get_kr = */ kai_get_kr_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
135
+ /* .get_sr = */ kai_get_sr_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
136
+ /* .get_lhs_offset = */ kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
137
+ /* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
138
+ /* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
139
+ /* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
140
+ /* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
141
+ },
142
+ /* .lhs_info = */ {
143
+ /* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32,
144
+ /* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32,
145
+ /* .packed_size = */ kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32,
146
+ /* .pack_func = */ kai_run_lhs_quant_pack_qsi8d32p_f32,
147
+ /* .require_aligned_m_idx = */ false,
148
+ },
149
+ /* .rhs_info = */ {
150
+ /* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
151
+ /* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
152
+ },
153
+ /* .required_cpu = */ CPU_FEATURE_DOTPROD | CPU_FEATURE_I8MM,
154
+ },
155
+ #endif
156
+ #else
157
+ #if defined(__ARM_FEATURE_MATMUL_INT8)
158
+ {
159
+ /* i8mm GEMM */
160
+ /* .kern_info = */ {
161
+ /* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
162
+ /* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
163
+ /* .get_mr = */ kai_get_mr_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
164
+ /* .get_nr = */ kai_get_nr_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
165
+ /* .get_kr = */ kai_get_kr_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
166
+ /* .get_sr = */ kai_get_sr_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
167
+ /* .get_lhs_offset = */ kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
168
+ /* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
169
+ /* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
170
+ /* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
171
+ /* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
172
+ },
173
+ /* i8mm GEMV */
174
+ /* .kern_info = */ {
175
+ /* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
176
+ /* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
177
+ /* .get_mr = */ kai_get_mr_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
178
+ /* .get_nr = */ kai_get_nr_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
179
+ /* .get_kr = */ kai_get_kr_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
180
+ /* .get_sr = */ kai_get_sr_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
181
+ /* .get_lhs_offset = */ kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
182
+ /* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
183
+ /* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
184
+ /* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
185
+ /* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
186
+ },
187
+ /* .lhs_info = */ {
188
+ /* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32,
189
+ /* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32,
190
+ /* .packed_size = */ kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32,
191
+ /* .pack_func = */ kai_run_lhs_quant_pack_qsi8d32p_f32,
192
+ /* .require_aligned_m_idx = */ false,
193
+ },
194
+ /* .rhs_info = */ {
195
+ /* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
196
+ /* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
197
+ },
198
+ /* .required_cpu = */ CPU_FEATURE_DOTPROD | CPU_FEATURE_I8MM,
199
+ },
200
+ #endif
201
+ #if defined(__ARM_FEATURE_DOTPROD)
202
+ {
203
+ /* DOTPROD GEMM */
204
+ /* .kern_info = */ {
205
+ /* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
206
+ /* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
207
+ /* .get_mr = */ kai_get_mr_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
208
+ /* .get_nr = */ kai_get_nr_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
209
+ /* .get_kr = */ kai_get_kr_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
210
+ /* .get_sr = */ kai_get_sr_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
211
+ /* .get_lhs_offset = */ kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
212
+ /* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
213
+ /* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
214
+ /* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
215
+ /* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
216
+ },
217
+ /* DOTPROD GEMV */
218
+ /* .kern_info = */ {
219
+ /* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
220
+ /* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
221
+ /* .get_mr = */ kai_get_mr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
222
+ /* .get_nr = */ kai_get_nr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
223
+ /* .get_kr = */ kai_get_kr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
224
+ /* .get_sr = */ kai_get_sr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
225
+ /* .get_lhs_offset = */ kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
226
+ /* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
227
+ /* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
228
+ /* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
229
+ /* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
230
+ },
231
+ /* .lhs_info = */ {
232
+ /* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32,
233
+ /* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32,
234
+ /* .packed_size = */ kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32,
235
+ /* .pack_func = */ kai_run_lhs_quant_pack_qsi8d32p_f32,
236
+ /* .require_aligned_m_idx = */ false,
237
+ },
238
+ /* .rhs_info = */ {
239
+ /* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
240
+ /* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
241
+ },
242
+ /* .required_cpu = */ CPU_FEATURE_DOTPROD,
243
+ },
244
+ #endif
245
+ #endif
246
+ };
247
+
248
+ ggml_kleidiai_kernels * ggml_kleidiai_select_kernels(cpu_feature features) {
249
+ ggml_kleidiai_kernels * kernels = nullptr;
250
+
251
+ for (size_t i = 0; i < NELEMS(gemm_gemv_kernels); ++i) {
252
+ if ((features & gemm_gemv_kernels[i].required_cpu) == gemm_gemv_kernels[i].required_cpu) {
253
+ kernels = &gemm_gemv_kernels[i];
254
+ break;
255
+ }
256
+ }
257
+
258
+ return kernels;
259
+ }
ggml/src/ggml-cpu/kleidiai/kernels.h ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates <[email protected]>
2
+ // SPDX-License-Identifier: MIT
3
+ //
4
+
5
+ #pragma once
6
+
7
+ enum cpu_feature {
8
+ CPU_FEATURE_NONE = 0,
9
+ CPU_FEATURE_DOTPROD = 1,
10
+ CPU_FEATURE_I8MM = 2,
11
+ CPU_FEATURE_SVE = 4,
12
+ CPU_FEATURE_SME = 8
13
+ };
14
+ inline cpu_feature& operator|=(cpu_feature& lhs, cpu_feature rhs) {
15
+ lhs = static_cast<cpu_feature>(lhs | rhs);
16
+ return lhs;
17
+ }
18
+ inline cpu_feature operator|(cpu_feature lhs, cpu_feature rhs) {
19
+ return static_cast<cpu_feature>(static_cast<int>(lhs) | static_cast<int>(rhs));
20
+ }
21
+
22
+ struct kernel_info {
23
+ size_t (*get_m_step)(void);
24
+ size_t (*get_n_step)(void);
25
+ size_t (*get_mr)(void);
26
+ size_t (*get_nr)(void);
27
+ size_t (*get_kr)(void);
28
+ size_t (*get_sr)(void);
29
+ size_t (*get_lhs_offset)(size_t m_idx, size_t k, size_t bl);
30
+ size_t (*get_rhs_packed_offset)(size_t n_idx, size_t k, size_t bl);
31
+ size_t (*get_dst_offset)(size_t m_idx, size_t n_idx, size_t stride);
32
+ size_t (*get_dst_size)(size_t m, size_t n);
33
+ void (*run_kernel)(size_t m, size_t n, size_t k, size_t bl, const void* lhs_packed, const void* rhs_packed,
34
+ float* dst, size_t dst_stride_row, size_t dst_stride_col, float scalar_min, float scalar_max);
35
+ };
36
+
37
+ struct lhs_packing_info {
38
+ size_t (*get_offset)(size_t m_idx, size_t lhs_stride);
39
+ size_t (*get_packed_offset)(size_t m_idx, size_t k, size_t bl, size_t mr, size_t kr, size_t sr);
40
+ size_t (*packed_size)(size_t m, size_t k, size_t bl, size_t mr, size_t kr, size_t sr);
41
+ void (*pack_func)(size_t m, size_t k, size_t bl, size_t mr, size_t kr, size_t sr, size_t m_idx_start, const float* lhs,
42
+ size_t lhs_stride, void* lhs_packed);
43
+ bool require_aligned_m_idx;
44
+ };
45
+
46
+ struct rhs_packing_info {
47
+ size_t (*packed_size)(size_t n, size_t k, size_t nr, size_t kr, size_t bl);
48
+ void (*pack_func)(size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, size_t bl, const uint8_t* rhs,
49
+ const float* bias, void* rhs_packed, size_t extra_bytes, const struct kai_rhs_pack_qs4cxs1s0_param* params);
50
+ };
51
+
52
+ struct ggml_kleidiai_kernels {
53
+ kernel_info gemm;
54
+ kernel_info gemv;
55
+ lhs_packing_info lhs_info;
56
+ rhs_packing_info rhs_info;
57
+
58
+ cpu_feature required_cpu;
59
+ };
60
+
61
+ ggml_kleidiai_kernels * ggml_kleidiai_select_kernels(cpu_feature cpu_features);
ggml/src/ggml-cpu/kleidiai/kleidiai.cpp ADDED
@@ -0,0 +1,287 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates <[email protected]>
2
+ // SPDX-License-Identifier: MIT
3
+ //
4
+ #include <arm_neon.h>
5
+ #include <assert.h>
6
+ #include <cfloat>
7
+ #include <stdint.h>
8
+ #include <string.h>
9
+ #if defined(__linux__)
10
+ #include <asm/hwcap.h>
11
+ #include <sys/auxv.h>
12
+ #elif defined(__APPLE__)
13
+ #include <string_view>
14
+ #include <sys/sysctl.h>
15
+ #include <sys/types.h>
16
+ #elif defined(_WIN32)
17
+ #include <windows.h>
18
+ #include <excpt.h>
19
+ #endif
20
+
21
+ #include "kleidiai.h"
22
+
23
+ #include "ggml-cpu.h"
24
+ #include "ggml-impl.h"
25
+ #include "ggml-backend-impl.h"
26
+ #include "ggml-threading.h"
27
+ #include "ggml-cpu-traits.h"
28
+
29
+ #include "kernels.h"
30
+
31
+ #include "kai_common.h"
32
+
33
+ #define GGML_COMMON_DECL_CPP
34
+ #include "ggml-common.h"
35
+
36
+ struct ggml_kleidiai_context {
37
+ ggml_kleidiai_kernels * kernels;
38
+ } static ctx = { NULL };
39
+
40
+ static void init_kleidiai_context(void) {
41
+
42
+ ggml_critical_section_start();
43
+ static bool initialized = false;
44
+
45
+ if (!initialized) {
46
+ initialized = true;
47
+ const char *env_var = getenv("GGML_KLEIDIAI_SME");
48
+ int sme_enabled = 0;
49
+
50
+ cpu_feature features = (ggml_cpu_has_dotprod() ? CPU_FEATURE_DOTPROD : CPU_FEATURE_NONE) |
51
+ (ggml_cpu_has_matmul_int8() ? CPU_FEATURE_I8MM : CPU_FEATURE_NONE) |
52
+ (ggml_cpu_has_sve() ? CPU_FEATURE_SVE : CPU_FEATURE_NONE);
53
+
54
+ if (env_var) {
55
+ sme_enabled = atoi(env_var);
56
+ }
57
+
58
+ if (sme_enabled != 0) {
59
+ features |= ggml_cpu_has_sme() ? CPU_FEATURE_SME : CPU_FEATURE_NONE;
60
+ }
61
+ ctx.kernels = ggml_kleidiai_select_kernels(features);
62
+ }
63
+ ggml_critical_section_end();
64
+ }
65
+
66
+ static inline int64_t ggml_ne(const ggml_tensor * tensor, int dim) {
67
+ GGML_ASSERT(dim >= 0 && dim < GGML_MAX_DIMS);
68
+ return tensor->ne[dim];
69
+ }
70
+
71
+ namespace ggml::cpu::kleidiai {
72
+ class tensor_traits : public ggml::cpu::tensor_traits {
73
+ bool work_size(int /* n_threads */, const struct ggml_tensor * op, size_t & size) override {
74
+ GGML_ASSERT(ctx.kernels);
75
+ kernel_info * kernel = op->src[1]->ne[1] == 1 ? &ctx.kernels->gemv : &ctx.kernels->gemm;
76
+
77
+ size_t k = op->src[0]->ne[0];
78
+ size_t m = op->src[1]->ne[1];
79
+
80
+ size_t mr = kernel->get_mr();
81
+ size_t kr = kernel->get_kr();
82
+ size_t sr = kernel->get_sr();
83
+
84
+ size = ctx.kernels->lhs_info.packed_size(m, k, QK4_0, mr, kr, sr);
85
+
86
+ return true;
87
+ }
88
+
89
+ bool compute_forward(struct ggml_compute_params * params, struct ggml_tensor * dst) override {
90
+ if (dst->op == GGML_OP_MUL_MAT) {
91
+ const ggml_tensor * src0 = dst->src[0];
92
+ const ggml_tensor * src1 = dst->src[1];
93
+
94
+ GGML_TENSOR_BINARY_OP_LOCALS
95
+
96
+ GGML_ASSERT(ctx.kernels);
97
+ kernel_info * kernel = src1->ne[1] == 1 ? &ctx.kernels->gemv : &ctx.kernels->gemm;
98
+ lhs_packing_info * lhs_info = &ctx.kernels->lhs_info;
99
+
100
+ GGML_ASSERT(kernel);
101
+
102
+ const int ith = params->ith;
103
+ const int nth = params->nth;
104
+
105
+ const size_t k = ne00;
106
+ const size_t m = ne11;
107
+ const size_t n = ne01;
108
+
109
+ const size_t n_step = kernel->get_n_step();
110
+ const size_t num_n_per_thread = kai_roundup(kai_roundup(n, nth) / nth, n_step);
111
+ const size_t n_start = ith * num_n_per_thread;
112
+
113
+ size_t n_to_process = num_n_per_thread;
114
+ if ((n_start + n_to_process) > n) {
115
+ n_to_process = n - n_start;
116
+ }
117
+
118
+ const uint8_t * lhs = static_cast<const uint8_t *>(src1->data);
119
+ uint8_t * lhs_packed = (uint8_t*)params->wdata;
120
+ const uint8_t * rhs_packed = static_cast<const uint8_t *>(src0->data);
121
+
122
+ size_t mr = kernel->get_mr();
123
+ size_t kr = kernel->get_kr();
124
+ size_t sr = kernel->get_sr();
125
+
126
+ // Calculate number of columns to be processed per thread
127
+ const bool use_multithread = lhs_info->require_aligned_m_idx && m <= mr ? false : true;
128
+ const size_t num_m_per_thread = use_multithread ? kai_roundup(m, nth) / nth : m;
129
+ const size_t m_start = ith * num_m_per_thread;
130
+ size_t m_to_process = num_m_per_thread;
131
+ if ((m_start + m_to_process) > m) {
132
+ m_to_process = m - m_start;
133
+ }
134
+
135
+ if(m_start < m) {
136
+ // Transform LHS
137
+ const size_t src_stride = src1->nb[1];
138
+ const float * src_ptr = reinterpret_cast<const float *>(lhs + lhs_info->get_offset(0, dst->src[1]->nb[1]));
139
+ const size_t lhs_packed_offset = lhs_info->get_packed_offset(m_start, k, QK4_0, mr, kr, sr);
140
+ void * lhs_packed_ptr = static_cast<void *>(lhs_packed + lhs_packed_offset);
141
+
142
+ lhs_info->pack_func(m_to_process, k, QK4_0, mr, kr, sr, m_start, src_ptr, src_stride, lhs_packed_ptr);
143
+ }
144
+
145
+ ggml_barrier(params->threadpool);
146
+
147
+ // Perform the operation
148
+ const size_t dst_stride = dst->nb[1];
149
+ const size_t lhs_packed_offset = lhs_info->get_packed_offset(0, k, QK4_0, mr, kr, sr);
150
+ const size_t rhs_packed_offset = kernel->get_rhs_packed_offset(n_start, k, QK4_0);
151
+ const size_t dst_offset = kernel->get_dst_offset(0, n_start, dst_stride);
152
+ const void * rhs_ptr = static_cast<const void *>(rhs_packed + rhs_packed_offset);
153
+ const void* lhs_ptr = (const void*)((const char *)lhs_packed + lhs_packed_offset);
154
+ float *dst_ptr = reinterpret_cast<float *>(static_cast<uint8_t *>(dst->data) + dst_offset);
155
+
156
+ kernel->run_kernel(m, n_to_process, k, QK4_0, lhs_ptr, rhs_ptr, dst_ptr,
157
+ dst_stride, sizeof(float), -FLT_MAX, FLT_MAX);
158
+ return true;
159
+ }
160
+ return false;
161
+ }
162
+
163
+ public:
164
+ int repack(struct ggml_tensor * tensor, const void * data, size_t data_size) {
165
+ GGML_ASSERT(ctx.kernels);
166
+ const size_t n = tensor->ne[1];
167
+ const size_t k = tensor->ne[0];
168
+ size_t nr = ctx.kernels->gemm.get_nr();
169
+ size_t kr = ctx.kernels->gemm.get_kr();
170
+ size_t sr = ctx.kernels->gemm.get_sr();
171
+
172
+ #ifndef NDEBUG
173
+ const size_t repacked_size = ctx.kernels->rhs_info.packed_size(n, k, nr, kr, QK4_0);
174
+ GGML_ASSERT(repacked_size <= data_size && "repacked size larger than the packed size!");
175
+ #endif
176
+ struct kai_rhs_pack_qs4cxs1s0_param params;
177
+ params.lhs_zero_point = 1;
178
+ params.rhs_zero_point = 8;
179
+ ctx.kernels->rhs_info.pack_func(1, n, k, nr, kr, sr, QK4_0, (const uint8_t *)data, NULL, tensor->data, 0, &params);
180
+
181
+ return 0;
182
+
183
+ GGML_UNUSED(data_size);
184
+ }
185
+ };
186
+
187
+ static ggml::cpu::tensor_traits * get_tensor_traits(ggml_backend_buffer_t, struct ggml_tensor *) {
188
+ static tensor_traits traits;
189
+ return &traits;
190
+ }
191
+ } // namespace ggml::cpu::kleidiai
192
+
193
+ static void ggml_backend_cpu_kleidiai_buffer_init_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor) {
194
+ tensor->extra = (void *) ggml::cpu::kleidiai::get_tensor_traits(buffer, tensor);
195
+
196
+ GGML_UNUSED(buffer);
197
+ }
198
+
199
+ static void ggml_backend_cpu_kleidiai_buffer_set_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor,
200
+ const void * data, size_t offset, size_t size) {
201
+ GGML_ASSERT(offset == 0);
202
+ GGML_ASSERT(size == ggml_nbytes(tensor));
203
+
204
+ auto tensor_traits = (ggml::cpu::kleidiai::tensor_traits *) tensor->extra;
205
+ auto OK = tensor_traits->repack(tensor, data, size);
206
+
207
+ GGML_ASSERT(OK == 0);
208
+ GGML_UNUSED(buffer);
209
+ }
210
+
211
+ static const char * ggml_backend_cpu_kleidiai_buffer_type_get_name(ggml_backend_buffer_type_t buft) {
212
+ return "CPU_KLEIDIAI";
213
+
214
+ GGML_UNUSED(buft);
215
+ }
216
+
217
+ static ggml_backend_buffer_t ggml_backend_cpu_kleidiai_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
218
+ ggml_backend_buffer_t buffer = ggml_backend_buft_alloc_buffer(ggml_backend_cpu_buffer_type(), size);
219
+
220
+ if (buffer == nullptr) {
221
+ return nullptr;
222
+ }
223
+
224
+ buffer->buft = buft;
225
+ buffer->iface.init_tensor = ggml_backend_cpu_kleidiai_buffer_init_tensor;
226
+ buffer->iface.set_tensor = ggml_backend_cpu_kleidiai_buffer_set_tensor;
227
+ buffer->iface.get_tensor = nullptr;
228
+ buffer->iface.cpy_tensor = nullptr;
229
+ return buffer;
230
+ }
231
+
232
+ static size_t ggml_backend_cpu_kleidiai_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {
233
+ return TENSOR_ALIGNMENT;
234
+
235
+ GGML_UNUSED(buft);
236
+ }
237
+
238
+ namespace ggml::cpu::kleidiai {
239
+ class extra_buffer_type : ggml::cpu::extra_buffer_type {
240
+ bool supports_op(ggml_backend_dev_t, const struct ggml_tensor * op) override {
241
+ if ( op->op == GGML_OP_MUL_MAT &&
242
+ op->src[0]->type == GGML_TYPE_Q4_0 &&
243
+ op->src[0]->buffer &&
244
+ (ggml_n_dims(op->src[0]) == 2) &&
245
+ op->src[0]->buffer->buft == ggml_backend_cpu_kleidiai_buffer_type() && ctx.kernels
246
+ ) {
247
+ if (op->src[1]->buffer && !ggml_backend_buft_is_host(op->src[1]->buffer->buft)) {
248
+ return false;
249
+ }
250
+ if (op->src[1]->type == GGML_TYPE_F32 &&
251
+ ggml_ne(op->src[1], 2) == 1 && ggml_ne(op->src[1], 3) == 1) {
252
+ return true;
253
+ }
254
+ }
255
+ return false;
256
+ }
257
+
258
+ ggml::cpu::tensor_traits * get_tensor_traits(const struct ggml_tensor * op) override {
259
+ if (op->op == GGML_OP_MUL_MAT) {
260
+ if (op->src[0]->buffer && op->src[0]->buffer->buft == ggml_backend_cpu_kleidiai_buffer_type()) {
261
+ return (ggml::cpu::tensor_traits *) op->src[0]->extra;
262
+ }
263
+ }
264
+ return nullptr;
265
+ }
266
+ };
267
+ } // namespace ggml::cpu::kleidiai
268
+
269
+ ggml_backend_buffer_type_t ggml_backend_cpu_kleidiai_buffer_type(void) {
270
+ static ggml::cpu::kleidiai::extra_buffer_type ctx;
271
+ static struct ggml_backend_buffer_type ggml_backend_cpu_buffer_type_kleidiai = {
272
+ /* .iface = */ {
273
+ /* .get_name = */ ggml_backend_cpu_kleidiai_buffer_type_get_name,
274
+ /* .alloc_buffer = */ ggml_backend_cpu_kleidiai_buffer_type_alloc_buffer,
275
+ /* .get_alignment = */ ggml_backend_cpu_kleidiai_buffer_type_get_alignment,
276
+ /* .get_max_size = */ nullptr, // defaults to SIZE_MAX
277
+ /* .get_alloc_size = */ nullptr, // defaults to ggml_nbytes
278
+ /* .is_host = */ nullptr,
279
+ },
280
+ /* .device = */ ggml_backend_reg_dev_get(ggml_backend_cpu_reg(), 0),
281
+ /* .context = */ &ctx,
282
+ };
283
+
284
+ init_kleidiai_context();
285
+
286
+ return &ggml_backend_cpu_buffer_type_kleidiai;
287
+ }
ggml/src/ggml-cpu/kleidiai/kleidiai.h ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates <[email protected]>
2
+ // SPDX-License-Identifier: MIT
3
+ //
4
+
5
+ #pragma once
6
+
7
+ #include "ggml-alloc.h"
8
+
9
+ #ifdef __cplusplus
10
+ extern "C" {
11
+ #endif
12
+
13
+ ggml_backend_buffer_type_t ggml_backend_cpu_kleidiai_buffer_type(void);
14
+
15
+ #ifdef __cplusplus
16
+ }
17
+ #endif