agray3 commited on
Commit
6b63eb1
·
1 Parent(s): 91c7734

Update CUDA graph on scale change plus clear nodes/params (llama/9550)

Browse files

* Avoid using saved CUDA graph if scale changes and reset nodes/params on update

Fixes https://github.com/ggerganov/llama.cpp/issues/9451

* clear before resize

ggml/src/ggml-cuda.cu CHANGED
@@ -2478,6 +2478,7 @@ static void set_ggml_graph_node_properties(ggml_tensor * node, ggml_graph_node_p
2478
  for (int i = 0; i < GGML_MAX_SRC; i++) {
2479
  graph_node_properties->src_address[i] = node->src[i] ? node->src[i]->data : nullptr;
2480
  }
 
2481
  }
2482
 
2483
  static bool ggml_graph_node_has_matching_properties(ggml_tensor * node, ggml_graph_node_properties * graph_node_properties) {
@@ -2509,6 +2510,12 @@ static bool ggml_graph_node_has_matching_properties(ggml_tensor * node, ggml_gra
2509
  return false;
2510
  }
2511
  }
 
 
 
 
 
 
2512
  return true;
2513
  }
2514
 
@@ -2720,7 +2727,9 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t
2720
  // First call with null argument gets number of nodes in graph
2721
  CUDA_CHECK(cudaGraphGetNodes(cuda_ctx->cuda_graph->graph, nullptr, &cuda_ctx->cuda_graph->num_nodes));
2722
  // Subsequent call with non-null argument gets nodes
 
2723
  cuda_ctx->cuda_graph->nodes.resize(cuda_ctx->cuda_graph->num_nodes);
 
2724
  cuda_ctx->cuda_graph->params.resize(cuda_ctx->cuda_graph->num_nodes);
2725
  if (cuda_ctx->cuda_graph->num_nodes > 0) {
2726
  CUDA_CHECK(cudaGraphGetNodes(cuda_ctx->cuda_graph->graph, cuda_ctx->cuda_graph->nodes.data(), &cuda_ctx->cuda_graph->num_nodes));
 
2478
  for (int i = 0; i < GGML_MAX_SRC; i++) {
2479
  graph_node_properties->src_address[i] = node->src[i] ? node->src[i]->data : nullptr;
2480
  }
2481
+ memcpy(graph_node_properties->op_params, node->op_params, GGML_MAX_OP_PARAMS);
2482
  }
2483
 
2484
  static bool ggml_graph_node_has_matching_properties(ggml_tensor * node, ggml_graph_node_properties * graph_node_properties) {
 
2510
  return false;
2511
  }
2512
  }
2513
+
2514
+ if (node->op == GGML_OP_SCALE &&
2515
+ memcmp(graph_node_properties->op_params, node->op_params, GGML_MAX_OP_PARAMS) != 0) {
2516
+ return false;
2517
+ }
2518
+
2519
  return true;
2520
  }
2521
 
 
2727
  // First call with null argument gets number of nodes in graph
2728
  CUDA_CHECK(cudaGraphGetNodes(cuda_ctx->cuda_graph->graph, nullptr, &cuda_ctx->cuda_graph->num_nodes));
2729
  // Subsequent call with non-null argument gets nodes
2730
+ cuda_ctx->cuda_graph->nodes.clear();
2731
  cuda_ctx->cuda_graph->nodes.resize(cuda_ctx->cuda_graph->num_nodes);
2732
+ cuda_ctx->cuda_graph->params.clear();
2733
  cuda_ctx->cuda_graph->params.resize(cuda_ctx->cuda_graph->num_nodes);
2734
  if (cuda_ctx->cuda_graph->num_nodes > 0) {
2735
  CUDA_CHECK(cudaGraphGetNodes(cuda_ctx->cuda_graph->graph, cuda_ctx->cuda_graph->nodes.data(), &cuda_ctx->cuda_graph->num_nodes));
ggml/src/ggml-cuda/common.cuh CHANGED
@@ -569,6 +569,7 @@ struct ggml_graph_node_properties {
569
  int64_t ne[GGML_MAX_DIMS];
570
  size_t nb[GGML_MAX_DIMS];
571
  void * src_address[GGML_MAX_SRC];
 
572
  };
573
 
574
  struct ggml_cuda_graph {
 
569
  int64_t ne[GGML_MAX_DIMS];
570
  size_t nb[GGML_MAX_DIMS];
571
  void * src_address[GGML_MAX_SRC];
572
+ int32_t op_params[GGML_MAX_OP_PARAMS / sizeof(int32_t)];
573
  };
574
 
575
  struct ggml_cuda_graph {