stanimirovb commited on
Commit
93af41a
·
unverified ·
1 Parent(s): a8eb666

whisper : use ggml-cuda in mel calc, set appropriate device (#2236)

Browse files

* whisper : use ggml-cuda in mel calc, set appropriate device

* whisper : forbid cuda mel calc on devices with compute < 600, workaround for #2230

Files changed (2) hide show
  1. whisper-mel-cuda.cu +26 -25
  2. whisper.cpp +11 -6
whisper-mel-cuda.cu CHANGED
@@ -2,6 +2,9 @@
2
  #include "whisper-mel-cuda.hpp"
3
  #include "whisper.h"
4
 
 
 
 
5
  #include <cuda.h>
6
  #include <cuda_runtime.h>
7
  #include <cufft.h>
@@ -16,16 +19,9 @@
16
  #pragma warning(disable: 4324) // added padding
17
  #endif
18
 
19
- #ifndef NDEBUG
20
- # define DO_CHECKS 1
21
- #else
22
- # define DO_CHECKS 0
23
- #endif
24
-
25
  namespace {
26
 
27
- #if DO_CHECKS
28
- const char* cufftGetErrorString(cufftResult_t res) {
29
  switch (res) {
30
  case CUFFT_SUCCESS: return "The cuFFT operation was successful";
31
  case CUFFT_INVALID_PLAN: return "cuFFT was passed an invalid plan handle";
@@ -48,19 +44,6 @@ const char* cufftGetErrorString(cufftResult_t res) {
48
  }
49
  }
50
 
51
- # define CUDA_CHECK_GEN(err, success, error_fn) \
52
- do { \
53
- auto err_ = (err); \
54
- if (err_ != (success)) { \
55
- fprintf(stderr, "%s %s:%d - %s\n", #err, __FILE__, __LINE__, error_fn(err_)); \
56
- } \
57
- } while (0)
58
- #else
59
- # define CUDA_CHECK_GEN(err, success, error_fn) err
60
- #endif
61
-
62
- #define CUDA_CHECK(err) CUDA_CHECK_GEN(err, cudaSuccess, cudaGetErrorString)
63
- #define CUBLAS_CHECK(err) CUDA_CHECK_GEN(err, CUBLAS_STATUS_SUCCESS, cublasGetStatusString)
64
  #define CUFFT_CHECK(err) CUDA_CHECK_GEN(err, CUFFT_SUCCESS, cufftGetErrorString)
65
 
66
  __global__ void k_fill_stft_input(
@@ -81,7 +64,7 @@ __global__ void k_fill_stft_input(
81
  }
82
 
83
  __global__ void k_calc_magnitudes(
84
- const cuComplex* stft_out,
85
  const int n_frames,
86
  float * magnitudes
87
  ) {
@@ -133,7 +116,7 @@ void fill_stft_input(
133
  }
134
 
135
  void calc_magnitudes(
136
- const cuComplex* stft_out,
137
  int n_frames,
138
  float * magnitudes,
139
  cudaStream_t stream
@@ -169,6 +152,7 @@ class mel_calc_cuda : public whisper_mel_calc {
169
  const int m_n_mel;
170
 
171
  ggml_backend_t m_backend = nullptr;
 
172
 
173
  cudaStream_t m_stream = nullptr;
174
  cublasHandle_t m_cublas_handle = nullptr;
@@ -190,6 +174,18 @@ public:
190
  : m_n_mel(filters.n_mel)
191
  , m_backend(backend)
192
  {
 
 
 
 
 
 
 
 
 
 
 
 
193
  if (filters.n_fft != WHISPER_N_FFT_HALF) {
194
  throw std::invalid_argument("MelFilters n_frames must be WHISPER_N_FFT_HALF");
195
  }
@@ -219,6 +215,7 @@ public:
219
  }
220
 
221
  ~mel_calc_cuda() {
 
222
  CUDA_CHECK(cudaStreamSynchronize(m_stream));
223
  CUDA_CHECK(cudaStreamDestroy(m_stream));
224
  CUDA_CHECK(cudaFree(m_hann_window));
@@ -268,6 +265,7 @@ public:
268
  }
269
 
270
  virtual whisper_mel calculate(whisper_span<const float> samples, int /*n_threads*/) override {
 
271
  ensure_working_areas(samples.len);
272
 
273
  const size_t mirror_pad = WHISPER_N_FFT / 2;
@@ -356,8 +354,11 @@ public:
356
  }
357
 
358
  whisper_mel_calc * whisper_mel_calc_create_cuda(ggml_backend_t backend, const whisper_filters & filters) {
359
- if (filters.n_fft != WHISPER_N_FFT_HALF) {
 
 
 
 
360
  return nullptr;
361
  }
362
- return new mel_calc_cuda(backend, filters);
363
  }
 
2
  #include "whisper-mel-cuda.hpp"
3
  #include "whisper.h"
4
 
5
+ #include <ggml-cuda/common.cuh>
6
+ #include <ggml-backend-impl.h>
7
+
8
  #include <cuda.h>
9
  #include <cuda_runtime.h>
10
  #include <cufft.h>
 
19
  #pragma warning(disable: 4324) // added padding
20
  #endif
21
 
 
 
 
 
 
 
22
  namespace {
23
 
24
+ static const char* cufftGetErrorString(cufftResult_t res) {
 
25
  switch (res) {
26
  case CUFFT_SUCCESS: return "The cuFFT operation was successful";
27
  case CUFFT_INVALID_PLAN: return "cuFFT was passed an invalid plan handle";
 
44
  }
45
  }
46
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
  #define CUFFT_CHECK(err) CUDA_CHECK_GEN(err, CUFFT_SUCCESS, cufftGetErrorString)
48
 
49
  __global__ void k_fill_stft_input(
 
64
  }
65
 
66
  __global__ void k_calc_magnitudes(
67
+ const cuComplex * stft_out,
68
  const int n_frames,
69
  float * magnitudes
70
  ) {
 
116
  }
117
 
118
  void calc_magnitudes(
119
+ const cuComplex * stft_out,
120
  int n_frames,
121
  float * magnitudes,
122
  cudaStream_t stream
 
152
  const int m_n_mel;
153
 
154
  ggml_backend_t m_backend = nullptr;
155
+ int m_device = -1;
156
 
157
  cudaStream_t m_stream = nullptr;
158
  cublasHandle_t m_cublas_handle = nullptr;
 
174
  : m_n_mel(filters.n_mel)
175
  , m_backend(backend)
176
  {
177
+ ggml_backend_cuda_context* cuda_ctx = (ggml_backend_cuda_context*)m_backend->context;
178
+ m_device = cuda_ctx->device;
179
+
180
+ if (ggml_cuda_info().devices[m_device].cc < 600) {
181
+ // we've only tesed on 6.0 and higher and we've had reports of crashes on 5.0:
182
+ // https://github.com/ggerganov/whisper.cpp/issues/2230
183
+ // to be safe forbid anything below 6.0
184
+ throw std::runtime_error("CUDA compute capability 6.0 or higher is required");
185
+ }
186
+
187
+ ggml_cuda_set_device(m_device);
188
+
189
  if (filters.n_fft != WHISPER_N_FFT_HALF) {
190
  throw std::invalid_argument("MelFilters n_frames must be WHISPER_N_FFT_HALF");
191
  }
 
215
  }
216
 
217
  ~mel_calc_cuda() {
218
+ ggml_cuda_set_device(m_device);
219
  CUDA_CHECK(cudaStreamSynchronize(m_stream));
220
  CUDA_CHECK(cudaStreamDestroy(m_stream));
221
  CUDA_CHECK(cudaFree(m_hann_window));
 
265
  }
266
 
267
  virtual whisper_mel calculate(whisper_span<const float> samples, int /*n_threads*/) override {
268
+ ggml_cuda_set_device(m_device);
269
  ensure_working_areas(samples.len);
270
 
271
  const size_t mirror_pad = WHISPER_N_FFT / 2;
 
354
  }
355
 
356
  whisper_mel_calc * whisper_mel_calc_create_cuda(ggml_backend_t backend, const whisper_filters & filters) {
357
+ try {
358
+ return new mel_calc_cuda(backend, filters);
359
+ }
360
+ catch (...) {
361
+ // TODO: log error (but for this we would have to expose the log state to be accessible here)
362
  return nullptr;
363
  }
 
364
  }
whisper.cpp CHANGED
@@ -3170,13 +3170,18 @@ whisper_mel_calc * whisper_mel_calc_create(ggml_backend_t backend, const whisper
3170
  #if defined(GGML_USE_CUDA) && !defined(GGML_USE_HIPBLAS)
3171
  if (ggml_backend_is_cuda(backend)) {
3172
  auto ret = whisper_mel_calc_create_cuda(backend, filters);
3173
- // run a warmup to avoid the first kernel launch overhead (thus we get the best perf even on the first run)
3174
- const float warmup[256] = {0};
3175
- ret->calculate({warmup, 256}, 1);
3176
- return ret;
3177
- } else
 
 
3178
  #endif
3179
- return new mel_calc_cpu(backend, filters);
 
 
 
3180
  }
3181
 
3182
  // split text into tokens
 
3170
  #if defined(GGML_USE_CUDA) && !defined(GGML_USE_HIPBLAS)
3171
  if (ggml_backend_is_cuda(backend)) {
3172
  auto ret = whisper_mel_calc_create_cuda(backend, filters);
3173
+ if (ret) {
3174
+ // run a warmup to avoid the first kernel launch overhead (thus we get the best perf even on the first run)
3175
+ const float warmup[256] = { 0 };
3176
+ ret->calculate({ warmup, 256 }, 1);
3177
+ return ret;
3178
+ }
3179
+ }
3180
  #endif
3181
+
3182
+ // a specialized mel_calc could not be created
3183
+ // fall back to CPU
3184
+ return new mel_calc_cpu(backend, filters);
3185
  }
3186
 
3187
  // split text into tokens