Romain Biessy commited on
Commit
ce0dc30
·
1 Parent(s): 3541ee8

sycl: Use syclcompat::dp4a (llama/10267)

Browse files

* sycl: Use syclcompat::dp4a

* Using the syclcompat version allow the compiler to optimize the
operation with native function

* Update news section

* Update CI Windows oneAPI version to 2025.0

* Reword doc

* Call syclcompat::dp4a inside dpct::dp4a

This reverts commit 90cb61d692d61360b46954a1c7f780bd2e569b73.

ggml/src/ggml-sycl/dpct/helper.hpp CHANGED
@@ -15,6 +15,7 @@
15
 
16
  #include <sycl/sycl.hpp>
17
  #include <sycl/half_type.hpp>
 
18
  #include <oneapi/mkl.hpp>
19
  #include <map>
20
 
@@ -1830,31 +1831,10 @@ namespace dpct
1830
  : id);
1831
  }
1832
 
1833
- template <typename T>
1834
- sycl::vec<T, 4> extract_and_sign_or_zero_extend4(T val)
1835
- {
1836
- return sycl::vec<T, 1>(val)
1837
- .template as<sycl::vec<
1838
- std::conditional_t<std::is_signed_v<T>, int8_t, uint8_t>, 4>>()
1839
- .template convert<T>();
1840
- }
1841
-
1842
- template <typename T1, typename T2>
1843
- using dot_product_acc_t =
1844
- std::conditional_t<std::is_unsigned_v<T1> && std::is_unsigned_v<T2>,
1845
- uint32_t, int32_t>;
1846
-
1847
  template <typename T1, typename T2, typename T3>
1848
  inline auto dp4a(T1 a, T2 b, T3 c)
1849
  {
1850
- dot_product_acc_t<T1, T2> res = c;
1851
- auto va = extract_and_sign_or_zero_extend4(a);
1852
- auto vb = extract_and_sign_or_zero_extend4(b);
1853
- res += va[0] * vb[0];
1854
- res += va[1] * vb[1];
1855
- res += va[2] * vb[2];
1856
- res += va[3] * vb[3];
1857
- return res;
1858
  }
1859
 
1860
  struct sub_sat
 
15
 
16
  #include <sycl/sycl.hpp>
17
  #include <sycl/half_type.hpp>
18
+ #include <syclcompat/math.hpp>
19
  #include <oneapi/mkl.hpp>
20
  #include <map>
21
 
 
1831
  : id);
1832
  }
1833
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1834
  template <typename T1, typename T2, typename T3>
1835
  inline auto dp4a(T1 a, T2 b, T3 c)
1836
  {
1837
+ return syclcompat::dp4a(a, b, c);
 
 
 
 
 
 
 
1838
  }
1839
 
1840
  struct sub_sat
ggml/src/ggml-sycl/vecdotq.hpp CHANGED
@@ -968,8 +968,8 @@ vec_dot_iq3_xxs_q8_1(const void *__restrict__ vbq,
968
  grid1[0] ^ signs[0], signs[0], std::minus<>());
969
  const int grid_h = dpct::vectorized_binary<sycl::uchar4>(
970
  grid2[0] ^ signs[1], signs[1], std::minus<>());
971
- sumi = dpct::dp4a(grid_l, *((int *)q8 + 0), sumi);
972
- sumi = dpct::dp4a(grid_h, *((int *)q8 + 1), sumi);
973
  q8 += 8;
974
  aux32 >>= 7;
975
  }
@@ -1009,8 +1009,8 @@ vec_dot_iq3_s_q8_1(const void *__restrict__ vbq,
1009
  grid1[0] ^ signs0, signs0, std::minus<>());
1010
  const int grid_h = dpct::vectorized_binary<sycl::uchar4>(
1011
  grid2[0] ^ signs1, signs1, std::minus<>());
1012
- sumi = dpct::dp4a(grid_l, *((int *)q8 + 0), sumi);
1013
- sumi = dpct::dp4a(grid_h, *((int *)q8 + 1), sumi);
1014
  q8 += 8;
1015
  }
1016
  const float d =
 
968
  grid1[0] ^ signs[0], signs[0], std::minus<>());
969
  const int grid_h = dpct::vectorized_binary<sycl::uchar4>(
970
  grid2[0] ^ signs[1], signs[1], std::minus<>());
971
+ sumi = dpct::dp4a(grid_l, *((const int *)q8 + 0), sumi);
972
+ sumi = dpct::dp4a(grid_h, *((const int *)q8 + 1), sumi);
973
  q8 += 8;
974
  aux32 >>= 7;
975
  }
 
1009
  grid1[0] ^ signs0, signs0, std::minus<>());
1010
  const int grid_h = dpct::vectorized_binary<sycl::uchar4>(
1011
  grid2[0] ^ signs1, signs1, std::minus<>());
1012
+ sumi = dpct::dp4a(grid_l, *((const int *)q8 + 0), sumi);
1013
+ sumi = dpct::dp4a(grid_h, *((const int *)q8 + 1), sumi);
1014
  q8 += 8;
1015
  }
1016
  const float d =