Engininja2 commited on
Commit
753b30d
·
unverified ·
1 Parent(s): 72e8610

cuda : replace remaining shfl_xor with calls to warp_reduce functions (llama/5744)

Browse files
Files changed (1) hide show
  1. ggml-cuda.cu +24 -49
ggml-cuda.cu CHANGED
@@ -696,18 +696,20 @@ static __device__ __forceinline__ float2 warp_reduce_sum(float2 a) {
696
  return a;
697
  }
698
 
699
- //static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) {
700
- //#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL
701
- //#pragma unroll
702
- // for (int mask = 16; mask > 0; mask >>= 1) {
703
- // a = __hadd2(a, __shfl_xor_sync(0xffffffff, a, mask, 32));
704
- // }
705
- // return a;
706
- //#else
707
- // (void) a;
708
- // NO_DEVICE_CODE;
709
- //#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL
710
- //}
 
 
711
 
712
  static __device__ __forceinline__ float warp_reduce_max(float x) {
713
  #pragma unroll
@@ -2521,10 +2523,7 @@ static __global__ void dequantize_mul_mat_vec_q2_k(const void * __restrict__ vx,
2521
  #endif
2522
 
2523
  // sum up partial sums and write back result
2524
- #pragma unroll
2525
- for (int mask = 16; mask > 0; mask >>= 1) {
2526
- tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
2527
- }
2528
 
2529
  if (threadIdx.x == 0) {
2530
  dst[row] = tmp;
@@ -2625,10 +2624,7 @@ static __global__ void dequantize_mul_mat_vec_q3_k(const void * __restrict__ vx,
2625
  #endif
2626
 
2627
  // sum up partial sums and write back result
2628
- #pragma unroll
2629
- for (int mask = 16; mask > 0; mask >>= 1) {
2630
- tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
2631
- }
2632
 
2633
  if (threadIdx.x == 0) {
2634
  dst[row] = tmp;
@@ -2761,10 +2757,7 @@ static __global__ void dequantize_mul_mat_vec_q4_k(const void * __restrict__ vx,
2761
  #endif
2762
 
2763
  // sum up partial sums and write back result
2764
- #pragma unroll
2765
- for (int mask = 16; mask > 0; mask >>= 1) {
2766
- tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
2767
- }
2768
 
2769
  if (tid == 0) {
2770
  dst[row] = tmp;
@@ -2877,10 +2870,7 @@ static __global__ void dequantize_mul_mat_vec_q5_k(const void * __restrict__ vx,
2877
  #endif
2878
 
2879
  // sum up partial sums and write back result
2880
- #pragma unroll
2881
- for (int mask = 16; mask > 0; mask >>= 1) {
2882
- tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
2883
- }
2884
 
2885
  if (threadIdx.x == 0) {
2886
  dst[row] = tmp;
@@ -2987,10 +2977,7 @@ static __global__ void dequantize_mul_mat_vec_q6_k(const void * __restrict__ vx,
2987
  #endif
2988
 
2989
  // sum up partial sums and write back result
2990
- #pragma unroll
2991
- for (int mask = 16; mask > 0; mask >>= 1) {
2992
- tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
2993
- }
2994
 
2995
  if (tid == 0) {
2996
  dst[row] = tmp;
@@ -3025,11 +3012,8 @@ static __global__ void quantize_q8_1(const float * __restrict__ x, void * __rest
3025
  float amax = fabsf(xi);
3026
  float sum = xi;
3027
 
3028
- #pragma unroll
3029
- for (int mask = 16; mask > 0; mask >>= 1) {
3030
- amax = fmaxf(amax, __shfl_xor_sync(0xffffffff, amax, mask, 32));
3031
- sum += __shfl_xor_sync(0xffffffff, sum, mask, 32);
3032
- }
3033
 
3034
  const float d = amax / 127;
3035
  const int8_t q = amax == 0.0f ? 0 : roundf(xi / d);
@@ -6222,10 +6206,7 @@ static __global__ void dequantize_mul_mat_vec(const void * __restrict__ vx, cons
6222
  }
6223
 
6224
  // sum up partial sums and write back result
6225
- #pragma unroll
6226
- for (int mask = 16; mask > 0; mask >>= 1) {
6227
- tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
6228
- }
6229
 
6230
  if (tid == 0) {
6231
  #ifdef GGML_CUDA_F16
@@ -6275,10 +6256,7 @@ static __global__ void mul_mat_p021_f16_f32(
6275
  const int idst = channel*nrows_dst + row_dst;
6276
 
6277
  // sum up partial sums and write back result
6278
- #pragma unroll
6279
- for (int mask = 16; mask > 0; mask >>= 1) {
6280
- tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
6281
- }
6282
 
6283
  if (threadIdx.x == 0) {
6284
  dst[idst] = tmp;
@@ -6321,10 +6299,7 @@ static __global__ void mul_mat_vec_nc_f16_f32( // nc == non-contiguous
6321
  }
6322
 
6323
  // sum up partial sums and write back result
6324
- #pragma unroll
6325
- for (int mask = 16; mask > 0; mask >>= 1) {
6326
- tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
6327
- }
6328
 
6329
  if (threadIdx.x == 0) {
6330
  dst[idst] = tmp;
 
696
  return a;
697
  }
698
 
699
+ #ifdef GGML_CUDA_F16
700
+ static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) {
701
+ #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL
702
+ #pragma unroll
703
+ for (int mask = 16; mask > 0; mask >>= 1) {
704
+ a = __hadd2(a, __shfl_xor_sync(0xffffffff, a, mask, 32));
705
+ }
706
+ return a;
707
+ #else
708
+ (void) a;
709
+ NO_DEVICE_CODE;
710
+ #endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL
711
+ }
712
+ #endif // GGML_CUDA_F16
713
 
714
  static __device__ __forceinline__ float warp_reduce_max(float x) {
715
  #pragma unroll
 
2523
  #endif
2524
 
2525
  // sum up partial sums and write back result
2526
+ tmp = warp_reduce_sum(tmp);
 
 
 
2527
 
2528
  if (threadIdx.x == 0) {
2529
  dst[row] = tmp;
 
2624
  #endif
2625
 
2626
  // sum up partial sums and write back result
2627
+ tmp = warp_reduce_sum(tmp);
 
 
 
2628
 
2629
  if (threadIdx.x == 0) {
2630
  dst[row] = tmp;
 
2757
  #endif
2758
 
2759
  // sum up partial sums and write back result
2760
+ tmp = warp_reduce_sum(tmp);
 
 
 
2761
 
2762
  if (tid == 0) {
2763
  dst[row] = tmp;
 
2870
  #endif
2871
 
2872
  // sum up partial sums and write back result
2873
+ tmp = warp_reduce_sum(tmp);
 
 
 
2874
 
2875
  if (threadIdx.x == 0) {
2876
  dst[row] = tmp;
 
2977
  #endif
2978
 
2979
  // sum up partial sums and write back result
2980
+ tmp = warp_reduce_sum(tmp);
 
 
 
2981
 
2982
  if (tid == 0) {
2983
  dst[row] = tmp;
 
3012
  float amax = fabsf(xi);
3013
  float sum = xi;
3014
 
3015
+ amax = warp_reduce_max(amax);
3016
+ sum = warp_reduce_sum(sum);
 
 
 
3017
 
3018
  const float d = amax / 127;
3019
  const int8_t q = amax == 0.0f ? 0 : roundf(xi / d);
 
6206
  }
6207
 
6208
  // sum up partial sums and write back result
6209
+ tmp = warp_reduce_sum(tmp);
 
 
 
6210
 
6211
  if (tid == 0) {
6212
  #ifdef GGML_CUDA_F16
 
6256
  const int idst = channel*nrows_dst + row_dst;
6257
 
6258
  // sum up partial sums and write back result
6259
+ tmp = warp_reduce_sum(tmp);
 
 
 
6260
 
6261
  if (threadIdx.x == 0) {
6262
  dst[idst] = tmp;
 
6299
  }
6300
 
6301
  // sum up partial sums and write back result
6302
+ tmp = warp_reduce_sum(tmp);
 
 
 
6303
 
6304
  if (threadIdx.x == 0) {
6305
  dst[idst] = tmp;