Diego Devesa commited on
Commit
39c4fa5
·
1 Parent(s): 4b0d2de

cuda : synchronize graph capture and cublas handle destruction (llama/14288)

Browse files

Workarounds an issue that may cause CUDA graph capture to fail when a cuBLAS handle is destroyed in a different thread

ggml/src/ggml-cuda/common.cuh CHANGED
@@ -19,10 +19,10 @@
19
  #endif
20
  #include "ggml-common.h"
21
 
22
- #include <cstdio>
23
  #include <array>
24
  #include <cassert>
25
  #include <cfloat>
 
26
  #include <string>
27
  #include <vector>
28
 
@@ -767,21 +767,7 @@ struct ggml_backend_cuda_context {
767
  name(GGML_CUDA_NAME + std::to_string(device)) {
768
  }
769
 
770
- ~ggml_backend_cuda_context() {
771
- if (copy_event != nullptr) {
772
- CUDA_CHECK(cudaEventDestroy(copy_event));
773
- }
774
- for (int i = 0; i < GGML_CUDA_MAX_DEVICES; ++i) {
775
- for (int j = 0; j < GGML_CUDA_MAX_STREAMS; ++j) {
776
- if (streams[i][j] != nullptr) {
777
- CUDA_CHECK(cudaStreamDestroy(streams[i][j]));
778
- }
779
- }
780
- if (cublas_handles[i] != nullptr) {
781
- CUBLAS_CHECK(cublasDestroy(cublas_handles[i]));
782
- }
783
- }
784
- }
785
 
786
  cudaStream_t stream(int device, int stream) {
787
  if (streams[device][stream] == nullptr) {
 
19
  #endif
20
  #include "ggml-common.h"
21
 
 
22
  #include <array>
23
  #include <cassert>
24
  #include <cfloat>
25
+ #include <cstdio>
26
  #include <string>
27
  #include <vector>
28
 
 
767
  name(GGML_CUDA_NAME + std::to_string(device)) {
768
  }
769
 
770
+ ~ggml_backend_cuda_context();
 
 
 
 
 
 
 
 
 
 
 
 
 
 
771
 
772
  cudaStream_t stream(int device, int stream) {
773
  if (streams[device][stream] == nullptr) {
ggml/src/ggml-cuda/ggml-cuda.cu CHANGED
@@ -48,6 +48,7 @@
48
  #include <atomic>
49
  #include <charconv>
50
  #include <cinttypes>
 
51
  #include <cstddef>
52
  #include <cstdint>
53
  #include <float.h>
@@ -55,9 +56,8 @@
55
  #include <map>
56
  #include <memory>
57
  #include <mutex>
58
- #include <stdint.h>
59
- #include <stdio.h>
60
  #include <stdarg.h>
 
61
  #include <stdlib.h>
62
  #include <string>
63
  #include <vector>
@@ -515,6 +515,33 @@ std::unique_ptr<ggml_cuda_pool> ggml_backend_cuda_context::new_pool_for_device(i
515
  return std::unique_ptr<ggml_cuda_pool>(new ggml_cuda_pool_leg(device));
516
  }
517
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
518
  // cuda buffer
519
 
520
  struct ggml_backend_cuda_buffer_context {
@@ -2689,6 +2716,11 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
2689
 
2690
  CUDA_CHECK(cudaStreamEndCapture(cuda_ctx->stream(), &cuda_ctx->cuda_graph->graph));
2691
  graph_evaluated_or_captured = true; // CUDA graph has been captured
 
 
 
 
 
2692
  } else {
2693
  graph_evaluated_or_captured = true; // ggml graph has been directly evaluated
2694
  }
@@ -2764,7 +2796,13 @@ static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend,
2764
  }
2765
  }
2766
 
2767
- if (use_cuda_graph && cuda_graph_update_required) { // Start CUDA graph capture
 
 
 
 
 
 
2768
  CUDA_CHECK(cudaStreamBeginCapture(cuda_ctx->stream(), cudaStreamCaptureModeRelaxed));
2769
  }
2770
 
 
48
  #include <atomic>
49
  #include <charconv>
50
  #include <cinttypes>
51
+ #include <condition_variable>
52
  #include <cstddef>
53
  #include <cstdint>
54
  #include <float.h>
 
56
  #include <map>
57
  #include <memory>
58
  #include <mutex>
 
 
59
  #include <stdarg.h>
60
+ #include <stdio.h>
61
  #include <stdlib.h>
62
  #include <string>
63
  #include <vector>
 
515
  return std::unique_ptr<ggml_cuda_pool>(new ggml_cuda_pool_leg(device));
516
  }
517
 
518
+ // destroying a cuBLAS handle while a graph is being captured in a different thread can result in a CUDA error
519
+ // this lock is used to ensure that no cuBLAS handle is destroyed while a graph is being captured
520
+
521
+ static std::mutex ggml_cuda_lock;
522
+ static std::condition_variable ggml_cuda_lock_cv;
523
+ static std::atomic<int> ggml_cuda_lock_counter;
524
+
525
+ ggml_backend_cuda_context::~ggml_backend_cuda_context() {
526
+ std::unique_lock<std::mutex> lock(ggml_cuda_lock);
527
+ ggml_cuda_lock_cv.wait(lock, []{ return ggml_cuda_lock_counter.load(std::memory_order_relaxed) == 0; });
528
+
529
+ if (copy_event != nullptr) {
530
+ CUDA_CHECK(cudaEventDestroy(copy_event));
531
+ }
532
+ for (int i = 0; i < GGML_CUDA_MAX_DEVICES; ++i) {
533
+ for (int j = 0; j < GGML_CUDA_MAX_STREAMS; ++j) {
534
+ if (streams[i][j] != nullptr) {
535
+ CUDA_CHECK(cudaStreamDestroy(streams[i][j]));
536
+ }
537
+ }
538
+ if (cublas_handles[i] != nullptr) {
539
+ CUBLAS_CHECK(cublasDestroy(cublas_handles[i]));
540
+ }
541
+ }
542
+ }
543
+
544
+
545
  // cuda buffer
546
 
547
  struct ggml_backend_cuda_buffer_context {
 
2716
 
2717
  CUDA_CHECK(cudaStreamEndCapture(cuda_ctx->stream(), &cuda_ctx->cuda_graph->graph));
2718
  graph_evaluated_or_captured = true; // CUDA graph has been captured
2719
+
2720
+ std::lock_guard<std::mutex> lock(ggml_cuda_lock);
2721
+ if (ggml_cuda_lock_counter.fetch_sub(1, std::memory_order_relaxed) == 1) {
2722
+ ggml_cuda_lock_cv.notify_all();
2723
+ }
2724
  } else {
2725
  graph_evaluated_or_captured = true; // ggml graph has been directly evaluated
2726
  }
 
2796
  }
2797
  }
2798
 
2799
+ if (use_cuda_graph && cuda_graph_update_required) {
2800
+ // Start CUDA graph capture
2801
+ {
2802
+ std::lock_guard<std::mutex> lock(ggml_cuda_lock);
2803
+ ggml_cuda_lock_counter.fetch_add(1, std::memory_order_relaxed);
2804
+ }
2805
+
2806
  CUDA_CHECK(cudaStreamBeginCapture(cuda_ctx->stream(), cudaStreamCaptureModeRelaxed));
2807
  }
2808