Diego Devesa JohannesGaessler commited on
Commit
69ae50d
·
1 Parent(s): 0a14325

cuda : optimize argmax (llama/10441)

Browse files

* cuda : optimize argmax

* remove unused parameter

ggml-ci

* fixup : use full warps

ggml-ci

* Apply suggestions from code review

Co-authored-by: Johannes Gäßler <[email protected]>

* fix ub

* ggml : check ne00 <= INT32_MAX in argmax and argsort

---------

Co-authored-by: Johannes Gäßler <[email protected]>

ggml/src/ggml-cuda/argmax.cu CHANGED
@@ -1,57 +1,69 @@
1
- #include "common.cuh"
 
 
2
  #include "argmax.cuh"
 
3
  #include "sum.cuh"
4
 
5
- #include <cstdint>
 
6
 
7
- static __global__ void argmax_f32(
8
- const float * x, int32_t * dst, const int64_t ncols, const int64_t nrows) {
 
9
 
10
- int argmax_thread = 0;
11
- const int64_t row0 = (int64_t)blockIdx.x*WARP_SIZE;
 
 
 
 
 
12
 
13
  #pragma unroll
14
- for (int64_t row1 = 0; row1 < WARP_SIZE; ++row1) {
15
- const int64_t row = row0 + row1;
16
-
17
- if (row >= nrows) {
18
- break;
 
19
  }
 
20
 
21
- float maxval = -FLT_MAX;
22
- int argmax = -1;
23
-
24
- for (int32_t col = threadIdx.x; col < ncols; col += WARP_SIZE) {
25
- const float val = x[row*ncols + col];
26
- const int bigger = val > maxval;
27
- const int not_bigger = bigger ^ 0x00000001;
28
-
29
- maxval = maxval*not_bigger + val*bigger;
30
- argmax = argmax*not_bigger + col*bigger;
31
  }
32
 
 
 
 
 
 
 
 
33
  #pragma unroll
34
- for (int mask = 16; mask > 0; mask >>= 1) {
35
- const float val = __shfl_xor_sync(0xFFFFFFFF, maxval, mask, WARP_SIZE);
36
- const int col = __shfl_xor_sync(0xFFFFFFFF, argmax, mask, WARP_SIZE);
37
- const int bigger = val > maxval;
38
- const int not_bigger = bigger ^ 0x00000001;
39
-
40
- maxval = maxval*not_bigger + val*bigger;
41
- argmax = argmax*not_bigger + col*bigger;
42
  }
43
-
44
- const int store = row1 == threadIdx.x;
45
- argmax_thread += store*argmax;
46
  }
47
 
48
- const int row = row0 + threadIdx.x;
49
-
50
- if (row >= nrows) {
51
- return;
52
  }
53
-
54
- dst[row] = argmax_thread;
55
  }
56
 
57
  void ggml_cuda_argmax(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
@@ -70,10 +82,10 @@ void ggml_cuda_argmax(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
70
 
71
  cudaStream_t stream = ctx.stream();
72
 
73
- const int64_t num_blocks = (nrows + WARP_SIZE - 1) / WARP_SIZE;
74
-
75
- const dim3 blocks_dim(WARP_SIZE, 1, 1);
76
  const dim3 blocks_num(num_blocks, 1, 1);
77
 
78
- argmax_f32<<<blocks_num, blocks_dim, 0, stream>>>(src0_d, dst_d, ne00, nrows);
79
  }
 
1
+ #include <algorithm>
2
+ #include <cstdint>
3
+
4
  #include "argmax.cuh"
5
+ #include "common.cuh"
6
  #include "sum.cuh"
7
 
8
+ static __global__ void argmax_f32(const float * __restrict__ x, int32_t * __restrict__ dst, const int64_t ncols) {
9
+ const int64_t row = blockIdx.x;
10
 
11
+ float maxval = -FLT_MAX;
12
+ int argmax = -1;
13
+ const float * rowx = x + row * ncols;
14
 
15
+ for (int32_t col = threadIdx.x; col < ncols; col += blockDim.x) {
16
+ const float val = rowx[col];
17
+ if (val > maxval) {
18
+ maxval = val;
19
+ argmax = col;
20
+ }
21
+ }
22
 
23
  #pragma unroll
24
+ for (int offset = 16; offset > 0; offset >>= 1) {
25
+ const float val = __shfl_xor_sync(0xFFFFFFFF, maxval, offset, WARP_SIZE);
26
+ const int col = __shfl_xor_sync(0xFFFFFFFF, argmax, offset, WARP_SIZE);
27
+ if (val > maxval) {
28
+ maxval = val;
29
+ argmax = col;
30
  }
31
+ }
32
 
33
+ const int n_warps = blockDim.x / WARP_SIZE;
34
+ const int lane_id = threadIdx.x % WARP_SIZE;
35
+ const int warp_id = threadIdx.x / WARP_SIZE;
36
+ if (n_warps > 1) {
37
+ constexpr int max_warps = 1024 / WARP_SIZE;
38
+ __shared__ float shared_maxval[max_warps];
39
+ __shared__ int shared_argmax[max_warps];
40
+ if (lane_id == 0) {
41
+ shared_maxval[warp_id] = maxval;
42
+ shared_argmax[warp_id] = argmax;
43
  }
44
 
45
+ __syncthreads();
46
+
47
+ if (warp_id == 0) {
48
+ if (lane_id < n_warps) {
49
+ maxval = shared_maxval[lane_id];
50
+ argmax = shared_argmax[lane_id];
51
+ }
52
  #pragma unroll
53
+ for (int offset = 16; offset > 0; offset >>= 1) {
54
+ const float val = __shfl_xor_sync(0xFFFFFFFF, maxval, offset, WARP_SIZE);
55
+ const int col = __shfl_xor_sync(0xFFFFFFFF, argmax, offset, WARP_SIZE);
56
+ if (val > maxval) {
57
+ maxval = val;
58
+ argmax = col;
59
+ }
60
+ }
61
  }
 
 
 
62
  }
63
 
64
+ if (warp_id == 0 && lane_id == 0) {
65
+ dst[row] = argmax;
 
 
66
  }
 
 
67
  }
68
 
69
  void ggml_cuda_argmax(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
 
82
 
83
  cudaStream_t stream = ctx.stream();
84
 
85
+ const int64_t num_blocks = nrows;
86
+ const int64_t num_threads = std::min<int64_t>(1024, (ne00 + WARP_SIZE - 1) / WARP_SIZE * WARP_SIZE);
87
+ const dim3 blocks_dim(num_threads, 1, 1);
88
  const dim3 blocks_num(num_blocks, 1, 1);
89
 
90
+ argmax_f32<<<blocks_num, blocks_dim, 0, stream>>>(src0_d, dst_d, ne00);
91
  }
ggml/src/ggml-cuda/common.cuh CHANGED
@@ -180,8 +180,8 @@ static __device__ __forceinline__ int warp_reduce_sum(int x) {
180
  return __reduce_add_sync(0xffffffff, x);
181
  #else
182
  #pragma unroll
183
- for (int mask = 16; mask > 0; mask >>= 1) {
184
- x += __shfl_xor_sync(0xffffffff, x, mask, 32);
185
  }
186
  return x;
187
  #endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_AMPERE
@@ -189,17 +189,17 @@ static __device__ __forceinline__ int warp_reduce_sum(int x) {
189
 
190
  static __device__ __forceinline__ float warp_reduce_sum(float x) {
191
  #pragma unroll
192
- for (int mask = 16; mask > 0; mask >>= 1) {
193
- x += __shfl_xor_sync(0xffffffff, x, mask, 32);
194
  }
195
  return x;
196
  }
197
 
198
  static __device__ __forceinline__ float2 warp_reduce_sum(float2 a) {
199
  #pragma unroll
200
- for (int mask = 16; mask > 0; mask >>= 1) {
201
- a.x += __shfl_xor_sync(0xffffffff, a.x, mask, 32);
202
- a.y += __shfl_xor_sync(0xffffffff, a.y, mask, 32);
203
  }
204
  return a;
205
  }
@@ -209,16 +209,16 @@ static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) {
209
 
210
  #if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
211
  #pragma unroll
212
- for (int mask = 16; mask > 0; mask >>= 1) {
213
- const half2 a_other = __shfl_xor_sync(0xffffffff, a, mask, 32);
214
  reinterpret_cast<half&>(a.x) += __low2half(a_other);
215
  reinterpret_cast<half&>(a.y) += __high2half(a_other);
216
  }
217
  return a;
218
  #else
219
  #pragma unroll
220
- for (int mask = 16; mask > 0; mask >>= 1) {
221
- a = __hadd2(a, __shfl_xor_sync(0xffffffff, a, mask, 32));
222
  }
223
  return a;
224
  #endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
@@ -231,8 +231,8 @@ static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) {
231
 
232
  static __device__ __forceinline__ float warp_reduce_max(float x) {
233
  #pragma unroll
234
- for (int mask = 16; mask > 0; mask >>= 1) {
235
- x = fmaxf(x, __shfl_xor_sync(0xffffffff, x, mask, 32));
236
  }
237
  return x;
238
  }
@@ -275,8 +275,8 @@ static __device__ __forceinline__ half2 ggml_cuda_hmax2(const half2 a, const hal
275
  static __device__ __forceinline__ half2 warp_reduce_max(half2 x) {
276
  #if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL
277
  #pragma unroll
278
- for (int mask = 16; mask > 0; mask >>= 1) {
279
- x = ggml_cuda_hmax2(x, __shfl_xor_sync(0xffffffff, x, mask, 32));
280
  }
281
  return x;
282
  #else
 
180
  return __reduce_add_sync(0xffffffff, x);
181
  #else
182
  #pragma unroll
183
+ for (int offset = 16; offset > 0; offset >>= 1) {
184
+ x += __shfl_xor_sync(0xffffffff, x, offset, 32);
185
  }
186
  return x;
187
  #endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_AMPERE
 
189
 
190
  static __device__ __forceinline__ float warp_reduce_sum(float x) {
191
  #pragma unroll
192
+ for (int offset = 16; offset > 0; offset >>= 1) {
193
+ x += __shfl_xor_sync(0xffffffff, x, offset, 32);
194
  }
195
  return x;
196
  }
197
 
198
  static __device__ __forceinline__ float2 warp_reduce_sum(float2 a) {
199
  #pragma unroll
200
+ for (int offset = 16; offset > 0; offset >>= 1) {
201
+ a.x += __shfl_xor_sync(0xffffffff, a.x, offset, 32);
202
+ a.y += __shfl_xor_sync(0xffffffff, a.y, offset, 32);
203
  }
204
  return a;
205
  }
 
209
 
210
  #if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
211
  #pragma unroll
212
+ for (int offset = 16; offset > 0; offset >>= 1) {
213
+ const half2 a_other = __shfl_xor_sync(0xffffffff, a, offset, 32);
214
  reinterpret_cast<half&>(a.x) += __low2half(a_other);
215
  reinterpret_cast<half&>(a.y) += __high2half(a_other);
216
  }
217
  return a;
218
  #else
219
  #pragma unroll
220
+ for (int offset = 16; offset > 0; offset >>= 1) {
221
+ a = __hadd2(a, __shfl_xor_sync(0xffffffff, a, offset, 32));
222
  }
223
  return a;
224
  #endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
 
231
 
232
  static __device__ __forceinline__ float warp_reduce_max(float x) {
233
  #pragma unroll
234
+ for (int offset = 16; offset > 0; offset >>= 1) {
235
+ x = fmaxf(x, __shfl_xor_sync(0xffffffff, x, offset, 32));
236
  }
237
  return x;
238
  }
 
275
  static __device__ __forceinline__ half2 warp_reduce_max(half2 x) {
276
  #if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL
277
  #pragma unroll
278
+ for (int offset = 16; offset > 0; offset >>= 1) {
279
+ x = ggml_cuda_hmax2(x, __shfl_xor_sync(0xffffffff, x, offset, 32));
280
  }
281
  return x;
282
  #else
ggml/src/ggml-cuda/quantize.cu CHANGED
@@ -69,8 +69,8 @@ static __global__ void quantize_mmq_q8_1(
69
 
70
  // Exchange max. abs. value between vals_per_scale/4 threads.
71
  #pragma unroll
72
- for (int mask = vals_per_scale/8; mask > 0; mask >>= 1) {
73
- amax = fmaxf(amax, __shfl_xor_sync(0xFFFFFFFF, amax, mask, WARP_SIZE));
74
  }
75
 
76
  float sum;
@@ -79,8 +79,8 @@ static __global__ void quantize_mmq_q8_1(
79
 
80
  // Exchange calculate sum across vals_per_sum/4 threads.
81
  #pragma unroll
82
- for (int mask = vals_per_sum/8; mask > 0; mask >>= 1) {
83
- sum += __shfl_xor_sync(0xFFFFFFFF, sum, mask, WARP_SIZE);
84
  }
85
  }
86
 
 
69
 
70
  // Exchange max. abs. value between vals_per_scale/4 threads.
71
  #pragma unroll
72
+ for (int offset = vals_per_scale/8; offset > 0; offset >>= 1) {
73
+ amax = fmaxf(amax, __shfl_xor_sync(0xFFFFFFFF, amax, offset, WARP_SIZE));
74
  }
75
 
76
  float sum;
 
79
 
80
  // Exchange calculate sum across vals_per_sum/4 threads.
81
  #pragma unroll
82
+ for (int offset = vals_per_sum/8; offset > 0; offset >>= 1) {
83
+ sum += __shfl_xor_sync(0xFFFFFFFF, sum, offset, WARP_SIZE);
84
  }
85
  }
86
 
ggml/src/ggml.c CHANGED
@@ -2255,6 +2255,7 @@ struct ggml_tensor * ggml_argmax(
2255
  struct ggml_context * ctx,
2256
  struct ggml_tensor * a) {
2257
  GGML_ASSERT(ggml_is_matrix(a));
 
2258
 
2259
  struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, a->ne[1]);
2260
 
@@ -4138,6 +4139,7 @@ struct ggml_tensor * ggml_argsort(
4138
  struct ggml_context * ctx,
4139
  struct ggml_tensor * a,
4140
  enum ggml_sort_order order) {
 
4141
  struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_I32, GGML_MAX_DIMS, a->ne);
4142
 
4143
  ggml_set_op_params_i32(result, 0, (int32_t) order);
 
2255
  struct ggml_context * ctx,
2256
  struct ggml_tensor * a) {
2257
  GGML_ASSERT(ggml_is_matrix(a));
2258
+ GGML_ASSERT(a->ne[0] <= INT32_MAX);
2259
 
2260
  struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, a->ne[1]);
2261
 
 
4139
  struct ggml_context * ctx,
4140
  struct ggml_tensor * a,
4141
  enum ggml_sort_order order) {
4142
+ GGML_ASSERT(a->ne[0] <= INT32_MAX);
4143
  struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_I32, GGML_MAX_DIMS, a->ne);
4144
 
4145
  ggml_set_op_params_i32(result, 0, (int32_t) order);