hipudding commited on
Commit
7c34a03
·
1 Parent(s): 0a74031

ggml : add CANN backend (llama/0)

Browse files
ggml/include/ggml-cann.h ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*
2
+ * Copyright (c) 2023-2024 The ggml authors
3
+ *
4
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
5
+ * of this software and associated documentation files (the "Software"), to
6
+ * deal in the Software without restriction, including without limitation the
7
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
8
+ * sell copies of the Software, and to permit persons to whom the Software is
9
+ * furnished to do so, subject to the following conditions:
10
+ *
11
+ * The above copyright notice and this permission notice shall be included in
12
+ * all copies or substantial portions of the Software.
13
+ *
14
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
15
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
16
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
17
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
18
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
19
+ * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
20
+ * IN THE SOFTWARE.
21
+ */
22
+
23
+ #pragma once
24
+
25
+ #include "ggml-backend.h"
26
+ #include "ggml.h"
27
+
28
+ #ifdef __cplusplus
29
+ extern "C" {
30
+ #endif
31
+
32
+ /**
33
+ * @brief Maximum number of CANN devices supported.
34
+ */
35
+ #define GGML_CANN_MAX_DEVICES 16
36
+
37
+ /**
38
+ * @brief Initializes the CANN backend for a specified device.
39
+ *
40
+ * This function initializes the CANN backend for the given device.
41
+ * It verifies the device index, allocates a context, and creates a backend
42
+ * instance.
43
+ *
44
+ * @param device The index of the device to initialize.
45
+ * @return A pointer to the initialized backend instance, or nullptr on failure.
46
+ */
47
+ GGML_API GGML_CALL ggml_backend_t ggml_backend_cann_init(int32_t device);
48
+
49
+ /**
50
+ * @brief Checks if a given backend is a CANN backend.
51
+ *
52
+ * This function verifies if the provided backend is a CANN backend by comparing
53
+ * its GUID with the CANN backend's GUID.
54
+ *
55
+ * @param backend The backend instance to check.
56
+ * @return True if the backend is a CANN backend, false otherwise.
57
+ */
58
+ GGML_API GGML_CALL bool ggml_backend_is_cann(ggml_backend_t backend);
59
+
60
+ /**
61
+ * @brief Retrieves the CANN buffer type for a specified device.
62
+ *
63
+ * This function initializes and returns the buffer type interface associated
64
+ * with the given device. It ensures thread-safe access using a mutex.
65
+ *
66
+ * @param device The device index for which to retrieve the buffer type.
67
+ * @return A pointer to the buffer type interface for the specified device, or
68
+ * nullptr if the device index is out of range.
69
+ */
70
+ GGML_API GGML_CALL ggml_backend_buffer_type_t
71
+ ggml_backend_cann_buffer_type(int32_t device);
72
+
73
+ /**
74
+ * @brief Retrieves the number of CANN devices available.
75
+ *
76
+ * This function returns the number of CANN devices available based on
77
+ * information obtained from `ggml_cann_info()`.
78
+ *
79
+ * @return The number of CANN devices available.
80
+ */
81
+ GGML_API GGML_CALL int32_t ggml_backend_cann_get_device_count(void);
82
+
83
+ /**
84
+ * @brief Retrieves the description of a specific CANN device.
85
+ *
86
+ * This function sets the specified device, retrieves the SoC name,
87
+ * and writes it into the provided description buffer.
88
+ *
89
+ * @param device The device index to retrieve the description for.
90
+ * @param description Pointer to a buffer where the description will be written.
91
+ * @param description_size Size of the description buffer.
92
+ */
93
+ GGML_API GGML_CALL void ggml_backend_cann_get_device_description(
94
+ int32_t device, char* description, size_t description_size);
95
+
96
+ /**
97
+ * @brief Retrieves the memory information of a specific CANN device.
98
+ *
99
+ * This function sets the specified device, retrieves the free and total
100
+ * memory information of the specified type (ACL_HBM_MEM), and stores them
101
+ * in the provided pointers.
102
+ *
103
+ * @param device The device index to retrieve memory information for.
104
+ * @param free Pointer to a variable where the free memory size will be stored.
105
+ * @param total Pointer to a variable where the total memory size will be
106
+ * stored.
107
+ */
108
+ GGML_API GGML_CALL void ggml_backend_cann_get_device_memory(int32_t device,
109
+ size_t* free,
110
+ size_t* total);
111
+
112
+ /**
113
+ * @brief Set the logging callback for GGML.
114
+ *
115
+ * This function sets the logging callback and user data for logging.
116
+ *
117
+ * @param log_callback The logging callback to set.
118
+ * @param user_data User data to pass to the logging callback.
119
+ */
120
+ GGML_API void ggml_backend_cann_log_set_callback(ggml_log_callback log_callback,
121
+ void* user_data);
122
+
123
+ #ifdef __cplusplus
124
+ }
125
+ #endif
ggml/src/ggml-cann.cpp ADDED
@@ -0,0 +1,2020 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*
2
+ * Copyright (c) 2023-2024 The ggml authors
3
+ *
4
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
5
+ * of this software and associated documentation files (the "Software"), to
6
+ * deal in the Software without restriction, including without limitation the
7
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
8
+ * sell copies of the Software, and to permit persons to whom the Software is
9
+ * furnished to do so, subject to the following conditions:
10
+ *
11
+ * The above copyright notice and this permission notice shall be included in
12
+ * all copies or substantial portions of the Software.
13
+ *
14
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
15
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
16
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
17
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
18
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
19
+ * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
20
+ * IN THE SOFTWARE.
21
+ */
22
+
23
+ #include "ggml-cann.h"
24
+
25
+ #include <acl/acl.h>
26
+ #include <stdarg.h>
27
+
28
+ #include <cmath>
29
+ #include <cstdio>
30
+ #include <cstring>
31
+ #include <mutex>
32
+
33
+ #include "ggml-backend-impl.h"
34
+ #include "ggml-cann/aclnn_ops.h"
35
+ #include "ggml-cann/common.h"
36
+
37
+ #define GGML_COMMON_DECL_C
38
+
39
+ #include "ggml-common.h"
40
+
41
+ /**
42
+ * @brief Default logging callback for GGML.
43
+ *
44
+ * This function is the default logging callback that logs messages to stderr.
45
+ *
46
+ * @param level The log level.
47
+ * @param msg The log message.
48
+ * @param user_data User data passed to the callback.
49
+ */
50
+ static void ggml_cann_default_log_callback(enum ggml_log_level level,
51
+ const char* msg, void* user_data) {
52
+ GGML_UNUSED(level);
53
+ GGML_UNUSED(user_data);
54
+ fprintf(stderr, "%s", msg);
55
+ }
56
+
57
+ ggml_log_callback ggml_cann_log_callback = ggml_cann_default_log_callback;
58
+ void* ggml_cann_log_user_data = NULL;
59
+
60
+ GGML_API void ggml_backend_cann_log_set_callback(ggml_log_callback log_callback,
61
+ void* user_data) {
62
+ ggml_cann_log_callback = log_callback;
63
+ ggml_cann_log_user_data = user_data;
64
+ }
65
+
66
+ #define GGML_CANN_LOG_INFO(...) ggml_cann_log(GGML_LOG_LEVEL_INFO, __VA_ARGS__)
67
+ #define GGML_CANN_LOG_WARN(...) ggml_cann_log(GGML_LOG_LEVEL_WARN, __VA_ARGS__)
68
+ #define GGML_CANN_LOG_ERROR(...) \
69
+ ggml_cann_log(GGML_LOG_LEVEL_ERROR, __VA_ARGS__)
70
+
71
+ GGML_ATTRIBUTE_FORMAT(2, 3)
72
+
73
+ /**
74
+ * @brief Log a message using the current logging callback.
75
+ *
76
+ * This function formats a log message and passes it to the current logging
77
+ * callback.
78
+ *
79
+ * @param level The log level.
80
+ * @param format The format string for the log message.
81
+ * @param ... The arguments for the format string.
82
+ */
83
+ static void ggml_cann_log(enum ggml_log_level level, const char* format, ...) {
84
+ if (ggml_cann_log_callback != NULL) {
85
+ va_list args;
86
+ va_start(args, format);
87
+ char buffer[128];
88
+ int len = vsnprintf(buffer, 128, format, args);
89
+ if (len < 128) {
90
+ ggml_cann_log_callback(level, buffer, ggml_cann_log_user_data);
91
+ } else {
92
+ // vsnprintf adds a null terminator
93
+ std::vector<char> buffer2(len + 1);
94
+ va_end(args);
95
+ va_start(args, format);
96
+ vsnprintf(&buffer2[0], buffer2.size(), format, args);
97
+ ggml_cann_log_callback(level, buffer2.data(),
98
+ ggml_cann_log_user_data);
99
+ }
100
+ va_end(args);
101
+ }
102
+ }
103
+
104
+ /**
105
+ * @brief Handles CANN errors by printing an error message and aborting.
106
+ *
107
+ * @param stmt The statement that caused the error.
108
+ * @param func The function in which the error occurred.
109
+ * @param file The file in which the error occurred.
110
+ * @param line The line number where the error occurred.
111
+ * @param msg The error message.
112
+ */
113
+ [[noreturn]] void ggml_cann_error(const char* stmt, const char* func,
114
+ const char* file, int line, const char* msg) {
115
+ int32_t id = -1;
116
+ aclrtGetDevice(&id);
117
+
118
+ GGML_CANN_LOG_ERROR("CANN error: %s\n", msg);
119
+ GGML_CANN_LOG_ERROR(" current device: %d, in function %s at %s:%d\n", id, func,
120
+ file, line);
121
+ GGML_CANN_LOG_ERROR(" %s\n", stmt);
122
+ // abort with GGML_ASSERT to get a stack trace
123
+ GGML_ABORT("CANN error");
124
+ }
125
+
126
+ /**
127
+ * @brief Sets the device to be used by CANN.
128
+ *
129
+ * @param device The device ID to set.
130
+ */
131
+ void ggml_cann_set_device(const int32_t device) {
132
+ // TODO: uncomment these lines after empty context has fixed.
133
+ // int current_device;
134
+ // ACL_CHECK(aclrtGetDevice(&current_device));
135
+
136
+ // if (device == current_device) {
137
+ // return;
138
+ // }
139
+ ACL_CHECK(aclrtSetDevice(device));
140
+ }
141
+
142
+ /**
143
+ * @brief Retrieves the current device ID.
144
+ *
145
+ * @return The current device ID.
146
+ */
147
+ int32_t ggml_cann_get_device() {
148
+ int32_t id;
149
+ ACL_CHECK(aclrtGetDevice(&id));
150
+ return id;
151
+ }
152
+
153
+ /**
154
+ * @brief Initialize the CANN device information.
155
+ *
156
+ * This function initializes the CANN device information by obtaining the
157
+ * device count and setting the memory allocation granularity for each device.
158
+ *
159
+ * @return A structure containing the device information.
160
+ */
161
+ static ggml_cann_device_info ggml_cann_init() {
162
+ ggml_cann_device_info info = {};
163
+
164
+ aclError err = aclrtGetDeviceCount((uint32_t*)&info.device_count);
165
+
166
+ if (err != ACL_SUCCESS) {
167
+ GGML_CANN_LOG_ERROR("%s: failed to initialize CANN: %s\n",
168
+ __func__, aclGetRecentErrMsg());
169
+ return info;
170
+ }
171
+
172
+ GGML_ASSERT(info.device_count <= GGML_CANN_MAX_DEVICES);
173
+
174
+ for (int id = 0; id < info.device_count; ++id) {
175
+ aclrtPhysicalMemProp prop = {};
176
+ prop.handleType = ACL_MEM_HANDLE_TYPE_NONE;
177
+ prop.allocationType = ACL_MEM_ALLOCATION_TYPE_PINNED;
178
+ prop.memAttr = ACL_HBM_MEM_HUGE;
179
+ prop.location.type = ACL_MEM_LOCATION_TYPE_DEVICE;
180
+ prop.location.id = id;
181
+ prop.reserve = 0;
182
+ ACL_CHECK(aclrtMemGetAllocationGranularity(
183
+ &prop, ACL_RT_MEM_ALLOC_GRANULARITY_RECOMMENDED,
184
+ &info.devices[id].vmm_granularity));
185
+ }
186
+
187
+ // TODO: add more device info later.
188
+ return info;
189
+ }
190
+
191
+ /**
192
+ * @brief Retrieve the CANN device information.
193
+ *
194
+ * This function returns a reference to a structure containing the CANN device
195
+ * information. The device information is initialized once and reused on
196
+ * subsequent calls.
197
+ *
198
+ * @return A reference to the structure containing the device information.
199
+ */
200
+ const ggml_cann_device_info& ggml_cann_info() {
201
+ static ggml_cann_device_info info = ggml_cann_init();
202
+ return info;
203
+ }
204
+
205
+ //#define DEBUG_CANN_MALLOC
206
+ /**
207
+ * @brief A pool of CANN buffers(legacy).
208
+ *
209
+ * This class manages a pool of CANN buffers for a specific device.
210
+ */
211
+ struct ggml_cann_pool_leg : public ggml_cann_pool {
212
+ /**
213
+ * @brief The maximum number of buffers in the pool.
214
+ */
215
+ static const int MAX_BUFFERS = 256;
216
+
217
+ /**
218
+ * @brief The device ID associated with this buffer pool.
219
+ */
220
+ int device;
221
+
222
+ /**
223
+ * @brief Structure representing a CANN buffer.
224
+ */
225
+ struct ggml_cann_buffer {
226
+ void* ptr = nullptr; ///< Pointer to the buffer memory.
227
+ size_t size = 0; ///< Size of the buffer.
228
+ };
229
+
230
+ /**
231
+ * @brief Array of CANN buffers in the pool.
232
+ */
233
+ ggml_cann_buffer buffer_pool[MAX_BUFFERS] = {};
234
+
235
+ /**
236
+ * @brief Total size of all buffers in the pool.
237
+ */
238
+ size_t pool_size = 0;
239
+
240
+ /**
241
+ * @brief Constructor to initialize the buffer pool for a specific device.
242
+ *
243
+ * @param device The device ID to associate with this buffer pool.
244
+ */
245
+ explicit ggml_cann_pool_leg(int device) : device(device) {}
246
+
247
+ /**
248
+ * @brief Destructor to free all buffers in the pool.
249
+ */
250
+ ~ggml_cann_pool_leg() {
251
+ ggml_cann_set_device(device);
252
+ for (int i = 0; i < MAX_BUFFERS; ++i) {
253
+ ggml_cann_buffer& b = buffer_pool[i];
254
+ if (b.ptr != nullptr) {
255
+ ACL_CHECK(aclrtFree(b.ptr));
256
+ pool_size -= b.size;
257
+ }
258
+ }
259
+ GGML_ASSERT(pool_size == 0);
260
+ }
261
+
262
+ /**
263
+ * @brief Allocate a buffer of the given size.
264
+ *
265
+ * @param size The size of the buffer to allocate.
266
+ * @param actual_size A pointer to a variable to receive the actual size of
267
+ * the allocated buffer.
268
+ * @return A pointer to the allocated buffer.
269
+ */
270
+ void* alloc(size_t size, size_t* actual_size) override {
271
+ #ifdef DEBUG_CANN_MALLOC
272
+ int nnz = 0;
273
+ size_t max_size = 0;
274
+ #endif
275
+ size_t best_diff = 1ull << 36;
276
+ int ibest = -1;
277
+ for (int i = 0; i < MAX_BUFFERS; ++i) {
278
+ ggml_cann_buffer& b = buffer_pool[i];
279
+ if (b.ptr != nullptr) {
280
+ #ifdef DEBUG_CANN_MALLOC
281
+ ++nnz;
282
+ if (b.size > max_size) max_size = b.size;
283
+ #endif
284
+ if (b.size >= size) {
285
+ size_t diff = b.size - size;
286
+ if (diff < best_diff) {
287
+ best_diff = diff;
288
+ ibest = i;
289
+ if (!best_diff) {
290
+ void* ptr = b.ptr;
291
+ *actual_size = b.size;
292
+ b.ptr = nullptr;
293
+ b.size = 0;
294
+ return ptr;
295
+ }
296
+ }
297
+ }
298
+ }
299
+ }
300
+ if (ibest >= 0) {
301
+ ggml_cann_buffer& b = buffer_pool[ibest];
302
+ void* ptr = b.ptr;
303
+ *actual_size = b.size;
304
+ b.ptr = nullptr;
305
+ b.size = 0;
306
+ return ptr;
307
+ }
308
+ void* ptr;
309
+ size_t look_ahead_size = (size_t)(1.05 * size);
310
+ look_ahead_size = 256 * ((look_ahead_size + 255) / 256);
311
+ ggml_cann_set_device(device);
312
+ ACL_CHECK(
313
+ aclrtMalloc(&ptr, look_ahead_size, ACL_MEM_MALLOC_HUGE_FIRST));
314
+ *actual_size = look_ahead_size;
315
+ pool_size += look_ahead_size;
316
+ #ifdef DEBUG_CANN_MALLOC
317
+ GGML_CANN_LOG_INFO(
318
+ "%s[%d]: %d buffers, max_size = %u MB, pool_size = %u MB, "
319
+ "requested %u MB\n",
320
+ __func__, device, nnz, (uint32_t)(max_size / 1024 / 1024),
321
+ (uint32_t)(pool_size / 1024 / 1024),
322
+ (uint32_t)(size / 1024 / 1024));
323
+ #endif
324
+ return ptr;
325
+ }
326
+
327
+ /**
328
+ * @brief Free a buffer and return it to the pool.
329
+ *
330
+ * @param ptr Pointer to the buffer to free.
331
+ * @param size Size of the buffer to free.
332
+ */
333
+ void free(void* ptr, size_t size) override {
334
+ for (int i = 0; i < MAX_BUFFERS; ++i) {
335
+ ggml_cann_buffer& b = buffer_pool[i];
336
+ if (b.ptr == nullptr) {
337
+ b.ptr = ptr;
338
+ b.size = size;
339
+ return;
340
+ }
341
+ }
342
+ // memory should always buffered. these memory may still needed by
343
+ // tasks in stream.
344
+ // TODO, fix me.
345
+ GGML_ABORT("Cann buffer pool full, increase MAX_CANN_BUFFERS\n");
346
+ }
347
+ };
348
+
349
+ /**
350
+ * @brief A pool of CANN buffers with virtual memory.
351
+ *
352
+ * This class manages a pool of CANN buffers with virtual memory for a specific
353
+ * device.
354
+ */
355
+ struct ggml_cann_pool_vmm : public ggml_cann_pool {
356
+ /**
357
+ * @brief The maximum size of the virtual memory pool (32 GB).
358
+ */
359
+ static const size_t CANN_POOL_VMM_MAX_SIZE = 1ull << 35; // 32 GB
360
+
361
+ /**
362
+ * @brief The device ID associated with this buffer pool.
363
+ */
364
+ int device;
365
+
366
+ /**
367
+ * @brief Pointer to the start of the virtual memory pool.
368
+ */
369
+ void* pool_addr = 0;
370
+
371
+ /**
372
+ * @brief Amount of virtual memory used in the pool.
373
+ */
374
+ size_t pool_used = 0;
375
+
376
+ /**
377
+ * @brief Total size of the virtual memory pool.
378
+ */
379
+ size_t pool_size = 0;
380
+
381
+ /**
382
+ * @brief Allocation granularity for the virtual memory pool.
383
+ */
384
+ size_t granularity;
385
+
386
+ /**
387
+ * @brief Handles for the physical memory allocated.
388
+ */
389
+ std::vector<aclrtDrvMemHandle> handles;
390
+
391
+ /**
392
+ * @brief Offsets for the mapped memory regions.
393
+ */
394
+ std::vector<void*> map_offsets;
395
+
396
+ /**
397
+ * @brief Constructor to initialize the buffer pool with virtual memory for
398
+ * a specific device.
399
+ *
400
+ * @param device The device ID to associate with this buffer pool.
401
+ */
402
+ explicit ggml_cann_pool_vmm(int device)
403
+ : device(device),
404
+ granularity(ggml_cann_info().devices[device].vmm_granularity) {}
405
+
406
+ /**
407
+ * @brief Destructor to free all buffers in the virtual memory pool.
408
+ */
409
+ ~ggml_cann_pool_vmm() {
410
+ if (pool_addr != 0) {
411
+ for (auto& offset : map_offsets) {
412
+ ACL_CHECK(aclrtUnmapMem(offset));
413
+ }
414
+ for (auto& handle : handles) {
415
+ ACL_CHECK(aclrtFreePhysical(handle));
416
+ }
417
+ ACL_CHECK(aclrtReleaseMemAddress(pool_addr));
418
+ }
419
+ }
420
+
421
+ /**
422
+ * @brief Allocate a buffer of the given size in the virtual memory pool.
423
+ *
424
+ * @param size The size of the buffer to allocate.
425
+ * @param actual_size A pointer to a variable to receive the actual size of
426
+ * the allocated buffer.
427
+ * @return A pointer to the allocated buffer.
428
+ */
429
+ void* alloc(size_t size, size_t* actual_size) override {
430
+ // round up the allocation size to the alignment to ensure that all
431
+ // allocations are aligned for all data types
432
+ const size_t alignment = 128;
433
+ size = alignment * ((size + alignment - 1) / alignment);
434
+
435
+ size_t avail = pool_size - pool_used;
436
+
437
+ if (size > avail) {
438
+ // round up to the next multiple of the granularity
439
+ size_t reserve_size = size - avail;
440
+ reserve_size =
441
+ granularity * ((reserve_size + granularity - 1) / granularity);
442
+
443
+ GGML_ASSERT(pool_size + reserve_size <= CANN_POOL_VMM_MAX_SIZE);
444
+
445
+ // allocate more physical memory
446
+ aclrtPhysicalMemProp prop = {};
447
+ prop.handleType = ACL_MEM_HANDLE_TYPE_NONE;
448
+ prop.allocationType = ACL_MEM_ALLOCATION_TYPE_PINNED;
449
+ prop.memAttr = ACL_HBM_MEM_HUGE;
450
+ prop.location.type = ACL_MEM_LOCATION_TYPE_DEVICE;
451
+ prop.location.id = device;
452
+ prop.reserve = 0;
453
+ aclrtDrvMemHandle handle;
454
+ ACL_CHECK(aclrtMallocPhysical(&handle, reserve_size, &prop, 0));
455
+
456
+ // reserve virtual address space (if not already reserved)
457
+ if (pool_addr == 0) {
458
+ ACL_CHECK(aclrtReserveMemAddress(
459
+ &pool_addr, CANN_POOL_VMM_MAX_SIZE, 0, NULL, 1));
460
+ }
461
+
462
+ // map at the end of the pool
463
+ ACL_CHECK(aclrtMapMem((char*)pool_addr + pool_size, reserve_size, 0,
464
+ handle, 0));
465
+
466
+ handles.push_back(handle);
467
+ map_offsets.push_back((char*)pool_addr + pool_size);
468
+
469
+ // add to the pool
470
+ pool_size += reserve_size;
471
+
472
+ // GGML_CANN_LOG_INFO("cann pool[%d]: size increased to %llu MB (
473
+ // reserved %llu MB)\n",
474
+ // device, (unsigned long long) (pool_size/1024/1024),
475
+ // (unsigned long long) (reserve_size/1024/1024));
476
+ }
477
+
478
+ GGML_ASSERT(pool_addr != 0);
479
+
480
+ void* ptr = (void*)((char*)pool_addr + pool_used);
481
+ *actual_size = size;
482
+ pool_used += size;
483
+
484
+ #ifdef DEBUG_CANN_MALLOC
485
+ GGML_CANN_LOG_INFO("cann pool[%d]: allocated %llu bytes at %llx\n", device,
486
+ (unsigned long long)size, (unsigned long long)ptr);
487
+ #endif
488
+ return ptr;
489
+ }
490
+
491
+ /**
492
+ * @brief Free a buffer and return it to the virtual memory pool.
493
+ *
494
+ * @param ptr Pointer to the buffer to free.
495
+ * @param size Size of the buffer to free.
496
+ */
497
+ void free(void* ptr, size_t size) override {
498
+ #ifdef DEBUG_CANN_MALLOC
499
+ GGML_CANN_LOG_INFO("cann pool[%d]: freed %llu bytes at %llx\n", device,
500
+ (unsigned long long)size, (unsigned long long)ptr);
501
+ #endif
502
+
503
+ pool_used -= size;
504
+
505
+ // all deallocations must be in reverse order of the allocations
506
+ GGML_ASSERT(ptr == (void*)((char*)pool_addr + pool_used));
507
+ }
508
+ };
509
+
510
+ /**
511
+ * @brief Create a new CANN pool for a specific device.
512
+ *
513
+ * Factory method to create a new CANN pool object based on the device type.
514
+ *
515
+ * @param device The device ID for which to create the pool.
516
+ * @return A unique pointer to the created CANN pool.
517
+ */
518
+ std::unique_ptr<ggml_cann_pool> ggml_backend_cann_context::new_pool_for_device(
519
+ int device) {
520
+ // return std::unique_ptr<ggml_cann_pool>(new ggml_cann_pool_leg(device));
521
+ return std::unique_ptr<ggml_cann_pool>(new ggml_cann_pool_vmm(device));
522
+ }
523
+
524
+ // cann buffer
525
+ /**
526
+ * @brief Context for managing a CANN buffer associated with a specific device.
527
+ *
528
+ * This structure holds information about a CANN buffer, including the device
529
+ * ID, device pointer, and a name derived from GGML_CANN_NAME and the device ID.
530
+ */
531
+ struct ggml_backend_cann_buffer_context {
532
+ int32_t device; ///< The device ID associated with this buffer context.
533
+ void* dev_ptr =
534
+ nullptr; ///< Pointer to the device memory allocated for the buffer.
535
+
536
+ /**
537
+ * @brief Constructor to initialize the CANN buffer context.
538
+ *
539
+ * @param device The device ID associated with this buffer context.
540
+ * @param dev_ptr Pointer to the device memory allocated for the buffer.
541
+ */
542
+ ggml_backend_cann_buffer_context(int32_t device, void* dev_ptr)
543
+ : device(device),
544
+ dev_ptr(dev_ptr) {}
545
+
546
+ /**
547
+ * @brief Destructor to free the device memory allocated for the buffer.
548
+ */
549
+ ~ggml_backend_cann_buffer_context() { ACL_CHECK(aclrtFree(dev_ptr)); }
550
+ };
551
+
552
+ /**
553
+ * @brief Retrieve the name associated with a CANN buffer.
554
+ *
555
+ * This function returns the name of a CANN buffer, which is stored in the
556
+ * context of the buffer.
557
+ *
558
+ * @param buffer The CANN buffer whose name is to be retrieved.
559
+ * @return A pointer to a C-string containing the name of the buffer.
560
+ */
561
+
562
+ GGML_CALL static const char* ggml_backend_cann_buffer_get_name(
563
+ ggml_backend_buffer_t buffer) {
564
+ return "CANN";
565
+
566
+ GGML_UNUSED(buffer);
567
+ }
568
+
569
+ /**
570
+ * @brief Check if a buffer is a CANN buffer.
571
+ *
572
+ * This function checks if a given buffer is a CANN buffer by comparing its
573
+ * `get_name` function pointer to `ggml_backend_cann_buffer_get_name`.
574
+ *
575
+ * @param buffer The buffer to check.
576
+ * @return true if the buffer is a CANN buffer, false otherwise.
577
+ */
578
+ GGML_CALL static bool ggml_backend_buffer_is_cann(
579
+ ggml_backend_buffer_t buffer) {
580
+ return buffer->iface.get_name == ggml_backend_cann_buffer_get_name;
581
+ }
582
+
583
+ /**
584
+ * @brief Free resources associated with a CANN buffer.
585
+ *
586
+ * This function frees the resources associated with a CANN buffer, including
587
+ * its context.
588
+ *
589
+ * @param buffer The CANN buffer to free.
590
+ */
591
+ GGML_CALL static void ggml_backend_cann_buffer_free_buffer(
592
+ ggml_backend_buffer_t buffer) {
593
+ ggml_backend_cann_buffer_context* ctx =
594
+ (ggml_backend_cann_buffer_context*)buffer->context;
595
+ delete ctx;
596
+ }
597
+
598
+ /**
599
+ * @brief Retrieve the base pointer of a CANN buffer.
600
+ *
601
+ * This function returns the base pointer of a CANN buffer, which points to the
602
+ * device memory allocated for the buffer.
603
+ *
604
+ * @param buffer The CANN buffer whose base pointer is to be retrieved.
605
+ * @return A pointer to the base of the device memory allocated for the buffer.
606
+ */
607
+ GGML_CALL static void* ggml_backend_cann_buffer_get_base(
608
+ ggml_backend_buffer_t buffer) {
609
+ ggml_backend_cann_buffer_context* ctx =
610
+ (ggml_backend_cann_buffer_context*)buffer->context;
611
+ return ctx->dev_ptr;
612
+ }
613
+
614
+ /**
615
+ * @brief Transform quantized Q4.0 tensor data into a format suitable for CANN
616
+ * processing.
617
+ *
618
+ * This function transforms quantized Q4.0 tensor data into a format suitable
619
+ * for CANN processing. It extracts quantization values and scales from the
620
+ * source data and prepares them in a format expected by CANN operations.
621
+ *
622
+ * @param tensor Pointer to the tensor information.
623
+ * @param src Pointer to the source data in Q4.0 format.
624
+ * @param dst Pointer to the destination buffer where transformed data will be
625
+ * stored.
626
+ */
627
+ GGML_CALL static void ggml_backend_cann_transform_q4_0(ggml_tensor* tensor,
628
+ const void* src,
629
+ void* dst) {
630
+
631
+ int64_t n_elems = ggml_nelements(tensor);
632
+ int64_t groups = n_elems / QK4_0;
633
+ size_t quant_bytes = n_elems * sizeof(uint8_t) / 2;
634
+
635
+ uint8_t* quant_offset = (uint8_t*)dst;
636
+ uint16_t* scale_offset = (uint16_t*)((char*)dst + quant_bytes);
637
+
638
+ for (int i = 0; i < groups; i++) {
639
+ const block_q4_0* group =
640
+ (const block_q4_0*)((const char*)src + i * sizeof(block_q4_0));
641
+ *scale_offset = group->d;
642
+ scale_offset++;
643
+
644
+ // 0-15
645
+ for (int j = 0; j < QK4_0 / 2; j += 2) {
646
+ (*quant_offset) = (group->qs[j] & 0x0F);
647
+ (*quant_offset) |= ((group->qs[j + 1] << 4));
648
+ quant_offset++;
649
+ }
650
+
651
+ // 16-31
652
+ for (int j = 0; j < QK4_0 / 2; j += 2) {
653
+ (*quant_offset) = (group->qs[j] >> 4);
654
+ (*quant_offset) |= (group->qs[j + 1] & 0xF0);
655
+ quant_offset++;
656
+ }
657
+ }
658
+
659
+ // put (uint4b_t -8) into int4b_t
660
+ for (quant_offset = (uint8_t*)dst;
661
+ quant_offset < (uint8_t*)dst + quant_bytes; quant_offset++) {
662
+ (*quant_offset) ^= 0x88;
663
+ }
664
+ }
665
+
666
+ /**
667
+ * @brief Transform CANN processed data back into quantized Q4.0 format.
668
+ *
669
+ * This function transforms CANN processed data back into quantized Q4.0 format.
670
+ * It reverses the transformation performed by
671
+ * ggml_backend_cann_transform_q4_0(), converting the data back into its
672
+ * original quantized form.
673
+ *
674
+ * @param tensor Pointer to the tensor information.
675
+ * @param src Pointer to the source buffer containing transformed data.
676
+ * @param dst Pointer to the destination buffer where the Q4.0 formatted data
677
+ * will be stored.
678
+ */
679
+ GGML_CALL static void ggml_backend_cann_transform_back_q4_0(
680
+ const ggml_tensor* tensor, void* src, void* dst) {
681
+
682
+ int64_t n_elems = ggml_nelements(tensor);
683
+ int64_t groups = n_elems / QK4_0;
684
+ size_t quant_bytes = n_elems * sizeof(uint8_t) / 2;
685
+
686
+ uint8_t* quant_offset = (uint8_t*)src;
687
+ uint16_t* scale_offset = (uint16_t*)((char*)src + quant_bytes);
688
+
689
+ for (; quant_offset < (uint8_t*)src + quant_bytes; quant_offset++) {
690
+ (*quant_offset) ^= 0x88;
691
+ }
692
+ quant_offset = (uint8_t*)src;
693
+
694
+ for (int i = 0; i < groups; i++) {
695
+ block_q4_0* group = (block_q4_0*)((char*)dst + i * sizeof(block_q4_0));
696
+ group->d = *scale_offset;
697
+ scale_offset++;
698
+
699
+ // 0-15
700
+ for (int j = 0; j < QK4_0 / 2; j += 2) {
701
+ group->qs[j] = ((*quant_offset) & 0x0F);
702
+ group->qs[j + 1] = ((*quant_offset) >> 4);
703
+ quant_offset++;
704
+ }
705
+
706
+ // 16-31
707
+ for (int j = 0; j < QK4_0 / 2; j += 2) {
708
+ group->qs[j] |= ((*quant_offset) << 4);
709
+ group->qs[j + 1] |= ((*quant_offset) & 0xF0);
710
+ quant_offset++;
711
+ }
712
+ }
713
+ }
714
+
715
+ /**
716
+ * @brief Transform quantized Q8.0 tensor data into a format suitable for CANN
717
+ * processing.
718
+ *
719
+ * This function transforms quantized Q8.0 tensor data into a format suitable
720
+ * for CANN processing. It extracts quantization values and scales from the
721
+ * source data and prepares them in a format expected by CANN operations.
722
+ *
723
+ * @param tensor Pointer to the tensor information.
724
+ * @param src Pointer to the source data in Q8.0 format.
725
+ * @param dst Pointer to the destination buffer where transformed data will be
726
+ * stored.
727
+ */
728
+ GGML_CALL static void ggml_backend_cann_transform_q8_0(ggml_tensor* tensor,
729
+ const void* src,
730
+ void* dst) {
731
+ int64_t n_elems = ggml_nelements(tensor);
732
+ int64_t groups = n_elems / QK8_0;
733
+ size_t quant_bytes = n_elems * sizeof(uint8_t);
734
+
735
+ uint8_t* quant_offset = (uint8_t*)dst;
736
+ uint16_t* scale_offset = (uint16_t*)((char*)dst + quant_bytes);
737
+
738
+ for (int i = 0; i < groups; i++) {
739
+ const block_q8_0* group =
740
+ (const block_q8_0*)((const char*)src + i * sizeof(block_q8_0));
741
+ *scale_offset = group->d;
742
+ scale_offset++;
743
+ size_t group_quant_size = QK8_0 * sizeof(uint8_t);
744
+ memcpy(quant_offset, group->qs, group_quant_size);
745
+ quant_offset += group_quant_size;
746
+ }
747
+ }
748
+
749
+ /**
750
+ * @brief Transform CANN processed data back into quantized Q8.0 format.
751
+ *
752
+ * This function transforms CANN processed data back into quantized Q8.0 format.
753
+ * It reverses the transformation performed by
754
+ * ggml_backend_cann_transform_q8_0(), converting the data back into its
755
+ * original quantized form.
756
+ *
757
+ * @param tensor Pointer to the tensor information.
758
+ * @param src Pointer to the source buffer containing transformed data.
759
+ * @param dst Pointer to the destination buffer where the Q8.0 formatted data
760
+ * will be stored.
761
+ */
762
+ GGML_CALL static void ggml_backend_cann_transform_back_q8_0(
763
+ const ggml_tensor* tensor, const void* src, void* dst) {
764
+ int64_t n_elems = ggml_nelements(tensor);
765
+ int64_t groups = n_elems / QK8_0;
766
+ size_t quant_bytes = n_elems * sizeof(uint8_t);
767
+
768
+ const uint8_t* quant_offset = (const uint8_t*)src;
769
+ const uint16_t* scale_offset =
770
+ (const uint16_t*)((const char*)src + quant_bytes);
771
+
772
+ for (int i = 0; i < groups; i++) {
773
+ block_q8_0* group = (block_q8_0*)((char*)dst + i * sizeof(block_q8_0));
774
+ group->d = *scale_offset;
775
+ scale_offset++;
776
+ size_t group_quant_size = QK8_0 * sizeof(uint8_t);
777
+ memcpy(group->qs, quant_offset, group_quant_size);
778
+ quant_offset += group_quant_size;
779
+ }
780
+ }
781
+
782
+ /**
783
+ * @brief Transform tensor data based on its type for CANN processing.
784
+ *
785
+ * This function transforms tensor data based on its quantization type for CANN
786
+ * processing. It dispatches the transformation based on the tensor's type to
787
+ * specialized functions handling Q4.0 and Q8.0 formats.
788
+ *
789
+ * @param tensor Pointer to the tensor information.
790
+ * @param src Pointer to the source data to be transformed.
791
+ * @param dst Pointer to the destination buffer where transformed data will be
792
+ * stored.
793
+ */
794
+ GGML_CALL static void ggml_backend_cann_transform(ggml_tensor* tensor,
795
+ const void* src, void* dst) {
796
+ switch (tensor->type) {
797
+ case GGML_TYPE_Q4_0:
798
+ ggml_backend_cann_transform_q4_0(tensor, src, dst);
799
+ break;
800
+ case GGML_TYPE_Q8_0:
801
+ ggml_backend_cann_transform_q8_0(tensor, src, dst);
802
+ break;
803
+ default:
804
+ break;
805
+ }
806
+ }
807
+
808
+ /**
809
+ * @brief Transform CANN processed data back into tensor data based on its type.
810
+ *
811
+ * This function transforms CANN processed data back into tensor data based on
812
+ * its quantization type for Q4.0 and Q8.0 formats. It dispatches the
813
+ * transformation based on the tensor's type to specialized functions.
814
+ *
815
+ * @param tensor Pointer to the tensor information.
816
+ * @param src Pointer to the source data containing CANN processed data.
817
+ * @param dst Pointer to the destination buffer where transformed tensor data
818
+ * will be stored.
819
+ */
820
+ GGML_CALL static void ggml_backend_cann_transform_back(
821
+ const ggml_tensor* tensor, void* src, void* dst) {
822
+ switch (tensor->type) {
823
+ case GGML_TYPE_Q4_0:
824
+ ggml_backend_cann_transform_back_q4_0(tensor, src, dst);
825
+ break;
826
+ case GGML_TYPE_Q8_0:
827
+ ggml_backend_cann_transform_back_q8_0(tensor, src, dst);
828
+ break;
829
+ default:
830
+ break;
831
+ }
832
+ }
833
+
834
+ /**
835
+ * @brief Check if transformation is needed for a given tensor type.
836
+ *
837
+ * This function checks if transformation is needed for a given tensor type
838
+ * to prepare data for CANN processing.
839
+ *
840
+ * @param type The tensor type to check.
841
+ * @return true if transformation is needed, false otherwise.
842
+ */
843
+ GGML_CALL static bool need_transform(ggml_type type) {
844
+ switch (type) {
845
+ case GGML_TYPE_Q4_0:
846
+ case GGML_TYPE_Q8_0:
847
+ return true;
848
+ default:
849
+ return false;
850
+ }
851
+ }
852
+
853
+ /**
854
+ * @brief Initialize a tensor using data from a CANN buffer.
855
+ *
856
+ * This function initializes a tensor using data from a CANN buffer.
857
+ * It handles special cases such as views and quantization.
858
+ *
859
+ * @param buffer The CANN buffer from which to initialize the tensor.
860
+ * @param tensor Pointer to the tensor to be initialized.
861
+ */
862
+ GGML_CALL static void ggml_backend_cann_buffer_init_tensor(
863
+ ggml_backend_buffer_t buffer, ggml_tensor* tensor) {
864
+ if (tensor->view_src != NULL && tensor->view_offs == 0) {
865
+ GGML_ASSERT(tensor->view_src->buffer->buft == buffer->buft);
866
+ return;
867
+ }
868
+
869
+ // TODO: can backend doesn't support quantized yet. Just leave the code
870
+ // here.
871
+ if (ggml_is_quantized(tensor->type)) {
872
+ // Initialize padding to 0 to avoid possible NaN values
873
+ size_t original_size = ggml_nbytes(tensor);
874
+ size_t padded_size =
875
+ ggml_backend_buft_get_alloc_size(buffer->buft, tensor);
876
+
877
+ if (padded_size > original_size && tensor->view_src == nullptr) {
878
+ size_t memset_size = padded_size - original_size;
879
+ ACL_CHECK(aclrtMemset((char*)tensor->data + original_size,
880
+ memset_size, 0, memset_size));
881
+ }
882
+ }
883
+ }
884
+
885
+ // TODO: need handle tensor which has paddings.
886
+ /**
887
+ * @brief Set tensor data in a CANN buffer.
888
+ *
889
+ * This function sets tensor data in a CANN buffer, handling transformations
890
+ * if needed based on the tensor's type.
891
+ *
892
+ * @param buffer The CANN buffer where the tensor data will be set.
893
+ * @param tensor Pointer to the tensor whose data will be set.
894
+ * @param data Pointer to the source data to be copied into the tensor.
895
+ * @param offset Offset in the source data from where to start copying.
896
+ * @param size Size of the data to be copied, in bytes.
897
+ */
898
+ GGML_CALL static void ggml_backend_cann_buffer_set_tensor(
899
+ ggml_backend_buffer_t buffer, ggml_tensor *tensor, const void *data,
900
+ size_t offset, size_t size) {
901
+ ggml_backend_cann_buffer_context *ctx =
902
+ (ggml_backend_cann_buffer_context *)buffer->context;
903
+
904
+ ggml_cann_set_device(ctx->device);
905
+ // TODO: refer to cann(#6017), it use thread's default stream.
906
+ // For acl, synchronous functions use this default stream.
907
+ // Why aclrtSynchronizeDevice?
908
+
909
+ if (!need_transform(tensor->type)) {
910
+ ACL_CHECK(aclrtMemcpy((char *)tensor->data + offset, size, data, size,
911
+ ACL_MEMCPY_HOST_TO_DEVICE));
912
+ } else {
913
+ void *transform_buffer = malloc(size);
914
+ ggml_backend_cann_transform(tensor, data, transform_buffer);
915
+
916
+ #ifndef NDEBUG
917
+ void *check_buffer = malloc(size);
918
+ ggml_backend_cann_transform_back(tensor, transform_buffer,
919
+ check_buffer);
920
+ GGML_ASSERT(memcmp(data, check_buffer, size) == 0);
921
+ free(check_buffer);
922
+ #endif
923
+ ACL_CHECK(aclrtMemcpy((char *)tensor->data + offset, size,
924
+ transform_buffer, size,
925
+ ACL_MEMCPY_HOST_TO_DEVICE));
926
+ free(transform_buffer);
927
+ }
928
+ }
929
+
930
+ /**
931
+ * @brief Get tensor data from a CANN buffer.
932
+ *
933
+ * This function retrieves tensor data from a CANN buffer, handling
934
+ * transformations if needed based on the tensor's type.
935
+ *
936
+ * @param buffer The CANN buffer from which to retrieve tensor data.
937
+ * @param tensor Pointer to the tensor whose data will be retrieved.
938
+ * @param data Pointer to the destination buffer where the tensor data will be
939
+ * copied.
940
+ * @param offset Offset in the destination buffer where to start copying.
941
+ * @param size Size of the data to be copied, in bytes.
942
+ */
943
+ GGML_CALL static void ggml_backend_cann_buffer_get_tensor(
944
+ ggml_backend_buffer_t buffer, const ggml_tensor* tensor, void* data,
945
+ size_t offset, size_t size) {
946
+ ggml_backend_cann_buffer_context* ctx =
947
+ (ggml_backend_cann_buffer_context*)buffer->context;
948
+
949
+ ggml_cann_set_device(ctx->device);
950
+
951
+ if (!need_transform(tensor->type)) {
952
+ ACL_CHECK(aclrtMemcpy(data, size, (char*)tensor->data + offset, size,
953
+ ACL_MEMCPY_DEVICE_TO_HOST));
954
+ } else {
955
+ void* transform_buffer = malloc(size);
956
+ ACL_CHECK(aclrtMemcpy(transform_buffer, size,
957
+ (char*)tensor->data + offset, size,
958
+ ACL_MEMCPY_DEVICE_TO_HOST));
959
+ ggml_backend_cann_transform_back(tensor, transform_buffer, data);
960
+ free(transform_buffer);
961
+ }
962
+ }
963
+
964
+ /**
965
+ * @brief Copy tensor data between CANN buffers if possible.
966
+ *
967
+ * This function copies tensor data between CANN buffers if the source and
968
+ * destination buffers are CANN buffers and they meet the necessary conditions
969
+ * (same device or devices can access each other).
970
+ *
971
+ * @param buffer The destination CANN buffer where the tensor data will be
972
+ * copied.
973
+ * @param src Pointer to the source tensor whose data will be copied.
974
+ * @param dst Pointer to the destination tensor where the data will be copied.
975
+ * @return true if the copy operation succeeded, false otherwise.
976
+ */
977
+ GGML_CALL static bool ggml_backend_cann_buffer_cpy_tensor(
978
+ ggml_backend_buffer_t buffer, const ggml_tensor* src, ggml_tensor* dst) {
979
+ if (ggml_backend_buffer_is_cann(src->buffer)) {
980
+ ggml_backend_cann_buffer_context* src_ctx =
981
+ (ggml_backend_cann_buffer_context*)src->buffer->context;
982
+ ggml_backend_cann_buffer_context* dst_ctx =
983
+ (ggml_backend_cann_buffer_context*)buffer->context;
984
+
985
+ size_t memcpy_size = ggml_nbytes(src);
986
+ // Same device.
987
+ if (src_ctx->device == dst_ctx->device) {
988
+ ACL_CHECK(aclrtMemcpy((char*)dst->data, memcpy_size,
989
+ (const char*)src->data, memcpy_size,
990
+ ACL_MEMCPY_DEVICE_TO_DEVICE));
991
+ return true;
992
+ } else {
993
+ // Different device but can access by peer.
994
+ int32_t canAccessPeer = 0;
995
+ ACL_CHECK(aclrtDeviceCanAccessPeer(&canAccessPeer, src_ctx->device,
996
+ dst_ctx->device));
997
+ if (canAccessPeer) {
998
+ ggml_cann_set_device(src_ctx->device);
999
+ ACL_CHECK(aclrtDeviceEnablePeerAccess(dst_ctx->device, 0));
1000
+ ACL_CHECK(aclrtMemcpy((char*)dst->data, memcpy_size,
1001
+ (const char*)src->data, memcpy_size,
1002
+ ACL_MEMCPY_DEVICE_TO_DEVICE));
1003
+ return true;
1004
+ }
1005
+ }
1006
+ }
1007
+ return false;
1008
+ }
1009
+
1010
+ /**
1011
+ * @brief Clear a CANN buffer by setting all its memory to a specified value.
1012
+ *
1013
+ * This function clears a CANN buffer by setting all its memory to a specified
1014
+ * value.
1015
+ *
1016
+ * @param buffer The CANN buffer to be cleared.
1017
+ * @param value The value to which each byte in the buffer will be set.
1018
+ */
1019
+ GGML_CALL static void ggml_backend_cann_buffer_clear(
1020
+ ggml_backend_buffer_t buffer, uint8_t value) {
1021
+ ggml_backend_cann_buffer_context* ctx =
1022
+ (ggml_backend_cann_buffer_context*)buffer->context;
1023
+
1024
+ ggml_cann_set_device(ctx->device);
1025
+ ACL_CHECK(aclrtMemset(ctx->dev_ptr, buffer->size, value, buffer->size));
1026
+ }
1027
+
1028
+ /**
1029
+ * @brief Interface for a CANN buffer in the backend.
1030
+ *
1031
+ * This structure defines function pointers to operations that can be performed
1032
+ * on a CANN buffer within the backend.
1033
+ */
1034
+ static ggml_backend_buffer_i ggml_backend_cann_buffer_interface = {
1035
+ /* .get_name = */ ggml_backend_cann_buffer_get_name,
1036
+ /* .free_buffer = */ ggml_backend_cann_buffer_free_buffer,
1037
+ /* .get_base = */ ggml_backend_cann_buffer_get_base,
1038
+ /* .init_tensor = */ ggml_backend_cann_buffer_init_tensor,
1039
+ /* .set_tensor = */ ggml_backend_cann_buffer_set_tensor,
1040
+ /* .get_tensor = */ ggml_backend_cann_buffer_get_tensor,
1041
+ /* .cpy_tensor = */ ggml_backend_cann_buffer_cpy_tensor,
1042
+ /* .clear = */ ggml_backend_cann_buffer_clear,
1043
+ /* .reset = */ NULL,
1044
+ };
1045
+
1046
+ // cann buffer type
1047
+ /**
1048
+ * @brief Structure representing context information for a specific backend
1049
+ * buffer type.
1050
+ */
1051
+ struct ggml_backend_cann_buffer_type_context {
1052
+ int32_t
1053
+ device; /**< Device identifier associated with the buffer context. */
1054
+ std::string name; /**< Name associated with the buffer context. */
1055
+ };
1056
+
1057
+ /**
1058
+ * @brief Retrieves the name associated with a CANN buffer type.
1059
+ *
1060
+ * This function returns the descriptive name associated with the specified
1061
+ * CANN buffer type context.
1062
+ *
1063
+ * @param buft Pointer to the buffer type context.
1064
+ * @return Const pointer to the C-style string containing the name.
1065
+ */
1066
+ GGML_CALL static const char* ggml_backend_cann_buffer_type_name(
1067
+ ggml_backend_buffer_type_t buft) {
1068
+ return "CANN";
1069
+
1070
+ GGML_UNUSED(buft);
1071
+ }
1072
+
1073
+ /**
1074
+ * @brief Allocates a new CANN buffer of the specified type and size.
1075
+ *
1076
+ * This function allocates a new CANN buffer on the specified device with the
1077
+ * given size.
1078
+ *
1079
+ * @param buft Pointer to the buffer type context.
1080
+ * @param size Size in bytes of the buffer to allocate.
1081
+ * @return Pointer to the allocated buffer, or nullptr if allocation fails.
1082
+ */
1083
+ GGML_CALL static ggml_backend_buffer_t
1084
+ ggml_backend_cann_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft,
1085
+ size_t size) {
1086
+ ggml_backend_cann_buffer_type_context* buft_ctx =
1087
+ (ggml_backend_cann_buffer_type_context*)buft->context;
1088
+
1089
+ ggml_cann_set_device(buft_ctx->device);
1090
+
1091
+ size = std::max(size, (size_t)1);
1092
+
1093
+ void* dev_ptr;
1094
+ aclError err = aclrtMalloc(&dev_ptr, size, ACL_MEM_MALLOC_HUGE_FIRST);
1095
+ if (err != ACL_SUCCESS) {
1096
+ GGML_CANN_LOG_ERROR(
1097
+ "%s: allocating %.2f MiB on device %d: aclrtMalloc failed: %s\n",
1098
+ __func__, size / 1024.0 / 1024.0, buft_ctx->device,
1099
+ aclGetRecentErrMsg());
1100
+ return nullptr;
1101
+ }
1102
+
1103
+ ggml_backend_cann_buffer_context* ctx =
1104
+ new ggml_backend_cann_buffer_context(buft_ctx->device, dev_ptr);
1105
+
1106
+ return ggml_backend_buffer_init(buft, ggml_backend_cann_buffer_interface,
1107
+ ctx, size);
1108
+ }
1109
+
1110
+ /**
1111
+ * @brief Retrieves the memory alignment requirement for CANN buffers of this
1112
+ * type.
1113
+ *
1114
+ * This function returns the alignment requirement in bytes for memory allocated
1115
+ * by the CANN buffer type.
1116
+ *
1117
+ * @param buft Pointer to the buffer type context (unused in this
1118
+ * implementation).
1119
+ * @return The alignment requirement in bytes (fixed at 128 bytes for CANN
1120
+ * buffers).
1121
+ */
1122
+ GGML_CALL static size_t ggml_backend_cann_buffer_type_get_alignment(
1123
+ ggml_backend_buffer_type_t buft) {
1124
+ return 128;
1125
+
1126
+ GGML_UNUSED(buft);
1127
+ }
1128
+
1129
+ /**
1130
+ * @brief Calculates the allocation size required for a tensor in a CANN buffer.
1131
+ *
1132
+ * Computes the total allocation size needed for storing the tensor's data in a
1133
+ * CANN buffer, considering any necessary padding or adjustments for quantized
1134
+ * types.
1135
+ *
1136
+ * @param buft Pointer to the buffer type context (unused in this
1137
+ * implementation).
1138
+ * @param tensor Pointer to the tensor for which the allocation size is
1139
+ * calculated.
1140
+ * @return The total allocation size in bytes required for the tensor in the
1141
+ * CANN buffer.
1142
+ */
1143
+ GGML_CALL static size_t ggml_backend_cann_buffer_type_get_alloc_size(
1144
+ ggml_backend_buffer_type_t buft, const ggml_tensor* tensor) {
1145
+ size_t size = ggml_nbytes(tensor);
1146
+ int64_t ne0 = tensor->ne[0];
1147
+
1148
+ // last line must bigger than 32, because every single op deal at
1149
+ // least 32 bytes.
1150
+ // TODO: quantized type?
1151
+ // int64_t line_size = ne0 * ggml_element_size(tensor);
1152
+ // int64_t line_size_align_32 = (line_size + 31) & ~31;
1153
+ // size += (line_size_align_32 - line_size);
1154
+
1155
+ // TODO: not support quantized yet.
1156
+ // TODO: consider un-continue tensor.
1157
+ if (ggml_is_quantized(tensor->type)) {
1158
+ if (ne0 % MATRIX_ROW_PADDING != 0) {
1159
+ size += ggml_row_size(
1160
+ tensor->type, MATRIX_ROW_PADDING - ne0 % MATRIX_ROW_PADDING);
1161
+ }
1162
+ }
1163
+
1164
+ return size;
1165
+
1166
+ GGML_UNUSED(buft);
1167
+ }
1168
+
1169
+ /**
1170
+ * @brief Interface for managing CANN buffer types in the GGML backend.
1171
+ *
1172
+ * Provides function pointers for allocating, querying properties, and managing
1173
+ * memory for CANN buffer types in the GGML backend.
1174
+ */
1175
+ static ggml_backend_buffer_type_i ggml_backend_cann_buffer_type_interface = {
1176
+ /* .get_name = */ ggml_backend_cann_buffer_type_name,
1177
+ /* .alloc_buffer = */ ggml_backend_cann_buffer_type_alloc_buffer,
1178
+ /* .get_alignment = */ ggml_backend_cann_buffer_type_get_alignment,
1179
+ /* .get_max_size = */ NULL, // defaults to SIZE_MAX
1180
+ /* .get_alloc_size = */ ggml_backend_cann_buffer_type_get_alloc_size,
1181
+ /* .is_host = */ NULL,
1182
+ };
1183
+
1184
+ /**
1185
+ * @brief Retrieves the CANN buffer type for a specified device.
1186
+ *
1187
+ * This function initializes and returns the buffer type interface associated
1188
+ * with the given device. It ensures thread-safe access using a mutex.
1189
+ *
1190
+ * @param device The device index for which to retrieve the buffer type.
1191
+ * @return A pointer to the buffer type interface for the specified device, or
1192
+ * nullptr if the device index is out of range.
1193
+ */
1194
+ GGML_CALL ggml_backend_buffer_type_t
1195
+ ggml_backend_cann_buffer_type(int32_t device) {
1196
+ static std::mutex mutex;
1197
+ std::lock_guard<std::mutex> lock(mutex);
1198
+
1199
+ if (device >= ggml_backend_cann_get_device_count()) {
1200
+ return nullptr;
1201
+ }
1202
+
1203
+ static ggml_backend_buffer_type
1204
+ ggml_backend_cann_buffer_types[GGML_CANN_MAX_DEVICES];
1205
+
1206
+ static bool ggml_backend_cann_buffer_type_initialized = false;
1207
+
1208
+ if (!ggml_backend_cann_buffer_type_initialized) {
1209
+ for (int32_t i = 0; i < GGML_CANN_MAX_DEVICES; i++) {
1210
+ ggml_backend_cann_buffer_types[i] = {
1211
+ /* .iface = */ ggml_backend_cann_buffer_type_interface,
1212
+ /* .context = */
1213
+ new ggml_backend_cann_buffer_type_context{
1214
+ i, "CANN" + std::to_string(i)},
1215
+ };
1216
+ }
1217
+ ggml_backend_cann_buffer_type_initialized = true;
1218
+ }
1219
+
1220
+ return &ggml_backend_cann_buffer_types[device];
1221
+ }
1222
+
1223
+ /**
1224
+ * @brief Computes the forward operation for a given tensor using CANN
1225
+ * operations.
1226
+ *
1227
+ * This function selects the appropriate CANN operation based on the type of
1228
+ * operation specified in the tensor and performs the computation.
1229
+ *
1230
+ * @param ctx The CANN context containing necessary resources and
1231
+ * configurations.
1232
+ * @param dst The destination tensor where the result of the computation will be
1233
+ * stored.
1234
+ * @return true if the computation was successful; false otherwise.
1235
+ */
1236
+ static bool ggml_cann_compute_forward(ggml_backend_cann_context& ctx,
1237
+ struct ggml_tensor* dst) {
1238
+ switch (dst->op) {
1239
+ case GGML_OP_REPEAT:
1240
+ ggml_cann_repeat(ctx, dst);
1241
+ break;
1242
+ case GGML_OP_GET_ROWS:
1243
+ ggml_cann_get_rows(ctx, dst);
1244
+ break;
1245
+ case GGML_OP_DUP:
1246
+ ggml_cann_dup(ctx, dst);
1247
+ break;
1248
+ case GGML_OP_ADD:
1249
+ ggml_cann_add(ctx, dst);
1250
+ break;
1251
+ case GGML_OP_ACC:
1252
+ ggml_cann_acc(ctx, dst);
1253
+ break;
1254
+ case GGML_OP_MUL:
1255
+ ggml_cann_mul_div<aclnnMulGetWorkspaceSize, aclnnMul>(ctx, dst);
1256
+ break;
1257
+ case GGML_OP_DIV:
1258
+ ggml_cann_mul_div<aclnnDivGetWorkspaceSize, aclnnDiv>(ctx, dst);
1259
+ break;
1260
+ case GGML_OP_UNARY:
1261
+ switch (ggml_get_unary_op(dst)) {
1262
+ case GGML_UNARY_OP_GELU:
1263
+ ggml_cann_activation<aclnnGeluGetWorkspaceSize, aclnnGelu>(
1264
+ ctx, dst);
1265
+ break;
1266
+ case GGML_UNARY_OP_SILU:
1267
+ ggml_cann_activation<aclnnSiluGetWorkspaceSize, aclnnSilu>(
1268
+ ctx, dst);
1269
+ break;
1270
+ // TODO: Use faster gelu??
1271
+ case GGML_UNARY_OP_GELU_QUICK:
1272
+ ggml_cann_activation<aclnnGeluGetWorkspaceSize, aclnnGelu>(
1273
+ ctx, dst);
1274
+ break;
1275
+ case GGML_UNARY_OP_TANH:
1276
+ ggml_cann_activation<aclnnTanhGetWorkspaceSize, aclnnTanh>(
1277
+ ctx, dst);
1278
+ break;
1279
+ case GGML_UNARY_OP_RELU:
1280
+ ggml_cann_activation<aclnnReluGetWorkspaceSize, aclnnRelu>(
1281
+ ctx, dst);
1282
+ break;
1283
+ case GGML_UNARY_OP_HARDSIGMOID:
1284
+ ggml_cann_activation<aclnnHardsigmoidGetWorkspaceSize,
1285
+ aclnnHardsigmoid>(ctx, dst);
1286
+ break;
1287
+ case GGML_UNARY_OP_HARDSWISH:
1288
+ ggml_cann_activation<aclnnHardswishGetWorkspaceSize,
1289
+ aclnnHardswish>(ctx, dst);
1290
+ break;
1291
+ default:
1292
+ return false;
1293
+ }
1294
+ break;
1295
+ case GGML_OP_NORM:
1296
+ ggml_cann_norm(ctx, dst);
1297
+ break;
1298
+ case GGML_OP_GROUP_NORM:
1299
+ ggml_cann_group_norm(ctx, dst);
1300
+ break;
1301
+ case GGML_OP_CONCAT:
1302
+ ggml_cann_concat(ctx, dst);
1303
+ break;
1304
+ case GGML_OP_UPSCALE:
1305
+ ggml_cann_upsample_nearest2d(ctx, dst);
1306
+ break;
1307
+ case GGML_OP_PAD:
1308
+ ggml_cann_pad(ctx, dst);
1309
+ break;
1310
+ case GGML_OP_ARANGE:
1311
+ ggml_cann_arange(ctx, dst);
1312
+ break;
1313
+ case GGML_OP_TIMESTEP_EMBEDDING:
1314
+ ggml_cann_timestep_embedding(ctx, dst);
1315
+ break;
1316
+ case GGML_OP_LEAKY_RELU:
1317
+ ggml_cann_leaky_relu(ctx, dst);
1318
+ break;
1319
+ case GGML_OP_RMS_NORM:
1320
+ ggml_cann_rms_norm(ctx, dst);
1321
+ break;
1322
+ case GGML_OP_MUL_MAT:
1323
+ ggml_cann_mul_mat(ctx, dst);
1324
+ break;
1325
+ case GGML_OP_MUL_MAT_ID:
1326
+ return false;
1327
+ case GGML_OP_SCALE:
1328
+ ggml_cann_scale(ctx, dst);
1329
+ break;
1330
+ case GGML_OP_SQR:
1331
+ ggml_cann_sqr(ctx, dst);
1332
+ break;
1333
+ case GGML_OP_CLAMP:
1334
+ ggml_cann_clamp(ctx, dst);
1335
+ break;
1336
+ case GGML_OP_CPY:
1337
+ ggml_cann_cpy(ctx, dst);
1338
+ break;
1339
+ case GGML_OP_CONT:
1340
+ ggml_cann_dup(ctx, dst);
1341
+ break;
1342
+ case GGML_OP_NONE:
1343
+ case GGML_OP_RESHAPE:
1344
+ case GGML_OP_VIEW:
1345
+ case GGML_OP_PERMUTE:
1346
+ case GGML_OP_TRANSPOSE:
1347
+ break;
1348
+ case GGML_OP_DIAG_MASK_INF:
1349
+ ggml_cann_diag_mask(ctx, dst, -INFINITY);
1350
+ break;
1351
+ case GGML_OP_SOFT_MAX:
1352
+ ggml_cann_softmax(ctx, dst);
1353
+ break;
1354
+ case GGML_OP_ROPE:
1355
+ ggml_cann_rope(ctx, dst);
1356
+ break;
1357
+ case GGML_OP_IM2COL:
1358
+ ggml_cann_im2col(ctx, dst);
1359
+ break;
1360
+ case GGML_OP_POOL_2D:
1361
+ ggml_cann_pool2d(ctx, dst);
1362
+ break;
1363
+ case GGML_OP_SUM_ROWS:
1364
+ ggml_cann_sum_rows(ctx, dst);
1365
+ break;
1366
+ case GGML_OP_ARGSORT:
1367
+ ggml_cann_argsort(ctx, dst);
1368
+ break;
1369
+ default:
1370
+ return false;
1371
+ }
1372
+
1373
+ return true;
1374
+ }
1375
+
1376
+ // backend
1377
+ /**
1378
+ * @brief Retrieves the name associated with the CANN backend.
1379
+ *
1380
+ * This function returns the name assigned to the CANN backend, which is stored
1381
+ * in the context of the provided backend structure.
1382
+ *
1383
+ * @param backend Pointer to the CANN backend structure.
1384
+ * @return A pointer to a constant string representing the backend name.
1385
+ */
1386
+ GGML_CALL static const char* ggml_backend_cann_name(ggml_backend_t backend) {
1387
+ ggml_backend_cann_context* cann_ctx =
1388
+ (ggml_backend_cann_context*)backend->context;
1389
+
1390
+ return cann_ctx->name.c_str();
1391
+ }
1392
+
1393
+ /**
1394
+ * @brief Frees resources associated with the CANN backend.
1395
+ *
1396
+ * This function releases resources associated with the CANN backend context
1397
+ * and resets the device associated with the backend to its initial state.
1398
+ *
1399
+ * @param backend Pointer to the CANN backend structure to be freed.
1400
+ */
1401
+ GGML_CALL static void ggml_backend_cann_free(ggml_backend_t backend) {
1402
+ ggml_backend_cann_context* cann_ctx =
1403
+ (ggml_backend_cann_context*)backend->context;
1404
+ ACL_CHECK(aclrtSynchronizeDevice());
1405
+ ACL_CHECK(aclrtResetDevice(cann_ctx->device));
1406
+
1407
+ // finalize when last backend freed.
1408
+ if (cann_ctx->device == ggml_backend_cann_get_device_count() - 1) {
1409
+ ACL_CHECK(aclFinalize());
1410
+ }
1411
+
1412
+ delete cann_ctx;
1413
+ delete backend;
1414
+ }
1415
+
1416
+ /**
1417
+ * @brief Retrieves the default buffer type associated with the CANN backend.
1418
+ *
1419
+ * This function returns the buffer type specific to the device associated
1420
+ * with the CANN backend. It is used to allocate buffers for computations
1421
+ * performed by the backend.
1422
+ *
1423
+ * @param backend Pointer to the CANN backend structure.
1424
+ * @return Pointer to the buffer type structure for the CANN backend.
1425
+ */
1426
+ GGML_CALL static ggml_backend_buffer_type_t
1427
+ ggml_backend_cann_get_default_buffer_type(ggml_backend_t backend) {
1428
+ ggml_backend_cann_context* cann_ctx =
1429
+ (ggml_backend_cann_context*)backend->context;
1430
+
1431
+ return ggml_backend_cann_buffer_type(cann_ctx->device);
1432
+ }
1433
+
1434
+ /**
1435
+ * @brief Sets tensor data asynchronously in the CANN backend.
1436
+ *
1437
+ * This function asynchronously sets tensor data in the CANN backend. Depending
1438
+ * on the tensor type, it may perform data transformations before copying data
1439
+ * to the device.
1440
+ *
1441
+ * @param backend Pointer to the CANN backend structure.
1442
+ * @param tensor Pointer to the tensor structure to set data for.
1443
+ * @param data Pointer to the host data to copy to the tensor.
1444
+ * @param offset Offset in bytes within the host data.
1445
+ * @param size Size of the data to copy in bytes.
1446
+ */
1447
+ GGML_CALL static void ggml_backend_cann_set_tensor_async(ggml_backend_t backend,
1448
+ ggml_tensor *tensor,
1449
+ const void *data,
1450
+ size_t offset,
1451
+ size_t size) {
1452
+ ggml_backend_cann_context *cann_ctx =
1453
+ (ggml_backend_cann_context *)backend->context;
1454
+
1455
+ if (!need_transform(tensor->type)) {
1456
+ ACL_CHECK(aclrtMemcpyAsync((char *)tensor->data + offset, size, data,
1457
+ size, ACL_MEMCPY_HOST_TO_DEVICE,
1458
+ cann_ctx->stream()));
1459
+ } else {
1460
+ void *transform_buffer = malloc(size);
1461
+ ggml_backend_cann_transform(tensor, data, transform_buffer);
1462
+
1463
+ #ifndef NDEBUG
1464
+ void *check_buffer = malloc(size);
1465
+ ggml_backend_cann_transform_back(tensor, transform_buffer,
1466
+ check_buffer);
1467
+ GGML_ASSERT(memcmp(data, check_buffer, size));
1468
+ free(check_buffer);
1469
+ #endif
1470
+ ACL_CHECK(aclrtMemcpyAsync(
1471
+ (char *)tensor->data + offset, size, transform_buffer, size,
1472
+ ACL_MEMCPY_HOST_TO_DEVICE, cann_ctx->stream()));
1473
+ ACL_CHECK(aclrtSynchronizeStream(cann_ctx->stream()));
1474
+ free(transform_buffer);
1475
+ }
1476
+ }
1477
+
1478
+ GGML_CALL static void ggml_backend_cann_get_tensor_async(
1479
+ ggml_backend_t backend, const ggml_tensor *tensor, void *data,
1480
+ size_t offset, size_t size) {
1481
+ ggml_backend_cann_context *cann_ctx =
1482
+ (ggml_backend_cann_context *)backend->context;
1483
+ ggml_backend_buffer_t buf =
1484
+ tensor->view_src ? tensor->view_src->buffer : tensor->buffer;
1485
+
1486
+ GGML_ASSERT(buf->buft == ggml_backend_cann_buffer_type(cann_ctx->device) &&
1487
+ "unsupported buffer type");
1488
+
1489
+ if (!need_transform(tensor->type)) {
1490
+ ACL_CHECK(aclrtMemcpyAsync(data, size, (char *)tensor->data + offset,
1491
+ size, ACL_MEMCPY_DEVICE_TO_HOST,
1492
+ cann_ctx->stream()));
1493
+ } else {
1494
+ void *transform_buffer = malloc(size);
1495
+ ACL_CHECK(aclrtMemcpyAsync(
1496
+ transform_buffer, size, (char *)tensor->data + offset, size,
1497
+ ACL_MEMCPY_DEVICE_TO_HOST, cann_ctx->stream()));
1498
+ ACL_CHECK(aclrtSynchronizeStream(cann_ctx->stream()));
1499
+ ggml_backend_cann_transform_back(tensor, transform_buffer, data);
1500
+ free(transform_buffer);
1501
+ }
1502
+ }
1503
+
1504
+ /**
1505
+ * @brief Asynchronously copies tensor data between CANN backends.
1506
+ *
1507
+ * This function copies tensor data asynchronously between two CANN backends. It
1508
+ * checks if both tensors reside in CANN buffers and whether the devices support
1509
+ * peer-to-peer access for direct copying. If not, it returns false.
1510
+ *
1511
+ * @param backend_src Pointer to the source CANN backend structure.
1512
+ * @param backend_dst Pointer to the destination CANN backend structure.
1513
+ * @param src Pointer to the source tensor to copy data from.
1514
+ * @param dst Pointer to the destination tensor to copy data to.
1515
+ * @return true if the copy operation succeeds, false otherwise.
1516
+ */
1517
+ GGML_CALL static bool ggml_backend_cann_cpy_tensor_async(
1518
+ ggml_backend_t backend_src, ggml_backend_t backend_dst,
1519
+ const ggml_tensor* src, ggml_tensor* dst) {
1520
+ GGML_ASSERT(ggml_backend_is_cann(backend_src) ||
1521
+ ggml_backend_is_cann(backend_dst));
1522
+
1523
+ if (!ggml_backend_buffer_is_cann(src->buffer) ||
1524
+ !ggml_backend_buffer_is_cann(dst->buffer)) {
1525
+ return false;
1526
+ }
1527
+
1528
+ ggml_backend_buffer_t buf_src =
1529
+ src->view_src ? src->view_src->buffer : src->buffer;
1530
+ ggml_backend_buffer_t buf_dst =
1531
+ dst->view_src ? dst->view_src->buffer : dst->buffer;
1532
+
1533
+ ggml_backend_cann_context* cann_ctx_src =
1534
+ (ggml_backend_cann_context*)backend_src->context;
1535
+ ggml_backend_cann_context* cann_ctx_dst =
1536
+ (ggml_backend_cann_context*)backend_dst->context;
1537
+
1538
+ size_t copy_size = ggml_nbytes(dst);
1539
+ if (backend_src != backend_dst) {
1540
+ ggml_backend_cann_buffer_context* buf_ctx_src =
1541
+ (ggml_backend_cann_buffer_context*)buf_src->context;
1542
+ ggml_backend_cann_buffer_context* buf_ctx_dst =
1543
+ (ggml_backend_cann_buffer_context*)buf_dst->context;
1544
+
1545
+ GGML_ASSERT(cann_ctx_src->device == buf_ctx_src->device);
1546
+ GGML_ASSERT(cann_ctx_dst->device == buf_ctx_dst->device);
1547
+
1548
+ int32_t canAccessPeer = 0;
1549
+ ACL_CHECK(aclrtDeviceCanAccessPeer(&canAccessPeer, cann_ctx_src->device,
1550
+ cann_ctx_dst->device));
1551
+ if (!canAccessPeer) {
1552
+ return false;
1553
+ }
1554
+
1555
+ // need open both directions for memcpyasync between devices.
1556
+ ggml_cann_set_device(cann_ctx_dst->device);
1557
+ ACL_CHECK(aclrtDeviceEnablePeerAccess(cann_ctx_src->device, 0));
1558
+ ggml_cann_set_device(cann_ctx_src->device);
1559
+ ACL_CHECK(aclrtDeviceEnablePeerAccess(cann_ctx_dst->device, 0));
1560
+
1561
+ ACL_CHECK(aclrtMemcpyAsync(dst->data, copy_size, src->data, copy_size,
1562
+ ACL_MEMCPY_DEVICE_TO_DEVICE,
1563
+ cann_ctx_src->stream()));
1564
+
1565
+ //TODO: workaround for Event didn`t work here.
1566
+ aclrtSynchronizeStream(cann_ctx_src->stream());
1567
+ } else {
1568
+ // src and dst are on the same backend
1569
+ ACL_CHECK(aclrtMemcpyAsync(dst->data, copy_size, src->data, copy_size,
1570
+ ACL_MEMCPY_DEVICE_TO_DEVICE,
1571
+ cann_ctx_dst->stream()));
1572
+ }
1573
+
1574
+ return true;
1575
+ }
1576
+
1577
+ /**
1578
+ * @brief Synchronizes a CANN backend.
1579
+ *
1580
+ * This function synchronizes the specified CANN backend by waiting for all
1581
+ * operations in its associated stream to complete.
1582
+ *
1583
+ * @param backend Pointer to the CANN backend structure to synchronize.
1584
+ */
1585
+ GGML_CALL static void ggml_backend_cann_synchronize(ggml_backend_t backend) {
1586
+ ggml_backend_cann_context* cann_ctx =
1587
+ (ggml_backend_cann_context*)backend->context;
1588
+
1589
+ ggml_cann_set_device(cann_ctx->device);
1590
+
1591
+ ACL_CHECK(aclrtSynchronizeStream(cann_ctx->stream()));
1592
+ }
1593
+
1594
+ /**
1595
+ * @brief Computes a computational graph using a CANN backend.
1596
+ *
1597
+ * This function computes the operations defined in the computational graph
1598
+ * using the specified CANN backend.
1599
+ *
1600
+ * @param backend Pointer to the CANN backend structure to use for computation.
1601
+ * @param cgraph Pointer to the computational graph structure containing nodes
1602
+ * representing operations to be computed.
1603
+ * @return enum ggml_status Returns GGML_STATUS_SUCCESS if computation
1604
+ * completes successfully, otherwise an appropriate error status.
1605
+ */
1606
+ GGML_CALL static enum ggml_status ggml_backend_cann_graph_compute(
1607
+ ggml_backend_t backend, ggml_cgraph* cgraph) {
1608
+ ggml_backend_cann_context* cann_ctx =
1609
+ (ggml_backend_cann_context*)backend->context;
1610
+
1611
+ ggml_cann_set_device(cann_ctx->device);
1612
+
1613
+ for (int i = 0; i < cgraph->n_nodes; i++) {
1614
+ ggml_tensor* node = cgraph->nodes[i];
1615
+
1616
+ if (ggml_is_empty(node) || node->op == GGML_OP_NONE) {
1617
+ continue;
1618
+ }
1619
+
1620
+ bool ok = ggml_cann_compute_forward(*cann_ctx, node);
1621
+
1622
+ if (!ok) {
1623
+ GGML_CANN_LOG_ERROR("%s: error: op not supported %s (%s)\n", __func__,
1624
+ node->name, ggml_op_name(node->op));
1625
+ }
1626
+ GGML_ASSERT(ok);
1627
+ }
1628
+
1629
+ return GGML_STATUS_SUCCESS;
1630
+ }
1631
+
1632
+ /**
1633
+ * @brief Checks if the CANN backend supports a specific operation.
1634
+ *
1635
+ * This function checks whether the specified operation is supported by the
1636
+ * CANN backend.
1637
+ *
1638
+ * @param backend Pointer to the CANN backend structure to check support for
1639
+ * the operation.
1640
+ * @param op Pointer to the tensor representing the operation to check.
1641
+ * @return bool Returns true if the operation is supported by the backend,
1642
+ * otherwise false.
1643
+ */
1644
+ GGML_CALL static bool ggml_backend_cann_supports_op(ggml_backend_t backend,
1645
+ const ggml_tensor* op) {
1646
+ switch (op->op) {
1647
+ case GGML_OP_UNARY:
1648
+ switch (ggml_get_unary_op(op)) {
1649
+ case GGML_UNARY_OP_GELU:
1650
+ case GGML_UNARY_OP_SILU:
1651
+ case GGML_UNARY_OP_RELU:
1652
+ case GGML_UNARY_OP_HARDSIGMOID:
1653
+ case GGML_UNARY_OP_HARDSWISH:
1654
+ case GGML_UNARY_OP_GELU_QUICK:
1655
+ case GGML_UNARY_OP_TANH:
1656
+ return true;
1657
+ default:
1658
+ return false;
1659
+ }
1660
+ case GGML_OP_MUL_MAT: {
1661
+ switch (op->src[0]->type) {
1662
+ case GGML_TYPE_F16:
1663
+ case GGML_TYPE_F32:
1664
+ case GGML_TYPE_Q8_0:
1665
+ // TODO: fix me
1666
+ // Current groupsize should not be greater than k-1 in
1667
+ // aclnnWeightQuantBatchMatmulV2GetWorkspaceSize().
1668
+ case GGML_TYPE_Q4_0:
1669
+ return true;
1670
+ default:
1671
+ return false;
1672
+ }
1673
+ }
1674
+ case GGML_OP_MUL_MAT_ID:
1675
+ return false;
1676
+ // embedding
1677
+ case GGML_OP_GET_ROWS: {
1678
+ switch (op->src[0]->type) {
1679
+ case GGML_TYPE_F32:
1680
+ case GGML_TYPE_F16:
1681
+ case GGML_TYPE_Q4_0:
1682
+ case GGML_TYPE_Q8_0:
1683
+ return true;
1684
+ default:
1685
+ return false;
1686
+ }
1687
+ } break;
1688
+ case GGML_OP_CPY: {
1689
+ switch (op->type) {
1690
+ case GGML_TYPE_F32:
1691
+ case GGML_TYPE_F16:
1692
+ case GGML_TYPE_Q8_0:
1693
+ case GGML_TYPE_Q4_0:
1694
+ return true;
1695
+ default:
1696
+ return false;
1697
+ }
1698
+ }
1699
+ case GGML_OP_DUP:
1700
+ case GGML_OP_REPEAT:
1701
+ case GGML_OP_CONCAT:
1702
+ case GGML_OP_NONE:
1703
+ case GGML_OP_RESHAPE:
1704
+ case GGML_OP_VIEW:
1705
+ case GGML_OP_PERMUTE:
1706
+ case GGML_OP_TRANSPOSE:
1707
+ case GGML_OP_NORM:
1708
+ case GGML_OP_ADD:
1709
+ case GGML_OP_MUL:
1710
+ case GGML_OP_DIV:
1711
+ case GGML_OP_RMS_NORM:
1712
+ case GGML_OP_SCALE:
1713
+ case GGML_OP_SQR:
1714
+ case GGML_OP_CLAMP:
1715
+ case GGML_OP_CONT:
1716
+ case GGML_OP_DIAG_MASK_INF:
1717
+ case GGML_OP_SOFT_MAX:
1718
+ case GGML_OP_ROPE:
1719
+ case GGML_OP_IM2COL:
1720
+ case GGML_OP_POOL_2D:
1721
+ case GGML_OP_SUM_ROWS:
1722
+ case GGML_OP_ARGSORT:
1723
+ case GGML_OP_ACC:
1724
+ case GGML_OP_GROUP_NORM:
1725
+ case GGML_OP_UPSCALE:
1726
+ case GGML_OP_PAD:
1727
+ case GGML_OP_ARANGE:
1728
+ case GGML_OP_TIMESTEP_EMBEDDING:
1729
+ case GGML_OP_LEAKY_RELU:
1730
+ return true;
1731
+ default:
1732
+ return false;
1733
+ }
1734
+
1735
+ GGML_UNUSED(backend);
1736
+ }
1737
+
1738
+ /**
1739
+ * @brief Checks if the backend buffer type is associated with the CANN backend.
1740
+ *
1741
+ * This function checks whether the provided backend buffer type is associated
1742
+ * with the CANN backend based on the comparison of its name retrieval function
1743
+ * pointer.
1744
+ *
1745
+ * @param buft Pointer to the backend buffer type to check.
1746
+ * @return bool Returns true if the buffer type is associated with the CANN
1747
+ * backend, otherwise false.
1748
+ */
1749
+ static bool ggml_backend_buft_is_cann(ggml_backend_buffer_type_t buft) {
1750
+ return buft->iface.get_name == ggml_backend_cann_buffer_type_name;
1751
+ }
1752
+
1753
+ /**
1754
+ * @brief Checks if the CANN backend supports a specific backend buffer type.
1755
+ *
1756
+ * This function determines whether the CANN backend supports the given backend
1757
+ * buffer type by comparing the device context of the backend and buffer type.
1758
+ * It returns true if the devices are same between the backend context and
1759
+ * buffer type context.
1760
+ *
1761
+ * @param backend Pointer to the CANN backend.
1762
+ * @param buft Pointer to the backend buffer type to check.
1763
+ * @return bool Returns true if the CANN backend supports the buffer type,
1764
+ * otherwise false.
1765
+ */
1766
+ GGML_CALL static bool ggml_backend_cann_supports_buft(
1767
+ ggml_backend_t backend, ggml_backend_buffer_type_t buft) {
1768
+ if (ggml_backend_buft_is_cann(buft)) {
1769
+ ggml_backend_cann_context * cann_ctx =
1770
+ (ggml_backend_cann_context *)backend->context;
1771
+ ggml_backend_cann_buffer_type_context * buft_ctx =
1772
+ (ggml_backend_cann_buffer_type_context *)buft->context;
1773
+ return buft_ctx->device == cann_ctx->device;
1774
+ }
1775
+ return false;
1776
+ }
1777
+
1778
+ /**
1779
+ * @brief Determines if a tensor operation should be offloaded to the CANN
1780
+ * backend.
1781
+ *
1782
+ * This function checks if a given tensor operation should be offloaded to the
1783
+ * CANN backend based on the operation type and the size of the tensor. It
1784
+ * returns true if the second dimension (ne[1]) of the tensor is greater than or
1785
+ * equal to the minimum batch size and the operation is not GGML_OP_GET_ROWS.
1786
+ *
1787
+ * @param backend Pointer to the CANN backend.
1788
+ * @param op Pointer to the tensor operation to check.
1789
+ * @return bool Returns true if the operation should be offloaded, otherwise
1790
+ * false.
1791
+ */
1792
+ GGML_CALL static bool ggml_backend_cann_offload_op(ggml_backend_t backend,
1793
+ const ggml_tensor* op) {
1794
+ const int min_batch_size = 32;
1795
+ GGML_UNUSED(backend);
1796
+
1797
+ return op->ne[1] >= min_batch_size && op->op != GGML_OP_GET_ROWS;
1798
+ }
1799
+
1800
+ /**
1801
+ * @brief Creates a new event for the CANN backend.
1802
+ *
1803
+ * This function initializes a new event for the CANN backend by setting the
1804
+ * device and creating an ACL runtime event. The created event is then wrapped
1805
+ * in a ggml_backend_event structure and returned.
1806
+ *
1807
+ * @param backend Pointer to the CANN backend.
1808
+ * @return ggml_backend_event_t Returns a pointer to the new event structure.
1809
+ */
1810
+ static ggml_backend_event_t ggml_backend_cann_event_new(
1811
+ ggml_backend_t backend) {
1812
+ ggml_backend_cann_context* cann_ctx =
1813
+ (ggml_backend_cann_context*)backend->context;
1814
+
1815
+ ggml_cann_set_device(cann_ctx->device);
1816
+
1817
+ aclrtEvent event;
1818
+ ACL_CHECK(aclrtCreateEvent(&event));
1819
+
1820
+ return new ggml_backend_event{
1821
+ /* .backend = */ backend,
1822
+ /* .context = */ event,
1823
+ };
1824
+ }
1825
+
1826
+ /**
1827
+ * @brief Frees a CANN backend event.
1828
+ *
1829
+ * This function destroys the ACL runtime event associated with the given CANN
1830
+ * backend event and then deletes the event structure itself.
1831
+ *
1832
+ * @param event Pointer to the event structure to be freed.
1833
+ */
1834
+ static void ggml_backend_cann_event_free(ggml_backend_event_t event) {
1835
+ ACL_CHECK(aclrtDestroyEvent((aclrtEvent)event->context));
1836
+
1837
+ delete event;
1838
+ }
1839
+
1840
+ /**
1841
+ * @brief Records an event on the CANN backend stream.
1842
+ *
1843
+ * This function records the given event on the ACL runtime stream associated
1844
+ * with the backend context.
1845
+ *
1846
+ * @param event Pointer to the event structure to be recorded.
1847
+ */
1848
+ static void ggml_backend_cann_event_record(ggml_backend_event_t event) {
1849
+ ggml_backend_cann_context* cann_ctx =
1850
+ (ggml_backend_cann_context*)event->backend->context;
1851
+
1852
+ ACL_CHECK(aclrtRecordEvent((aclrtEvent)event->context, cann_ctx->stream()));
1853
+ }
1854
+
1855
+ /**
1856
+ * @brief Waits for a recorded event to complete on the CANN backend stream.
1857
+ *
1858
+ * This function makes the given backend wait for the event to complete on its
1859
+ * ACL runtime stream.
1860
+ *
1861
+ * @param backend Pointer to the backend structure.
1862
+ * @param event Pointer to the event structure that the backend needs to wait
1863
+ * for.
1864
+ */
1865
+ static void ggml_backend_cann_event_wait(ggml_backend_t backend,
1866
+ ggml_backend_event_t event) {
1867
+ ggml_backend_cann_context* cann_ctx =
1868
+ (ggml_backend_cann_context*)backend->context;
1869
+
1870
+ if (ggml_backend_is_cann(event->backend)) {
1871
+ ACL_CHECK(aclrtStreamWaitEvent(cann_ctx->stream(),
1872
+ (aclrtEvent)event->context));
1873
+ } else {
1874
+ GGML_ABORT("fatal error");
1875
+ }
1876
+ }
1877
+
1878
+ /**
1879
+ * @brief Synchronizes the given event on the CANN backend.
1880
+ *
1881
+ * This function waits for the specified event to complete on the ACL runtime.
1882
+ *
1883
+ * @param event Pointer to the event structure to be synchronized.
1884
+ */
1885
+ static void ggml_backend_cann_event_synchronize(ggml_backend_event_t event) {
1886
+ ACL_CHECK(aclrtSynchronizeEvent((aclrtEvent)event->context));
1887
+ }
1888
+
1889
+ /**
1890
+ * @brief Structure defining the interface for the CANN backend.
1891
+ *
1892
+ * This structure contains function pointers for various operations
1893
+ * supported by the CANN backend, including name retrieval, memory
1894
+ * management, tensor operations, synchronization, and event handling.
1895
+ */
1896
+ static ggml_backend_i ggml_backend_cann_interface = {
1897
+ /* .get_name = */ ggml_backend_cann_name,
1898
+ /* .free = */ ggml_backend_cann_free,
1899
+ /* .get_default_buffer_type = */ ggml_backend_cann_get_default_buffer_type,
1900
+ /* .set_tensor_async = */ ggml_backend_cann_set_tensor_async,
1901
+ /* .get_tensor_async = */ ggml_backend_cann_get_tensor_async,
1902
+ /* .cpy_tensor_async = */ ggml_backend_cann_cpy_tensor_async,
1903
+ /* .synchronize = */ ggml_backend_cann_synchronize,
1904
+ /* .graph_plan_create = */ NULL,
1905
+ /* .graph_plan_free = */ NULL,
1906
+ /* .graph_plan_update = */ NULL,
1907
+ /* .graph_plan_compute = */ NULL,
1908
+ /* .graph_compute = */ ggml_backend_cann_graph_compute,
1909
+ /* .supports_op = */ ggml_backend_cann_supports_op,
1910
+ /* .supports_buft = */ ggml_backend_cann_supports_buft,
1911
+ /* .offload_op = */ ggml_backend_cann_offload_op,
1912
+ /* .event_new = */ ggml_backend_cann_event_new,
1913
+ /* .event_free = */ ggml_backend_cann_event_free,
1914
+ /* .event_record = */ ggml_backend_cann_event_record,
1915
+ /* .event_wait = */ ggml_backend_cann_event_wait,
1916
+ /* .event_synchronize = */ ggml_backend_cann_event_synchronize,
1917
+ };
1918
+
1919
+ /**
1920
+ * @brief Return the hardcoded GUID for the CANN backend.
1921
+ *
1922
+ * This function returns a static GUID which uniquely identifies the CANN
1923
+ * backend.
1924
+ *
1925
+ * @return A pointer to the static GUID.
1926
+ */
1927
+ static ggml_guid_t ggml_backend_cann_guid() {
1928
+ static ggml_guid guid = {0xa1, 0x94, 0xaf, 0xac, 0xbd, 0x4f, 0x47, 0x34,
1929
+ 0xbe, 0x1a, 0x9e, 0x71, 0x1f, 0x9e, 0xed, 0x64};
1930
+ return &guid;
1931
+ }
1932
+
1933
+ GGML_CALL ggml_backend_t ggml_backend_cann_init(int32_t device) {
1934
+ aclInit(nullptr);
1935
+ if (device < 0 || device >= ggml_backend_cann_get_device_count()) {
1936
+ GGML_CANN_LOG_ERROR("%s: error: invalid device %d\n", __func__, device);
1937
+ return nullptr;
1938
+ }
1939
+
1940
+ ggml_backend_cann_context* ctx = new ggml_backend_cann_context(device);
1941
+ if (ctx == nullptr) {
1942
+ GGML_CANN_LOG_ERROR("%s: error: failed to allocate context\n", __func__);
1943
+ return nullptr;
1944
+ }
1945
+
1946
+ ggml_backend_t cann_backend =
1947
+ new ggml_backend{/* .guid = */ ggml_backend_cann_guid(),
1948
+ /* .interface = */ ggml_backend_cann_interface,
1949
+ /* .context = */ ctx};
1950
+
1951
+ return cann_backend;
1952
+ }
1953
+
1954
+ GGML_CALL bool ggml_backend_is_cann(ggml_backend_t backend) {
1955
+ return backend != NULL &&
1956
+ ggml_guid_matches(backend->guid, ggml_backend_cann_guid());
1957
+ }
1958
+
1959
+ GGML_CALL int32_t ggml_backend_cann_get_device_count() {
1960
+ return ggml_cann_info().device_count;
1961
+ }
1962
+
1963
+ GGML_CALL void ggml_backend_cann_get_device_description(
1964
+ int32_t device, char* description, size_t description_size) {
1965
+ ggml_cann_set_device(device);
1966
+ const char* soc_name = aclrtGetSocName();
1967
+ snprintf(description, description_size, "%s", soc_name);
1968
+ }
1969
+
1970
+ GGML_CALL void ggml_backend_cann_get_device_memory(int32_t device, size_t* free,
1971
+ size_t* total) {
1972
+ ggml_cann_set_device(device);
1973
+ ACL_CHECK(aclrtGetMemInfo(ACL_HBM_MEM, free, total));
1974
+ }
1975
+
1976
+ // backend registry
1977
+ /**
1978
+ * @brief Initializes a CANN backend based on the provided parameters.
1979
+ *
1980
+ * This function initializes a CANN backend using the device index and then
1981
+ * initializes the backend using `ggml_backend_cann_init`.
1982
+ *
1983
+ * @param params Parameters for initialization (unused in this implementation).
1984
+ * @param user_data User data containing the device index to initialize the
1985
+ * backend.
1986
+ * @return ggml_backend_t The initialized CANN backend.
1987
+ */
1988
+ GGML_CALL static ggml_backend_t ggml_backend_reg_cann_init(const char* params,
1989
+ void* user_data) {
1990
+ ggml_backend_t cann_backend =
1991
+ ggml_backend_cann_init((int)(intptr_t)user_data);
1992
+ return cann_backend;
1993
+
1994
+ GGML_UNUSED(params);
1995
+ }
1996
+
1997
+ extern "C" GGML_CALL int ggml_backend_cann_reg_devices();
1998
+
1999
+ /**
2000
+ * @brief Registers CANN (Ascend) devices as backend options.
2001
+ *
2002
+ * This function initializes ACL, retrieves the number of available CANN
2003
+ * devices, and registers each device as a backend option using
2004
+ * `ggml_backend_register`. Each device is given a unique name based on
2005
+ * `GGML_CANN_NAME` followed by its index.
2006
+ *
2007
+ * @return int The number of CANN devices registered.
2008
+ */
2009
+ GGML_CALL int ggml_backend_cann_reg_devices() {
2010
+ uint32_t device_count = ggml_backend_cann_get_device_count();
2011
+ // initialization
2012
+ for (uint32_t i = 0; i < device_count; i++) {
2013
+ char name[128];
2014
+ snprintf(name, sizeof(name), "CANN%d", i);
2015
+ ggml_backend_register(name, ggml_backend_reg_cann_init,
2016
+ ggml_backend_cann_buffer_type(i),
2017
+ (void*)(intptr_t)i);
2018
+ }
2019
+ return device_count;
2020
+ }
ggml/src/ggml-cann/Doxyfile ADDED
The diff for this file is too large to render. See raw diff
 
ggml/src/ggml-cann/acl_tensor.cpp ADDED
@@ -0,0 +1,175 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*
2
+ * Copyright (c) 2023-2024 The ggml authors
3
+ *
4
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
5
+ * of this software and associated documentation files (the "Software"), to
6
+ * deal in the Software without restriction, including without limitation the
7
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
8
+ * sell copies of the Software, and to permit persons to whom the Software is
9
+ * furnished to do so, subject to the following conditions:
10
+ *
11
+ * The above copyright notice and this permission notice shall be included in
12
+ * all copies or substantial portions of the Software.
13
+ *
14
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
15
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
16
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
17
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
18
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
19
+ * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
20
+ * IN THE SOFTWARE.
21
+ */
22
+
23
+ #include "acl_tensor.h"
24
+
25
+ #include <algorithm>
26
+ #include <cstring>
27
+
28
+ aclDataType ggml_cann_type_mapping(ggml_type type) {
29
+ switch (type) {
30
+ case GGML_TYPE_F32:
31
+ return ACL_FLOAT;
32
+ case GGML_TYPE_F16:
33
+ return ACL_FLOAT16;
34
+ case GGML_TYPE_I8:
35
+ return ACL_INT8;
36
+ case GGML_TYPE_I16:
37
+ return ACL_INT16;
38
+ case GGML_TYPE_I32:
39
+ return ACL_INT32;
40
+ case GGML_TYPE_Q4_0:
41
+ return ACL_INT4;
42
+ case GGML_TYPE_Q8_0:
43
+ return ACL_INT8;
44
+ default:
45
+ return ACL_DT_UNDEFINED;
46
+ }
47
+ return ACL_DT_UNDEFINED;
48
+ }
49
+
50
+ aclTensor* ggml_cann_create_tensor(const ggml_tensor* tensor, int64_t* ne,
51
+ size_t* nb, int64_t dims, aclFormat format,
52
+ size_t offset) {
53
+ // If tensor is bcasted, Up to GGML_MAX_DIMS additional dimensions will be
54
+ // added.
55
+ int64_t acl_ne[GGML_MAX_DIMS * 2], acl_stride[GGML_MAX_DIMS * 2];
56
+
57
+ int64_t acl_storage_len = 0;
58
+ if (ne == nullptr) {
59
+ acl_storage_len = ggml_nbytes(tensor);
60
+ for (int i = 0; i < GGML_MAX_DIMS; i++) {
61
+ acl_ne[i] = tensor->ne[i];
62
+ // The step size of acl is in elements.
63
+ acl_stride[i] = tensor->nb[i] / ggml_element_size(tensor);
64
+ }
65
+ } else {
66
+ // With bcast
67
+ for (int i = 0; i < dims; i++) {
68
+ acl_storage_len += (ne[i] - 1) * nb[i];
69
+ acl_ne[i] = ne[i];
70
+ acl_stride[i] = nb[i] / ggml_element_size(tensor);
71
+ }
72
+ }
73
+
74
+ // Reverse ne and stride.
75
+ int64_t final_dims = (dims == 0 ? GGML_MAX_DIMS : dims);
76
+ std::reverse(acl_ne, acl_ne + final_dims);
77
+ std::reverse(acl_stride, acl_stride + final_dims);
78
+
79
+ aclTensor* acl_tensor = aclCreateTensor(
80
+ acl_ne, final_dims, ggml_cann_type_mapping(tensor->type), acl_stride,
81
+ offset / ggml_element_size(tensor), format, &acl_storage_len, 1,
82
+ tensor->data);
83
+
84
+ return acl_tensor;
85
+ }
86
+
87
+ bool ggml_cann_need_bcast(const ggml_tensor* t0, const ggml_tensor* t1) {
88
+ for (int i = 0; i < GGML_MAX_DIMS; i++) {
89
+ if (t1->ne[i] != t0->ne[i] && t1->ne[i] != 1) {
90
+ return true;
91
+ }
92
+ }
93
+ return false;
94
+ }
95
+
96
+ int64_t ggml_cann_get_bcast_shape(const ggml_tensor* src0,
97
+ const ggml_tensor* src1,
98
+ int64_t* bcast_src0_ne,
99
+ int64_t* bcast_src1_ne, size_t* bcast_src0_nb,
100
+ size_t* bcast_src1_nb) {
101
+ GGML_ASSERT(ggml_can_repeat(src1, src0));
102
+ int bcast_dim_cnt = 0;
103
+ for (int i = 0; i < GGML_MAX_DIMS; i++) {
104
+ int64_t nr = src0->ne[i] / src1->ne[i];
105
+ bcast_src0_ne[bcast_dim_cnt] = src0->ne[i] / nr;
106
+ bcast_src1_ne[bcast_dim_cnt] = src1->ne[i];
107
+ bcast_src0_nb[bcast_dim_cnt] = src0->nb[i];
108
+ bcast_src1_nb[bcast_dim_cnt] = src1->nb[i];
109
+ bcast_dim_cnt++;
110
+ if (nr != 1) {
111
+ // Need to add an extra dim.
112
+ bcast_src0_ne[bcast_dim_cnt] = nr;
113
+ bcast_src1_ne[bcast_dim_cnt] = 1;
114
+ bcast_src0_nb[bcast_dim_cnt] = bcast_src0_nb[bcast_dim_cnt - 1] *
115
+ bcast_src0_ne[bcast_dim_cnt - 1];
116
+ bcast_src1_nb[bcast_dim_cnt] = bcast_src1_nb[bcast_dim_cnt - 1] *
117
+ bcast_src1_ne[bcast_dim_cnt - 1];
118
+ bcast_dim_cnt++;
119
+ }
120
+ }
121
+ return bcast_dim_cnt;
122
+ }
123
+
124
+ int64_t ggml_cann_get_mulmat_bcast_shape(
125
+ const int64_t* input_ne, const int64_t* weight_ne, const int64_t* dst_ne,
126
+ const size_t* input_nb, const size_t* weight_nb, const size_t* dst_nb,
127
+ int64_t* bcast_input_ne, int64_t* bcast_weight_ne, int64_t* bcast_dst_ne,
128
+ size_t* bcast_input_nb, size_t* bcast_weight_nb, size_t* bcast_dst_nb) {
129
+ // input and dst shoule in same shape, except first two dims.
130
+ GGML_ASSERT(input_ne[2] == dst_ne[2]);
131
+ GGML_ASSERT(input_ne[3] == dst_ne[3]);
132
+
133
+ int bcast_dim_cnt = 0;
134
+
135
+ // For mul_mat, a dimension needs to be added before the dimension that
136
+ // weight needs to be expanded to satisfy the bcast rule of matrix
137
+ // multiplication.
138
+ for (int i = 0; i < GGML_MAX_DIMS; i++) {
139
+ int64_t nr = input_ne[i] / weight_ne[i];
140
+ // Do not use bcast in the first two dimensions because we only support
141
+ // the bcast batch dimension. Just copy them.
142
+ if (i < 2 || nr == 1) {
143
+ bcast_input_ne[bcast_dim_cnt] = input_ne[i];
144
+ bcast_weight_ne[bcast_dim_cnt] = weight_ne[i];
145
+ bcast_dst_ne[bcast_dim_cnt] = dst_ne[i];
146
+
147
+ bcast_input_nb[bcast_dim_cnt] = input_nb[i];
148
+ bcast_weight_nb[bcast_dim_cnt] = weight_nb[i];
149
+ bcast_dst_nb[bcast_dim_cnt] = dst_nb[i];
150
+ bcast_dim_cnt++;
151
+ } else {
152
+ // Need to add an extra dim.
153
+ bcast_input_ne[bcast_dim_cnt] = nr;
154
+ bcast_dst_ne[bcast_dim_cnt] = nr;
155
+ bcast_weight_ne[bcast_dim_cnt] = 1;
156
+ bcast_input_nb[bcast_dim_cnt] = input_nb[i];
157
+ bcast_dst_nb[bcast_dim_cnt] = dst_nb[i];
158
+ bcast_weight_nb[bcast_dim_cnt] = weight_nb[i];
159
+ bcast_dim_cnt++;
160
+
161
+ bcast_input_ne[bcast_dim_cnt] = input_ne[i] / nr;
162
+ bcast_dst_ne[bcast_dim_cnt] = dst_ne[i] / nr;
163
+ bcast_weight_ne[bcast_dim_cnt] = weight_ne[i];
164
+ bcast_input_nb[bcast_dim_cnt] = bcast_input_nb[bcast_dim_cnt - 1] *
165
+ bcast_input_ne[bcast_dim_cnt - 1];
166
+ bcast_dst_nb[bcast_dim_cnt] = bcast_dst_nb[bcast_dim_cnt - 1] *
167
+ bcast_dst_ne[bcast_dim_cnt - 1];
168
+ bcast_weight_nb[bcast_dim_cnt] =
169
+ bcast_weight_nb[bcast_dim_cnt - 1] *
170
+ bcast_weight_ne[bcast_dim_cnt - 1];
171
+ bcast_dim_cnt++;
172
+ }
173
+ }
174
+ return bcast_dim_cnt;
175
+ }
ggml/src/ggml-cann/acl_tensor.h ADDED
@@ -0,0 +1,258 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*
2
+ * Copyright (c) 2023-2024 The ggml authors
3
+ *
4
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
5
+ * of this software and associated documentation files (the "Software"), to
6
+ * deal in the Software without restriction, including without limitation the
7
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
8
+ * sell copies of the Software, and to permit persons to whom the Software is
9
+ * furnished to do so, subject to the following conditions:
10
+ *
11
+ * The above copyright notice and this permission notice shall be included in
12
+ * all copies or substantial portions of the Software.
13
+ *
14
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
15
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
16
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
17
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
18
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
19
+ * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
20
+ * IN THE SOFTWARE.
21
+ */
22
+
23
+ #ifndef CANN_ACL_TENSOR_H
24
+ #define CANN_ACL_TENSOR_H
25
+
26
+ #include <algorithm>
27
+ #include <cstring>
28
+
29
+ #include <aclnn/aclnn_base.h>
30
+ #include "common.h"
31
+
32
+ /**
33
+ * @brief Maps a ggml_type to its corresponding aclDataType.
34
+ *
35
+ * @details This function takes a ggml_type as input and returns the corresponding
36
+ * aclDataType. It supports mapping for various ggml_types. If the input type
37
+ * does not match any of the predefined ggml_types, the function returns
38
+ * ACL_DT_UNDEFINED.
39
+ *
40
+ * @param type The ggml_type to be mapped.
41
+ * @return The corresponding aclDataType. If the input type is not recognized,
42
+ * ACL_DT_UNDEFINED is returned.
43
+ */
44
+ aclDataType ggml_cann_type_mapping(ggml_type type);
45
+
46
+ /**
47
+ * @brief Creates an ACL tensor from a ggml_tensor with optional shape.
48
+ *
49
+ * @details This function creates an ACL tensor based on the properties of the
50
+ * provided ggml_tensor. It supports customer shape by adjusting dimensions
51
+ * and strides accordingly. If customer shape is applied, additional
52
+ * dimensions and strides are calculated based on the provided parameters.
53
+ *
54
+ * @param tensor Pointer to the ggml_tensor to be converted to ACL tensor.
55
+ * @param ne Pointer to an array containing dimensions. Defaults to nullptr
56
+ * if no customer shape is applied.
57
+ * @param nb Pointer to an array containing strides. Defaults to nullptr
58
+ * if no customer shape is applied.
59
+ * @param dims Number of dimensions in the tensor. Defaults to 0 if no customer
60
+ * shape is applied.
61
+ * @param format ACL tensor format. Defaults to ACL_FORMAT_ND.
62
+ * @param offset Offset in bytes for the ACL tensor data. Defaults to 0.
63
+ * @return Pointer to the created ACL tensor.
64
+ */
65
+ aclTensor* ggml_cann_create_tensor(const ggml_tensor* tensor, int64_t* ne = nullptr,
66
+ size_t* nb = nullptr, int64_t dims = 0,
67
+ aclFormat format = ACL_FORMAT_ND,
68
+ size_t offset = 0);
69
+
70
+ /**
71
+ * @brief Template for creating an ACL tensor from provided parameters. typename TYPE
72
+ * should be size_t or float.
73
+ *
74
+ * @details This function creates an ACL tensor using the provided data pointer,
75
+ * data type, dimensions, strides, format, offset, and additional parameters.
76
+ * It calculates necessary dimensions and strides based on the provided ne and nb
77
+ * arrays, adjusting them for the ACL tensor creation. The ACL storage length
78
+ * is also calculated based on the provided dimensions and strides.
79
+ *
80
+ * @param data_ptr Pointer to the data buffer for the ACL tensor.
81
+ * @param dtype ACL data type of the tensor.
82
+ * @param type_size Size of each element in the tensor data buffer.
83
+ * @param ne Pointer to an array containing tensor dimensions.
84
+ * @param nb Pointer to an array containing tensor strides.
85
+ * @param dims Number of dimensions of the tensor.
86
+ * @param format ACL tensor format. Defaults to ACL_FORMAT_ND.
87
+ * @param offset Offset in bytes for the ACL tensor data. Defaults to 0.
88
+ * @return Pointer to the created ACL tensor.
89
+ */
90
+ template<typename TYPE>
91
+ aclTensor* ggml_cann_create_tensor(void* data_ptr, aclDataType dtype,
92
+ TYPE type_size, int64_t* ne, TYPE* nb,
93
+ int64_t dims,
94
+ aclFormat format = ACL_FORMAT_ND,
95
+ size_t offset = 0) {
96
+ int64_t tmp_ne[GGML_MAX_DIMS * 2];
97
+ int64_t tmp_stride[GGML_MAX_DIMS * 2];
98
+
99
+ memcpy(tmp_ne, ne, dims * sizeof(int64_t));
100
+ for (int i = 0; i < dims; i++) {
101
+ tmp_stride[i] = nb[i] / type_size;
102
+ }
103
+
104
+ std::reverse(tmp_ne, tmp_ne + dims);
105
+ std::reverse(tmp_stride, tmp_stride + dims);
106
+
107
+ int64_t acl_storage_len = 0;
108
+ for (int i = 0; i < dims; i++) {
109
+ acl_storage_len += (ne[i] - 1) * nb[i];
110
+ }
111
+
112
+ aclTensor* acl_tensor =
113
+ aclCreateTensor(tmp_ne, dims, dtype, tmp_stride, offset / type_size,
114
+ format, &acl_storage_len, 1, data_ptr);
115
+
116
+ return acl_tensor;
117
+ }
118
+
119
+ /**
120
+ * @brief Checks if tensors require broadcasting based on their shapes.
121
+ *
122
+ * @details This function determines if two ggml_tensors need to be broadcasted for
123
+ * element-wise operations. Broadcasting is necessary if the shapes of the
124
+ * tensors are not identical and no dimension in either tensor equals 1.
125
+ *
126
+ * @param t0 Pointer to the first ggml_tensor.
127
+ * @param t1 Pointer to the second ggml_tensor.
128
+ * @return True if broadcasting is needed, False otherwise.
129
+ *
130
+ * @remarks This function iterates over the dimensions of t0 and t1. It checks if each
131
+ * dimension in t1 differs from t0's corresponding dimension and is not equal
132
+ * to 1. If such a dimension is found, broadcasting is required to align t1
133
+ * with t0 for element-wise operations.
134
+ */
135
+ bool ggml_cann_need_bcast(const ggml_tensor* t0, const ggml_tensor* t1);
136
+
137
+ /**
138
+ * @brief Computes broadcast shapes and strides for two ggml_tensors.
139
+ *
140
+ * @details This function calculates the broadcast shapes and strides for two ggml_tensors,
141
+ * following the broadcasting rules similar to numpy. It adjusts dimensions and
142
+ * strides to ensure compatibility for element-wise operations where one tensor
143
+ * can be broadcasted to match the shape of another tensor.
144
+ *
145
+ * @param src0 Pointer to the first ggml_tensor.
146
+ * @param src1 Pointer to the second ggml_tensor.
147
+ * @param bcast_ne_src0 Output array to store broadcasted dimensions for src0.
148
+ * @param bcast_ne_src1 Output array to store broadcasted dimensions for src1.
149
+ * @param bcast_nb_src0 Output array to store broadcasted strides for src0.
150
+ * @param bcast_nb_src1 Output array to store broadcasted strides for src1.
151
+ * @return Number of dimensions in the broadcasted shape.
152
+ *
153
+ * @pre ggml_can_repeat(src1, src0) must return true, indicating src1 can be broadcasted
154
+ * to match src0.
155
+ *
156
+ * @remarks This function iterates over the dimensions of src0 and src1, calculating the
157
+ * necessary broadcast dimensions and strides. If a dimension requires broadcasting
158
+ * (i.e., its size in src1 is smaller than in src0), an additional dimension is
159
+ * added with size calculated to match src0's dimension. This adjustment ensures
160
+ * that src1 can be element-wise broadcasted to src0's shape.
161
+ *
162
+ * How it works:
163
+ *
164
+ * if dim0 has padding.
165
+ * a -> (2, 2) padding = 2
166
+ * a: [[1, 2, *, *]
167
+ * [2, 3, *, *]]
168
+ * nb = (8, 4, 2)
169
+ *
170
+ * if a should bcast with b -> (2, 4)
171
+ * b' -> (2, 2, 2)
172
+ * b : [[1, 2, 3, 4, *, *]
173
+ * [5, 6, 7, 8, *, *]]
174
+ * nb = (12, 6, 1)
175
+ *
176
+ * after bcast:
177
+ * a' -> (2, 1, 2)
178
+ * a': [[[1, 2], *, *]
179
+ * [[2, 3], *, *]]
180
+ * nb = (8, 4, 2, 1)
181
+ *
182
+ * b' : [[[1, 2], [3, 4], *, *]
183
+ * [[5, 6], [7, 8], *, *]]
184
+ * nb = (12, 6, 2, 1)
185
+ * \endcode
186
+ *
187
+ * dim1 in a inserted dim, should add nb for dim1,
188
+ * and all other nb moves to next in order.
189
+ */
190
+ int64_t ggml_cann_get_bcast_shape(const ggml_tensor* src0, const ggml_tensor* src1,
191
+ int64_t* bcast_ne_src0, int64_t* bcast_ne_src1,
192
+ size_t* bcast_nb_src0, size_t* bcast_nb_src1);
193
+
194
+ // Bcast macro to avoid duplicate code.
195
+ #define BCAST_SHAPE(src0, src1) \
196
+ int64_t bcast_##src0##_ne[GGML_MAX_DIMS * 2]; \
197
+ int64_t bcast_##src1##_ne[GGML_MAX_DIMS * 2]; \
198
+ size_t bcast_##src0##_nb[GGML_MAX_DIMS * 2]; \
199
+ size_t bcast_##src1##_nb[GGML_MAX_DIMS * 2]; \
200
+ int64_t bcast_dims = ggml_cann_get_bcast_shape( \
201
+ src0, src1, bcast_##src0##_ne, bcast_##src1##_ne, bcast_##src0##_nb, \
202
+ bcast_##src1##_nb);
203
+
204
+ #define BCAST_PARAM(tensor) bcast_##tensor##_ne, bcast_##tensor##_nb, bcast_dims
205
+
206
+ /**
207
+ * @brief Calculates broadcast shapes for matrix multiplication.
208
+ *
209
+ * @details This function computes the broadcast shapes required for matrix multiplication
210
+ * based on the input, weight, and destination tensor shapes. It ensures that the
211
+ * dimensions of weight tensors are expanded appropriately to satisfy matrix
212
+ * multiplication broadcast rules.
213
+ *
214
+ * @param input_ne Array containing the dimensions of the input tensor.
215
+ * @param weight_ne Array containing the dimensions of the weight tensor.
216
+ * @param dst_ne Array containing the dimensions of the destination tensor.
217
+ * @param input_nb Array containing the strides of the input tensor.
218
+ * @param weight_nb Array containing the strides of the weight tensor.
219
+ * @param dst_nb Array containing the strides of the destination tensor.
220
+ * @param bcast_input_ne Output array for broadcasted input tensor dimensions.
221
+ * @param bcast_weight_ne Output array for broadcasted weight tensor dimensions.
222
+ * @param bcast_dst_ne Output array for broadcasted destination tensor dimensions.
223
+ * @param bcast_input_nb Output array for broadcasted input tensor strides.
224
+ * @param bcast_weight_nb Output array for broadcasted weight tensor strides.
225
+ * @param bcast_dst_nb Output array for broadcasted destination tensor strides.
226
+ * @return The number of dimensions in the broadcasted tensors.
227
+ *
228
+ * @remarks This function iterates over the tensor dimensions and calculates the broadcast
229
+ * shapes needed for matrix multiplication. It ensures that dimensions where
230
+ * weight tensor requires expansion are appropriately handled to conform with
231
+ * broadcasting rules.
232
+ * @note compare with ggml_cann_get_bcast_shape, mul_mat broadcast need add this new dim
233
+ * before cast dim.
234
+ * @sa ggml_cann_get_bcast_shape
235
+ */
236
+ int64_t ggml_cann_get_mulmat_bcast_shape(
237
+ const int64_t* input_ne, const int64_t* weight_ne, const int64_t* dst_ne,
238
+ const size_t* input_nb, const size_t* weight_nb, const size_t* dst_nb,
239
+ int64_t* bcast_input_ne, int64_t* bcast_weight_ne, int64_t* bcast_dst_ne,
240
+ size_t* bcast_input_nb, size_t* bcast_weight_nb, size_t* bcast_dst_nb);
241
+
242
+ // Bcast macro to avoid duplicate code.
243
+ #define BCAST_MUL_MAT_SHAPE(input, weight, dst) \
244
+ int64_t bcast_##input##_ne[GGML_MAX_DIMS * 2]; \
245
+ int64_t bcast_##weight##_ne[GGML_MAX_DIMS * 2]; \
246
+ int64_t bcast_##dst##_ne[GGML_MAX_DIMS * 2]; \
247
+ size_t bcast_##input##_nb[GGML_MAX_DIMS * 2]; \
248
+ size_t bcast_##weight##_nb[GGML_MAX_DIMS * 2]; \
249
+ size_t bcast_##dst##_nb[GGML_MAX_DIMS * 2]; \
250
+ int64_t bcast_dims = ggml_cann_get_mulmat_bcast_shape( \
251
+ input->ne, weight->ne, dst->ne, input->nb, weight->nb, dst->nb, \
252
+ bcast_##input##_ne, bcast_##weight##_ne, bcast_##dst##_ne, \
253
+ bcast_##input##_nb, bcast_##weight##_nb, bcast_##dst##_nb);
254
+
255
+ #define BCAST_MUL_MAT_PARAM(tensor) \
256
+ bcast_##tensor##_ne, bcast_##tensor##_nb, bcast_dims
257
+
258
+ #endif // CANN_ACL_TENSOR_H
ggml/src/ggml-cann/aclnn_ops.cpp ADDED
The diff for this file is too large to render. See raw diff
 
ggml/src/ggml-cann/aclnn_ops.h ADDED
@@ -0,0 +1,592 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #ifndef CANN_ACLNN_OPS
2
+ #define CANN_ACLNN_OPS
3
+
4
+ /**
5
+ * @file acl_tensor
6
+ * @brief This file contains related functions of ggml_tensor and acl_tensor.
7
+ * Contains conversion from ggml_tensor to acl_tensor, broadcast and other
8
+ * functions.
9
+ * @author hipudding <[email protected]>
10
+ * @author wangshuai09 <[email protected]>
11
+ * @date July 15, 2024
12
+ *
13
+ * Copyright (c) 2023-2024 The ggml authors
14
+ *
15
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
16
+ * of this software and associated documentation files (the "Software"), to
17
+ * deal in the Software without restriction, including without limitation the
18
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
19
+ * sell copies of the Software, and to permit persons to whom the Software is
20
+ * furnished to do so, subject to the following conditions:
21
+ *
22
+ * The above copyright notice and this permission notice shall be included in
23
+ * all copies or substantial portions of the Software.
24
+ *
25
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
26
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
27
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
28
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
29
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
30
+ * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
31
+ * IN THE SOFTWARE.
32
+ */
33
+
34
+ #include <aclnnop/aclnn_add.h>
35
+ #include <aclnnop/aclnn_arange.h>
36
+ #include <aclnnop/aclnn_argsort.h>
37
+ #include <aclnnop/aclnn_cat.h>
38
+ #include <aclnnop/aclnn_clamp.h>
39
+ #include <aclnnop/aclnn_div.h>
40
+ #include <aclnnop/aclnn_gelu.h>
41
+ #include <aclnnop/aclnn_hardsigmoid.h>
42
+ #include <aclnnop/aclnn_hardswish.h>
43
+ #include <aclnnop/aclnn_leaky_relu.h>
44
+ #include <aclnnop/aclnn_mul.h>
45
+ #include <aclnnop/aclnn_relu.h>
46
+ #include <aclnnop/aclnn_silu.h>
47
+ #include <aclnnop/aclnn_tanh.h>
48
+ #include "acl_tensor.h"
49
+ #include "common.h"
50
+
51
+ /**
52
+ * @brief Repeats a ggml tensor along each dimension to match the dimensions
53
+ * of another tensor.
54
+ *
55
+ * @details This function repeats the elements of a source ggml tensor along
56
+ * each dimension to create a destination tensor with the specified
57
+ * dimensions. The operation is performed using the ACL backend and
58
+ * executed asynchronously on the device.
59
+ *
60
+ * @param ctx The CANN context used for operations.
61
+ * @param dst The ggml tensor representing the destination, which op is
62
+ * GGML_OP_REPEAT and specifies the desired dimensions.
63
+ */
64
+ void ggml_cann_repeat(ggml_backend_cann_context& ctx, ggml_tensor* dst);
65
+
66
+ /**
67
+ * @brief Adds two ggml tensors using the CANN backend.
68
+ *
69
+ * @details This function performs an element-wise addition of two tensors. In
70
+ * case the tensors do not have the same shape, one or both tensors
71
+ * will be broadcasted to match the shape of the other before the
72
+ * addition is performed.The formula for the operation is given by:
73
+ * \f[
74
+ * \text{dst} = \text{acl_src0} + \alpha \cdot \text{acl_src1}
75
+ * \f]
76
+ *
77
+ * @param ctx The CANN context used for operations.
78
+ * @param dst The ggml tensor representing the destination, result of the
79
+ * addition is stored at dst->data, and dst->op is `GGML_OP_ADD`
80
+ */
81
+ void ggml_cann_add(ggml_backend_cann_context& ctx, ggml_tensor* dst);
82
+
83
+ /**
84
+ * @brief Applies the Leaky ReLU activation function to a tensor using the CANN
85
+ * backend.
86
+ *
87
+ * @details This function computes the Leaky ReLU activation for each element of
88
+ * the input tensor. The Leaky ReLU function allows a small gradient
89
+ * when the unit is not active (i.e., when the input is negative). The
90
+ * Leaky ReLU function is defined as:
91
+ * \f[
92
+ * \text{dst} = \max(0, src) + \text{negativeSlope} \cdot \min(0,
93
+ * src)
94
+ * \f]
95
+ * `negativeSlope` is in dst->params.
96
+ *
97
+ * @param ctx The CANN context used for operations.
98
+ * @param dst The destination tensor where the result of the Leaky ReLU
99
+ * activation is stored, which op is `GGML_OP_LEAKY_RELU`
100
+ */
101
+ void ggml_cann_leaky_relu(ggml_backend_cann_context& ctx, ggml_tensor* dst);
102
+
103
+ /**
104
+ * @brief Concatenates multiple tensors along a specified dimension using the
105
+ * CANN backend.
106
+ *
107
+ * @param ctx The CANN context used for operations.
108
+ * @param tensorList A pointer to the list of tensors to be concatenated.
109
+ * @param dst The destination tensor where the result of the
110
+ * concatenation is stored. dst->op is `GGML_OP_CONCAT`.
111
+ * @param concat_dim The dimension along which the tensors are concatenated.
112
+ *
113
+ * @attention tensorList length should be 2 and the dimension using for concat
114
+ * default to 1.
115
+ */
116
+ void ggml_cann_concat(ggml_backend_cann_context& ctx, ggml_tensor* dst);
117
+
118
+ /**
119
+ * @brief Generates a sequence of evenly spaced values within a specified
120
+ * interval for a ggml tensor using the CANN backend.
121
+ *
122
+ * @details This function creates a sequence of numbers over a specified i
123
+ * nterval, starting from `start`, ending before `stop`, and
124
+ * incrementing by `step`. The sequence is stored in the destination
125
+ * tensor `dst`.
126
+ *
127
+ * @param ctx The CANN context used for operations.
128
+ * @param dst The destination tensor where the generated sequence will be stored.
129
+ * `start`, 'stop' and 'step' are in dst->op_params and dst->op is
130
+ * `GGML_OP_ARANGE`.
131
+ */
132
+ void ggml_cann_arange(ggml_backend_cann_context& ctx, ggml_tensor* dst);
133
+
134
+ /**
135
+ * @brief Computes the square of the elements of a ggml tensor using the CANN
136
+ * backend.
137
+ * @details The function sets the second source tensor of the destination
138
+ * tensor `dst` to be equal to the first source tensor. This is
139
+ * effectively squaring the elements since the multiplication becomes
140
+ * `element * element`.
141
+ * @param ctx The CANN context used for operations.
142
+ * @param dst The destination tensor where the squared values will be stored,
143
+ * which dst->op is `GGML_OP_SQR`.
144
+ */
145
+ void ggml_cann_sqr(ggml_backend_cann_context& ctx, ggml_tensor* dst);
146
+
147
+ /**
148
+ * @brief Applies a clamp operation to the elements of a ggml tensor using the
149
+ * CANN backend.
150
+ *
151
+ * @details This function clamps the elements of the input tensor `src` to a
152
+ * specified range defined by `min` and `max` values. The result is
153
+ * stored in the destination tensor `dst`. The operation is defined as:
154
+ * \f[
155
+ * y = \max(\min(x, max\_value), min\_value)
156
+ * \f]
157
+ * where `x` is an element of the input tensor, and `y` is the
158
+ * corresponding element in the output tensor.
159
+ * @param ctx The CANN context used for operations.
160
+ * @param dst The destination tensor where the clamped values will be stored.
161
+ * dst->op is `GGML_OP_CLAMP`, `min` and `max` value is in dst->params.
162
+ */
163
+ void ggml_cann_clamp(ggml_backend_cann_context& ctx, ggml_tensor* dst);
164
+
165
+ /**
166
+ * @brief Scales the elements of a ggml tensor by a constant factor using the
167
+ * CANN backend.
168
+ *
169
+ * @details This function multiplies each element of the input tensor `src` by
170
+ * a scaling factor `scale`, storing the result in the destination
171
+ * tensor `dst`. The operation is defined as:
172
+ * \f[
173
+ * dst = src \times scale
174
+ * \f]
175
+ *
176
+ * @param ctx The CANN context used for operations.
177
+ * @param dst The destination tensor where the scaled values will be stored.
178
+ * dst->op is `GGML_OP_SCALE` and `scale` value is in dst->params.
179
+ */
180
+ void ggml_cann_scale(ggml_backend_cann_context& ctx, ggml_tensor* dst);
181
+
182
+ /**
183
+ * @brief Sorts the elements of a ggml tensor and returns the indices that
184
+ * would sort the tensor using the CANN backend.
185
+ *
186
+ * @details This function performs an argsort operation on the input tensor
187
+ * `src`. It sorts the elements of `src` in either ascending or
188
+ * descending order, depending on the `GGML_SORT_ORDER_DESC`,
189
+ * and returns the indices that would sort the original tensor.
190
+ *
191
+ * @param ctx The CANN context used for operations.
192
+ * @param dst The destination tensor where the sorted indices will be stored.
193
+ * dst->op is `GGML_OP_ARGSORT`.
194
+ */
195
+ void ggml_cann_argsort(ggml_backend_cann_context& ctx, ggml_tensor* dst);
196
+
197
+ /**
198
+ * @brief Computes the Layer Normalization for a ggml tensor using the CANN
199
+ * backend.
200
+ *
201
+ * @details This function applies the Layer Normalization operation on the
202
+ * input tensor `src` and stores the result in the destination tensor
203
+ * `dst`. Layer Normalization normalizes the features at each sample in
204
+ * a mini-batch independently. It is commonly used in neural networks
205
+ * to normalize the activations of a layer by adjusting and scaling
206
+ * the outputs.
207
+ * The operation is defined as:
208
+ * \f[
209
+ * \text { out }=\frac{x-\mathrm{E}[x]}{\sqrt{\text{Var}[x]+eps}}
210
+ * \f]
211
+ * `Var` defaults dst->ne[0]. `eps` is in dst->params.
212
+ *
213
+ * @param ctx The CANN context used for operations.
214
+ * @param dst The destination tensor where the normalized values will be stored.
215
+ * @attention `Var` defaults to dst->ne[0].
216
+ */
217
+ void ggml_cann_norm(ggml_backend_cann_context& ctx, ggml_tensor* dst);
218
+
219
+ /**
220
+ * @brief Computes the Group Normalization for a ggml tensor using the CANN
221
+ * backend.
222
+ *
223
+ * @brief This function applies the Group Normalization operation on the input
224
+ * tensor `src` and stores the result in the destination tensor `dst`.
225
+ * Group Normalization divides the channels into groups and normalizes
226
+ * the features within each group across spatial locations.
227
+ * It is commonly used in convolutional neural networks to improve
228
+ * training stability and performance.
229
+ * The operation is defined as:
230
+ * \f[
231
+ * \text { out }=\frac{x-\mathrm{E}[x]}{\sqrt{\text{Var}[x]+eps}}
232
+ * \f]
233
+ *
234
+ * @param ctx The CANN context used for operations.
235
+ * @param dst The destination tensor where the normalized values will be stored.
236
+ * `n_groups` is in dst->params, which split C channel to `n_groups`.
237
+ * dst->op is `GGML_OP_GROUP_NORM`.
238
+ *
239
+ * @attention eps defaults to 1e-6f.
240
+ */
241
+ void ggml_cann_group_norm(ggml_backend_cann_context& ctx, ggml_tensor* dst);
242
+
243
+ /**
244
+ * @brief Computes the accumulation of tensors using the CANN backend.
245
+ *
246
+ * @details This function performs an accumulation operation on two tensors.
247
+ * Depending on the `inplace` flag, it either updates the destination
248
+ * tensor `dst` in place by adding `alpha * src1` to it, or it creates
249
+ * a new tensor as the result of `src0 + alpha * src1` and stores it in
250
+ * `dst`.
251
+ * The operation is defined as:
252
+ * \f[
253
+ * dst = src0 + alpha \times src1
254
+ * \f]
255
+ * if `inplace` is `true`, `src0` is equal to 'dst'.
256
+ * @param ctx The CANN context used for operations.
257
+ * @param dst The destination tensor where the accumulated values will be stored.
258
+ * `inplace` is in dst->params, and dst->op is `GGML_OP_ACC`.
259
+ */
260
+ void ggml_cann_acc(ggml_backend_cann_context& ctx, ggml_tensor* dst);
261
+
262
+ /**
263
+ * @brief Computes the sum of elements along the last dimension of a ggml tensor
264
+ * using the CANN backend.
265
+ *
266
+ * @details This function performs a reduction sum operation along the last
267
+ * dimension of the input tensor `src`. The result of the sum is stored
268
+ * in the destination tensor `dst`.
269
+ *
270
+ * @param ctx The CANN context used for operations.
271
+ * @param dst The destination tensor where the reduced values will be stored。
272
+ * dst->op is `GGML_OP_SUM_ROWS`.
273
+ *
274
+ * @attention `reduce_dims` defaults to 3, which means the last dimension.
275
+ */
276
+ void ggml_cann_sum_rows(ggml_backend_cann_context& ctx, ggml_tensor* dst);
277
+
278
+ /**
279
+ * @brief Upsamples a ggml tensor using nearest neighbor interpolation using
280
+ * the CANN backend.
281
+ *
282
+ * @details This function performs upsampling of the input tensor `src` using
283
+ * nearest neighbor interpolation. The upsampling is applied to the
284
+ * height and width dimensions (last two dimensions) of the tensor. The
285
+ * result is stored in the destination tensor `dst`, which must have
286
+ * the appropriate dimensions for the upsampled output.
287
+ *
288
+ * @param ctx The CANN context used for operations.
289
+ * @param dst The destination tensor where the upsampled values will be stored.
290
+ * dst->op is `GGML_OP_UPSCALE`.
291
+ */
292
+ void ggml_cann_upsample_nearest2d(ggml_backend_cann_context& ctx,
293
+ ggml_tensor* dst);
294
+
295
+ /**
296
+ * @brief Pads a ggml tensor to match the dimensions of the destination tensor
297
+ * using the CANN backend.
298
+ *
299
+ * @details This function pads the input tensor `src` so that it matches the
300
+ * dimensions of the destination tensor `dst`. The amount of padding
301
+ * is calculated based on the difference in sizes between `src` and
302
+ * `dst` along each dimension. The padded tensor is stored in `dst`.
303
+ *
304
+ * @param ctx The CANN context used for operations.
305
+ * @param dst The destination tensor, which specifies the target dimensions for
306
+ * padding. dst->op is `GGML_OP_PAD`.
307
+ */
308
+ void ggml_cann_pad(ggml_backend_cann_context& ctx, ggml_tensor* dst);
309
+
310
+ /**
311
+ * @brief Executes a 2D pooling operation on a ggml tensor using the CANN
312
+ * backend.
313
+ *
314
+ * @details This function dispatches the execution of a 2D pooling operation on
315
+ * the input tensor `dst`. The type of pooling (average or max) is
316
+ * determined by the `op` parameter, which is read from the operation
317
+ * parameters of `dst`. The function supports average pooling
318
+ * (`GGML_OP_POOL_AVG`) and max pooling (`GGML_OP_POOL_MAX`). If an
319
+ * invalid operation is encountered, the function asserts a failure.
320
+ *
321
+ * @param ctx The CANN context used for operations.
322
+ * @param dst The destination tensor on which the pooling operation is to be
323
+ * performed. dst->op is `GGML_OP_POOL_2D`.
324
+ */
325
+ void ggml_cann_pool2d(ggml_backend_cann_context& ctx, ggml_tensor* dst);
326
+
327
+ /**
328
+ * @brief Duplicates a ggml tensor using the CANN backend.
329
+ *
330
+ * @details This function duplicates the contents of the source tensor `src` to
331
+ * the destination tensor `dst`. The function supports various tensor
332
+ * types and configurations, including handling of extra data, type
333
+ * conversions, and special cases for contiguous and non-contiguous
334
+ * tensors.
335
+ *
336
+ * @param ctx The CANN context used for operations.
337
+ * @param dst The destination tensor where the duplicated data will be stored.
338
+ * dst->op is `GGML_OP_DUP`
339
+ *
340
+ * @attention Only support Fp16/FP32. Not support when src and dst have
341
+ * different shape and dst is no-contiguous.
342
+ * @note: This func need to simplify.
343
+ */
344
+ void ggml_cann_dup(ggml_backend_cann_context& ctx, ggml_tensor* dst);
345
+
346
+ /**
347
+ * @brief Computes the Root Mean Square (RMS) normalization of a ggml tensor
348
+ * using the CANN backend.
349
+ *
350
+ * @details This function applies RMS normalization to the input tensor `src`
351
+ * and stores the result in the destination tensor `dst`. RMS
352
+ * normalization involves computing the root mean square of the input
353
+ * tensor along a specified dimension and then dividing each element of
354
+ * the tensor by this value, adjusted by a small epsilon value to
355
+ * prevent division by zero.
356
+ * The operation is defined as:
357
+ * \f[
358
+ * \text{RmsNorm}\left(x_i\right)=\frac{x_i}{\text{Rms}(\mathbf{x})} g_i,
359
+ * \quad \text { where } \text{Rms}(\mathbf{x})=\sqrt{\frac{1}{n} \sum_{i=1}^n x_i^2+e p s}
360
+ * \f]
361
+ * `eps` is in dst->op_params.
362
+ * @param ctx The CANN context used for operations.
363
+ * @param dst The destination tensor where the normalized values will be stored.
364
+ * dst->op is `GGML_OP_RMS_NORM`.
365
+ */
366
+ void ggml_cann_rms_norm(ggml_backend_cann_context& ctx, ggml_tensor* dst);
367
+
368
+ /**
369
+ * @brief Applies a diagonal mask to the tensor with a specified value.
370
+ *
371
+ * @details This function creates a mask tensor filled with ones, then applies
372
+ * an upper triangular and lower triangular operation to it based on
373
+ * the number of past elements specified. Afterward, it adds the masked
374
+ * tensor to the destination tensor in-place.
375
+ *
376
+ * @param ctx The backend CANN context used for operations.
377
+ * @param dst The destination tensor where the result will be stored. dst->op is
378
+ * `GGML_OP_DIAG_MASK`
379
+ * @param value The value to use for masking.
380
+ */
381
+ void ggml_cann_diag_mask(ggml_backend_cann_context& ctx, ggml_tensor* dst, float value);
382
+
383
+ /**
384
+ * @brief Performs an image-to-column transformation on the input tensor.
385
+ *
386
+ * @details This function takes an input tensor and applies an image-to-column
387
+ * operation, converting spatial dimensions into column-like
388
+ * structures suitable for convolutional operations. It supports both
389
+ * half-precision (F16) and single-precision (F32) floating-point data
390
+ * types.
391
+ *
392
+ * @param ctx The backend CANN context for executing operations.
393
+ * @param dst The destination tensor that stores the result of the operation.
394
+ * dst->op is `GGML_OP_IM2COL`.
395
+ */
396
+ void ggml_cann_im2col(ggml_backend_cann_context& ctx, ggml_tensor* dst);
397
+
398
+ /**
399
+ * @brief Computes time step embeddings using sine and cosine functions.
400
+ *
401
+ * @details This function calculates time step embeddings by applying sine and
402
+ * cosine transformations to a given input tensor, which is typically
403
+ * used in temporal models like diffusion models or transformers to
404
+ * encode time information effectively.
405
+ *
406
+ * @param ctx The backend CANN context for executing operations.
407
+ * @param dst The destination tensor where the result of the embedding operation
408
+ * will be stored. dst->op is `GGML_OP_TIMESTEP_EMBEDDING`.
409
+ */
410
+ void ggml_cann_timestep_embedding(ggml_backend_cann_context& ctx, ggml_tensor* dst);
411
+
412
+ // @see ggml_cann_dup.
413
+ void ggml_cann_cpy(ggml_backend_cann_context& ctx, ggml_tensor* dst);
414
+
415
+ /**
416
+ * @brief Computes the softmax activation with optional masking.
417
+ *
418
+ * @details This function computes the softmax activation over the input tensor,
419
+ * optionally applying a mask and scaling factor. It supports both FP16
420
+ * and FP32 data types and can handle masking by broadcasting the mask
421
+ * across rows if necessary.
422
+ * The function performs the following steps:
423
+ * 1. Multiplies the input tensor by a scale factor.
424
+ * 2. Optionally casts the mask tensor to FP32 if it is in FP16 format.
425
+ * 3. Broadcasts the mask tensor if its dimensions do not match the
426
+ * input tensor's dimensions.
427
+ * 4. Adds the mask to the scaled input tensor.
428
+ * 5. Applies the softmax activation function along the specified
429
+ * dimension.
430
+ *
431
+ * @param ctx The backend CANN context for executing operations.
432
+ * @param dst The destination tensor where the result will be stored. dst->op is
433
+ * `GGML_OP_SOFTMAX`.
434
+ */
435
+ void ggml_cann_softmax(ggml_backend_cann_context& ctx, ggml_tensor* dst);
436
+
437
+ /**
438
+ * @brief Extracts specific rows from a tensor based on indices.
439
+ *
440
+ * @details This function retrieves rows from a source tensor src0 according to
441
+ * the indices provided in another tensor src1 and stores the result in
442
+ * a destination tensor (\p dst). It supports different data types
443
+ * including F32, F16, Q4_0, and Q8_0.
444
+ *
445
+ * @param ctx The backend CANN context for executing operations.
446
+ * @param dst The destination tensor where the extracted rows will be stored.
447
+ * dst->op is `GGML_OP_GET_ROWS`.
448
+ */
449
+ void ggml_cann_get_rows(ggml_backend_cann_context& ctx, ggml_tensor* dst);
450
+
451
+ /**
452
+ * @brief Executes matrix multiplication for the given tensor.
453
+ *
454
+ * @details This function performs matrix multiplication on the source tensors
455
+ * associated with the destination tensor. It supports matrix
456
+ * multiplication F32, F16, and Q8_0.
457
+ *
458
+ * @param ctx The backend CANN context for executing operations.
459
+ * @param dst The destination tensor for storing the result of the matrix
460
+ * multiplication. dst->op is `GGML_OP_MUL_MAT`.
461
+ */
462
+ void ggml_cann_mul_mat(ggml_backend_cann_context& ctx, ggml_tensor* dst);
463
+
464
+ /**
465
+ * @brief Applies Rotary Positional Embedding (RoPE) to the input tensor.
466
+ *
467
+ * @details This function implements the RoPE mechanism, which is a method to
468
+ * encode positional information into sequence data, particularly
469
+ * useful in transformer models. It supports both F32 and F16 data
470
+ * types.
471
+ *
472
+ * @param ctx The backend CANN context for executing operations.
473
+ * @param dst The destination tensor where the RoPE-transformed data will be
474
+ * stored. dst->op is `GGML_OP_ROPE`.
475
+ *
476
+ * @note The function currently does not support cases where the n_dims is less
477
+ * than the input tensor's first dimension.
478
+ * @note The function currently does not support cases where the freq_factors is
479
+ * not NULL.
480
+ * @note The function currently does not support cases where the ext_factor is
481
+ * not equal 0.
482
+ * @note The function currently does not support cases where the freq_scale is
483
+ * not equal 1.
484
+ */
485
+ void ggml_cann_rope(ggml_backend_cann_context& ctx, ggml_tensor* dst);
486
+
487
+ template <aclnnStatus getWorkspaceSize(const aclTensor*, const aclTensor*,
488
+ aclTensor*, uint64_t*, aclOpExecutor**),
489
+ aclnnStatus execute(void*, uint64_t, aclOpExecutor*, aclrtStream)>
490
+ void ggml_cann_mul_div(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
491
+ ggml_tensor* src0 = dst->src[0];
492
+ ggml_tensor* src1 = dst->src[1];
493
+ GGML_ASSERT(ggml_can_repeat(src1, src0) && ggml_are_same_shape(src0, dst));
494
+
495
+ aclTensor* acl_src0;
496
+ aclTensor* acl_src1;
497
+ aclTensor* acl_dst;
498
+
499
+ // Need bcast
500
+ if (!ggml_are_same_shape(src0, src1) && ggml_cann_need_bcast(src0, src1)) {
501
+ BCAST_SHAPE(src0, src1)
502
+ acl_src0 = ggml_cann_create_tensor(src0, BCAST_PARAM(src0));
503
+ acl_src1 = ggml_cann_create_tensor(src1, BCAST_PARAM(src1));
504
+ acl_dst = ggml_cann_create_tensor(dst, BCAST_PARAM(src0));
505
+ } else {
506
+ acl_src0 = ggml_cann_create_tensor(src0);
507
+ acl_src1 = ggml_cann_create_tensor(src1);
508
+ acl_dst = ggml_cann_create_tensor(dst);
509
+ }
510
+
511
+ uint64_t workspaceSize = 0;
512
+ aclOpExecutor* executor;
513
+ void* workspaceAddr = nullptr;
514
+
515
+ ACL_CHECK(getWorkspaceSize(acl_src0, acl_src1, acl_dst, &workspaceSize,
516
+ &executor));
517
+ if (workspaceSize > 0) {
518
+ ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize);
519
+ workspaceAddr = workspace_allocator.get();
520
+ }
521
+
522
+ aclrtStream main_stream = ctx.stream();
523
+ ACL_CHECK(execute(workspaceAddr, workspaceSize, executor, main_stream));
524
+
525
+ ACL_CHECK(aclDestroyTensor(acl_src0));
526
+ ACL_CHECK(aclDestroyTensor(acl_src1));
527
+ ACL_CHECK(aclDestroyTensor(acl_dst));
528
+ }
529
+
530
+ // Activation functions template.
531
+ template <aclnnStatus getWorkspaceSize(const aclTensor*, aclTensor*, uint64_t*,
532
+ aclOpExecutor**),
533
+ aclnnStatus execute(void*, uint64_t, aclOpExecutor*,
534
+ const aclrtStream)>
535
+ void ggml_cann_activation(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
536
+ ggml_tensor* src = dst->src[0];
537
+
538
+ GGML_ASSERT(src->type == GGML_TYPE_F32);
539
+ GGML_ASSERT(dst->type == GGML_TYPE_F32);
540
+
541
+ aclTensor* acl_src = ggml_cann_create_tensor(src);
542
+ aclTensor* acl_dst = ggml_cann_create_tensor(dst);
543
+
544
+ uint64_t workspaceSize = 0;
545
+ aclOpExecutor* executor;
546
+ void* workspaceAddr = nullptr;
547
+
548
+ ACL_CHECK(getWorkspaceSize(acl_src, acl_dst, &workspaceSize, &executor));
549
+ if (workspaceSize > 0) {
550
+ ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize);
551
+ workspaceAddr = workspace_allocator.get();
552
+ }
553
+
554
+ aclrtStream main_stream = ctx.stream();
555
+ ACL_CHECK(execute(workspaceAddr, workspaceSize, executor, main_stream));
556
+
557
+ ACL_CHECK(aclDestroyTensor(acl_src));
558
+ ACL_CHECK(aclDestroyTensor(acl_dst));
559
+ }
560
+
561
+ // Activation functions template for const aclTensors.
562
+ template <aclnnStatus getWorkspaceSize(const aclTensor*, const aclTensor*,
563
+ uint64_t*, aclOpExecutor**),
564
+ aclnnStatus execute(void*, uint64_t, aclOpExecutor*,
565
+ const aclrtStream)>
566
+ void ggml_cann_activation(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
567
+ ggml_tensor* src = dst->src[0];
568
+
569
+ GGML_ASSERT(src->type == GGML_TYPE_F32);
570
+ GGML_ASSERT(dst->type == GGML_TYPE_F32);
571
+
572
+ aclTensor* acl_src = ggml_cann_create_tensor(src);
573
+ aclTensor* acl_dst = ggml_cann_create_tensor(dst);
574
+
575
+ uint64_t workspaceSize = 0;
576
+ aclOpExecutor* executor;
577
+ void* workspaceAddr = nullptr;
578
+
579
+ ACL_CHECK(getWorkspaceSize(acl_src, acl_dst, &workspaceSize, &executor));
580
+ if (workspaceSize > 0) {
581
+ ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize);
582
+ workspaceAddr = workspace_allocator.get();
583
+ }
584
+
585
+ aclrtStream main_stream = ctx.stream();
586
+ ACL_CHECK(execute(workspaceAddr, workspaceSize, executor, main_stream));
587
+
588
+ ACL_CHECK(aclDestroyTensor(acl_src));
589
+ ACL_CHECK(aclDestroyTensor(acl_dst));
590
+ }
591
+
592
+ #endif // CANN_ACLNN_OPS
ggml/src/ggml-cann/common.h ADDED
@@ -0,0 +1,282 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*
2
+ * Copyright (c) 2023-2024 The ggml authors
3
+ *
4
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
5
+ * of this software and associated documentation files (the "Software"), to
6
+ * deal in the Software without restriction, including without limitation the
7
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
8
+ * sell copies of the Software, and to permit persons to whom the Software is
9
+ * furnished to do so, subject to the following conditions:
10
+ *
11
+ * The above copyright notice and this permission notice shall be included in
12
+ * all copies or substantial portions of the Software.
13
+ *
14
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
15
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
16
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
17
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
18
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
19
+ * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
20
+ * IN THE SOFTWARE.
21
+ */
22
+
23
+ #ifndef CANN_COMMON_H
24
+ #define CANN_COMMON_H
25
+
26
+ #include <acl/acl.h>
27
+
28
+ #include <cstdio>
29
+ #include <iostream>
30
+ #include <map>
31
+ #include <memory>
32
+ #include <string>
33
+ #include <vector>
34
+
35
+ #include "../include/ggml-cann.h"
36
+ #include "../include/ggml.h"
37
+
38
+ #define MATRIX_ROW_PADDING 512
39
+ #define GGML_CANN_MAX_STREAMS 8
40
+
41
+ /**
42
+ * @brief Handles CANN-related errors by printing an error message and
43
+ * terminating the program.
44
+ * @param stmt The statement that caused the error.
45
+ * @param func The function in which the error occurred.
46
+ * @param file The file in which the error occurred.
47
+ * @param line The line number at which the error occurred.
48
+ * @param msg The error message.
49
+ */
50
+ [[noreturn]] void ggml_cann_error(const char* stmt, const char* func,
51
+ const char* file, int line, const char* msg);
52
+
53
+ /**
54
+ * @brief Checks the result of a CANN function call and invokes the error
55
+ * handler if the call fails.
56
+ * @param stmt The CANN function call to check.
57
+ * @param success The success code that indicates the call was successful.
58
+ * @param error_fn The function to call to retrieve the error message.
59
+ */
60
+ #define ACL_CHECK_GEN(stmt, success, error_fn) \
61
+ do { \
62
+ int err_code = (stmt); \
63
+ if (err_code != (success)) { \
64
+ ggml_cann_error(#stmt, __func__, __FILE__, __LINE__, error_fn()); \
65
+ } \
66
+ } while (0);
67
+
68
+ #define ACL_CHECK(stmt) ACL_CHECK_GEN(stmt, 0, aclGetRecentErrMsg)
69
+
70
+ /**
71
+ * @brief Contains information about CANN devices.
72
+ */
73
+ struct ggml_cann_device_info {
74
+ /**
75
+ * @brief Number of CANN devices available.
76
+ */
77
+ int32_t device_count;
78
+
79
+ /**
80
+ * @brief Information about a single CANN device.
81
+ */
82
+ struct cann_device_info {
83
+ int cc; /**< Compute capability. */
84
+ size_t smpb; /**< Maximum shared memory per block. */
85
+ bool vmm; /**< Virtual memory support. */
86
+ size_t vmm_granularity; /**< Granularity of virtual memory. */
87
+ size_t total_vram; /**< Total video RAM available on the device. */
88
+ };
89
+
90
+ cann_device_info devices[GGML_CANN_MAX_DEVICES] =
91
+ {}; /**< Array of CANN device information. */
92
+ };
93
+
94
+ const ggml_cann_device_info& ggml_cann_info();
95
+
96
+ void ggml_cann_set_device(int32_t device);
97
+ int32_t ggml_cann_get_device();
98
+
99
+ /**
100
+ * @brief Abstract base class for memory pools used by CANN.
101
+ */
102
+ struct ggml_cann_pool {
103
+ /**
104
+ * @brief Virtual destructor for the memory pool.
105
+ */
106
+ virtual ~ggml_cann_pool() = default;
107
+
108
+ /**
109
+ * @brief Allocates memory from the pool.
110
+ *
111
+ * @param size The size of the memory block to allocate.
112
+ * @param actual_size Pointer to a variable where the actual allocated size
113
+ * will be stored.
114
+ * @return Pointer to the allocated memory block.
115
+ */
116
+ virtual void* alloc(size_t size, size_t* actual_size) = 0;
117
+
118
+ /**
119
+ * @brief Frees a previously allocated memory block.
120
+ *
121
+ * @param ptr Pointer to the memory block to free.
122
+ * @param size Size of the memory block to free.
123
+ * @note Note that all CANN opertors are running async. Make sure memory is
124
+ * still avaiable before this operator finished.
125
+ */
126
+ virtual void free(void* ptr, size_t size) = 0;
127
+ };
128
+
129
+ /**
130
+ * @brief RAII wrapper for managing memory allocations from a CANN memory pool.
131
+ */
132
+ struct ggml_cann_pool_alloc {
133
+ ggml_cann_pool* pool = nullptr; /**< Pointer to the memory pool. */
134
+ void* ptr = nullptr; /**< Pointer to the allocated memory block. */
135
+ size_t actual_size = 0; /**< Actual size of the allocated memory block. */
136
+
137
+ /**
138
+ * @brief Default constructor.
139
+ */
140
+ ggml_cann_pool_alloc() = default;
141
+
142
+ /**
143
+ * @brief Constructor that initializes the memory pool.
144
+ * @param pool Reference to the memory pool.
145
+ */
146
+ explicit ggml_cann_pool_alloc(ggml_cann_pool& pool) : pool(&pool) {}
147
+
148
+ /**
149
+ * @brief Constructor that initializes the memory pool and allocates memory.
150
+ * @param pool Reference to the memory pool.
151
+ * @param size Size of the memory block to allocate.
152
+ */
153
+ ggml_cann_pool_alloc(ggml_cann_pool& pool, size_t size) : pool(&pool) {
154
+ alloc(size);
155
+ }
156
+
157
+ /**
158
+ * @brief Destructor that frees the allocated memory block.
159
+ */
160
+ ~ggml_cann_pool_alloc() {
161
+ if (ptr != nullptr) {
162
+ pool->free(ptr, actual_size);
163
+ }
164
+ }
165
+
166
+ /**
167
+ * @brief Allocates memory from the pool.
168
+ * @param size Size of the memory block to allocate.
169
+ * @return Pointer to the allocated memory block.
170
+ */
171
+ void* alloc(size_t size) {
172
+ GGML_ASSERT(pool != nullptr);
173
+ GGML_ASSERT(ptr == nullptr);
174
+ ptr = pool->alloc(size, &this->actual_size);
175
+ return ptr;
176
+ }
177
+
178
+ /**
179
+ * @brief Allocates memory from a specific memory pool.
180
+ * @param pool Reference to the memory pool.
181
+ * @param size Size of the memory block to allocate.
182
+ * @return Pointer to the allocated memory block.
183
+ */
184
+ void* alloc(ggml_cann_pool& pool, size_t size) {
185
+ this->pool = &pool;
186
+ return alloc(size);
187
+ }
188
+
189
+ /**
190
+ * @brief Gets the pointer to the allocated memory block.
191
+ * @return Pointer to the allocated memory block.
192
+ */
193
+ void* get() { return ptr; }
194
+
195
+ // Deleted copy constructor
196
+ ggml_cann_pool_alloc(const ggml_cann_pool_alloc&) = delete;
197
+
198
+ // Deleted move constructor
199
+ ggml_cann_pool_alloc(ggml_cann_pool_alloc&&) = delete;
200
+
201
+ // Deleted copy assignment operator
202
+ ggml_cann_pool_alloc& operator=(const ggml_cann_pool_alloc&) = delete;
203
+
204
+ // Deleted move assignment operator
205
+ ggml_cann_pool_alloc& operator=(ggml_cann_pool_alloc&&) = delete;
206
+ };
207
+
208
+ /**
209
+ * @brief Context for managing CANN backend operations.
210
+ */
211
+ struct ggml_backend_cann_context {
212
+ int32_t device; /**< Device ID. */
213
+ std::string name; /**< Name of the device. */
214
+ aclrtEvent copy_event = nullptr; /**< Event for managing copy operations. */
215
+
216
+ aclrtStream streams[GGML_CANN_MAX_STREAMS] = {
217
+ {nullptr}}; /**< Array of streams for the device. */
218
+
219
+ /**
220
+ * @brief Constructor for initializing the context with a given device.
221
+ * @param device Device ID.
222
+ */
223
+ explicit ggml_backend_cann_context(int device)
224
+ : device(device), name("CANN" + std::to_string(device)) {}
225
+
226
+ /**
227
+ * @brief Destructor for cleaning up resources.
228
+ */
229
+ ~ggml_backend_cann_context() {
230
+ if (copy_event != nullptr) {
231
+ ACL_CHECK(aclrtDestroyEvent(copy_event));
232
+ }
233
+ for (int i = 0; i < GGML_CANN_MAX_STREAMS; ++i) {
234
+ if (streams[i] != nullptr) {
235
+ ACL_CHECK(aclrtDestroyStream(streams[i]));
236
+ }
237
+ }
238
+ }
239
+
240
+ /**
241
+ * @brief Get or create a stream for a given index.
242
+ * @param stream Index of the stream.
243
+ * @return The stream corresponding to the given index.
244
+ */
245
+ aclrtStream stream(int stream) {
246
+ if (streams[stream] == nullptr) {
247
+ ggml_cann_set_device(device);
248
+ ACL_CHECK(aclrtCreateStream(&streams[stream]));
249
+ }
250
+ return streams[stream];
251
+ }
252
+
253
+ /**
254
+ * @brief Get or create the default stream (index 0).
255
+ * @return The default stream.
256
+ */
257
+ aclrtStream stream() { return stream(0); }
258
+
259
+ // TODO: each stream should have a memory pool.
260
+ std::unique_ptr<ggml_cann_pool>
261
+ mem_pool; /**< Memory pool for the device. */
262
+
263
+ /**
264
+ * @brief Create a new memory pool for a given device.
265
+ * @param device Device ID.
266
+ * @return A unique pointer to the new memory pool.
267
+ */
268
+ static std::unique_ptr<ggml_cann_pool> new_pool_for_device(int device);
269
+
270
+ /**
271
+ * @brief Get or create the memory pool for the context.
272
+ * @return Reference to the memory pool.
273
+ */
274
+ ggml_cann_pool& pool() {
275
+ if (mem_pool == nullptr) {
276
+ mem_pool = new_pool_for_device(device);
277
+ }
278
+ return *mem_pool;
279
+ }
280
+ };
281
+
282
+ #endif // CANN_COMMON_H
ggml/src/ggml-cann/kernels/CMakeLists.txt ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ if (NOT SOC_TYPE)
2
+ set (SOC_TYPE "Ascend910B3")
3
+ endif()
4
+
5
+ file(GLOB SRC_FILES
6
+ get_row_f32.cpp
7
+ get_row_f16.cpp
8
+ get_row_q4_0.cpp
9
+ get_row_q8_0.cpp
10
+ quantize_f32_q8_0.cpp
11
+ quantize_f16_q8_0.cpp
12
+ quantize_float_to_q4_0.cpp
13
+ dup.cpp
14
+ )
15
+
16
+ string(TOLOWER ${SOC_TYPE} SOC_VERSION)
17
+ set(ASCEND_CANN_PACKAGE_PATH ${CANN_INSTALL_DIR})
18
+ set(RUN_MODE "npu" CACHE STRING "run mode: npu/sim")
19
+
20
+ if(EXISTS ${ASCEND_CANN_PACKAGE_PATH}/compiler/tikcpp/ascendc_kernel_cmake)
21
+ set(ASCENDC_CMAKE_DIR ${ASCEND_CANN_PACKAGE_PATH}/compiler/tikcpp/ascendc_kernel_cmake)
22
+ elseif(EXISTS ${ASCEND_CANN_PACKAGE_PATH}/ascendc_devkit/tikcpp/samples/cmake)
23
+ set(ASCENDC_CMAKE_DIR ${ASCEND_CANN_PACKAGE_PATH}/ascendc_devkit/tikcpp/samples/cmake)
24
+ else()
25
+ message(FATAL_ERROR "ascendc_kernel_cmake does not exist, please check whether the compiler package is installed.")
26
+ endif()
27
+ include(${ASCENDC_CMAKE_DIR}/ascendc.cmake)
28
+
29
+ ascendc_library(ascendc_kernels STATIC
30
+ ${SRC_FILES}
31
+ )
32
+
33
+ # ascendc_compile_definitions(ascendc_kernels PRIVATE -DASCENDC_DUMP)
ggml/src/ggml-cann/kernels/ascendc_kernels.h ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #ifndef ASCENDC_KERNELS_H
2
+ #define ASCENDC_KERNELS_H
3
+
4
+ #include "aclrtlaunch_ascendc_get_row_f32.h"
5
+ #include "aclrtlaunch_ascendc_get_row_f16.h"
6
+ #include "aclrtlaunch_ascendc_get_row_q8_0.h"
7
+ #include "aclrtlaunch_ascendc_get_row_q4_0.h"
8
+
9
+ #include "aclrtlaunch_ascendc_quantize_f32_q8_0.h"
10
+ #include "aclrtlaunch_ascendc_quantize_f16_q8_0.h"
11
+ #include "aclrtlaunch_ascendc_quantize_f16_to_q4_0.h"
12
+ #include "aclrtlaunch_ascendc_quantize_f32_to_q4_0.h"
13
+
14
+ #include "aclrtlaunch_ascendc_dup_by_rows_fp16.h"
15
+ #include "aclrtlaunch_ascendc_dup_by_rows_fp32.h"
16
+ #include "aclrtlaunch_ascendc_dup_by_rows_fp32_to_fp16.h"
17
+ #include "aclrtlaunch_ascendc_dup_by_rows_fp16_to_fp32.h"
18
+
19
+ #endif // ASCENDC_KERNELS_H
ggml/src/ggml-cann/kernels/dup.cpp ADDED
@@ -0,0 +1,223 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include "kernel_operator.h"
2
+
3
+ #include <cmath>
4
+
5
+ using namespace AscendC;
6
+
7
+ #define BUFFER_NUM 2
8
+
9
+ template <typename SRC_T, typename DST_T>
10
+ class DupByRows {
11
+ public:
12
+ __aicore__ inline DupByRows() {}
13
+ __aicore__ inline void init(GM_ADDR src, GM_ADDR dst, int64_t *input_ne_ub,
14
+ size_t *input_nb_ub) {
15
+ /* Dup by rows when src is contigous on first dimension and dst is
16
+ contiguous, each kernel process one row.
17
+ */
18
+
19
+ // Input has four dims.
20
+ int64_t op_block_num = GetBlockNum();
21
+ int64_t op_block_idx = GetBlockIdx();
22
+
23
+ // param
24
+ num_rows = input_ne_ub[1] * input_ne_ub[2] * input_ne_ub[3];
25
+ num_elem = input_ne_ub[0];
26
+
27
+ // index for (ne[1], ne[2], ne[3]): (idx_ne1, idx_ne2, idx_ne3)
28
+ idx_ne3 = op_block_idx / (input_ne_ub[1] * input_ne_ub[2]);
29
+ idx_ne2 = (op_block_idx - idx_ne3 * (input_ne_ub[1] * input_ne_ub[2]))
30
+ / (input_ne_ub[1]);
31
+ idx_ne1 = op_block_idx - idx_ne3 * (input_ne_ub[1] * input_ne_ub[2])
32
+ - idx_ne2 * input_ne_ub[1];
33
+
34
+ // src may not contiguous in dim [1,2,3], so stride decited by ne&nb
35
+ src_stride = input_nb_ub[3] * idx_ne3 + input_nb_ub[2] * idx_ne2
36
+ + input_nb_ub[1] * idx_ne1;
37
+
38
+ // dst is contiguous
39
+ dst_stride = op_block_idx * (input_ne_ub[0] * sizeof(DST_T));
40
+
41
+ src_gm.SetGlobalBuffer(reinterpret_cast<__gm__ SRC_T *>(src +
42
+ src_stride));
43
+ dst_gm.SetGlobalBuffer(reinterpret_cast<__gm__ DST_T *>(dst +
44
+ dst_stride));
45
+
46
+ pipe.InitBuffer(src_queue, BUFFER_NUM, (sizeof(SRC_T) * num_elem +
47
+ 32 - 1) / 32 * 32);
48
+ pipe.InitBuffer(dst_queue, BUFFER_NUM, (sizeof(DST_T) * num_elem +
49
+ 32 - 1) / 32 * 32);
50
+ }
51
+
52
+ __aicore__ inline void copy_in() {
53
+ LocalTensor<SRC_T> src_local = src_queue.AllocTensor<SRC_T>();
54
+
55
+ DataCopyExtParams dataCopyParams;
56
+ dataCopyParams.blockCount = 1;
57
+ dataCopyParams.blockLen = num_elem * sizeof(SRC_T);
58
+ DataCopyPadExtParams<SRC_T> padParams;
59
+ DataCopyPad(src_local, src_gm, dataCopyParams, padParams);
60
+
61
+ src_queue.EnQue(src_local);
62
+ }
63
+
64
+ __aicore__ inline void copy_out() {
65
+ LocalTensor<DST_T> dst_local = dst_queue.DeQue<DST_T>();
66
+
67
+ DataCopyExtParams dataCopyParams;
68
+ dataCopyParams.blockCount = 1;
69
+ dataCopyParams.blockLen = num_elem * sizeof(DST_T);
70
+ DataCopyPad(dst_gm, dst_local, dataCopyParams);
71
+
72
+ dst_queue.FreeTensor(dst_local);
73
+ }
74
+
75
+ __aicore__ inline void dup() {
76
+ // main process, copy one row data from src to dst.
77
+ copy_in();
78
+
79
+ LocalTensor<SRC_T> src_local = src_queue.DeQue<SRC_T>();
80
+ LocalTensor<DST_T> dst_local = dst_queue.AllocTensor<DST_T>();
81
+
82
+ int32_t BLOCK_NUM = 32 / sizeof(DST_T);
83
+ DataCopy(dst_local, src_local, (num_elem + BLOCK_NUM - 1)
84
+ / BLOCK_NUM * BLOCK_NUM);
85
+ dst_queue.EnQue<DST_T>(dst_local);
86
+
87
+ src_queue.FreeTensor(src_local);
88
+ copy_out();
89
+ }
90
+
91
+ __aicore__ inline void dup_with_cast() {
92
+ // main process, copy one row data from src to dst.
93
+ // cast dtype from src to dst.
94
+ copy_in();
95
+
96
+ LocalTensor<SRC_T> src_local = src_queue.DeQue<SRC_T>();
97
+ LocalTensor<DST_T> dst_local = dst_queue.AllocTensor<DST_T>();
98
+
99
+ Cast(dst_local, src_local, RoundMode::CAST_NONE, num_elem);
100
+ dst_queue.EnQue<DST_T>(dst_local);
101
+
102
+ src_queue.FreeTensor(src_local);
103
+ copy_out();
104
+ }
105
+
106
+ private:
107
+
108
+ TPipe pipe;
109
+ GlobalTensor<SRC_T> src_gm;
110
+ GlobalTensor<DST_T> dst_gm;
111
+
112
+ int64_t num_rows;
113
+ int64_t num_elem;
114
+ int64_t idx_ne3;
115
+ int64_t idx_ne2;
116
+ int64_t idx_ne1;
117
+ int64_t src_stride;
118
+ int64_t dst_stride;
119
+
120
+ TQue<QuePosition::VECIN, BUFFER_NUM> src_queue;
121
+ TQue<QuePosition::VECOUT, BUFFER_NUM> dst_queue;
122
+ };
123
+
124
+ template <typename T>
125
+ __aicore__ inline void copy_to_ub(GM_ADDR gm, T *ub, size_t size) {
126
+ auto gm_ptr = (__gm__ uint8_t *)gm;
127
+ auto ub_ptr = (uint8_t *)(ub);
128
+ for (int32_t i = 0; i < size; ++i, ++ub_ptr, ++gm_ptr) {
129
+ *ub_ptr = *gm_ptr;
130
+ }
131
+ }
132
+
133
+ extern "C" __global__ __aicore__ void ascendc_dup_by_rows_fp16(
134
+ GM_ADDR src_gm,
135
+ GM_ADDR dst_gm,
136
+ GM_ADDR input_ne_gm,
137
+ GM_ADDR input_nb_gm,
138
+ GM_ADDR output_ne_gm,
139
+ GM_ADDR output_nb_gm) {
140
+
141
+ int64_t input_ne_ub[4];
142
+ size_t input_nb_ub[4];
143
+ int64_t output_ne_ub[4];
144
+ size_t output_nb_ub[4];
145
+
146
+ copy_to_ub(input_ne_gm, input_ne_ub, 32);
147
+ copy_to_ub(input_nb_gm, input_nb_ub, 32);
148
+ copy_to_ub(output_ne_gm, output_ne_ub, 32);
149
+ copy_to_ub(output_nb_gm, output_nb_ub, 32);
150
+
151
+ DupByRows<half, half> op;
152
+ op.init(src_gm, dst_gm, input_ne_ub, input_nb_ub);
153
+ op.dup();
154
+ }
155
+
156
+ extern "C" __global__ __aicore__ void ascendc_dup_by_rows_fp32(
157
+ GM_ADDR src_gm,
158
+ GM_ADDR dst_gm,
159
+ GM_ADDR input_ne_gm,
160
+ GM_ADDR input_nb_gm,
161
+ GM_ADDR output_ne_gm,
162
+ GM_ADDR output_nb_gm) {
163
+ int64_t input_ne_ub[4];
164
+ size_t input_nb_ub[4];
165
+ int64_t output_ne_ub[4];
166
+ size_t output_nb_ub[4];
167
+
168
+ copy_to_ub(input_ne_gm, input_ne_ub, 32);
169
+ copy_to_ub(input_nb_gm, input_nb_ub, 32);
170
+ copy_to_ub(output_ne_gm, output_ne_ub, 32);
171
+ copy_to_ub(output_nb_gm, output_nb_ub, 32);
172
+
173
+ DupByRows<float_t, float_t> op;
174
+ op.init(src_gm, dst_gm, input_ne_ub, input_nb_ub);
175
+ op.dup();
176
+ }
177
+
178
+ extern "C" __global__ __aicore__ void ascendc_dup_by_rows_fp32_to_fp16(
179
+ GM_ADDR src_gm,
180
+ GM_ADDR dst_gm,
181
+ GM_ADDR input_ne_gm,
182
+ GM_ADDR input_nb_gm,
183
+ GM_ADDR output_ne_gm,
184
+ GM_ADDR output_nb_gm) {
185
+
186
+ int64_t input_ne_ub[4];
187
+ size_t input_nb_ub[4];
188
+ int64_t output_ne_ub[4];
189
+ size_t output_nb_ub[4];
190
+
191
+ copy_to_ub(input_ne_gm, input_ne_ub, 32);
192
+ copy_to_ub(input_nb_gm, input_nb_ub, 32);
193
+ copy_to_ub(output_ne_gm, output_ne_ub, 32);
194
+ copy_to_ub(output_nb_gm, output_nb_ub, 32);
195
+
196
+ DupByRows<float_t, half> op;
197
+ op.init(src_gm, dst_gm, input_ne_ub, input_nb_ub);
198
+ op.dup_with_cast();
199
+ }
200
+
201
+ extern "C" __global__ __aicore__ void ascendc_dup_by_rows_fp16_to_fp32(
202
+ GM_ADDR src_gm,
203
+ GM_ADDR dst_gm,
204
+ GM_ADDR input_ne_gm,
205
+ GM_ADDR input_nb_gm,
206
+ GM_ADDR output_ne_gm,
207
+ GM_ADDR output_nb_gm) {
208
+
209
+ // copy params from gm to ub.
210
+ int64_t input_ne_ub[4];
211
+ size_t input_nb_ub[4];
212
+ int64_t output_ne_ub[4];
213
+ size_t output_nb_ub[4];
214
+
215
+ copy_to_ub(input_ne_gm, input_ne_ub, 32);
216
+ copy_to_ub(input_nb_gm, input_nb_ub, 32);
217
+ copy_to_ub(output_ne_gm, output_ne_ub, 32);
218
+ copy_to_ub(output_nb_gm, output_nb_ub, 32);
219
+
220
+ DupByRows<half, float_t> op;
221
+ op.init(src_gm, dst_gm, input_ne_ub, input_nb_ub);
222
+ op.dup_with_cast();
223
+ }
ggml/src/ggml-cann/kernels/get_row_f16.cpp ADDED
@@ -0,0 +1,186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include "kernel_operator.h"
2
+
3
+ // optimize me. Use template to avoid copy code.
4
+ using namespace AscendC;
5
+
6
+ #define BUFFER_NUM 2
7
+
8
+ class GET_ROW_F16 {
9
+ public:
10
+ __aicore__ inline GET_ROW_F16() {}
11
+ __aicore__ inline void init(GM_ADDR input, GM_ADDR indices, GM_ADDR output,
12
+ int64_t *input_ne_ub, size_t *input_nb_ub,
13
+ int64_t *indices_ne_ub, size_t *indices_nb_ub,
14
+ int64_t *output_ne_ub, size_t *output_nb_ub) {
15
+ // TODO, use template for F16/f32
16
+ int64_t op_block_num = GetBlockNum();
17
+ int64_t op_block_idx = GetBlockIdx();
18
+
19
+ for (int i = 0; i < 4; i++) {
20
+ input_ne[i] = input_ne_ub[i];
21
+ input_stride[i] = input_nb_ub[i] / input_nb_ub[0];
22
+
23
+ indices_ne[i] = indices_ne_ub[i];
24
+ indices_stride[i] = indices_nb_ub[i] / indices_nb_ub[0];
25
+
26
+ output_ne[i] = output_ne_ub[i];
27
+ output_stride[i] = output_nb_ub[i] / output_nb_ub[0];
28
+ }
29
+
30
+ // Indices has two dims. n_elements = all rows should get.
31
+ // dr, all rows should this thread get.
32
+ uint64_t n_elements =
33
+ indices_ne[0] * indices_ne[1] * indices_ne[2] * indices_ne[3];
34
+ dr = n_elements / op_block_num;
35
+
36
+ uint64_t tails = n_elements % op_block_num;
37
+ if (op_block_idx < tails) {
38
+ dr += 1;
39
+ ir = dr * op_block_idx;
40
+ } else {
41
+ ir = dr * op_block_idx + tails;
42
+ }
43
+
44
+ input_gm.SetGlobalBuffer((__gm__ half *)input);
45
+ indices_gm.SetGlobalBuffer((__gm__ int32_t *)indices);
46
+ output_gm.SetGlobalBuffer((__gm__ float *)output);
47
+
48
+ uint64_t input_local_buffer_size = ((input_ne[0] * sizeof(half) + 31)
49
+ & ~31);
50
+ uint64_t output_local_buffer_size = ((input_ne[0] * sizeof(float) + 31)
51
+ & ~31);
52
+
53
+ local_buffer_elems = input_local_buffer_size / sizeof(half);
54
+
55
+ // TODO, consider long row that can't put in UB.
56
+ // All data should asign to 32. It's ok because all data is align to 32.
57
+ pipe.InitBuffer(input_queue, BUFFER_NUM, input_local_buffer_size);
58
+ pipe.InitBuffer(output_queue, BUFFER_NUM, output_local_buffer_size);
59
+ }
60
+
61
+ __aicore__ inline void copy_in(uint32_t offset, size_t len) {
62
+ LocalTensor<half> input_local = input_queue.AllocTensor<half>();
63
+ size_t tail = len % 32;
64
+ len = len & ~31;
65
+ DataCopy(input_local, input_gm[offset], len);
66
+ if(tail != 0) {
67
+ DataCopyExtParams dataCopyParams;
68
+ dataCopyParams.blockCount = 1;
69
+ dataCopyParams.blockLen = tail * sizeof(half);
70
+ DataCopyPadExtParams<half> padParams;
71
+ DataCopyPad(input_local[len], input_gm[offset + len],
72
+ dataCopyParams, padParams);
73
+ }
74
+ input_queue.EnQue(input_local);
75
+ }
76
+
77
+ __aicore__ inline void copy_out(uint32_t offset, size_t len) {
78
+ LocalTensor<float> output_local = output_queue.DeQue<float>();
79
+ size_t tail = len % 32;
80
+ len = len & ~31;
81
+ DataCopy(output_gm[offset], output_local, len);
82
+ if(tail != 0) {
83
+ DataCopyExtParams dataCopyParams;
84
+ dataCopyParams.blockCount = 1;
85
+ dataCopyParams.blockLen = tail * sizeof(float);
86
+ DataCopyPad(output_gm[offset + len], output_local[len],
87
+ dataCopyParams);
88
+ }
89
+ output_queue.FreeTensor(output_local);
90
+ }
91
+
92
+ __aicore__ inline void calculate_row(int64_t idx) {
93
+ const int64_t indices_ne2_idx = idx / (indices_ne[0] * indices_ne[1]);
94
+ const int64_t indices_ne1_idx =
95
+ (idx - indices_ne2_idx * indices_ne[0] * indices_ne[1]) /
96
+ indices_ne[0];
97
+ const int64_t indices_ne0_idx =
98
+ (idx - indices_ne2_idx * indices_ne[0] * indices_ne[1] -
99
+ indices_ne1_idx * indices_ne[0]);
100
+
101
+ const int64_t indices_offset = indices_ne0_idx * indices_stride[0] +
102
+ indices_ne1_idx * indices_stride[1] +
103
+ indices_ne2_idx * indices_stride[2];
104
+ const int32_t selected_row_idx = indices_gm.GetValue(indices_offset);
105
+
106
+ const int64_t input_offset = selected_row_idx * input_stride[1] +
107
+ indices_ne1_idx * input_stride[2] +
108
+ indices_ne2_idx * input_stride[3];
109
+
110
+ const int64_t output_offset = indices_ne0_idx * output_stride[1] +
111
+ indices_ne1_idx * output_stride[2] +
112
+ indices_ne2_idx * output_stride[3];
113
+
114
+ copy_in(input_offset, input_ne[0]);
115
+ LocalTensor<half> input_local = input_queue.DeQue<half>();
116
+ LocalTensor<float> output_local = output_queue.AllocTensor<float>();
117
+
118
+ Cast(output_local, input_local, RoundMode::CAST_NONE,
119
+ local_buffer_elems);
120
+ output_queue.EnQue(output_local);
121
+ copy_out(output_offset, input_ne[0]);
122
+
123
+ input_queue.FreeTensor(input_local);
124
+ }
125
+
126
+ __aicore__ inline void calculate() {
127
+ for (int64_t i = ir; i < ir + dr; i++) {
128
+ calculate_row(i);
129
+ }
130
+ }
131
+
132
+ private:
133
+ int64_t input_ne[4];
134
+ size_t input_stride[4];
135
+
136
+ int64_t indices_ne[4];
137
+ size_t indices_stride[4];
138
+
139
+ int64_t output_ne[4];
140
+ size_t output_stride[4];
141
+
142
+ size_t local_buffer_elems;
143
+
144
+ int64_t ir;
145
+ int64_t dr;
146
+
147
+ TPipe pipe;
148
+ GlobalTensor<half> input_gm;
149
+ GlobalTensor<int32_t> indices_gm;
150
+ GlobalTensor<float> output_gm;
151
+ TQue<QuePosition::VECIN, BUFFER_NUM> input_queue;
152
+ TQue<QuePosition::VECOUT, BUFFER_NUM> output_queue;
153
+ };
154
+
155
+ template <typename T>
156
+ __aicore__ inline void copy_to_ub(GM_ADDR gm, T *ub, size_t size) {
157
+ auto gm_ptr = (__gm__ uint8_t *)gm;
158
+ auto ub_ptr = (uint8_t *)(ub);
159
+ for (int32_t i = 0; i < size; ++i, ++ub_ptr, ++gm_ptr) {
160
+ *ub_ptr = *gm_ptr;
161
+ }
162
+ }
163
+
164
+ extern "C" __global__ __aicore__ void ascendc_get_row_f16(
165
+ GM_ADDR input_gm, GM_ADDR indices_gm, GM_ADDR output_gm,
166
+ GM_ADDR input_ne_gm, GM_ADDR input_nb_gm, GM_ADDR indices_ne_gm,
167
+ GM_ADDR indices_nb_gm, GM_ADDR output_ne_gm, GM_ADDR output_nb_gm) {
168
+ int64_t input_ne_ub[4];
169
+ size_t input_nb_ub[4];
170
+ int64_t indices_ne_ub[4];
171
+ size_t indices_nb_ub[4];
172
+ int64_t output_ne_ub[4];
173
+ size_t output_nb_ub[4];
174
+
175
+ copy_to_ub(input_ne_gm, input_ne_ub, 32);
176
+ copy_to_ub(input_nb_gm, input_nb_ub, 32);
177
+ copy_to_ub(indices_ne_gm, indices_ne_ub, 32);
178
+ copy_to_ub(indices_nb_gm, indices_nb_ub, 32);
179
+ copy_to_ub(output_ne_gm, output_ne_ub, 32);
180
+ copy_to_ub(output_nb_gm, output_nb_ub, 32);
181
+
182
+ GET_ROW_F16 op;
183
+ op.init(input_gm, indices_gm, output_gm, input_ne_ub, input_nb_ub,
184
+ indices_ne_ub, indices_nb_ub, output_ne_ub, output_nb_ub);
185
+ op.calculate();
186
+ }
ggml/src/ggml-cann/kernels/get_row_f32.cpp ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include "kernel_operator.h"
2
+
3
+ // optimize me. Use template to avoid copy code.
4
+ using namespace AscendC;
5
+
6
+ #define BUFFER_NUM 2
7
+
8
+ class GET_ROW_F32 {
9
+ public:
10
+ __aicore__ inline GET_ROW_F32() {}
11
+ __aicore__ inline void init(GM_ADDR input, GM_ADDR indices, GM_ADDR output,
12
+ int64_t *input_ne_ub, size_t *input_nb_ub,
13
+ int64_t *indices_ne_ub, size_t *indices_nb_ub,
14
+ int64_t *output_ne_ub, size_t *output_nb_ub) {
15
+ int64_t op_block_num = GetBlockNum();
16
+ int64_t op_block_idx = GetBlockIdx();
17
+
18
+ for (int i = 0; i < 4; i++) {
19
+ input_ne[i] = input_ne_ub[i];
20
+ input_stride[i] = input_nb_ub[i] / input_nb_ub[0];
21
+
22
+ indices_ne[i] = indices_ne_ub[i];
23
+ indices_stride[i] = indices_nb_ub[i] / indices_nb_ub[0];
24
+
25
+ output_ne[i] = output_ne_ub[i];
26
+ output_stride[i] = output_nb_ub[i] / output_nb_ub[0];
27
+ }
28
+
29
+ // Indices has two dims. n_elements = all rows should get.
30
+ // dr, all rows should this thread get.
31
+ uint64_t n_elements =
32
+ indices_ne[0] * indices_ne[1] * indices_ne[2] * indices_ne[3];
33
+ dr = n_elements / op_block_num;
34
+
35
+ uint64_t tails = n_elements % op_block_num;
36
+ if (op_block_idx < tails) {
37
+ dr += 1;
38
+ ir = dr * op_block_idx;
39
+ } else {
40
+ ir = dr * op_block_idx + tails;
41
+ }
42
+
43
+ input_gm.SetGlobalBuffer((__gm__ float *)input);
44
+ indices_gm.SetGlobalBuffer((__gm__ int32_t *)indices);
45
+ output_gm.SetGlobalBuffer((__gm__ float *)output);
46
+
47
+ uint64_t local_buffer_size = ((input_ne[0] * sizeof(float) + 31) & ~31);
48
+ local_buffer_elems = local_buffer_size / sizeof(float);
49
+
50
+ // TODO, consider long row that can't put in UB.
51
+ // All data should asign to 32. It's ok because all data is align to 32.
52
+ pipe.InitBuffer(input_queue, BUFFER_NUM, local_buffer_size);
53
+ pipe.InitBuffer(output_queue, BUFFER_NUM, local_buffer_size);
54
+ }
55
+
56
+ __aicore__ inline void copy_in(uint32_t offset, size_t len) {
57
+ LocalTensor<float> input_local = input_queue.AllocTensor<float>();
58
+ size_t tail = len % 32;
59
+ len = len & ~31;
60
+ DataCopy(input_local, input_gm[offset], len);
61
+ if(tail != 0) {
62
+ DataCopyExtParams dataCopyParams;
63
+ dataCopyParams.blockCount = 1;
64
+ dataCopyParams.blockLen = tail * sizeof(float);
65
+ DataCopyPadExtParams<float> padParams;
66
+ DataCopyPad(input_local[len], input_gm[offset + len],
67
+ dataCopyParams, padParams);
68
+ }
69
+ input_queue.EnQue(input_local);
70
+ }
71
+
72
+ __aicore__ inline void copy_out(uint32_t offset, size_t len) {
73
+ LocalTensor<float> output_local = output_queue.DeQue<float>();
74
+ size_t tail = len % 32;
75
+ len = len & ~31;
76
+ DataCopy(output_gm[offset], output_local, len);
77
+ if(tail != 0) {
78
+ DataCopyExtParams dataCopyParams;
79
+ dataCopyParams.blockCount = 1;
80
+ dataCopyParams.blockLen = tail * sizeof(float);
81
+ DataCopyPad(output_gm[offset + len], output_local[len],
82
+ dataCopyParams);
83
+ }
84
+ output_queue.FreeTensor(output_local);
85
+ }
86
+
87
+ __aicore__ inline void calculate_row(int64_t idx) {
88
+ const int64_t indices_ne2_idx = idx / (indices_ne[0] * indices_ne[1]);
89
+ const int64_t indices_ne1_idx =
90
+ (idx - indices_ne2_idx * indices_ne[0] * indices_ne[1]) /
91
+ indices_ne[0];
92
+ const int64_t indices_ne0_idx =
93
+ (idx - indices_ne2_idx * indices_ne[0] * indices_ne[1] -
94
+ indices_ne1_idx * indices_ne[0]);
95
+
96
+ const int64_t indices_offset = indices_ne0_idx * indices_stride[0] +
97
+ indices_ne1_idx * indices_stride[1] +
98
+ indices_ne2_idx * indices_stride[2];
99
+ const int32_t selected_row_idx = indices_gm.GetValue(indices_offset);
100
+
101
+ const int64_t input_offset = selected_row_idx * input_stride[1] +
102
+ indices_ne1_idx * input_stride[2] +
103
+ indices_ne2_idx * input_stride[3];
104
+
105
+ const int64_t output_offset = indices_ne0_idx * output_stride[1] +
106
+ indices_ne1_idx * output_stride[2] +
107
+ indices_ne2_idx * output_stride[3];
108
+
109
+ copy_in(input_offset, input_ne[0]);
110
+ LocalTensor<float> input_local = input_queue.DeQue<float>();
111
+ LocalTensor<float> output_local = output_queue.AllocTensor<float>();
112
+
113
+ DataCopy(output_local, input_local, local_buffer_elems);
114
+ output_queue.EnQue(output_local);
115
+ copy_out(output_offset, input_ne[0]);
116
+
117
+ input_queue.FreeTensor(input_local);
118
+ }
119
+
120
+ __aicore__ inline void calculate() {
121
+ for (int64_t i = ir; i < ir + dr; i++) {
122
+ calculate_row(i);
123
+ }
124
+ }
125
+
126
+ private:
127
+ int64_t input_ne[4];
128
+ size_t input_stride[4];
129
+
130
+ int64_t indices_ne[4];
131
+ size_t indices_stride[4];
132
+
133
+ int64_t output_ne[4];
134
+ size_t output_stride[4];
135
+
136
+ size_t local_buffer_elems;
137
+
138
+ int64_t ir;
139
+ int64_t dr;
140
+
141
+ TPipe pipe;
142
+ GlobalTensor<float> input_gm;
143
+ GlobalTensor<int32_t> indices_gm;
144
+ GlobalTensor<float> output_gm;
145
+ TQue<QuePosition::VECIN, BUFFER_NUM> input_queue;
146
+ TQue<QuePosition::VECOUT, BUFFER_NUM> output_queue;
147
+ };
148
+
149
+ template <typename T>
150
+ __aicore__ inline void copy_to_ub(GM_ADDR gm, T *ub, size_t size) {
151
+ auto gm_ptr = (__gm__ uint8_t *)gm;
152
+ auto ub_ptr = (uint8_t *)(ub);
153
+ for (int32_t i = 0; i < size; ++i, ++ub_ptr, ++gm_ptr) {
154
+ *ub_ptr = *gm_ptr;
155
+ }
156
+ }
157
+
158
+ extern "C" __global__ __aicore__ void ascendc_get_row_f32(
159
+ GM_ADDR input_gm, GM_ADDR indices_gm, GM_ADDR output_gm,
160
+ GM_ADDR input_ne_gm, GM_ADDR input_nb_gm, GM_ADDR indices_ne_gm,
161
+ GM_ADDR indices_nb_gm, GM_ADDR output_ne_gm, GM_ADDR output_nb_gm) {
162
+ int64_t input_ne_ub[4];
163
+ size_t input_nb_ub[4];
164
+ int64_t indices_ne_ub[4];
165
+ size_t indices_nb_ub[4];
166
+ int64_t output_ne_ub[4];
167
+ size_t output_nb_ub[4];
168
+
169
+ copy_to_ub(input_ne_gm, input_ne_ub, 32);
170
+ copy_to_ub(input_nb_gm, input_nb_ub, 32);
171
+ copy_to_ub(indices_ne_gm, indices_ne_ub, 32);
172
+ copy_to_ub(indices_nb_gm, indices_nb_ub, 32);
173
+ copy_to_ub(output_ne_gm, output_ne_ub, 32);
174
+ copy_to_ub(output_nb_gm, output_nb_ub, 32);
175
+
176
+ GET_ROW_F32 op;
177
+ op.init(input_gm, indices_gm, output_gm, input_ne_ub, input_nb_ub,
178
+ indices_ne_ub, indices_nb_ub, output_ne_ub, output_nb_ub);
179
+ op.calculate();
180
+ }
ggml/src/ggml-cann/kernels/get_row_q4_0.cpp ADDED
@@ -0,0 +1,193 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include "kernel_operator.h"
2
+
3
+ // optimize me. Use template to avoid copy code.
4
+ using namespace AscendC;
5
+
6
+ #define BUFFER_NUM 2
7
+
8
+ #define QK4_0 32
9
+
10
+ class GET_ROW_Q4_0 {
11
+ public:
12
+ __aicore__ inline GET_ROW_Q4_0() {}
13
+ __aicore__ inline void init(GM_ADDR input, GM_ADDR indices, GM_ADDR output,
14
+ int64_t *input_ne_ub, int64_t *indices_ne_ub,
15
+ size_t *indices_nb_ub, int64_t *output_ne_ub,
16
+ size_t *output_nb_ub) {
17
+ int64_t op_block_num = GetBlockNum();
18
+ int64_t op_block_idx = GetBlockIdx();
19
+
20
+ for (int i = 0; i < 4; i++) {
21
+ input_ne[i] = input_ne_ub[i];
22
+ indices_ne[i] = indices_ne_ub[i];
23
+ indices_stride[i] = indices_nb_ub[i] / indices_nb_ub[0];
24
+ scale_ne[i] = input_ne_ub[i];
25
+ output_ne[i] = output_ne_ub[i];
26
+ output_stride[i] = output_nb_ub[i] / output_nb_ub[0];
27
+ }
28
+
29
+ // one scale for a group.
30
+ scale_ne[0] /= QK4_0;
31
+
32
+ input_stride[0] = 1;
33
+ scale_stride[0] = 1;
34
+ output_stride[0] = 1;
35
+ for (int i = 1; i < 4; i++) {
36
+ input_stride[i] = input_stride[i - 1] * input_ne[i - 1];
37
+ scale_stride[i] = scale_stride[i - 1] * scale_ne[i - 1];
38
+ }
39
+
40
+ group_size_in_row = input_ne[0] / QK4_0;
41
+ int64_t scale_offset = input_ne[0] * input_ne[1] * input_ne[2] *
42
+ input_ne[3] / 2;
43
+
44
+ // Indices has two dims. n_elements = all rows should get.
45
+ // dr, all rows should this thread get.
46
+ uint64_t n_elements =
47
+ indices_ne[0] * indices_ne[1] * indices_ne[2] * indices_ne[3];
48
+ dr = n_elements / op_block_num;
49
+
50
+ uint64_t tails = n_elements % op_block_num;
51
+ if (op_block_idx < tails) {
52
+ dr += 1;
53
+ ir = dr * op_block_idx;
54
+ } else {
55
+ ir = dr * op_block_idx + tails;
56
+ }
57
+
58
+ input_gm.SetGlobalBuffer((__gm__ int4b_t *)input);
59
+ scale_gm.SetGlobalBuffer((__gm__ half *)(input + scale_offset));
60
+ indices_gm.SetGlobalBuffer((__gm__ int32_t *)indices);
61
+ output_gm.SetGlobalBuffer((__gm__ float *)output);
62
+
63
+ pipe.InitBuffer(input_queue, BUFFER_NUM, QK4_0 * sizeof(int4b_t));
64
+ pipe.InitBuffer(cast_queue, BUFFER_NUM, QK4_0 * sizeof(half));
65
+ pipe.InitBuffer(output_queue, BUFFER_NUM, QK4_0 * sizeof(float));
66
+ }
67
+
68
+ __aicore__ inline void copy_in(uint32_t offset) {
69
+ LocalTensor<int4b_t> input_local = input_queue.AllocTensor<int4b_t>();
70
+ // 32 * sizeof(int4b_t) = 16, which is not aligned to 32, why no error?
71
+ DataCopy(input_local, input_gm[offset], QK4_0);
72
+ input_queue.EnQue(input_local);
73
+ }
74
+
75
+ __aicore__ inline void copy_out(uint32_t offset) {
76
+ LocalTensor<float> output_local = output_queue.DeQue<float>();
77
+ DataCopy(output_gm[offset], output_local, QK4_0);
78
+ output_queue.FreeTensor(output_local);
79
+ }
80
+
81
+ __aicore__ inline void calculate_group(int64_t idx, int64_t group) {
82
+ const int64_t indices_ne2_idx = idx / (indices_ne[0] * indices_ne[1]);
83
+ const int64_t indices_ne1_idx =
84
+ (idx - indices_ne2_idx * indices_ne[0] * indices_ne[1]) /
85
+ indices_ne[0];
86
+ const int64_t indices_ne0_idx =
87
+ (idx - indices_ne2_idx * indices_ne[0] * indices_ne[1] -
88
+ indices_ne1_idx * indices_ne[0]);
89
+
90
+ const int64_t indices_offset = indices_ne0_idx * indices_stride[0] +
91
+ indices_ne1_idx * indices_stride[1] +
92
+ indices_ne2_idx * indices_stride[2];
93
+ const int32_t selected_row_idx = indices_gm.GetValue(indices_offset);
94
+
95
+ const int64_t input_offset = selected_row_idx * input_stride[1] +
96
+ indices_ne1_idx * input_stride[2] +
97
+ indices_ne2_idx * input_stride[3] +
98
+ group * QK4_0;
99
+ const int64_t scale_offset = selected_row_idx * scale_stride[1] +
100
+ indices_ne1_idx * scale_stride[2] +
101
+ indices_ne2_idx * scale_stride[3] + group;
102
+ const int64_t output_offset = indices_ne0_idx * output_stride[1] +
103
+ indices_ne1_idx * output_stride[2] +
104
+ indices_ne2_idx * output_stride[3] +
105
+ group * QK4_0;
106
+
107
+ copy_in(input_offset);
108
+ LocalTensor<int4b_t> input_local = input_queue.DeQue<int4b_t>();
109
+ LocalTensor<half> cast_local = cast_queue.AllocTensor<half>();
110
+ LocalTensor<float> output_local = output_queue.AllocTensor<float>();
111
+
112
+ // TODO: cast more data to speed up.
113
+ Cast(cast_local, input_local, RoundMode::CAST_NONE, QK4_0);
114
+ Cast(output_local, cast_local, RoundMode::CAST_NONE, QK4_0);
115
+
116
+ // Only mul need compile by group.
117
+ half scale = scale_gm.GetValue(scale_offset);
118
+
119
+ Muls(output_local, output_local, (float)scale, QK4_0);
120
+
121
+ input_queue.FreeTensor(input_local);
122
+ cast_queue.FreeTensor(cast_local);
123
+ output_queue.EnQue(output_local);
124
+
125
+ copy_out(output_offset);
126
+ }
127
+
128
+ __aicore__ inline void calculate() {
129
+ for (int64_t i = ir; i < ir + dr; i++) {
130
+ for (int64_t j = 0; j < group_size_in_row; j++) {
131
+ calculate_group(i, j);
132
+ }
133
+ }
134
+ }
135
+
136
+ private:
137
+ int64_t input_ne[4];
138
+ size_t input_stride[4];
139
+
140
+ int64_t scale_ne[4];
141
+ size_t scale_stride[4];
142
+
143
+ int64_t indices_ne[4];
144
+ size_t indices_stride[4];
145
+
146
+ int64_t output_ne[4];
147
+ size_t output_stride[4];
148
+
149
+ int64_t ir;
150
+ int64_t dr;
151
+
152
+ int64_t group_size_in_row;
153
+
154
+ TPipe pipe;
155
+ GlobalTensor<int4b_t> input_gm;
156
+ GlobalTensor<half> scale_gm;
157
+ GlobalTensor<int32_t> indices_gm;
158
+ GlobalTensor<float> output_gm;
159
+ TQue<QuePosition::VECIN, BUFFER_NUM> input_queue;
160
+ TQue<QuePosition::VECOUT, BUFFER_NUM> output_queue;
161
+ TQue<QuePosition::VECIN, BUFFER_NUM> cast_queue;
162
+ };
163
+
164
+ template <typename T>
165
+ __aicore__ inline void copy_to_ub(GM_ADDR gm, T *ub, size_t size) {
166
+ auto gm_ptr = (__gm__ uint8_t *)gm;
167
+ auto ub_ptr = (uint8_t *)(ub);
168
+ for (int32_t i = 0; i < size; ++i, ++ub_ptr, ++gm_ptr) {
169
+ *ub_ptr = *gm_ptr;
170
+ }
171
+ }
172
+
173
+ extern "C" __global__ __aicore__ void ascendc_get_row_q4_0(
174
+ GM_ADDR input_gm, GM_ADDR indices_gm, GM_ADDR output_gm,
175
+ GM_ADDR input_ne_gm, GM_ADDR indices_ne_gm, GM_ADDR indices_nb_gm,
176
+ GM_ADDR output_ne_gm, GM_ADDR output_nb_gm) {
177
+ int64_t input_ne_ub[4];
178
+ int64_t indices_ne_ub[4];
179
+ size_t indices_nb_ub[4];
180
+ int64_t output_ne_ub[4];
181
+ size_t output_nb_ub[4];
182
+
183
+ copy_to_ub(input_ne_gm, input_ne_ub, 32);
184
+ copy_to_ub(indices_ne_gm, indices_ne_ub, 32);
185
+ copy_to_ub(indices_nb_gm, indices_nb_ub, 32);
186
+ copy_to_ub(output_ne_gm, output_ne_ub, 32);
187
+ copy_to_ub(output_nb_gm, output_nb_ub, 32);
188
+
189
+ GET_ROW_Q4_0 op;
190
+ op.init(input_gm, indices_gm, output_gm, input_ne_ub, indices_ne_ub,
191
+ indices_nb_ub, output_ne_ub, output_nb_ub);
192
+ op.calculate();
193
+ }
ggml/src/ggml-cann/kernels/get_row_q8_0.cpp ADDED
@@ -0,0 +1,191 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include "kernel_operator.h"
2
+
3
+ // optimize me. Use template to avoid copy code.
4
+ using namespace AscendC;
5
+
6
+ #define BUFFER_NUM 2
7
+
8
+ #define QK8_0 32
9
+
10
+ class GET_ROW_Q8_0 {
11
+ public:
12
+ __aicore__ inline GET_ROW_Q8_0() {}
13
+ __aicore__ inline void init(GM_ADDR input, GM_ADDR indices, GM_ADDR output,
14
+ int64_t *input_ne_ub, int64_t *indices_ne_ub,
15
+ size_t *indices_nb_ub, int64_t *output_ne_ub,
16
+ size_t *output_nb_ub) {
17
+ int64_t op_block_num = GetBlockNum();
18
+ int64_t op_block_idx = GetBlockIdx();
19
+
20
+ for (int i = 0; i < 4; i++) {
21
+ input_ne[i] = input_ne_ub[i];
22
+ indices_ne[i] = indices_ne_ub[i];
23
+ indices_stride[i] = indices_nb_ub[i] / indices_nb_ub[0];
24
+ scale_ne[i] = input_ne_ub[i];
25
+ output_ne[i] = output_ne_ub[i];
26
+ output_stride[i] = output_nb_ub[i] / output_nb_ub[0];
27
+ }
28
+
29
+ // one scale for a group.
30
+ scale_ne[0] /= QK8_0;
31
+
32
+ input_stride[0] = 1;
33
+ scale_stride[0] = 1;
34
+ output_stride[0] = 1;
35
+ for (int i = 1; i < 4; i++) {
36
+ input_stride[i] = input_stride[i - 1] * input_ne[i - 1];
37
+ scale_stride[i] = scale_stride[i - 1] * scale_ne[i - 1];
38
+ }
39
+
40
+ group_size_in_row = input_ne[0] / QK8_0;
41
+ int64_t scale_offset = input_ne[0] * input_ne[1] * input_ne[2] *
42
+ input_ne[3] * sizeof(int8_t);
43
+
44
+ // Indices has two dims. n_elements = all rows should get.
45
+ // dr, all rows should this thread get.
46
+ uint64_t n_elements =
47
+ indices_ne[0] * indices_ne[1] * indices_ne[2] * indices_ne[3];
48
+ dr = n_elements / op_block_num;
49
+
50
+ uint64_t tails = n_elements % op_block_num;
51
+ if (op_block_idx < tails) {
52
+ dr += 1;
53
+ ir = dr * op_block_idx;
54
+ } else {
55
+ ir = dr * op_block_idx + tails;
56
+ }
57
+
58
+ input_gm.SetGlobalBuffer((__gm__ int8_t *)input);
59
+ scale_gm.SetGlobalBuffer((__gm__ half *)(input + scale_offset));
60
+ indices_gm.SetGlobalBuffer((__gm__ int32_t *)indices);
61
+ output_gm.SetGlobalBuffer((__gm__ float *)output);
62
+
63
+ pipe.InitBuffer(input_queue, BUFFER_NUM, QK8_0 * sizeof(int8_t));
64
+ pipe.InitBuffer(cast_queue, BUFFER_NUM, QK8_0 * sizeof(half));
65
+ pipe.InitBuffer(output_queue, BUFFER_NUM, QK8_0 * sizeof(float));
66
+ }
67
+
68
+ __aicore__ inline void copy_in(uint32_t offset) {
69
+ LocalTensor<int8_t> input_local = input_queue.AllocTensor<int8_t>();
70
+ DataCopy(input_local, input_gm[offset], QK8_0);
71
+ input_queue.EnQue(input_local);
72
+ }
73
+
74
+ __aicore__ inline void copy_out(uint32_t offset) {
75
+ LocalTensor<float> output_local = output_queue.DeQue<float>();
76
+ DataCopy(output_gm[offset], output_local, QK8_0);
77
+ output_queue.FreeTensor(output_local);
78
+ }
79
+
80
+ __aicore__ inline void calculate_group(int64_t idx, int64_t group) {
81
+ const int64_t indices_ne2_idx = idx / (indices_ne[0] * indices_ne[1]);
82
+ const int64_t indices_ne1_idx =
83
+ (idx - indices_ne2_idx * indices_ne[0] * indices_ne[1]) /
84
+ indices_ne[0];
85
+ const int64_t indices_ne0_idx =
86
+ (idx - indices_ne2_idx * indices_ne[0] * indices_ne[1] -
87
+ indices_ne1_idx * indices_ne[0]);
88
+
89
+ const int64_t indices_offset = indices_ne0_idx * indices_stride[0] +
90
+ indices_ne1_idx * indices_stride[1] +
91
+ indices_ne2_idx * indices_stride[2];
92
+ const int32_t selected_row_idx = indices_gm.GetValue(indices_offset);
93
+
94
+ const int64_t input_offset = selected_row_idx * input_stride[1] +
95
+ indices_ne1_idx * input_stride[2] +
96
+ indices_ne2_idx * input_stride[3] +
97
+ group * QK8_0;
98
+ const int64_t scale_offset = selected_row_idx * scale_stride[1] +
99
+ indices_ne1_idx * scale_stride[2] +
100
+ indices_ne2_idx * scale_stride[3] + group;
101
+ const int64_t output_offset = indices_ne0_idx * output_stride[1] +
102
+ indices_ne1_idx * output_stride[2] +
103
+ indices_ne2_idx * output_stride[3] +
104
+ group * QK8_0;
105
+
106
+ copy_in(input_offset);
107
+ LocalTensor<int8_t> input_local = input_queue.DeQue<int8_t>();
108
+ LocalTensor<half> cast_local = cast_queue.AllocTensor<half>();
109
+ LocalTensor<float> output_local = output_queue.AllocTensor<float>();
110
+
111
+ // TODO: cast more data to speed up.
112
+ Cast(cast_local, input_local, RoundMode::CAST_NONE, QK8_0);
113
+ Cast(output_local, cast_local, RoundMode::CAST_NONE, QK8_0);
114
+
115
+ // Only mul need compile by group.
116
+ half scale = scale_gm.GetValue(scale_offset);
117
+ Muls(output_local, output_local, (float)scale, QK8_0);
118
+
119
+ input_queue.FreeTensor(input_local);
120
+ cast_queue.FreeTensor(cast_local);
121
+ output_queue.EnQue(output_local);
122
+
123
+ copy_out(output_offset);
124
+ }
125
+
126
+ __aicore__ inline void calculate() {
127
+ for (int64_t i = ir; i < ir + dr; i++) {
128
+ for (int64_t j = 0; j < group_size_in_row; j++) {
129
+ calculate_group(i, j);
130
+ }
131
+ }
132
+ }
133
+
134
+ private:
135
+ int64_t input_ne[4];
136
+ size_t input_stride[4];
137
+
138
+ int64_t scale_ne[4];
139
+ size_t scale_stride[4];
140
+
141
+ int64_t indices_ne[4];
142
+ size_t indices_stride[4];
143
+
144
+ int64_t output_ne[4];
145
+ size_t output_stride[4];
146
+
147
+ int64_t ir;
148
+ int64_t dr;
149
+
150
+ int64_t group_size_in_row;
151
+
152
+ TPipe pipe;
153
+ GlobalTensor<int8_t> input_gm;
154
+ GlobalTensor<half> scale_gm;
155
+ GlobalTensor<int32_t> indices_gm;
156
+ GlobalTensor<float> output_gm;
157
+ TQue<QuePosition::VECIN, BUFFER_NUM> input_queue;
158
+ TQue<QuePosition::VECOUT, BUFFER_NUM> output_queue;
159
+ TQue<QuePosition::VECIN, BUFFER_NUM> cast_queue;
160
+ };
161
+
162
+ template <typename T>
163
+ __aicore__ inline void copy_to_ub(GM_ADDR gm, T *ub, size_t size) {
164
+ auto gm_ptr = (__gm__ uint8_t *)gm;
165
+ auto ub_ptr = (uint8_t *)(ub);
166
+ for (int32_t i = 0; i < size; ++i, ++ub_ptr, ++gm_ptr) {
167
+ *ub_ptr = *gm_ptr;
168
+ }
169
+ }
170
+
171
+ extern "C" __global__ __aicore__ void ascendc_get_row_q8_0(
172
+ GM_ADDR input_gm, GM_ADDR indices_gm, GM_ADDR output_gm,
173
+ GM_ADDR input_ne_gm, GM_ADDR indices_ne_gm, GM_ADDR indices_nb_gm,
174
+ GM_ADDR output_ne_gm, GM_ADDR output_nb_gm) {
175
+ int64_t input_ne_ub[4];
176
+ int64_t indices_ne_ub[4];
177
+ size_t indices_nb_ub[4];
178
+ int64_t output_ne_ub[4];
179
+ size_t output_nb_ub[4];
180
+
181
+ copy_to_ub(input_ne_gm, input_ne_ub, 32);
182
+ copy_to_ub(indices_ne_gm, indices_ne_ub, 32);
183
+ copy_to_ub(indices_nb_gm, indices_nb_ub, 32);
184
+ copy_to_ub(output_ne_gm, output_ne_ub, 32);
185
+ copy_to_ub(output_nb_gm, output_nb_ub, 32);
186
+
187
+ GET_ROW_Q8_0 op;
188
+ op.init(input_gm, indices_gm, output_gm, input_ne_ub, indices_ne_ub,
189
+ indices_nb_ub, output_ne_ub, output_nb_ub);
190
+ op.calculate();
191
+ }
ggml/src/ggml-cann/kernels/quantize_f16_q8_0.cpp ADDED
@@ -0,0 +1,208 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include "kernel_operator.h"
2
+
3
+ using namespace AscendC;
4
+
5
+ #define BUFFER_NUM 2
6
+ #define QK8_0 32
7
+
8
+ class QUANTIZE_F16_Q8_0 {
9
+ public:
10
+ __aicore__ inline QUANTIZE_F16_Q8_0() {}
11
+ __aicore__ inline void init(GM_ADDR input, GM_ADDR output,
12
+ int64_t *input_ne_ub, size_t *input_nb_ub,
13
+ int64_t *output_ne_ub) {
14
+ int64_t op_block_num = GetBlockNum();
15
+ int64_t op_block_idx = GetBlockIdx();
16
+
17
+ for (int i = 0; i < 4; i++) {
18
+ input_ne[i] = input_ne_ub[i];
19
+ input_stride[i] = input_nb_ub[i] / input_nb_ub[0];
20
+
21
+ output_ne[i] = output_ne_ub[i];
22
+ }
23
+
24
+ output_stride[0] = 1;
25
+ for (int i = 1; i < 4; i++) {
26
+ output_stride[i] = output_stride[i - 1] * output_ne[i - 1];
27
+ }
28
+
29
+ scale_ne = input_ne;
30
+ scale_stride[0] = 1;
31
+ scale_stride[1] = input_ne[0] / QK8_0;
32
+ for (int i = 2; i < 4; i++) {
33
+ scale_stride[i] = scale_stride[i - 1] * scale_ne[i - 1];
34
+ }
35
+
36
+ // split input tensor by rows.
37
+ uint64_t nr = input_ne[1] * input_ne[2] * input_ne[3];
38
+ dr = nr / op_block_num;
39
+
40
+ uint64_t tails = nr % op_block_num;
41
+ if (op_block_idx < tails) {
42
+ dr += 1;
43
+ ir = dr * op_block_idx;
44
+ } else {
45
+ ir = dr * op_block_idx + tails;
46
+ }
47
+
48
+ group_size_in_row = scale_stride[1];
49
+ int64_t output_size = output_ne[0] * output_ne[1] * output_ne[2] *
50
+ output_ne[3] * sizeof(uint8_t);
51
+
52
+ input_gm.SetGlobalBuffer((__gm__ half *)input);
53
+ output_gm.SetGlobalBuffer((__gm__ int8_t *)output);
54
+ scale_gm.SetGlobalBuffer((__gm__ half *)(output + output_size + ir *
55
+ group_size_in_row *
56
+ sizeof(half)));
57
+
58
+ pipe.InitBuffer(input_queue, BUFFER_NUM, QK8_0 * sizeof(half));
59
+ pipe.InitBuffer(output_queue, BUFFER_NUM, QK8_0 * sizeof(int8_t));
60
+ pipe.InitBuffer(work_queue, 1, 32);
61
+ pipe.InitBuffer(max_queue, 1, 32);
62
+ pipe.InitBuffer(abs_queue, 1, QK8_0 * sizeof(float));
63
+ pipe.InitBuffer(scale_queue, 1, 32);
64
+ pipe.InitBuffer(cast_queue ,1 ,QK8_0 * sizeof(float));
65
+ }
66
+
67
+ __aicore__ inline void copy_in(uint32_t offset) {
68
+ LocalTensor<half> input_local = input_queue.AllocTensor<half>();
69
+ DataCopy(input_local, input_gm[offset], QK8_0);
70
+ input_queue.EnQue(input_local);
71
+ }
72
+
73
+ __aicore__ inline void copy_out(uint32_t offset) {
74
+ LocalTensor<int8_t> output_local = output_queue.DeQue<int8_t>();
75
+ DataCopy(output_gm[offset], output_local, QK8_0);
76
+ output_queue.FreeTensor(output_local);
77
+ }
78
+
79
+ __aicore__ inline half calculate_group(int64_t row, int64_t group) {
80
+ const int64_t i3 = row / (input_ne[1] * input_ne[2]);
81
+ const int64_t i2 = (row - i3 * input_ne[1] * input_ne[2]) / input_ne[1];
82
+ const int64_t i1 =
83
+ row - i3 * input_ne[1] * input_ne[2] - i2 * input_ne[1];
84
+
85
+ const int64_t input_offset = i1 * input_stride[1] +
86
+ i2 * input_stride[2] +
87
+ i3 * input_stride[3] + QK8_0 * group;
88
+
89
+ const int64_t output_offset = i1 * output_stride[1] +
90
+ i2 * output_stride[2] +
91
+ i3 * output_stride[3] + QK8_0 * group;
92
+
93
+ copy_in(input_offset);
94
+ LocalTensor<half> input_local = input_queue.DeQue<half>();
95
+ LocalTensor<int8_t> output_local = output_queue.AllocTensor<int8_t>();
96
+ LocalTensor<float> work_local = work_queue.AllocTensor<float>();
97
+ LocalTensor<float> abs_local = abs_queue.AllocTensor<float>();
98
+ LocalTensor<float> max_local = max_queue.AllocTensor<float>();
99
+ LocalTensor<float> cast_local = cast_queue.AllocTensor<float>();
100
+
101
+ Cast(cast_local, input_local, RoundMode::CAST_NONE, QK8_0);
102
+ Abs(abs_local, cast_local, QK8_0);
103
+ ReduceMax(max_local, abs_local, work_local, QK8_0);
104
+
105
+ pipe_barrier(PIPE_ALL);
106
+ float d = max_local.GetValue(0);
107
+ d = d / ((1 << 7) - 1);
108
+ if (d != 0) {
109
+ Muls(cast_local, cast_local, 1.0f / d, QK8_0);
110
+ }
111
+
112
+ Cast(cast_local, cast_local, RoundMode::CAST_ROUND, QK8_0);
113
+ Cast(input_local, cast_local, RoundMode::CAST_ROUND, QK8_0);
114
+ Cast(output_local, input_local, RoundMode::CAST_ROUND, QK8_0);
115
+ output_queue.EnQue(output_local);
116
+ copy_out(output_offset);
117
+
118
+ input_queue.FreeTensor(input_local);
119
+ work_queue.FreeTensor(work_local);
120
+ abs_queue.FreeTensor(abs_local);
121
+ max_queue.FreeTensor(max_local);
122
+ cast_queue.FreeTensor(cast_local);
123
+ return (half)d;
124
+ }
125
+
126
+ __aicore__ inline void calculate() {
127
+ LocalTensor<half> scale_local = scale_queue.AllocTensor<half>();
128
+ uint32_t scale_local_offset = 0;
129
+ uint32_t scale_global_offset = 0;
130
+ for (int64_t i = ir; i < ir + dr; i++) {
131
+ for (int64_t j = 0; j < group_size_in_row; j++) {
132
+ half scale = calculate_group(i, j);
133
+ scale_local.SetValue(scale_local_offset++, scale);
134
+ if (scale_local_offset == 16) {
135
+ scale_local_offset = 0;
136
+ // TODO: OPTIMIZE ME
137
+ pipe_barrier(PIPE_ALL);
138
+ DataCopy(scale_gm[scale_global_offset], scale_local, 16);
139
+ pipe_barrier(PIPE_ALL);
140
+ scale_global_offset += 16;
141
+ }
142
+ }
143
+ }
144
+
145
+ if (scale_local_offset != 0) {
146
+ pipe_barrier(PIPE_ALL);
147
+ DataCopyExtParams dataCopyParams;
148
+ dataCopyParams.blockCount = 1;
149
+ dataCopyParams.blockLen = scale_local_offset * sizeof(half);
150
+ DataCopyPad(scale_gm[scale_global_offset], scale_local,
151
+ dataCopyParams);
152
+ pipe_barrier(PIPE_ALL);
153
+ }
154
+ }
155
+
156
+ private:
157
+ int64_t input_ne[4];
158
+ size_t input_stride[4];
159
+
160
+ int64_t *scale_ne;
161
+ size_t scale_stride[4];
162
+
163
+ int64_t output_ne[4];
164
+ size_t output_stride[4];
165
+
166
+ int64_t group_size_in_row;
167
+
168
+ int64_t ir;
169
+ int64_t dr;
170
+
171
+ TPipe pipe;
172
+ GlobalTensor<half> input_gm;
173
+ GlobalTensor<half> scale_gm;
174
+ GlobalTensor<int8_t> output_gm;
175
+ TQue<QuePosition::VECIN, BUFFER_NUM> input_queue;
176
+ TQue<QuePosition::VECOUT, BUFFER_NUM> output_queue;
177
+ TQue<QuePosition::VECIN, 1> work_queue;
178
+ TQue<QuePosition::VECOUT, 1> max_queue;
179
+ TQue<QuePosition::VECIN, 1> abs_queue;
180
+ TQue<QuePosition::VECOUT, 1> scale_queue;
181
+ TQue<QuePosition::VECOUT, 1> cast_queue;
182
+
183
+ };
184
+
185
+ template <typename T>
186
+ __aicore__ inline void copy_to_ub(GM_ADDR gm, T *ub, size_t size) {
187
+ auto gm_ptr = (__gm__ uint8_t *)gm;
188
+ auto ub_ptr = (uint8_t *)(ub);
189
+ for (int32_t i = 0; i < size; ++i, ++ub_ptr, ++gm_ptr) {
190
+ *ub_ptr = *gm_ptr;
191
+ }
192
+ }
193
+
194
+ extern "C" __global__ __aicore__ void ascendc_quantize_f16_q8_0(
195
+ GM_ADDR input_gm, GM_ADDR output_gm, GM_ADDR input_ne_gm,
196
+ GM_ADDR input_nb_gm, GM_ADDR output_ne_gm) {
197
+ int64_t input_ne_ub[4];
198
+ size_t input_nb_ub[4];
199
+ int64_t output_ne_ub[4];
200
+
201
+ copy_to_ub(input_ne_gm, input_ne_ub, 32);
202
+ copy_to_ub(input_nb_gm, input_nb_ub, 32);
203
+ copy_to_ub(output_ne_gm, output_ne_ub, 32);
204
+
205
+ QUANTIZE_F16_Q8_0 op;
206
+ op.init(input_gm, output_gm, input_ne_ub, input_nb_ub, output_ne_ub);
207
+ op.calculate();
208
+ }
ggml/src/ggml-cann/kernels/quantize_f32_q8_0.cpp ADDED
@@ -0,0 +1,206 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include "kernel_operator.h"
2
+
3
+ using namespace AscendC;
4
+
5
+ #define BUFFER_NUM 2
6
+ #define QK8_0 32
7
+
8
+ class QUANTIZE_F32_Q8_0 {
9
+ public:
10
+ __aicore__ inline QUANTIZE_F32_Q8_0() {}
11
+ __aicore__ inline void init(GM_ADDR input, GM_ADDR output,
12
+ int64_t *input_ne_ub, size_t *input_nb_ub,
13
+ int64_t *output_ne_ub) {
14
+ int64_t op_block_num = GetBlockNum();
15
+ int64_t op_block_idx = GetBlockIdx();
16
+
17
+ for (int i = 0; i < 4; i++) {
18
+ input_ne[i] = input_ne_ub[i];
19
+ input_stride[i] = input_nb_ub[i] / input_nb_ub[0];
20
+
21
+ output_ne[i] = output_ne_ub[i];
22
+ }
23
+
24
+ output_stride[0] = 1;
25
+ for (int i = 1; i < 4; i++) {
26
+ output_stride[i] = output_stride[i - 1] * output_ne[i - 1];
27
+ }
28
+
29
+ scale_ne = input_ne;
30
+ scale_stride[0] = 1;
31
+ scale_stride[1] = input_ne[0] / QK8_0;
32
+ for (int i = 2; i < 4; i++) {
33
+ scale_stride[i] = scale_stride[i - 1] * scale_ne[i - 1];
34
+ }
35
+
36
+ // split input tensor by rows.
37
+ uint64_t nr = input_ne[1] * input_ne[2] * input_ne[3];
38
+ dr = nr / op_block_num;
39
+
40
+ uint64_t tails = nr % op_block_num;
41
+ if (op_block_idx < tails) {
42
+ dr += 1;
43
+ ir = dr * op_block_idx;
44
+ } else {
45
+ ir = dr * op_block_idx + tails;
46
+ }
47
+
48
+ group_size_in_row = scale_stride[1];
49
+ int64_t output_size = output_ne[0] * output_ne[1] * output_ne[2] *
50
+ output_ne[3] * sizeof(uint8_t);
51
+
52
+ input_gm.SetGlobalBuffer((__gm__ float *)input);
53
+ output_gm.SetGlobalBuffer((__gm__ int8_t *)output);
54
+ scale_gm.SetGlobalBuffer((__gm__ half *)(output + output_size +
55
+ ir * group_size_in_row *
56
+ sizeof(half)));
57
+
58
+ pipe.InitBuffer(input_queue, BUFFER_NUM, QK8_0 * sizeof(float));
59
+ pipe.InitBuffer(output_queue, BUFFER_NUM, QK8_0 * sizeof(int8_t));
60
+ pipe.InitBuffer(work_queue, 1, 32);
61
+ pipe.InitBuffer(max_queue, 1, 32);
62
+ pipe.InitBuffer(abs_queue, 1, QK8_0 * sizeof(float));
63
+ pipe.InitBuffer(cast_queue, 1, QK8_0 * sizeof(half));
64
+ pipe.InitBuffer(scale_queue, 1, 32);
65
+ }
66
+
67
+ __aicore__ inline void copy_in(uint32_t offset) {
68
+ LocalTensor<float> input_local = input_queue.AllocTensor<float>();
69
+ DataCopy(input_local, input_gm[offset], QK8_0);
70
+ input_queue.EnQue(input_local);
71
+ }
72
+
73
+ __aicore__ inline void copy_out(uint32_t offset) {
74
+ LocalTensor<int8_t> output_local = output_queue.DeQue<int8_t>();
75
+ DataCopy(output_gm[offset], output_local, QK8_0);
76
+ output_queue.FreeTensor(output_local);
77
+ }
78
+
79
+ __aicore__ inline half calculate_group(int64_t row, int64_t group) {
80
+ const int64_t i3 = row / (input_ne[1] * input_ne[2]);
81
+ const int64_t i2 = (row - i3 * input_ne[1] * input_ne[2]) / input_ne[1];
82
+ const int64_t i1 =
83
+ row - i3 * input_ne[1] * input_ne[2] - i2 * input_ne[1];
84
+
85
+ const int64_t input_offset = i1 * input_stride[1] +
86
+ i2 * input_stride[2] +
87
+ i3 * input_stride[3] + QK8_0 * group;
88
+
89
+ const int64_t output_offset = i1 * output_stride[1] +
90
+ i2 * output_stride[2] +
91
+ i3 * output_stride[3] + QK8_0 * group;
92
+
93
+ copy_in(input_offset);
94
+ LocalTensor<float> input_local = input_queue.DeQue<float>();
95
+ LocalTensor<int8_t> output_local = output_queue.AllocTensor<int8_t>();
96
+ LocalTensor<float> work_local = work_queue.AllocTensor<float>();
97
+ LocalTensor<float> abs_local = abs_queue.AllocTensor<float>();
98
+ LocalTensor<float> max_local = max_queue.AllocTensor<float>();
99
+ LocalTensor<half> cast_local = cast_queue.AllocTensor<half>();
100
+
101
+ Abs(abs_local, input_local, QK8_0);
102
+ ReduceMax(max_local, abs_local, work_local, QK8_0);
103
+ pipe_barrier(PIPE_ALL);
104
+ float d = max_local.GetValue(0);
105
+ d = d / ((1 << 7) - 1);
106
+ if (d != 0) {
107
+ Muls(input_local, input_local, 1.0f / d, QK8_0);
108
+ }
109
+
110
+ Cast(input_local, input_local, RoundMode::CAST_ROUND, QK8_0);
111
+ Cast(cast_local, input_local, RoundMode::CAST_ROUND, QK8_0);
112
+ Cast(output_local, cast_local, RoundMode::CAST_ROUND, QK8_0);
113
+ output_queue.EnQue(output_local);
114
+ copy_out(output_offset);
115
+
116
+ input_queue.FreeTensor(input_local);
117
+ work_queue.FreeTensor(work_local);
118
+ abs_queue.FreeTensor(abs_local);
119
+ max_queue.FreeTensor(max_local);
120
+ cast_queue.FreeTensor(cast_local);
121
+
122
+ return (half)d;
123
+ }
124
+
125
+ __aicore__ inline void calculate() {
126
+ LocalTensor<half> scale_local = scale_queue.AllocTensor<half>();
127
+ uint32_t scale_local_offset = 0;
128
+ uint32_t scale_global_offset = 0;
129
+ for (int64_t i = ir; i < ir + dr; i++) {
130
+ for (int64_t j = 0; j < group_size_in_row; j++) {
131
+ half scale = calculate_group(i, j);
132
+ scale_local.SetValue(scale_local_offset++, scale);
133
+ if (scale_local_offset == 16) {
134
+ scale_local_offset = 0;
135
+ // TODO: OPTIMIZE ME
136
+ pipe_barrier(PIPE_ALL);
137
+ DataCopy(scale_gm[scale_global_offset], scale_local, 16);
138
+ pipe_barrier(PIPE_ALL);
139
+ scale_global_offset += 16;
140
+ }
141
+ }
142
+ }
143
+
144
+ if (scale_local_offset != 0) {
145
+ pipe_barrier(PIPE_ALL);
146
+ DataCopyExtParams dataCopyParams;
147
+ dataCopyParams.blockCount = 1;
148
+ dataCopyParams.blockLen = scale_local_offset * sizeof(half);
149
+ DataCopyPad(scale_gm[scale_global_offset], scale_local,
150
+ dataCopyParams);
151
+ pipe_barrier(PIPE_ALL);
152
+ }
153
+ }
154
+
155
+ private:
156
+ int64_t input_ne[4];
157
+ size_t input_stride[4];
158
+
159
+ int64_t *scale_ne;
160
+ size_t scale_stride[4];
161
+
162
+ int64_t output_ne[4];
163
+ size_t output_stride[4];
164
+
165
+ int64_t group_size_in_row;
166
+
167
+ int64_t ir;
168
+ int64_t dr;
169
+
170
+ TPipe pipe;
171
+ GlobalTensor<float> input_gm;
172
+ GlobalTensor<half> scale_gm;
173
+ GlobalTensor<int8_t> output_gm;
174
+ TQue<QuePosition::VECIN, BUFFER_NUM> input_queue;
175
+ TQue<QuePosition::VECOUT, BUFFER_NUM> output_queue;
176
+ TQue<QuePosition::VECIN, 1> work_queue;
177
+ TQue<QuePosition::VECOUT, 1> max_queue;
178
+ TQue<QuePosition::VECIN, 1> abs_queue;
179
+ TQue<QuePosition::VECIN, 1> cast_queue;
180
+ TQue<QuePosition::VECOUT, 1> scale_queue;
181
+ };
182
+
183
+ template <typename T>
184
+ __aicore__ inline void copy_to_ub(GM_ADDR gm, T *ub, size_t size) {
185
+ auto gm_ptr = (__gm__ uint8_t *)gm;
186
+ auto ub_ptr = (uint8_t *)(ub);
187
+ for (int32_t i = 0; i < size; ++i, ++ub_ptr, ++gm_ptr) {
188
+ *ub_ptr = *gm_ptr;
189
+ }
190
+ }
191
+
192
+ extern "C" __global__ __aicore__ void ascendc_quantize_f32_q8_0(
193
+ GM_ADDR input_gm, GM_ADDR output_gm, GM_ADDR input_ne_gm,
194
+ GM_ADDR input_nb_gm, GM_ADDR output_ne_gm) {
195
+ int64_t input_ne_ub[4];
196
+ size_t input_nb_ub[4];
197
+ int64_t output_ne_ub[4];
198
+
199
+ copy_to_ub(input_ne_gm, input_ne_ub, 32);
200
+ copy_to_ub(input_nb_gm, input_nb_ub, 32);
201
+ copy_to_ub(output_ne_gm, output_ne_ub, 32);
202
+
203
+ QUANTIZE_F32_Q8_0 op;
204
+ op.init(input_gm, output_gm, input_ne_ub, input_nb_ub, output_ne_ub);
205
+ op.calculate();
206
+ }
ggml/src/ggml-cann/kernels/quantize_float_to_q4_0.cpp ADDED
@@ -0,0 +1,278 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include "kernel_operator.h"
2
+
3
+ using namespace AscendC;
4
+
5
+ #define BUFFER_NUM 2
6
+ #define Group_Size 32
7
+
8
+ template <typename SRC_T>
9
+ class QUANTIZE_FLOAT_TO_Q4_0 {
10
+ public:
11
+ __aicore__ inline QUANTIZE_FLOAT_TO_Q4_0() {}
12
+ __aicore__ inline void init(GM_ADDR input, GM_ADDR output,
13
+ int64_t *input_ne_ub, size_t *input_nb_ub,
14
+ int64_t *output_ne_ub) {
15
+ // TODO: fix test_case CPY(type_src=f16,type_dst=q4_0,ne=[256,4,4,4],
16
+ // permute=[0,0,0,0]):
17
+ // [CPY] NMSE = 0.000008343 > 0.000001000 FAIL
18
+ int64_t op_block_num = GetBlockNum();
19
+ int64_t op_block_idx = GetBlockIdx();
20
+
21
+ // input stride of data elements
22
+ for (int i = 0; i < 4; i++) {
23
+ input_ne[i] = input_ne_ub[i];
24
+ input_stride[i] = input_nb_ub[i] / input_nb_ub[0];
25
+ output_ne[i] = output_ne_ub[i];
26
+ }
27
+
28
+ // output stride of data elements
29
+ output_stride[0] = 1;
30
+ for (int i = 1; i < 4; i++) {
31
+ output_stride[i] = output_stride[i - 1] * output_ne[i - 1];
32
+ }
33
+
34
+ // scale saved one by one after data:. [group1_scale, group2_scale, ...]
35
+ scale_ne = input_ne;
36
+ scale_stride[0] = 1;
37
+ scale_stride[1] = input_ne[0] / Group_Size;
38
+ for (int i = 2; i < 4; i++) {
39
+ scale_stride[i] = scale_stride[i - 1] * scale_ne[i - 1];
40
+ }
41
+
42
+ // split input tensor by rows.
43
+ uint64_t nr = input_ne[1] * input_ne[2] * input_ne[3];
44
+ dr = nr / op_block_num;
45
+
46
+ uint64_t tails = nr % op_block_num;
47
+ if (op_block_idx < tails) {
48
+ dr += 1;
49
+ ir = dr * op_block_idx;
50
+ } else {
51
+ ir = dr * op_block_idx + tails;
52
+ }
53
+
54
+ group_size_in_row = scale_stride[1];
55
+ int64_t scale_offset = output_ne[0] * output_ne[1] * output_ne[2] *
56
+ output_ne[3] * sizeof(uint8_t) / 2;
57
+
58
+ input_gm.SetGlobalBuffer((__gm__ SRC_T *)input);
59
+ output_gm.SetGlobalBuffer((__gm__ int8_t *)output);
60
+ scale_gm.SetGlobalBuffer((__gm__ half *)(output + scale_offset + ir *
61
+ group_size_in_row *
62
+ sizeof(half)));
63
+
64
+ pipe.InitBuffer(input_queue, BUFFER_NUM, Group_Size * sizeof(SRC_T));
65
+ pipe.InitBuffer(output_queue, BUFFER_NUM,
66
+ Group_Size * sizeof(int8_t) / 2);
67
+ pipe.InitBuffer(cast_queue , 1, Group_Size * sizeof(float));
68
+ pipe.InitBuffer(work_queue, 1, Group_Size * sizeof(float));
69
+ pipe.InitBuffer(max_queue, 1, Group_Size * sizeof(float));
70
+ pipe.InitBuffer(min_queue, 1, Group_Size * sizeof(float));
71
+ pipe.InitBuffer(scale_queue, 1, Group_Size / 2 * sizeof(half));
72
+ pipe.InitBuffer(int8_queue, 1, Group_Size * sizeof(int8_t));
73
+ pipe.InitBuffer(half_queue, 1, Group_Size * sizeof(half));
74
+ }
75
+
76
+ __aicore__ inline void copy_in(uint32_t offset) {
77
+ LocalTensor<SRC_T> input_local = input_queue.AllocTensor<SRC_T>();
78
+ DataCopy(input_local, input_gm[offset], Group_Size);
79
+ input_queue.EnQue(input_local);
80
+ }
81
+
82
+ __aicore__ inline void copy_out(uint32_t offset) {
83
+ // reinterpretcast Group_Size(32) * int4b_t to Group_Size / 2 * int8_t,
84
+ // and using DataCopyPad to avoid 32 bits align.
85
+ LocalTensor<int4b_t> output_local = output_queue.DeQue<int4b_t>();
86
+ LocalTensor<int8_t> output_int8_local =
87
+ output_local.ReinterpretCast<int8_t>();
88
+
89
+ DataCopyExtParams dataCopyParams;
90
+ dataCopyParams.blockCount = 1;
91
+ dataCopyParams.blockLen = Group_Size / 2 * sizeof(int8_t);
92
+ DataCopyPad(output_gm[offset], output_int8_local, dataCopyParams);
93
+
94
+ output_queue.FreeTensor(output_local);
95
+ }
96
+
97
+ __aicore__ inline void input_to_cast(LocalTensor<float> cast_local,
98
+ LocalTensor<float> input_local) {
99
+ DataCopy(cast_local, input_local, Group_Size);
100
+ }
101
+
102
+ __aicore__ inline void input_to_cast(LocalTensor<float> cast_local,
103
+ LocalTensor<half> input_local) {
104
+ Cast(cast_local, input_local, RoundMode::CAST_NONE, Group_Size);
105
+ }
106
+
107
+ __aicore__ inline half calculate_group(int64_t row, int64_t group) {
108
+ const int64_t i3 = row / (input_ne[1] * input_ne[2]);
109
+ const int64_t i2 = (row - i3 * input_ne[1] * input_ne[2]) / input_ne[1];
110
+ const int64_t i1 =
111
+ row - i3 * input_ne[1] * input_ne[2] - i2 * input_ne[1];
112
+
113
+ const int64_t input_offset = i1 * input_stride[1] +
114
+ i2 * input_stride[2] +
115
+ i3 * input_stride[3] + Group_Size * group;
116
+
117
+ // output_offset is stride for output_gm which datatype is int8_t and
118
+ // divided by 2 is needed for int4b_t.
119
+ const int64_t output_offset = (i1 * output_stride[1] +
120
+ i2 * output_stride[2] +
121
+ i3 * output_stride[3] +
122
+ Group_Size * group) / 2;
123
+ copy_in(input_offset);
124
+
125
+ LocalTensor<SRC_T> input_local = input_queue.DeQue<SRC_T>();
126
+ LocalTensor<int4b_t> output_local = output_queue.AllocTensor<int4b_t>();
127
+ LocalTensor<float> cast_local = cast_queue.AllocTensor<float>();
128
+ LocalTensor<float> work_local = work_queue.AllocTensor<float>();
129
+ LocalTensor<float> max_local = max_queue.AllocTensor<float>();
130
+ LocalTensor<float> min_local = min_queue.AllocTensor<float>();
131
+ LocalTensor<int8_t> int8_local = int8_queue.AllocTensor<int8_t>();
132
+ LocalTensor<half> half_local = half_queue.AllocTensor<half>();
133
+
134
+ input_to_cast(cast_local, input_local);
135
+
136
+ ReduceMax(max_local, cast_local, work_local, Group_Size);
137
+ ReduceMin(min_local, cast_local, work_local, Group_Size);
138
+ const float max_value = max_local.GetValue(0);
139
+ const float min_value = min_local.GetValue(0);
140
+ float d = max_value;
141
+ if (min_value < 0 && (-1 * min_value) > max_value) {
142
+ d = min_value;
143
+ }
144
+
145
+ d = d / (-8);
146
+ if (d != 0) {
147
+ Muls(cast_local, cast_local, 1.0f / d, Group_Size);
148
+ }
149
+
150
+ // range: [-8,8] -> [0.5,16.5] -> [0,16] -> [0,15] -> [-8,7]
151
+ float scalar = 8.5f;
152
+ Adds(cast_local, cast_local, scalar, Group_Size);
153
+ Cast(cast_local, cast_local, RoundMode::CAST_FLOOR, Group_Size);
154
+ scalar = 15.0f;
155
+ Mins(cast_local, cast_local, scalar, Group_Size);
156
+ scalar = -8.0f;
157
+ Adds(cast_local, cast_local, scalar, Group_Size);
158
+
159
+ // float->half->int4b
160
+ Cast(half_local, cast_local, RoundMode::CAST_NONE, Group_Size);
161
+ Cast(output_local, half_local, RoundMode::CAST_NONE, Group_Size);
162
+
163
+ output_queue.EnQue(output_local);
164
+ copy_out(output_offset);
165
+
166
+ input_queue.FreeTensor(input_local);
167
+ work_queue.FreeTensor(work_local);
168
+ max_queue.FreeTensor(max_local);
169
+ min_queue.FreeTensor(min_local);
170
+ int8_queue.FreeTensor(int8_local);
171
+ half_queue.FreeTensor(half_local);
172
+ cast_queue.FreeTensor(cast_local);
173
+ return (half)d;
174
+ }
175
+
176
+ __aicore__ inline void calculate() {
177
+ LocalTensor<half> scale_local = scale_queue.AllocTensor<half>();
178
+ uint32_t scale_local_offset = 0;
179
+ uint32_t scale_global_offset = 0;
180
+ for (int64_t i = ir; i < ir + dr; i++) {
181
+ for (int64_t j = 0; j < group_size_in_row; j++) {
182
+ half scale = calculate_group(i, j);
183
+ scale_local.SetValue(scale_local_offset++, scale);
184
+ // Copy Group_Size/2 length data each time.
185
+ if (scale_local_offset == Group_Size / 2) {
186
+ scale_local_offset = 0;
187
+ // TODO: OPTIMIZE ME
188
+ pipe_barrier(PIPE_ALL);
189
+ DataCopy(scale_gm[scale_global_offset], scale_local,
190
+ Group_Size / 2);
191
+ pipe_barrier(PIPE_ALL);
192
+ scale_global_offset += Group_Size / 2;
193
+ }
194
+ }
195
+ }
196
+
197
+ if (scale_local_offset != 0) {
198
+ pipe_barrier(PIPE_ALL);
199
+ DataCopyExtParams dataCopyParams;
200
+ dataCopyParams.blockCount = 1;
201
+ dataCopyParams.blockLen = scale_local_offset * sizeof(half);
202
+ DataCopyPad(scale_gm[scale_global_offset], scale_local,
203
+ dataCopyParams);
204
+ pipe_barrier(PIPE_ALL);
205
+ }
206
+ scale_queue.FreeTensor(scale_local);
207
+ }
208
+
209
+ private:
210
+ int64_t input_ne[4];
211
+ size_t input_stride[4];
212
+
213
+ int64_t *scale_ne;
214
+ size_t scale_stride[4];
215
+
216
+ int64_t output_ne[4];
217
+ size_t output_stride[4];
218
+
219
+ int64_t group_size_in_row;
220
+
221
+ int64_t ir;
222
+ int64_t dr;
223
+
224
+ TPipe pipe;
225
+ GlobalTensor<SRC_T> input_gm;
226
+ GlobalTensor<half> scale_gm;
227
+ GlobalTensor<int8_t> output_gm;
228
+ TQue<QuePosition::VECIN, BUFFER_NUM> input_queue;
229
+ TQue<QuePosition::VECOUT, BUFFER_NUM> output_queue;
230
+ TQue<QuePosition::VECIN, BUFFER_NUM> work_queue;
231
+ TQue<QuePosition::VECOUT, BUFFER_NUM> max_queue;
232
+ TQue<QuePosition::VECOUT, BUFFER_NUM> min_queue;
233
+ TQue<QuePosition::VECOUT, BUFFER_NUM> scale_queue;
234
+ TQue<QuePosition::VECOUT, BUFFER_NUM> cast_queue;
235
+ TQue<QuePosition::VECOUT, BUFFER_NUM> int8_queue;
236
+ TQue<QuePosition::VECOUT, BUFFER_NUM> half_queue;
237
+ };
238
+
239
+ template <typename T>
240
+ __aicore__ inline void copy_to_ub(GM_ADDR gm, T *ub, size_t size) {
241
+ auto gm_ptr = (__gm__ uint8_t *)gm;
242
+ auto ub_ptr = (uint8_t *)(ub);
243
+ for (int32_t i = 0; i < size; ++i, ++ub_ptr, ++gm_ptr) {
244
+ *ub_ptr = *gm_ptr;
245
+ }
246
+ }
247
+
248
+ extern "C" __global__ __aicore__ void ascendc_quantize_f16_to_q4_0(
249
+ GM_ADDR input_gm, GM_ADDR output_gm, GM_ADDR input_ne_gm,
250
+ GM_ADDR input_nb_gm, GM_ADDR output_ne_gm) {
251
+ int64_t input_ne_ub[4];
252
+ size_t input_nb_ub[4];
253
+ int64_t output_ne_ub[4];
254
+
255
+ copy_to_ub(input_ne_gm, input_ne_ub, 32);
256
+ copy_to_ub(input_nb_gm, input_nb_ub, 32);
257
+ copy_to_ub(output_ne_gm, output_ne_ub, 32);
258
+
259
+ QUANTIZE_FLOAT_TO_Q4_0<half> op;
260
+ op.init(input_gm, output_gm, input_ne_ub, input_nb_ub, output_ne_ub);
261
+ op.calculate();
262
+ }
263
+
264
+ extern "C" __global__ __aicore__ void ascendc_quantize_f32_to_q4_0(
265
+ GM_ADDR input_gm, GM_ADDR output_gm, GM_ADDR input_ne_gm,
266
+ GM_ADDR input_nb_gm, GM_ADDR output_ne_gm) {
267
+ int64_t input_ne_ub[4];
268
+ size_t input_nb_ub[4];
269
+ int64_t output_ne_ub[4];
270
+
271
+ copy_to_ub(input_ne_gm, input_ne_ub, 32);
272
+ copy_to_ub(input_nb_gm, input_nb_ub, 32);
273
+ copy_to_ub(output_ne_gm, output_ne_ub, 32);
274
+
275
+ QUANTIZE_FLOAT_TO_Q4_0<float> op;
276
+ op.init(input_gm, output_gm, input_ne_ub, input_nb_ub, output_ne_ub);
277
+ op.calculate();
278
+ }