uvos commited on
Commit
b6dc6a1
·
1 Parent(s): 7cee55b

HIP: enable vec fattn on RDNA4 (llama/14323)

Browse files
ggml/src/ggml-cuda/common.cuh CHANGED
@@ -241,8 +241,18 @@ static bool fp16_mma_available(const int cc) {
241
  #if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && !defined(GGML_HIP_ROCWMMA_FATTN)
242
  return false;
243
  #else
244
- return (GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA) ||
245
- GGML_CUDA_CC_IS_CDNA(cc) || GGML_CUDA_CC_IS_RDNA3(cc) || GGML_CUDA_CC_IS_RDNA4(cc);
 
 
 
 
 
 
 
 
 
 
246
  #endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && !defined(GGML_HIP_ROCWMMA_FATTN)
247
  }
248
 
 
241
  #if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && !defined(GGML_HIP_ROCWMMA_FATTN)
242
  return false;
243
  #else
244
+ if ((GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA) ||
245
+ GGML_CUDA_CC_IS_CDNA(cc) || GGML_CUDA_CC_IS_RDNA3(cc)) {
246
+ return true;
247
+ } else if (GGML_CUDA_CC_IS_RDNA4(cc)) {
248
+ #if defined(GGML_HIP_ROCWMMA_FATTN) && defined(GGML_HIP_ROCWMMA_FATTN_GFX12)
249
+ return true;
250
+ #else
251
+ return false;
252
+ #endif // defined(GGML_HIP_ROCWMMA_FATTN) && defined(GGML_HIP_ROCWMMA_FATTN_GFX12)
253
+ } else {
254
+ return false;
255
+ }
256
  #endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && !defined(GGML_HIP_ROCWMMA_FATTN)
257
  }
258
 
ggml/src/ggml-cuda/ggml-cuda.cu CHANGED
@@ -100,8 +100,7 @@ int ggml_cuda_get_device() {
100
  static cudaError_t ggml_cuda_device_malloc(void ** ptr, size_t size, int device) {
101
  ggml_cuda_set_device(device);
102
  cudaError_t err;
103
- if (getenv("GGML_CUDA_ENABLE_UNIFIED_MEMORY") != nullptr)
104
- {
105
  err = cudaMallocManaged(ptr, size);
106
  #if defined(GGML_USE_HIP)
107
  if (err == hipSuccess) {
@@ -119,9 +118,7 @@ static cudaError_t ggml_cuda_device_malloc(void ** ptr, size_t size, int device)
119
  err = cudaMalloc(ptr, size);
120
  }
121
  #endif // defined(GGML_USE_HIP)
122
- }
123
- else
124
- {
125
  err = cudaMalloc(ptr, size);
126
  }
127
  return err;
 
100
  static cudaError_t ggml_cuda_device_malloc(void ** ptr, size_t size, int device) {
101
  ggml_cuda_set_device(device);
102
  cudaError_t err;
103
+ if (getenv("GGML_CUDA_ENABLE_UNIFIED_MEMORY") != nullptr) {
 
104
  err = cudaMallocManaged(ptr, size);
105
  #if defined(GGML_USE_HIP)
106
  if (err == hipSuccess) {
 
118
  err = cudaMalloc(ptr, size);
119
  }
120
  #endif // defined(GGML_USE_HIP)
121
+ } else {
 
 
122
  err = cudaMalloc(ptr, size);
123
  }
124
  return err;