Spaces:
Running
Running
whisper : use ggml-cuda in mel calc, set appropriate device (#2236)
Browse files* whisper : use ggml-cuda in mel calc, set appropriate device
* whisper : forbid cuda mel calc on devices with compute < 600, workaround for #2230
- whisper-mel-cuda.cu +26 -25
- whisper.cpp +11 -6
whisper-mel-cuda.cu
CHANGED
|
@@ -2,6 +2,9 @@
|
|
| 2 |
#include "whisper-mel-cuda.hpp"
|
| 3 |
#include "whisper.h"
|
| 4 |
|
|
|
|
|
|
|
|
|
|
| 5 |
#include <cuda.h>
|
| 6 |
#include <cuda_runtime.h>
|
| 7 |
#include <cufft.h>
|
|
@@ -16,16 +19,9 @@
|
|
| 16 |
#pragma warning(disable: 4324) // added padding
|
| 17 |
#endif
|
| 18 |
|
| 19 |
-
#ifndef NDEBUG
|
| 20 |
-
# define DO_CHECKS 1
|
| 21 |
-
#else
|
| 22 |
-
# define DO_CHECKS 0
|
| 23 |
-
#endif
|
| 24 |
-
|
| 25 |
namespace {
|
| 26 |
|
| 27 |
-
|
| 28 |
-
const char* cufftGetErrorString(cufftResult_t res) {
|
| 29 |
switch (res) {
|
| 30 |
case CUFFT_SUCCESS: return "The cuFFT operation was successful";
|
| 31 |
case CUFFT_INVALID_PLAN: return "cuFFT was passed an invalid plan handle";
|
|
@@ -48,19 +44,6 @@ const char* cufftGetErrorString(cufftResult_t res) {
|
|
| 48 |
}
|
| 49 |
}
|
| 50 |
|
| 51 |
-
# define CUDA_CHECK_GEN(err, success, error_fn) \
|
| 52 |
-
do { \
|
| 53 |
-
auto err_ = (err); \
|
| 54 |
-
if (err_ != (success)) { \
|
| 55 |
-
fprintf(stderr, "%s %s:%d - %s\n", #err, __FILE__, __LINE__, error_fn(err_)); \
|
| 56 |
-
} \
|
| 57 |
-
} while (0)
|
| 58 |
-
#else
|
| 59 |
-
# define CUDA_CHECK_GEN(err, success, error_fn) err
|
| 60 |
-
#endif
|
| 61 |
-
|
| 62 |
-
#define CUDA_CHECK(err) CUDA_CHECK_GEN(err, cudaSuccess, cudaGetErrorString)
|
| 63 |
-
#define CUBLAS_CHECK(err) CUDA_CHECK_GEN(err, CUBLAS_STATUS_SUCCESS, cublasGetStatusString)
|
| 64 |
#define CUFFT_CHECK(err) CUDA_CHECK_GEN(err, CUFFT_SUCCESS, cufftGetErrorString)
|
| 65 |
|
| 66 |
__global__ void k_fill_stft_input(
|
|
@@ -81,7 +64,7 @@ __global__ void k_fill_stft_input(
|
|
| 81 |
}
|
| 82 |
|
| 83 |
__global__ void k_calc_magnitudes(
|
| 84 |
-
const cuComplex* stft_out,
|
| 85 |
const int n_frames,
|
| 86 |
float * magnitudes
|
| 87 |
) {
|
|
@@ -133,7 +116,7 @@ void fill_stft_input(
|
|
| 133 |
}
|
| 134 |
|
| 135 |
void calc_magnitudes(
|
| 136 |
-
const cuComplex* stft_out,
|
| 137 |
int n_frames,
|
| 138 |
float * magnitudes,
|
| 139 |
cudaStream_t stream
|
|
@@ -169,6 +152,7 @@ class mel_calc_cuda : public whisper_mel_calc {
|
|
| 169 |
const int m_n_mel;
|
| 170 |
|
| 171 |
ggml_backend_t m_backend = nullptr;
|
|
|
|
| 172 |
|
| 173 |
cudaStream_t m_stream = nullptr;
|
| 174 |
cublasHandle_t m_cublas_handle = nullptr;
|
|
@@ -190,6 +174,18 @@ public:
|
|
| 190 |
: m_n_mel(filters.n_mel)
|
| 191 |
, m_backend(backend)
|
| 192 |
{
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 193 |
if (filters.n_fft != WHISPER_N_FFT_HALF) {
|
| 194 |
throw std::invalid_argument("MelFilters n_frames must be WHISPER_N_FFT_HALF");
|
| 195 |
}
|
|
@@ -219,6 +215,7 @@ public:
|
|
| 219 |
}
|
| 220 |
|
| 221 |
~mel_calc_cuda() {
|
|
|
|
| 222 |
CUDA_CHECK(cudaStreamSynchronize(m_stream));
|
| 223 |
CUDA_CHECK(cudaStreamDestroy(m_stream));
|
| 224 |
CUDA_CHECK(cudaFree(m_hann_window));
|
|
@@ -268,6 +265,7 @@ public:
|
|
| 268 |
}
|
| 269 |
|
| 270 |
virtual whisper_mel calculate(whisper_span<const float> samples, int /*n_threads*/) override {
|
|
|
|
| 271 |
ensure_working_areas(samples.len);
|
| 272 |
|
| 273 |
const size_t mirror_pad = WHISPER_N_FFT / 2;
|
|
@@ -356,8 +354,11 @@ public:
|
|
| 356 |
}
|
| 357 |
|
| 358 |
whisper_mel_calc * whisper_mel_calc_create_cuda(ggml_backend_t backend, const whisper_filters & filters) {
|
| 359 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 360 |
return nullptr;
|
| 361 |
}
|
| 362 |
-
return new mel_calc_cuda(backend, filters);
|
| 363 |
}
|
|
|
|
| 2 |
#include "whisper-mel-cuda.hpp"
|
| 3 |
#include "whisper.h"
|
| 4 |
|
| 5 |
+
#include <ggml-cuda/common.cuh>
|
| 6 |
+
#include <ggml-backend-impl.h>
|
| 7 |
+
|
| 8 |
#include <cuda.h>
|
| 9 |
#include <cuda_runtime.h>
|
| 10 |
#include <cufft.h>
|
|
|
|
| 19 |
#pragma warning(disable: 4324) // added padding
|
| 20 |
#endif
|
| 21 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 22 |
namespace {
|
| 23 |
|
| 24 |
+
static const char* cufftGetErrorString(cufftResult_t res) {
|
|
|
|
| 25 |
switch (res) {
|
| 26 |
case CUFFT_SUCCESS: return "The cuFFT operation was successful";
|
| 27 |
case CUFFT_INVALID_PLAN: return "cuFFT was passed an invalid plan handle";
|
|
|
|
| 44 |
}
|
| 45 |
}
|
| 46 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 47 |
#define CUFFT_CHECK(err) CUDA_CHECK_GEN(err, CUFFT_SUCCESS, cufftGetErrorString)
|
| 48 |
|
| 49 |
__global__ void k_fill_stft_input(
|
|
|
|
| 64 |
}
|
| 65 |
|
| 66 |
__global__ void k_calc_magnitudes(
|
| 67 |
+
const cuComplex * stft_out,
|
| 68 |
const int n_frames,
|
| 69 |
float * magnitudes
|
| 70 |
) {
|
|
|
|
| 116 |
}
|
| 117 |
|
| 118 |
void calc_magnitudes(
|
| 119 |
+
const cuComplex * stft_out,
|
| 120 |
int n_frames,
|
| 121 |
float * magnitudes,
|
| 122 |
cudaStream_t stream
|
|
|
|
| 152 |
const int m_n_mel;
|
| 153 |
|
| 154 |
ggml_backend_t m_backend = nullptr;
|
| 155 |
+
int m_device = -1;
|
| 156 |
|
| 157 |
cudaStream_t m_stream = nullptr;
|
| 158 |
cublasHandle_t m_cublas_handle = nullptr;
|
|
|
|
| 174 |
: m_n_mel(filters.n_mel)
|
| 175 |
, m_backend(backend)
|
| 176 |
{
|
| 177 |
+
ggml_backend_cuda_context* cuda_ctx = (ggml_backend_cuda_context*)m_backend->context;
|
| 178 |
+
m_device = cuda_ctx->device;
|
| 179 |
+
|
| 180 |
+
if (ggml_cuda_info().devices[m_device].cc < 600) {
|
| 181 |
+
// we've only tesed on 6.0 and higher and we've had reports of crashes on 5.0:
|
| 182 |
+
// https://github.com/ggerganov/whisper.cpp/issues/2230
|
| 183 |
+
// to be safe forbid anything below 6.0
|
| 184 |
+
throw std::runtime_error("CUDA compute capability 6.0 or higher is required");
|
| 185 |
+
}
|
| 186 |
+
|
| 187 |
+
ggml_cuda_set_device(m_device);
|
| 188 |
+
|
| 189 |
if (filters.n_fft != WHISPER_N_FFT_HALF) {
|
| 190 |
throw std::invalid_argument("MelFilters n_frames must be WHISPER_N_FFT_HALF");
|
| 191 |
}
|
|
|
|
| 215 |
}
|
| 216 |
|
| 217 |
~mel_calc_cuda() {
|
| 218 |
+
ggml_cuda_set_device(m_device);
|
| 219 |
CUDA_CHECK(cudaStreamSynchronize(m_stream));
|
| 220 |
CUDA_CHECK(cudaStreamDestroy(m_stream));
|
| 221 |
CUDA_CHECK(cudaFree(m_hann_window));
|
|
|
|
| 265 |
}
|
| 266 |
|
| 267 |
virtual whisper_mel calculate(whisper_span<const float> samples, int /*n_threads*/) override {
|
| 268 |
+
ggml_cuda_set_device(m_device);
|
| 269 |
ensure_working_areas(samples.len);
|
| 270 |
|
| 271 |
const size_t mirror_pad = WHISPER_N_FFT / 2;
|
|
|
|
| 354 |
}
|
| 355 |
|
| 356 |
whisper_mel_calc * whisper_mel_calc_create_cuda(ggml_backend_t backend, const whisper_filters & filters) {
|
| 357 |
+
try {
|
| 358 |
+
return new mel_calc_cuda(backend, filters);
|
| 359 |
+
}
|
| 360 |
+
catch (...) {
|
| 361 |
+
// TODO: log error (but for this we would have to expose the log state to be accessible here)
|
| 362 |
return nullptr;
|
| 363 |
}
|
|
|
|
| 364 |
}
|
whisper.cpp
CHANGED
|
@@ -3170,13 +3170,18 @@ whisper_mel_calc * whisper_mel_calc_create(ggml_backend_t backend, const whisper
|
|
| 3170 |
#if defined(GGML_USE_CUDA) && !defined(GGML_USE_HIPBLAS)
|
| 3171 |
if (ggml_backend_is_cuda(backend)) {
|
| 3172 |
auto ret = whisper_mel_calc_create_cuda(backend, filters);
|
| 3173 |
-
|
| 3174 |
-
|
| 3175 |
-
|
| 3176 |
-
|
| 3177 |
-
|
|
|
|
|
|
|
| 3178 |
#endif
|
| 3179 |
-
|
|
|
|
|
|
|
|
|
|
| 3180 |
}
|
| 3181 |
|
| 3182 |
// split text into tokens
|
|
|
|
| 3170 |
#if defined(GGML_USE_CUDA) && !defined(GGML_USE_HIPBLAS)
|
| 3171 |
if (ggml_backend_is_cuda(backend)) {
|
| 3172 |
auto ret = whisper_mel_calc_create_cuda(backend, filters);
|
| 3173 |
+
if (ret) {
|
| 3174 |
+
// run a warmup to avoid the first kernel launch overhead (thus we get the best perf even on the first run)
|
| 3175 |
+
const float warmup[256] = { 0 };
|
| 3176 |
+
ret->calculate({ warmup, 256 }, 1);
|
| 3177 |
+
return ret;
|
| 3178 |
+
}
|
| 3179 |
+
}
|
| 3180 |
#endif
|
| 3181 |
+
|
| 3182 |
+
// a specialized mel_calc could not be created
|
| 3183 |
+
// fall back to CPU
|
| 3184 |
+
return new mel_calc_cpu(backend, filters);
|
| 3185 |
}
|
| 3186 |
|
| 3187 |
// split text into tokens
|