Spaces:
Running
Running
Anton Mitkov
commited on
Commit
·
eed049f
1
Parent(s):
8076017
sycl: Remove not needed copy f16->f32 for dnnl mul mat (llama/14125)
Browse files
ggml/src/ggml-sycl/gemm.hpp
CHANGED
|
@@ -65,6 +65,9 @@ public:
|
|
| 65 |
|
| 66 |
dnnl::primitive_attr primitive_attr;
|
| 67 |
primitive_attr.set_scratchpad_mode(dnnl::scratchpad_mode::user);
|
|
|
|
|
|
|
|
|
|
| 68 |
|
| 69 |
auto a_mem = dnnl::memory(a_in_md, eng, const_cast<void*>(a));
|
| 70 |
auto b_mem = dnnl::memory(b_in_md, eng, const_cast<void*>(b));
|
|
|
|
| 65 |
|
| 66 |
dnnl::primitive_attr primitive_attr;
|
| 67 |
primitive_attr.set_scratchpad_mode(dnnl::scratchpad_mode::user);
|
| 68 |
+
#ifdef GGML_SYCL_F16
|
| 69 |
+
primitive_attr.set_fpmath_mode(dnnl::fpmath_mode::f16);
|
| 70 |
+
#endif
|
| 71 |
|
| 72 |
auto a_mem = dnnl::memory(a_in_md, eng, const_cast<void*>(a));
|
| 73 |
auto b_mem = dnnl::memory(b_in_md, eng, const_cast<void*>(b));
|
ggml/src/ggml-sycl/ggml-sycl.cpp
CHANGED
|
@@ -2127,21 +2127,18 @@ inline void ggml_sycl_op_mul_mat_sycl(
|
|
| 2127 |
const sycl::half *src1_ptr = src1->type == GGML_TYPE_F16
|
| 2128 |
? (const sycl::half *)src1->data + src1_padded_row_size
|
| 2129 |
: src1_as_f16.get();
|
| 2130 |
-
ggml_sycl_pool_alloc<sycl::half> dst_f16(ctx.pool(), row_diff * src1_ncols);
|
| 2131 |
|
| 2132 |
#if GGML_SYCL_DNNL
|
| 2133 |
if (!g_ggml_sycl_disable_dnn) {
|
| 2134 |
DnnlGemmWrapper::row_gemm(ctx, src1_ncols, row_diff, ne10, src1_ptr,
|
| 2135 |
DnnlGemmWrapper::to_dt<sycl::half>(), src0_ptr, DnnlGemmWrapper::to_dt<sycl::half>(),
|
| 2136 |
-
|
| 2137 |
-
scope_op_debug_print scope_dbg_print(__func__, "/to_fp32_sycl", dst, /*num_src=*/2,
|
| 2138 |
-
" : converting dst to fp32");
|
| 2139 |
-
const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(GGML_TYPE_F16, dst);
|
| 2140 |
-
to_fp32_sycl(dst_f16.get(), dst_dd_i, row_diff* src1_ncols, stream);
|
| 2141 |
}
|
| 2142 |
else
|
| 2143 |
#endif
|
| 2144 |
{
|
|
|
|
|
|
|
| 2145 |
const sycl::half alpha_f16 = 1.0f;
|
| 2146 |
const sycl::half beta_f16 = 0.0f;
|
| 2147 |
SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm(
|
|
|
|
| 2127 |
const sycl::half *src1_ptr = src1->type == GGML_TYPE_F16
|
| 2128 |
? (const sycl::half *)src1->data + src1_padded_row_size
|
| 2129 |
: src1_as_f16.get();
|
|
|
|
| 2130 |
|
| 2131 |
#if GGML_SYCL_DNNL
|
| 2132 |
if (!g_ggml_sycl_disable_dnn) {
|
| 2133 |
DnnlGemmWrapper::row_gemm(ctx, src1_ncols, row_diff, ne10, src1_ptr,
|
| 2134 |
DnnlGemmWrapper::to_dt<sycl::half>(), src0_ptr, DnnlGemmWrapper::to_dt<sycl::half>(),
|
| 2135 |
+
dst_dd_i, DnnlGemmWrapper::to_dt<float>(), stream);
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2136 |
}
|
| 2137 |
else
|
| 2138 |
#endif
|
| 2139 |
{
|
| 2140 |
+
ggml_sycl_pool_alloc<sycl::half> dst_f16(ctx.pool(), row_diff * src1_ncols);
|
| 2141 |
+
|
| 2142 |
const sycl::half alpha_f16 = 1.0f;
|
| 2143 |
const sycl::half beta_f16 = 0.0f;
|
| 2144 |
SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm(
|