KevinLy commited on
Commit
b4d8c3e
·
1 Parent(s): f1abcb4

Add oneDNN primitive support (llama/9091)

Browse files

* add onednn

* add sycl_f16

* add dnnl stream

* add engine map

* use dnnl for intel only

* use fp16fp16fp16

* update doc

ggml/src/CMakeLists.txt CHANGED
@@ -549,6 +549,13 @@ if (GGML_SYCL)
549
  file(GLOB GGML_SOURCES_SYCL "ggml-sycl/*.cpp")
550
  list(APPEND GGML_SOURCES_SYCL "ggml-sycl.cpp")
551
 
 
 
 
 
 
 
 
552
  if (WIN32)
553
  find_package(IntelSYCL REQUIRED)
554
  find_package(MKL REQUIRED)
@@ -561,6 +568,9 @@ if (GGML_SYCL)
561
  set(GGML_EXTRA_LIBS ${GGML_EXTRA_LIBS} -fsycl pthread m dl onemkl)
562
  endif()
563
  endif()
 
 
 
564
  endif()
565
 
566
  if (GGML_RPC)
 
549
  file(GLOB GGML_SOURCES_SYCL "ggml-sycl/*.cpp")
550
  list(APPEND GGML_SOURCES_SYCL "ggml-sycl.cpp")
551
 
552
+ find_package(DNNL)
553
+ message("-- DNNL found:"${DNNL_FOUND})
554
+ if (GGML_SYCL_TARGET STREQUAL "INTEL")
555
+ add_compile_definitions(GGML_SYCL_DNNL=${DNNL_FOUND})
556
+ else()
557
+ add_compile_definitions(GGML_SYCL_DNNL=0)
558
+ endif()
559
  if (WIN32)
560
  find_package(IntelSYCL REQUIRED)
561
  find_package(MKL REQUIRED)
 
568
  set(GGML_EXTRA_LIBS ${GGML_EXTRA_LIBS} -fsycl pthread m dl onemkl)
569
  endif()
570
  endif()
571
+ if (${DNNL_FOUND} AND GGML_SYCL_TARGET STREQUAL "INTEL")
572
+ list(APPEND GGML_EXTRA_LIBS DNNL::dnnl)
573
+ endif()
574
  endif()
575
 
576
  if (GGML_RPC)
ggml/src/ggml-sycl.cpp CHANGED
@@ -38,6 +38,7 @@
38
 
39
  #include "ggml-sycl/backend.hpp"
40
  #include "ggml-sycl/presets.hpp"
 
41
 
42
  bool ggml_sycl_loaded(void);
43
  void ggml_sycl_free_data(struct ggml_tensor * tensor);
@@ -2482,6 +2483,7 @@ inline void ggml_sycl_op_mul_mat_sycl(
2482
 
2483
  const sycl::half alpha_f16 = 1.0f;
2484
  const sycl::half beta_f16 = 0.0f;
 
2485
  SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm(
2486
  *stream, oneapi::mkl::transpose::trans,
2487
  oneapi::mkl::transpose::nontrans, row_diff, src1_ncols, ne10,
@@ -2491,6 +2493,13 @@ inline void ggml_sycl_op_mul_mat_sycl(
2491
  dpct::library_data_t::real_half)));
2492
  const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(GGML_TYPE_F16);
2493
  to_fp32_sycl(dst_f16.get(), dst_dd_i, row_diff*src1_ncols, stream);
 
 
 
 
 
 
 
2494
  }
2495
  else {
2496
  // GGML_SYCL_DEBUG("ggml_sycl_op_mul_mat_sycl - fp32 path\n");
@@ -2513,13 +2522,18 @@ inline void ggml_sycl_op_mul_mat_sycl(
2513
 
2514
  const float alpha = 1.0f;
2515
  const float beta = 0.0f;
2516
-
2517
  SYCL_CHECK(CHECK_TRY_ERROR(oneapi::mkl::blas::column_major::gemm(
2518
  *stream, oneapi::mkl::transpose::trans,
2519
  oneapi::mkl::transpose::nontrans, row_diff, src1_ncols, ne10,
2520
  dpct::get_value(&alpha, *stream), src0_ddf_i, ne00,
2521
  src1_ddf1_i, ne10, dpct::get_value(&beta, *stream),
2522
  dst_dd_i, ldc)));
 
 
 
 
 
2523
  }
2524
  (void) dst;
2525
  (void) src1_ddq_i;
 
38
 
39
  #include "ggml-sycl/backend.hpp"
40
  #include "ggml-sycl/presets.hpp"
41
+ #include "ggml-sycl/gemm.hpp"
42
 
43
  bool ggml_sycl_loaded(void);
44
  void ggml_sycl_free_data(struct ggml_tensor * tensor);
 
2483
 
2484
  const sycl::half alpha_f16 = 1.0f;
2485
  const sycl::half beta_f16 = 0.0f;
2486
+ #if !GGML_SYCL_DNNL
2487
  SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm(
2488
  *stream, oneapi::mkl::transpose::trans,
2489
  oneapi::mkl::transpose::nontrans, row_diff, src1_ncols, ne10,
 
2493
  dpct::library_data_t::real_half)));
2494
  const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(GGML_TYPE_F16);
2495
  to_fp32_sycl(dst_f16.get(), dst_dd_i, row_diff*src1_ncols, stream);
2496
+ #else
2497
+ auto dnnl_stream = ctx.stream_dnnl(stream);
2498
+ DnnlGemmWrapper::row_gemm(dnnl_stream, false, true, src1_ncols, row_diff, ne10, src1_ptr, DnnlGemmWrapper::to_dt<sycl::half>(),
2499
+ src0_ptr, DnnlGemmWrapper::to_dt<sycl::half>(), dst_f16.get(), DnnlGemmWrapper::to_dt<sycl::half>());
2500
+ const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(GGML_TYPE_F16);
2501
+ to_fp32_sycl(dst_f16.get(), dst_dd_i, row_diff* src1_ncols, stream);
2502
+ #endif
2503
  }
2504
  else {
2505
  // GGML_SYCL_DEBUG("ggml_sycl_op_mul_mat_sycl - fp32 path\n");
 
2522
 
2523
  const float alpha = 1.0f;
2524
  const float beta = 0.0f;
2525
+ #if !GGML_SYCL_DNNL
2526
  SYCL_CHECK(CHECK_TRY_ERROR(oneapi::mkl::blas::column_major::gemm(
2527
  *stream, oneapi::mkl::transpose::trans,
2528
  oneapi::mkl::transpose::nontrans, row_diff, src1_ncols, ne10,
2529
  dpct::get_value(&alpha, *stream), src0_ddf_i, ne00,
2530
  src1_ddf1_i, ne10, dpct::get_value(&beta, *stream),
2531
  dst_dd_i, ldc)));
2532
+ #else
2533
+ auto dnnl_stream = ctx.stream_dnnl(stream);
2534
+ DnnlGemmWrapper::row_gemm(dnnl_stream, false, true, src1_ncols, row_diff, ne10, src1_ddf1_i, DnnlGemmWrapper::to_dt<float>(),
2535
+ src0_ddf_i, DnnlGemmWrapper::to_dt<float>(), dst_dd_i, DnnlGemmWrapper::to_dt<float>());
2536
+ #endif
2537
  }
2538
  (void) dst;
2539
  (void) src1_ddq_i;
ggml/src/ggml-sycl/common.hpp CHANGED
@@ -19,6 +19,10 @@
19
  #include "dpct/helper.hpp"
20
  #include "ggml-sycl.h"
21
  #include "presets.hpp"
 
 
 
 
22
 
23
  #define GGML_COMMON_DECL_SYCL
24
  #define GGML_COMMON_IMPL_SYCL
@@ -277,6 +281,52 @@ struct ggml_backend_sycl_context {
277
  return stream(device, 0);
278
  }
279
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
280
  // pool
281
  std::unique_ptr<ggml_sycl_pool> pools[GGML_SYCL_MAX_DEVICES];
282
 
 
19
  #include "dpct/helper.hpp"
20
  #include "ggml-sycl.h"
21
  #include "presets.hpp"
22
+ #if GGML_SYCL_DNNL
23
+ #include "dnnl.hpp"
24
+ #include "dnnl_sycl.hpp"
25
+ #endif
26
 
27
  #define GGML_COMMON_DECL_SYCL
28
  #define GGML_COMMON_IMPL_SYCL
 
281
  return stream(device, 0);
282
  }
283
 
284
+ #if GGML_SYCL_DNNL
285
+ dnnl::engine make_engine(sycl::queue* q) {
286
+ // Get the device associated with the queue
287
+ sycl::device dev = q->get_device();
288
+ // Get the context associated with the queue
289
+ sycl::context ctx = q->get_context();
290
+ const dnnl::engine eng = dnnl::sycl_interop::make_engine(dev, ctx);
291
+ return eng;
292
+ }
293
+
294
+ std::unordered_map<sycl::queue*, dnnl::stream> stream_map;
295
+ std::unordered_map<sycl::queue*, dnnl::engine> engine_map;
296
+ dnnl::stream stream_dnnl(int device, int _stream) {
297
+ auto q = stream(device, _stream);
298
+ return stream_dnnl(q);
299
+ }
300
+ dnnl::engine engine_dnnl(sycl::queue* qptr) {
301
+ auto it = engine_map.find(qptr);
302
+ if (it == engine_map.end()) {
303
+ auto eng = make_engine(qptr);
304
+ engine_map[qptr] = eng;
305
+ return eng;
306
+ }
307
+ else
308
+ {
309
+ return it->second;
310
+ }
311
+ }
312
+ dnnl::stream stream_dnnl(sycl::queue* qptr) {
313
+ auto it = stream_map.find(qptr);
314
+ if (it == stream_map.end()) {
315
+ auto eng = engine_dnnl(qptr);
316
+ auto stream = dnnl::sycl_interop::make_stream(eng, *qptr);
317
+ stream_map[qptr] = stream;
318
+ return stream;
319
+ }
320
+ else
321
+ {
322
+ return it->second;
323
+ }
324
+ }
325
+ dnnl::stream stream_dnnl() {
326
+ return stream_dnnl(device, 0);
327
+ }
328
+ #endif
329
+
330
  // pool
331
  std::unique_ptr<ggml_sycl_pool> pools[GGML_SYCL_MAX_DEVICES];
332
 
ggml/src/ggml-sycl/gemm.hpp ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ //
2
+ // MIT license
3
+ // Copyright (C) 2024 Intel Corporation
4
+ // SPDX-License-Identifier: MIT
5
+ //
6
+
7
+ //
8
+ // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
9
+ // See https://llvm.org/LICENSE.txt for license information.
10
+ // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
11
+ //
12
+
13
+ #ifndef GGML_SYCL_GEMM_HPP
14
+ #define GGML_SYCL_GEMM_HPP
15
+
16
+ #include <fstream>
17
+ #include <iostream>
18
+
19
+ #include "ggml-sycl.h"
20
+
21
+ #if GGML_SYCL_DNNL
22
+
23
+ #include "dnnl.hpp"
24
+ #include "dnnl_sycl.hpp"
25
+
26
+ class DnnlGemmWrapper {
27
+ public:
28
+ using dt = dnnl::memory::data_type;
29
+ using tag = dnnl::memory::format_tag;
30
+
31
+ template<typename T>
32
+ static constexpr dt to_dt() {
33
+ if constexpr (std::is_same_v<T, float>) return dt::f32;
34
+ else if constexpr (std::is_same_v<T, sycl::half>) return dt::f16;
35
+ else static_assert(0);
36
+ }
37
+
38
+ static inline void row_gemm(sycl::queue& q, bool a_trans,
39
+ bool b_trans, int m, int n, int k,
40
+ const void* a, dt at, const void* b, dt bt, void* c, dt ct)
41
+ {
42
+ // Get the device associated with the queue
43
+ sycl::device dev = q.get_device();
44
+ // Get the context associated with the queue
45
+ sycl::context ctx = q.get_context();
46
+ const dnnl::engine eng = dnnl::sycl_interop::make_engine(dev, ctx);
47
+ const dnnl::stream stream = dnnl::sycl_interop::make_stream(eng, q);
48
+ dnnl::memory::dims a_dims = { m, k };
49
+ dnnl::memory::dims b_dims = { k, n };
50
+ dnnl::memory::dims c_dims = { m, n };
51
+ const auto a_in_md = dnnl::memory::desc(a_dims, at, a_trans ? tag::ba : tag::ab);
52
+ const auto b_in_md = dnnl::memory::desc(b_dims, bt, b_trans ? tag::ba : tag::ab);
53
+ const auto c_md = dnnl::memory::desc(c_dims, ct, tag::ab);
54
+ auto a_mem = dnnl::memory(a_in_md, eng, (void*)a);
55
+ auto b_mem = dnnl::memory(b_in_md, eng, (void*)b);
56
+ auto matmul_pd = dnnl::matmul::primitive_desc(eng, a_in_md, b_in_md, c_md);
57
+ auto c_mem = dnnl::memory(matmul_pd.dst_desc(), eng, c);
58
+
59
+ // Create the primitive.
60
+ auto matmul_prim = dnnl::matmul(matmul_pd);
61
+ // Primitive arguments.
62
+ std::unordered_map<int, dnnl::memory> matmul_args;
63
+ matmul_args.insert({ DNNL_ARG_SRC, a_mem });
64
+ matmul_args.insert({ DNNL_ARG_WEIGHTS, b_mem });
65
+ matmul_args.insert({ DNNL_ARG_DST, c_mem });
66
+
67
+ matmul_prim.execute(stream, matmul_args);
68
+ }
69
+
70
+
71
+ static inline void row_gemm(const dnnl::stream& stream, bool a_trans,
72
+ bool b_trans, int m, int n, int k,
73
+ const void* a, dt at, const void* b, dt bt, void* c, dt ct)
74
+ {
75
+ auto const eng = stream.get_engine();
76
+ dnnl::memory::dims a_dims = { m, k };
77
+ dnnl::memory::dims b_dims = { k, n };
78
+ dnnl::memory::dims c_dims = { m, n };
79
+ const auto a_in_md = dnnl::memory::desc(a_dims, at, a_trans ? tag::ba : tag::ab);
80
+ const auto b_in_md = dnnl::memory::desc(b_dims, bt, b_trans ? tag::ba : tag::ab);
81
+ const auto c_md = dnnl::memory::desc(c_dims, ct, tag::ab);
82
+ auto a_mem = dnnl::memory(a_in_md, eng, (void*)a);
83
+ auto b_mem = dnnl::memory(b_in_md, eng, (void*)b);
84
+ auto matmul_pd = dnnl::matmul::primitive_desc(eng, a_in_md, b_in_md, c_md);
85
+ auto c_mem = dnnl::memory(matmul_pd.dst_desc(), eng, c);
86
+
87
+ // Create the primitive.
88
+ auto matmul_prim = dnnl::matmul(matmul_pd);
89
+ // Primitive arguments.
90
+ std::unordered_map<int, dnnl::memory> matmul_args;
91
+ matmul_args.insert({ DNNL_ARG_SRC, a_mem });
92
+ matmul_args.insert({ DNNL_ARG_WEIGHTS, b_mem });
93
+ matmul_args.insert({ DNNL_ARG_DST, c_mem });
94
+
95
+ matmul_prim.execute(stream, matmul_args);
96
+ }
97
+ };
98
+
99
+ #endif
100
+
101
+ #endif // GGML_SYCL_GEMM_HPP