bashbaug commited on
Commit
9f87c2f
·
1 Parent(s): fcfd59e

use the correct SYCL context for host USM allocations (llama/7777)

Browse files
Files changed (1) hide show
  1. ggml-sycl.cpp +7 -4
ggml-sycl.cpp CHANGED
@@ -13089,10 +13089,12 @@ void *ggml_sycl_host_malloc(size_t size) try {
13089
  return nullptr;
13090
  }
13091
 
 
 
 
13092
  void * ptr = nullptr;
13093
- //allow to use dpct::get_in_order_queue() for host malloc
13094
  dpct::err0 err = CHECK_TRY_ERROR(
13095
- ptr = (void *)sycl::malloc_host(size, dpct::get_in_order_queue()));
13096
 
13097
  if (err != 0) {
13098
  // clear the error
@@ -13113,8 +13115,9 @@ catch (sycl::exception const &exc) {
13113
  }
13114
 
13115
  void ggml_sycl_host_free(void *ptr) try {
13116
- //allow to use dpct::get_in_order_queue() for host malloc
13117
- SYCL_CHECK(CHECK_TRY_ERROR(sycl::free(ptr, dpct::get_in_order_queue())));
 
13118
  }
13119
  catch (sycl::exception const &exc) {
13120
  std::cerr << exc.what() << "Exception caught at file:" << __FILE__
 
13089
  return nullptr;
13090
  }
13091
 
13092
+ ggml_sycl_set_device(g_main_device);
13093
+ dpct::queue_ptr main_stream = g_syclStreams[g_main_device][0];
13094
+
13095
  void * ptr = nullptr;
 
13096
  dpct::err0 err = CHECK_TRY_ERROR(
13097
+ ptr = (void *)sycl::malloc_host(size, *main_stream));
13098
 
13099
  if (err != 0) {
13100
  // clear the error
 
13115
  }
13116
 
13117
  void ggml_sycl_host_free(void *ptr) try {
13118
+ ggml_sycl_set_device(g_main_device);
13119
+ dpct::queue_ptr main_stream = g_syclStreams[g_main_device][0];
13120
+ SYCL_CHECK(CHECK_TRY_ERROR(sycl::free(ptr, *main_stream)));
13121
  }
13122
  catch (sycl::exception const &exc) {
13123
  std::cerr << exc.what() << "Exception caught at file:" << __FILE__