Spaces:
Sleeping
Sleeping
Add oneDNN primitive support (llama/9091)
Browse files* add onednn
* add sycl_f16
* add dnnl stream
* add engine map
* use dnnl for intel only
* use fp16fp16fp16
* update doc
- ggml/src/CMakeLists.txt +10 -0
- ggml/src/ggml-sycl.cpp +15 -1
- ggml/src/ggml-sycl/common.hpp +50 -0
- ggml/src/ggml-sycl/gemm.hpp +101 -0
ggml/src/CMakeLists.txt
CHANGED
|
@@ -549,6 +549,13 @@ if (GGML_SYCL)
|
|
| 549 |
file(GLOB GGML_SOURCES_SYCL "ggml-sycl/*.cpp")
|
| 550 |
list(APPEND GGML_SOURCES_SYCL "ggml-sycl.cpp")
|
| 551 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 552 |
if (WIN32)
|
| 553 |
find_package(IntelSYCL REQUIRED)
|
| 554 |
find_package(MKL REQUIRED)
|
|
@@ -561,6 +568,9 @@ if (GGML_SYCL)
|
|
| 561 |
set(GGML_EXTRA_LIBS ${GGML_EXTRA_LIBS} -fsycl pthread m dl onemkl)
|
| 562 |
endif()
|
| 563 |
endif()
|
|
|
|
|
|
|
|
|
|
| 564 |
endif()
|
| 565 |
|
| 566 |
if (GGML_RPC)
|
|
|
|
| 549 |
file(GLOB GGML_SOURCES_SYCL "ggml-sycl/*.cpp")
|
| 550 |
list(APPEND GGML_SOURCES_SYCL "ggml-sycl.cpp")
|
| 551 |
|
| 552 |
+
find_package(DNNL)
|
| 553 |
+
message("-- DNNL found:"${DNNL_FOUND})
|
| 554 |
+
if (GGML_SYCL_TARGET STREQUAL "INTEL")
|
| 555 |
+
add_compile_definitions(GGML_SYCL_DNNL=${DNNL_FOUND})
|
| 556 |
+
else()
|
| 557 |
+
add_compile_definitions(GGML_SYCL_DNNL=0)
|
| 558 |
+
endif()
|
| 559 |
if (WIN32)
|
| 560 |
find_package(IntelSYCL REQUIRED)
|
| 561 |
find_package(MKL REQUIRED)
|
|
|
|
| 568 |
set(GGML_EXTRA_LIBS ${GGML_EXTRA_LIBS} -fsycl pthread m dl onemkl)
|
| 569 |
endif()
|
| 570 |
endif()
|
| 571 |
+
if (${DNNL_FOUND} AND GGML_SYCL_TARGET STREQUAL "INTEL")
|
| 572 |
+
list(APPEND GGML_EXTRA_LIBS DNNL::dnnl)
|
| 573 |
+
endif()
|
| 574 |
endif()
|
| 575 |
|
| 576 |
if (GGML_RPC)
|
ggml/src/ggml-sycl.cpp
CHANGED
|
@@ -38,6 +38,7 @@
|
|
| 38 |
|
| 39 |
#include "ggml-sycl/backend.hpp"
|
| 40 |
#include "ggml-sycl/presets.hpp"
|
|
|
|
| 41 |
|
| 42 |
bool ggml_sycl_loaded(void);
|
| 43 |
void ggml_sycl_free_data(struct ggml_tensor * tensor);
|
|
@@ -2482,6 +2483,7 @@ inline void ggml_sycl_op_mul_mat_sycl(
|
|
| 2482 |
|
| 2483 |
const sycl::half alpha_f16 = 1.0f;
|
| 2484 |
const sycl::half beta_f16 = 0.0f;
|
|
|
|
| 2485 |
SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm(
|
| 2486 |
*stream, oneapi::mkl::transpose::trans,
|
| 2487 |
oneapi::mkl::transpose::nontrans, row_diff, src1_ncols, ne10,
|
|
@@ -2491,6 +2493,13 @@ inline void ggml_sycl_op_mul_mat_sycl(
|
|
| 2491 |
dpct::library_data_t::real_half)));
|
| 2492 |
const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(GGML_TYPE_F16);
|
| 2493 |
to_fp32_sycl(dst_f16.get(), dst_dd_i, row_diff*src1_ncols, stream);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2494 |
}
|
| 2495 |
else {
|
| 2496 |
// GGML_SYCL_DEBUG("ggml_sycl_op_mul_mat_sycl - fp32 path\n");
|
|
@@ -2513,13 +2522,18 @@ inline void ggml_sycl_op_mul_mat_sycl(
|
|
| 2513 |
|
| 2514 |
const float alpha = 1.0f;
|
| 2515 |
const float beta = 0.0f;
|
| 2516 |
-
|
| 2517 |
SYCL_CHECK(CHECK_TRY_ERROR(oneapi::mkl::blas::column_major::gemm(
|
| 2518 |
*stream, oneapi::mkl::transpose::trans,
|
| 2519 |
oneapi::mkl::transpose::nontrans, row_diff, src1_ncols, ne10,
|
| 2520 |
dpct::get_value(&alpha, *stream), src0_ddf_i, ne00,
|
| 2521 |
src1_ddf1_i, ne10, dpct::get_value(&beta, *stream),
|
| 2522 |
dst_dd_i, ldc)));
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2523 |
}
|
| 2524 |
(void) dst;
|
| 2525 |
(void) src1_ddq_i;
|
|
|
|
| 38 |
|
| 39 |
#include "ggml-sycl/backend.hpp"
|
| 40 |
#include "ggml-sycl/presets.hpp"
|
| 41 |
+
#include "ggml-sycl/gemm.hpp"
|
| 42 |
|
| 43 |
bool ggml_sycl_loaded(void);
|
| 44 |
void ggml_sycl_free_data(struct ggml_tensor * tensor);
|
|
|
|
| 2483 |
|
| 2484 |
const sycl::half alpha_f16 = 1.0f;
|
| 2485 |
const sycl::half beta_f16 = 0.0f;
|
| 2486 |
+
#if !GGML_SYCL_DNNL
|
| 2487 |
SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm(
|
| 2488 |
*stream, oneapi::mkl::transpose::trans,
|
| 2489 |
oneapi::mkl::transpose::nontrans, row_diff, src1_ncols, ne10,
|
|
|
|
| 2493 |
dpct::library_data_t::real_half)));
|
| 2494 |
const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(GGML_TYPE_F16);
|
| 2495 |
to_fp32_sycl(dst_f16.get(), dst_dd_i, row_diff*src1_ncols, stream);
|
| 2496 |
+
#else
|
| 2497 |
+
auto dnnl_stream = ctx.stream_dnnl(stream);
|
| 2498 |
+
DnnlGemmWrapper::row_gemm(dnnl_stream, false, true, src1_ncols, row_diff, ne10, src1_ptr, DnnlGemmWrapper::to_dt<sycl::half>(),
|
| 2499 |
+
src0_ptr, DnnlGemmWrapper::to_dt<sycl::half>(), dst_f16.get(), DnnlGemmWrapper::to_dt<sycl::half>());
|
| 2500 |
+
const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(GGML_TYPE_F16);
|
| 2501 |
+
to_fp32_sycl(dst_f16.get(), dst_dd_i, row_diff* src1_ncols, stream);
|
| 2502 |
+
#endif
|
| 2503 |
}
|
| 2504 |
else {
|
| 2505 |
// GGML_SYCL_DEBUG("ggml_sycl_op_mul_mat_sycl - fp32 path\n");
|
|
|
|
| 2522 |
|
| 2523 |
const float alpha = 1.0f;
|
| 2524 |
const float beta = 0.0f;
|
| 2525 |
+
#if !GGML_SYCL_DNNL
|
| 2526 |
SYCL_CHECK(CHECK_TRY_ERROR(oneapi::mkl::blas::column_major::gemm(
|
| 2527 |
*stream, oneapi::mkl::transpose::trans,
|
| 2528 |
oneapi::mkl::transpose::nontrans, row_diff, src1_ncols, ne10,
|
| 2529 |
dpct::get_value(&alpha, *stream), src0_ddf_i, ne00,
|
| 2530 |
src1_ddf1_i, ne10, dpct::get_value(&beta, *stream),
|
| 2531 |
dst_dd_i, ldc)));
|
| 2532 |
+
#else
|
| 2533 |
+
auto dnnl_stream = ctx.stream_dnnl(stream);
|
| 2534 |
+
DnnlGemmWrapper::row_gemm(dnnl_stream, false, true, src1_ncols, row_diff, ne10, src1_ddf1_i, DnnlGemmWrapper::to_dt<float>(),
|
| 2535 |
+
src0_ddf_i, DnnlGemmWrapper::to_dt<float>(), dst_dd_i, DnnlGemmWrapper::to_dt<float>());
|
| 2536 |
+
#endif
|
| 2537 |
}
|
| 2538 |
(void) dst;
|
| 2539 |
(void) src1_ddq_i;
|
ggml/src/ggml-sycl/common.hpp
CHANGED
|
@@ -19,6 +19,10 @@
|
|
| 19 |
#include "dpct/helper.hpp"
|
| 20 |
#include "ggml-sycl.h"
|
| 21 |
#include "presets.hpp"
|
|
|
|
|
|
|
|
|
|
|
|
|
| 22 |
|
| 23 |
#define GGML_COMMON_DECL_SYCL
|
| 24 |
#define GGML_COMMON_IMPL_SYCL
|
|
@@ -277,6 +281,52 @@ struct ggml_backend_sycl_context {
|
|
| 277 |
return stream(device, 0);
|
| 278 |
}
|
| 279 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 280 |
// pool
|
| 281 |
std::unique_ptr<ggml_sycl_pool> pools[GGML_SYCL_MAX_DEVICES];
|
| 282 |
|
|
|
|
| 19 |
#include "dpct/helper.hpp"
|
| 20 |
#include "ggml-sycl.h"
|
| 21 |
#include "presets.hpp"
|
| 22 |
+
#if GGML_SYCL_DNNL
|
| 23 |
+
#include "dnnl.hpp"
|
| 24 |
+
#include "dnnl_sycl.hpp"
|
| 25 |
+
#endif
|
| 26 |
|
| 27 |
#define GGML_COMMON_DECL_SYCL
|
| 28 |
#define GGML_COMMON_IMPL_SYCL
|
|
|
|
| 281 |
return stream(device, 0);
|
| 282 |
}
|
| 283 |
|
| 284 |
+
#if GGML_SYCL_DNNL
|
| 285 |
+
dnnl::engine make_engine(sycl::queue* q) {
|
| 286 |
+
// Get the device associated with the queue
|
| 287 |
+
sycl::device dev = q->get_device();
|
| 288 |
+
// Get the context associated with the queue
|
| 289 |
+
sycl::context ctx = q->get_context();
|
| 290 |
+
const dnnl::engine eng = dnnl::sycl_interop::make_engine(dev, ctx);
|
| 291 |
+
return eng;
|
| 292 |
+
}
|
| 293 |
+
|
| 294 |
+
std::unordered_map<sycl::queue*, dnnl::stream> stream_map;
|
| 295 |
+
std::unordered_map<sycl::queue*, dnnl::engine> engine_map;
|
| 296 |
+
dnnl::stream stream_dnnl(int device, int _stream) {
|
| 297 |
+
auto q = stream(device, _stream);
|
| 298 |
+
return stream_dnnl(q);
|
| 299 |
+
}
|
| 300 |
+
dnnl::engine engine_dnnl(sycl::queue* qptr) {
|
| 301 |
+
auto it = engine_map.find(qptr);
|
| 302 |
+
if (it == engine_map.end()) {
|
| 303 |
+
auto eng = make_engine(qptr);
|
| 304 |
+
engine_map[qptr] = eng;
|
| 305 |
+
return eng;
|
| 306 |
+
}
|
| 307 |
+
else
|
| 308 |
+
{
|
| 309 |
+
return it->second;
|
| 310 |
+
}
|
| 311 |
+
}
|
| 312 |
+
dnnl::stream stream_dnnl(sycl::queue* qptr) {
|
| 313 |
+
auto it = stream_map.find(qptr);
|
| 314 |
+
if (it == stream_map.end()) {
|
| 315 |
+
auto eng = engine_dnnl(qptr);
|
| 316 |
+
auto stream = dnnl::sycl_interop::make_stream(eng, *qptr);
|
| 317 |
+
stream_map[qptr] = stream;
|
| 318 |
+
return stream;
|
| 319 |
+
}
|
| 320 |
+
else
|
| 321 |
+
{
|
| 322 |
+
return it->second;
|
| 323 |
+
}
|
| 324 |
+
}
|
| 325 |
+
dnnl::stream stream_dnnl() {
|
| 326 |
+
return stream_dnnl(device, 0);
|
| 327 |
+
}
|
| 328 |
+
#endif
|
| 329 |
+
|
| 330 |
// pool
|
| 331 |
std::unique_ptr<ggml_sycl_pool> pools[GGML_SYCL_MAX_DEVICES];
|
| 332 |
|
ggml/src/ggml-sycl/gemm.hpp
ADDED
|
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
//
|
| 2 |
+
// MIT license
|
| 3 |
+
// Copyright (C) 2024 Intel Corporation
|
| 4 |
+
// SPDX-License-Identifier: MIT
|
| 5 |
+
//
|
| 6 |
+
|
| 7 |
+
//
|
| 8 |
+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
| 9 |
+
// See https://llvm.org/LICENSE.txt for license information.
|
| 10 |
+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
| 11 |
+
//
|
| 12 |
+
|
| 13 |
+
#ifndef GGML_SYCL_GEMM_HPP
|
| 14 |
+
#define GGML_SYCL_GEMM_HPP
|
| 15 |
+
|
| 16 |
+
#include <fstream>
|
| 17 |
+
#include <iostream>
|
| 18 |
+
|
| 19 |
+
#include "ggml-sycl.h"
|
| 20 |
+
|
| 21 |
+
#if GGML_SYCL_DNNL
|
| 22 |
+
|
| 23 |
+
#include "dnnl.hpp"
|
| 24 |
+
#include "dnnl_sycl.hpp"
|
| 25 |
+
|
| 26 |
+
class DnnlGemmWrapper {
|
| 27 |
+
public:
|
| 28 |
+
using dt = dnnl::memory::data_type;
|
| 29 |
+
using tag = dnnl::memory::format_tag;
|
| 30 |
+
|
| 31 |
+
template<typename T>
|
| 32 |
+
static constexpr dt to_dt() {
|
| 33 |
+
if constexpr (std::is_same_v<T, float>) return dt::f32;
|
| 34 |
+
else if constexpr (std::is_same_v<T, sycl::half>) return dt::f16;
|
| 35 |
+
else static_assert(0);
|
| 36 |
+
}
|
| 37 |
+
|
| 38 |
+
static inline void row_gemm(sycl::queue& q, bool a_trans,
|
| 39 |
+
bool b_trans, int m, int n, int k,
|
| 40 |
+
const void* a, dt at, const void* b, dt bt, void* c, dt ct)
|
| 41 |
+
{
|
| 42 |
+
// Get the device associated with the queue
|
| 43 |
+
sycl::device dev = q.get_device();
|
| 44 |
+
// Get the context associated with the queue
|
| 45 |
+
sycl::context ctx = q.get_context();
|
| 46 |
+
const dnnl::engine eng = dnnl::sycl_interop::make_engine(dev, ctx);
|
| 47 |
+
const dnnl::stream stream = dnnl::sycl_interop::make_stream(eng, q);
|
| 48 |
+
dnnl::memory::dims a_dims = { m, k };
|
| 49 |
+
dnnl::memory::dims b_dims = { k, n };
|
| 50 |
+
dnnl::memory::dims c_dims = { m, n };
|
| 51 |
+
const auto a_in_md = dnnl::memory::desc(a_dims, at, a_trans ? tag::ba : tag::ab);
|
| 52 |
+
const auto b_in_md = dnnl::memory::desc(b_dims, bt, b_trans ? tag::ba : tag::ab);
|
| 53 |
+
const auto c_md = dnnl::memory::desc(c_dims, ct, tag::ab);
|
| 54 |
+
auto a_mem = dnnl::memory(a_in_md, eng, (void*)a);
|
| 55 |
+
auto b_mem = dnnl::memory(b_in_md, eng, (void*)b);
|
| 56 |
+
auto matmul_pd = dnnl::matmul::primitive_desc(eng, a_in_md, b_in_md, c_md);
|
| 57 |
+
auto c_mem = dnnl::memory(matmul_pd.dst_desc(), eng, c);
|
| 58 |
+
|
| 59 |
+
// Create the primitive.
|
| 60 |
+
auto matmul_prim = dnnl::matmul(matmul_pd);
|
| 61 |
+
// Primitive arguments.
|
| 62 |
+
std::unordered_map<int, dnnl::memory> matmul_args;
|
| 63 |
+
matmul_args.insert({ DNNL_ARG_SRC, a_mem });
|
| 64 |
+
matmul_args.insert({ DNNL_ARG_WEIGHTS, b_mem });
|
| 65 |
+
matmul_args.insert({ DNNL_ARG_DST, c_mem });
|
| 66 |
+
|
| 67 |
+
matmul_prim.execute(stream, matmul_args);
|
| 68 |
+
}
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
static inline void row_gemm(const dnnl::stream& stream, bool a_trans,
|
| 72 |
+
bool b_trans, int m, int n, int k,
|
| 73 |
+
const void* a, dt at, const void* b, dt bt, void* c, dt ct)
|
| 74 |
+
{
|
| 75 |
+
auto const eng = stream.get_engine();
|
| 76 |
+
dnnl::memory::dims a_dims = { m, k };
|
| 77 |
+
dnnl::memory::dims b_dims = { k, n };
|
| 78 |
+
dnnl::memory::dims c_dims = { m, n };
|
| 79 |
+
const auto a_in_md = dnnl::memory::desc(a_dims, at, a_trans ? tag::ba : tag::ab);
|
| 80 |
+
const auto b_in_md = dnnl::memory::desc(b_dims, bt, b_trans ? tag::ba : tag::ab);
|
| 81 |
+
const auto c_md = dnnl::memory::desc(c_dims, ct, tag::ab);
|
| 82 |
+
auto a_mem = dnnl::memory(a_in_md, eng, (void*)a);
|
| 83 |
+
auto b_mem = dnnl::memory(b_in_md, eng, (void*)b);
|
| 84 |
+
auto matmul_pd = dnnl::matmul::primitive_desc(eng, a_in_md, b_in_md, c_md);
|
| 85 |
+
auto c_mem = dnnl::memory(matmul_pd.dst_desc(), eng, c);
|
| 86 |
+
|
| 87 |
+
// Create the primitive.
|
| 88 |
+
auto matmul_prim = dnnl::matmul(matmul_pd);
|
| 89 |
+
// Primitive arguments.
|
| 90 |
+
std::unordered_map<int, dnnl::memory> matmul_args;
|
| 91 |
+
matmul_args.insert({ DNNL_ARG_SRC, a_mem });
|
| 92 |
+
matmul_args.insert({ DNNL_ARG_WEIGHTS, b_mem });
|
| 93 |
+
matmul_args.insert({ DNNL_ARG_DST, c_mem });
|
| 94 |
+
|
| 95 |
+
matmul_prim.execute(stream, matmul_args);
|
| 96 |
+
}
|
| 97 |
+
};
|
| 98 |
+
|
| 99 |
+
#endif
|
| 100 |
+
|
| 101 |
+
#endif // GGML_SYCL_GEMM_HPP
|