Spaces:
Sleeping
Sleeping
update HIP_UMA #7399 (llama/7414)
Browse files* update HIP_UMA #7399
add use of hipMemAdviseSetCoarseGrain when LLAMA_HIP_UMA is enable.
- get x2 on prompte eval and x1.5 on token gen with rocm6.0 on ryzen 7940HX iGPU (780M/gfx1103)
* simplify code, more consistent style
---------
Co-authored-by: slaren <[email protected]>
- ggml-cuda.cu +17 -3
- ggml-cuda/common.cuh +0 -5
ggml-cuda.cu
CHANGED
|
@@ -119,6 +119,20 @@ int ggml_cuda_get_device() {
|
|
| 119 |
return id;
|
| 120 |
}
|
| 121 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 122 |
static ggml_cuda_device_info ggml_cuda_init() {
|
| 123 |
#ifdef __HIP_PLATFORM_AMD__
|
| 124 |
// Workaround for a rocBLAS bug when using multiple graphics cards:
|
|
@@ -271,7 +285,7 @@ struct ggml_cuda_pool_leg : public ggml_cuda_pool {
|
|
| 271 |
size_t look_ahead_size = (size_t) (1.05 * size);
|
| 272 |
look_ahead_size = 256 * ((look_ahead_size + 255)/256);
|
| 273 |
ggml_cuda_set_device(device);
|
| 274 |
-
CUDA_CHECK(
|
| 275 |
*actual_size = look_ahead_size;
|
| 276 |
pool_size += look_ahead_size;
|
| 277 |
#ifdef DEBUG_CUDA_MALLOC
|
|
@@ -537,7 +551,7 @@ GGML_CALL static ggml_backend_buffer_t ggml_backend_cuda_buffer_type_alloc_buffe
|
|
| 537 |
size = std::max(size, (size_t)1); // cudaMalloc returns null for size 0
|
| 538 |
|
| 539 |
void * dev_ptr;
|
| 540 |
-
cudaError_t err =
|
| 541 |
if (err != cudaSuccess) {
|
| 542 |
// clear the error
|
| 543 |
cudaGetLastError();
|
|
@@ -798,7 +812,7 @@ GGML_CALL static void ggml_backend_cuda_split_buffer_init_tensor(ggml_backend_bu
|
|
| 798 |
// currently, init_tensor cannot fail, it needs to be fixed in ggml-backend first
|
| 799 |
ggml_cuda_set_device(id);
|
| 800 |
char * buf;
|
| 801 |
-
CUDA_CHECK(
|
| 802 |
|
| 803 |
// set padding to 0 to avoid possible NaN values
|
| 804 |
if (size > original_size) {
|
|
|
|
| 119 |
return id;
|
| 120 |
}
|
| 121 |
|
| 122 |
+
static cudaError_t ggml_cuda_device_malloc(void ** ptr, size_t size, int device) {
|
| 123 |
+
ggml_cuda_set_device(device);
|
| 124 |
+
#if defined(GGML_USE_HIPBLAS) && defined(GGML_HIP_UMA)
|
| 125 |
+
auto res = hipMallocManaged(ptr, size);
|
| 126 |
+
if (res == hipSuccess) {
|
| 127 |
+
// if error we "need" to know why...
|
| 128 |
+
CUDA_CHECK(hipMemAdvise(*ptr, size, hipMemAdviseSetCoarseGrain, device));
|
| 129 |
+
}
|
| 130 |
+
return res;
|
| 131 |
+
#else
|
| 132 |
+
return cudaMalloc(ptr, size);
|
| 133 |
+
#endif
|
| 134 |
+
}
|
| 135 |
+
|
| 136 |
static ggml_cuda_device_info ggml_cuda_init() {
|
| 137 |
#ifdef __HIP_PLATFORM_AMD__
|
| 138 |
// Workaround for a rocBLAS bug when using multiple graphics cards:
|
|
|
|
| 285 |
size_t look_ahead_size = (size_t) (1.05 * size);
|
| 286 |
look_ahead_size = 256 * ((look_ahead_size + 255)/256);
|
| 287 |
ggml_cuda_set_device(device);
|
| 288 |
+
CUDA_CHECK(ggml_cuda_device_malloc(&ptr, look_ahead_size, device));
|
| 289 |
*actual_size = look_ahead_size;
|
| 290 |
pool_size += look_ahead_size;
|
| 291 |
#ifdef DEBUG_CUDA_MALLOC
|
|
|
|
| 551 |
size = std::max(size, (size_t)1); // cudaMalloc returns null for size 0
|
| 552 |
|
| 553 |
void * dev_ptr;
|
| 554 |
+
cudaError_t err = ggml_cuda_device_malloc(&dev_ptr, size, buft_ctx->device);
|
| 555 |
if (err != cudaSuccess) {
|
| 556 |
// clear the error
|
| 557 |
cudaGetLastError();
|
|
|
|
| 812 |
// currently, init_tensor cannot fail, it needs to be fixed in ggml-backend first
|
| 813 |
ggml_cuda_set_device(id);
|
| 814 |
char * buf;
|
| 815 |
+
CUDA_CHECK(ggml_cuda_device_malloc((void**)&buf, size, id));
|
| 816 |
|
| 817 |
// set padding to 0 to avoid possible NaN values
|
| 818 |
if (size > original_size) {
|
ggml-cuda/common.cuh
CHANGED
|
@@ -79,13 +79,8 @@
|
|
| 79 |
#define cudaHostRegisterReadOnly hipHostRegisterReadOnly
|
| 80 |
#define cudaHostUnregister hipHostUnregister
|
| 81 |
#define cudaLaunchHostFunc hipLaunchHostFunc
|
| 82 |
-
#ifdef GGML_HIP_UMA
|
| 83 |
-
#define cudaMalloc hipMallocManaged
|
| 84 |
-
#define cudaMallocHost(ptr, size) hipHostMalloc(ptr, size)
|
| 85 |
-
#else
|
| 86 |
#define cudaMalloc hipMalloc
|
| 87 |
#define cudaMallocHost(ptr, size) hipHostMalloc(ptr, size, hipHostMallocDefault)
|
| 88 |
-
#endif
|
| 89 |
#define cudaMemcpy hipMemcpy
|
| 90 |
#define cudaMemcpyAsync hipMemcpyAsync
|
| 91 |
#define cudaMemcpyPeerAsync hipMemcpyPeerAsync
|
|
|
|
| 79 |
#define cudaHostRegisterReadOnly hipHostRegisterReadOnly
|
| 80 |
#define cudaHostUnregister hipHostUnregister
|
| 81 |
#define cudaLaunchHostFunc hipLaunchHostFunc
|
|
|
|
|
|
|
|
|
|
|
|
|
| 82 |
#define cudaMalloc hipMalloc
|
| 83 |
#define cudaMallocHost(ptr, size) hipHostMalloc(ptr, size, hipHostMallocDefault)
|
|
|
|
| 84 |
#define cudaMemcpy hipMemcpy
|
| 85 |
#define cudaMemcpyAsync hipMemcpyAsync
|
| 86 |
#define cudaMemcpyPeerAsync hipMemcpyPeerAsync
|