qnixsynapse commited on
Commit
8aaf0c8
·
1 Parent(s): 04b01d8

SYCL : SOFTMAX F16 mask support and other fixes (llama/11261)

Browse files

Implemented ggml_sycl_op_soft_max() F16 src1(mask) support for which a pragma deprecation warning was added during #5021.
To do this, had to decouple it from ggml_sycl_op_flatten which always considered src1 to be of fp32 type(many OP functions are dependent on it).

* SYCL: SOFTMAX F16 mask support and other fixes

* test-backend-ops: Add F16 mask test cases

ggml/src/ggml-sycl/ggml-sycl.cpp CHANGED
@@ -3878,10 +3878,6 @@ static void ggml_sycl_diag_mask_inf(ggml_backend_sycl_context & ctx, ggml_tensor
3878
  ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_diag_mask_inf);
3879
  }
3880
 
3881
- static void ggml_sycl_soft_max(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
3882
- ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_soft_max);
3883
- }
3884
-
3885
  static void ggml_sycl_rope(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
3886
  GGML_ASSERT(ggml_is_contiguous(dst->src[0])); // TODO: this restriction is temporary until non-cont support is implemented
3887
  ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_rope);
@@ -4090,7 +4086,7 @@ bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct ggml_tens
4090
  ggml_sycl_diag_mask_inf(ctx, dst);
4091
  break;
4092
  case GGML_OP_SOFT_MAX:
4093
- ggml_sycl_soft_max(ctx, dst);
4094
  break;
4095
  case GGML_OP_ROPE:
4096
  ggml_sycl_rope(ctx, dst);
 
3878
  ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_diag_mask_inf);
3879
  }
3880
 
 
 
 
 
3881
  static void ggml_sycl_rope(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
3882
  GGML_ASSERT(ggml_is_contiguous(dst->src[0])); // TODO: this restriction is temporary until non-cont support is implemented
3883
  ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_rope);
 
4086
  ggml_sycl_diag_mask_inf(ctx, dst);
4087
  break;
4088
  case GGML_OP_SOFT_MAX:
4089
+ ggml_sycl_op_soft_max(ctx, dst);
4090
  break;
4091
  case GGML_OP_ROPE:
4092
  ggml_sycl_rope(ctx, dst);
ggml/src/ggml-sycl/softmax.cpp CHANGED
@@ -1,7 +1,7 @@
1
- #include "norm.hpp"
2
 
3
- template <bool vals_smem, int ncols_template, int block_size_template>
4
- static void soft_max_f32(const float * x, const float * mask, float * dst, const int ncols_par,
5
  const int nrows_y, const float scale, const float max_bias, const float m0,
6
  const float m1, uint32_t n_head_log2, const sycl::nd_item<3> &item_ct1, float *buf) {
7
  const int ncols = ncols_template == 0 ? ncols_par : ncols_template;
@@ -29,7 +29,7 @@ static void soft_max_f32(const float * x, const float * mask, float * dst, const
29
  slope = sycl::pow(base, float(exp));
30
  }
31
 
32
- float *vals = vals_smem ? buf + std::max(nwarps, WARP_SIZE) : dst + rowx * ncols;
33
  float max_val = -INFINITY;
34
 
35
  for (int col0 = 0; col0 < ncols; col0 += block_size) {
@@ -42,7 +42,7 @@ static void soft_max_f32(const float * x, const float * mask, float * dst, const
42
  const int ix = rowx*ncols + col;
43
  const int iy = rowy*ncols + col;
44
 
45
- const float val = x[ix]*scale + (mask ? slope*mask[iy] : 0.0f);
46
 
47
  vals[col] = val;
48
  max_val = sycl::max(max_val, val);
@@ -65,7 +65,7 @@ static void soft_max_f32(const float * x, const float * mask, float * dst, const
65
  item_ct1.barrier(sycl::access::fence_space::local_space);
66
  max_val = buf[lane_id];
67
  for (size_t i = 1; i < nreduce; i += 1) {
68
- max_val = std::max(max_val, buf[lane_id + i * WARP_SIZE]);
69
  }
70
  max_val = warp_reduce_max(max_val, item_ct1);
71
  }
@@ -122,8 +122,8 @@ static void soft_max_f32(const float * x, const float * mask, float * dst, const
122
  }
123
  }
124
 
125
- template <bool vals_smem, int ncols_template, int block_size_template>
126
- static void soft_max_f32_submitter(const float * x, const float * mask, float * dst, const int ncols_par,
127
  const int nrows_y, const float scale, const float max_bias, const float m0,
128
  const float m1, uint32_t n_head_log2, sycl::range<3> block_nums, sycl::range<3> block_dims,
129
  const size_t n_local_scratch, queue_ptr stream) {
@@ -141,7 +141,8 @@ static void soft_max_f32_submitter(const float * x, const float * mask, float *
141
  });
142
  }
143
 
144
- static void soft_max_f32_sycl(const float * x, const float * mask,
 
145
  float * dst, const int ncols_x, const int nrows_x,
146
  const int nrows_y, const float scale, const float max_bias,
147
  queue_ptr stream, int device) {
@@ -223,22 +224,16 @@ static void soft_max_f32_sycl(const float * x, const float * mask,
223
  }
224
  }
225
 
226
- void ggml_sycl_op_soft_max(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
227
- const ggml_tensor *src1, ggml_tensor *dst,
228
- const float *src0_dd, const float *src1_dd,
229
- float *dst_dd,
230
- const queue_ptr &main_stream) {
231
 
232
- GGML_ASSERT(src0->type == GGML_TYPE_F32);
233
  GGML_ASSERT( dst->type == GGML_TYPE_F32);
234
 
235
- #pragma message("TODO: add ggml_sycl_op_soft_max() F16 src1 support")
236
- #pragma message("ref: https://github.com/ggerganov/llama.cpp/pull/5021")
237
- GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F32); // src1 contains mask and it is optional
238
 
239
- const int64_t ne00 = src0->ne[0];
240
- const int64_t nrows_x = ggml_nrows(src0);
241
- const int64_t nrows_y = src0->ne[1];
242
 
243
  float scale = 1.0f;
244
  float max_bias = 0.0f;
@@ -246,6 +241,21 @@ void ggml_sycl_op_soft_max(ggml_backend_sycl_context & ctx, const ggml_tensor *s
246
  memcpy(&scale, dst->op_params + 0, sizeof(float));
247
  memcpy(&max_bias, dst->op_params + 1, sizeof(float));
248
 
249
- soft_max_f32_sycl(src0_dd, src1 ? src1_dd : nullptr, dst_dd, ne00,
250
- nrows_x, nrows_y, scale, max_bias, main_stream, ctx.device);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
251
  }
 
1
+ #include "softmax.hpp"
2
 
3
+ template <bool vals_smem, int ncols_template, int block_size_template, typename T>
4
+ static void soft_max_f32(const float * x, const T * mask, float * dst, const int ncols_par,
5
  const int nrows_y, const float scale, const float max_bias, const float m0,
6
  const float m1, uint32_t n_head_log2, const sycl::nd_item<3> &item_ct1, float *buf) {
7
  const int ncols = ncols_template == 0 ? ncols_par : ncols_template;
 
29
  slope = sycl::pow(base, float(exp));
30
  }
31
 
32
+ float *vals = vals_smem ? buf + sycl::max(nwarps, WARP_SIZE) : dst + rowx * ncols;
33
  float max_val = -INFINITY;
34
 
35
  for (int col0 = 0; col0 < ncols; col0 += block_size) {
 
42
  const int ix = rowx*ncols + col;
43
  const int iy = rowy*ncols + col;
44
 
45
+ const float val = x[ix]*scale + (mask ? slope*static_cast<float>(mask[iy]) : 0.0f);
46
 
47
  vals[col] = val;
48
  max_val = sycl::max(max_val, val);
 
65
  item_ct1.barrier(sycl::access::fence_space::local_space);
66
  max_val = buf[lane_id];
67
  for (size_t i = 1; i < nreduce; i += 1) {
68
+ max_val = sycl::max(max_val, buf[lane_id + i * WARP_SIZE]);
69
  }
70
  max_val = warp_reduce_max(max_val, item_ct1);
71
  }
 
122
  }
123
  }
124
 
125
+ template <bool vals_smem, int ncols_template, int block_size_template, typename T>
126
+ static void soft_max_f32_submitter(const float * x, const T * mask, float * dst, const int ncols_par,
127
  const int nrows_y, const float scale, const float max_bias, const float m0,
128
  const float m1, uint32_t n_head_log2, sycl::range<3> block_nums, sycl::range<3> block_dims,
129
  const size_t n_local_scratch, queue_ptr stream) {
 
141
  });
142
  }
143
 
144
+ template<typename T>
145
+ static void soft_max_f32_sycl(const float * x, const T * mask,
146
  float * dst, const int ncols_x, const int nrows_x,
147
  const int nrows_y, const float scale, const float max_bias,
148
  queue_ptr stream, int device) {
 
224
  }
225
  }
226
 
227
+ void ggml_sycl_op_soft_max(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
 
 
 
 
228
 
229
+ GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
230
  GGML_ASSERT( dst->type == GGML_TYPE_F32);
231
 
232
+ GGML_ASSERT(!dst->src[1] || dst->src[1]->type == GGML_TYPE_F16 || dst->src[1]->type == GGML_TYPE_F32); // src1 contains mask and it is optional
 
 
233
 
234
+ const int64_t ne00 = dst->src[0]->ne[0];
235
+ const int64_t nrows_x = ggml_nrows(dst->src[0]);
236
+ const int64_t nrows_y = dst->src[0]->ne[1];
237
 
238
  float scale = 1.0f;
239
  float max_bias = 0.0f;
 
241
  memcpy(&scale, dst->op_params + 0, sizeof(float));
242
  memcpy(&max_bias, dst->op_params + 1, sizeof(float));
243
 
244
+ const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
245
+ float * dst_dd = static_cast<float *>(dst->data);
246
+
247
+ ggml_sycl_set_device(ctx.device);
248
+ dpct::queue_ptr main_stream = ctx.stream();
249
+
250
+ if (dst->src[1] && dst->src[1]->type == GGML_TYPE_F16) {
251
+ const sycl::half * src1_dd = static_cast<sycl::half *>(dst->src[1]->data);
252
+ soft_max_f32_sycl<sycl::half>(src0_dd, src1_dd, dst_dd, ne00, nrows_x, nrows_y, scale, max_bias,
253
+ main_stream, ctx.device);
254
+ } else if (dst->src[1] && dst->src[1]->type == GGML_TYPE_F32) {
255
+ const float * src1_dd = static_cast<const float *>(dst->src[1]->data);
256
+ soft_max_f32_sycl<float>(src0_dd, src1_dd, dst_dd, ne00, nrows_x, nrows_y, scale, max_bias, main_stream, ctx.device);
257
+ } else {
258
+ /* mask unavailable */
259
+ soft_max_f32_sycl<float>(src0_dd, nullptr, dst_dd, ne00, nrows_x, nrows_y, scale, max_bias, main_stream, ctx.device);
260
+ }
261
  }
ggml/src/ggml-sycl/softmax.hpp CHANGED
@@ -15,10 +15,6 @@
15
 
16
  #include "common.hpp"
17
 
18
- void ggml_sycl_op_soft_max(ggml_backend_sycl_context &ctx, const ggml_tensor *src0,
19
- const ggml_tensor *src1, ggml_tensor *dst,
20
- const float *src0_dd, const float *src1_dd,
21
- float *dst_dd,
22
- const queue_ptr &main_stream);
23
 
24
  #endif // GGML_SYCL_SOFTMAX_HPP
 
15
 
16
  #include "common.hpp"
17
 
18
+ void ggml_sycl_op_soft_max(ggml_backend_sycl_context &ctx, ggml_tensor *dst);
 
 
 
 
19
 
20
  #endif // GGML_SYCL_SOFTMAX_HPP