Spaces:
Running
Running
Commit
·
d13b876
1
Parent(s):
9e9f2fe
CUDA: fix non-cont. inputs for batched mat mul (llama/13155)
Browse files- ggml/src/ggml-cuda/convert.cu +41 -12
- ggml/src/ggml-cuda/convert.cuh +11 -1
- ggml/src/ggml-cuda/ggml-cuda.cu +42 -28
ggml/src/ggml-cuda/convert.cu
CHANGED
|
@@ -1,6 +1,8 @@
|
|
| 1 |
#include "convert.cuh"
|
| 2 |
#include "dequantize.cuh"
|
| 3 |
|
|
|
|
|
|
|
| 4 |
#define CUDA_Q8_0_NE_ALIGN 2048
|
| 5 |
|
| 6 |
template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
|
|
@@ -570,30 +572,46 @@ static void dequantize_row_iq4_xs_cuda(const void * vx, dst_t * y, const int64_t
|
|
| 570 |
}
|
| 571 |
|
| 572 |
template <typename src_t, typename dst_t>
|
| 573 |
-
static __global__ void convert_unary(
|
| 574 |
-
|
|
|
|
|
|
|
| 575 |
|
| 576 |
-
if (
|
| 577 |
return;
|
| 578 |
}
|
| 579 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 580 |
const src_t * x = (const src_t *) vx;
|
| 581 |
|
| 582 |
-
|
|
|
|
|
|
|
| 583 |
}
|
| 584 |
|
| 585 |
template <typename src_t, typename dst_t>
|
| 586 |
-
static void convert_unary_cuda(const void *
|
| 587 |
-
|
| 588 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 589 |
}
|
| 590 |
|
| 591 |
to_bf16_cuda_t ggml_get_to_bf16_cuda(ggml_type type) {
|
| 592 |
switch (type) {
|
| 593 |
case GGML_TYPE_F32:
|
| 594 |
-
return
|
| 595 |
case GGML_TYPE_F16:
|
| 596 |
-
return
|
| 597 |
default:
|
| 598 |
return nullptr;
|
| 599 |
}
|
|
@@ -643,9 +661,9 @@ to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) {
|
|
| 643 |
case GGML_TYPE_IQ3_S:
|
| 644 |
return dequantize_row_iq3_s_cuda;
|
| 645 |
case GGML_TYPE_F32:
|
| 646 |
-
return
|
| 647 |
case GGML_TYPE_BF16:
|
| 648 |
-
return
|
| 649 |
default:
|
| 650 |
return nullptr;
|
| 651 |
}
|
|
@@ -692,7 +710,18 @@ to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) {
|
|
| 692 |
case GGML_TYPE_IQ3_S:
|
| 693 |
return dequantize_row_iq3_s_cuda;
|
| 694 |
case GGML_TYPE_F16:
|
| 695 |
-
return
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 696 |
case GGML_TYPE_BF16:
|
| 697 |
return convert_unary_cuda<nv_bfloat16>;
|
| 698 |
default:
|
|
|
|
| 1 |
#include "convert.cuh"
|
| 2 |
#include "dequantize.cuh"
|
| 3 |
|
| 4 |
+
#include <cstdint>
|
| 5 |
+
|
| 6 |
#define CUDA_Q8_0_NE_ALIGN 2048
|
| 7 |
|
| 8 |
template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
|
|
|
|
| 572 |
}
|
| 573 |
|
| 574 |
template <typename src_t, typename dst_t>
|
| 575 |
+
static __global__ void convert_unary(
|
| 576 |
+
const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t ne00, const int64_t ne01, const int64_t ne02,
|
| 577 |
+
const int64_t s01, const int64_t s02, const int64_t s03) {
|
| 578 |
+
const int64_t i00 = (int64_t)blockDim.x*blockIdx.x + threadIdx.x;
|
| 579 |
|
| 580 |
+
if (i00 >= ne00) {
|
| 581 |
return;
|
| 582 |
}
|
| 583 |
|
| 584 |
+
const int64_t i01 = blockIdx.y;
|
| 585 |
+
const int64_t i02 = blockIdx.z % ne02;
|
| 586 |
+
const int64_t i03 = blockIdx.z / ne02;
|
| 587 |
+
|
| 588 |
const src_t * x = (const src_t *) vx;
|
| 589 |
|
| 590 |
+
const int64_t ix = i03*s03 + i02*s02 + i01*s01 + i00;
|
| 591 |
+
const int64_t iy = ((i03*ne02 + i02)*ne01 + i01)*ne00 + i00;
|
| 592 |
+
y[iy] = float(x[ix]);
|
| 593 |
}
|
| 594 |
|
| 595 |
template <typename src_t, typename dst_t>
|
| 596 |
+
static void convert_unary_cuda(const void * vx, dst_t * y,
|
| 597 |
+
const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
|
| 598 |
+
const int64_t s01, const int64_t s02, const int64_t s03, cudaStream_t stream) {
|
| 599 |
+
const dim3 num_blocks((ne00 + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE, ne01, ne02*ne03);
|
| 600 |
+
convert_unary<src_t><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>
|
| 601 |
+
(vx, y, ne00, ne01, ne02, s01, s02, s03);
|
| 602 |
+
}
|
| 603 |
+
|
| 604 |
+
template <typename src_t, typename dst_t>
|
| 605 |
+
static void convert_unary_cont_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
|
| 606 |
+
convert_unary_cuda<src_t>(vx, y, k, 1, 1, 1, k, k, k, stream);
|
| 607 |
}
|
| 608 |
|
| 609 |
to_bf16_cuda_t ggml_get_to_bf16_cuda(ggml_type type) {
|
| 610 |
switch (type) {
|
| 611 |
case GGML_TYPE_F32:
|
| 612 |
+
return convert_unary_cont_cuda<float>;
|
| 613 |
case GGML_TYPE_F16:
|
| 614 |
+
return convert_unary_cont_cuda<half>;
|
| 615 |
default:
|
| 616 |
return nullptr;
|
| 617 |
}
|
|
|
|
| 661 |
case GGML_TYPE_IQ3_S:
|
| 662 |
return dequantize_row_iq3_s_cuda;
|
| 663 |
case GGML_TYPE_F32:
|
| 664 |
+
return convert_unary_cont_cuda<float>;
|
| 665 |
case GGML_TYPE_BF16:
|
| 666 |
+
return convert_unary_cont_cuda<nv_bfloat16>;
|
| 667 |
default:
|
| 668 |
return nullptr;
|
| 669 |
}
|
|
|
|
| 710 |
case GGML_TYPE_IQ3_S:
|
| 711 |
return dequantize_row_iq3_s_cuda;
|
| 712 |
case GGML_TYPE_F16:
|
| 713 |
+
return convert_unary_cont_cuda<half>;
|
| 714 |
+
case GGML_TYPE_BF16:
|
| 715 |
+
return convert_unary_cont_cuda<nv_bfloat16>;
|
| 716 |
+
default:
|
| 717 |
+
return nullptr;
|
| 718 |
+
}
|
| 719 |
+
}
|
| 720 |
+
|
| 721 |
+
to_fp16_nc_cuda_t ggml_get_to_fp16_nc_cuda(ggml_type type) {
|
| 722 |
+
switch (type) {
|
| 723 |
+
case GGML_TYPE_F32:
|
| 724 |
+
return convert_unary_cuda<float>;
|
| 725 |
case GGML_TYPE_BF16:
|
| 726 |
return convert_unary_cuda<nv_bfloat16>;
|
| 727 |
default:
|
ggml/src/ggml-cuda/convert.cuh
CHANGED
|
@@ -3,7 +3,7 @@
|
|
| 3 |
#define CUDA_DEQUANTIZE_BLOCK_SIZE 256
|
| 4 |
|
| 5 |
template<typename T>
|
| 6 |
-
using to_t_cuda_t = void (*)(const void *
|
| 7 |
|
| 8 |
typedef to_t_cuda_t<float> to_fp32_cuda_t;
|
| 9 |
typedef to_t_cuda_t<half> to_fp16_cuda_t;
|
|
@@ -14,3 +14,13 @@ to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type);
|
|
| 14 |
to_bf16_cuda_t ggml_get_to_bf16_cuda(ggml_type type);
|
| 15 |
|
| 16 |
to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
#define CUDA_DEQUANTIZE_BLOCK_SIZE 256
|
| 4 |
|
| 5 |
template<typename T>
|
| 6 |
+
using to_t_cuda_t = void (*)(const void * x, T * y, int64_t k, cudaStream_t stream);
|
| 7 |
|
| 8 |
typedef to_t_cuda_t<float> to_fp32_cuda_t;
|
| 9 |
typedef to_t_cuda_t<half> to_fp16_cuda_t;
|
|
|
|
| 14 |
to_bf16_cuda_t ggml_get_to_bf16_cuda(ggml_type type);
|
| 15 |
|
| 16 |
to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type);
|
| 17 |
+
|
| 18 |
+
// TODO more general support for non-contiguous inputs
|
| 19 |
+
|
| 20 |
+
template<typename T>
|
| 21 |
+
using to_t_nc_cuda_t = void (*)(const void * x, T * y,
|
| 22 |
+
int64_t ne00, int64_t ne01, int64_t ne02, int64_t ne03,
|
| 23 |
+
int64_t s01, int64_t s02, int64_t s03, cudaStream_t stream);
|
| 24 |
+
|
| 25 |
+
typedef to_t_nc_cuda_t<half> to_fp16_nc_cuda_t;
|
| 26 |
+
to_fp16_nc_cuda_t ggml_get_to_fp16_nc_cuda(ggml_type type);
|
ggml/src/ggml-cuda/ggml-cuda.cu
CHANGED
|
@@ -1720,15 +1720,15 @@ static __global__ void k_compute_batched_ptrs(
|
|
| 1720 |
size_t nb12, size_t nb13,
|
| 1721 |
size_t nbd2, size_t nbd3,
|
| 1722 |
int64_t r2, int64_t r3) {
|
| 1723 |
-
int64_t i13 = blockIdx.x * blockDim.x + threadIdx.x;
|
| 1724 |
-
int64_t i12 = blockIdx.y * blockDim.y + threadIdx.y;
|
| 1725 |
|
| 1726 |
if (i13 >= ne13 || i12 >= ne12) {
|
| 1727 |
return;
|
| 1728 |
}
|
| 1729 |
|
| 1730 |
-
int64_t i03 = i13 / r3;
|
| 1731 |
-
int64_t i02 = i12 / r2;
|
| 1732 |
|
| 1733 |
ptrs_src[0*ne23 + i12 + i13*ne12] = (const char *) src0_as_f16 + i02*nb02 + i03*nb03;
|
| 1734 |
ptrs_src[1*ne23 + i12 + i13*ne12] = (const char *) src1_as_f16 + i12*nb12 + i13*nb13;
|
|
@@ -1742,6 +1742,10 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co
|
|
| 1742 |
GGML_ASSERT(ggml_backend_buffer_is_cuda(src0->buffer));
|
| 1743 |
GGML_ASSERT(src0->type == GGML_TYPE_F16);
|
| 1744 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1745 |
GGML_TENSOR_BINARY_OP_LOCALS
|
| 1746 |
|
| 1747 |
const int64_t ne_dst = ggml_nelements(dst);
|
|
@@ -1750,21 +1754,31 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co
|
|
| 1750 |
|
| 1751 |
CUBLAS_CHECK(cublasSetStream(ctx.cublas_handle(), main_stream));
|
| 1752 |
|
| 1753 |
-
|
| 1754 |
-
|
| 1755 |
-
float * src1_ddf = (float *) src1->data;
|
| 1756 |
-
float * dst_ddf = (float *) dst->data;
|
| 1757 |
|
| 1758 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1759 |
ggml_cuda_pool_alloc<half> src1_f16_alloc(ctx.pool());
|
|
|
|
|
|
|
| 1760 |
if (src1->type != GGML_TYPE_F16) {
|
| 1761 |
-
const
|
| 1762 |
const int64_t ne_src1 = ggml_nelements(src1);
|
| 1763 |
src1_f16_alloc.alloc(ne_src1);
|
| 1764 |
GGML_ASSERT(to_fp16_cuda != nullptr);
|
| 1765 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1766 |
}
|
| 1767 |
-
half * src1_f16 = src1->type == GGML_TYPE_F16 ? (half *) src1_ddf : src1_f16_alloc.get();
|
| 1768 |
|
| 1769 |
ggml_cuda_pool_alloc<half> dst_f16(ctx.pool());
|
| 1770 |
char * dst_t;
|
|
@@ -1824,13 +1838,13 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co
|
|
| 1824 |
int i02 = i12 / r2;
|
| 1825 |
|
| 1826 |
CUBLAS_CHECK(
|
| 1827 |
-
|
| 1828 |
-
|
| 1829 |
-
|
| 1830 |
-
|
| 1831 |
-
|
| 1832 |
-
|
| 1833 |
-
|
| 1834 |
}
|
| 1835 |
}
|
| 1836 |
}
|
|
@@ -1841,15 +1855,15 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co
|
|
| 1841 |
CUBLAS_CHECK(
|
| 1842 |
cublasGemmStridedBatchedEx(ctx.cublas_handle(), CUBLAS_OP_T, CUBLAS_OP_N,
|
| 1843 |
ne01, ne11, ne10,
|
| 1844 |
-
alpha,
|
| 1845 |
-
|
| 1846 |
-
beta,
|
| 1847 |
ne12*ne13,
|
| 1848 |
cu_compute_type,
|
| 1849 |
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
|
| 1850 |
} else {
|
| 1851 |
// use cublasGemmBatchedEx
|
| 1852 |
-
const
|
| 1853 |
|
| 1854 |
ggml_cuda_pool_alloc<const void *> ptrs_src(ctx.pool(), 2*ne23);
|
| 1855 |
ggml_cuda_pool_alloc< void *> ptrs_dst(ctx.pool(), 1*ne23);
|
|
@@ -1861,8 +1875,8 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co
|
|
| 1861 |
ne12, ne13,
|
| 1862 |
ne23,
|
| 1863 |
nb02, nb03,
|
| 1864 |
-
src1->type == GGML_TYPE_F16 ? nb12 :
|
| 1865 |
-
src1->type == GGML_TYPE_F16 ? nb13 :
|
| 1866 |
nbd2, nbd3,
|
| 1867 |
r2, r3);
|
| 1868 |
CUDA_CHECK(cudaGetLastError());
|
|
@@ -1871,8 +1885,8 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co
|
|
| 1871 |
cublasGemmBatchedEx(ctx.cublas_handle(), CUBLAS_OP_T, CUBLAS_OP_N,
|
| 1872 |
ne01, ne11, ne10,
|
| 1873 |
alpha, (const void **) (ptrs_src.get() + 0*ne23), CUDA_R_16F, nb01/nb00,
|
| 1874 |
-
(const void **) (ptrs_src.get() + 1*ne23), CUDA_R_16F,
|
| 1875 |
-
beta, ( void **) (ptrs_dst.get() + 0*ne23), cu_data_type,
|
| 1876 |
ne23,
|
| 1877 |
cu_compute_type,
|
| 1878 |
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
|
|
@@ -1936,7 +1950,7 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
|
|
| 1936 |
} else if (!split && use_mul_mat_vec_q) {
|
| 1937 |
ggml_cuda_mul_mat_vec_q(ctx, src0, src1, nullptr, dst);
|
| 1938 |
} else if (!split && src0->type == GGML_TYPE_F16 && (src1->type == GGML_TYPE_F16 || !any_gpus_with_slow_fp16) &&
|
| 1939 |
-
|
| 1940 |
// general KQ + KQV multi-batch without FlashAttention
|
| 1941 |
ggml_cuda_mul_mat_batched_cublas(ctx, src0, src1, dst);
|
| 1942 |
} else if (use_mul_mat_vec) {
|
|
|
|
| 1720 |
size_t nb12, size_t nb13,
|
| 1721 |
size_t nbd2, size_t nbd3,
|
| 1722 |
int64_t r2, int64_t r3) {
|
| 1723 |
+
const int64_t i13 = blockIdx.x * blockDim.x + threadIdx.x;
|
| 1724 |
+
const int64_t i12 = blockIdx.y * blockDim.y + threadIdx.y;
|
| 1725 |
|
| 1726 |
if (i13 >= ne13 || i12 >= ne12) {
|
| 1727 |
return;
|
| 1728 |
}
|
| 1729 |
|
| 1730 |
+
const int64_t i03 = i13 / r3;
|
| 1731 |
+
const int64_t i02 = i12 / r2;
|
| 1732 |
|
| 1733 |
ptrs_src[0*ne23 + i12 + i13*ne12] = (const char *) src0_as_f16 + i02*nb02 + i03*nb03;
|
| 1734 |
ptrs_src[1*ne23 + i12 + i13*ne12] = (const char *) src1_as_f16 + i12*nb12 + i13*nb13;
|
|
|
|
| 1742 |
GGML_ASSERT(ggml_backend_buffer_is_cuda(src0->buffer));
|
| 1743 |
GGML_ASSERT(src0->type == GGML_TYPE_F16);
|
| 1744 |
|
| 1745 |
+
// Byte offsets and tensor dimensions are currently used in an inconsistent way for dst.
|
| 1746 |
+
// As long as dst is contiguous this does not matter though.
|
| 1747 |
+
GGML_ASSERT(ggml_is_contiguous(dst));
|
| 1748 |
+
|
| 1749 |
GGML_TENSOR_BINARY_OP_LOCALS
|
| 1750 |
|
| 1751 |
const int64_t ne_dst = ggml_nelements(dst);
|
|
|
|
| 1754 |
|
| 1755 |
CUBLAS_CHECK(cublasSetStream(ctx.cublas_handle(), main_stream));
|
| 1756 |
|
| 1757 |
+
const half * src0_f16 = (const half *) src0->data;
|
| 1758 |
+
float * dst_ddf = (float *) dst->data;
|
|
|
|
|
|
|
| 1759 |
|
| 1760 |
+
const half * src1_f16 = (const half *) src1->data;
|
| 1761 |
+
const size_t ts_src1 = ggml_type_size(src1->type);
|
| 1762 |
+
GGML_ASSERT(nb10 == ts_src1);
|
| 1763 |
+
int64_t s11 = nb11 / ts_src1;
|
| 1764 |
+
int64_t s12 = nb12 / ts_src1;
|
| 1765 |
+
int64_t s13 = nb13 / ts_src1;
|
| 1766 |
ggml_cuda_pool_alloc<half> src1_f16_alloc(ctx.pool());
|
| 1767 |
+
|
| 1768 |
+
// convert src1 to fp16
|
| 1769 |
if (src1->type != GGML_TYPE_F16) {
|
| 1770 |
+
const to_fp16_nc_cuda_t to_fp16_cuda = ggml_get_to_fp16_nc_cuda(src1->type);
|
| 1771 |
const int64_t ne_src1 = ggml_nelements(src1);
|
| 1772 |
src1_f16_alloc.alloc(ne_src1);
|
| 1773 |
GGML_ASSERT(to_fp16_cuda != nullptr);
|
| 1774 |
+
|
| 1775 |
+
to_fp16_cuda(src1_f16, src1_f16_alloc.get(), ne10, ne11, ne12, ne13, s11, s12, s13, main_stream);
|
| 1776 |
+
|
| 1777 |
+
src1_f16 = src1_f16_alloc.get();
|
| 1778 |
+
s11 = ne10;
|
| 1779 |
+
s12 = ne11*s11;
|
| 1780 |
+
s13 = ne12*s12;
|
| 1781 |
}
|
|
|
|
| 1782 |
|
| 1783 |
ggml_cuda_pool_alloc<half> dst_f16(ctx.pool());
|
| 1784 |
char * dst_t;
|
|
|
|
| 1838 |
int i02 = i12 / r2;
|
| 1839 |
|
| 1840 |
CUBLAS_CHECK(
|
| 1841 |
+
cublasGemmEx(ctx.cublas_handle(), CUBLAS_OP_T, CUBLAS_OP_N,
|
| 1842 |
+
ne01, ne11, ne10,
|
| 1843 |
+
alpha, (const char *) src0_f16 + i03*nb03 + i02*nb02, CUDA_R_16F, nb01/sizeof(half),
|
| 1844 |
+
src1_f16 + i13*s13 + i12*s12, CUDA_R_16F, s11,
|
| 1845 |
+
beta, ( char *) dst_t + i13*nbd3 + i12*nbd2, cu_data_type, ne0,
|
| 1846 |
+
cu_compute_type,
|
| 1847 |
+
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
|
| 1848 |
}
|
| 1849 |
}
|
| 1850 |
}
|
|
|
|
| 1855 |
CUBLAS_CHECK(
|
| 1856 |
cublasGemmStridedBatchedEx(ctx.cublas_handle(), CUBLAS_OP_T, CUBLAS_OP_N,
|
| 1857 |
ne01, ne11, ne10,
|
| 1858 |
+
alpha, src0_f16, CUDA_R_16F, nb01/nb00, nb02/nb00, // strideA
|
| 1859 |
+
src1_f16, CUDA_R_16F, s11, s12, // strideB
|
| 1860 |
+
beta, dst_t, cu_data_type, ne0, ne1*ne0, // strideC
|
| 1861 |
ne12*ne13,
|
| 1862 |
cu_compute_type,
|
| 1863 |
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
|
| 1864 |
} else {
|
| 1865 |
// use cublasGemmBatchedEx
|
| 1866 |
+
const int64_t ne23 = ne12*ne13;
|
| 1867 |
|
| 1868 |
ggml_cuda_pool_alloc<const void *> ptrs_src(ctx.pool(), 2*ne23);
|
| 1869 |
ggml_cuda_pool_alloc< void *> ptrs_dst(ctx.pool(), 1*ne23);
|
|
|
|
| 1875 |
ne12, ne13,
|
| 1876 |
ne23,
|
| 1877 |
nb02, nb03,
|
| 1878 |
+
src1->type == GGML_TYPE_F16 ? nb12 : s12*sizeof(half),
|
| 1879 |
+
src1->type == GGML_TYPE_F16 ? nb13 : s13*sizeof(half),
|
| 1880 |
nbd2, nbd3,
|
| 1881 |
r2, r3);
|
| 1882 |
CUDA_CHECK(cudaGetLastError());
|
|
|
|
| 1885 |
cublasGemmBatchedEx(ctx.cublas_handle(), CUBLAS_OP_T, CUBLAS_OP_N,
|
| 1886 |
ne01, ne11, ne10,
|
| 1887 |
alpha, (const void **) (ptrs_src.get() + 0*ne23), CUDA_R_16F, nb01/nb00,
|
| 1888 |
+
(const void **) (ptrs_src.get() + 1*ne23), CUDA_R_16F, s11,
|
| 1889 |
+
beta, ( void **) (ptrs_dst.get() + 0*ne23), cu_data_type, ne0,
|
| 1890 |
ne23,
|
| 1891 |
cu_compute_type,
|
| 1892 |
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
|
|
|
|
| 1950 |
} else if (!split && use_mul_mat_vec_q) {
|
| 1951 |
ggml_cuda_mul_mat_vec_q(ctx, src0, src1, nullptr, dst);
|
| 1952 |
} else if (!split && src0->type == GGML_TYPE_F16 && (src1->type == GGML_TYPE_F16 || !any_gpus_with_slow_fp16) &&
|
| 1953 |
+
!ggml_is_transposed(src0) && !ggml_is_transposed(src1) && src1->ne[2]*src1->ne[3] > 1) {
|
| 1954 |
// general KQ + KQV multi-batch without FlashAttention
|
| 1955 |
ggml_cuda_mul_mat_batched_cublas(ctx, src0, src1, dst);
|
| 1956 |
} else if (use_mul_mat_vec) {
|