Spaces:
Running
Running
Commit
·
69ae50d
1
Parent(s):
0a14325
cuda : optimize argmax (llama/10441)
Browse files* cuda : optimize argmax
* remove unused parameter
ggml-ci
* fixup : use full warps
ggml-ci
* Apply suggestions from code review
Co-authored-by: Johannes Gäßler <[email protected]>
* fix ub
* ggml : check ne00 <= INT32_MAX in argmax and argsort
---------
Co-authored-by: Johannes Gäßler <[email protected]>
- ggml/src/ggml-cuda/argmax.cu +54 -42
- ggml/src/ggml-cuda/common.cuh +15 -15
- ggml/src/ggml-cuda/quantize.cu +4 -4
- ggml/src/ggml.c +2 -0
ggml/src/ggml-cuda/argmax.cu
CHANGED
|
@@ -1,57 +1,69 @@
|
|
| 1 |
-
#include
|
|
|
|
|
|
|
| 2 |
#include "argmax.cuh"
|
|
|
|
| 3 |
#include "sum.cuh"
|
| 4 |
|
| 5 |
-
|
|
|
|
| 6 |
|
| 7 |
-
|
| 8 |
-
|
|
|
|
| 9 |
|
| 10 |
-
|
| 11 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
|
| 13 |
#pragma unroll
|
| 14 |
-
for (
|
| 15 |
-
const
|
| 16 |
-
|
| 17 |
-
if (
|
| 18 |
-
|
|
|
|
| 19 |
}
|
|
|
|
| 20 |
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
}
|
| 32 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
#pragma unroll
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
}
|
| 43 |
-
|
| 44 |
-
const int store = row1 == threadIdx.x;
|
| 45 |
-
argmax_thread += store*argmax;
|
| 46 |
}
|
| 47 |
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
if (row >= nrows) {
|
| 51 |
-
return;
|
| 52 |
}
|
| 53 |
-
|
| 54 |
-
dst[row] = argmax_thread;
|
| 55 |
}
|
| 56 |
|
| 57 |
void ggml_cuda_argmax(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
|
@@ -70,10 +82,10 @@ void ggml_cuda_argmax(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
|
| 70 |
|
| 71 |
cudaStream_t stream = ctx.stream();
|
| 72 |
|
| 73 |
-
const int64_t num_blocks =
|
| 74 |
-
|
| 75 |
-
const dim3 blocks_dim(
|
| 76 |
const dim3 blocks_num(num_blocks, 1, 1);
|
| 77 |
|
| 78 |
-
argmax_f32<<<blocks_num, blocks_dim, 0, stream>>>(src0_d, dst_d, ne00
|
| 79 |
}
|
|
|
|
| 1 |
+
#include <algorithm>
|
| 2 |
+
#include <cstdint>
|
| 3 |
+
|
| 4 |
#include "argmax.cuh"
|
| 5 |
+
#include "common.cuh"
|
| 6 |
#include "sum.cuh"
|
| 7 |
|
| 8 |
+
static __global__ void argmax_f32(const float * __restrict__ x, int32_t * __restrict__ dst, const int64_t ncols) {
|
| 9 |
+
const int64_t row = blockIdx.x;
|
| 10 |
|
| 11 |
+
float maxval = -FLT_MAX;
|
| 12 |
+
int argmax = -1;
|
| 13 |
+
const float * rowx = x + row * ncols;
|
| 14 |
|
| 15 |
+
for (int32_t col = threadIdx.x; col < ncols; col += blockDim.x) {
|
| 16 |
+
const float val = rowx[col];
|
| 17 |
+
if (val > maxval) {
|
| 18 |
+
maxval = val;
|
| 19 |
+
argmax = col;
|
| 20 |
+
}
|
| 21 |
+
}
|
| 22 |
|
| 23 |
#pragma unroll
|
| 24 |
+
for (int offset = 16; offset > 0; offset >>= 1) {
|
| 25 |
+
const float val = __shfl_xor_sync(0xFFFFFFFF, maxval, offset, WARP_SIZE);
|
| 26 |
+
const int col = __shfl_xor_sync(0xFFFFFFFF, argmax, offset, WARP_SIZE);
|
| 27 |
+
if (val > maxval) {
|
| 28 |
+
maxval = val;
|
| 29 |
+
argmax = col;
|
| 30 |
}
|
| 31 |
+
}
|
| 32 |
|
| 33 |
+
const int n_warps = blockDim.x / WARP_SIZE;
|
| 34 |
+
const int lane_id = threadIdx.x % WARP_SIZE;
|
| 35 |
+
const int warp_id = threadIdx.x / WARP_SIZE;
|
| 36 |
+
if (n_warps > 1) {
|
| 37 |
+
constexpr int max_warps = 1024 / WARP_SIZE;
|
| 38 |
+
__shared__ float shared_maxval[max_warps];
|
| 39 |
+
__shared__ int shared_argmax[max_warps];
|
| 40 |
+
if (lane_id == 0) {
|
| 41 |
+
shared_maxval[warp_id] = maxval;
|
| 42 |
+
shared_argmax[warp_id] = argmax;
|
| 43 |
}
|
| 44 |
|
| 45 |
+
__syncthreads();
|
| 46 |
+
|
| 47 |
+
if (warp_id == 0) {
|
| 48 |
+
if (lane_id < n_warps) {
|
| 49 |
+
maxval = shared_maxval[lane_id];
|
| 50 |
+
argmax = shared_argmax[lane_id];
|
| 51 |
+
}
|
| 52 |
#pragma unroll
|
| 53 |
+
for (int offset = 16; offset > 0; offset >>= 1) {
|
| 54 |
+
const float val = __shfl_xor_sync(0xFFFFFFFF, maxval, offset, WARP_SIZE);
|
| 55 |
+
const int col = __shfl_xor_sync(0xFFFFFFFF, argmax, offset, WARP_SIZE);
|
| 56 |
+
if (val > maxval) {
|
| 57 |
+
maxval = val;
|
| 58 |
+
argmax = col;
|
| 59 |
+
}
|
| 60 |
+
}
|
| 61 |
}
|
|
|
|
|
|
|
|
|
|
| 62 |
}
|
| 63 |
|
| 64 |
+
if (warp_id == 0 && lane_id == 0) {
|
| 65 |
+
dst[row] = argmax;
|
|
|
|
|
|
|
| 66 |
}
|
|
|
|
|
|
|
| 67 |
}
|
| 68 |
|
| 69 |
void ggml_cuda_argmax(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
|
|
|
| 82 |
|
| 83 |
cudaStream_t stream = ctx.stream();
|
| 84 |
|
| 85 |
+
const int64_t num_blocks = nrows;
|
| 86 |
+
const int64_t num_threads = std::min<int64_t>(1024, (ne00 + WARP_SIZE - 1) / WARP_SIZE * WARP_SIZE);
|
| 87 |
+
const dim3 blocks_dim(num_threads, 1, 1);
|
| 88 |
const dim3 blocks_num(num_blocks, 1, 1);
|
| 89 |
|
| 90 |
+
argmax_f32<<<blocks_num, blocks_dim, 0, stream>>>(src0_d, dst_d, ne00);
|
| 91 |
}
|
ggml/src/ggml-cuda/common.cuh
CHANGED
|
@@ -180,8 +180,8 @@ static __device__ __forceinline__ int warp_reduce_sum(int x) {
|
|
| 180 |
return __reduce_add_sync(0xffffffff, x);
|
| 181 |
#else
|
| 182 |
#pragma unroll
|
| 183 |
-
for (int
|
| 184 |
-
x += __shfl_xor_sync(0xffffffff, x,
|
| 185 |
}
|
| 186 |
return x;
|
| 187 |
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_AMPERE
|
|
@@ -189,17 +189,17 @@ static __device__ __forceinline__ int warp_reduce_sum(int x) {
|
|
| 189 |
|
| 190 |
static __device__ __forceinline__ float warp_reduce_sum(float x) {
|
| 191 |
#pragma unroll
|
| 192 |
-
for (int
|
| 193 |
-
x += __shfl_xor_sync(0xffffffff, x,
|
| 194 |
}
|
| 195 |
return x;
|
| 196 |
}
|
| 197 |
|
| 198 |
static __device__ __forceinline__ float2 warp_reduce_sum(float2 a) {
|
| 199 |
#pragma unroll
|
| 200 |
-
for (int
|
| 201 |
-
a.x += __shfl_xor_sync(0xffffffff, a.x,
|
| 202 |
-
a.y += __shfl_xor_sync(0xffffffff, a.y,
|
| 203 |
}
|
| 204 |
return a;
|
| 205 |
}
|
|
@@ -209,16 +209,16 @@ static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) {
|
|
| 209 |
|
| 210 |
#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
|
| 211 |
#pragma unroll
|
| 212 |
-
for (int
|
| 213 |
-
const half2 a_other = __shfl_xor_sync(0xffffffff, a,
|
| 214 |
reinterpret_cast<half&>(a.x) += __low2half(a_other);
|
| 215 |
reinterpret_cast<half&>(a.y) += __high2half(a_other);
|
| 216 |
}
|
| 217 |
return a;
|
| 218 |
#else
|
| 219 |
#pragma unroll
|
| 220 |
-
for (int
|
| 221 |
-
a = __hadd2(a, __shfl_xor_sync(0xffffffff, a,
|
| 222 |
}
|
| 223 |
return a;
|
| 224 |
#endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
|
|
@@ -231,8 +231,8 @@ static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) {
|
|
| 231 |
|
| 232 |
static __device__ __forceinline__ float warp_reduce_max(float x) {
|
| 233 |
#pragma unroll
|
| 234 |
-
for (int
|
| 235 |
-
x = fmaxf(x, __shfl_xor_sync(0xffffffff, x,
|
| 236 |
}
|
| 237 |
return x;
|
| 238 |
}
|
|
@@ -275,8 +275,8 @@ static __device__ __forceinline__ half2 ggml_cuda_hmax2(const half2 a, const hal
|
|
| 275 |
static __device__ __forceinline__ half2 warp_reduce_max(half2 x) {
|
| 276 |
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL
|
| 277 |
#pragma unroll
|
| 278 |
-
for (int
|
| 279 |
-
x = ggml_cuda_hmax2(x, __shfl_xor_sync(0xffffffff, x,
|
| 280 |
}
|
| 281 |
return x;
|
| 282 |
#else
|
|
|
|
| 180 |
return __reduce_add_sync(0xffffffff, x);
|
| 181 |
#else
|
| 182 |
#pragma unroll
|
| 183 |
+
for (int offset = 16; offset > 0; offset >>= 1) {
|
| 184 |
+
x += __shfl_xor_sync(0xffffffff, x, offset, 32);
|
| 185 |
}
|
| 186 |
return x;
|
| 187 |
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_AMPERE
|
|
|
|
| 189 |
|
| 190 |
static __device__ __forceinline__ float warp_reduce_sum(float x) {
|
| 191 |
#pragma unroll
|
| 192 |
+
for (int offset = 16; offset > 0; offset >>= 1) {
|
| 193 |
+
x += __shfl_xor_sync(0xffffffff, x, offset, 32);
|
| 194 |
}
|
| 195 |
return x;
|
| 196 |
}
|
| 197 |
|
| 198 |
static __device__ __forceinline__ float2 warp_reduce_sum(float2 a) {
|
| 199 |
#pragma unroll
|
| 200 |
+
for (int offset = 16; offset > 0; offset >>= 1) {
|
| 201 |
+
a.x += __shfl_xor_sync(0xffffffff, a.x, offset, 32);
|
| 202 |
+
a.y += __shfl_xor_sync(0xffffffff, a.y, offset, 32);
|
| 203 |
}
|
| 204 |
return a;
|
| 205 |
}
|
|
|
|
| 209 |
|
| 210 |
#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
|
| 211 |
#pragma unroll
|
| 212 |
+
for (int offset = 16; offset > 0; offset >>= 1) {
|
| 213 |
+
const half2 a_other = __shfl_xor_sync(0xffffffff, a, offset, 32);
|
| 214 |
reinterpret_cast<half&>(a.x) += __low2half(a_other);
|
| 215 |
reinterpret_cast<half&>(a.y) += __high2half(a_other);
|
| 216 |
}
|
| 217 |
return a;
|
| 218 |
#else
|
| 219 |
#pragma unroll
|
| 220 |
+
for (int offset = 16; offset > 0; offset >>= 1) {
|
| 221 |
+
a = __hadd2(a, __shfl_xor_sync(0xffffffff, a, offset, 32));
|
| 222 |
}
|
| 223 |
return a;
|
| 224 |
#endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
|
|
|
|
| 231 |
|
| 232 |
static __device__ __forceinline__ float warp_reduce_max(float x) {
|
| 233 |
#pragma unroll
|
| 234 |
+
for (int offset = 16; offset > 0; offset >>= 1) {
|
| 235 |
+
x = fmaxf(x, __shfl_xor_sync(0xffffffff, x, offset, 32));
|
| 236 |
}
|
| 237 |
return x;
|
| 238 |
}
|
|
|
|
| 275 |
static __device__ __forceinline__ half2 warp_reduce_max(half2 x) {
|
| 276 |
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL
|
| 277 |
#pragma unroll
|
| 278 |
+
for (int offset = 16; offset > 0; offset >>= 1) {
|
| 279 |
+
x = ggml_cuda_hmax2(x, __shfl_xor_sync(0xffffffff, x, offset, 32));
|
| 280 |
}
|
| 281 |
return x;
|
| 282 |
#else
|
ggml/src/ggml-cuda/quantize.cu
CHANGED
|
@@ -69,8 +69,8 @@ static __global__ void quantize_mmq_q8_1(
|
|
| 69 |
|
| 70 |
// Exchange max. abs. value between vals_per_scale/4 threads.
|
| 71 |
#pragma unroll
|
| 72 |
-
for (int
|
| 73 |
-
amax = fmaxf(amax, __shfl_xor_sync(0xFFFFFFFF, amax,
|
| 74 |
}
|
| 75 |
|
| 76 |
float sum;
|
|
@@ -79,8 +79,8 @@ static __global__ void quantize_mmq_q8_1(
|
|
| 79 |
|
| 80 |
// Exchange calculate sum across vals_per_sum/4 threads.
|
| 81 |
#pragma unroll
|
| 82 |
-
for (int
|
| 83 |
-
sum += __shfl_xor_sync(0xFFFFFFFF, sum,
|
| 84 |
}
|
| 85 |
}
|
| 86 |
|
|
|
|
| 69 |
|
| 70 |
// Exchange max. abs. value between vals_per_scale/4 threads.
|
| 71 |
#pragma unroll
|
| 72 |
+
for (int offset = vals_per_scale/8; offset > 0; offset >>= 1) {
|
| 73 |
+
amax = fmaxf(amax, __shfl_xor_sync(0xFFFFFFFF, amax, offset, WARP_SIZE));
|
| 74 |
}
|
| 75 |
|
| 76 |
float sum;
|
|
|
|
| 79 |
|
| 80 |
// Exchange calculate sum across vals_per_sum/4 threads.
|
| 81 |
#pragma unroll
|
| 82 |
+
for (int offset = vals_per_sum/8; offset > 0; offset >>= 1) {
|
| 83 |
+
sum += __shfl_xor_sync(0xFFFFFFFF, sum, offset, WARP_SIZE);
|
| 84 |
}
|
| 85 |
}
|
| 86 |
|
ggml/src/ggml.c
CHANGED
|
@@ -2255,6 +2255,7 @@ struct ggml_tensor * ggml_argmax(
|
|
| 2255 |
struct ggml_context * ctx,
|
| 2256 |
struct ggml_tensor * a) {
|
| 2257 |
GGML_ASSERT(ggml_is_matrix(a));
|
|
|
|
| 2258 |
|
| 2259 |
struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, a->ne[1]);
|
| 2260 |
|
|
@@ -4138,6 +4139,7 @@ struct ggml_tensor * ggml_argsort(
|
|
| 4138 |
struct ggml_context * ctx,
|
| 4139 |
struct ggml_tensor * a,
|
| 4140 |
enum ggml_sort_order order) {
|
|
|
|
| 4141 |
struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_I32, GGML_MAX_DIMS, a->ne);
|
| 4142 |
|
| 4143 |
ggml_set_op_params_i32(result, 0, (int32_t) order);
|
|
|
|
| 2255 |
struct ggml_context * ctx,
|
| 2256 |
struct ggml_tensor * a) {
|
| 2257 |
GGML_ASSERT(ggml_is_matrix(a));
|
| 2258 |
+
GGML_ASSERT(a->ne[0] <= INT32_MAX);
|
| 2259 |
|
| 2260 |
struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, a->ne[1]);
|
| 2261 |
|
|
|
|
| 4139 |
struct ggml_context * ctx,
|
| 4140 |
struct ggml_tensor * a,
|
| 4141 |
enum ggml_sort_order order) {
|
| 4142 |
+
GGML_ASSERT(a->ne[0] <= INT32_MAX);
|
| 4143 |
struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_I32, GGML_MAX_DIMS, a->ne);
|
| 4144 |
|
| 4145 |
ggml_set_op_params_i32(result, 0, (int32_t) order);
|