Spaces:
Running
Running
Commit
·
2d1e6e7
1
Parent(s):
b6dc6a1
CUDA: mul_mat_v support for batch sizes > 1 (llama/14262)
Browse files* CUDA: mul_mat_v support for batch sizes > 1
* use 64 bit math for initial offset calculation
- ggml/src/ggml-cuda/common.cuh +4 -0
- ggml/src/ggml-cuda/ggml-cuda.cu +11 -13
- ggml/src/ggml-cuda/mmv.cu +239 -87
- ggml/src/ggml-cuda/mmv.cuh +2 -3
ggml/src/ggml-cuda/common.cuh
CHANGED
|
@@ -262,6 +262,10 @@ static bool fp16_mma_hardware_available(const int cc) {
|
|
| 262 |
GGML_CUDA_CC_IS_CDNA(cc) || GGML_CUDA_CC_IS_RDNA3(cc) || GGML_CUDA_CC_IS_RDNA4(cc);
|
| 263 |
}
|
| 264 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 265 |
// Volta technically had FP16 tensor cores but they work very differently compared to Turing and later.
|
| 266 |
static bool new_mma_available(const int cc) {
|
| 267 |
return GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_TURING;
|
|
|
|
| 262 |
GGML_CUDA_CC_IS_CDNA(cc) || GGML_CUDA_CC_IS_RDNA3(cc) || GGML_CUDA_CC_IS_RDNA4(cc);
|
| 263 |
}
|
| 264 |
|
| 265 |
+
static bool bf16_mma_hardware_available(const int cc) {
|
| 266 |
+
return GGML_CUDA_CC_IS_NVIDIA(cc) && cc >= GGML_CUDA_CC_AMPERE;
|
| 267 |
+
}
|
| 268 |
+
|
| 269 |
// Volta technically had FP16 tensor cores but they work very differently compared to Turing and later.
|
| 270 |
static bool new_mma_available(const int cc) {
|
| 271 |
return GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_TURING;
|
ggml/src/ggml-cuda/ggml-cuda.cu
CHANGED
|
@@ -1943,16 +1943,14 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
|
|
| 1943 |
&& ggml_nbytes(src0) != ggml_backend_buffer_get_alloc_size(src0->buffer, src0) && src0->view_src;
|
| 1944 |
|
| 1945 |
bool use_mul_mat_vec = (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_BF16)
|
| 1946 |
-
&& src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32
|
| 1947 |
-
&& src0->ne[0] % 2 == 0 && src1->ne[1] == 1;
|
| 1948 |
bool use_mul_mat_vec_q = ggml_is_quantized(src0->type) && !bad_padding_clear
|
| 1949 |
&& src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32
|
| 1950 |
&& src1->ne[1] <= MMVQ_MAX_BATCH_SIZE;
|
| 1951 |
bool use_mul_mat_q = ggml_is_quantized(src0->type) && !bad_padding_clear
|
| 1952 |
&& src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32;
|
| 1953 |
|
| 1954 |
-
bool any_gpus_with_slow_fp16
|
| 1955 |
-
bool any_gpus_without_fp16_mma = false;
|
| 1956 |
|
| 1957 |
if (split) {
|
| 1958 |
ggml_backend_cuda_split_buffer_type_context * buft_ctx = (ggml_backend_cuda_split_buffer_type_context *) src0->buffer->buft->context;
|
|
@@ -1963,16 +1961,16 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
|
|
| 1963 |
continue;
|
| 1964 |
}
|
| 1965 |
|
| 1966 |
-
const int cc
|
| 1967 |
-
use_mul_mat_q
|
| 1968 |
-
|
| 1969 |
-
|
| 1970 |
}
|
| 1971 |
} else {
|
| 1972 |
-
const int cc
|
| 1973 |
-
use_mul_mat_q
|
| 1974 |
-
|
| 1975 |
-
|
| 1976 |
}
|
| 1977 |
|
| 1978 |
// debug helpers
|
|
@@ -1983,7 +1981,7 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
|
|
| 1983 |
//printf("src0 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src0), ggml_is_transposed(src0), ggml_type_name(src0->type), src0->name);
|
| 1984 |
//printf("src1 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src1), ggml_is_transposed(src1), ggml_type_name(src1->type), src1->name);
|
| 1985 |
|
| 1986 |
-
if (!split && use_mul_mat_vec
|
| 1987 |
// the custom F16 vector kernel can be used over batched cuBLAS GEMM
|
| 1988 |
// but this is only faster for GPUs without tensor cores or with a thin src0 matrix (particularly KQV in attention)
|
| 1989 |
ggml_cuda_mul_mat_vec(ctx, src0, src1, nullptr, dst);
|
|
|
|
| 1943 |
&& ggml_nbytes(src0) != ggml_backend_buffer_get_alloc_size(src0->buffer, src0) && src0->view_src;
|
| 1944 |
|
| 1945 |
bool use_mul_mat_vec = (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_BF16)
|
| 1946 |
+
&& src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32;
|
|
|
|
| 1947 |
bool use_mul_mat_vec_q = ggml_is_quantized(src0->type) && !bad_padding_clear
|
| 1948 |
&& src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32
|
| 1949 |
&& src1->ne[1] <= MMVQ_MAX_BATCH_SIZE;
|
| 1950 |
bool use_mul_mat_q = ggml_is_quantized(src0->type) && !bad_padding_clear
|
| 1951 |
&& src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32;
|
| 1952 |
|
| 1953 |
+
bool any_gpus_with_slow_fp16 = false;
|
|
|
|
| 1954 |
|
| 1955 |
if (split) {
|
| 1956 |
ggml_backend_cuda_split_buffer_type_context * buft_ctx = (ggml_backend_cuda_split_buffer_type_context *) src0->buffer->buft->context;
|
|
|
|
| 1961 |
continue;
|
| 1962 |
}
|
| 1963 |
|
| 1964 |
+
const int cc = ggml_cuda_info().devices[id].cc;
|
| 1965 |
+
use_mul_mat_q = use_mul_mat_q && ggml_cuda_should_use_mmq(src0->type, cc, src1->ne[1]);
|
| 1966 |
+
use_mul_mat_vec = use_mul_mat_vec && ggml_cuda_should_use_mmv(src0->type, cc, src0->ne, src1->ne[1]);
|
| 1967 |
+
any_gpus_with_slow_fp16 = any_gpus_with_slow_fp16 || !fast_fp16_hardware_available(cc);
|
| 1968 |
}
|
| 1969 |
} else {
|
| 1970 |
+
const int cc = ggml_cuda_info().devices[ctx.device].cc;
|
| 1971 |
+
use_mul_mat_q = use_mul_mat_q && ggml_cuda_should_use_mmq(src0->type, cc, src1->ne[1]);
|
| 1972 |
+
use_mul_mat_vec = use_mul_mat_vec && ggml_cuda_should_use_mmv(src0->type, cc, src0->ne, src1->ne[1]);
|
| 1973 |
+
any_gpus_with_slow_fp16 = any_gpus_with_slow_fp16 || !fast_fp16_hardware_available(cc);
|
| 1974 |
}
|
| 1975 |
|
| 1976 |
// debug helpers
|
|
|
|
| 1981 |
//printf("src0 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src0), ggml_is_transposed(src0), ggml_type_name(src0->type), src0->name);
|
| 1982 |
//printf("src1 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src1), ggml_is_transposed(src1), ggml_type_name(src1->type), src1->name);
|
| 1983 |
|
| 1984 |
+
if (!split && use_mul_mat_vec) {
|
| 1985 |
// the custom F16 vector kernel can be used over batched cuBLAS GEMM
|
| 1986 |
// but this is only faster for GPUs without tensor cores or with a thin src0 matrix (particularly KQV in attention)
|
| 1987 |
ggml_cuda_mul_mat_vec(ctx, src0, src1, nullptr, dst);
|
ggml/src/ggml-cuda/mmv.cu
CHANGED
|
@@ -2,25 +2,26 @@
|
|
| 2 |
#include "common.cuh"
|
| 3 |
#include "mmv.cuh"
|
| 4 |
|
| 5 |
-
template <typename T, typename type_acc, int block_size>
|
| 6 |
static __global__ void mul_mat_vec(
|
| 7 |
const T * __restrict__ x, const float * __restrict__ y, const int32_t * __restrict__ ids, float * __restrict__ dst,
|
| 8 |
-
const
|
| 9 |
-
const
|
| 10 |
-
const
|
| 11 |
-
const
|
| 12 |
-
const
|
| 13 |
-
const
|
| 14 |
-
const
|
| 15 |
-
const
|
| 16 |
-
const
|
| 17 |
-
const
|
| 18 |
-
const int
|
|
|
|
| 19 |
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
| 20 |
|
| 21 |
-
x += sample_x *stride_sample_x + channel_x *stride_channel_x + row*stride_row;
|
| 22 |
-
y += sample_y *stride_sample_y + channel_y *stride_channel_y;
|
| 23 |
-
dst += sample_dst*stride_sample_dst + channel_dst*stride_channel_dst;
|
| 24 |
|
| 25 |
const float2 * y2 = (const float2 *) y;
|
| 26 |
|
|
@@ -34,81 +35,108 @@ static __global__ void mul_mat_vec(
|
|
| 34 |
__syncthreads();
|
| 35 |
}
|
| 36 |
|
| 37 |
-
float sumf = 0.0f;
|
| 38 |
|
| 39 |
if constexpr (std::is_same<T, float>::value) {
|
| 40 |
const float2 * x2 = (const float2 *) x;
|
| 41 |
|
| 42 |
-
for (
|
| 43 |
const float2 tmpx = x2[col2];
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 47 |
}
|
| 48 |
} else if constexpr (std::is_same<T, half>::value) {
|
| 49 |
const half2 * x2 = (const half2 *) x;
|
| 50 |
|
| 51 |
if (std::is_same<type_acc, float>::value) {
|
| 52 |
-
for (
|
| 53 |
const float2 tmpx = __half22float2(x2[col2]);
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 57 |
}
|
| 58 |
} else {
|
| 59 |
#ifdef FP16_AVAILABLE
|
| 60 |
-
half2 sumh2 =
|
|
|
|
|
|
|
|
|
|
| 61 |
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
|
|
|
|
|
|
| 65 |
}
|
| 66 |
|
| 67 |
-
|
|
|
|
|
|
|
|
|
|
| 68 |
#else
|
| 69 |
NO_DEVICE_CODE;
|
| 70 |
#endif // FP16_AVAILABLE
|
| 71 |
}
|
| 72 |
} else if constexpr (std::is_same<T, nv_bfloat16>::value) {
|
| 73 |
const int * x2 = (const int *) x;
|
| 74 |
-
for (
|
| 75 |
-
const int
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
|
|
|
|
|
|
|
|
|
| 79 |
}
|
| 80 |
} else {
|
| 81 |
static_assert(std::is_same<T, void>::value, "unsupported type");
|
| 82 |
}
|
| 83 |
|
| 84 |
-
|
|
|
|
|
|
|
| 85 |
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 91 |
}
|
| 92 |
-
sumf = buf_iw[tid];
|
| 93 |
-
sumf = warp_reduce_sum<warp_size>(sumf);
|
| 94 |
}
|
| 95 |
|
| 96 |
-
if (tid
|
| 97 |
return;
|
| 98 |
}
|
| 99 |
|
| 100 |
-
dst[row] = sumf;
|
| 101 |
}
|
| 102 |
|
| 103 |
-
template <typename T, typename type_acc>
|
| 104 |
static void launch_mul_mat_vec_cuda(
|
| 105 |
const T * x, const float * y, const int32_t * ids, float * dst,
|
| 106 |
-
const int64_t ncols, const int64_t nrows,
|
|
|
|
|
|
|
| 107 |
const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,
|
| 108 |
const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
|
| 109 |
cudaStream_t stream) {
|
| 110 |
-
GGML_ASSERT(ncols
|
| 111 |
-
GGML_ASSERT(stride_row
|
|
|
|
| 112 |
GGML_ASSERT(ids || nchannels_dst % nchannels_x == 0);
|
| 113 |
GGML_ASSERT( nsamples_dst % nsamples_x == 0);
|
| 114 |
const int64_t channel_ratio = nchannels_dst / nchannels_x;
|
|
@@ -138,44 +166,52 @@ static void launch_mul_mat_vec_cuda(
|
|
| 138 |
const dim3 block_dims(block_size_best, 1, 1);
|
| 139 |
switch (block_size_best) {
|
| 140 |
case 32: {
|
| 141 |
-
mul_mat_vec<T, type_acc, 32><<<block_nums, block_dims, smem, stream>>>
|
| 142 |
-
(x, y, ids, dst, ncols/2, nchannels_y, stride_row,
|
| 143 |
-
|
|
|
|
| 144 |
} break;
|
| 145 |
case 64: {
|
| 146 |
-
mul_mat_vec<T, type_acc, 64><<<block_nums, block_dims, smem, stream>>>
|
| 147 |
-
(x, y, ids, dst, ncols/2, nchannels_y, stride_row,
|
| 148 |
-
|
|
|
|
| 149 |
} break;
|
| 150 |
case 96: {
|
| 151 |
-
mul_mat_vec<T, type_acc, 96><<<block_nums, block_dims, smem, stream>>>
|
| 152 |
-
(x, y, ids, dst, ncols/2, nchannels_y, stride_row,
|
| 153 |
-
|
|
|
|
| 154 |
} break;
|
| 155 |
case 128: {
|
| 156 |
-
mul_mat_vec<T, type_acc, 128><<<block_nums, block_dims, smem, stream>>>
|
| 157 |
-
(x, y, ids, dst, ncols/2, nchannels_y, stride_row,
|
| 158 |
-
|
|
|
|
| 159 |
} break;
|
| 160 |
case 160: {
|
| 161 |
-
mul_mat_vec<T, type_acc, 160><<<block_nums, block_dims, smem, stream>>>
|
| 162 |
-
(x, y, ids, dst, ncols/2, nchannels_y, stride_row,
|
| 163 |
-
|
|
|
|
| 164 |
} break;
|
| 165 |
case 192: {
|
| 166 |
-
mul_mat_vec<T, type_acc, 192><<<block_nums, block_dims, smem, stream>>>
|
| 167 |
-
(x, y, ids, dst, ncols/2, nchannels_y, stride_row,
|
| 168 |
-
|
|
|
|
| 169 |
} break;
|
| 170 |
case 224: {
|
| 171 |
-
mul_mat_vec<T, type_acc, 224><<<block_nums, block_dims, smem, stream>>>
|
| 172 |
-
(x, y, ids, dst, ncols/2, nchannels_y, stride_row,
|
| 173 |
-
|
|
|
|
| 174 |
} break;
|
| 175 |
case 256: {
|
| 176 |
-
mul_mat_vec<T, type_acc, 256><<<block_nums, block_dims, smem, stream>>>
|
| 177 |
-
(x, y, ids, dst, ncols/2, nchannels_y, stride_row,
|
| 178 |
-
|
|
|
|
| 179 |
} break;
|
| 180 |
default: {
|
| 181 |
GGML_ABORT("fatal error");
|
|
@@ -183,23 +219,91 @@ static void launch_mul_mat_vec_cuda(
|
|
| 183 |
}
|
| 184 |
}
|
| 185 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 186 |
template<typename T>
|
| 187 |
static void mul_mat_vec_cuda(
|
| 188 |
const T * x, const float * y, const int32_t * ids, float * dst,
|
| 189 |
-
const int64_t ncols, const int64_t nrows, const int64_t
|
|
|
|
|
|
|
| 190 |
const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,
|
| 191 |
const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
|
| 192 |
enum ggml_prec prec, cudaStream_t stream) {
|
| 193 |
if constexpr(std::is_same<T, half>::value) {
|
| 194 |
if (prec == GGML_PREC_DEFAULT) {
|
| 195 |
-
|
| 196 |
-
(x, y, ids, dst, ncols, nrows,
|
|
|
|
| 197 |
stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
| 198 |
return;
|
| 199 |
}
|
| 200 |
}
|
| 201 |
-
|
| 202 |
-
(x, y, ids, dst, ncols, nrows,
|
|
|
|
| 203 |
stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
| 204 |
}
|
| 205 |
|
|
@@ -246,24 +350,24 @@ void ggml_cuda_mul_mat_vec(ggml_backend_cuda_context & ctx, const ggml_tensor *
|
|
| 246 |
const int64_t stride_channel_dst = ids ? s1 : s2;
|
| 247 |
const int64_t stride_channel_y = ids ? s11 : s12;
|
| 248 |
|
| 249 |
-
GGML_ASSERT(ncols_dst == 1);
|
| 250 |
|
| 251 |
switch (src0->type) {
|
| 252 |
case GGML_TYPE_F32: {
|
| 253 |
const float * src0_d = (const float *) src0->data;
|
| 254 |
-
mul_mat_vec_cuda(src0_d, src1_d, ids_d, dst_d, ne00, ne01, s01,
|
| 255 |
ne02, nchannels_y, nchannels_dst, s02, stride_channel_y, stride_channel_dst,
|
| 256 |
ne03, ne3, s03, s13, s3, prec, ctx.stream());
|
| 257 |
} break;
|
| 258 |
case GGML_TYPE_F16: {
|
| 259 |
const half * src0_d = (const half *) src0->data;
|
| 260 |
-
mul_mat_vec_cuda(src0_d, src1_d, ids_d, dst_d, ne00, ne01, s01,
|
| 261 |
ne02, nchannels_y, nchannels_dst, s02, stride_channel_y, stride_channel_dst,
|
| 262 |
ne03, ne3, s03, s13, s3, prec, ctx.stream());
|
| 263 |
} break;
|
| 264 |
case GGML_TYPE_BF16: {
|
| 265 |
const nv_bfloat16 * src0_d = (const nv_bfloat16 *) src0->data;
|
| 266 |
-
mul_mat_vec_cuda(src0_d, src1_d, ids_d, dst_d, ne00, ne01, s01,
|
| 267 |
ne02, nchannels_y, nchannels_dst, s02, stride_channel_y, stride_channel_dst,
|
| 268 |
ne03, ne3, s03, s13, s3, prec, ctx.stream());
|
| 269 |
} break;
|
|
@@ -282,16 +386,19 @@ void ggml_cuda_op_mul_mat_vec(
|
|
| 282 |
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
| 283 |
|
| 284 |
const int64_t ne00 = src0->ne[0];
|
|
|
|
|
|
|
| 285 |
const int64_t row_diff = row_high - row_low;
|
| 286 |
|
| 287 |
-
|
| 288 |
-
|
| 289 |
-
const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
|
| 290 |
const enum ggml_prec prec = fast_fp16_available(cc) ? ggml_prec(dst->op_params[0]) : GGML_PREC_F32;
|
| 291 |
|
| 292 |
|
| 293 |
// ggml_cuda_op provides single, contiguous matrices
|
| 294 |
const int64_t stride_row = ne00;
|
|
|
|
|
|
|
| 295 |
const int64_t nchannels_x = 1;
|
| 296 |
const int64_t nchannels_y = 1;
|
| 297 |
const int64_t nchannels_dst = 1;
|
|
@@ -307,19 +414,19 @@ void ggml_cuda_op_mul_mat_vec(
|
|
| 307 |
switch (src0->type) {
|
| 308 |
case GGML_TYPE_F32: {
|
| 309 |
const float * src0_d = (const float *) src0_dd_i;
|
| 310 |
-
mul_mat_vec_cuda(src0_d, src1_ddf_i, nullptr, dst_dd_i, ne00, row_diff, stride_row,
|
| 311 |
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
| 312 |
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, prec, stream);
|
| 313 |
} break;
|
| 314 |
case GGML_TYPE_F16: {
|
| 315 |
const half * src0_d = (const half *) src0_dd_i;
|
| 316 |
-
mul_mat_vec_cuda(src0_d, src1_ddf_i, nullptr, dst_dd_i, ne00, row_diff, stride_row,
|
| 317 |
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
| 318 |
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, prec, stream);
|
| 319 |
} break;
|
| 320 |
case GGML_TYPE_BF16: {
|
| 321 |
const nv_bfloat16 * src0_d = (const nv_bfloat16 *) src0_dd_i;
|
| 322 |
-
mul_mat_vec_cuda(src0_d, src1_ddf_i, nullptr, dst_dd_i, ne00, row_diff, stride_row,
|
| 323 |
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
| 324 |
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, prec, stream);
|
| 325 |
} break;
|
|
@@ -334,3 +441,48 @@ void ggml_cuda_op_mul_mat_vec(
|
|
| 334 |
GGML_UNUSED(src1_ncols);
|
| 335 |
GGML_UNUSED(src1_padded_row_size);
|
| 336 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
#include "common.cuh"
|
| 3 |
#include "mmv.cuh"
|
| 4 |
|
| 5 |
+
template <typename T, typename type_acc, int ncols_dst, int block_size>
|
| 6 |
static __global__ void mul_mat_vec(
|
| 7 |
const T * __restrict__ x, const float * __restrict__ y, const int32_t * __restrict__ ids, float * __restrict__ dst,
|
| 8 |
+
const int ncols2, const int nchannels_y, const int stride_row, const int stride_col_y2, const int stride_col_dst,
|
| 9 |
+
const int channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
|
| 10 |
+
const int sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst) {
|
| 11 |
+
const int row = blockIdx.x;
|
| 12 |
+
const int channel_dst = blockIdx.y;
|
| 13 |
+
const int channel_x = ids ? ids[channel_dst] : channel_dst / channel_ratio;
|
| 14 |
+
const int channel_y = ids ? channel_dst % nchannels_y : channel_dst;
|
| 15 |
+
const int sample_dst = blockIdx.z;
|
| 16 |
+
const int sample_x = sample_dst / sample_ratio;
|
| 17 |
+
const int sample_y = sample_dst;
|
| 18 |
+
const int tid = threadIdx.x;
|
| 19 |
+
|
| 20 |
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
| 21 |
|
| 22 |
+
x += int64_t(sample_x) *stride_sample_x + channel_x *stride_channel_x + row*stride_row;
|
| 23 |
+
y += int64_t(sample_y) *stride_sample_y + channel_y *stride_channel_y;
|
| 24 |
+
dst += int64_t(sample_dst)*stride_sample_dst + channel_dst*stride_channel_dst;
|
| 25 |
|
| 26 |
const float2 * y2 = (const float2 *) y;
|
| 27 |
|
|
|
|
| 35 |
__syncthreads();
|
| 36 |
}
|
| 37 |
|
| 38 |
+
float sumf[ncols_dst] = {0.0f};
|
| 39 |
|
| 40 |
if constexpr (std::is_same<T, float>::value) {
|
| 41 |
const float2 * x2 = (const float2 *) x;
|
| 42 |
|
| 43 |
+
for (int col2 = tid; col2 < ncols2; col2 += block_size) {
|
| 44 |
const float2 tmpx = x2[col2];
|
| 45 |
+
|
| 46 |
+
#pragma unroll
|
| 47 |
+
for (int j = 0; j < ncols_dst; ++j) {
|
| 48 |
+
const float2 tmpy = y2[j*stride_col_y2 + col2];
|
| 49 |
+
sumf[j] += tmpx.x*tmpy.x;
|
| 50 |
+
sumf[j] += tmpx.y*tmpy.y;
|
| 51 |
+
}
|
| 52 |
}
|
| 53 |
} else if constexpr (std::is_same<T, half>::value) {
|
| 54 |
const half2 * x2 = (const half2 *) x;
|
| 55 |
|
| 56 |
if (std::is_same<type_acc, float>::value) {
|
| 57 |
+
for (int col2 = tid; col2 < ncols2; col2 += block_size) {
|
| 58 |
const float2 tmpx = __half22float2(x2[col2]);
|
| 59 |
+
|
| 60 |
+
#pragma unroll
|
| 61 |
+
for (int j = 0; j < ncols_dst; ++j) {
|
| 62 |
+
const float2 tmpy = y2[j*stride_col_y2 + col2];
|
| 63 |
+
sumf[j] += tmpx.x * tmpy.x;
|
| 64 |
+
sumf[j] += tmpx.y * tmpy.y;
|
| 65 |
+
}
|
| 66 |
}
|
| 67 |
} else {
|
| 68 |
#ifdef FP16_AVAILABLE
|
| 69 |
+
half2 sumh2[ncols_dst] = {{0.0f, 0.0f}};
|
| 70 |
+
|
| 71 |
+
for (int col2 = tid; col2 < ncols2; col2 += block_size) {
|
| 72 |
+
const half2 tmpx = x2[col2];
|
| 73 |
|
| 74 |
+
#pragma unroll
|
| 75 |
+
for (int j = 0; j < ncols_dst; ++j) {
|
| 76 |
+
const float2 tmpy = y2[j*stride_col_y2 + col2];
|
| 77 |
+
sumh2[j] += tmpx * make_half2(tmpy.x, tmpy.y);
|
| 78 |
+
}
|
| 79 |
}
|
| 80 |
|
| 81 |
+
#pragma unroll
|
| 82 |
+
for (int j = 0; j < ncols_dst; ++j) {
|
| 83 |
+
sumf[j] = __low2float(sumh2[j]) + __high2float(sumh2[j]);
|
| 84 |
+
}
|
| 85 |
#else
|
| 86 |
NO_DEVICE_CODE;
|
| 87 |
#endif // FP16_AVAILABLE
|
| 88 |
}
|
| 89 |
} else if constexpr (std::is_same<T, nv_bfloat16>::value) {
|
| 90 |
const int * x2 = (const int *) x;
|
| 91 |
+
for (int col2 = tid; col2 < ncols2; col2 += block_size) {
|
| 92 |
+
const int tmpx = x2[col2];
|
| 93 |
+
#pragma unroll
|
| 94 |
+
for (int j = 0; j < ncols_dst; ++j) {
|
| 95 |
+
const float2 tmpy = y2[j*stride_col_y2 + col2];
|
| 96 |
+
sumf[j] += float(reinterpret_cast<const nv_bfloat16 *>(&tmpx)[0]) * tmpy.x;
|
| 97 |
+
sumf[j] += float(reinterpret_cast<const nv_bfloat16 *>(&tmpx)[1]) * tmpy.y;
|
| 98 |
+
}
|
| 99 |
}
|
| 100 |
} else {
|
| 101 |
static_assert(std::is_same<T, void>::value, "unsupported type");
|
| 102 |
}
|
| 103 |
|
| 104 |
+
#pragma unroll
|
| 105 |
+
for (int j = 0; j < ncols_dst; ++j) {
|
| 106 |
+
sumf[j] = warp_reduce_sum<warp_size>(sumf[j]);
|
| 107 |
|
| 108 |
+
if (block_size > warp_size) {
|
| 109 |
+
buf_iw[tid/warp_size] = sumf[j];
|
| 110 |
+
__syncthreads();
|
| 111 |
+
if (tid < warp_size) {
|
| 112 |
+
sumf[j] = buf_iw[tid];
|
| 113 |
+
sumf[j] = warp_reduce_sum<warp_size>(sumf[j]);
|
| 114 |
+
}
|
| 115 |
+
if (j < ncols_dst) {
|
| 116 |
+
__syncthreads();
|
| 117 |
+
}
|
| 118 |
}
|
|
|
|
|
|
|
| 119 |
}
|
| 120 |
|
| 121 |
+
if (tid >= ncols_dst) {
|
| 122 |
return;
|
| 123 |
}
|
| 124 |
|
| 125 |
+
dst[tid*stride_col_dst + row] = sumf[tid];
|
| 126 |
}
|
| 127 |
|
| 128 |
+
template <typename T, typename type_acc, int ncols_dst>
|
| 129 |
static void launch_mul_mat_vec_cuda(
|
| 130 |
const T * x, const float * y, const int32_t * ids, float * dst,
|
| 131 |
+
const int64_t ncols, const int64_t nrows,
|
| 132 |
+
const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst,
|
| 133 |
+
const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst,
|
| 134 |
const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,
|
| 135 |
const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
|
| 136 |
cudaStream_t stream) {
|
| 137 |
+
GGML_ASSERT(ncols % 2 == 0);
|
| 138 |
+
GGML_ASSERT(stride_row % 2 == 0);
|
| 139 |
+
GGML_ASSERT(stride_col_y % 2 == 0);
|
| 140 |
GGML_ASSERT(ids || nchannels_dst % nchannels_x == 0);
|
| 141 |
GGML_ASSERT( nsamples_dst % nsamples_x == 0);
|
| 142 |
const int64_t channel_ratio = nchannels_dst / nchannels_x;
|
|
|
|
| 166 |
const dim3 block_dims(block_size_best, 1, 1);
|
| 167 |
switch (block_size_best) {
|
| 168 |
case 32: {
|
| 169 |
+
mul_mat_vec<T, type_acc, ncols_dst, 32><<<block_nums, block_dims, smem, stream>>>
|
| 170 |
+
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
|
| 171 |
+
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
| 172 |
+
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
|
| 173 |
} break;
|
| 174 |
case 64: {
|
| 175 |
+
mul_mat_vec<T, type_acc, ncols_dst, 64><<<block_nums, block_dims, smem, stream>>>
|
| 176 |
+
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
|
| 177 |
+
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
| 178 |
+
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
|
| 179 |
} break;
|
| 180 |
case 96: {
|
| 181 |
+
mul_mat_vec<T, type_acc, ncols_dst, 96><<<block_nums, block_dims, smem, stream>>>
|
| 182 |
+
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
|
| 183 |
+
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
| 184 |
+
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
|
| 185 |
} break;
|
| 186 |
case 128: {
|
| 187 |
+
mul_mat_vec<T, type_acc, ncols_dst, 128><<<block_nums, block_dims, smem, stream>>>
|
| 188 |
+
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
|
| 189 |
+
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
| 190 |
+
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
|
| 191 |
} break;
|
| 192 |
case 160: {
|
| 193 |
+
mul_mat_vec<T, type_acc, ncols_dst, 160><<<block_nums, block_dims, smem, stream>>>
|
| 194 |
+
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
|
| 195 |
+
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
| 196 |
+
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
|
| 197 |
} break;
|
| 198 |
case 192: {
|
| 199 |
+
mul_mat_vec<T, type_acc, ncols_dst, 192><<<block_nums, block_dims, smem, stream>>>
|
| 200 |
+
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
|
| 201 |
+
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
| 202 |
+
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
|
| 203 |
} break;
|
| 204 |
case 224: {
|
| 205 |
+
mul_mat_vec<T, type_acc, ncols_dst, 224><<<block_nums, block_dims, smem, stream>>>
|
| 206 |
+
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
|
| 207 |
+
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
| 208 |
+
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
|
| 209 |
} break;
|
| 210 |
case 256: {
|
| 211 |
+
mul_mat_vec<T, type_acc, ncols_dst, 256><<<block_nums, block_dims, smem, stream>>>
|
| 212 |
+
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
|
| 213 |
+
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
| 214 |
+
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
|
| 215 |
} break;
|
| 216 |
default: {
|
| 217 |
GGML_ABORT("fatal error");
|
|
|
|
| 219 |
}
|
| 220 |
}
|
| 221 |
|
| 222 |
+
template <typename T, typename type_acc>
|
| 223 |
+
static void mul_mat_vec_cuda_switch_ncols_dst(
|
| 224 |
+
const T * x, const float * y, const int32_t * ids, float * dst,
|
| 225 |
+
const int64_t ncols, const int64_t nrows, const int64_t ncols_dst,
|
| 226 |
+
const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst,
|
| 227 |
+
const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst,
|
| 228 |
+
const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,
|
| 229 |
+
const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
|
| 230 |
+
cudaStream_t stream) {
|
| 231 |
+
switch (ncols_dst) {
|
| 232 |
+
case 1:
|
| 233 |
+
launch_mul_mat_vec_cuda<T, type_acc, 1>
|
| 234 |
+
(x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
|
| 235 |
+
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
|
| 236 |
+
stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
| 237 |
+
break;
|
| 238 |
+
case 2:
|
| 239 |
+
launch_mul_mat_vec_cuda<T, type_acc, 2>
|
| 240 |
+
(x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
|
| 241 |
+
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
|
| 242 |
+
stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
| 243 |
+
break;
|
| 244 |
+
case 3:
|
| 245 |
+
launch_mul_mat_vec_cuda<T, type_acc, 3>
|
| 246 |
+
(x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
|
| 247 |
+
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
|
| 248 |
+
stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
| 249 |
+
break;
|
| 250 |
+
case 4:
|
| 251 |
+
launch_mul_mat_vec_cuda<T, type_acc, 4>
|
| 252 |
+
(x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
|
| 253 |
+
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
|
| 254 |
+
stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
| 255 |
+
break;
|
| 256 |
+
case 5:
|
| 257 |
+
launch_mul_mat_vec_cuda<T, type_acc, 5>
|
| 258 |
+
(x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
|
| 259 |
+
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
|
| 260 |
+
stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
| 261 |
+
break;
|
| 262 |
+
case 6:
|
| 263 |
+
launch_mul_mat_vec_cuda<T, type_acc, 6>
|
| 264 |
+
(x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
|
| 265 |
+
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
|
| 266 |
+
stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
| 267 |
+
break;
|
| 268 |
+
case 7:
|
| 269 |
+
launch_mul_mat_vec_cuda<T, type_acc, 7>
|
| 270 |
+
(x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
|
| 271 |
+
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
|
| 272 |
+
stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
| 273 |
+
break;
|
| 274 |
+
case 8:
|
| 275 |
+
launch_mul_mat_vec_cuda<T, type_acc, 8>
|
| 276 |
+
(x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
|
| 277 |
+
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
|
| 278 |
+
stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
| 279 |
+
break;
|
| 280 |
+
default:
|
| 281 |
+
GGML_ABORT("fatal error");
|
| 282 |
+
break;
|
| 283 |
+
}
|
| 284 |
+
}
|
| 285 |
+
|
| 286 |
template<typename T>
|
| 287 |
static void mul_mat_vec_cuda(
|
| 288 |
const T * x, const float * y, const int32_t * ids, float * dst,
|
| 289 |
+
const int64_t ncols, const int64_t nrows, const int64_t ncols_dst,
|
| 290 |
+
const int64_t stride_row, const int64_t stride_col_y, const int stride_col_dst,
|
| 291 |
+
const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst,
|
| 292 |
const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,
|
| 293 |
const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
|
| 294 |
enum ggml_prec prec, cudaStream_t stream) {
|
| 295 |
if constexpr(std::is_same<T, half>::value) {
|
| 296 |
if (prec == GGML_PREC_DEFAULT) {
|
| 297 |
+
mul_mat_vec_cuda_switch_ncols_dst<T, half>
|
| 298 |
+
(x, y, ids, dst, ncols, nrows, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
| 299 |
+
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
|
| 300 |
stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
| 301 |
return;
|
| 302 |
}
|
| 303 |
}
|
| 304 |
+
mul_mat_vec_cuda_switch_ncols_dst<T, float>
|
| 305 |
+
(x, y, ids, dst, ncols, nrows, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
| 306 |
+
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
|
| 307 |
stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
| 308 |
}
|
| 309 |
|
|
|
|
| 350 |
const int64_t stride_channel_dst = ids ? s1 : s2;
|
| 351 |
const int64_t stride_channel_y = ids ? s11 : s12;
|
| 352 |
|
| 353 |
+
GGML_ASSERT(!ids || ncols_dst == 1);
|
| 354 |
|
| 355 |
switch (src0->type) {
|
| 356 |
case GGML_TYPE_F32: {
|
| 357 |
const float * src0_d = (const float *) src0->data;
|
| 358 |
+
mul_mat_vec_cuda(src0_d, src1_d, ids_d, dst_d, ne00, ne01, ncols_dst, s01, s11, s1,
|
| 359 |
ne02, nchannels_y, nchannels_dst, s02, stride_channel_y, stride_channel_dst,
|
| 360 |
ne03, ne3, s03, s13, s3, prec, ctx.stream());
|
| 361 |
} break;
|
| 362 |
case GGML_TYPE_F16: {
|
| 363 |
const half * src0_d = (const half *) src0->data;
|
| 364 |
+
mul_mat_vec_cuda(src0_d, src1_d, ids_d, dst_d, ne00, ne01, ncols_dst, s01, s11, s1,
|
| 365 |
ne02, nchannels_y, nchannels_dst, s02, stride_channel_y, stride_channel_dst,
|
| 366 |
ne03, ne3, s03, s13, s3, prec, ctx.stream());
|
| 367 |
} break;
|
| 368 |
case GGML_TYPE_BF16: {
|
| 369 |
const nv_bfloat16 * src0_d = (const nv_bfloat16 *) src0->data;
|
| 370 |
+
mul_mat_vec_cuda(src0_d, src1_d, ids_d, dst_d, ne00, ne01, ncols_dst, s01, s11, s1,
|
| 371 |
ne02, nchannels_y, nchannels_dst, s02, stride_channel_y, stride_channel_dst,
|
| 372 |
ne03, ne3, s03, s13, s3, prec, ctx.stream());
|
| 373 |
} break;
|
|
|
|
| 386 |
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
| 387 |
|
| 388 |
const int64_t ne00 = src0->ne[0];
|
| 389 |
+
const int64_t ne10 = src1->ne[0];
|
| 390 |
+
const int64_t ne0 = dst->ne[0];
|
| 391 |
const int64_t row_diff = row_high - row_low;
|
| 392 |
|
| 393 |
+
const int id = ggml_cuda_get_device();
|
| 394 |
+
const int cc = ggml_cuda_info().devices[id].cc;
|
|
|
|
| 395 |
const enum ggml_prec prec = fast_fp16_available(cc) ? ggml_prec(dst->op_params[0]) : GGML_PREC_F32;
|
| 396 |
|
| 397 |
|
| 398 |
// ggml_cuda_op provides single, contiguous matrices
|
| 399 |
const int64_t stride_row = ne00;
|
| 400 |
+
const int64_t stride_col_y = ne10;
|
| 401 |
+
const int64_t stride_col_dst = id == ctx.device ? ne0 : row_diff; // main device has larger memory buffer
|
| 402 |
const int64_t nchannels_x = 1;
|
| 403 |
const int64_t nchannels_y = 1;
|
| 404 |
const int64_t nchannels_dst = 1;
|
|
|
|
| 414 |
switch (src0->type) {
|
| 415 |
case GGML_TYPE_F32: {
|
| 416 |
const float * src0_d = (const float *) src0_dd_i;
|
| 417 |
+
mul_mat_vec_cuda(src0_d, src1_ddf_i, nullptr, dst_dd_i, ne00, row_diff, src1_ncols, stride_row, stride_col_y, stride_col_dst,
|
| 418 |
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
| 419 |
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, prec, stream);
|
| 420 |
} break;
|
| 421 |
case GGML_TYPE_F16: {
|
| 422 |
const half * src0_d = (const half *) src0_dd_i;
|
| 423 |
+
mul_mat_vec_cuda(src0_d, src1_ddf_i, nullptr, dst_dd_i, ne00, row_diff, src1_ncols, stride_row, stride_col_y, stride_col_dst,
|
| 424 |
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
| 425 |
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, prec, stream);
|
| 426 |
} break;
|
| 427 |
case GGML_TYPE_BF16: {
|
| 428 |
const nv_bfloat16 * src0_d = (const nv_bfloat16 *) src0_dd_i;
|
| 429 |
+
mul_mat_vec_cuda(src0_d, src1_ddf_i, nullptr, dst_dd_i, ne00, row_diff, src1_ncols, stride_row, stride_col_y, stride_col_dst,
|
| 430 |
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
| 431 |
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, prec, stream);
|
| 432 |
} break;
|
|
|
|
| 441 |
GGML_UNUSED(src1_ncols);
|
| 442 |
GGML_UNUSED(src1_padded_row_size);
|
| 443 |
}
|
| 444 |
+
|
| 445 |
+
bool ggml_cuda_should_use_mmv(enum ggml_type type, int cc, const int64_t * src0_ne, int64_t ne11) {
|
| 446 |
+
if (src0_ne[0] % 2 != 0) {
|
| 447 |
+
return false;
|
| 448 |
+
}
|
| 449 |
+
switch (type) {
|
| 450 |
+
case GGML_TYPE_F32:
|
| 451 |
+
if (GGML_CUDA_CC_IS_NVIDIA(cc)) {
|
| 452 |
+
if (cc >= GGML_CUDA_CC_ADA_LOVELACE) {
|
| 453 |
+
return ne11 <= 8;
|
| 454 |
+
}
|
| 455 |
+
if (cc >= GGML_CUDA_CC_TURING) {
|
| 456 |
+
return ne11 <= 4;
|
| 457 |
+
}
|
| 458 |
+
return ne11 <= 3;
|
| 459 |
+
}
|
| 460 |
+
return ne11 <= 8;
|
| 461 |
+
case GGML_TYPE_F16:
|
| 462 |
+
if (GGML_CUDA_CC_IS_NVIDIA(cc)) {
|
| 463 |
+
const bool src0_small = (src0_ne[1] <= 512 || src0_ne[2]*src0_ne[3] == 1);
|
| 464 |
+
if (cc >= GGML_CUDA_CC_ADA_LOVELACE) {
|
| 465 |
+
return src0_small && ne11 <= 4;
|
| 466 |
+
}
|
| 467 |
+
if (fp16_mma_hardware_available(cc)) {
|
| 468 |
+
return src0_small && ne11 <= 3;
|
| 469 |
+
}
|
| 470 |
+
return ne11 <= 8;
|
| 471 |
+
}
|
| 472 |
+
return ne11 <= 8;
|
| 473 |
+
case GGML_TYPE_BF16:
|
| 474 |
+
if (GGML_CUDA_CC_IS_NVIDIA(cc)) {
|
| 475 |
+
const bool src0_small = (src0_ne[1] <= 512 || src0_ne[2]*src0_ne[3] == 1);
|
| 476 |
+
if (cc >= GGML_CUDA_CC_ADA_LOVELACE) {
|
| 477 |
+
return src0_small && ne11 <= 4;
|
| 478 |
+
}
|
| 479 |
+
if (bf16_mma_hardware_available(cc)) {
|
| 480 |
+
return src0_small && ne11 <= 3;
|
| 481 |
+
}
|
| 482 |
+
return ne11 <= 8;
|
| 483 |
+
}
|
| 484 |
+
return ne11 <= 8;
|
| 485 |
+
default:
|
| 486 |
+
return false;
|
| 487 |
+
}
|
| 488 |
+
}
|
ggml/src/ggml-cuda/mmv.cuh
CHANGED
|
@@ -1,8 +1,5 @@
|
|
| 1 |
#include "common.cuh"
|
| 2 |
|
| 3 |
-
// maximum number of src0 rows with which to use mul_mat_vec over cuBLAS if FP16 tensor cores are available
|
| 4 |
-
#define MMV_MAX_ROWS 512
|
| 5 |
-
|
| 6 |
void ggml_cuda_mul_mat_vec(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst);
|
| 7 |
|
| 8 |
void ggml_cuda_op_mul_mat_vec(
|
|
@@ -10,3 +7,5 @@ void ggml_cuda_op_mul_mat_vec(
|
|
| 10 |
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i,
|
| 11 |
const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,
|
| 12 |
const int64_t src1_padded_row_size, cudaStream_t stream);
|
|
|
|
|
|
|
|
|
| 1 |
#include "common.cuh"
|
| 2 |
|
|
|
|
|
|
|
|
|
|
| 3 |
void ggml_cuda_mul_mat_vec(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst);
|
| 4 |
|
| 5 |
void ggml_cuda_op_mul_mat_vec(
|
|
|
|
| 7 |
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i,
|
| 8 |
const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,
|
| 9 |
const int64_t src1_padded_row_size, cudaStream_t stream);
|
| 10 |
+
|
| 11 |
+
bool ggml_cuda_should_use_mmv(enum ggml_type type, int cc, const int64_t * src0_ne, int64_t ne11);
|