agray3 commited on
Commit
6124287
·
1 Parent(s): 4ff3b72

Allow number of nodes in CUDA graph to change (llama/7738)

Browse files

Previously the code would have failed to cope in the case that the
number of nodes changes in an existing CUDA graph. This fixes the
issue by removing an unnecessary conditional.

Files changed (1) hide show
  1. ggml-cuda.cu +2 -4
ggml-cuda.cu CHANGED
@@ -2702,10 +2702,8 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t
2702
 
2703
  if (cuda_graph_update_required) {
2704
  // Extract nodes from graph
2705
- if (cuda_ctx->cuda_graph->num_nodes == 0) {
2706
- // First call with null argument gets number of nodes in graph
2707
- CUDA_CHECK(cudaGraphGetNodes(cuda_ctx->cuda_graph->graph, nullptr, &cuda_ctx->cuda_graph->num_nodes));
2708
- }
2709
  // Subsequent call with non-null argument gets nodes
2710
  cuda_ctx->cuda_graph->nodes.resize(cuda_ctx->cuda_graph->num_nodes);
2711
  cuda_ctx->cuda_graph->params.resize(cuda_ctx->cuda_graph->num_nodes);
 
2702
 
2703
  if (cuda_graph_update_required) {
2704
  // Extract nodes from graph
2705
+ // First call with null argument gets number of nodes in graph
2706
+ CUDA_CHECK(cudaGraphGetNodes(cuda_ctx->cuda_graph->graph, nullptr, &cuda_ctx->cuda_graph->num_nodes));
 
 
2707
  // Subsequent call with non-null argument gets nodes
2708
  cuda_ctx->cuda_graph->nodes.resize(cuda_ctx->cuda_graph->num_nodes);
2709
  cuda_ctx->cuda_graph->params.resize(cuda_ctx->cuda_graph->num_nodes);