Chen Xi hengyu commited on
Commit
94a6436
·
1 Parent(s): efcca56

fix multi-gpu issue on sycl (llama/8554)

Browse files

---------

Signed-off-by: Chen Xi <[email protected]>
Co-authored-by: Meng, Hengyu <[email protected]>

ggml/src/ggml-sycl/common.hpp CHANGED
@@ -267,7 +267,7 @@ struct ggml_backend_sycl_context {
267
 
268
  queue_ptr stream(int device, int stream) {
269
  if (qptrs[device][stream] == nullptr) {
270
- qptrs[device][stream] = &(dpct::get_current_device().default_queue());
271
  }
272
  return qptrs[device][stream];
273
  }
 
267
 
268
  queue_ptr stream(int device, int stream) {
269
  if (qptrs[device][stream] == nullptr) {
270
+ qptrs[device][stream] = &(dpct::get_device(device).default_queue());
271
  }
272
  return qptrs[device][stream];
273
  }
ggml/src/ggml-sycl/dpct/helper.hpp CHANGED
@@ -588,7 +588,7 @@ namespace dpct
588
  out = prop;
589
  }
590
 
591
- /// dpct device extension
592
  class device_ext : public sycl::device {
593
  typedef std::mutex mutex_type;
594
 
@@ -697,7 +697,7 @@ namespace dpct
697
  std::unique_lock<mutex_type> lock(m_mutex);
698
  lock.unlock();
699
  for (auto &q : _queues) {
700
- q.wait_and_throw();
701
  }
702
  // Guard the destruct of current_queues to make sure the ref count is
703
  // safe.
@@ -734,7 +734,12 @@ namespace dpct
734
 
735
  void destroy_queue(sycl::queue queue) {
736
  std::lock_guard<mutex_type> lock(m_mutex);
737
- _queues.clear();
 
 
 
 
 
738
  }
739
  void set_saved_queue(sycl::queue q) {
740
  std::lock_guard<mutex_type> lock(m_mutex);
@@ -764,13 +769,13 @@ namespace dpct
764
  if (enable_exception_handler) {
765
  eh = exception_handler;
766
  }
767
- auto q = sycl::queue(*this, eh,
768
- sycl::property_list(
 
769
  #ifdef DPCT_PROFILING_ENABLED
770
- sycl::property::queue::enable_profiling(),
771
  #endif
772
- properties...));
773
- _queues.push_back(q);
774
 
775
  return _queues.back();
776
  }
@@ -783,8 +788,8 @@ namespace dpct
783
  if (enable_exception_handler) {
784
  eh = exception_handler;
785
  }
786
- _queues.push_back(
787
- sycl::queue(device, eh,
788
  sycl::property_list(
789
  #ifdef DPCT_PROFILING_ENABLED
790
  sycl::property::queue::enable_profiling(),
@@ -855,15 +860,75 @@ namespace dpct
855
  unsigned int get_device_id(const sycl::device &dev)
856
  {
857
  unsigned int id = 0;
858
- for (auto dev_item : _devs)
859
  {
860
  if (*dev_item == dev)
861
  {
862
- break;
863
  }
864
  id++;
865
  }
866
- return id;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
867
  }
868
 
869
  template <class DeviceSelector>
@@ -930,10 +995,15 @@ namespace dpct
930
  // Keep track of the number of devices per backend
931
  std::map<sycl::backend, size_t> DeviceNums;
932
  std::map<std::string, std::vector<sycl::device>> backend_devices;
 
933
 
934
  while (!Platforms.empty()) {
935
  auto Platform = Platforms.back();
936
  Platforms.pop_back();
 
 
 
 
937
  auto devices = Platform.get_devices();
938
  std::string backend_type = get_device_backend_and_type(devices[0]);
939
  for (const auto &device : devices) {
@@ -1989,6 +2059,11 @@ namespace dpct
1989
  return dev_mgr::instance().current_device();
1990
  }
1991
 
 
 
 
 
 
1992
  static inline sycl::queue &get_in_order_queue()
1993
  {
1994
  return dev_mgr::instance().current_device().in_order_queue();
 
588
  out = prop;
589
  }
590
 
591
+ /// dpct device extension
592
  class device_ext : public sycl::device {
593
  typedef std::mutex mutex_type;
594
 
 
697
  std::unique_lock<mutex_type> lock(m_mutex);
698
  lock.unlock();
699
  for (auto &q : _queues) {
700
+ q.wait_and_throw();
701
  }
702
  // Guard the destruct of current_queues to make sure the ref count is
703
  // safe.
 
734
 
735
  void destroy_queue(sycl::queue queue) {
736
  std::lock_guard<mutex_type> lock(m_mutex);
737
+ _queues.erase(std::remove_if(_queues.begin(), _queues.end(),
738
+ [=](const sycl::queue &q) -> bool
739
+ {
740
+ return q == queue;
741
+ }),
742
+ _queues.end());
743
  }
744
  void set_saved_queue(sycl::queue q) {
745
  std::lock_guard<mutex_type> lock(m_mutex);
 
769
  if (enable_exception_handler) {
770
  eh = exception_handler;
771
  }
772
+ _queues.push_back(sycl::queue(
773
+ *this, eh,
774
+ sycl::property_list(
775
  #ifdef DPCT_PROFILING_ENABLED
776
+ sycl::property::queue::enable_profiling(),
777
  #endif
778
+ properties...)));
 
779
 
780
  return _queues.back();
781
  }
 
788
  if (enable_exception_handler) {
789
  eh = exception_handler;
790
  }
791
+ _queues.push_back(sycl::queue(
792
+ device, eh,
793
  sycl::property_list(
794
  #ifdef DPCT_PROFILING_ENABLED
795
  sycl::property::queue::enable_profiling(),
 
860
  unsigned int get_device_id(const sycl::device &dev)
861
  {
862
  unsigned int id = 0;
863
+ for (auto &dev_item : _devs)
864
  {
865
  if (*dev_item == dev)
866
  {
867
+ return id;
868
  }
869
  id++;
870
  }
871
+ return -1;
872
+ }
873
+
874
+ inline std::string get_preferred_gpu_platform_name() {
875
+ std::string result;
876
+
877
+ std::string filter = "level-zero";
878
+ char* env = getenv("ONEAPI_DEVICE_SELECTOR");
879
+ if (env) {
880
+ if (std::strstr(env, "level_zero")) {
881
+ filter = "level-zero";
882
+ }
883
+ else if (std::strstr(env, "opencl")) {
884
+ filter = "opencl";
885
+ }
886
+ else if (std::strstr(env, "cuda")) {
887
+ filter = "cuda";
888
+ }
889
+ else if (std::strstr(env, "hip")) {
890
+ filter = "hip";
891
+ }
892
+ else {
893
+ throw std::runtime_error("invalid device filter: " + std::string(env));
894
+ }
895
+ }
896
+
897
+ auto plaform_list = sycl::platform::get_platforms();
898
+
899
+ for (const auto& platform : plaform_list) {
900
+ auto devices = platform.get_devices();
901
+ auto gpu_dev = std::find_if(devices.begin(), devices.end(), [](const sycl::device& d) {
902
+ return d.is_gpu();
903
+ });
904
+
905
+ if (gpu_dev == devices.end()) {
906
+ // cout << "platform [" << platform_name
907
+ // << "] does not contain GPU devices, skipping\n";
908
+ continue;
909
+ }
910
+
911
+ auto platform_name = platform.get_info<sycl::info::platform::name>();
912
+ std::string platform_name_low_case;
913
+ platform_name_low_case.resize(platform_name.size());
914
+
915
+ std::transform(
916
+ platform_name.begin(), platform_name.end(), platform_name_low_case.begin(), ::tolower);
917
+
918
+ if (platform_name_low_case.find(filter) == std::string::npos) {
919
+ // cout << "platform [" << platform_name
920
+ // << "] does not match with requested "
921
+ // << filter << ", skipping\n";
922
+ continue;
923
+ }
924
+
925
+ result = platform_name;
926
+ }
927
+
928
+ if (result.empty())
929
+ throw std::runtime_error("can not find preferred GPU platform");
930
+
931
+ return result;
932
  }
933
 
934
  template <class DeviceSelector>
 
995
  // Keep track of the number of devices per backend
996
  std::map<sycl::backend, size_t> DeviceNums;
997
  std::map<std::string, std::vector<sycl::device>> backend_devices;
998
+ auto preferred_platform_name = get_preferred_gpu_platform_name();
999
 
1000
  while (!Platforms.empty()) {
1001
  auto Platform = Platforms.back();
1002
  Platforms.pop_back();
1003
+ auto platform_name = Platform.get_info<sycl::info::platform::name>();
1004
+ if (platform_name.compare(preferred_platform_name) != 0) {
1005
+ continue;
1006
+ }
1007
  auto devices = Platform.get_devices();
1008
  std::string backend_type = get_device_backend_and_type(devices[0]);
1009
  for (const auto &device : devices) {
 
2059
  return dev_mgr::instance().current_device();
2060
  }
2061
 
2062
+ static inline device_ext &get_device(unsigned int id)
2063
+ {
2064
+ return dev_mgr::instance().get_device(id);
2065
+ }
2066
+
2067
  static inline sycl::queue &get_in_order_queue()
2068
  {
2069
  return dev_mgr::instance().current_device().in_order_queue();