JohannesGaessler commited on
Commit
1136116
·
1 Parent(s): c3e51a2

CUDA: fix --split-mode row for MMQ (llama/13323)

Browse files
ggml/src/ggml-cuda/mmq.cu CHANGED
@@ -128,7 +128,7 @@ void ggml_cuda_mul_mat_q(
128
 
129
  const mmq_args args = {
130
  src0_d, src0->type, (const int *) src1_q8_1.ptr, nullptr, nullptr, dst_d,
131
- ne00, ne01, ne1, s01, s1,
132
  ne02, ne12, s02, s12, s2,
133
  ne03, ne13, s03, s13, s3,
134
  use_stream_k};
@@ -212,7 +212,7 @@ void ggml_cuda_mul_mat_q(
212
  // Note that ne02 is used instead of ne12 because the number of y channels determines the z dimension of the CUDA grid.
213
  const mmq_args args = {
214
  src0_d, src0->type, (const int *) src1_q8_1.ptr, ids_dst_dev, expert_bounds_dev, dst_d,
215
- ne00, ne01, ne_get_rows, s01, s1,
216
  ne02, ne02, s02, s12, s2,
217
  ne03, ne13, s03, s13, s3,
218
  use_stream_k};
@@ -251,7 +251,7 @@ void ggml_cuda_op_mul_mat_q(
251
  ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA && src1_ncols == ne11;
252
  const mmq_args args = {
253
  src0_dd_i, src0->type, (const int *) src1_ddq_i, nullptr, nullptr, dst_dd_i,
254
- ne00, row_diff, src1_ncols, stride01, nrows_dst,
255
  1, 1, 0, 0, 0,
256
  1, 1, 0, 0, 0,
257
  use_stream_k};
 
128
 
129
  const mmq_args args = {
130
  src0_d, src0->type, (const int *) src1_q8_1.ptr, nullptr, nullptr, dst_d,
131
+ ne00, ne01, ne1, s01, ne11, s1,
132
  ne02, ne12, s02, s12, s2,
133
  ne03, ne13, s03, s13, s3,
134
  use_stream_k};
 
212
  // Note that ne02 is used instead of ne12 because the number of y channels determines the z dimension of the CUDA grid.
213
  const mmq_args args = {
214
  src0_d, src0->type, (const int *) src1_q8_1.ptr, ids_dst_dev, expert_bounds_dev, dst_d,
215
+ ne00, ne01, ne_get_rows, s01, ne_get_rows, s1,
216
  ne02, ne02, s02, s12, s2,
217
  ne03, ne13, s03, s13, s3,
218
  use_stream_k};
 
251
  ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA && src1_ncols == ne11;
252
  const mmq_args args = {
253
  src0_dd_i, src0->type, (const int *) src1_ddq_i, nullptr, nullptr, dst_dd_i,
254
+ ne00, row_diff, src1_ncols, stride01, ne11, nrows_dst,
255
  1, 1, 0, 0, 0,
256
  1, 1, 0, 0, 0,
257
  use_stream_k};
ggml/src/ggml-cuda/mmq.cuh CHANGED
@@ -2522,7 +2522,7 @@ template <ggml_type type, int mmq_x, int nwarps, bool need_check, bool fixup>
2522
  static __device__ __forceinline__ void mul_mat_q_process_tile(
2523
  const char * __restrict__ x, const int offset_x, const int * __restrict__ y,
2524
  const int * __restrict__ ids_dst, float * __restrict__ dst, float * __restrict__ tmp_fixup,
2525
- const int nrows_x, const int ncols_y, const int stride_row_x, const int stride_col_dst,
2526
  const int tile_x_max_i, const int tile_y_max_j, const int kb0_start, const int kb0_stop) {
2527
 
2528
  constexpr int qk = ggml_cuda_type_traits<type>::qk;
@@ -2606,7 +2606,7 @@ template <ggml_type type, int mmq_x, int nwarps, bool need_check>
2606
  static __global__ void mul_mat_q(
2607
  const char * __restrict__ x, const int * __restrict__ y, const int32_t * __restrict__ ids_dst,
2608
  const int32_t * __restrict__ expert_bounds, float * __restrict__ dst, float * __restrict__ tmp_fixup,
2609
- const int ncols_x, const int nrows_x, const int ncols_y, const int stride_row_x, const int stride_col_dst,
2610
  const int channel_ratio, const int nchannels_y, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
2611
  const int sample_ratio, const int nsamples_y, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst) {
2612
 
@@ -2619,8 +2619,8 @@ static __global__ void mul_mat_q(
2619
  constexpr int qk = ggml_cuda_type_traits<type>::qk;
2620
  constexpr int mmq_y = get_mmq_y_device();
2621
 
2622
- const int ntx = (ncols_y + mmq_x - 1) / mmq_x; // Number of tiles x
2623
- const int nty = (nrows_x + mmq_y - 1) / mmq_y; // Number of tiles y
2624
 
2625
  // Initialize the ids for writing back data with just the index.
2626
  // For regular matrix multiplications this is never changed.
@@ -2648,8 +2648,8 @@ static __global__ void mul_mat_q(
2648
 
2649
  // Defaults for regular matrix multiplication:
2650
  int col_low = 0;
2651
- int col_high = ncols_y;
2652
- int col_diff = ncols_y;
2653
  int offset_y = wt*stride_sample_y + zt*stride_channel_y;
2654
  int offset_dst = wt*stride_sample_dst + zt*stride_channel_dst + jt*mmq_x*stride_col_dst;
2655
 
@@ -2689,7 +2689,7 @@ static __global__ void mul_mat_q(
2689
 
2690
  constexpr bool fixup = false;
2691
  mul_mat_q_process_tile<type, mmq_x, nwarps, need_check, fixup>
2692
- (x, offset_x, y + offset_y, ids_dst_shared, dst + offset_dst, tmp_fixup, nrows_x, ncols_y, stride_row_x, stride_col_dst,
2693
  tile_x_max_i, tile_y_max_j, 0, ncols_x/qk);
2694
  return;
2695
  }
@@ -2720,8 +2720,8 @@ static __global__ void mul_mat_q(
2720
 
2721
  // Defaults for regular matrix multiplication:
2722
  int col_low = 0;
2723
- int col_high = ncols_y;
2724
- int col_diff = ncols_y;
2725
  int offset_y = wt*stride_sample_y + zt*stride_channel_y;
2726
  int offset_dst = wt*stride_sample_dst + zt*stride_channel_dst + jt*mmq_x*stride_col_dst;
2727
 
@@ -2767,7 +2767,7 @@ static __global__ void mul_mat_q(
2767
 
2768
  constexpr bool fixup = false; // All but (potentially) the last iterations write their data to dst rather than the fixup buffer.
2769
  mul_mat_q_process_tile<type, mmq_x, nwarps, need_check, fixup>
2770
- (x, offset_x, y + offset_y, ids_dst_shared, dst + offset_dst, tmp_fixup, nrows_x, ncols_y, stride_row_x, stride_col_dst,
2771
  tile_x_max_i, tile_y_max_j, kb0_start, kb0_stop);
2772
 
2773
  kbc += blocks_per_ne00;
@@ -2792,8 +2792,8 @@ static __global__ void mul_mat_q(
2792
 
2793
  // Defaults for regular matrix multiplication:
2794
  int col_low = 0;
2795
- int col_high = ncols_y;
2796
- int col_diff = ncols_y;
2797
  int offset_y = wt*stride_sample_y + zt*stride_channel_y;
2798
  int offset_dst = wt*stride_sample_dst + zt*stride_channel_dst + jt*mmq_x*stride_col_dst;
2799
 
@@ -2834,7 +2834,7 @@ static __global__ void mul_mat_q(
2834
 
2835
  constexpr bool fixup = true; // Last index writes its data to fixup buffer to avoid data races with other blocks.
2836
  mul_mat_q_process_tile<type, mmq_x, nwarps, need_check, fixup>
2837
- (x, offset_x, y + offset_y, ids_dst_shared, dst + offset_dst, tmp_fixup, nrows_x, ncols_y, stride_row_x, stride_col_dst,
2838
  tile_x_max_i, tile_y_max_j, kb0_start, kb0_stop);
2839
  }
2840
 
@@ -2842,7 +2842,7 @@ static __global__ void mul_mat_q(
2842
  template <ggml_type type, int mmq_x, int nwarps, bool need_check>
2843
  static __global__ void mul_mat_q_stream_k_fixup(
2844
  const int32_t * ids_dst, const int32_t * expert_bounds, float * __restrict__ dst, const float * __restrict__ tmp_last_tile,
2845
- const int ncols_x, const int nrows_x, const int ncols_y, const int stride_col_dst,
2846
  const int nchannels_y, const int stride_channel_dst, const int nsamples_y, const int stride_sample_dst) {
2847
  constexpr int mmq_y = get_mmq_y_device();
2848
  constexpr int qk = ggml_cuda_type_traits<type>::qk;
@@ -2851,8 +2851,8 @@ static __global__ void mul_mat_q_stream_k_fixup(
2851
 
2852
  float sum[mmq_x*mmq_y / (nwarps*WARP_SIZE)] = {0.0f};
2853
 
2854
- const int ntx = (ncols_y + mmq_x - 1) / mmq_x;
2855
- const int nty = (nrows_x + mmq_y - 1) / mmq_y;
2856
 
2857
  const int bidx0 = blockIdx.x;
2858
 
@@ -2925,8 +2925,8 @@ static __global__ void mul_mat_q_stream_k_fixup(
2925
  const int offset_dst = wt*stride_sample_dst + zt*stride_channel_dst + jt*mmq_x*stride_col_dst + it*mmq_y;
2926
  dst += offset_dst;
2927
 
2928
- const int i_max = nrows_x - it*mmq_y - 1;
2929
- const int j_max = ncols_y - jt*mmq_x - 1;
2930
 
2931
  #pragma unroll
2932
  for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
@@ -2989,7 +2989,7 @@ static __global__ void mul_mat_q_stream_k_fixup(
2989
 
2990
  struct mmq_args {
2991
  const char * x; ggml_type type_x; const int * y; const int32_t * ids_dst; const int32_t * expert_bounds; float * dst;
2992
- int64_t ncols_x; int64_t nrows_x; int64_t ncols_y; int64_t stride_row_x; int64_t nrows_dst;
2993
  int64_t nchannels_x; int64_t nchannels_y; int64_t stride_channel_x; int64_t stride_channel_y; int64_t stride_channel_dst;
2994
  int64_t nsamples_x; int64_t nsamples_y; int64_t stride_sample_x; int64_t stride_sample_y; int64_t stride_sample_dst;
2995
  bool use_stream_k;
@@ -3025,8 +3025,8 @@ static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & a
3025
  }
3026
  #endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA)
3027
 
3028
- const int nty = (args.nrows_x + mmq_y - 1) / mmq_y;
3029
- const int ntx = (args.ncols_y + mmq_x - 1) / mmq_x;
3030
  const int ntzw = args.nchannels_y * args.nsamples_y;
3031
  const dim3 block_nums_xy_tiling(nty, ntx, ntzw);
3032
 
@@ -3040,14 +3040,14 @@ static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & a
3040
  constexpr bool need_check = false;
3041
  mul_mat_q<type, mmq_x, MMQ_NWARPS, need_check><<<block_nums_xy_tiling, block_dims, nbytes_shared, stream>>>
3042
  (args.x, args.y, args.ids_dst, args.expert_bounds, args.dst, nullptr,
3043
- args.ncols_x, args.nrows_x, args.ncols_y, args.stride_row_x, args.nrows_dst,
3044
  channel_ratio, args.nchannels_y, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst,
3045
  sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst);
3046
  } else {
3047
  constexpr bool need_check = true;
3048
  mul_mat_q<type, mmq_x, MMQ_NWARPS, need_check><<<block_nums_xy_tiling, block_dims, nbytes_shared, stream>>>
3049
  (args.x, args.y, args.ids_dst, args.expert_bounds, args.dst, nullptr,
3050
- args.ncols_x, args.nrows_x, args.ncols_y, args.stride_row_x, args.nrows_dst,
3051
  channel_ratio, args.nchannels_y, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst,
3052
  sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst);
3053
  }
@@ -3068,7 +3068,7 @@ static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & a
3068
 
3069
  mul_mat_q<type, mmq_x, MMQ_NWARPS, need_check><<<block_nums_stream_k, block_dims, nbytes_shared, stream>>>
3070
  (args.x, args.y, args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr,
3071
- args.ncols_x, args.nrows_x, args.ncols_y, args.stride_row_x, args.nrows_dst,
3072
  channel_ratio, args.nchannels_y, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst,
3073
  sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst);
3074
 
@@ -3077,14 +3077,14 @@ static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & a
3077
  }
3078
 
3079
  mul_mat_q_stream_k_fixup<type, mmq_x, MMQ_NWARPS, need_check><<<block_nums_stream_k, block_dims, 0, stream>>>
3080
- (args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr, args.ncols_x, args.nrows_x, args.ncols_y,
3081
  args.nrows_dst, args.nchannels_y, args.stride_channel_dst, args.nsamples_y, args.stride_sample_dst);
3082
  } else {
3083
  constexpr bool need_check = true;
3084
 
3085
  mul_mat_q<type, mmq_x, MMQ_NWARPS, need_check><<<block_nums_stream_k, block_dims, nbytes_shared, stream>>>
3086
  (args.x, args.y, args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr,
3087
- args.ncols_x, args.nrows_x, args.ncols_y, args.stride_row_x, args.nrows_dst,
3088
  channel_ratio, args.nchannels_y, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst,
3089
  sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst);
3090
 
@@ -3093,7 +3093,7 @@ static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & a
3093
  }
3094
 
3095
  mul_mat_q_stream_k_fixup<type, mmq_x, MMQ_NWARPS, need_check><<<block_nums_stream_k, block_dims, 0, stream>>>
3096
- (args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr, args.ncols_x, args.nrows_x, args.ncols_y,
3097
  args.nrows_dst, args.nchannels_y, args.stride_channel_dst, args.nsamples_y, args.stride_sample_dst);
3098
  }
3099
  }
 
2522
  static __device__ __forceinline__ void mul_mat_q_process_tile(
2523
  const char * __restrict__ x, const int offset_x, const int * __restrict__ y,
2524
  const int * __restrict__ ids_dst, float * __restrict__ dst, float * __restrict__ tmp_fixup,
2525
+ const int nrows_x, const int stride_row_x, const int ncols_y, const int stride_col_dst,
2526
  const int tile_x_max_i, const int tile_y_max_j, const int kb0_start, const int kb0_stop) {
2527
 
2528
  constexpr int qk = ggml_cuda_type_traits<type>::qk;
 
2606
  static __global__ void mul_mat_q(
2607
  const char * __restrict__ x, const int * __restrict__ y, const int32_t * __restrict__ ids_dst,
2608
  const int32_t * __restrict__ expert_bounds, float * __restrict__ dst, float * __restrict__ tmp_fixup,
2609
+ const int ncols_x, const int nrows_x, const int ncols_dst, const int stride_row_x, const int ncols_y, const int stride_col_dst,
2610
  const int channel_ratio, const int nchannels_y, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
2611
  const int sample_ratio, const int nsamples_y, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst) {
2612
 
 
2619
  constexpr int qk = ggml_cuda_type_traits<type>::qk;
2620
  constexpr int mmq_y = get_mmq_y_device();
2621
 
2622
+ const int ntx = (ncols_dst + mmq_x - 1) / mmq_x; // Number of tiles x
2623
+ const int nty = (nrows_x + mmq_y - 1) / mmq_y; // Number of tiles y
2624
 
2625
  // Initialize the ids for writing back data with just the index.
2626
  // For regular matrix multiplications this is never changed.
 
2648
 
2649
  // Defaults for regular matrix multiplication:
2650
  int col_low = 0;
2651
+ int col_high = ncols_dst;
2652
+ int col_diff = ncols_dst;
2653
  int offset_y = wt*stride_sample_y + zt*stride_channel_y;
2654
  int offset_dst = wt*stride_sample_dst + zt*stride_channel_dst + jt*mmq_x*stride_col_dst;
2655
 
 
2689
 
2690
  constexpr bool fixup = false;
2691
  mul_mat_q_process_tile<type, mmq_x, nwarps, need_check, fixup>
2692
+ (x, offset_x, y + offset_y, ids_dst_shared, dst + offset_dst, tmp_fixup, nrows_x, stride_row_x, ncols_y, stride_col_dst,
2693
  tile_x_max_i, tile_y_max_j, 0, ncols_x/qk);
2694
  return;
2695
  }
 
2720
 
2721
  // Defaults for regular matrix multiplication:
2722
  int col_low = 0;
2723
+ int col_high = ncols_dst;
2724
+ int col_diff = ncols_dst;
2725
  int offset_y = wt*stride_sample_y + zt*stride_channel_y;
2726
  int offset_dst = wt*stride_sample_dst + zt*stride_channel_dst + jt*mmq_x*stride_col_dst;
2727
 
 
2767
 
2768
  constexpr bool fixup = false; // All but (potentially) the last iterations write their data to dst rather than the fixup buffer.
2769
  mul_mat_q_process_tile<type, mmq_x, nwarps, need_check, fixup>
2770
+ (x, offset_x, y + offset_y, ids_dst_shared, dst + offset_dst, tmp_fixup, nrows_x, stride_row_x, ncols_y, stride_col_dst,
2771
  tile_x_max_i, tile_y_max_j, kb0_start, kb0_stop);
2772
 
2773
  kbc += blocks_per_ne00;
 
2792
 
2793
  // Defaults for regular matrix multiplication:
2794
  int col_low = 0;
2795
+ int col_high = ncols_dst;
2796
+ int col_diff = ncols_dst;
2797
  int offset_y = wt*stride_sample_y + zt*stride_channel_y;
2798
  int offset_dst = wt*stride_sample_dst + zt*stride_channel_dst + jt*mmq_x*stride_col_dst;
2799
 
 
2834
 
2835
  constexpr bool fixup = true; // Last index writes its data to fixup buffer to avoid data races with other blocks.
2836
  mul_mat_q_process_tile<type, mmq_x, nwarps, need_check, fixup>
2837
+ (x, offset_x, y + offset_y, ids_dst_shared, dst + offset_dst, tmp_fixup, nrows_x, stride_row_x, ncols_y, stride_col_dst,
2838
  tile_x_max_i, tile_y_max_j, kb0_start, kb0_stop);
2839
  }
2840
 
 
2842
  template <ggml_type type, int mmq_x, int nwarps, bool need_check>
2843
  static __global__ void mul_mat_q_stream_k_fixup(
2844
  const int32_t * ids_dst, const int32_t * expert_bounds, float * __restrict__ dst, const float * __restrict__ tmp_last_tile,
2845
+ const int ncols_x, const int nrows_x, const int ncols_dst, const int stride_col_dst,
2846
  const int nchannels_y, const int stride_channel_dst, const int nsamples_y, const int stride_sample_dst) {
2847
  constexpr int mmq_y = get_mmq_y_device();
2848
  constexpr int qk = ggml_cuda_type_traits<type>::qk;
 
2851
 
2852
  float sum[mmq_x*mmq_y / (nwarps*WARP_SIZE)] = {0.0f};
2853
 
2854
+ const int ntx = (ncols_dst + mmq_x - 1) / mmq_x;
2855
+ const int nty = (nrows_x + mmq_y - 1) / mmq_y;
2856
 
2857
  const int bidx0 = blockIdx.x;
2858
 
 
2925
  const int offset_dst = wt*stride_sample_dst + zt*stride_channel_dst + jt*mmq_x*stride_col_dst + it*mmq_y;
2926
  dst += offset_dst;
2927
 
2928
+ const int i_max = nrows_x - it*mmq_y - 1;
2929
+ const int j_max = ncols_dst - jt*mmq_x - 1;
2930
 
2931
  #pragma unroll
2932
  for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
 
2989
 
2990
  struct mmq_args {
2991
  const char * x; ggml_type type_x; const int * y; const int32_t * ids_dst; const int32_t * expert_bounds; float * dst;
2992
+ int64_t ncols_x; int64_t nrows_x; int64_t ncols_dst; int64_t stride_row_x; int64_t ncols_y; int64_t nrows_dst;
2993
  int64_t nchannels_x; int64_t nchannels_y; int64_t stride_channel_x; int64_t stride_channel_y; int64_t stride_channel_dst;
2994
  int64_t nsamples_x; int64_t nsamples_y; int64_t stride_sample_x; int64_t stride_sample_y; int64_t stride_sample_dst;
2995
  bool use_stream_k;
 
3025
  }
3026
  #endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA)
3027
 
3028
+ const int nty = (args.nrows_x + mmq_y - 1) / mmq_y;
3029
+ const int ntx = (args.ncols_dst + mmq_x - 1) / mmq_x;
3030
  const int ntzw = args.nchannels_y * args.nsamples_y;
3031
  const dim3 block_nums_xy_tiling(nty, ntx, ntzw);
3032
 
 
3040
  constexpr bool need_check = false;
3041
  mul_mat_q<type, mmq_x, MMQ_NWARPS, need_check><<<block_nums_xy_tiling, block_dims, nbytes_shared, stream>>>
3042
  (args.x, args.y, args.ids_dst, args.expert_bounds, args.dst, nullptr,
3043
+ args.ncols_x, args.nrows_x, args.ncols_dst, args.stride_row_x, args.ncols_y, args.nrows_dst,
3044
  channel_ratio, args.nchannels_y, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst,
3045
  sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst);
3046
  } else {
3047
  constexpr bool need_check = true;
3048
  mul_mat_q<type, mmq_x, MMQ_NWARPS, need_check><<<block_nums_xy_tiling, block_dims, nbytes_shared, stream>>>
3049
  (args.x, args.y, args.ids_dst, args.expert_bounds, args.dst, nullptr,
3050
+ args.ncols_x, args.nrows_x, args.ncols_dst, args.stride_row_x, args.ncols_y, args.nrows_dst,
3051
  channel_ratio, args.nchannels_y, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst,
3052
  sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst);
3053
  }
 
3068
 
3069
  mul_mat_q<type, mmq_x, MMQ_NWARPS, need_check><<<block_nums_stream_k, block_dims, nbytes_shared, stream>>>
3070
  (args.x, args.y, args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr,
3071
+ args.ncols_x, args.nrows_x, args.ncols_dst, args.stride_row_x, args.ncols_y, args.nrows_dst,
3072
  channel_ratio, args.nchannels_y, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst,
3073
  sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst);
3074
 
 
3077
  }
3078
 
3079
  mul_mat_q_stream_k_fixup<type, mmq_x, MMQ_NWARPS, need_check><<<block_nums_stream_k, block_dims, 0, stream>>>
3080
+ (args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr, args.ncols_x, args.nrows_x, args.ncols_dst,
3081
  args.nrows_dst, args.nchannels_y, args.stride_channel_dst, args.nsamples_y, args.stride_sample_dst);
3082
  } else {
3083
  constexpr bool need_check = true;
3084
 
3085
  mul_mat_q<type, mmq_x, MMQ_NWARPS, need_check><<<block_nums_stream_k, block_dims, nbytes_shared, stream>>>
3086
  (args.x, args.y, args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr,
3087
+ args.ncols_x, args.nrows_x, args.ncols_dst, args.stride_row_x, args.ncols_y, args.nrows_dst,
3088
  channel_ratio, args.nchannels_y, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst,
3089
  sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst);
3090
 
 
3093
  }
3094
 
3095
  mul_mat_q_stream_k_fixup<type, mmq_x, MMQ_NWARPS, need_check><<<block_nums_stream_k, block_dims, 0, stream>>>
3096
+ (args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr, args.ncols_x, args.nrows_x, args.ncols_dst,
3097
  args.nrows_dst, args.nchannels_y, args.stride_channel_dst, args.nsamples_y, args.stride_sample_dst);
3098
  }
3099
  }