uvos JohannesGaessler commited on
Commit
1f75790
·
1 Parent(s): 66d5b20

CUDA/HIP: refractor mmqv to unify the calculation of nwarps and rows per block between host and device code. (llama/12177)

Browse files

refactor mmqv to unify the calculation of nwarps and rows per block between host and device code.

---------

Co-authored-by: Johannes Gäßler <[email protected]>

ggml/src/ggml-cuda/common.cuh CHANGED
@@ -395,11 +395,11 @@ static __device__ __forceinline__ uint32_t __hgt2_mask(const half2 a, const half
395
 
396
  static __device__ __forceinline__ int ggml_cuda_dp4a(const int a, const int b, int c) {
397
  #if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
398
- #if defined(__gfx906__) || defined(__gfx908__) || defined(__gfx90a__) || defined(RDNA2)
399
  c = __builtin_amdgcn_sdot4(a, b, c, false);
400
  #elif defined(RDNA3)
401
  c = __builtin_amdgcn_sudot4( true, a, true, b, c, false);
402
- #elif defined(__gfx1010__) || defined(__gfx900__)
403
  int tmp1;
404
  int tmp2;
405
  asm("\n \
 
395
 
396
  static __device__ __forceinline__ int ggml_cuda_dp4a(const int a, const int b, int c) {
397
  #if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
398
+ #if defined(CDNA) || defined(RDNA2) || defined(__gfx906__)
399
  c = __builtin_amdgcn_sdot4(a, b, c, false);
400
  #elif defined(RDNA3)
401
  c = __builtin_amdgcn_sudot4( true, a, true, b, c, false);
402
+ #elif defined(RDNA1) || defined(__gfx900__)
403
  int tmp1;
404
  int tmp2;
405
  asm("\n \
ggml/src/ggml-cuda/mmvq.cu CHANGED
@@ -47,11 +47,89 @@ static constexpr __device__ int get_vdr_mmvq(ggml_type type) {
47
  1;
48
  }
49
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
  template <ggml_type type, int ncols_y>
51
- #if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
52
  // tell the compiler to use as many registers as it wants, see nwarps definition below
53
- __launch_bounds__((ncols_y <= 4 ? 4 : 2)*WARP_SIZE, 1)
54
- #endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
55
  static __global__ void mul_mat_vec_q(
56
  const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
57
  const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
@@ -59,24 +137,20 @@ static __global__ void mul_mat_vec_q(
59
  constexpr int qk = ggml_cuda_type_traits<type>::qk;
60
  constexpr int qi = ggml_cuda_type_traits<type>::qi;
61
  constexpr int vdr = get_vdr_mmvq(type);
 
 
 
 
62
 
63
  constexpr vec_dot_q_cuda_t vec_dot_q_cuda = get_vec_dot_q_cuda(type);
64
 
65
- #if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && (defined(RDNA2) || defined(RDNA3))
66
- constexpr int nwarps = 1;
67
- constexpr int rows_per_cuda_block = 1;
68
- #else
69
- constexpr int nwarps = ncols_y <= 4 ? 4 : 2;
70
- constexpr int rows_per_cuda_block = ncols_y == 1 ? 1 : 2;
71
- #endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && !defined(RDNA2) && !defined(RDNA3)
72
-
73
- const int tid = WARP_SIZE*threadIdx.y + threadIdx.x;
74
  const int row0 = rows_per_cuda_block*blockIdx.x;
75
  const int blocks_per_row_x = ncols_x / qk;
76
  const int blocks_per_col_y = nrows_y / QK8_1;
77
- constexpr int blocks_per_iter = vdr * nwarps*WARP_SIZE / qi;
78
 
79
- // partial sum for each thread
80
  float tmp[ncols_y][rows_per_cuda_block] = {0.0f};
81
 
82
  const block_q8_1 * y = (const block_q8_1 *) vy;
@@ -96,7 +170,7 @@ static __global__ void mul_mat_vec_q(
96
  }
97
  }
98
 
99
- __shared__ float tmp_shared[nwarps-1 > 0 ? nwarps-1 : 1][ncols_y][rows_per_cuda_block][WARP_SIZE];
100
  if (threadIdx.y > 0) {
101
  #pragma unroll
102
  for (int j = 0; j < ncols_y; ++j) {
@@ -120,7 +194,7 @@ static __global__ void mul_mat_vec_q(
120
  for (int l = 0; l < nwarps-1; ++l) {
121
  tmp[j][i] += tmp_shared[l][j][i][threadIdx.x];
122
  }
123
- tmp[j][i] = warp_reduce_sum(tmp[j][i]);
124
  }
125
 
126
  if (threadIdx.x < rows_per_cuda_block && (rows_per_cuda_block == 1 || row0 + threadIdx.x < nrows_dst)) {
@@ -129,6 +203,13 @@ static __global__ void mul_mat_vec_q(
129
  }
130
  }
131
 
 
 
 
 
 
 
 
132
  template <ggml_type type>
133
  static void mul_mat_vec_q_cuda(
134
  const void * vx, const void * vy, float * dst,
@@ -137,65 +218,67 @@ static void mul_mat_vec_q_cuda(
137
  GGML_ASSERT(ncols_x % ggml_blck_size(type) == 0);
138
  GGML_ASSERT(ncols_y <= MMVQ_MAX_BATCH_SIZE);
139
 
140
- int id = ggml_cuda_get_device();
141
-
142
- int64_t nwarps = 1;
143
- int64_t rows_per_cuda_block = 1;
144
-
145
- if (ggml_cuda_info().devices[id].cc < GGML_CUDA_CC_RDNA2) { // NVIDIA and AMD older than RDNA2
146
- switch(ncols_y) {
147
- case 1:
148
- nwarps = 4;
149
- rows_per_cuda_block = 1;
150
- break;
151
- case 2:
152
- case 3:
153
- case 4:
154
- nwarps = 4;
155
- rows_per_cuda_block = 2;
156
- break;
157
- case 5:
158
- case 6:
159
- case 7:
160
- case 8:
161
- nwarps = 2;
162
- rows_per_cuda_block = 2;
163
- break;
164
- default:
165
- GGML_ABORT("fatal error");
166
- break;
167
- }
168
- }
169
-
170
- const int64_t nblocks = (nrows_x + rows_per_cuda_block - 1) / rows_per_cuda_block;
171
- const dim3 block_nums(nblocks, 1, 1);
172
- const dim3 block_dims(WARP_SIZE, nwarps, 1);
173
 
174
  switch (ncols_y) {
175
  case 1:
176
- mul_mat_vec_q<type, 1><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
 
 
 
177
  break;
 
178
  case 2:
179
- mul_mat_vec_q<type, 2><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
 
 
 
180
  break;
 
181
  case 3:
182
- mul_mat_vec_q<type, 3><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
 
 
 
183
  break;
 
184
  case 4:
185
- mul_mat_vec_q<type, 4><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
 
 
 
186
  break;
 
187
  case 5:
188
- mul_mat_vec_q<type, 5><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
 
 
 
189
  break;
 
190
  case 6:
191
- mul_mat_vec_q<type, 6><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
 
 
 
192
  break;
 
193
  case 7:
194
- mul_mat_vec_q<type, 7><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
 
 
 
195
  break;
 
196
  case 8:
197
- mul_mat_vec_q<type, 8><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
 
 
 
198
  break;
 
199
  default:
200
  GGML_ABORT("fatal error");
201
  break;
 
47
  1;
48
  }
49
 
50
+ enum mmvq_parameter_table_id {
51
+ MMVQ_PARAMETERS_GENERIC = 0,
52
+ MMVQ_PARAMETERS_GCN,
53
+ MMVQ_PARAMETERS_RDNA2
54
+ };
55
+
56
+ static constexpr __device__ mmvq_parameter_table_id get_device_table_id() {
57
+ #if defined(RDNA2) || defined(RDNA3)
58
+ return MMVQ_PARAMETERS_RDNA2;
59
+ #elif defined(GCN) || defined(CDNA)
60
+ return MMVQ_PARAMETERS_GCN;
61
+ #else
62
+ return MMVQ_PARAMETERS_GENERIC;
63
+ #endif
64
+ }
65
+
66
+ static __host__ mmvq_parameter_table_id get_device_table_id(int cc) {
67
+ if (GGML_CUDA_CC_IS_RDNA2(cc) || GGML_CUDA_CC_IS_RDNA3(cc)) {
68
+ return MMVQ_PARAMETERS_RDNA2;
69
+ }
70
+ if (GGML_CUDA_CC_IS_GCN(cc) || GGML_CUDA_CC_IS_CDNA(cc)) {
71
+ return MMVQ_PARAMETERS_GCN;
72
+ }
73
+ return MMVQ_PARAMETERS_GENERIC;
74
+ }
75
+
76
+ static constexpr __host__ __device__ int calc_nwarps(int ncols_y, mmvq_parameter_table_id table_id) {
77
+ if (table_id == MMVQ_PARAMETERS_GENERIC) {
78
+ switch (ncols_y) {
79
+ case 1:
80
+ case 2:
81
+ case 3:
82
+ case 4:
83
+ return 4;
84
+ case 5:
85
+ case 6:
86
+ case 7:
87
+ case 8:
88
+ return 2;
89
+ default:
90
+ return 1;
91
+ }
92
+ } else if (table_id == MMVQ_PARAMETERS_GCN) {
93
+ switch (ncols_y) {
94
+ case 1:
95
+ case 2:
96
+ case 3:
97
+ case 4:
98
+ return 2;
99
+ case 5:
100
+ case 6:
101
+ case 7:
102
+ case 8:
103
+ default:
104
+ return 1;
105
+ }
106
+ }
107
+ return 1;
108
+ }
109
+
110
+ static constexpr __host__ __device__ int calc_rows_per_block(int ncols_y, int table_id) {
111
+ if (table_id == MMVQ_PARAMETERS_GENERIC || table_id == MMVQ_PARAMETERS_GCN) {
112
+ switch (ncols_y) {
113
+ case 1:
114
+ return 1;
115
+ case 2:
116
+ case 3:
117
+ case 4:
118
+ case 5:
119
+ case 6:
120
+ case 7:
121
+ case 8:
122
+ return 2;
123
+ default:
124
+ return 1;
125
+ }
126
+ }
127
+ return 1;
128
+ }
129
+
130
  template <ggml_type type, int ncols_y>
 
131
  // tell the compiler to use as many registers as it wants, see nwarps definition below
132
+ __launch_bounds__(calc_nwarps(ncols_y, get_device_table_id())*ggml_cuda_get_physical_warp_size(), 1)
 
133
  static __global__ void mul_mat_vec_q(
134
  const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
135
  const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
 
137
  constexpr int qk = ggml_cuda_type_traits<type>::qk;
138
  constexpr int qi = ggml_cuda_type_traits<type>::qi;
139
  constexpr int vdr = get_vdr_mmvq(type);
140
+ constexpr mmvq_parameter_table_id table_id = get_device_table_id();
141
+ constexpr int nwarps = calc_nwarps(ncols_y, table_id);
142
+ constexpr int rows_per_cuda_block = calc_rows_per_block(ncols_y, table_id);
143
+ constexpr int warp_size = ggml_cuda_get_physical_warp_size();
144
 
145
  constexpr vec_dot_q_cuda_t vec_dot_q_cuda = get_vec_dot_q_cuda(type);
146
 
147
+ const int tid = warp_size*threadIdx.y + threadIdx.x;
 
 
 
 
 
 
 
 
148
  const int row0 = rows_per_cuda_block*blockIdx.x;
149
  const int blocks_per_row_x = ncols_x / qk;
150
  const int blocks_per_col_y = nrows_y / QK8_1;
151
+ constexpr int blocks_per_iter = vdr * nwarps*warp_size / qi;
152
 
153
+ // partial sum for each thread
154
  float tmp[ncols_y][rows_per_cuda_block] = {0.0f};
155
 
156
  const block_q8_1 * y = (const block_q8_1 *) vy;
 
170
  }
171
  }
172
 
173
+ __shared__ float tmp_shared[nwarps-1 > 0 ? nwarps-1 : 1][ncols_y][rows_per_cuda_block][warp_size];
174
  if (threadIdx.y > 0) {
175
  #pragma unroll
176
  for (int j = 0; j < ncols_y; ++j) {
 
194
  for (int l = 0; l < nwarps-1; ++l) {
195
  tmp[j][i] += tmp_shared[l][j][i][threadIdx.x];
196
  }
197
+ tmp[j][i] = warp_reduce_sum<warp_size>(tmp[j][i]);
198
  }
199
 
200
  if (threadIdx.x < rows_per_cuda_block && (rows_per_cuda_block == 1 || row0 + threadIdx.x < nrows_dst)) {
 
203
  }
204
  }
205
 
206
+ static std::pair<dim3, dim3> calc_launch_params(const int ncols_y, const int nrows_x, const int warp_size, const mmvq_parameter_table_id table_id) {
207
+ const int64_t nblocks = (nrows_x + calc_rows_per_block(ncols_y, table_id) - 1) / calc_rows_per_block(ncols_y, table_id);
208
+ const dim3 block_nums(nblocks, 1, 1);
209
+ const dim3 block_dims(warp_size, calc_nwarps(ncols_y, table_id), 1);
210
+ return {block_nums, block_dims};
211
+ }
212
+
213
  template <ggml_type type>
214
  static void mul_mat_vec_q_cuda(
215
  const void * vx, const void * vy, float * dst,
 
218
  GGML_ASSERT(ncols_x % ggml_blck_size(type) == 0);
219
  GGML_ASSERT(ncols_y <= MMVQ_MAX_BATCH_SIZE);
220
 
221
+ const int device = ggml_cuda_get_device();
222
+ const int warp_size = ggml_cuda_info().devices[device].warp_size;
223
+ const mmvq_parameter_table_id table_id = get_device_table_id(ggml_cuda_info().devices[device].cc);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
224
 
225
  switch (ncols_y) {
226
  case 1:
227
+ {
228
+ constexpr int c_ncols_y = 1;
229
+ std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_y, nrows_x, warp_size, table_id);
230
+ mul_mat_vec_q<type, c_ncols_y><<<dims.first, dims.second, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
231
  break;
232
+ }
233
  case 2:
234
+ {
235
+ constexpr int c_ncols_y = 2;
236
+ std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_y, nrows_x, warp_size, table_id);
237
+ mul_mat_vec_q<type, c_ncols_y><<<dims.first, dims.second, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
238
  break;
239
+ }
240
  case 3:
241
+ {
242
+ constexpr int c_ncols_y = 3;
243
+ std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_y, nrows_x, warp_size, table_id);
244
+ mul_mat_vec_q<type, c_ncols_y><<<dims.first, dims.second, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
245
  break;
246
+ }
247
  case 4:
248
+ {
249
+ constexpr int c_ncols_y = 4;
250
+ std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_y, nrows_x, warp_size, table_id);
251
+ mul_mat_vec_q<type, c_ncols_y><<<dims.first, dims.second, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
252
  break;
253
+ }
254
  case 5:
255
+ {
256
+ constexpr int c_ncols_y = 5;
257
+ std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_y, nrows_x, warp_size, table_id);
258
+ mul_mat_vec_q<type, c_ncols_y><<<dims.first, dims.second, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
259
  break;
260
+ }
261
  case 6:
262
+ {
263
+ constexpr int c_ncols_y = 6;
264
+ std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_y, nrows_x, warp_size, table_id);
265
+ mul_mat_vec_q<type, c_ncols_y><<<dims.first, dims.second, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
266
  break;
267
+ }
268
  case 7:
269
+ {
270
+ constexpr int c_ncols_y = 7;
271
+ std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_y, nrows_x, warp_size, table_id);
272
+ mul_mat_vec_q<type, c_ncols_y><<<dims.first, dims.second, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
273
  break;
274
+ }
275
  case 8:
276
+ {
277
+ constexpr int c_ncols_y = 8;
278
+ std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_y, nrows_x, warp_size, table_id);
279
+ mul_mat_vec_q<type, c_ncols_y><<<dims.first, dims.second, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
280
  break;
281
+ }
282
  default:
283
  GGML_ABORT("fatal error");
284
  break;