AidanBeltonS Abhilash Majumder commited on
Commit
86d6a5e
·
unverified ·
1 Parent(s): 430efc6

Add support for soft_max ALiBi (llama/5639)

Browse files

* Add support for bias

* Update pre-processor

* rm commented code

* fix format

* fix CI

---------

Co-authored-by: Abhilash Majumder <[email protected]>

Files changed (1) hide show
  1. ggml-sycl.cpp +165 -81
ggml-sycl.cpp CHANGED
@@ -8126,23 +8126,51 @@ static void diag_mask_inf_f32(const float * x, float * dst, const int ncols, con
8126
  dst[i] = x[i] - (col > n_past + row % rows_per_channel) * FLT_MAX;
8127
  }
8128
 
8129
- static void soft_max_f32(const float * x, const float * y, float * dst, const int ncols, const int nrows_y, const float scale,
8130
- const sycl::nd_item<3> &item_ct1, float *buf) {
 
 
 
 
 
8131
  const int tid = item_ct1.get_local_id(2);
8132
  const int rowx = item_ct1.get_group(2);
8133
  const int rowy = rowx % nrows_y; // broadcast the mask (y) in the row dimension
8134
 
8135
- const int block_size = item_ct1.get_local_range(2);
8136
 
8137
  const int warp_id = item_ct1.get_local_id(2) / WARP_SIZE;
8138
  const int lane_id = item_ct1.get_local_id(2) % WARP_SIZE;
8139
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8140
  float max_val = -INFINITY;
8141
 
8142
- for (int col = tid; col < ncols; col += block_size) {
 
 
 
 
 
 
8143
  const int ix = rowx*ncols + col;
8144
  const int iy = rowy*ncols + col;
8145
- max_val = sycl::max(max_val, x[ix] * scale + (y ? y[iy] : 0.0f));
 
 
 
 
8146
  }
8147
 
8148
  // find the max value in the block
@@ -8151,30 +8179,12 @@ static void soft_max_f32(const float * x, const float * y, float * dst, const in
8151
  if (warp_id == 0) {
8152
  buf[lane_id] = -INFINITY;
8153
  }
8154
- /*
8155
- DPCT1118:12: SYCL group functions and algorithms must be encountered in
8156
- converged control flow. You may need to adjust the code.
8157
- */
8158
- /*
8159
- DPCT1065:60: Consider replacing sycl::nd_item::barrier() with
8160
- sycl::nd_item::barrier(sycl::access::fence_space::local_space) for
8161
- better performance if there is no access to global memory.
8162
- */
8163
- item_ct1.barrier();
8164
 
8165
  if (lane_id == 0) {
8166
  buf[warp_id] = max_val;
8167
  }
8168
- /*
8169
- DPCT1118:13: SYCL group functions and algorithms must be encountered in
8170
- converged control flow. You may need to adjust the code.
8171
- */
8172
- /*
8173
- DPCT1065:61: Consider replacing sycl::nd_item::barrier() with
8174
- sycl::nd_item::barrier(sycl::access::fence_space::local_space) for
8175
- better performance if there is no access to global memory.
8176
- */
8177
- item_ct1.barrier();
8178
 
8179
  max_val = buf[lane_id];
8180
  max_val = warp_reduce_max(max_val, item_ct1);
@@ -8182,13 +8192,16 @@ static void soft_max_f32(const float * x, const float * y, float * dst, const in
8182
 
8183
  float tmp = 0.f;
8184
 
8185
- for (int col = tid; col < ncols; col += block_size) {
8186
- const int ix = rowx*ncols + col;
8187
- const int iy = rowy*ncols + col;
8188
- const float val =
8189
- sycl::native::exp((x[ix] * scale + (y ? y[iy] : 0.0f)) - max_val);
 
 
 
8190
  tmp += val;
8191
- dst[ix] = val;
8192
  }
8193
 
8194
  // find the sum of exps in the block
@@ -8197,40 +8210,29 @@ static void soft_max_f32(const float * x, const float * y, float * dst, const in
8197
  if (warp_id == 0) {
8198
  buf[lane_id] = 0.f;
8199
  }
8200
- /*
8201
- DPCT1118:14: SYCL group functions and algorithms must be encountered in
8202
- converged control flow. You may need to adjust the code.
8203
- */
8204
- /*
8205
- DPCT1065:62: Consider replacing sycl::nd_item::barrier() with
8206
- sycl::nd_item::barrier(sycl::access::fence_space::local_space) for
8207
- better performance if there is no access to global memory.
8208
- */
8209
- item_ct1.barrier();
8210
 
8211
  if (lane_id == 0) {
8212
  buf[warp_id] = tmp;
8213
  }
8214
- /*
8215
- DPCT1118:15: SYCL group functions and algorithms must be encountered in
8216
- converged control flow. You may need to adjust the code.
8217
- */
8218
- /*
8219
- DPCT1065:63: Consider replacing sycl::nd_item::barrier() with
8220
- sycl::nd_item::barrier(sycl::access::fence_space::local_space) for
8221
- better performance if there is no access to global memory.
8222
- */
8223
- item_ct1.barrier();
8224
 
8225
  tmp = buf[lane_id];
8226
  tmp = warp_reduce_sum(tmp, item_ct1);
8227
  }
8228
 
8229
- const float inv_tmp = 1.f / tmp;
8230
 
8231
- for (int col = tid; col < ncols; col += block_size) {
8232
- const int i = rowx*ncols + col;
8233
- dst[i] *= inv_tmp;
 
 
 
 
 
 
 
8234
  }
8235
  }
8236
 
@@ -10867,37 +10869,98 @@ static void diag_mask_inf_f32_sycl(const float *x, float *dst,
10867
  });
10868
  }
10869
 
10870
- static void soft_max_f32_sycl(const float *x, const float *y, float *dst,
10871
- const int ncols_x, const int nrows_x,
10872
- const int nrows_y, const float scale,
10873
- dpct::queue_ptr stream) {
10874
- int nth = WARP_SIZE;
10875
- while (nth < ncols_x && nth < SYCL_SOFT_MAX_BLOCK_SIZE) nth *= 2;
10876
- const sycl::range<3> block_dims(1, 1, nth);
10877
- const sycl::range<3> block_nums(1, 1, nrows_x);
10878
- /*
10879
- DPCT1049:46: The work-group size passed to the SYCL kernel may exceed the
10880
- limit. To get the device limit, query info::device::max_work_group_size.
10881
- Adjust the work-group size if needed.
10882
- */
10883
  stream->submit([&](sycl::handler &cgh) {
10884
- /*
10885
- DPCT1101:96: 'SYCL_SOFT_MAX_BLOCK_SIZE/WARP_SIZE' expression was
10886
- replaced with a value. Modify the code to use the original expression,
10887
- provided in comments, if it is correct.
10888
- */
10889
- sycl::local_accessor<float, 1> buf_acc_ct1(
10890
- sycl::range<1>(32 /*SYCL_SOFT_MAX_BLOCK_SIZE/WARP_SIZE*/), cgh);
10891
 
10892
  cgh.parallel_for(
10893
  sycl::nd_range<3>(block_nums * block_dims, block_dims),
10894
  [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(32)]] {
10895
- soft_max_f32(x, y, dst, ncols_x, nrows_y, scale, item_ct1,
10896
- buf_acc_ct1.get_pointer());
 
 
10897
  });
10898
  });
10899
  }
10900
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10901
  template <typename T>
10902
  static void im2col_sycl(const float *x, T *dst, int IW, int IH,
10903
  int OW, int OH, int KW, int KH, int IC,
@@ -12435,14 +12498,35 @@ inline void ggml_sycl_op_soft_max(const ggml_tensor *src0,
12435
 
12436
  const int64_t ne00 = src0->ne[0];
12437
  const int64_t nrows_x = ggml_nrows(src0);
12438
- const int64_t nrows_y = src1 ? ggml_nrows(src1) : 1;
12439
 
12440
  float scale = 1.0f;
12441
- memcpy(&scale, dst->op_params, sizeof(float));
12442
 
12443
- soft_max_f32_sycl(src0_dd, src1 ? src1_dd : nullptr, dst_dd, ne00, nrows_x, nrows_y, scale, main_stream);
 
12444
 
12445
- (void) dst;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12446
  }
12447
 
12448
  inline void ggml_sycl_op_scale(const ggml_tensor *src0, const ggml_tensor *src1,
 
8126
  dst[i] = x[i] - (col > n_past + row % rows_per_channel) * FLT_MAX;
8127
  }
8128
 
8129
+
8130
+ template <bool vals_smem, int ncols_template, int block_size_template>
8131
+ static void soft_max_f32(const float * x, const float * mask, const float *pos, float * dst, const int ncols_par,
8132
+ const int nrows_y, const float scale, const float max_bias, const float m0,
8133
+ const float m1, uint32_t n_head_log2, const sycl::nd_item<3> &item_ct1, float *buf) {
8134
+ const int ncols = ncols_template == 0 ? ncols_par : ncols_template;
8135
+
8136
  const int tid = item_ct1.get_local_id(2);
8137
  const int rowx = item_ct1.get_group(2);
8138
  const int rowy = rowx % nrows_y; // broadcast the mask (y) in the row dimension
8139
 
8140
+ const int block_size = block_size_template == 0 ? item_ct1.get_local_range(2) : block_size_template;
8141
 
8142
  const int warp_id = item_ct1.get_local_id(2) / WARP_SIZE;
8143
  const int lane_id = item_ct1.get_local_id(2) % WARP_SIZE;
8144
 
8145
+ float slope = 0.0f;
8146
+
8147
+ // ALiBi
8148
+ if (max_bias > 0.0f) {
8149
+ const uint32_t h = rowx/nrows_y; // head index
8150
+
8151
+ const float base = h < n_head_log2 ? m0 : m1;
8152
+ const int exp = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
8153
+
8154
+ slope = sycl::pow(base, float(exp));
8155
+ }
8156
+
8157
+ float * vals = vals_smem ? buf + WARP_SIZE : dst + rowx*ncols;
8158
  float max_val = -INFINITY;
8159
 
8160
+ for (int col0 = 0; col0 < ncols; col0 += block_size) {
8161
+ const int col = col0 + tid;
8162
+
8163
+ if (ncols_template == 0 && col >= ncols) {
8164
+ break;
8165
+ }
8166
+
8167
  const int ix = rowx*ncols + col;
8168
  const int iy = rowy*ncols + col;
8169
+
8170
+ const float val = x[ix]*scale + (mask ? mask[iy] : 0.0f) + (pos ? slope*pos[col] : 0.0f);
8171
+
8172
+ vals[col] = val;
8173
+ max_val = sycl::max(max_val, val);
8174
  }
8175
 
8176
  // find the max value in the block
 
8179
  if (warp_id == 0) {
8180
  buf[lane_id] = -INFINITY;
8181
  }
8182
+ item_ct1.barrier(sycl::access::fence_space::local_space);
 
 
 
 
 
 
 
 
 
8183
 
8184
  if (lane_id == 0) {
8185
  buf[warp_id] = max_val;
8186
  }
8187
+ item_ct1.barrier(sycl::access::fence_space::local_space);
 
 
 
 
 
 
 
 
 
8188
 
8189
  max_val = buf[lane_id];
8190
  max_val = warp_reduce_max(max_val, item_ct1);
 
8192
 
8193
  float tmp = 0.f;
8194
 
8195
+ #pragma unroll
8196
+ for (int col0 = 0; col0 < ncols; col0 += block_size) {
8197
+ const int col = col0 + tid;
8198
+ if (ncols_template == 0 && col >= ncols) {
8199
+ break;
8200
+ }
8201
+
8202
+ const float val = sycl::native::exp(vals[col] - max_val);
8203
  tmp += val;
8204
+ vals[col] = val;
8205
  }
8206
 
8207
  // find the sum of exps in the block
 
8210
  if (warp_id == 0) {
8211
  buf[lane_id] = 0.f;
8212
  }
8213
+ item_ct1.barrier(sycl::access::fence_space::local_space);
 
 
 
 
 
 
 
 
 
8214
 
8215
  if (lane_id == 0) {
8216
  buf[warp_id] = tmp;
8217
  }
8218
+ item_ct1.barrier(sycl::access::fence_space::local_space);
 
 
 
 
 
 
 
 
 
8219
 
8220
  tmp = buf[lane_id];
8221
  tmp = warp_reduce_sum(tmp, item_ct1);
8222
  }
8223
 
8224
+ const float inv_sum = 1.f / tmp;
8225
 
8226
+ #pragma unroll
8227
+ for (int col0 = 0; col0 < ncols; col0 += block_size) {
8228
+ const int col = col0 + tid;
8229
+
8230
+ if (ncols_template == 0 && col >= ncols) {
8231
+ return;
8232
+ }
8233
+
8234
+ const int idst = rowx*ncols + col;
8235
+ dst[idst] = vals[col] * inv_sum;
8236
  }
8237
  }
8238
 
 
10869
  });
10870
  }
10871
 
10872
+ template <bool vals_smem, int ncols_template, int block_size_template>
10873
+ static void soft_max_f32_submitter(const float * x, const float * mask, const float *pos, float * dst, const int ncols_par,
10874
+ const int nrows_y, const float scale, const float max_bias, const float m0,
10875
+ const float m1, uint32_t n_head_log2, sycl::range<3> block_nums, sycl::range<3> block_dims,
10876
+ const size_t n_local_scratch, dpct::queue_ptr stream) {
 
 
 
 
 
 
 
 
10877
  stream->submit([&](sycl::handler &cgh) {
10878
+ sycl::local_accessor<float, 1> local_buf_acc(n_local_scratch, cgh);
 
 
 
 
 
 
10879
 
10880
  cgh.parallel_for(
10881
  sycl::nd_range<3>(block_nums * block_dims, block_dims),
10882
  [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(32)]] {
10883
+ soft_max_f32<vals_smem, ncols_template, block_size_template>(x, mask, pos, dst, ncols_par,
10884
+ nrows_y, scale, max_bias, m0,
10885
+ m1, n_head_log2, item_ct1,
10886
+ local_buf_acc.get_pointer());
10887
  });
10888
  });
10889
  }
10890
 
10891
+ static void soft_max_f32_sycl(const float * x, const float * mask, const float * pos,
10892
+ float * dst, const int ncols_x, const int nrows_x,
10893
+ const int nrows_y, const float scale, const float max_bias,
10894
+ dpct::queue_ptr stream) {
10895
+ int nth = WARP_SIZE;
10896
+ while (nth < ncols_x && nth < SYCL_SOFT_MAX_BLOCK_SIZE) nth *= 2;
10897
+ const sycl::range<3> block_dims(1, 1, nth);
10898
+ const sycl::range<3> block_nums(1, 1, nrows_x);
10899
+ const size_t n_local_scratch = (GGML_PAD(ncols_x, WARP_SIZE) + WARP_SIZE);
10900
+ static_assert(SYCL_SOFT_MAX_BLOCK_SIZE == 1024, "These values need to be adjusted.");
10901
+
10902
+ const uint32_t n_head_kv = nrows_x/nrows_y;
10903
+ const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head_kv));
10904
+
10905
+ const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
10906
+ const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
10907
+
10908
+ const size_t local_mem_size = stream->get_device().get_info<sycl::info::device::local_mem_size>();
10909
+ if (n_local_scratch*sizeof(float) < local_mem_size) {
10910
+ switch (ncols_x) {
10911
+ case 32:
10912
+ soft_max_f32_submitter<true, 32, 32>(x, mask, pos, dst, ncols_x, nrows_y, scale,
10913
+ max_bias, m0, m1, n_head_log2, block_nums,
10914
+ block_dims, n_local_scratch, stream);
10915
+ break;
10916
+ case 64:
10917
+ soft_max_f32_submitter<true, 64, 64>(x, mask, pos, dst, ncols_x, nrows_y, scale,
10918
+ max_bias, m0, m1, n_head_log2, block_nums,
10919
+ block_dims, n_local_scratch, stream);
10920
+ break;
10921
+ case 128:
10922
+ soft_max_f32_submitter<true, 128, 128>(x, mask, pos, dst, ncols_x, nrows_y, scale,
10923
+ max_bias, m0, m1, n_head_log2, block_nums,
10924
+ block_dims, n_local_scratch, stream);
10925
+ break;
10926
+ case 256:
10927
+ soft_max_f32_submitter<true, 256, 256>(x, mask, pos, dst, ncols_x, nrows_y, scale,
10928
+ max_bias, m0, m1, n_head_log2, block_nums,
10929
+ block_dims, n_local_scratch, stream);
10930
+ break;
10931
+ case 512:
10932
+ soft_max_f32_submitter<true, 512, 512>(x, mask, pos, dst, ncols_x, nrows_y, scale,
10933
+ max_bias, m0, m1, n_head_log2, block_nums,
10934
+ block_dims, n_local_scratch, stream);
10935
+ break;
10936
+ case 1024:
10937
+ soft_max_f32_submitter<true, 1024, 1024>(x, mask, pos, dst, ncols_x, nrows_y, scale,
10938
+ max_bias, m0, m1, n_head_log2, block_nums,
10939
+ block_dims, n_local_scratch, stream);
10940
+ break;
10941
+ case 2048:
10942
+ soft_max_f32_submitter<true, 2048, 1024>(x, mask, pos, dst, ncols_x, nrows_y, scale,
10943
+ max_bias, m0, m1, n_head_log2, block_nums,
10944
+ block_dims, n_local_scratch, stream);
10945
+ break;
10946
+ case 4096:
10947
+ soft_max_f32_submitter<true, 4096, 1024>(x, mask, pos, dst, ncols_x, nrows_y, scale,
10948
+ max_bias, m0, m1, n_head_log2, block_nums,
10949
+ block_dims, n_local_scratch, stream);
10950
+ break;
10951
+ default:
10952
+ soft_max_f32_submitter<true, 0, 0>(x, mask, pos, dst, ncols_x, nrows_y, scale,
10953
+ max_bias, m0, m1, n_head_log2, block_nums,
10954
+ block_dims, n_local_scratch, stream);
10955
+ break;
10956
+ }
10957
+ } else {
10958
+ soft_max_f32_submitter<false, 0, 0>(x, mask, pos, dst, ncols_x, nrows_y, scale,
10959
+ max_bias, m0, m1, n_head_log2, block_nums,
10960
+ block_dims, WARP_SIZE, stream);
10961
+ }
10962
+ }
10963
+
10964
  template <typename T>
10965
  static void im2col_sycl(const float *x, T *dst, int IW, int IH,
10966
  int OW, int OH, int KW, int KH, int IC,
 
12498
 
12499
  const int64_t ne00 = src0->ne[0];
12500
  const int64_t nrows_x = ggml_nrows(src0);
12501
+ const int64_t nrows_y = src0->ne[1];
12502
 
12503
  float scale = 1.0f;
12504
+ float max_bias = 0.0f;
12505
 
12506
+ memcpy(&scale, dst->op_params + 0, sizeof(float));
12507
+ memcpy(&max_bias, dst->op_params + 1, sizeof(float));
12508
 
12509
+ // positions tensor
12510
+ float * src2_dd = nullptr;
12511
+ sycl_pool_alloc<float> src2_f;
12512
+
12513
+ ggml_tensor * src2 = dst->src[2];
12514
+ const bool use_src2 = src2 != nullptr;
12515
+
12516
+ if (use_src2) {
12517
+ const bool src2_on_device = src2->backend == GGML_BACKEND_TYPE_GPU;
12518
+
12519
+ if (src2_on_device) {
12520
+ ggml_tensor_extra_gpu * src2_extra = (ggml_tensor_extra_gpu *) src2->extra;
12521
+ src2_dd = (float *) src2_extra->data_device[g_main_device];
12522
+ } else {
12523
+ src2_dd = src2_f.alloc(ggml_nelements(src2));
12524
+ SYCL_CHECK(ggml_sycl_cpy_tensor_2d(src2_dd, src2, 0, 0, 0, 1, main_stream));
12525
+ }
12526
+ }
12527
+
12528
+ soft_max_f32_sycl(src0_dd, src1 ? src1_dd : nullptr, src2_dd, dst_dd, ne00,
12529
+ nrows_x, nrows_y, scale, max_bias, main_stream);
12530
  }
12531
 
12532
  inline void ggml_sycl_op_scale(const ggml_tensor *src0, const ggml_tensor *src1,