Spaces:
Running
Running
fix multi-gpu issue on sycl (llama/8554)
Browse files---------
Signed-off-by: Chen Xi <[email protected]>
Co-authored-by: Meng, Hengyu <[email protected]>
ggml/src/ggml-sycl/common.hpp
CHANGED
|
@@ -267,7 +267,7 @@ struct ggml_backend_sycl_context {
|
|
| 267 |
|
| 268 |
queue_ptr stream(int device, int stream) {
|
| 269 |
if (qptrs[device][stream] == nullptr) {
|
| 270 |
-
qptrs[device][stream] = &(dpct::
|
| 271 |
}
|
| 272 |
return qptrs[device][stream];
|
| 273 |
}
|
|
|
|
| 267 |
|
| 268 |
queue_ptr stream(int device, int stream) {
|
| 269 |
if (qptrs[device][stream] == nullptr) {
|
| 270 |
+
qptrs[device][stream] = &(dpct::get_device(device).default_queue());
|
| 271 |
}
|
| 272 |
return qptrs[device][stream];
|
| 273 |
}
|
ggml/src/ggml-sycl/dpct/helper.hpp
CHANGED
|
@@ -588,7 +588,7 @@ namespace dpct
|
|
| 588 |
out = prop;
|
| 589 |
}
|
| 590 |
|
| 591 |
-
|
| 592 |
class device_ext : public sycl::device {
|
| 593 |
typedef std::mutex mutex_type;
|
| 594 |
|
|
@@ -697,7 +697,7 @@ namespace dpct
|
|
| 697 |
std::unique_lock<mutex_type> lock(m_mutex);
|
| 698 |
lock.unlock();
|
| 699 |
for (auto &q : _queues) {
|
| 700 |
-
|
| 701 |
}
|
| 702 |
// Guard the destruct of current_queues to make sure the ref count is
|
| 703 |
// safe.
|
|
@@ -734,7 +734,12 @@ namespace dpct
|
|
| 734 |
|
| 735 |
void destroy_queue(sycl::queue queue) {
|
| 736 |
std::lock_guard<mutex_type> lock(m_mutex);
|
| 737 |
-
_queues.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 738 |
}
|
| 739 |
void set_saved_queue(sycl::queue q) {
|
| 740 |
std::lock_guard<mutex_type> lock(m_mutex);
|
|
@@ -764,13 +769,13 @@ namespace dpct
|
|
| 764 |
if (enable_exception_handler) {
|
| 765 |
eh = exception_handler;
|
| 766 |
}
|
| 767 |
-
|
| 768 |
-
|
|
|
|
| 769 |
#ifdef DPCT_PROFILING_ENABLED
|
| 770 |
-
|
| 771 |
#endif
|
| 772 |
-
|
| 773 |
-
_queues.push_back(q);
|
| 774 |
|
| 775 |
return _queues.back();
|
| 776 |
}
|
|
@@ -783,8 +788,8 @@ namespace dpct
|
|
| 783 |
if (enable_exception_handler) {
|
| 784 |
eh = exception_handler;
|
| 785 |
}
|
| 786 |
-
_queues.push_back(
|
| 787 |
-
|
| 788 |
sycl::property_list(
|
| 789 |
#ifdef DPCT_PROFILING_ENABLED
|
| 790 |
sycl::property::queue::enable_profiling(),
|
|
@@ -855,15 +860,75 @@ namespace dpct
|
|
| 855 |
unsigned int get_device_id(const sycl::device &dev)
|
| 856 |
{
|
| 857 |
unsigned int id = 0;
|
| 858 |
-
for (auto dev_item : _devs)
|
| 859 |
{
|
| 860 |
if (*dev_item == dev)
|
| 861 |
{
|
| 862 |
-
|
| 863 |
}
|
| 864 |
id++;
|
| 865 |
}
|
| 866 |
-
return
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 867 |
}
|
| 868 |
|
| 869 |
template <class DeviceSelector>
|
|
@@ -930,10 +995,15 @@ namespace dpct
|
|
| 930 |
// Keep track of the number of devices per backend
|
| 931 |
std::map<sycl::backend, size_t> DeviceNums;
|
| 932 |
std::map<std::string, std::vector<sycl::device>> backend_devices;
|
|
|
|
| 933 |
|
| 934 |
while (!Platforms.empty()) {
|
| 935 |
auto Platform = Platforms.back();
|
| 936 |
Platforms.pop_back();
|
|
|
|
|
|
|
|
|
|
|
|
|
| 937 |
auto devices = Platform.get_devices();
|
| 938 |
std::string backend_type = get_device_backend_and_type(devices[0]);
|
| 939 |
for (const auto &device : devices) {
|
|
@@ -1989,6 +2059,11 @@ namespace dpct
|
|
| 1989 |
return dev_mgr::instance().current_device();
|
| 1990 |
}
|
| 1991 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1992 |
static inline sycl::queue &get_in_order_queue()
|
| 1993 |
{
|
| 1994 |
return dev_mgr::instance().current_device().in_order_queue();
|
|
|
|
| 588 |
out = prop;
|
| 589 |
}
|
| 590 |
|
| 591 |
+
/// dpct device extension
|
| 592 |
class device_ext : public sycl::device {
|
| 593 |
typedef std::mutex mutex_type;
|
| 594 |
|
|
|
|
| 697 |
std::unique_lock<mutex_type> lock(m_mutex);
|
| 698 |
lock.unlock();
|
| 699 |
for (auto &q : _queues) {
|
| 700 |
+
q.wait_and_throw();
|
| 701 |
}
|
| 702 |
// Guard the destruct of current_queues to make sure the ref count is
|
| 703 |
// safe.
|
|
|
|
| 734 |
|
| 735 |
void destroy_queue(sycl::queue queue) {
|
| 736 |
std::lock_guard<mutex_type> lock(m_mutex);
|
| 737 |
+
_queues.erase(std::remove_if(_queues.begin(), _queues.end(),
|
| 738 |
+
[=](const sycl::queue &q) -> bool
|
| 739 |
+
{
|
| 740 |
+
return q == queue;
|
| 741 |
+
}),
|
| 742 |
+
_queues.end());
|
| 743 |
}
|
| 744 |
void set_saved_queue(sycl::queue q) {
|
| 745 |
std::lock_guard<mutex_type> lock(m_mutex);
|
|
|
|
| 769 |
if (enable_exception_handler) {
|
| 770 |
eh = exception_handler;
|
| 771 |
}
|
| 772 |
+
_queues.push_back(sycl::queue(
|
| 773 |
+
*this, eh,
|
| 774 |
+
sycl::property_list(
|
| 775 |
#ifdef DPCT_PROFILING_ENABLED
|
| 776 |
+
sycl::property::queue::enable_profiling(),
|
| 777 |
#endif
|
| 778 |
+
properties...)));
|
|
|
|
| 779 |
|
| 780 |
return _queues.back();
|
| 781 |
}
|
|
|
|
| 788 |
if (enable_exception_handler) {
|
| 789 |
eh = exception_handler;
|
| 790 |
}
|
| 791 |
+
_queues.push_back(sycl::queue(
|
| 792 |
+
device, eh,
|
| 793 |
sycl::property_list(
|
| 794 |
#ifdef DPCT_PROFILING_ENABLED
|
| 795 |
sycl::property::queue::enable_profiling(),
|
|
|
|
| 860 |
unsigned int get_device_id(const sycl::device &dev)
|
| 861 |
{
|
| 862 |
unsigned int id = 0;
|
| 863 |
+
for (auto &dev_item : _devs)
|
| 864 |
{
|
| 865 |
if (*dev_item == dev)
|
| 866 |
{
|
| 867 |
+
return id;
|
| 868 |
}
|
| 869 |
id++;
|
| 870 |
}
|
| 871 |
+
return -1;
|
| 872 |
+
}
|
| 873 |
+
|
| 874 |
+
inline std::string get_preferred_gpu_platform_name() {
|
| 875 |
+
std::string result;
|
| 876 |
+
|
| 877 |
+
std::string filter = "level-zero";
|
| 878 |
+
char* env = getenv("ONEAPI_DEVICE_SELECTOR");
|
| 879 |
+
if (env) {
|
| 880 |
+
if (std::strstr(env, "level_zero")) {
|
| 881 |
+
filter = "level-zero";
|
| 882 |
+
}
|
| 883 |
+
else if (std::strstr(env, "opencl")) {
|
| 884 |
+
filter = "opencl";
|
| 885 |
+
}
|
| 886 |
+
else if (std::strstr(env, "cuda")) {
|
| 887 |
+
filter = "cuda";
|
| 888 |
+
}
|
| 889 |
+
else if (std::strstr(env, "hip")) {
|
| 890 |
+
filter = "hip";
|
| 891 |
+
}
|
| 892 |
+
else {
|
| 893 |
+
throw std::runtime_error("invalid device filter: " + std::string(env));
|
| 894 |
+
}
|
| 895 |
+
}
|
| 896 |
+
|
| 897 |
+
auto plaform_list = sycl::platform::get_platforms();
|
| 898 |
+
|
| 899 |
+
for (const auto& platform : plaform_list) {
|
| 900 |
+
auto devices = platform.get_devices();
|
| 901 |
+
auto gpu_dev = std::find_if(devices.begin(), devices.end(), [](const sycl::device& d) {
|
| 902 |
+
return d.is_gpu();
|
| 903 |
+
});
|
| 904 |
+
|
| 905 |
+
if (gpu_dev == devices.end()) {
|
| 906 |
+
// cout << "platform [" << platform_name
|
| 907 |
+
// << "] does not contain GPU devices, skipping\n";
|
| 908 |
+
continue;
|
| 909 |
+
}
|
| 910 |
+
|
| 911 |
+
auto platform_name = platform.get_info<sycl::info::platform::name>();
|
| 912 |
+
std::string platform_name_low_case;
|
| 913 |
+
platform_name_low_case.resize(platform_name.size());
|
| 914 |
+
|
| 915 |
+
std::transform(
|
| 916 |
+
platform_name.begin(), platform_name.end(), platform_name_low_case.begin(), ::tolower);
|
| 917 |
+
|
| 918 |
+
if (platform_name_low_case.find(filter) == std::string::npos) {
|
| 919 |
+
// cout << "platform [" << platform_name
|
| 920 |
+
// << "] does not match with requested "
|
| 921 |
+
// << filter << ", skipping\n";
|
| 922 |
+
continue;
|
| 923 |
+
}
|
| 924 |
+
|
| 925 |
+
result = platform_name;
|
| 926 |
+
}
|
| 927 |
+
|
| 928 |
+
if (result.empty())
|
| 929 |
+
throw std::runtime_error("can not find preferred GPU platform");
|
| 930 |
+
|
| 931 |
+
return result;
|
| 932 |
}
|
| 933 |
|
| 934 |
template <class DeviceSelector>
|
|
|
|
| 995 |
// Keep track of the number of devices per backend
|
| 996 |
std::map<sycl::backend, size_t> DeviceNums;
|
| 997 |
std::map<std::string, std::vector<sycl::device>> backend_devices;
|
| 998 |
+
auto preferred_platform_name = get_preferred_gpu_platform_name();
|
| 999 |
|
| 1000 |
while (!Platforms.empty()) {
|
| 1001 |
auto Platform = Platforms.back();
|
| 1002 |
Platforms.pop_back();
|
| 1003 |
+
auto platform_name = Platform.get_info<sycl::info::platform::name>();
|
| 1004 |
+
if (platform_name.compare(preferred_platform_name) != 0) {
|
| 1005 |
+
continue;
|
| 1006 |
+
}
|
| 1007 |
auto devices = Platform.get_devices();
|
| 1008 |
std::string backend_type = get_device_backend_and_type(devices[0]);
|
| 1009 |
for (const auto &device : devices) {
|
|
|
|
| 2059 |
return dev_mgr::instance().current_device();
|
| 2060 |
}
|
| 2061 |
|
| 2062 |
+
static inline device_ext &get_device(unsigned int id)
|
| 2063 |
+
{
|
| 2064 |
+
return dev_mgr::instance().get_device(id);
|
| 2065 |
+
}
|
| 2066 |
+
|
| 2067 |
static inline sycl::queue &get_in_order_queue()
|
| 2068 |
{
|
| 2069 |
return dev_mgr::instance().current_device().in_order_queue();
|