JohannesGaessler commited on
Commit
654d245
·
1 Parent(s): b10cbfd

CUDA: fixed tensor cores not being used on RDNA3 (llama/4697)

Browse files
Files changed (1) hide show
  1. ggml-cuda.cu +24 -23
ggml-cuda.cu CHANGED
@@ -119,10 +119,29 @@
119
  #define MIN_CC_DP4A 610 // minimum compute capability for __dp4a, an intrinsic for byte-wise dot products
120
  #define CC_VOLTA 700
121
  #define CC_OFFSET_AMD 1000000
 
122
  #define CC_RDNA2 (CC_OFFSET_AMD + 1030)
 
123
 
124
  #define GGML_CUDA_MAX_NODES 8192
125
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
126
  #if defined(GGML_USE_HIPBLAS)
127
  #define __CUDA_ARCH__ 1300
128
 
@@ -189,23 +208,6 @@ static __device__ __forceinline__ int __dp4a(const int a, const int b, int c) {
189
  }
190
  #endif // defined(GGML_USE_HIPBLAS)
191
 
192
- // define this if you want to always fallback to MMQ kernels and not use cuBLAS for matrix multiplication
193
- // on modern hardware, using cuBLAS is recommended as it utilizes F16 tensor cores which are very performant
194
- // for large computational tasks. the drawback is that this requires some extra amount of VRAM:
195
- // - 7B quantum model: +100-200 MB
196
- // - 13B quantum model: +200-400 MB
197
- //
198
- //#define GGML_CUDA_FORCE_MMQ
199
-
200
- // TODO: improve this to be correct for more hardware
201
- // for example, currently fails for GeForce GTX 1660 which is TURING arch (> VOLTA) but does not have tensor cores
202
- #if !defined(GGML_CUDA_FORCE_MMQ) && (!defined(GGML_USE_HIPBLAS) || defined(RDNA3))
203
- #define CUDA_USE_TENSOR_CORES
204
- #endif
205
-
206
- // max batch size to use MMQ kernels when tensor cores are available
207
- #define MMQ_MAX_BATCH_SIZE 32
208
-
209
  #if defined(_MSC_VER)
210
  #pragma warning(disable: 4244 4267) // possible loss of data
211
  #endif
@@ -8661,13 +8663,12 @@ static void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1
8661
  }
8662
 
8663
  #if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
8664
- const bool fp16_performance_good = true;
8665
 
8666
- #ifdef RDNA3
8667
- const bool use_mul_mat_q = false;
8668
- #else
8669
- const bool use_mul_mat_q = true;
8670
- #endif // RDNA3
8671
 
8672
  #else
8673
 
 
119
  #define MIN_CC_DP4A 610 // minimum compute capability for __dp4a, an intrinsic for byte-wise dot products
120
  #define CC_VOLTA 700
121
  #define CC_OFFSET_AMD 1000000
122
+ #define CC_RDNA1 (CC_OFFSET_AMD + 1010)
123
  #define CC_RDNA2 (CC_OFFSET_AMD + 1030)
124
+ #define CC_RDNA3 (CC_OFFSET_AMD + 1100)
125
 
126
  #define GGML_CUDA_MAX_NODES 8192
127
 
128
+ // define this if you want to always fallback to MMQ kernels and not use cuBLAS for matrix multiplication
129
+ // on modern hardware, using cuBLAS is recommended as it utilizes F16 tensor cores which are very performant
130
+ // for large computational tasks. the drawback is that this requires some extra amount of VRAM:
131
+ // - 7B quantum model: +100-200 MB
132
+ // - 13B quantum model: +200-400 MB
133
+ //
134
+ //#define GGML_CUDA_FORCE_MMQ
135
+
136
+ // TODO: improve this to be correct for more hardware
137
+ // for example, currently fails for GeForce GTX 1660 which is TURING arch (> VOLTA) but does not have tensor cores
138
+ #if !defined(GGML_CUDA_FORCE_MMQ)
139
+ #define CUDA_USE_TENSOR_CORES
140
+ #endif
141
+
142
+ // max batch size to use MMQ kernels when tensor cores are available
143
+ #define MMQ_MAX_BATCH_SIZE 32
144
+
145
  #if defined(GGML_USE_HIPBLAS)
146
  #define __CUDA_ARCH__ 1300
147
 
 
208
  }
209
  #endif // defined(GGML_USE_HIPBLAS)
210
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
211
  #if defined(_MSC_VER)
212
  #pragma warning(disable: 4244 4267) // possible loss of data
213
  #endif
 
8663
  }
8664
 
8665
  #if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
 
8666
 
8667
+ const bool fp16_performance_good = min_compute_capability >= CC_RDNA1;
8668
+ bool use_mul_mat_q = ggml_is_quantized(src0->type);
8669
+ #ifdef CUDA_USE_TENSOR_CORES
8670
+ use_mul_mat_q = use_mul_mat_q && min_compute_capability < CC_RDNA3;
8671
+ #endif // CUDA_USE_TENSOR_CORES
8672
 
8673
  #else
8674