KevinLy commited on
Commit
6519fd2
·
1 Parent(s): 4eec44b

fix scratch size of softmax (llama/8642)

Browse files
Files changed (1) hide show
  1. ggml/src/ggml-sycl/softmax.cpp +2 -1
ggml/src/ggml-sycl/softmax.cpp CHANGED
@@ -152,7 +152,8 @@ static void soft_max_f32_sycl(const float * x, const float * mask,
152
 
153
  const sycl::range<3> block_dims(1, 1, nth);
154
  const sycl::range<3> block_nums(1, 1, nrows_x);
155
- const size_t n_local_scratch = (GGML_PAD(ncols_x, WARP_SIZE) + WARP_SIZE);
 
156
 
157
  const uint32_t n_head_kv = nrows_x/nrows_y;
158
  const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head_kv));
 
152
 
153
  const sycl::range<3> block_dims(1, 1, nth);
154
  const sycl::range<3> block_nums(1, 1, nrows_x);
155
+ const size_t n_val_tmp = nth / WARP_SIZE;
156
+ const size_t n_local_scratch = (GGML_PAD(ncols_x, WARP_SIZE) + n_val_tmp);
157
 
158
  const uint32_t n_head_kv = nrows_x/nrows_y;
159
  const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head_kv));