Spaces:
Running
Running
Engininja2
commited on
cuda : replace remaining shfl_xor with calls to warp_reduce functions (llama/5744)
Browse files- ggml-cuda.cu +24 -49
ggml-cuda.cu
CHANGED
|
@@ -696,18 +696,20 @@ static __device__ __forceinline__ float2 warp_reduce_sum(float2 a) {
|
|
| 696 |
return a;
|
| 697 |
}
|
| 698 |
|
| 699 |
-
|
| 700 |
-
|
| 701 |
-
|
| 702 |
-
|
| 703 |
-
|
| 704 |
-
|
| 705 |
-
|
| 706 |
-
|
| 707 |
-
|
| 708 |
-
|
| 709 |
-
|
| 710 |
-
//
|
|
|
|
|
|
|
| 711 |
|
| 712 |
static __device__ __forceinline__ float warp_reduce_max(float x) {
|
| 713 |
#pragma unroll
|
|
@@ -2521,10 +2523,7 @@ static __global__ void dequantize_mul_mat_vec_q2_k(const void * __restrict__ vx,
|
|
| 2521 |
#endif
|
| 2522 |
|
| 2523 |
// sum up partial sums and write back result
|
| 2524 |
-
|
| 2525 |
-
for (int mask = 16; mask > 0; mask >>= 1) {
|
| 2526 |
-
tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
|
| 2527 |
-
}
|
| 2528 |
|
| 2529 |
if (threadIdx.x == 0) {
|
| 2530 |
dst[row] = tmp;
|
|
@@ -2625,10 +2624,7 @@ static __global__ void dequantize_mul_mat_vec_q3_k(const void * __restrict__ vx,
|
|
| 2625 |
#endif
|
| 2626 |
|
| 2627 |
// sum up partial sums and write back result
|
| 2628 |
-
|
| 2629 |
-
for (int mask = 16; mask > 0; mask >>= 1) {
|
| 2630 |
-
tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
|
| 2631 |
-
}
|
| 2632 |
|
| 2633 |
if (threadIdx.x == 0) {
|
| 2634 |
dst[row] = tmp;
|
|
@@ -2761,10 +2757,7 @@ static __global__ void dequantize_mul_mat_vec_q4_k(const void * __restrict__ vx,
|
|
| 2761 |
#endif
|
| 2762 |
|
| 2763 |
// sum up partial sums and write back result
|
| 2764 |
-
|
| 2765 |
-
for (int mask = 16; mask > 0; mask >>= 1) {
|
| 2766 |
-
tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
|
| 2767 |
-
}
|
| 2768 |
|
| 2769 |
if (tid == 0) {
|
| 2770 |
dst[row] = tmp;
|
|
@@ -2877,10 +2870,7 @@ static __global__ void dequantize_mul_mat_vec_q5_k(const void * __restrict__ vx,
|
|
| 2877 |
#endif
|
| 2878 |
|
| 2879 |
// sum up partial sums and write back result
|
| 2880 |
-
|
| 2881 |
-
for (int mask = 16; mask > 0; mask >>= 1) {
|
| 2882 |
-
tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
|
| 2883 |
-
}
|
| 2884 |
|
| 2885 |
if (threadIdx.x == 0) {
|
| 2886 |
dst[row] = tmp;
|
|
@@ -2987,10 +2977,7 @@ static __global__ void dequantize_mul_mat_vec_q6_k(const void * __restrict__ vx,
|
|
| 2987 |
#endif
|
| 2988 |
|
| 2989 |
// sum up partial sums and write back result
|
| 2990 |
-
|
| 2991 |
-
for (int mask = 16; mask > 0; mask >>= 1) {
|
| 2992 |
-
tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
|
| 2993 |
-
}
|
| 2994 |
|
| 2995 |
if (tid == 0) {
|
| 2996 |
dst[row] = tmp;
|
|
@@ -3025,11 +3012,8 @@ static __global__ void quantize_q8_1(const float * __restrict__ x, void * __rest
|
|
| 3025 |
float amax = fabsf(xi);
|
| 3026 |
float sum = xi;
|
| 3027 |
|
| 3028 |
-
|
| 3029 |
-
|
| 3030 |
-
amax = fmaxf(amax, __shfl_xor_sync(0xffffffff, amax, mask, 32));
|
| 3031 |
-
sum += __shfl_xor_sync(0xffffffff, sum, mask, 32);
|
| 3032 |
-
}
|
| 3033 |
|
| 3034 |
const float d = amax / 127;
|
| 3035 |
const int8_t q = amax == 0.0f ? 0 : roundf(xi / d);
|
|
@@ -6222,10 +6206,7 @@ static __global__ void dequantize_mul_mat_vec(const void * __restrict__ vx, cons
|
|
| 6222 |
}
|
| 6223 |
|
| 6224 |
// sum up partial sums and write back result
|
| 6225 |
-
|
| 6226 |
-
for (int mask = 16; mask > 0; mask >>= 1) {
|
| 6227 |
-
tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
|
| 6228 |
-
}
|
| 6229 |
|
| 6230 |
if (tid == 0) {
|
| 6231 |
#ifdef GGML_CUDA_F16
|
|
@@ -6275,10 +6256,7 @@ static __global__ void mul_mat_p021_f16_f32(
|
|
| 6275 |
const int idst = channel*nrows_dst + row_dst;
|
| 6276 |
|
| 6277 |
// sum up partial sums and write back result
|
| 6278 |
-
|
| 6279 |
-
for (int mask = 16; mask > 0; mask >>= 1) {
|
| 6280 |
-
tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
|
| 6281 |
-
}
|
| 6282 |
|
| 6283 |
if (threadIdx.x == 0) {
|
| 6284 |
dst[idst] = tmp;
|
|
@@ -6321,10 +6299,7 @@ static __global__ void mul_mat_vec_nc_f16_f32( // nc == non-contiguous
|
|
| 6321 |
}
|
| 6322 |
|
| 6323 |
// sum up partial sums and write back result
|
| 6324 |
-
|
| 6325 |
-
for (int mask = 16; mask > 0; mask >>= 1) {
|
| 6326 |
-
tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
|
| 6327 |
-
}
|
| 6328 |
|
| 6329 |
if (threadIdx.x == 0) {
|
| 6330 |
dst[idst] = tmp;
|
|
|
|
| 696 |
return a;
|
| 697 |
}
|
| 698 |
|
| 699 |
+
#ifdef GGML_CUDA_F16
|
| 700 |
+
static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) {
|
| 701 |
+
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL
|
| 702 |
+
#pragma unroll
|
| 703 |
+
for (int mask = 16; mask > 0; mask >>= 1) {
|
| 704 |
+
a = __hadd2(a, __shfl_xor_sync(0xffffffff, a, mask, 32));
|
| 705 |
+
}
|
| 706 |
+
return a;
|
| 707 |
+
#else
|
| 708 |
+
(void) a;
|
| 709 |
+
NO_DEVICE_CODE;
|
| 710 |
+
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL
|
| 711 |
+
}
|
| 712 |
+
#endif // GGML_CUDA_F16
|
| 713 |
|
| 714 |
static __device__ __forceinline__ float warp_reduce_max(float x) {
|
| 715 |
#pragma unroll
|
|
|
|
| 2523 |
#endif
|
| 2524 |
|
| 2525 |
// sum up partial sums and write back result
|
| 2526 |
+
tmp = warp_reduce_sum(tmp);
|
|
|
|
|
|
|
|
|
|
| 2527 |
|
| 2528 |
if (threadIdx.x == 0) {
|
| 2529 |
dst[row] = tmp;
|
|
|
|
| 2624 |
#endif
|
| 2625 |
|
| 2626 |
// sum up partial sums and write back result
|
| 2627 |
+
tmp = warp_reduce_sum(tmp);
|
|
|
|
|
|
|
|
|
|
| 2628 |
|
| 2629 |
if (threadIdx.x == 0) {
|
| 2630 |
dst[row] = tmp;
|
|
|
|
| 2757 |
#endif
|
| 2758 |
|
| 2759 |
// sum up partial sums and write back result
|
| 2760 |
+
tmp = warp_reduce_sum(tmp);
|
|
|
|
|
|
|
|
|
|
| 2761 |
|
| 2762 |
if (tid == 0) {
|
| 2763 |
dst[row] = tmp;
|
|
|
|
| 2870 |
#endif
|
| 2871 |
|
| 2872 |
// sum up partial sums and write back result
|
| 2873 |
+
tmp = warp_reduce_sum(tmp);
|
|
|
|
|
|
|
|
|
|
| 2874 |
|
| 2875 |
if (threadIdx.x == 0) {
|
| 2876 |
dst[row] = tmp;
|
|
|
|
| 2977 |
#endif
|
| 2978 |
|
| 2979 |
// sum up partial sums and write back result
|
| 2980 |
+
tmp = warp_reduce_sum(tmp);
|
|
|
|
|
|
|
|
|
|
| 2981 |
|
| 2982 |
if (tid == 0) {
|
| 2983 |
dst[row] = tmp;
|
|
|
|
| 3012 |
float amax = fabsf(xi);
|
| 3013 |
float sum = xi;
|
| 3014 |
|
| 3015 |
+
amax = warp_reduce_max(amax);
|
| 3016 |
+
sum = warp_reduce_sum(sum);
|
|
|
|
|
|
|
|
|
|
| 3017 |
|
| 3018 |
const float d = amax / 127;
|
| 3019 |
const int8_t q = amax == 0.0f ? 0 : roundf(xi / d);
|
|
|
|
| 6206 |
}
|
| 6207 |
|
| 6208 |
// sum up partial sums and write back result
|
| 6209 |
+
tmp = warp_reduce_sum(tmp);
|
|
|
|
|
|
|
|
|
|
| 6210 |
|
| 6211 |
if (tid == 0) {
|
| 6212 |
#ifdef GGML_CUDA_F16
|
|
|
|
| 6256 |
const int idst = channel*nrows_dst + row_dst;
|
| 6257 |
|
| 6258 |
// sum up partial sums and write back result
|
| 6259 |
+
tmp = warp_reduce_sum(tmp);
|
|
|
|
|
|
|
|
|
|
| 6260 |
|
| 6261 |
if (threadIdx.x == 0) {
|
| 6262 |
dst[idst] = tmp;
|
|
|
|
| 6299 |
}
|
| 6300 |
|
| 6301 |
// sum up partial sums and write back result
|
| 6302 |
+
tmp = warp_reduce_sum(tmp);
|
|
|
|
|
|
|
|
|
|
| 6303 |
|
| 6304 |
if (threadIdx.x == 0) {
|
| 6305 |
dst[idst] = tmp;
|