Spaces:
Running
Running
Commit
·
1136116
1
Parent(s):
c3e51a2
CUDA: fix --split-mode row for MMQ (llama/13323)
Browse files- ggml/src/ggml-cuda/mmq.cu +3 -3
- ggml/src/ggml-cuda/mmq.cuh +27 -27
ggml/src/ggml-cuda/mmq.cu
CHANGED
|
@@ -128,7 +128,7 @@ void ggml_cuda_mul_mat_q(
|
|
| 128 |
|
| 129 |
const mmq_args args = {
|
| 130 |
src0_d, src0->type, (const int *) src1_q8_1.ptr, nullptr, nullptr, dst_d,
|
| 131 |
-
ne00, ne01, ne1, s01, s1,
|
| 132 |
ne02, ne12, s02, s12, s2,
|
| 133 |
ne03, ne13, s03, s13, s3,
|
| 134 |
use_stream_k};
|
|
@@ -212,7 +212,7 @@ void ggml_cuda_mul_mat_q(
|
|
| 212 |
// Note that ne02 is used instead of ne12 because the number of y channels determines the z dimension of the CUDA grid.
|
| 213 |
const mmq_args args = {
|
| 214 |
src0_d, src0->type, (const int *) src1_q8_1.ptr, ids_dst_dev, expert_bounds_dev, dst_d,
|
| 215 |
-
ne00, ne01, ne_get_rows, s01, s1,
|
| 216 |
ne02, ne02, s02, s12, s2,
|
| 217 |
ne03, ne13, s03, s13, s3,
|
| 218 |
use_stream_k};
|
|
@@ -251,7 +251,7 @@ void ggml_cuda_op_mul_mat_q(
|
|
| 251 |
ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA && src1_ncols == ne11;
|
| 252 |
const mmq_args args = {
|
| 253 |
src0_dd_i, src0->type, (const int *) src1_ddq_i, nullptr, nullptr, dst_dd_i,
|
| 254 |
-
ne00, row_diff, src1_ncols, stride01, nrows_dst,
|
| 255 |
1, 1, 0, 0, 0,
|
| 256 |
1, 1, 0, 0, 0,
|
| 257 |
use_stream_k};
|
|
|
|
| 128 |
|
| 129 |
const mmq_args args = {
|
| 130 |
src0_d, src0->type, (const int *) src1_q8_1.ptr, nullptr, nullptr, dst_d,
|
| 131 |
+
ne00, ne01, ne1, s01, ne11, s1,
|
| 132 |
ne02, ne12, s02, s12, s2,
|
| 133 |
ne03, ne13, s03, s13, s3,
|
| 134 |
use_stream_k};
|
|
|
|
| 212 |
// Note that ne02 is used instead of ne12 because the number of y channels determines the z dimension of the CUDA grid.
|
| 213 |
const mmq_args args = {
|
| 214 |
src0_d, src0->type, (const int *) src1_q8_1.ptr, ids_dst_dev, expert_bounds_dev, dst_d,
|
| 215 |
+
ne00, ne01, ne_get_rows, s01, ne_get_rows, s1,
|
| 216 |
ne02, ne02, s02, s12, s2,
|
| 217 |
ne03, ne13, s03, s13, s3,
|
| 218 |
use_stream_k};
|
|
|
|
| 251 |
ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA && src1_ncols == ne11;
|
| 252 |
const mmq_args args = {
|
| 253 |
src0_dd_i, src0->type, (const int *) src1_ddq_i, nullptr, nullptr, dst_dd_i,
|
| 254 |
+
ne00, row_diff, src1_ncols, stride01, ne11, nrows_dst,
|
| 255 |
1, 1, 0, 0, 0,
|
| 256 |
1, 1, 0, 0, 0,
|
| 257 |
use_stream_k};
|
ggml/src/ggml-cuda/mmq.cuh
CHANGED
|
@@ -2522,7 +2522,7 @@ template <ggml_type type, int mmq_x, int nwarps, bool need_check, bool fixup>
|
|
| 2522 |
static __device__ __forceinline__ void mul_mat_q_process_tile(
|
| 2523 |
const char * __restrict__ x, const int offset_x, const int * __restrict__ y,
|
| 2524 |
const int * __restrict__ ids_dst, float * __restrict__ dst, float * __restrict__ tmp_fixup,
|
| 2525 |
-
const int nrows_x, const int
|
| 2526 |
const int tile_x_max_i, const int tile_y_max_j, const int kb0_start, const int kb0_stop) {
|
| 2527 |
|
| 2528 |
constexpr int qk = ggml_cuda_type_traits<type>::qk;
|
|
@@ -2606,7 +2606,7 @@ template <ggml_type type, int mmq_x, int nwarps, bool need_check>
|
|
| 2606 |
static __global__ void mul_mat_q(
|
| 2607 |
const char * __restrict__ x, const int * __restrict__ y, const int32_t * __restrict__ ids_dst,
|
| 2608 |
const int32_t * __restrict__ expert_bounds, float * __restrict__ dst, float * __restrict__ tmp_fixup,
|
| 2609 |
-
const int ncols_x, const int nrows_x, const int
|
| 2610 |
const int channel_ratio, const int nchannels_y, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
|
| 2611 |
const int sample_ratio, const int nsamples_y, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst) {
|
| 2612 |
|
|
@@ -2619,8 +2619,8 @@ static __global__ void mul_mat_q(
|
|
| 2619 |
constexpr int qk = ggml_cuda_type_traits<type>::qk;
|
| 2620 |
constexpr int mmq_y = get_mmq_y_device();
|
| 2621 |
|
| 2622 |
-
const int ntx = (
|
| 2623 |
-
const int nty = (nrows_x
|
| 2624 |
|
| 2625 |
// Initialize the ids for writing back data with just the index.
|
| 2626 |
// For regular matrix multiplications this is never changed.
|
|
@@ -2648,8 +2648,8 @@ static __global__ void mul_mat_q(
|
|
| 2648 |
|
| 2649 |
// Defaults for regular matrix multiplication:
|
| 2650 |
int col_low = 0;
|
| 2651 |
-
int col_high =
|
| 2652 |
-
int col_diff =
|
| 2653 |
int offset_y = wt*stride_sample_y + zt*stride_channel_y;
|
| 2654 |
int offset_dst = wt*stride_sample_dst + zt*stride_channel_dst + jt*mmq_x*stride_col_dst;
|
| 2655 |
|
|
@@ -2689,7 +2689,7 @@ static __global__ void mul_mat_q(
|
|
| 2689 |
|
| 2690 |
constexpr bool fixup = false;
|
| 2691 |
mul_mat_q_process_tile<type, mmq_x, nwarps, need_check, fixup>
|
| 2692 |
-
(x, offset_x, y + offset_y, ids_dst_shared, dst + offset_dst, tmp_fixup, nrows_x,
|
| 2693 |
tile_x_max_i, tile_y_max_j, 0, ncols_x/qk);
|
| 2694 |
return;
|
| 2695 |
}
|
|
@@ -2720,8 +2720,8 @@ static __global__ void mul_mat_q(
|
|
| 2720 |
|
| 2721 |
// Defaults for regular matrix multiplication:
|
| 2722 |
int col_low = 0;
|
| 2723 |
-
int col_high =
|
| 2724 |
-
int col_diff =
|
| 2725 |
int offset_y = wt*stride_sample_y + zt*stride_channel_y;
|
| 2726 |
int offset_dst = wt*stride_sample_dst + zt*stride_channel_dst + jt*mmq_x*stride_col_dst;
|
| 2727 |
|
|
@@ -2767,7 +2767,7 @@ static __global__ void mul_mat_q(
|
|
| 2767 |
|
| 2768 |
constexpr bool fixup = false; // All but (potentially) the last iterations write their data to dst rather than the fixup buffer.
|
| 2769 |
mul_mat_q_process_tile<type, mmq_x, nwarps, need_check, fixup>
|
| 2770 |
-
(x, offset_x, y + offset_y, ids_dst_shared, dst + offset_dst, tmp_fixup, nrows_x,
|
| 2771 |
tile_x_max_i, tile_y_max_j, kb0_start, kb0_stop);
|
| 2772 |
|
| 2773 |
kbc += blocks_per_ne00;
|
|
@@ -2792,8 +2792,8 @@ static __global__ void mul_mat_q(
|
|
| 2792 |
|
| 2793 |
// Defaults for regular matrix multiplication:
|
| 2794 |
int col_low = 0;
|
| 2795 |
-
int col_high =
|
| 2796 |
-
int col_diff =
|
| 2797 |
int offset_y = wt*stride_sample_y + zt*stride_channel_y;
|
| 2798 |
int offset_dst = wt*stride_sample_dst + zt*stride_channel_dst + jt*mmq_x*stride_col_dst;
|
| 2799 |
|
|
@@ -2834,7 +2834,7 @@ static __global__ void mul_mat_q(
|
|
| 2834 |
|
| 2835 |
constexpr bool fixup = true; // Last index writes its data to fixup buffer to avoid data races with other blocks.
|
| 2836 |
mul_mat_q_process_tile<type, mmq_x, nwarps, need_check, fixup>
|
| 2837 |
-
(x, offset_x, y + offset_y, ids_dst_shared, dst + offset_dst, tmp_fixup, nrows_x,
|
| 2838 |
tile_x_max_i, tile_y_max_j, kb0_start, kb0_stop);
|
| 2839 |
}
|
| 2840 |
|
|
@@ -2842,7 +2842,7 @@ static __global__ void mul_mat_q(
|
|
| 2842 |
template <ggml_type type, int mmq_x, int nwarps, bool need_check>
|
| 2843 |
static __global__ void mul_mat_q_stream_k_fixup(
|
| 2844 |
const int32_t * ids_dst, const int32_t * expert_bounds, float * __restrict__ dst, const float * __restrict__ tmp_last_tile,
|
| 2845 |
-
const int ncols_x, const int nrows_x, const int
|
| 2846 |
const int nchannels_y, const int stride_channel_dst, const int nsamples_y, const int stride_sample_dst) {
|
| 2847 |
constexpr int mmq_y = get_mmq_y_device();
|
| 2848 |
constexpr int qk = ggml_cuda_type_traits<type>::qk;
|
|
@@ -2851,8 +2851,8 @@ static __global__ void mul_mat_q_stream_k_fixup(
|
|
| 2851 |
|
| 2852 |
float sum[mmq_x*mmq_y / (nwarps*WARP_SIZE)] = {0.0f};
|
| 2853 |
|
| 2854 |
-
const int ntx = (
|
| 2855 |
-
const int nty = (nrows_x
|
| 2856 |
|
| 2857 |
const int bidx0 = blockIdx.x;
|
| 2858 |
|
|
@@ -2925,8 +2925,8 @@ static __global__ void mul_mat_q_stream_k_fixup(
|
|
| 2925 |
const int offset_dst = wt*stride_sample_dst + zt*stride_channel_dst + jt*mmq_x*stride_col_dst + it*mmq_y;
|
| 2926 |
dst += offset_dst;
|
| 2927 |
|
| 2928 |
-
const int i_max = nrows_x
|
| 2929 |
-
const int j_max =
|
| 2930 |
|
| 2931 |
#pragma unroll
|
| 2932 |
for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
|
|
@@ -2989,7 +2989,7 @@ static __global__ void mul_mat_q_stream_k_fixup(
|
|
| 2989 |
|
| 2990 |
struct mmq_args {
|
| 2991 |
const char * x; ggml_type type_x; const int * y; const int32_t * ids_dst; const int32_t * expert_bounds; float * dst;
|
| 2992 |
-
int64_t ncols_x; int64_t nrows_x; int64_t
|
| 2993 |
int64_t nchannels_x; int64_t nchannels_y; int64_t stride_channel_x; int64_t stride_channel_y; int64_t stride_channel_dst;
|
| 2994 |
int64_t nsamples_x; int64_t nsamples_y; int64_t stride_sample_x; int64_t stride_sample_y; int64_t stride_sample_dst;
|
| 2995 |
bool use_stream_k;
|
|
@@ -3025,8 +3025,8 @@ static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & a
|
|
| 3025 |
}
|
| 3026 |
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA)
|
| 3027 |
|
| 3028 |
-
const int nty = (args.nrows_x
|
| 3029 |
-
const int ntx = (args.
|
| 3030 |
const int ntzw = args.nchannels_y * args.nsamples_y;
|
| 3031 |
const dim3 block_nums_xy_tiling(nty, ntx, ntzw);
|
| 3032 |
|
|
@@ -3040,14 +3040,14 @@ static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & a
|
|
| 3040 |
constexpr bool need_check = false;
|
| 3041 |
mul_mat_q<type, mmq_x, MMQ_NWARPS, need_check><<<block_nums_xy_tiling, block_dims, nbytes_shared, stream>>>
|
| 3042 |
(args.x, args.y, args.ids_dst, args.expert_bounds, args.dst, nullptr,
|
| 3043 |
-
args.ncols_x, args.nrows_x, args.
|
| 3044 |
channel_ratio, args.nchannels_y, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst,
|
| 3045 |
sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst);
|
| 3046 |
} else {
|
| 3047 |
constexpr bool need_check = true;
|
| 3048 |
mul_mat_q<type, mmq_x, MMQ_NWARPS, need_check><<<block_nums_xy_tiling, block_dims, nbytes_shared, stream>>>
|
| 3049 |
(args.x, args.y, args.ids_dst, args.expert_bounds, args.dst, nullptr,
|
| 3050 |
-
args.ncols_x, args.nrows_x, args.
|
| 3051 |
channel_ratio, args.nchannels_y, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst,
|
| 3052 |
sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst);
|
| 3053 |
}
|
|
@@ -3068,7 +3068,7 @@ static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & a
|
|
| 3068 |
|
| 3069 |
mul_mat_q<type, mmq_x, MMQ_NWARPS, need_check><<<block_nums_stream_k, block_dims, nbytes_shared, stream>>>
|
| 3070 |
(args.x, args.y, args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr,
|
| 3071 |
-
args.ncols_x, args.nrows_x, args.
|
| 3072 |
channel_ratio, args.nchannels_y, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst,
|
| 3073 |
sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst);
|
| 3074 |
|
|
@@ -3077,14 +3077,14 @@ static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & a
|
|
| 3077 |
}
|
| 3078 |
|
| 3079 |
mul_mat_q_stream_k_fixup<type, mmq_x, MMQ_NWARPS, need_check><<<block_nums_stream_k, block_dims, 0, stream>>>
|
| 3080 |
-
(args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr, args.ncols_x, args.nrows_x, args.
|
| 3081 |
args.nrows_dst, args.nchannels_y, args.stride_channel_dst, args.nsamples_y, args.stride_sample_dst);
|
| 3082 |
} else {
|
| 3083 |
constexpr bool need_check = true;
|
| 3084 |
|
| 3085 |
mul_mat_q<type, mmq_x, MMQ_NWARPS, need_check><<<block_nums_stream_k, block_dims, nbytes_shared, stream>>>
|
| 3086 |
(args.x, args.y, args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr,
|
| 3087 |
-
args.ncols_x, args.nrows_x, args.
|
| 3088 |
channel_ratio, args.nchannels_y, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst,
|
| 3089 |
sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst);
|
| 3090 |
|
|
@@ -3093,7 +3093,7 @@ static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & a
|
|
| 3093 |
}
|
| 3094 |
|
| 3095 |
mul_mat_q_stream_k_fixup<type, mmq_x, MMQ_NWARPS, need_check><<<block_nums_stream_k, block_dims, 0, stream>>>
|
| 3096 |
-
(args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr, args.ncols_x, args.nrows_x, args.
|
| 3097 |
args.nrows_dst, args.nchannels_y, args.stride_channel_dst, args.nsamples_y, args.stride_sample_dst);
|
| 3098 |
}
|
| 3099 |
}
|
|
|
|
| 2522 |
static __device__ __forceinline__ void mul_mat_q_process_tile(
|
| 2523 |
const char * __restrict__ x, const int offset_x, const int * __restrict__ y,
|
| 2524 |
const int * __restrict__ ids_dst, float * __restrict__ dst, float * __restrict__ tmp_fixup,
|
| 2525 |
+
const int nrows_x, const int stride_row_x, const int ncols_y, const int stride_col_dst,
|
| 2526 |
const int tile_x_max_i, const int tile_y_max_j, const int kb0_start, const int kb0_stop) {
|
| 2527 |
|
| 2528 |
constexpr int qk = ggml_cuda_type_traits<type>::qk;
|
|
|
|
| 2606 |
static __global__ void mul_mat_q(
|
| 2607 |
const char * __restrict__ x, const int * __restrict__ y, const int32_t * __restrict__ ids_dst,
|
| 2608 |
const int32_t * __restrict__ expert_bounds, float * __restrict__ dst, float * __restrict__ tmp_fixup,
|
| 2609 |
+
const int ncols_x, const int nrows_x, const int ncols_dst, const int stride_row_x, const int ncols_y, const int stride_col_dst,
|
| 2610 |
const int channel_ratio, const int nchannels_y, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
|
| 2611 |
const int sample_ratio, const int nsamples_y, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst) {
|
| 2612 |
|
|
|
|
| 2619 |
constexpr int qk = ggml_cuda_type_traits<type>::qk;
|
| 2620 |
constexpr int mmq_y = get_mmq_y_device();
|
| 2621 |
|
| 2622 |
+
const int ntx = (ncols_dst + mmq_x - 1) / mmq_x; // Number of tiles x
|
| 2623 |
+
const int nty = (nrows_x + mmq_y - 1) / mmq_y; // Number of tiles y
|
| 2624 |
|
| 2625 |
// Initialize the ids for writing back data with just the index.
|
| 2626 |
// For regular matrix multiplications this is never changed.
|
|
|
|
| 2648 |
|
| 2649 |
// Defaults for regular matrix multiplication:
|
| 2650 |
int col_low = 0;
|
| 2651 |
+
int col_high = ncols_dst;
|
| 2652 |
+
int col_diff = ncols_dst;
|
| 2653 |
int offset_y = wt*stride_sample_y + zt*stride_channel_y;
|
| 2654 |
int offset_dst = wt*stride_sample_dst + zt*stride_channel_dst + jt*mmq_x*stride_col_dst;
|
| 2655 |
|
|
|
|
| 2689 |
|
| 2690 |
constexpr bool fixup = false;
|
| 2691 |
mul_mat_q_process_tile<type, mmq_x, nwarps, need_check, fixup>
|
| 2692 |
+
(x, offset_x, y + offset_y, ids_dst_shared, dst + offset_dst, tmp_fixup, nrows_x, stride_row_x, ncols_y, stride_col_dst,
|
| 2693 |
tile_x_max_i, tile_y_max_j, 0, ncols_x/qk);
|
| 2694 |
return;
|
| 2695 |
}
|
|
|
|
| 2720 |
|
| 2721 |
// Defaults for regular matrix multiplication:
|
| 2722 |
int col_low = 0;
|
| 2723 |
+
int col_high = ncols_dst;
|
| 2724 |
+
int col_diff = ncols_dst;
|
| 2725 |
int offset_y = wt*stride_sample_y + zt*stride_channel_y;
|
| 2726 |
int offset_dst = wt*stride_sample_dst + zt*stride_channel_dst + jt*mmq_x*stride_col_dst;
|
| 2727 |
|
|
|
|
| 2767 |
|
| 2768 |
constexpr bool fixup = false; // All but (potentially) the last iterations write their data to dst rather than the fixup buffer.
|
| 2769 |
mul_mat_q_process_tile<type, mmq_x, nwarps, need_check, fixup>
|
| 2770 |
+
(x, offset_x, y + offset_y, ids_dst_shared, dst + offset_dst, tmp_fixup, nrows_x, stride_row_x, ncols_y, stride_col_dst,
|
| 2771 |
tile_x_max_i, tile_y_max_j, kb0_start, kb0_stop);
|
| 2772 |
|
| 2773 |
kbc += blocks_per_ne00;
|
|
|
|
| 2792 |
|
| 2793 |
// Defaults for regular matrix multiplication:
|
| 2794 |
int col_low = 0;
|
| 2795 |
+
int col_high = ncols_dst;
|
| 2796 |
+
int col_diff = ncols_dst;
|
| 2797 |
int offset_y = wt*stride_sample_y + zt*stride_channel_y;
|
| 2798 |
int offset_dst = wt*stride_sample_dst + zt*stride_channel_dst + jt*mmq_x*stride_col_dst;
|
| 2799 |
|
|
|
|
| 2834 |
|
| 2835 |
constexpr bool fixup = true; // Last index writes its data to fixup buffer to avoid data races with other blocks.
|
| 2836 |
mul_mat_q_process_tile<type, mmq_x, nwarps, need_check, fixup>
|
| 2837 |
+
(x, offset_x, y + offset_y, ids_dst_shared, dst + offset_dst, tmp_fixup, nrows_x, stride_row_x, ncols_y, stride_col_dst,
|
| 2838 |
tile_x_max_i, tile_y_max_j, kb0_start, kb0_stop);
|
| 2839 |
}
|
| 2840 |
|
|
|
|
| 2842 |
template <ggml_type type, int mmq_x, int nwarps, bool need_check>
|
| 2843 |
static __global__ void mul_mat_q_stream_k_fixup(
|
| 2844 |
const int32_t * ids_dst, const int32_t * expert_bounds, float * __restrict__ dst, const float * __restrict__ tmp_last_tile,
|
| 2845 |
+
const int ncols_x, const int nrows_x, const int ncols_dst, const int stride_col_dst,
|
| 2846 |
const int nchannels_y, const int stride_channel_dst, const int nsamples_y, const int stride_sample_dst) {
|
| 2847 |
constexpr int mmq_y = get_mmq_y_device();
|
| 2848 |
constexpr int qk = ggml_cuda_type_traits<type>::qk;
|
|
|
|
| 2851 |
|
| 2852 |
float sum[mmq_x*mmq_y / (nwarps*WARP_SIZE)] = {0.0f};
|
| 2853 |
|
| 2854 |
+
const int ntx = (ncols_dst + mmq_x - 1) / mmq_x;
|
| 2855 |
+
const int nty = (nrows_x + mmq_y - 1) / mmq_y;
|
| 2856 |
|
| 2857 |
const int bidx0 = blockIdx.x;
|
| 2858 |
|
|
|
|
| 2925 |
const int offset_dst = wt*stride_sample_dst + zt*stride_channel_dst + jt*mmq_x*stride_col_dst + it*mmq_y;
|
| 2926 |
dst += offset_dst;
|
| 2927 |
|
| 2928 |
+
const int i_max = nrows_x - it*mmq_y - 1;
|
| 2929 |
+
const int j_max = ncols_dst - jt*mmq_x - 1;
|
| 2930 |
|
| 2931 |
#pragma unroll
|
| 2932 |
for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
|
|
|
|
| 2989 |
|
| 2990 |
struct mmq_args {
|
| 2991 |
const char * x; ggml_type type_x; const int * y; const int32_t * ids_dst; const int32_t * expert_bounds; float * dst;
|
| 2992 |
+
int64_t ncols_x; int64_t nrows_x; int64_t ncols_dst; int64_t stride_row_x; int64_t ncols_y; int64_t nrows_dst;
|
| 2993 |
int64_t nchannels_x; int64_t nchannels_y; int64_t stride_channel_x; int64_t stride_channel_y; int64_t stride_channel_dst;
|
| 2994 |
int64_t nsamples_x; int64_t nsamples_y; int64_t stride_sample_x; int64_t stride_sample_y; int64_t stride_sample_dst;
|
| 2995 |
bool use_stream_k;
|
|
|
|
| 3025 |
}
|
| 3026 |
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA)
|
| 3027 |
|
| 3028 |
+
const int nty = (args.nrows_x + mmq_y - 1) / mmq_y;
|
| 3029 |
+
const int ntx = (args.ncols_dst + mmq_x - 1) / mmq_x;
|
| 3030 |
const int ntzw = args.nchannels_y * args.nsamples_y;
|
| 3031 |
const dim3 block_nums_xy_tiling(nty, ntx, ntzw);
|
| 3032 |
|
|
|
|
| 3040 |
constexpr bool need_check = false;
|
| 3041 |
mul_mat_q<type, mmq_x, MMQ_NWARPS, need_check><<<block_nums_xy_tiling, block_dims, nbytes_shared, stream>>>
|
| 3042 |
(args.x, args.y, args.ids_dst, args.expert_bounds, args.dst, nullptr,
|
| 3043 |
+
args.ncols_x, args.nrows_x, args.ncols_dst, args.stride_row_x, args.ncols_y, args.nrows_dst,
|
| 3044 |
channel_ratio, args.nchannels_y, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst,
|
| 3045 |
sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst);
|
| 3046 |
} else {
|
| 3047 |
constexpr bool need_check = true;
|
| 3048 |
mul_mat_q<type, mmq_x, MMQ_NWARPS, need_check><<<block_nums_xy_tiling, block_dims, nbytes_shared, stream>>>
|
| 3049 |
(args.x, args.y, args.ids_dst, args.expert_bounds, args.dst, nullptr,
|
| 3050 |
+
args.ncols_x, args.nrows_x, args.ncols_dst, args.stride_row_x, args.ncols_y, args.nrows_dst,
|
| 3051 |
channel_ratio, args.nchannels_y, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst,
|
| 3052 |
sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst);
|
| 3053 |
}
|
|
|
|
| 3068 |
|
| 3069 |
mul_mat_q<type, mmq_x, MMQ_NWARPS, need_check><<<block_nums_stream_k, block_dims, nbytes_shared, stream>>>
|
| 3070 |
(args.x, args.y, args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr,
|
| 3071 |
+
args.ncols_x, args.nrows_x, args.ncols_dst, args.stride_row_x, args.ncols_y, args.nrows_dst,
|
| 3072 |
channel_ratio, args.nchannels_y, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst,
|
| 3073 |
sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst);
|
| 3074 |
|
|
|
|
| 3077 |
}
|
| 3078 |
|
| 3079 |
mul_mat_q_stream_k_fixup<type, mmq_x, MMQ_NWARPS, need_check><<<block_nums_stream_k, block_dims, 0, stream>>>
|
| 3080 |
+
(args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr, args.ncols_x, args.nrows_x, args.ncols_dst,
|
| 3081 |
args.nrows_dst, args.nchannels_y, args.stride_channel_dst, args.nsamples_y, args.stride_sample_dst);
|
| 3082 |
} else {
|
| 3083 |
constexpr bool need_check = true;
|
| 3084 |
|
| 3085 |
mul_mat_q<type, mmq_x, MMQ_NWARPS, need_check><<<block_nums_stream_k, block_dims, nbytes_shared, stream>>>
|
| 3086 |
(args.x, args.y, args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr,
|
| 3087 |
+
args.ncols_x, args.nrows_x, args.ncols_dst, args.stride_row_x, args.ncols_y, args.nrows_dst,
|
| 3088 |
channel_ratio, args.nchannels_y, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst,
|
| 3089 |
sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst);
|
| 3090 |
|
|
|
|
| 3093 |
}
|
| 3094 |
|
| 3095 |
mul_mat_q_stream_k_fixup<type, mmq_x, MMQ_NWARPS, need_check><<<block_nums_stream_k, block_dims, 0, stream>>>
|
| 3096 |
+
(args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr, args.ncols_x, args.nrows_x, args.ncols_dst,
|
| 3097 |
args.nrows_dst, args.nchannels_y, args.stride_channel_dst, args.nsamples_y, args.stride_sample_dst);
|
| 3098 |
}
|
| 3099 |
}
|