uvos commited on
Commit
bf49bbe
·
1 Parent(s): 20ee62d

Add some minimal optimizations for CDNA (llama/10498)

Browse files

* Add some minimal optimizations for CDNA

* ggml_cuda: set launch bounds also for GCN as it helps there too

ggml/src/ggml-cuda/common.cuh CHANGED
@@ -47,9 +47,20 @@
47
  #define CC_TURING 750
48
  #define CC_AMPERE 800
49
  #define CC_OFFSET_AMD 1000000
50
- #define CC_RDNA1 (CC_OFFSET_AMD + 1010)
51
- #define CC_RDNA2 (CC_OFFSET_AMD + 1030)
52
- #define CC_RDNA3 (CC_OFFSET_AMD + 1100)
 
 
 
 
 
 
 
 
 
 
 
53
  #define CC_QY1 210
54
  #define CC_QY2 220
55
 
 
47
  #define CC_TURING 750
48
  #define CC_AMPERE 800
49
  #define CC_OFFSET_AMD 1000000
50
+
51
+ // GCN/CNDA, wave size is 64
52
+ #define CC_GCN4 (CC_OFFSET_AMD + 803) // Tonga, Fiji, Polaris, minimum for fast fp16
53
+ #define CC_VEGA (CC_OFFSET_AMD + 900) // Vega56/64, minimum for fp16 dual issue
54
+ #define CC_VEGA20 (CC_OFFSET_AMD + 906) // MI50/Radeon VII, minimum for dp4a
55
+ #define CC_CDNA (CC_OFFSET_AMD + 908) // MI100, minimum for MFMA, acc registers
56
+ #define CC_CDNA2 (CC_OFFSET_AMD + 910) // MI210, minimum acc register renameing
57
+ #define CC_CDNA3 (CC_OFFSET_AMD + 942) // MI300
58
+
59
+ // RNDA removes MFMA, dp4a, xnack, acc registers, wave size is 32
60
+ #define CC_RDNA1 (CC_OFFSET_AMD + 1010) // RX 5000
61
+ #define CC_RDNA2 (CC_OFFSET_AMD + 1030) // RX 6000, minimum for dp4a
62
+ #define CC_RDNA3 (CC_OFFSET_AMD + 1100) // RX 7000, minimum for WMMA
63
+
64
  #define CC_QY1 210
65
  #define CC_QY2 220
66
 
ggml/src/ggml-cuda/ggml-cuda.cu CHANGED
@@ -1107,6 +1107,11 @@ static void ggml_cuda_op_mul_mat_cublas(
1107
  const half alpha_f16 = 1.0f;
1108
  const half beta_f16 = 0.0f;
1109
 
 
 
 
 
 
1110
  CUBLAS_CHECK(cublasSetStream(ctx.cublas_handle(id), stream));
1111
  CUBLAS_CHECK(
1112
  cublasGemmEx(ctx.cublas_handle(id), CUBLAS_OP_T, CUBLAS_OP_N,
@@ -1114,7 +1119,7 @@ static void ggml_cuda_op_mul_mat_cublas(
1114
  &alpha_f16, src0_ptr, CUDA_R_16F, ne00,
1115
  src1_ptr, CUDA_R_16F, ne10,
1116
  &beta_f16, dst_f16.get(), CUDA_R_16F, ldc,
1117
- CUBLAS_COMPUTE_16F,
1118
  CUBLAS_GEMM_DEFAULT_TENSOR_OP));
1119
 
1120
  const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16);
@@ -1607,6 +1612,10 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co
1607
  cublasComputeType_t cu_compute_type = CUBLAS_COMPUTE_16F;
1608
  cudaDataType_t cu_data_type = CUDA_R_16F;
1609
 
 
 
 
 
1610
  // dst strides
1611
  size_t nbd2 = dst->nb[2];
1612
  size_t nbd3 = dst->nb[3];
 
1107
  const half alpha_f16 = 1.0f;
1108
  const half beta_f16 = 0.0f;
1109
 
1110
+ cublasComputeType_t cu_compute_type = CUBLAS_COMPUTE_16F;
1111
+ if (ggml_cuda_info().devices[ctx.device].cc == CC_CDNA) {
1112
+ cu_compute_type = CUBLAS_COMPUTE_32F;
1113
+ }
1114
+
1115
  CUBLAS_CHECK(cublasSetStream(ctx.cublas_handle(id), stream));
1116
  CUBLAS_CHECK(
1117
  cublasGemmEx(ctx.cublas_handle(id), CUBLAS_OP_T, CUBLAS_OP_N,
 
1119
  &alpha_f16, src0_ptr, CUDA_R_16F, ne00,
1120
  src1_ptr, CUDA_R_16F, ne10,
1121
  &beta_f16, dst_f16.get(), CUDA_R_16F, ldc,
1122
+ cu_compute_type,
1123
  CUBLAS_GEMM_DEFAULT_TENSOR_OP));
1124
 
1125
  const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16);
 
1612
  cublasComputeType_t cu_compute_type = CUBLAS_COMPUTE_16F;
1613
  cudaDataType_t cu_data_type = CUDA_R_16F;
1614
 
1615
+ if (ggml_cuda_info().devices[ctx.device].cc == CC_CDNA) {
1616
+ cu_compute_type = CUBLAS_COMPUTE_32F;
1617
+ }
1618
+
1619
  // dst strides
1620
  size_t nbd2 = dst->nb[2];
1621
  size_t nbd3 = dst->nb[3];
ggml/src/ggml-cuda/mmq.cu CHANGED
@@ -148,5 +148,5 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11) {
148
  return cc < CC_VOLTA || ne11 < MMQ_DP4A_MAX_BATCH_SIZE;
149
  }
150
 
151
- return cc < CC_RDNA3 || ne11 < MMQ_DP4A_MAX_BATCH_SIZE;
152
  }
 
148
  return cc < CC_VOLTA || ne11 < MMQ_DP4A_MAX_BATCH_SIZE;
149
  }
150
 
151
+ return (cc < CC_RDNA3 && cc != CC_CDNA && cc != CC_VEGA20) || ne11 < MMQ_DP4A_MAX_BATCH_SIZE;
152
  }
ggml/src/ggml-cuda/mmq.cuh CHANGED
@@ -2570,9 +2570,9 @@ static __device__ void mul_mat_q_process_tile(
2570
 
2571
  template <ggml_type type, int mmq_x, int nwarps, bool need_check>
2572
  #if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
2573
- #if defined(RDNA3) || defined(RDNA2)
2574
  __launch_bounds__(WARP_SIZE*nwarps, 2)
2575
- #endif // defined(RDNA3) || defined(RDNA2)
2576
  #else
2577
  #if __CUDA_ARCH__ >= CC_VOLTA
2578
  __launch_bounds__(WARP_SIZE*nwarps, 1)
 
2570
 
2571
  template <ggml_type type, int mmq_x, int nwarps, bool need_check>
2572
  #if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
2573
+ #if defined(RDNA3) || defined(RDNA2) || defined(CDNA) || defined(GCN)
2574
  __launch_bounds__(WARP_SIZE*nwarps, 2)
2575
+ #endif // defined(RDNA3) || defined(RDNA2) || defined(CDNA) || defined(GCN)
2576
  #else
2577
  #if __CUDA_ARCH__ >= CC_VOLTA
2578
  __launch_bounds__(WARP_SIZE*nwarps, 1)
ggml/src/ggml-cuda/mmvq.cu CHANGED
@@ -142,7 +142,7 @@ static void mul_mat_vec_q_cuda(
142
  int64_t nwarps = 1;
143
  int64_t rows_per_cuda_block = 1;
144
 
145
- if (ggml_cuda_info().devices[id].cc < CC_RDNA2) { // NVIDIA and AMD older than RDNA2
146
  switch(ncols_y) {
147
  case 1:
148
  nwarps = 4;
 
142
  int64_t nwarps = 1;
143
  int64_t rows_per_cuda_block = 1;
144
 
145
+ if (ggml_cuda_info().devices[id].cc < CC_CDNA || ggml_cuda_info().devices[id].cc == CC_RDNA1) { // NVIDIA and AMD older than RDNA2 but not CDNA
146
  switch(ncols_y) {
147
  case 1:
148
  nwarps = 4;
ggml/src/ggml-cuda/vendors/hip.h CHANGED
@@ -95,6 +95,14 @@
95
 
96
  #define __CUDA_ARCH__ 1300
97
 
 
 
 
 
 
 
 
 
98
  #if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) || defined(__gfx1103__) || \
99
  defined(__gfx1150__) || defined(__gfx1151__)
100
  #define RDNA3
 
95
 
96
  #define __CUDA_ARCH__ 1300
97
 
98
+ #if defined(__gfx803__) || defined(__gfx900__) || defined(__gfx906__)
99
+ #define GCN
100
+ #endif
101
+
102
+ #if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx942__)
103
+ #define CDNA
104
+ #endif
105
+
106
  #if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) || defined(__gfx1103__) || \
107
  defined(__gfx1150__) || defined(__gfx1151__)
108
  #define RDNA3