Spaces:
Running
Running
Alan Gray
slaren
commited on
Commit
·
a2fdbe6
1
Parent(s):
7d5f3d4
Simplify and improve CUDA graphs through use of indirect copy pointers (llama/9017)
Browse files* CUDA: Simplify and improve CUDA graphs through use of indirect copy pointers
Previously there was complexity in the CUDA graphs implementation due
frequently changing parameters to copy kernels associated with K and V
cache pointers. This patch simplifies by using indirection to avoid
such parameters frequently changing, avoiding the need for frequent
graph updates.
Fixes #12152
* Addressed comments
* fix HIP builds
* properly sync to stream
* removed ggml_cuda_cpy_fn_ptrs
* move stream sync before free
* guard to only use indirection with graphs
* style fixes
* check for errors
---------
Co-authored-by: slaren <[email protected]>
- ggml/src/ggml-cuda/common.cuh +7 -1
- ggml/src/ggml-cuda/cpy.cu +89 -51
- ggml/src/ggml-cuda/cpy.cuh +2 -0
- ggml/src/ggml-cuda/ggml-cuda.cu +24 -69
ggml/src/ggml-cuda/common.cuh
CHANGED
|
@@ -729,7 +729,13 @@ struct ggml_cuda_graph {
|
|
| 729 |
bool disable_due_to_failed_graph_capture = false;
|
| 730 |
int number_consecutive_updates = 0;
|
| 731 |
std::vector<ggml_graph_node_properties> ggml_graph_properties;
|
| 732 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 733 |
#endif
|
| 734 |
};
|
| 735 |
|
|
|
|
| 729 |
bool disable_due_to_failed_graph_capture = false;
|
| 730 |
int number_consecutive_updates = 0;
|
| 731 |
std::vector<ggml_graph_node_properties> ggml_graph_properties;
|
| 732 |
+
bool use_cpy_indirection = false;
|
| 733 |
+
std::vector<char *> cpy_dest_ptrs;
|
| 734 |
+
char ** dest_ptrs_d;
|
| 735 |
+
int dest_ptrs_size = 0;
|
| 736 |
+
// Index to allow each cpy kernel to be aware of it's position within the graph
|
| 737 |
+
// relative to other cpy nodes.
|
| 738 |
+
int graph_cpynode_index = -1;
|
| 739 |
#endif
|
| 740 |
};
|
| 741 |
|
ggml/src/ggml-cuda/cpy.cu
CHANGED
|
@@ -32,16 +32,18 @@ static __device__ void cpy_1_f16_f32(const char * cxi, char * cdsti) {
|
|
| 32 |
}
|
| 33 |
|
| 34 |
template <cpy_kernel_t cpy_1>
|
| 35 |
-
static __global__ void cpy_f32_f16(const char * cx, char *
|
| 36 |
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
|
| 37 |
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
|
| 38 |
-
const int nb12, const int nb13) {
|
| 39 |
const int64_t i = blockDim.x*blockIdx.x + threadIdx.x;
|
| 40 |
|
| 41 |
if (i >= ne) {
|
| 42 |
return;
|
| 43 |
}
|
| 44 |
|
|
|
|
|
|
|
| 45 |
// determine indices i03/i13, i02/i12, i01/i11, i00/i10 as a function of index i of flattened tensor
|
| 46 |
// then combine those indices with the corresponding byte offsets to get the total offsets
|
| 47 |
const int64_t i03 = i/(ne00 * ne01 * ne02);
|
|
@@ -288,16 +290,18 @@ static __device__ void cpy_blck_f32_iq4_nl(const char * cxi, char * cdsti) {
|
|
| 288 |
}
|
| 289 |
|
| 290 |
template <cpy_kernel_t cpy_blck, int qk>
|
| 291 |
-
static __global__ void cpy_f32_q(const char * cx, char *
|
| 292 |
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
|
| 293 |
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
|
| 294 |
-
const int nb12, const int nb13) {
|
| 295 |
const int i = (blockDim.x*blockIdx.x + threadIdx.x)*qk;
|
| 296 |
|
| 297 |
if (i >= ne) {
|
| 298 |
return;
|
| 299 |
}
|
| 300 |
|
|
|
|
|
|
|
| 301 |
const int i03 = i/(ne00 * ne01 * ne02);
|
| 302 |
const int i02 = (i - i03*ne00*ne01*ne02 )/ (ne00*ne01);
|
| 303 |
const int i01 = (i - i03*ne00*ne01*ne02 - i02*ne01*ne00) / ne00;
|
|
@@ -314,16 +318,18 @@ static __global__ void cpy_f32_q(const char * cx, char * cdst, const int ne,
|
|
| 314 |
}
|
| 315 |
|
| 316 |
template <cpy_kernel_t cpy_blck, int qk>
|
| 317 |
-
static __global__ void cpy_q_f32(const char * cx, char *
|
| 318 |
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
|
| 319 |
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
|
| 320 |
-
const int nb12, const int nb13) {
|
| 321 |
const int i = (blockDim.x*blockIdx.x + threadIdx.x)*qk;
|
| 322 |
|
| 323 |
if (i >= ne) {
|
| 324 |
return;
|
| 325 |
}
|
| 326 |
|
|
|
|
|
|
|
| 327 |
const int i03 = i/(ne00 * ne01 * ne02);
|
| 328 |
const int i02 = (i - i03*ne00*ne01*ne02 )/ (ne00*ne01);
|
| 329 |
const int i01 = (i - i03*ne00*ne01*ne02 - i02*ne01*ne00) / ne00;
|
|
@@ -339,66 +345,84 @@ static __global__ void cpy_q_f32(const char * cx, char * cdst, const int ne,
|
|
| 339 |
cpy_blck(cx + x_offset, cdst + dst_offset);
|
| 340 |
}
|
| 341 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 342 |
static void ggml_cpy_f16_f32_cuda(
|
| 343 |
const char * cx, char * cdst, const int ne,
|
| 344 |
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
|
| 345 |
-
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
|
| 346 |
|
| 347 |
const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
|
| 348 |
cpy_f32_f16<cpy_1_f16_f32><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
|
| 349 |
-
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
|
| 350 |
}
|
| 351 |
|
| 352 |
static void ggml_cpy_f32_f32_cuda(
|
| 353 |
const char * cx, char * cdst, const int ne,
|
| 354 |
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
|
| 355 |
-
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
|
| 356 |
|
| 357 |
const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
|
| 358 |
cpy_f32_f16<cpy_1_f32_f32><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
|
| 359 |
-
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
|
| 360 |
}
|
| 361 |
|
| 362 |
static void ggml_cpy_f32_f16_cuda(
|
| 363 |
const char * cx, char * cdst, const int ne,
|
| 364 |
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
|
| 365 |
-
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
|
| 366 |
|
| 367 |
const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
|
| 368 |
cpy_f32_f16<cpy_1_f32_f16><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
|
| 369 |
-
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
|
| 370 |
}
|
| 371 |
|
| 372 |
static void ggml_cpy_f32_q8_0_cuda(
|
| 373 |
const char * cx, char * cdst, const int ne,
|
| 374 |
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
|
| 375 |
-
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
|
| 376 |
|
| 377 |
GGML_ASSERT(ne % QK8_0 == 0);
|
| 378 |
const int num_blocks = ne / QK8_0;
|
| 379 |
cpy_f32_q<cpy_blck_f32_q8_0, QK8_0><<<num_blocks, 1, 0, stream>>>
|
| 380 |
-
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
|
| 381 |
}
|
| 382 |
|
| 383 |
static void ggml_cpy_q8_0_f32_cuda(
|
| 384 |
const char * cx, char * cdst, const int ne,
|
| 385 |
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
|
| 386 |
-
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
|
| 387 |
|
| 388 |
const int num_blocks = ne;
|
| 389 |
cpy_q_f32<cpy_blck_q8_0_f32, QK8_0><<<num_blocks, 1, 0, stream>>>
|
| 390 |
-
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
|
| 391 |
}
|
| 392 |
|
| 393 |
static void ggml_cpy_f32_q4_0_cuda(
|
| 394 |
const char * cx, char * cdst, const int ne,
|
| 395 |
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
|
| 396 |
-
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
|
| 397 |
|
| 398 |
GGML_ASSERT(ne % QK4_0 == 0);
|
| 399 |
const int num_blocks = ne / QK4_0;
|
| 400 |
cpy_f32_q<cpy_blck_f32_q4_0, QK4_0><<<num_blocks, 1, 0, stream>>>
|
| 401 |
-
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
|
| 402 |
}
|
| 403 |
|
| 404 |
static void ggml_cpy_q4_0_f32_cuda(
|
|
@@ -407,22 +431,22 @@ static void ggml_cpy_q4_0_f32_cuda(
|
|
| 407 |
const int nb00, const int nb01, const int nb02,
|
| 408 |
const int nb03, const int ne10, const int ne11, const int ne12,
|
| 409 |
const int nb10, const int nb11, const int nb12, const int nb13,
|
| 410 |
-
cudaStream_t stream) {
|
| 411 |
const int num_blocks = ne;
|
| 412 |
cpy_q_f32<cpy_blck_q_f32<dequantize_q4_0, QK4_0>, QK4_0><<<num_blocks, 1, 0, stream>>>(
|
| 413 |
cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
|
| 414 |
-
ne10, ne11, ne12, nb10, nb11, nb12, nb13);
|
| 415 |
}
|
| 416 |
|
| 417 |
static void ggml_cpy_f32_q4_1_cuda(
|
| 418 |
const char * cx, char * cdst, const int ne,
|
| 419 |
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
|
| 420 |
-
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
|
| 421 |
|
| 422 |
GGML_ASSERT(ne % QK4_1 == 0);
|
| 423 |
const int num_blocks = ne / QK4_1;
|
| 424 |
cpy_f32_q<cpy_blck_f32_q4_1, QK4_1><<<num_blocks, 1, 0, stream>>>
|
| 425 |
-
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
|
| 426 |
}
|
| 427 |
|
| 428 |
static void ggml_cpy_q4_1_f32_cuda(
|
|
@@ -431,22 +455,22 @@ static void ggml_cpy_q4_1_f32_cuda(
|
|
| 431 |
const int nb00, const int nb01, const int nb02,
|
| 432 |
const int nb03, const int ne10, const int ne11, const int ne12,
|
| 433 |
const int nb10, const int nb11, const int nb12, const int nb13,
|
| 434 |
-
cudaStream_t stream) {
|
| 435 |
const int num_blocks = ne;
|
| 436 |
cpy_q_f32<cpy_blck_q_f32<dequantize_q4_1, QK4_1>, QK4_1><<<num_blocks, 1, 0, stream>>>(
|
| 437 |
cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
|
| 438 |
-
ne10, ne11, ne12, nb10, nb11, nb12, nb13);
|
| 439 |
}
|
| 440 |
|
| 441 |
static void ggml_cpy_f32_q5_0_cuda(
|
| 442 |
const char * cx, char * cdst, const int ne,
|
| 443 |
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
|
| 444 |
-
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
|
| 445 |
|
| 446 |
GGML_ASSERT(ne % QK5_0 == 0);
|
| 447 |
const int num_blocks = ne / QK5_0;
|
| 448 |
cpy_f32_q<cpy_blck_f32_q5_0, QK5_0><<<num_blocks, 1, 0, stream>>>
|
| 449 |
-
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
|
| 450 |
}
|
| 451 |
|
| 452 |
static void ggml_cpy_q5_0_f32_cuda(
|
|
@@ -455,22 +479,22 @@ static void ggml_cpy_q5_0_f32_cuda(
|
|
| 455 |
const int nb00, const int nb01, const int nb02,
|
| 456 |
const int nb03, const int ne10, const int ne11, const int ne12,
|
| 457 |
const int nb10, const int nb11, const int nb12, const int nb13,
|
| 458 |
-
cudaStream_t stream) {
|
| 459 |
const int num_blocks = ne;
|
| 460 |
cpy_q_f32<cpy_blck_q_f32<dequantize_q5_0, QK5_0>, QK5_0><<<num_blocks, 1, 0, stream>>>(
|
| 461 |
cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
|
| 462 |
-
ne10, ne11, ne12, nb10, nb11, nb12, nb13);
|
| 463 |
}
|
| 464 |
|
| 465 |
static void ggml_cpy_f32_q5_1_cuda(
|
| 466 |
const char * cx, char * cdst, const int ne,
|
| 467 |
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
|
| 468 |
-
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
|
| 469 |
|
| 470 |
GGML_ASSERT(ne % QK5_1 == 0);
|
| 471 |
const int num_blocks = ne / QK5_1;
|
| 472 |
cpy_f32_q<cpy_blck_f32_q5_1, QK5_1><<<num_blocks, 1, 0, stream>>>
|
| 473 |
-
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
|
| 474 |
}
|
| 475 |
|
| 476 |
static void ggml_cpy_q5_1_f32_cuda(
|
|
@@ -479,32 +503,32 @@ static void ggml_cpy_q5_1_f32_cuda(
|
|
| 479 |
const int nb00, const int nb01, const int nb02,
|
| 480 |
const int nb03, const int ne10, const int ne11, const int ne12,
|
| 481 |
const int nb10, const int nb11, const int nb12, const int nb13,
|
| 482 |
-
cudaStream_t stream) {
|
| 483 |
const int num_blocks = ne;
|
| 484 |
cpy_q_f32<cpy_blck_q_f32<dequantize_q5_1, QK5_1>, QK5_1><<<num_blocks, 1, 0, stream>>>(
|
| 485 |
cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
|
| 486 |
-
ne10, ne11, ne12, nb10, nb11, nb12, nb13);
|
| 487 |
}
|
| 488 |
|
| 489 |
static void ggml_cpy_f32_iq4_nl_cuda(
|
| 490 |
const char * cx, char * cdst, const int ne,
|
| 491 |
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
|
| 492 |
-
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
|
| 493 |
|
| 494 |
GGML_ASSERT(ne % QK4_NL == 0);
|
| 495 |
const int num_blocks = ne / QK4_NL;
|
| 496 |
cpy_f32_q<cpy_blck_f32_iq4_nl, QK4_NL><<<num_blocks, 1, 0, stream>>>
|
| 497 |
-
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
|
| 498 |
}
|
| 499 |
|
| 500 |
static void ggml_cpy_f16_f16_cuda(
|
| 501 |
const char * cx, char * cdst, const int ne,
|
| 502 |
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
|
| 503 |
-
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
|
| 504 |
|
| 505 |
const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
|
| 506 |
cpy_f32_f16<cpy_1_f16_f16><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
|
| 507 |
-
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
|
| 508 |
}
|
| 509 |
|
| 510 |
void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, ggml_tensor * src1) {
|
|
@@ -541,46 +565,60 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
|
|
| 541 |
char * src0_ddc = (char *) src0->data;
|
| 542 |
char * src1_ddc = (char *) src1->data;
|
| 543 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 544 |
if (src0->type == src1->type && ggml_is_contiguous(src0) && ggml_is_contiguous(src1)) {
|
| 545 |
GGML_ASSERT(ggml_nbytes(src0) == ggml_nbytes(src1));
|
| 546 |
CUDA_CHECK(cudaMemcpyAsync(src1_ddc, src0_ddc, ggml_nbytes(src0), cudaMemcpyDeviceToDevice, main_stream));
|
| 547 |
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) {
|
| 548 |
-
ggml_cpy_f32_f32_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
| 549 |
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16) {
|
| 550 |
-
ggml_cpy_f32_f16_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
| 551 |
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q8_0) {
|
| 552 |
-
ggml_cpy_f32_q8_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
| 553 |
} else if (src0->type == GGML_TYPE_Q8_0 && src1->type == GGML_TYPE_F32) {
|
| 554 |
-
ggml_cpy_q8_0_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
| 555 |
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_0) {
|
| 556 |
-
ggml_cpy_f32_q4_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
| 557 |
} else if (src0->type == GGML_TYPE_Q4_0 && src1->type == GGML_TYPE_F32) {
|
| 558 |
ggml_cpy_q4_0_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02,
|
| 559 |
-
nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
| 560 |
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_1) {
|
| 561 |
-
ggml_cpy_f32_q4_1_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
| 562 |
} else if (src0->type == GGML_TYPE_Q4_1 && src1->type == GGML_TYPE_F32) {
|
| 563 |
ggml_cpy_q4_1_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02,
|
| 564 |
-
nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
| 565 |
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_0) {
|
| 566 |
-
ggml_cpy_f32_q5_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
| 567 |
} else if (src0->type == GGML_TYPE_Q5_0 && src1->type == GGML_TYPE_F32) {
|
| 568 |
ggml_cpy_q5_0_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02,
|
| 569 |
-
nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
| 570 |
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_IQ4_NL) {
|
| 571 |
-
ggml_cpy_f32_iq4_nl_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
| 572 |
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_1) {
|
| 573 |
-
ggml_cpy_f32_q5_1_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
| 574 |
} else if (src0->type == GGML_TYPE_Q5_1 && src1->type == GGML_TYPE_F32) {
|
| 575 |
-
ggml_cpy_q5_1_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
| 576 |
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) {
|
| 577 |
-
ggml_cpy_f16_f16_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
| 578 |
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) {
|
| 579 |
-
ggml_cpy_f16_f32_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
| 580 |
} else {
|
| 581 |
GGML_ABORT("%s: unsupported type combination (%s to %s)\n", __func__,
|
| 582 |
ggml_type_name(src0->type), ggml_type_name(src1->type));
|
| 583 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 584 |
}
|
| 585 |
|
| 586 |
void ggml_cuda_dup(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
|
|
|
| 32 |
}
|
| 33 |
|
| 34 |
template <cpy_kernel_t cpy_1>
|
| 35 |
+
static __global__ void cpy_f32_f16(const char * cx, char * cdst_direct, const int ne,
|
| 36 |
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
|
| 37 |
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
|
| 38 |
+
const int nb12, const int nb13, char ** cdst_indirect, int graph_cpynode_index) {
|
| 39 |
const int64_t i = blockDim.x*blockIdx.x + threadIdx.x;
|
| 40 |
|
| 41 |
if (i >= ne) {
|
| 42 |
return;
|
| 43 |
}
|
| 44 |
|
| 45 |
+
char * cdst = (cdst_indirect != nullptr) ? cdst_indirect[graph_cpynode_index]: cdst_direct;
|
| 46 |
+
|
| 47 |
// determine indices i03/i13, i02/i12, i01/i11, i00/i10 as a function of index i of flattened tensor
|
| 48 |
// then combine those indices with the corresponding byte offsets to get the total offsets
|
| 49 |
const int64_t i03 = i/(ne00 * ne01 * ne02);
|
|
|
|
| 290 |
}
|
| 291 |
|
| 292 |
template <cpy_kernel_t cpy_blck, int qk>
|
| 293 |
+
static __global__ void cpy_f32_q(const char * cx, char * cdst_direct, const int ne,
|
| 294 |
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
|
| 295 |
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
|
| 296 |
+
const int nb12, const int nb13, char ** cdst_indirect, int graph_cpynode_index) {
|
| 297 |
const int i = (blockDim.x*blockIdx.x + threadIdx.x)*qk;
|
| 298 |
|
| 299 |
if (i >= ne) {
|
| 300 |
return;
|
| 301 |
}
|
| 302 |
|
| 303 |
+
char * cdst = (cdst_indirect != nullptr) ? cdst_indirect[graph_cpynode_index]: cdst_direct;
|
| 304 |
+
|
| 305 |
const int i03 = i/(ne00 * ne01 * ne02);
|
| 306 |
const int i02 = (i - i03*ne00*ne01*ne02 )/ (ne00*ne01);
|
| 307 |
const int i01 = (i - i03*ne00*ne01*ne02 - i02*ne01*ne00) / ne00;
|
|
|
|
| 318 |
}
|
| 319 |
|
| 320 |
template <cpy_kernel_t cpy_blck, int qk>
|
| 321 |
+
static __global__ void cpy_q_f32(const char * cx, char * cdst_direct, const int ne,
|
| 322 |
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
|
| 323 |
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
|
| 324 |
+
const int nb12, const int nb13, char ** cdst_indirect, int graph_cpynode_index) {
|
| 325 |
const int i = (blockDim.x*blockIdx.x + threadIdx.x)*qk;
|
| 326 |
|
| 327 |
if (i >= ne) {
|
| 328 |
return;
|
| 329 |
}
|
| 330 |
|
| 331 |
+
char * cdst = (cdst_indirect != nullptr) ? cdst_indirect[graph_cpynode_index]: cdst_direct;
|
| 332 |
+
|
| 333 |
const int i03 = i/(ne00 * ne01 * ne02);
|
| 334 |
const int i02 = (i - i03*ne00*ne01*ne02 )/ (ne00*ne01);
|
| 335 |
const int i01 = (i - i03*ne00*ne01*ne02 - i02*ne01*ne00) / ne00;
|
|
|
|
| 345 |
cpy_blck(cx + x_offset, cdst + dst_offset);
|
| 346 |
}
|
| 347 |
|
| 348 |
+
// Copy destination pointers to GPU to be available when pointer indirection is in use
|
| 349 |
+
|
| 350 |
+
void ggml_cuda_cpy_dest_ptrs_copy(ggml_cuda_graph * cuda_graph, char ** host_dest_ptrs, const int host_dest_ptrs_size, cudaStream_t stream) {
|
| 351 |
+
#if defined(GGML_CUDA_USE_GRAPHS) || defined(GGML_HIP_GRAPHS)
|
| 352 |
+
if (cuda_graph->dest_ptrs_size < host_dest_ptrs_size) { // (re-)allocate GPU memory for destination pointers
|
| 353 |
+
CUDA_CHECK(cudaStreamSynchronize(stream));
|
| 354 |
+
if (cuda_graph->dest_ptrs_d != nullptr) {
|
| 355 |
+
CUDA_CHECK(cudaFree(cuda_graph->dest_ptrs_d));
|
| 356 |
+
}
|
| 357 |
+
CUDA_CHECK(cudaMalloc(&cuda_graph->dest_ptrs_d, host_dest_ptrs_size*sizeof(char *)));
|
| 358 |
+
cuda_graph->dest_ptrs_size = host_dest_ptrs_size;
|
| 359 |
+
}
|
| 360 |
+
// copy destination pointers to GPU
|
| 361 |
+
CUDA_CHECK(cudaMemcpyAsync(cuda_graph->dest_ptrs_d, host_dest_ptrs, host_dest_ptrs_size*sizeof(char *), cudaMemcpyHostToDevice, stream));
|
| 362 |
+
cuda_graph->graph_cpynode_index = 0; // reset index
|
| 363 |
+
#endif
|
| 364 |
+
}
|
| 365 |
+
|
| 366 |
static void ggml_cpy_f16_f32_cuda(
|
| 367 |
const char * cx, char * cdst, const int ne,
|
| 368 |
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
|
| 369 |
+
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
|
| 370 |
|
| 371 |
const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
|
| 372 |
cpy_f32_f16<cpy_1_f16_f32><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
|
| 373 |
+
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
|
| 374 |
}
|
| 375 |
|
| 376 |
static void ggml_cpy_f32_f32_cuda(
|
| 377 |
const char * cx, char * cdst, const int ne,
|
| 378 |
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
|
| 379 |
+
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
|
| 380 |
|
| 381 |
const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
|
| 382 |
cpy_f32_f16<cpy_1_f32_f32><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
|
| 383 |
+
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
|
| 384 |
}
|
| 385 |
|
| 386 |
static void ggml_cpy_f32_f16_cuda(
|
| 387 |
const char * cx, char * cdst, const int ne,
|
| 388 |
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
|
| 389 |
+
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
|
| 390 |
|
| 391 |
const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
|
| 392 |
cpy_f32_f16<cpy_1_f32_f16><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
|
| 393 |
+
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
|
| 394 |
}
|
| 395 |
|
| 396 |
static void ggml_cpy_f32_q8_0_cuda(
|
| 397 |
const char * cx, char * cdst, const int ne,
|
| 398 |
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
|
| 399 |
+
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
|
| 400 |
|
| 401 |
GGML_ASSERT(ne % QK8_0 == 0);
|
| 402 |
const int num_blocks = ne / QK8_0;
|
| 403 |
cpy_f32_q<cpy_blck_f32_q8_0, QK8_0><<<num_blocks, 1, 0, stream>>>
|
| 404 |
+
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
|
| 405 |
}
|
| 406 |
|
| 407 |
static void ggml_cpy_q8_0_f32_cuda(
|
| 408 |
const char * cx, char * cdst, const int ne,
|
| 409 |
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
|
| 410 |
+
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
|
| 411 |
|
| 412 |
const int num_blocks = ne;
|
| 413 |
cpy_q_f32<cpy_blck_q8_0_f32, QK8_0><<<num_blocks, 1, 0, stream>>>
|
| 414 |
+
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
|
| 415 |
}
|
| 416 |
|
| 417 |
static void ggml_cpy_f32_q4_0_cuda(
|
| 418 |
const char * cx, char * cdst, const int ne,
|
| 419 |
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
|
| 420 |
+
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
|
| 421 |
|
| 422 |
GGML_ASSERT(ne % QK4_0 == 0);
|
| 423 |
const int num_blocks = ne / QK4_0;
|
| 424 |
cpy_f32_q<cpy_blck_f32_q4_0, QK4_0><<<num_blocks, 1, 0, stream>>>
|
| 425 |
+
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
|
| 426 |
}
|
| 427 |
|
| 428 |
static void ggml_cpy_q4_0_f32_cuda(
|
|
|
|
| 431 |
const int nb00, const int nb01, const int nb02,
|
| 432 |
const int nb03, const int ne10, const int ne11, const int ne12,
|
| 433 |
const int nb10, const int nb11, const int nb12, const int nb13,
|
| 434 |
+
cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
|
| 435 |
const int num_blocks = ne;
|
| 436 |
cpy_q_f32<cpy_blck_q_f32<dequantize_q4_0, QK4_0>, QK4_0><<<num_blocks, 1, 0, stream>>>(
|
| 437 |
cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
|
| 438 |
+
ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
|
| 439 |
}
|
| 440 |
|
| 441 |
static void ggml_cpy_f32_q4_1_cuda(
|
| 442 |
const char * cx, char * cdst, const int ne,
|
| 443 |
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
|
| 444 |
+
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
|
| 445 |
|
| 446 |
GGML_ASSERT(ne % QK4_1 == 0);
|
| 447 |
const int num_blocks = ne / QK4_1;
|
| 448 |
cpy_f32_q<cpy_blck_f32_q4_1, QK4_1><<<num_blocks, 1, 0, stream>>>
|
| 449 |
+
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
|
| 450 |
}
|
| 451 |
|
| 452 |
static void ggml_cpy_q4_1_f32_cuda(
|
|
|
|
| 455 |
const int nb00, const int nb01, const int nb02,
|
| 456 |
const int nb03, const int ne10, const int ne11, const int ne12,
|
| 457 |
const int nb10, const int nb11, const int nb12, const int nb13,
|
| 458 |
+
cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
|
| 459 |
const int num_blocks = ne;
|
| 460 |
cpy_q_f32<cpy_blck_q_f32<dequantize_q4_1, QK4_1>, QK4_1><<<num_blocks, 1, 0, stream>>>(
|
| 461 |
cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
|
| 462 |
+
ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
|
| 463 |
}
|
| 464 |
|
| 465 |
static void ggml_cpy_f32_q5_0_cuda(
|
| 466 |
const char * cx, char * cdst, const int ne,
|
| 467 |
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
|
| 468 |
+
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
|
| 469 |
|
| 470 |
GGML_ASSERT(ne % QK5_0 == 0);
|
| 471 |
const int num_blocks = ne / QK5_0;
|
| 472 |
cpy_f32_q<cpy_blck_f32_q5_0, QK5_0><<<num_blocks, 1, 0, stream>>>
|
| 473 |
+
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
|
| 474 |
}
|
| 475 |
|
| 476 |
static void ggml_cpy_q5_0_f32_cuda(
|
|
|
|
| 479 |
const int nb00, const int nb01, const int nb02,
|
| 480 |
const int nb03, const int ne10, const int ne11, const int ne12,
|
| 481 |
const int nb10, const int nb11, const int nb12, const int nb13,
|
| 482 |
+
cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
|
| 483 |
const int num_blocks = ne;
|
| 484 |
cpy_q_f32<cpy_blck_q_f32<dequantize_q5_0, QK5_0>, QK5_0><<<num_blocks, 1, 0, stream>>>(
|
| 485 |
cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
|
| 486 |
+
ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
|
| 487 |
}
|
| 488 |
|
| 489 |
static void ggml_cpy_f32_q5_1_cuda(
|
| 490 |
const char * cx, char * cdst, const int ne,
|
| 491 |
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
|
| 492 |
+
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
|
| 493 |
|
| 494 |
GGML_ASSERT(ne % QK5_1 == 0);
|
| 495 |
const int num_blocks = ne / QK5_1;
|
| 496 |
cpy_f32_q<cpy_blck_f32_q5_1, QK5_1><<<num_blocks, 1, 0, stream>>>
|
| 497 |
+
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
|
| 498 |
}
|
| 499 |
|
| 500 |
static void ggml_cpy_q5_1_f32_cuda(
|
|
|
|
| 503 |
const int nb00, const int nb01, const int nb02,
|
| 504 |
const int nb03, const int ne10, const int ne11, const int ne12,
|
| 505 |
const int nb10, const int nb11, const int nb12, const int nb13,
|
| 506 |
+
cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
|
| 507 |
const int num_blocks = ne;
|
| 508 |
cpy_q_f32<cpy_blck_q_f32<dequantize_q5_1, QK5_1>, QK5_1><<<num_blocks, 1, 0, stream>>>(
|
| 509 |
cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
|
| 510 |
+
ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
|
| 511 |
}
|
| 512 |
|
| 513 |
static void ggml_cpy_f32_iq4_nl_cuda(
|
| 514 |
const char * cx, char * cdst, const int ne,
|
| 515 |
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
|
| 516 |
+
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
|
| 517 |
|
| 518 |
GGML_ASSERT(ne % QK4_NL == 0);
|
| 519 |
const int num_blocks = ne / QK4_NL;
|
| 520 |
cpy_f32_q<cpy_blck_f32_iq4_nl, QK4_NL><<<num_blocks, 1, 0, stream>>>
|
| 521 |
+
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
|
| 522 |
}
|
| 523 |
|
| 524 |
static void ggml_cpy_f16_f16_cuda(
|
| 525 |
const char * cx, char * cdst, const int ne,
|
| 526 |
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
|
| 527 |
+
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
|
| 528 |
|
| 529 |
const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
|
| 530 |
cpy_f32_f16<cpy_1_f16_f16><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
|
| 531 |
+
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
|
| 532 |
}
|
| 533 |
|
| 534 |
void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, ggml_tensor * src1) {
|
|
|
|
| 565 |
char * src0_ddc = (char *) src0->data;
|
| 566 |
char * src1_ddc = (char *) src1->data;
|
| 567 |
|
| 568 |
+
char ** dest_ptrs_d = nullptr;
|
| 569 |
+
int graph_cpynode_index = -1;
|
| 570 |
+
#if defined(GGML_CUDA_USE_GRAPHS) || defined(GGML_HIP_GRAPHS)
|
| 571 |
+
if(ctx.cuda_graph->use_cpy_indirection) {
|
| 572 |
+
dest_ptrs_d = ctx.cuda_graph->dest_ptrs_d;
|
| 573 |
+
graph_cpynode_index = ctx.cuda_graph->graph_cpynode_index;
|
| 574 |
+
}
|
| 575 |
+
#endif
|
| 576 |
if (src0->type == src1->type && ggml_is_contiguous(src0) && ggml_is_contiguous(src1)) {
|
| 577 |
GGML_ASSERT(ggml_nbytes(src0) == ggml_nbytes(src1));
|
| 578 |
CUDA_CHECK(cudaMemcpyAsync(src1_ddc, src0_ddc, ggml_nbytes(src0), cudaMemcpyDeviceToDevice, main_stream));
|
| 579 |
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) {
|
| 580 |
+
ggml_cpy_f32_f32_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
|
| 581 |
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16) {
|
| 582 |
+
ggml_cpy_f32_f16_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
|
| 583 |
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q8_0) {
|
| 584 |
+
ggml_cpy_f32_q8_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
|
| 585 |
} else if (src0->type == GGML_TYPE_Q8_0 && src1->type == GGML_TYPE_F32) {
|
| 586 |
+
ggml_cpy_q8_0_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
|
| 587 |
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_0) {
|
| 588 |
+
ggml_cpy_f32_q4_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
|
| 589 |
} else if (src0->type == GGML_TYPE_Q4_0 && src1->type == GGML_TYPE_F32) {
|
| 590 |
ggml_cpy_q4_0_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02,
|
| 591 |
+
nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
|
| 592 |
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_1) {
|
| 593 |
+
ggml_cpy_f32_q4_1_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
|
| 594 |
} else if (src0->type == GGML_TYPE_Q4_1 && src1->type == GGML_TYPE_F32) {
|
| 595 |
ggml_cpy_q4_1_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02,
|
| 596 |
+
nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
|
| 597 |
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_0) {
|
| 598 |
+
ggml_cpy_f32_q5_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
|
| 599 |
} else if (src0->type == GGML_TYPE_Q5_0 && src1->type == GGML_TYPE_F32) {
|
| 600 |
ggml_cpy_q5_0_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02,
|
| 601 |
+
nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
|
| 602 |
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_IQ4_NL) {
|
| 603 |
+
ggml_cpy_f32_iq4_nl_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
|
| 604 |
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_1) {
|
| 605 |
+
ggml_cpy_f32_q5_1_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
|
| 606 |
} else if (src0->type == GGML_TYPE_Q5_1 && src1->type == GGML_TYPE_F32) {
|
| 607 |
+
ggml_cpy_q5_1_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
|
| 608 |
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) {
|
| 609 |
+
ggml_cpy_f16_f16_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
|
| 610 |
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) {
|
| 611 |
+
ggml_cpy_f16_f32_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
|
| 612 |
} else {
|
| 613 |
GGML_ABORT("%s: unsupported type combination (%s to %s)\n", __func__,
|
| 614 |
ggml_type_name(src0->type), ggml_type_name(src1->type));
|
| 615 |
}
|
| 616 |
+
#if defined(GGML_CUDA_USE_GRAPHS) || defined(GGML_HIP_GRAPHS)
|
| 617 |
+
if(ctx.cuda_graph->use_cpy_indirection) {
|
| 618 |
+
ctx.cuda_graph->graph_cpynode_index = graph_cpynode_index;
|
| 619 |
+
}
|
| 620 |
+
#endif
|
| 621 |
+
|
| 622 |
}
|
| 623 |
|
| 624 |
void ggml_cuda_dup(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
ggml/src/ggml-cuda/cpy.cuh
CHANGED
|
@@ -7,3 +7,5 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
|
|
| 7 |
void ggml_cuda_dup(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
| 8 |
|
| 9 |
void* ggml_cuda_cpy_fn(const ggml_tensor * src0, ggml_tensor * src1);
|
|
|
|
|
|
|
|
|
| 7 |
void ggml_cuda_dup(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
| 8 |
|
| 9 |
void* ggml_cuda_cpy_fn(const ggml_tensor * src0, ggml_tensor * src1);
|
| 10 |
+
|
| 11 |
+
void ggml_cuda_cpy_dest_ptrs_copy(ggml_cuda_graph * cuda_graph, char ** host_dest_ptrs, const int host_dest_ptrs_size, cudaStream_t stream);
|
ggml/src/ggml-cuda/ggml-cuda.cu
CHANGED
|
@@ -2469,10 +2469,11 @@ static void ggml_backend_cuda_synchronize(ggml_backend_t backend) {
|
|
| 2469 |
|
| 2470 |
#ifdef USE_CUDA_GRAPH
|
| 2471 |
static bool check_node_graph_compatibility_and_refresh_copy_ops(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph,
|
| 2472 |
-
|
| 2473 |
|
| 2474 |
// Loop over nodes in GGML graph to obtain info needed for CUDA graph
|
| 2475 |
-
cuda_ctx->cuda_graph->
|
|
|
|
| 2476 |
for (int i = 0; i < cgraph->n_nodes; i++) {
|
| 2477 |
ggml_tensor * node = cgraph->nodes[i];
|
| 2478 |
|
|
@@ -2504,8 +2505,11 @@ static bool check_node_graph_compatibility_and_refresh_copy_ops(ggml_backend_cud
|
|
| 2504 |
}
|
| 2505 |
|
| 2506 |
if (node->op == GGML_OP_CPY) {
|
| 2507 |
-
|
| 2508 |
-
|
|
|
|
|
|
|
|
|
|
| 2509 |
// store a pointer to each copy op CUDA kernel to identify it later
|
| 2510 |
void * ptr = ggml_cuda_cpy_fn(node->src[0], node->src[1]);
|
| 2511 |
if (!ptr) {
|
|
@@ -2513,10 +2517,6 @@ static bool check_node_graph_compatibility_and_refresh_copy_ops(ggml_backend_cud
|
|
| 2513 |
#ifndef NDEBUG
|
| 2514 |
GGML_LOG_DEBUG("%s: disabling CUDA graphs due to unsupported copy op\n", __func__);
|
| 2515 |
#endif
|
| 2516 |
-
} else {
|
| 2517 |
-
if (std::find(ggml_cuda_cpy_fn_ptrs.begin(), ggml_cuda_cpy_fn_ptrs.end(), ptr) == ggml_cuda_cpy_fn_ptrs.end()) {
|
| 2518 |
-
ggml_cuda_cpy_fn_ptrs.push_back(ptr);
|
| 2519 |
-
}
|
| 2520 |
}
|
| 2521 |
}
|
| 2522 |
|
|
@@ -2525,6 +2525,12 @@ static bool check_node_graph_compatibility_and_refresh_copy_ops(ggml_backend_cud
|
|
| 2525 |
}
|
| 2526 |
}
|
| 2527 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2528 |
return use_cuda_graph;
|
| 2529 |
}
|
| 2530 |
|
|
@@ -2579,51 +2585,6 @@ static bool ggml_graph_node_has_matching_properties(ggml_tensor * node, ggml_gra
|
|
| 2579 |
return true;
|
| 2580 |
}
|
| 2581 |
|
| 2582 |
-
static void maintain_cuda_graph(ggml_backend_cuda_context * cuda_ctx, std::vector<void *> & ggml_cuda_cpy_fn_ptrs, bool cuda_graph_update_required) {
|
| 2583 |
-
|
| 2584 |
-
if (cuda_graph_update_required) {
|
| 2585 |
-
// Extract nodes from graph
|
| 2586 |
-
// First call with null argument gets number of nodes in graph
|
| 2587 |
-
CUDA_CHECK(cudaGraphGetNodes(cuda_ctx->cuda_graph->graph, nullptr, &cuda_ctx->cuda_graph->num_nodes));
|
| 2588 |
-
// Subsequent call with non-null argument gets nodes
|
| 2589 |
-
cuda_ctx->cuda_graph->nodes.clear();
|
| 2590 |
-
cuda_ctx->cuda_graph->nodes.resize(cuda_ctx->cuda_graph->num_nodes);
|
| 2591 |
-
cuda_ctx->cuda_graph->params.clear();
|
| 2592 |
-
cuda_ctx->cuda_graph->params.resize(cuda_ctx->cuda_graph->num_nodes);
|
| 2593 |
-
if (cuda_ctx->cuda_graph->num_nodes > 0) {
|
| 2594 |
-
CUDA_CHECK(cudaGraphGetNodes(cuda_ctx->cuda_graph->graph, cuda_ctx->cuda_graph->nodes.data(), &cuda_ctx->cuda_graph->num_nodes));
|
| 2595 |
-
|
| 2596 |
-
// Loop over nodes, and extract kernel parameters from each node
|
| 2597 |
-
for (size_t i = 0; i < cuda_ctx->cuda_graph->num_nodes; i++) {
|
| 2598 |
-
cudaGraphNodeType node_type;
|
| 2599 |
-
CUDA_CHECK(cudaGraphNodeGetType(cuda_ctx->cuda_graph->nodes[i], &node_type));
|
| 2600 |
-
if (node_type == cudaGraphNodeTypeKernel) {
|
| 2601 |
-
cudaError_t stat = cudaGraphKernelNodeGetParams(cuda_ctx->cuda_graph->nodes[i], &cuda_ctx->cuda_graph->params[i]); // Get params using runtime
|
| 2602 |
-
if (stat == cudaErrorInvalidDeviceFunction) {
|
| 2603 |
-
// Fails due to incorrect handling by CUDA runtime of CUDA BLAS node.
|
| 2604 |
-
// We don't need to update blas nodes, so clear error and move on.
|
| 2605 |
-
(void)cudaGetLastError();
|
| 2606 |
-
} else {
|
| 2607 |
-
GGML_ASSERT(stat == cudaSuccess);
|
| 2608 |
-
}
|
| 2609 |
-
}
|
| 2610 |
-
}
|
| 2611 |
-
}
|
| 2612 |
-
} else {
|
| 2613 |
-
// One of the arguments to the copy kernel is updated for each token, hence we need to
|
| 2614 |
-
// replace that argument with the updated value in the CUDA graph
|
| 2615 |
-
// on update steps, the live parameters will already be captured
|
| 2616 |
-
int k = 0;
|
| 2617 |
-
for (size_t i = 0; i < cuda_ctx->cuda_graph->num_nodes; i++) {
|
| 2618 |
-
if(count(ggml_cuda_cpy_fn_ptrs.begin(), ggml_cuda_cpy_fn_ptrs.end(), cuda_ctx->cuda_graph->params[i].func) > 0) {
|
| 2619 |
-
char ** updated_kernel_arg_ptr = cuda_ctx->cuda_graph->updated_kernel_arg.at(k++);
|
| 2620 |
-
*(void**)cuda_ctx->cuda_graph->params[i].kernelParams[1] = *(void**)updated_kernel_arg_ptr;
|
| 2621 |
-
CUDA_CHECK(cudaGraphKernelNodeSetParams(cuda_ctx->cuda_graph->nodes[i], &cuda_ctx->cuda_graph->params[i]));
|
| 2622 |
-
}
|
| 2623 |
-
}
|
| 2624 |
-
}
|
| 2625 |
-
}
|
| 2626 |
-
|
| 2627 |
static bool is_cuda_graph_update_required(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph) {
|
| 2628 |
|
| 2629 |
bool cuda_graph_update_required = false;
|
|
@@ -2683,8 +2644,7 @@ static void update_cuda_graph_executable(ggml_backend_cuda_context * cuda_ctx) {
|
|
| 2683 |
#endif
|
| 2684 |
|
| 2685 |
static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph,
|
| 2686 |
-
|
| 2687 |
-
bool & cuda_graph_update_required) {
|
| 2688 |
|
| 2689 |
while (!graph_evaluated_or_captured) {
|
| 2690 |
// Only perform the graph execution if CUDA graphs are not enabled, or we are capturing the graph.
|
|
@@ -2734,13 +2694,9 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
|
|
| 2734 |
if (cuda_ctx->cuda_graph->instance == nullptr) { // Create executable graph from captured graph.
|
| 2735 |
CUDA_CHECK(cudaGraphInstantiate(&cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, NULL, NULL, 0));
|
| 2736 |
}
|
| 2737 |
-
|
| 2738 |
-
|
| 2739 |
-
|
| 2740 |
-
|
| 2741 |
-
// Update graph executable
|
| 2742 |
-
update_cuda_graph_executable(cuda_ctx);
|
| 2743 |
-
|
| 2744 |
// Launch graph
|
| 2745 |
CUDA_CHECK(cudaGraphLaunch(cuda_ctx->cuda_graph->instance, cuda_ctx->stream()));
|
| 2746 |
#else
|
|
@@ -2754,10 +2710,6 @@ static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend,
|
|
| 2754 |
|
| 2755 |
ggml_cuda_set_device(cuda_ctx->device);
|
| 2756 |
|
| 2757 |
-
// vector of pointers to CUDA cpy kernels, which are required to identify
|
| 2758 |
-
// kernel parameters which need updated in the graph for each token
|
| 2759 |
-
std::vector<void *> ggml_cuda_cpy_fn_ptrs;
|
| 2760 |
-
|
| 2761 |
#ifdef USE_CUDA_GRAPH
|
| 2762 |
static const bool disable_cuda_graphs_due_to_env = (getenv("GGML_CUDA_DISABLE_GRAPHS") != nullptr);
|
| 2763 |
|
|
@@ -2791,8 +2743,7 @@ static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend,
|
|
| 2791 |
if (use_cuda_graph) {
|
| 2792 |
cuda_graph_update_required = is_cuda_graph_update_required(cuda_ctx, cgraph);
|
| 2793 |
|
| 2794 |
-
use_cuda_graph = check_node_graph_compatibility_and_refresh_copy_ops(cuda_ctx, cgraph,
|
| 2795 |
-
ggml_cuda_cpy_fn_ptrs, use_cuda_graph);
|
| 2796 |
|
| 2797 |
// Disable CUDA graphs (from the next token) if the use-case is demanding too many consecutive graph updates.
|
| 2798 |
if (use_cuda_graph && cuda_graph_update_required) {
|
|
@@ -2813,6 +2764,10 @@ static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend,
|
|
| 2813 |
CUDA_CHECK(cudaStreamBeginCapture(cuda_ctx->stream(), cudaStreamCaptureModeRelaxed));
|
| 2814 |
}
|
| 2815 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2816 |
#else
|
| 2817 |
bool use_cuda_graph = false;
|
| 2818 |
bool cuda_graph_update_required = false;
|
|
@@ -2820,7 +2775,7 @@ static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend,
|
|
| 2820 |
|
| 2821 |
bool graph_evaluated_or_captured = false;
|
| 2822 |
|
| 2823 |
-
evaluate_and_capture_cuda_graph(cuda_ctx, cgraph,
|
| 2824 |
|
| 2825 |
return GGML_STATUS_SUCCESS;
|
| 2826 |
}
|
|
|
|
| 2469 |
|
| 2470 |
#ifdef USE_CUDA_GRAPH
|
| 2471 |
static bool check_node_graph_compatibility_and_refresh_copy_ops(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph,
|
| 2472 |
+
bool use_cuda_graph) {
|
| 2473 |
|
| 2474 |
// Loop over nodes in GGML graph to obtain info needed for CUDA graph
|
| 2475 |
+
cuda_ctx->cuda_graph->cpy_dest_ptrs.clear();
|
| 2476 |
+
|
| 2477 |
for (int i = 0; i < cgraph->n_nodes; i++) {
|
| 2478 |
ggml_tensor * node = cgraph->nodes[i];
|
| 2479 |
|
|
|
|
| 2505 |
}
|
| 2506 |
|
| 2507 |
if (node->op == GGML_OP_CPY) {
|
| 2508 |
+
|
| 2509 |
+
// Store the pointers which are updated for each token, such that these can be sent
|
| 2510 |
+
// to the device and accessed using indirection from CUDA graph
|
| 2511 |
+
cuda_ctx->cuda_graph->cpy_dest_ptrs.push_back((char *) node->src[1]->data);
|
| 2512 |
+
|
| 2513 |
// store a pointer to each copy op CUDA kernel to identify it later
|
| 2514 |
void * ptr = ggml_cuda_cpy_fn(node->src[0], node->src[1]);
|
| 2515 |
if (!ptr) {
|
|
|
|
| 2517 |
#ifndef NDEBUG
|
| 2518 |
GGML_LOG_DEBUG("%s: disabling CUDA graphs due to unsupported copy op\n", __func__);
|
| 2519 |
#endif
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2520 |
}
|
| 2521 |
}
|
| 2522 |
|
|
|
|
| 2525 |
}
|
| 2526 |
}
|
| 2527 |
|
| 2528 |
+
if (use_cuda_graph) {
|
| 2529 |
+
cuda_ctx->cuda_graph->use_cpy_indirection = true;
|
| 2530 |
+
// copy pointers to GPU so they can be accessed via indirection within CUDA graph
|
| 2531 |
+
ggml_cuda_cpy_dest_ptrs_copy(cuda_ctx->cuda_graph.get(), cuda_ctx->cuda_graph->cpy_dest_ptrs.data(), cuda_ctx->cuda_graph->cpy_dest_ptrs.size(), cuda_ctx->stream());
|
| 2532 |
+
}
|
| 2533 |
+
|
| 2534 |
return use_cuda_graph;
|
| 2535 |
}
|
| 2536 |
|
|
|
|
| 2585 |
return true;
|
| 2586 |
}
|
| 2587 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2588 |
static bool is_cuda_graph_update_required(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph) {
|
| 2589 |
|
| 2590 |
bool cuda_graph_update_required = false;
|
|
|
|
| 2644 |
#endif
|
| 2645 |
|
| 2646 |
static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph,
|
| 2647 |
+
bool & graph_evaluated_or_captured, bool & use_cuda_graph, bool & cuda_graph_update_required) {
|
|
|
|
| 2648 |
|
| 2649 |
while (!graph_evaluated_or_captured) {
|
| 2650 |
// Only perform the graph execution if CUDA graphs are not enabled, or we are capturing the graph.
|
|
|
|
| 2694 |
if (cuda_ctx->cuda_graph->instance == nullptr) { // Create executable graph from captured graph.
|
| 2695 |
CUDA_CHECK(cudaGraphInstantiate(&cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, NULL, NULL, 0));
|
| 2696 |
}
|
| 2697 |
+
if (cuda_graph_update_required) { // Update graph executable
|
| 2698 |
+
update_cuda_graph_executable(cuda_ctx);
|
| 2699 |
+
}
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2700 |
// Launch graph
|
| 2701 |
CUDA_CHECK(cudaGraphLaunch(cuda_ctx->cuda_graph->instance, cuda_ctx->stream()));
|
| 2702 |
#else
|
|
|
|
| 2710 |
|
| 2711 |
ggml_cuda_set_device(cuda_ctx->device);
|
| 2712 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2713 |
#ifdef USE_CUDA_GRAPH
|
| 2714 |
static const bool disable_cuda_graphs_due_to_env = (getenv("GGML_CUDA_DISABLE_GRAPHS") != nullptr);
|
| 2715 |
|
|
|
|
| 2743 |
if (use_cuda_graph) {
|
| 2744 |
cuda_graph_update_required = is_cuda_graph_update_required(cuda_ctx, cgraph);
|
| 2745 |
|
| 2746 |
+
use_cuda_graph = check_node_graph_compatibility_and_refresh_copy_ops(cuda_ctx, cgraph, use_cuda_graph);
|
|
|
|
| 2747 |
|
| 2748 |
// Disable CUDA graphs (from the next token) if the use-case is demanding too many consecutive graph updates.
|
| 2749 |
if (use_cuda_graph && cuda_graph_update_required) {
|
|
|
|
| 2764 |
CUDA_CHECK(cudaStreamBeginCapture(cuda_ctx->stream(), cudaStreamCaptureModeRelaxed));
|
| 2765 |
}
|
| 2766 |
|
| 2767 |
+
if (!use_cuda_graph) {
|
| 2768 |
+
cuda_ctx->cuda_graph->use_cpy_indirection = false;
|
| 2769 |
+
}
|
| 2770 |
+
|
| 2771 |
#else
|
| 2772 |
bool use_cuda_graph = false;
|
| 2773 |
bool cuda_graph_update_required = false;
|
|
|
|
| 2775 |
|
| 2776 |
bool graph_evaluated_or_captured = false;
|
| 2777 |
|
| 2778 |
+
evaluate_and_capture_cuda_graph(cuda_ctx, cgraph, graph_evaluated_or_captured, use_cuda_graph, cuda_graph_update_required);
|
| 2779 |
|
| 2780 |
return GGML_STATUS_SUCCESS;
|
| 2781 |
}
|