Spaces:
Sleeping
Sleeping
fix scratch size of softmax (llama/8642)
Browse files
ggml/src/ggml-sycl/softmax.cpp
CHANGED
|
@@ -152,7 +152,8 @@ static void soft_max_f32_sycl(const float * x, const float * mask,
|
|
| 152 |
|
| 153 |
const sycl::range<3> block_dims(1, 1, nth);
|
| 154 |
const sycl::range<3> block_nums(1, 1, nrows_x);
|
| 155 |
-
const size_t
|
|
|
|
| 156 |
|
| 157 |
const uint32_t n_head_kv = nrows_x/nrows_y;
|
| 158 |
const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head_kv));
|
|
|
|
| 152 |
|
| 153 |
const sycl::range<3> block_dims(1, 1, nth);
|
| 154 |
const sycl::range<3> block_nums(1, 1, nrows_x);
|
| 155 |
+
const size_t n_val_tmp = nth / WARP_SIZE;
|
| 156 |
+
const size_t n_local_scratch = (GGML_PAD(ncols_x, WARP_SIZE) + n_val_tmp);
|
| 157 |
|
| 158 |
const uint32_t n_head_kv = nrows_x/nrows_y;
|
| 159 |
const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head_kv));
|