Spaces:
Running
Running
AidanBeltonS
Abhilash Majumder
commited on
Add support for soft_max ALiBi (llama/5639)
Browse files* Add support for bias
* Update pre-processor
* rm commented code
* fix format
* fix CI
---------
Co-authored-by: Abhilash Majumder <[email protected]>
- ggml-sycl.cpp +165 -81
ggml-sycl.cpp
CHANGED
|
@@ -8126,23 +8126,51 @@ static void diag_mask_inf_f32(const float * x, float * dst, const int ncols, con
|
|
| 8126 |
dst[i] = x[i] - (col > n_past + row % rows_per_channel) * FLT_MAX;
|
| 8127 |
}
|
| 8128 |
|
| 8129 |
-
|
| 8130 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8131 |
const int tid = item_ct1.get_local_id(2);
|
| 8132 |
const int rowx = item_ct1.get_group(2);
|
| 8133 |
const int rowy = rowx % nrows_y; // broadcast the mask (y) in the row dimension
|
| 8134 |
|
| 8135 |
-
const int block_size = item_ct1.get_local_range(2);
|
| 8136 |
|
| 8137 |
const int warp_id = item_ct1.get_local_id(2) / WARP_SIZE;
|
| 8138 |
const int lane_id = item_ct1.get_local_id(2) % WARP_SIZE;
|
| 8139 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8140 |
float max_val = -INFINITY;
|
| 8141 |
|
| 8142 |
-
for (int
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8143 |
const int ix = rowx*ncols + col;
|
| 8144 |
const int iy = rowy*ncols + col;
|
| 8145 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8146 |
}
|
| 8147 |
|
| 8148 |
// find the max value in the block
|
|
@@ -8151,30 +8179,12 @@ static void soft_max_f32(const float * x, const float * y, float * dst, const in
|
|
| 8151 |
if (warp_id == 0) {
|
| 8152 |
buf[lane_id] = -INFINITY;
|
| 8153 |
}
|
| 8154 |
-
|
| 8155 |
-
DPCT1118:12: SYCL group functions and algorithms must be encountered in
|
| 8156 |
-
converged control flow. You may need to adjust the code.
|
| 8157 |
-
*/
|
| 8158 |
-
/*
|
| 8159 |
-
DPCT1065:60: Consider replacing sycl::nd_item::barrier() with
|
| 8160 |
-
sycl::nd_item::barrier(sycl::access::fence_space::local_space) for
|
| 8161 |
-
better performance if there is no access to global memory.
|
| 8162 |
-
*/
|
| 8163 |
-
item_ct1.barrier();
|
| 8164 |
|
| 8165 |
if (lane_id == 0) {
|
| 8166 |
buf[warp_id] = max_val;
|
| 8167 |
}
|
| 8168 |
-
|
| 8169 |
-
DPCT1118:13: SYCL group functions and algorithms must be encountered in
|
| 8170 |
-
converged control flow. You may need to adjust the code.
|
| 8171 |
-
*/
|
| 8172 |
-
/*
|
| 8173 |
-
DPCT1065:61: Consider replacing sycl::nd_item::barrier() with
|
| 8174 |
-
sycl::nd_item::barrier(sycl::access::fence_space::local_space) for
|
| 8175 |
-
better performance if there is no access to global memory.
|
| 8176 |
-
*/
|
| 8177 |
-
item_ct1.barrier();
|
| 8178 |
|
| 8179 |
max_val = buf[lane_id];
|
| 8180 |
max_val = warp_reduce_max(max_val, item_ct1);
|
|
@@ -8182,13 +8192,16 @@ static void soft_max_f32(const float * x, const float * y, float * dst, const in
|
|
| 8182 |
|
| 8183 |
float tmp = 0.f;
|
| 8184 |
|
| 8185 |
-
|
| 8186 |
-
|
| 8187 |
-
const int
|
| 8188 |
-
|
| 8189 |
-
|
|
|
|
|
|
|
|
|
|
| 8190 |
tmp += val;
|
| 8191 |
-
|
| 8192 |
}
|
| 8193 |
|
| 8194 |
// find the sum of exps in the block
|
|
@@ -8197,40 +8210,29 @@ static void soft_max_f32(const float * x, const float * y, float * dst, const in
|
|
| 8197 |
if (warp_id == 0) {
|
| 8198 |
buf[lane_id] = 0.f;
|
| 8199 |
}
|
| 8200 |
-
|
| 8201 |
-
DPCT1118:14: SYCL group functions and algorithms must be encountered in
|
| 8202 |
-
converged control flow. You may need to adjust the code.
|
| 8203 |
-
*/
|
| 8204 |
-
/*
|
| 8205 |
-
DPCT1065:62: Consider replacing sycl::nd_item::barrier() with
|
| 8206 |
-
sycl::nd_item::barrier(sycl::access::fence_space::local_space) for
|
| 8207 |
-
better performance if there is no access to global memory.
|
| 8208 |
-
*/
|
| 8209 |
-
item_ct1.barrier();
|
| 8210 |
|
| 8211 |
if (lane_id == 0) {
|
| 8212 |
buf[warp_id] = tmp;
|
| 8213 |
}
|
| 8214 |
-
|
| 8215 |
-
DPCT1118:15: SYCL group functions and algorithms must be encountered in
|
| 8216 |
-
converged control flow. You may need to adjust the code.
|
| 8217 |
-
*/
|
| 8218 |
-
/*
|
| 8219 |
-
DPCT1065:63: Consider replacing sycl::nd_item::barrier() with
|
| 8220 |
-
sycl::nd_item::barrier(sycl::access::fence_space::local_space) for
|
| 8221 |
-
better performance if there is no access to global memory.
|
| 8222 |
-
*/
|
| 8223 |
-
item_ct1.barrier();
|
| 8224 |
|
| 8225 |
tmp = buf[lane_id];
|
| 8226 |
tmp = warp_reduce_sum(tmp, item_ct1);
|
| 8227 |
}
|
| 8228 |
|
| 8229 |
-
const float
|
| 8230 |
|
| 8231 |
-
|
| 8232 |
-
|
| 8233 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8234 |
}
|
| 8235 |
}
|
| 8236 |
|
|
@@ -10867,37 +10869,98 @@ static void diag_mask_inf_f32_sycl(const float *x, float *dst,
|
|
| 10867 |
});
|
| 10868 |
}
|
| 10869 |
|
| 10870 |
-
|
| 10871 |
-
|
| 10872 |
-
|
| 10873 |
-
|
| 10874 |
-
|
| 10875 |
-
while (nth < ncols_x && nth < SYCL_SOFT_MAX_BLOCK_SIZE) nth *= 2;
|
| 10876 |
-
const sycl::range<3> block_dims(1, 1, nth);
|
| 10877 |
-
const sycl::range<3> block_nums(1, 1, nrows_x);
|
| 10878 |
-
/*
|
| 10879 |
-
DPCT1049:46: The work-group size passed to the SYCL kernel may exceed the
|
| 10880 |
-
limit. To get the device limit, query info::device::max_work_group_size.
|
| 10881 |
-
Adjust the work-group size if needed.
|
| 10882 |
-
*/
|
| 10883 |
stream->submit([&](sycl::handler &cgh) {
|
| 10884 |
-
|
| 10885 |
-
DPCT1101:96: 'SYCL_SOFT_MAX_BLOCK_SIZE/WARP_SIZE' expression was
|
| 10886 |
-
replaced with a value. Modify the code to use the original expression,
|
| 10887 |
-
provided in comments, if it is correct.
|
| 10888 |
-
*/
|
| 10889 |
-
sycl::local_accessor<float, 1> buf_acc_ct1(
|
| 10890 |
-
sycl::range<1>(32 /*SYCL_SOFT_MAX_BLOCK_SIZE/WARP_SIZE*/), cgh);
|
| 10891 |
|
| 10892 |
cgh.parallel_for(
|
| 10893 |
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
| 10894 |
[=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(32)]] {
|
| 10895 |
-
soft_max_f32
|
| 10896 |
-
|
|
|
|
|
|
|
| 10897 |
});
|
| 10898 |
});
|
| 10899 |
}
|
| 10900 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10901 |
template <typename T>
|
| 10902 |
static void im2col_sycl(const float *x, T *dst, int IW, int IH,
|
| 10903 |
int OW, int OH, int KW, int KH, int IC,
|
|
@@ -12435,14 +12498,35 @@ inline void ggml_sycl_op_soft_max(const ggml_tensor *src0,
|
|
| 12435 |
|
| 12436 |
const int64_t ne00 = src0->ne[0];
|
| 12437 |
const int64_t nrows_x = ggml_nrows(src0);
|
| 12438 |
-
const int64_t nrows_y =
|
| 12439 |
|
| 12440 |
float scale = 1.0f;
|
| 12441 |
-
|
| 12442 |
|
| 12443 |
-
|
|
|
|
| 12444 |
|
| 12445 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12446 |
}
|
| 12447 |
|
| 12448 |
inline void ggml_sycl_op_scale(const ggml_tensor *src0, const ggml_tensor *src1,
|
|
|
|
| 8126 |
dst[i] = x[i] - (col > n_past + row % rows_per_channel) * FLT_MAX;
|
| 8127 |
}
|
| 8128 |
|
| 8129 |
+
|
| 8130 |
+
template <bool vals_smem, int ncols_template, int block_size_template>
|
| 8131 |
+
static void soft_max_f32(const float * x, const float * mask, const float *pos, float * dst, const int ncols_par,
|
| 8132 |
+
const int nrows_y, const float scale, const float max_bias, const float m0,
|
| 8133 |
+
const float m1, uint32_t n_head_log2, const sycl::nd_item<3> &item_ct1, float *buf) {
|
| 8134 |
+
const int ncols = ncols_template == 0 ? ncols_par : ncols_template;
|
| 8135 |
+
|
| 8136 |
const int tid = item_ct1.get_local_id(2);
|
| 8137 |
const int rowx = item_ct1.get_group(2);
|
| 8138 |
const int rowy = rowx % nrows_y; // broadcast the mask (y) in the row dimension
|
| 8139 |
|
| 8140 |
+
const int block_size = block_size_template == 0 ? item_ct1.get_local_range(2) : block_size_template;
|
| 8141 |
|
| 8142 |
const int warp_id = item_ct1.get_local_id(2) / WARP_SIZE;
|
| 8143 |
const int lane_id = item_ct1.get_local_id(2) % WARP_SIZE;
|
| 8144 |
|
| 8145 |
+
float slope = 0.0f;
|
| 8146 |
+
|
| 8147 |
+
// ALiBi
|
| 8148 |
+
if (max_bias > 0.0f) {
|
| 8149 |
+
const uint32_t h = rowx/nrows_y; // head index
|
| 8150 |
+
|
| 8151 |
+
const float base = h < n_head_log2 ? m0 : m1;
|
| 8152 |
+
const int exp = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
|
| 8153 |
+
|
| 8154 |
+
slope = sycl::pow(base, float(exp));
|
| 8155 |
+
}
|
| 8156 |
+
|
| 8157 |
+
float * vals = vals_smem ? buf + WARP_SIZE : dst + rowx*ncols;
|
| 8158 |
float max_val = -INFINITY;
|
| 8159 |
|
| 8160 |
+
for (int col0 = 0; col0 < ncols; col0 += block_size) {
|
| 8161 |
+
const int col = col0 + tid;
|
| 8162 |
+
|
| 8163 |
+
if (ncols_template == 0 && col >= ncols) {
|
| 8164 |
+
break;
|
| 8165 |
+
}
|
| 8166 |
+
|
| 8167 |
const int ix = rowx*ncols + col;
|
| 8168 |
const int iy = rowy*ncols + col;
|
| 8169 |
+
|
| 8170 |
+
const float val = x[ix]*scale + (mask ? mask[iy] : 0.0f) + (pos ? slope*pos[col] : 0.0f);
|
| 8171 |
+
|
| 8172 |
+
vals[col] = val;
|
| 8173 |
+
max_val = sycl::max(max_val, val);
|
| 8174 |
}
|
| 8175 |
|
| 8176 |
// find the max value in the block
|
|
|
|
| 8179 |
if (warp_id == 0) {
|
| 8180 |
buf[lane_id] = -INFINITY;
|
| 8181 |
}
|
| 8182 |
+
item_ct1.barrier(sycl::access::fence_space::local_space);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8183 |
|
| 8184 |
if (lane_id == 0) {
|
| 8185 |
buf[warp_id] = max_val;
|
| 8186 |
}
|
| 8187 |
+
item_ct1.barrier(sycl::access::fence_space::local_space);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8188 |
|
| 8189 |
max_val = buf[lane_id];
|
| 8190 |
max_val = warp_reduce_max(max_val, item_ct1);
|
|
|
|
| 8192 |
|
| 8193 |
float tmp = 0.f;
|
| 8194 |
|
| 8195 |
+
#pragma unroll
|
| 8196 |
+
for (int col0 = 0; col0 < ncols; col0 += block_size) {
|
| 8197 |
+
const int col = col0 + tid;
|
| 8198 |
+
if (ncols_template == 0 && col >= ncols) {
|
| 8199 |
+
break;
|
| 8200 |
+
}
|
| 8201 |
+
|
| 8202 |
+
const float val = sycl::native::exp(vals[col] - max_val);
|
| 8203 |
tmp += val;
|
| 8204 |
+
vals[col] = val;
|
| 8205 |
}
|
| 8206 |
|
| 8207 |
// find the sum of exps in the block
|
|
|
|
| 8210 |
if (warp_id == 0) {
|
| 8211 |
buf[lane_id] = 0.f;
|
| 8212 |
}
|
| 8213 |
+
item_ct1.barrier(sycl::access::fence_space::local_space);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8214 |
|
| 8215 |
if (lane_id == 0) {
|
| 8216 |
buf[warp_id] = tmp;
|
| 8217 |
}
|
| 8218 |
+
item_ct1.barrier(sycl::access::fence_space::local_space);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8219 |
|
| 8220 |
tmp = buf[lane_id];
|
| 8221 |
tmp = warp_reduce_sum(tmp, item_ct1);
|
| 8222 |
}
|
| 8223 |
|
| 8224 |
+
const float inv_sum = 1.f / tmp;
|
| 8225 |
|
| 8226 |
+
#pragma unroll
|
| 8227 |
+
for (int col0 = 0; col0 < ncols; col0 += block_size) {
|
| 8228 |
+
const int col = col0 + tid;
|
| 8229 |
+
|
| 8230 |
+
if (ncols_template == 0 && col >= ncols) {
|
| 8231 |
+
return;
|
| 8232 |
+
}
|
| 8233 |
+
|
| 8234 |
+
const int idst = rowx*ncols + col;
|
| 8235 |
+
dst[idst] = vals[col] * inv_sum;
|
| 8236 |
}
|
| 8237 |
}
|
| 8238 |
|
|
|
|
| 10869 |
});
|
| 10870 |
}
|
| 10871 |
|
| 10872 |
+
template <bool vals_smem, int ncols_template, int block_size_template>
|
| 10873 |
+
static void soft_max_f32_submitter(const float * x, const float * mask, const float *pos, float * dst, const int ncols_par,
|
| 10874 |
+
const int nrows_y, const float scale, const float max_bias, const float m0,
|
| 10875 |
+
const float m1, uint32_t n_head_log2, sycl::range<3> block_nums, sycl::range<3> block_dims,
|
| 10876 |
+
const size_t n_local_scratch, dpct::queue_ptr stream) {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10877 |
stream->submit([&](sycl::handler &cgh) {
|
| 10878 |
+
sycl::local_accessor<float, 1> local_buf_acc(n_local_scratch, cgh);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10879 |
|
| 10880 |
cgh.parallel_for(
|
| 10881 |
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
| 10882 |
[=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(32)]] {
|
| 10883 |
+
soft_max_f32<vals_smem, ncols_template, block_size_template>(x, mask, pos, dst, ncols_par,
|
| 10884 |
+
nrows_y, scale, max_bias, m0,
|
| 10885 |
+
m1, n_head_log2, item_ct1,
|
| 10886 |
+
local_buf_acc.get_pointer());
|
| 10887 |
});
|
| 10888 |
});
|
| 10889 |
}
|
| 10890 |
|
| 10891 |
+
static void soft_max_f32_sycl(const float * x, const float * mask, const float * pos,
|
| 10892 |
+
float * dst, const int ncols_x, const int nrows_x,
|
| 10893 |
+
const int nrows_y, const float scale, const float max_bias,
|
| 10894 |
+
dpct::queue_ptr stream) {
|
| 10895 |
+
int nth = WARP_SIZE;
|
| 10896 |
+
while (nth < ncols_x && nth < SYCL_SOFT_MAX_BLOCK_SIZE) nth *= 2;
|
| 10897 |
+
const sycl::range<3> block_dims(1, 1, nth);
|
| 10898 |
+
const sycl::range<3> block_nums(1, 1, nrows_x);
|
| 10899 |
+
const size_t n_local_scratch = (GGML_PAD(ncols_x, WARP_SIZE) + WARP_SIZE);
|
| 10900 |
+
static_assert(SYCL_SOFT_MAX_BLOCK_SIZE == 1024, "These values need to be adjusted.");
|
| 10901 |
+
|
| 10902 |
+
const uint32_t n_head_kv = nrows_x/nrows_y;
|
| 10903 |
+
const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head_kv));
|
| 10904 |
+
|
| 10905 |
+
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
|
| 10906 |
+
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
|
| 10907 |
+
|
| 10908 |
+
const size_t local_mem_size = stream->get_device().get_info<sycl::info::device::local_mem_size>();
|
| 10909 |
+
if (n_local_scratch*sizeof(float) < local_mem_size) {
|
| 10910 |
+
switch (ncols_x) {
|
| 10911 |
+
case 32:
|
| 10912 |
+
soft_max_f32_submitter<true, 32, 32>(x, mask, pos, dst, ncols_x, nrows_y, scale,
|
| 10913 |
+
max_bias, m0, m1, n_head_log2, block_nums,
|
| 10914 |
+
block_dims, n_local_scratch, stream);
|
| 10915 |
+
break;
|
| 10916 |
+
case 64:
|
| 10917 |
+
soft_max_f32_submitter<true, 64, 64>(x, mask, pos, dst, ncols_x, nrows_y, scale,
|
| 10918 |
+
max_bias, m0, m1, n_head_log2, block_nums,
|
| 10919 |
+
block_dims, n_local_scratch, stream);
|
| 10920 |
+
break;
|
| 10921 |
+
case 128:
|
| 10922 |
+
soft_max_f32_submitter<true, 128, 128>(x, mask, pos, dst, ncols_x, nrows_y, scale,
|
| 10923 |
+
max_bias, m0, m1, n_head_log2, block_nums,
|
| 10924 |
+
block_dims, n_local_scratch, stream);
|
| 10925 |
+
break;
|
| 10926 |
+
case 256:
|
| 10927 |
+
soft_max_f32_submitter<true, 256, 256>(x, mask, pos, dst, ncols_x, nrows_y, scale,
|
| 10928 |
+
max_bias, m0, m1, n_head_log2, block_nums,
|
| 10929 |
+
block_dims, n_local_scratch, stream);
|
| 10930 |
+
break;
|
| 10931 |
+
case 512:
|
| 10932 |
+
soft_max_f32_submitter<true, 512, 512>(x, mask, pos, dst, ncols_x, nrows_y, scale,
|
| 10933 |
+
max_bias, m0, m1, n_head_log2, block_nums,
|
| 10934 |
+
block_dims, n_local_scratch, stream);
|
| 10935 |
+
break;
|
| 10936 |
+
case 1024:
|
| 10937 |
+
soft_max_f32_submitter<true, 1024, 1024>(x, mask, pos, dst, ncols_x, nrows_y, scale,
|
| 10938 |
+
max_bias, m0, m1, n_head_log2, block_nums,
|
| 10939 |
+
block_dims, n_local_scratch, stream);
|
| 10940 |
+
break;
|
| 10941 |
+
case 2048:
|
| 10942 |
+
soft_max_f32_submitter<true, 2048, 1024>(x, mask, pos, dst, ncols_x, nrows_y, scale,
|
| 10943 |
+
max_bias, m0, m1, n_head_log2, block_nums,
|
| 10944 |
+
block_dims, n_local_scratch, stream);
|
| 10945 |
+
break;
|
| 10946 |
+
case 4096:
|
| 10947 |
+
soft_max_f32_submitter<true, 4096, 1024>(x, mask, pos, dst, ncols_x, nrows_y, scale,
|
| 10948 |
+
max_bias, m0, m1, n_head_log2, block_nums,
|
| 10949 |
+
block_dims, n_local_scratch, stream);
|
| 10950 |
+
break;
|
| 10951 |
+
default:
|
| 10952 |
+
soft_max_f32_submitter<true, 0, 0>(x, mask, pos, dst, ncols_x, nrows_y, scale,
|
| 10953 |
+
max_bias, m0, m1, n_head_log2, block_nums,
|
| 10954 |
+
block_dims, n_local_scratch, stream);
|
| 10955 |
+
break;
|
| 10956 |
+
}
|
| 10957 |
+
} else {
|
| 10958 |
+
soft_max_f32_submitter<false, 0, 0>(x, mask, pos, dst, ncols_x, nrows_y, scale,
|
| 10959 |
+
max_bias, m0, m1, n_head_log2, block_nums,
|
| 10960 |
+
block_dims, WARP_SIZE, stream);
|
| 10961 |
+
}
|
| 10962 |
+
}
|
| 10963 |
+
|
| 10964 |
template <typename T>
|
| 10965 |
static void im2col_sycl(const float *x, T *dst, int IW, int IH,
|
| 10966 |
int OW, int OH, int KW, int KH, int IC,
|
|
|
|
| 12498 |
|
| 12499 |
const int64_t ne00 = src0->ne[0];
|
| 12500 |
const int64_t nrows_x = ggml_nrows(src0);
|
| 12501 |
+
const int64_t nrows_y = src0->ne[1];
|
| 12502 |
|
| 12503 |
float scale = 1.0f;
|
| 12504 |
+
float max_bias = 0.0f;
|
| 12505 |
|
| 12506 |
+
memcpy(&scale, dst->op_params + 0, sizeof(float));
|
| 12507 |
+
memcpy(&max_bias, dst->op_params + 1, sizeof(float));
|
| 12508 |
|
| 12509 |
+
// positions tensor
|
| 12510 |
+
float * src2_dd = nullptr;
|
| 12511 |
+
sycl_pool_alloc<float> src2_f;
|
| 12512 |
+
|
| 12513 |
+
ggml_tensor * src2 = dst->src[2];
|
| 12514 |
+
const bool use_src2 = src2 != nullptr;
|
| 12515 |
+
|
| 12516 |
+
if (use_src2) {
|
| 12517 |
+
const bool src2_on_device = src2->backend == GGML_BACKEND_TYPE_GPU;
|
| 12518 |
+
|
| 12519 |
+
if (src2_on_device) {
|
| 12520 |
+
ggml_tensor_extra_gpu * src2_extra = (ggml_tensor_extra_gpu *) src2->extra;
|
| 12521 |
+
src2_dd = (float *) src2_extra->data_device[g_main_device];
|
| 12522 |
+
} else {
|
| 12523 |
+
src2_dd = src2_f.alloc(ggml_nelements(src2));
|
| 12524 |
+
SYCL_CHECK(ggml_sycl_cpy_tensor_2d(src2_dd, src2, 0, 0, 0, 1, main_stream));
|
| 12525 |
+
}
|
| 12526 |
+
}
|
| 12527 |
+
|
| 12528 |
+
soft_max_f32_sycl(src0_dd, src1 ? src1_dd : nullptr, src2_dd, dst_dd, ne00,
|
| 12529 |
+
nrows_x, nrows_y, scale, max_bias, main_stream);
|
| 12530 |
}
|
| 12531 |
|
| 12532 |
inline void ggml_sycl_op_scale(const ggml_tensor *src0, const ggml_tensor *src1,
|