Ouadie EL FAROUKI AidanBeltonS commited on
Commit
f84edd5
·
unverified ·
1 Parent(s): 26fdc9f

Fixed minor bug when enabling FP16 for non intel targets (llama/6464)

Browse files

* moved INTEL_MKL guard from gemm_impl to gemm (wrapper)

* Update ggml-sycl.cpp

Co-authored-by: AidanBeltonS <[email protected]>

---------

Co-authored-by: AidanBeltonS <[email protected]>

Files changed (1) hide show
  1. ggml-sycl.cpp +2 -19
ggml-sycl.cpp CHANGED
@@ -1664,24 +1664,6 @@ namespace dpct
1664
  const void *alpha, const void *a, int lda, const void *b,
1665
  int ldb, const void *beta, void *c, int ldc)
1666
  {
1667
- #ifndef __INTEL_MKL__
1668
- GGML_UNUSED(q);
1669
- GGML_UNUSED(a_trans);
1670
- GGML_UNUSED(b_trans);
1671
- GGML_UNUSED(m);
1672
- GGML_UNUSED(n);
1673
- GGML_UNUSED(k);
1674
- GGML_UNUSED(alpha);
1675
- GGML_UNUSED(a);
1676
- GGML_UNUSED(lda);
1677
- GGML_UNUSED(b);
1678
- GGML_UNUSED(ldb);
1679
- GGML_UNUSED(beta);
1680
- GGML_UNUSED(c);
1681
- GGML_UNUSED(ldc);
1682
- throw std::runtime_error("The oneAPI Math Kernel Library (oneMKL) Interfaces "
1683
- "Project does not support this API.");
1684
- #else
1685
  Ts alpha_value = dpct::get_value(reinterpret_cast<const Ts *>(alpha), q);
1686
  Ts beta_value = dpct::get_value(reinterpret_cast<const Ts *>(beta), q);
1687
  auto data_a = get_memory<const Ta>(a);
@@ -1690,7 +1672,6 @@ namespace dpct
1690
  oneapi::mkl::blas::column_major::gemm(
1691
  q, a_trans, b_trans, m, n, k, alpha_value, data_a, lda,
1692
  data_b, ldb, beta_value, data_c, ldc);
1693
- #endif
1694
  }
1695
 
1696
  template <typename VecT, class BinaryOperation, class = void>
@@ -2330,6 +2311,7 @@ namespace dpct
2330
  lda, b, ldb, beta, c, ldc);
2331
  break;
2332
  }
 
2333
  case detail::get_type_combination_id(
2334
  library_data_t::real_bfloat16, library_data_t::real_bfloat16,
2335
  library_data_t::real_float, library_data_t::real_float):
@@ -2391,6 +2373,7 @@ namespace dpct
2391
  q, a_trans, b_trans, m, n, k, &alpha_float, a, lda, b, ldb, &beta_float, c, ldc);
2392
  break;
2393
  }
 
2394
  default:
2395
  throw std::runtime_error("the combination of data type is unsupported");
2396
  }
 
1664
  const void *alpha, const void *a, int lda, const void *b,
1665
  int ldb, const void *beta, void *c, int ldc)
1666
  {
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1667
  Ts alpha_value = dpct::get_value(reinterpret_cast<const Ts *>(alpha), q);
1668
  Ts beta_value = dpct::get_value(reinterpret_cast<const Ts *>(beta), q);
1669
  auto data_a = get_memory<const Ta>(a);
 
1672
  oneapi::mkl::blas::column_major::gemm(
1673
  q, a_trans, b_trans, m, n, k, alpha_value, data_a, lda,
1674
  data_b, ldb, beta_value, data_c, ldc);
 
1675
  }
1676
 
1677
  template <typename VecT, class BinaryOperation, class = void>
 
2311
  lda, b, ldb, beta, c, ldc);
2312
  break;
2313
  }
2314
+ #ifdef __INTEL_MKL__
2315
  case detail::get_type_combination_id(
2316
  library_data_t::real_bfloat16, library_data_t::real_bfloat16,
2317
  library_data_t::real_float, library_data_t::real_float):
 
2373
  q, a_trans, b_trans, m, n, k, &alpha_float, a, lda, b, ldb, &beta_float, c, ldc);
2374
  break;
2375
  }
2376
+ #endif // __INTEL_MKL__
2377
  default:
2378
  throw std::runtime_error("the combination of data type is unsupported");
2379
  }