Spaces:
Running
Running
Commit
·
6201c74
1
Parent(s):
aca04d5
cuda: add q8_0->f32 cpy operation (llama/9571)
Browse filesllama: enable K-shift for quantized KV cache
It will fail on unsupported backends or quant types.
- ggml/src/ggml-cuda.cu +3 -0
- ggml/src/ggml-cuda/cpy.cu +51 -0
ggml/src/ggml-cuda.cu
CHANGED
|
@@ -2899,6 +2899,9 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
|
|
| 2899 |
if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_Q8_0) {
|
| 2900 |
return true;
|
| 2901 |
}
|
|
|
|
|
|
|
|
|
|
| 2902 |
if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_Q4_0) {
|
| 2903 |
return true;
|
| 2904 |
}
|
|
|
|
| 2899 |
if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_Q8_0) {
|
| 2900 |
return true;
|
| 2901 |
}
|
| 2902 |
+
if (src0_type == GGML_TYPE_Q8_0 && src1_type == GGML_TYPE_F32) {
|
| 2903 |
+
return true;
|
| 2904 |
+
}
|
| 2905 |
if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_Q4_0) {
|
| 2906 |
return true;
|
| 2907 |
}
|
ggml/src/ggml-cuda/cpy.cu
CHANGED
|
@@ -81,6 +81,17 @@ static __device__ void cpy_blck_f32_q8_0(const char * cxi, char * cdsti) {
|
|
| 81 |
}
|
| 82 |
}
|
| 83 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 84 |
static __device__ void cpy_blck_f32_q4_0(const char * cxi, char * cdsti) {
|
| 85 |
const float * xi = (const float *) cxi;
|
| 86 |
block_q4_0 * dsti = (block_q4_0 *) cdsti;
|
|
@@ -288,6 +299,32 @@ static __global__ void cpy_f32_q(const char * cx, char * cdst, const int ne,
|
|
| 288 |
cpy_blck(cx + x_offset, cdst + dst_offset);
|
| 289 |
}
|
| 290 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 291 |
static void ggml_cpy_f16_f32_cuda(
|
| 292 |
const char * cx, char * cdst, const int ne,
|
| 293 |
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
|
|
@@ -329,6 +366,16 @@ static void ggml_cpy_f32_q8_0_cuda(
|
|
| 329 |
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
|
| 330 |
}
|
| 331 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 332 |
static void ggml_cpy_f32_q4_0_cuda(
|
| 333 |
const char * cx, char * cdst, const int ne,
|
| 334 |
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
|
|
@@ -437,6 +484,8 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
|
|
| 437 |
ggml_cpy_f32_f16_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
| 438 |
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q8_0) {
|
| 439 |
ggml_cpy_f32_q8_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
|
|
|
|
|
|
| 440 |
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_0) {
|
| 441 |
ggml_cpy_f32_q4_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
| 442 |
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_1) {
|
|
@@ -471,6 +520,8 @@ void* ggml_cuda_cpy_fn(const ggml_tensor * src0, ggml_tensor * src1) {
|
|
| 471 |
return (void*) cpy_f32_f16<cpy_1_f32_f16>;
|
| 472 |
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q8_0) {
|
| 473 |
return (void*) cpy_f32_q<cpy_blck_f32_q8_0, QK8_0>;
|
|
|
|
|
|
|
| 474 |
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_0) {
|
| 475 |
return (void*) cpy_f32_q<cpy_blck_f32_q4_0, QK4_0>;
|
| 476 |
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_1) {
|
|
|
|
| 81 |
}
|
| 82 |
}
|
| 83 |
|
| 84 |
+
static __device__ void cpy_blck_q8_0_f32(const char * cxi, char * cdsti) {
|
| 85 |
+
const block_q8_0 * xi = (const block_q8_0 *) cxi;
|
| 86 |
+
float * dsti = (float *) cdsti;
|
| 87 |
+
|
| 88 |
+
const float d = (float)xi->d;
|
| 89 |
+
|
| 90 |
+
for (int j = 0; j < QK8_0; j++) {
|
| 91 |
+
dsti[j] = xi->qs[j] * d;
|
| 92 |
+
}
|
| 93 |
+
}
|
| 94 |
+
|
| 95 |
static __device__ void cpy_blck_f32_q4_0(const char * cxi, char * cdsti) {
|
| 96 |
const float * xi = (const float *) cxi;
|
| 97 |
block_q4_0 * dsti = (block_q4_0 *) cdsti;
|
|
|
|
| 299 |
cpy_blck(cx + x_offset, cdst + dst_offset);
|
| 300 |
}
|
| 301 |
|
| 302 |
+
template <cpy_kernel_t cpy_blck, int qk>
|
| 303 |
+
static __global__ void cpy_q_f32(const char * cx, char * cdst, const int ne,
|
| 304 |
+
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
|
| 305 |
+
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
|
| 306 |
+
const int nb12, const int nb13) {
|
| 307 |
+
const int i = (blockDim.x*blockIdx.x + threadIdx.x)*qk;
|
| 308 |
+
|
| 309 |
+
if (i >= ne) {
|
| 310 |
+
return;
|
| 311 |
+
}
|
| 312 |
+
|
| 313 |
+
const int i03 = i/(ne00 * ne01 * ne02);
|
| 314 |
+
const int i02 = (i - i03*ne00*ne01*ne02 )/ (ne00*ne01);
|
| 315 |
+
const int i01 = (i - i03*ne00*ne01*ne02 - i02*ne01*ne00) / ne00;
|
| 316 |
+
const int i00 = i - i03*ne00*ne01*ne02 - i02*ne01*ne00 - i01*ne00;
|
| 317 |
+
const int x_offset = (i00/qk)*nb00 + i01*nb01 + i02*nb02 + i03 * nb03;
|
| 318 |
+
|
| 319 |
+
const int i13 = i/(ne10 * ne11 * ne12);
|
| 320 |
+
const int i12 = (i - i13*ne10*ne11*ne12) / (ne10*ne11);
|
| 321 |
+
const int i11 = (i - i13*ne10*ne11*ne12 - i12*ne10*ne11) / ne10;
|
| 322 |
+
const int i10 = i - i13*ne10*ne11*ne12 - i12*ne10*ne11 - i11*ne10;
|
| 323 |
+
const int dst_offset = i10*nb10 + i11*nb11 + i12*nb12 + i13*nb13;
|
| 324 |
+
|
| 325 |
+
cpy_blck(cx + x_offset, cdst + dst_offset);
|
| 326 |
+
}
|
| 327 |
+
|
| 328 |
static void ggml_cpy_f16_f32_cuda(
|
| 329 |
const char * cx, char * cdst, const int ne,
|
| 330 |
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
|
|
|
|
| 366 |
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
|
| 367 |
}
|
| 368 |
|
| 369 |
+
static void ggml_cpy_q8_0_f32_cuda(
|
| 370 |
+
const char * cx, char * cdst, const int ne,
|
| 371 |
+
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
|
| 372 |
+
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
|
| 373 |
+
|
| 374 |
+
const int num_blocks = ne;
|
| 375 |
+
cpy_q_f32<cpy_blck_q8_0_f32, QK8_0><<<num_blocks, 1, 0, stream>>>
|
| 376 |
+
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
|
| 377 |
+
}
|
| 378 |
+
|
| 379 |
static void ggml_cpy_f32_q4_0_cuda(
|
| 380 |
const char * cx, char * cdst, const int ne,
|
| 381 |
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
|
|
|
|
| 484 |
ggml_cpy_f32_f16_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
| 485 |
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q8_0) {
|
| 486 |
ggml_cpy_f32_q8_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
| 487 |
+
} else if (src0->type == GGML_TYPE_Q8_0 && src1->type == GGML_TYPE_F32) {
|
| 488 |
+
ggml_cpy_q8_0_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
| 489 |
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_0) {
|
| 490 |
ggml_cpy_f32_q4_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
| 491 |
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_1) {
|
|
|
|
| 520 |
return (void*) cpy_f32_f16<cpy_1_f32_f16>;
|
| 521 |
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q8_0) {
|
| 522 |
return (void*) cpy_f32_q<cpy_blck_f32_q8_0, QK8_0>;
|
| 523 |
+
} else if (src0->type == GGML_TYPE_Q8_0 && src1->type == GGML_TYPE_F32) {
|
| 524 |
+
return (void*) cpy_q_f32<cpy_blck_q8_0_f32, QK8_0>;
|
| 525 |
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_0) {
|
| 526 |
return (void*) cpy_f32_q<cpy_blck_f32_q4_0, QK4_0>;
|
| 527 |
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_1) {
|