Anton Mitkov commited on
Commit
eed049f
·
1 Parent(s): 8076017

sycl: Remove not needed copy f16->f32 for dnnl mul mat (llama/14125)

Browse files
ggml/src/ggml-sycl/gemm.hpp CHANGED
@@ -65,6 +65,9 @@ public:
65
 
66
  dnnl::primitive_attr primitive_attr;
67
  primitive_attr.set_scratchpad_mode(dnnl::scratchpad_mode::user);
 
 
 
68
 
69
  auto a_mem = dnnl::memory(a_in_md, eng, const_cast<void*>(a));
70
  auto b_mem = dnnl::memory(b_in_md, eng, const_cast<void*>(b));
 
65
 
66
  dnnl::primitive_attr primitive_attr;
67
  primitive_attr.set_scratchpad_mode(dnnl::scratchpad_mode::user);
68
+ #ifdef GGML_SYCL_F16
69
+ primitive_attr.set_fpmath_mode(dnnl::fpmath_mode::f16);
70
+ #endif
71
 
72
  auto a_mem = dnnl::memory(a_in_md, eng, const_cast<void*>(a));
73
  auto b_mem = dnnl::memory(b_in_md, eng, const_cast<void*>(b));
ggml/src/ggml-sycl/ggml-sycl.cpp CHANGED
@@ -2127,21 +2127,18 @@ inline void ggml_sycl_op_mul_mat_sycl(
2127
  const sycl::half *src1_ptr = src1->type == GGML_TYPE_F16
2128
  ? (const sycl::half *)src1->data + src1_padded_row_size
2129
  : src1_as_f16.get();
2130
- ggml_sycl_pool_alloc<sycl::half> dst_f16(ctx.pool(), row_diff * src1_ncols);
2131
 
2132
  #if GGML_SYCL_DNNL
2133
  if (!g_ggml_sycl_disable_dnn) {
2134
  DnnlGemmWrapper::row_gemm(ctx, src1_ncols, row_diff, ne10, src1_ptr,
2135
  DnnlGemmWrapper::to_dt<sycl::half>(), src0_ptr, DnnlGemmWrapper::to_dt<sycl::half>(),
2136
- dst_f16.get(), DnnlGemmWrapper::to_dt<sycl::half>(), stream);
2137
- scope_op_debug_print scope_dbg_print(__func__, "/to_fp32_sycl", dst, /*num_src=*/2,
2138
- " : converting dst to fp32");
2139
- const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(GGML_TYPE_F16, dst);
2140
- to_fp32_sycl(dst_f16.get(), dst_dd_i, row_diff* src1_ncols, stream);
2141
  }
2142
  else
2143
  #endif
2144
  {
 
 
2145
  const sycl::half alpha_f16 = 1.0f;
2146
  const sycl::half beta_f16 = 0.0f;
2147
  SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm(
 
2127
  const sycl::half *src1_ptr = src1->type == GGML_TYPE_F16
2128
  ? (const sycl::half *)src1->data + src1_padded_row_size
2129
  : src1_as_f16.get();
 
2130
 
2131
  #if GGML_SYCL_DNNL
2132
  if (!g_ggml_sycl_disable_dnn) {
2133
  DnnlGemmWrapper::row_gemm(ctx, src1_ncols, row_diff, ne10, src1_ptr,
2134
  DnnlGemmWrapper::to_dt<sycl::half>(), src0_ptr, DnnlGemmWrapper::to_dt<sycl::half>(),
2135
+ dst_dd_i, DnnlGemmWrapper::to_dt<float>(), stream);
 
 
 
 
2136
  }
2137
  else
2138
  #endif
2139
  {
2140
+ ggml_sycl_pool_alloc<sycl::half> dst_f16(ctx.pool(), row_diff * src1_ncols);
2141
+
2142
  const sycl::half alpha_f16 = 1.0f;
2143
  const sycl::half beta_f16 = 0.0f;
2144
  SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm(