Spaces:
Running
whisper : significantly improve the inference quality (#1148)
Browse files* Fix MSVC compile error C3688
Instead of simply using 'add_compile_options(/utf-8)' to address the MSVC compile error C3688, a better approach would be to handle it in a way that prevents passing '/utf-8' to NVCC.
* Significantly improve inference quality
In the function `log_mel_spectrogram_worker_thread`, there's an array out-of-bounds issue occurring during the calculation of complex number moduli. This issue is causing disruptions in the FFT spectrum, which, in turn, is reducing the quality of inference.
* Significantly improve inference quality
At last, I've pinpointed the actual source of the problem. Given that the frequency spectrum generated from real input data is symmetrical around the Nyquist frequency, there's a for-loop within the `log_mel_spectrogram_worker_thread` function that attempts to fold the frequency spectrum. Regrettably, a bug within this for-loop is causing a frame shift in the frequency spectrum. The previous attempt to remedy this, which involved using `fft_size + 1` when calculating the modulus, was merely a band-aid solution and did not address the underlying issue.
* Addressed a few minor issues
Fixed the issue of `fft_out` continuously expanding. Resolved the fallback caused by using 'break' instead of `fft_in[j] = 0`.
* Significantly improve inference quality
Thanks for your patience everyone. It's finally sorted out. Now, the right side of the FFT spectrum is being flipped over to the left, and the amplitudes at corresponding positions on the left and right are added together (the spectrum on the left needs to be shifted by one position), then the average is calculated. FFT_OUT[0] is no longer discarded, making full use of the limited space to pack in more information.
* Add annotation and performance improvement
* Calculate FFT only when fft_in are not all zero
* Some minor performance improvement
* Fixed a bug impacting inference quality
* The first version after all the analysis is completed.
* Fix some bugs and add debug mode
* Fixed several bugs
* Temporarily disable speed-up mode and add debug mode.
* Add debug mode
* Disable speed-up mode and add debug mode
* Fix CI error (#1)
* Fix error
* Fix error
* Fixed several bugs including [BLANK_AUDIO] problem
* Remove Hard-coded hann window
* Some Final Fix (#2)
* Fix error
* Fix error
* Probably the last commit
* Probably the last commit
* whisper : minor coding style changes
* whisper : remove debug from public API
---------
Co-authored-by: Georgi Gerganov <[email protected]>
- examples/main/main.cpp +6 -2
- whisper.cpp +114 -82
- whisper.h +1 -0
|
@@ -70,6 +70,7 @@ struct whisper_params {
|
|
| 70 |
float logprob_thold = -1.00f;
|
| 71 |
|
| 72 |
bool speed_up = false;
|
|
|
|
| 73 |
bool translate = false;
|
| 74 |
bool detect_language = false;
|
| 75 |
bool diarize = false;
|
|
@@ -135,7 +136,8 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
|
|
| 135 |
else if (arg == "-wt" || arg == "--word-thold") { params.word_thold = std::stof(argv[++i]); }
|
| 136 |
else if (arg == "-et" || arg == "--entropy-thold") { params.entropy_thold = std::stof(argv[++i]); }
|
| 137 |
else if (arg == "-lpt" || arg == "--logprob-thold") { params.logprob_thold = std::stof(argv[++i]); }
|
| 138 |
-
else if (arg == "-su" || arg == "--speed-up") { params.speed_up = true; }
|
|
|
|
| 139 |
else if (arg == "-tr" || arg == "--translate") { params.translate = true; }
|
| 140 |
else if (arg == "-di" || arg == "--diarize") { params.diarize = true; }
|
| 141 |
else if (arg == "-tdrz" || arg == "--tinydiarize") { params.tinydiarize = true; }
|
|
@@ -190,7 +192,8 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
|
|
| 190 |
fprintf(stderr, " -wt N, --word-thold N [%-7.2f] word timestamp probability threshold\n", params.word_thold);
|
| 191 |
fprintf(stderr, " -et N, --entropy-thold N [%-7.2f] entropy threshold for decoder fail\n", params.entropy_thold);
|
| 192 |
fprintf(stderr, " -lpt N, --logprob-thold N [%-7.2f] log probability threshold for decoder fail\n", params.logprob_thold);
|
| 193 |
-
fprintf(stderr, " -su, --speed-up [%-7s] speed up audio by x2 (reduced accuracy)\n", params.speed_up ? "true" : "false");
|
|
|
|
| 194 |
fprintf(stderr, " -tr, --translate [%-7s] translate from source language to english\n", params.translate ? "true" : "false");
|
| 195 |
fprintf(stderr, " -di, --diarize [%-7s] stereo audio diarization\n", params.diarize ? "true" : "false");
|
| 196 |
fprintf(stderr, " -tdrz, --tinydiarize [%-7s] enable tinydiarize (requires a tdrz model)\n", params.tinydiarize ? "true" : "false");
|
|
@@ -915,6 +918,7 @@ int main(int argc, char ** argv) {
|
|
| 915 |
wparams.split_on_word = params.split_on_word;
|
| 916 |
|
| 917 |
wparams.speed_up = params.speed_up;
|
|
|
|
| 918 |
|
| 919 |
wparams.tdrz_enable = params.tinydiarize; // [TDRZ]
|
| 920 |
|
|
|
|
| 70 |
float logprob_thold = -1.00f;
|
| 71 |
|
| 72 |
bool speed_up = false;
|
| 73 |
+
bool debug_mode = false;
|
| 74 |
bool translate = false;
|
| 75 |
bool detect_language = false;
|
| 76 |
bool diarize = false;
|
|
|
|
| 136 |
else if (arg == "-wt" || arg == "--word-thold") { params.word_thold = std::stof(argv[++i]); }
|
| 137 |
else if (arg == "-et" || arg == "--entropy-thold") { params.entropy_thold = std::stof(argv[++i]); }
|
| 138 |
else if (arg == "-lpt" || arg == "--logprob-thold") { params.logprob_thold = std::stof(argv[++i]); }
|
| 139 |
+
// else if (arg == "-su" || arg == "--speed-up") { params.speed_up = true; }
|
| 140 |
+
else if (arg == "-debug"|| arg == "--debug-mode") { params.debug_mode = true; }
|
| 141 |
else if (arg == "-tr" || arg == "--translate") { params.translate = true; }
|
| 142 |
else if (arg == "-di" || arg == "--diarize") { params.diarize = true; }
|
| 143 |
else if (arg == "-tdrz" || arg == "--tinydiarize") { params.tinydiarize = true; }
|
|
|
|
| 192 |
fprintf(stderr, " -wt N, --word-thold N [%-7.2f] word timestamp probability threshold\n", params.word_thold);
|
| 193 |
fprintf(stderr, " -et N, --entropy-thold N [%-7.2f] entropy threshold for decoder fail\n", params.entropy_thold);
|
| 194 |
fprintf(stderr, " -lpt N, --logprob-thold N [%-7.2f] log probability threshold for decoder fail\n", params.logprob_thold);
|
| 195 |
+
// fprintf(stderr, " -su, --speed-up [%-7s] speed up audio by x2 (reduced accuracy)\n", params.speed_up ? "true" : "false");
|
| 196 |
+
fprintf(stderr, " -debug, --debug-mode [%-7s] enable debug mode (eg. dump log_mel)\n", params.debug_mode ? "true" : "false");
|
| 197 |
fprintf(stderr, " -tr, --translate [%-7s] translate from source language to english\n", params.translate ? "true" : "false");
|
| 198 |
fprintf(stderr, " -di, --diarize [%-7s] stereo audio diarization\n", params.diarize ? "true" : "false");
|
| 199 |
fprintf(stderr, " -tdrz, --tinydiarize [%-7s] enable tinydiarize (requires a tdrz model)\n", params.tinydiarize ? "true" : "false");
|
|
|
|
| 918 |
wparams.split_on_word = params.split_on_word;
|
| 919 |
|
| 920 |
wparams.speed_up = params.speed_up;
|
| 921 |
+
wparams.debug_mode = params.debug_mode;
|
| 922 |
|
| 923 |
wparams.tdrz_enable = params.tinydiarize; // [TDRZ]
|
| 924 |
|
|
@@ -2445,40 +2445,50 @@ static void fft(const std::vector<float> & in, std::vector<float> & out) {
|
|
| 2445 |
}
|
| 2446 |
}
|
| 2447 |
|
| 2448 |
-
static
|
| 2449 |
-
|
| 2450 |
-
|
| 2451 |
-
|
| 2452 |
-
|
| 2453 |
-
|
| 2454 |
-
|
| 2455 |
-
|
| 2456 |
-
|
| 2457 |
-
|
| 2458 |
-
|
| 2459 |
-
for (int j = 0; j < fft_size; j++) {
|
| 2460 |
-
if (offset + j < n_samples) {
|
| 2461 |
-
fft_in[j] = hann[j] * samples[offset + j];
|
| 2462 |
-
} else {
|
| 2463 |
-
fft_in[j] = 0.0;
|
| 2464 |
-
}
|
| 2465 |
-
}
|
| 2466 |
|
| 2467 |
-
|
| 2468 |
-
|
| 2469 |
|
| 2470 |
-
|
| 2471 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2472 |
}
|
| 2473 |
-
|
| 2474 |
-
|
|
|
|
| 2475 |
}
|
| 2476 |
|
| 2477 |
-
|
| 2478 |
-
|
| 2479 |
-
|
| 2480 |
-
|
| 2481 |
-
|
|
|
|
|
|
|
| 2482 |
}
|
| 2483 |
|
| 2484 |
// mel spectrogram
|
|
@@ -2489,10 +2499,10 @@ static void log_mel_spectrogram_worker_thread(int ith, const std::vector<float>
|
|
| 2489 |
int k = 0;
|
| 2490 |
for (k = 0; k < n_fft - 3; k += 4) {
|
| 2491 |
sum +=
|
| 2492 |
-
|
| 2493 |
-
|
| 2494 |
-
|
| 2495 |
-
|
| 2496 |
}
|
| 2497 |
|
| 2498 |
// handle n_fft remainder
|
|
@@ -2505,68 +2515,73 @@ static void log_mel_spectrogram_worker_thread(int ith, const std::vector<float>
|
|
| 2505 |
mel.data[j * mel.n_len + i] = sum;
|
| 2506 |
}
|
| 2507 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2508 |
}
|
| 2509 |
|
| 2510 |
-
// ref: https://github.com/openai/whisper/blob/main/whisper/audio.py#
|
| 2511 |
static bool log_mel_spectrogram(
|
| 2512 |
-
|
| 2513 |
-
|
| 2514 |
const int n_samples,
|
| 2515 |
const int /*sample_rate*/,
|
| 2516 |
-
const int
|
| 2517 |
-
const int
|
| 2518 |
const int n_mel,
|
| 2519 |
const int n_threads,
|
| 2520 |
-
|
| 2521 |
-
|
| 2522 |
-
|
| 2523 |
const int64_t t_start_us = ggml_time_us();
|
| 2524 |
|
| 2525 |
-
// Hanning window
|
|
|
|
|
|
|
| 2526 |
std::vector<float> hann;
|
| 2527 |
-
hann
|
| 2528 |
-
for (int i = 0; i < fft_size; i++) {
|
| 2529 |
-
hann[i] = 0.5*(1.0 - cos((2.0*M_PI*i)/(fft_size)));
|
| 2530 |
-
}
|
| 2531 |
-
|
| 2532 |
-
mel.n_mel = n_mel;
|
| 2533 |
-
mel.n_len = n_samples/fft_step;
|
| 2534 |
-
mel.n_len_org = mel.n_len;
|
| 2535 |
|
| 2536 |
-
std::vector<float> samples_padded;
|
| 2537 |
|
| 2538 |
-
//
|
| 2539 |
-
|
| 2540 |
-
|
| 2541 |
|
| 2542 |
-
|
| 2543 |
-
|
| 2544 |
-
|
| 2545 |
-
|
| 2546 |
|
| 2547 |
-
|
| 2548 |
-
|
| 2549 |
-
memset(samples_padded.data() + n_samples, 0, (mel.n_len*fft_step - n_samples)*sizeof(float));
|
| 2550 |
|
| 2551 |
-
|
| 2552 |
-
|
| 2553 |
|
| 2554 |
-
mel.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2555 |
|
| 2556 |
-
//printf("%s: n_samples = %d, n_len = %d\n", __func__, n_samples, mel.n_len);
|
| 2557 |
-
//printf("%s: recording length: %f s\n", __func__, (float) n_samples/sample_rate);
|
| 2558 |
|
| 2559 |
{
|
| 2560 |
std::vector<std::thread> workers(n_threads - 1);
|
| 2561 |
for (int iw = 0; iw < n_threads - 1; ++iw) {
|
| 2562 |
workers[iw] = std::thread(
|
| 2563 |
-
log_mel_spectrogram_worker_thread, iw + 1, std::cref(hann),
|
| 2564 |
-
n_samples,
|
| 2565 |
-
std::cref(filters),
|
| 2566 |
}
|
| 2567 |
|
| 2568 |
// main thread
|
| 2569 |
-
log_mel_spectrogram_worker_thread(0, hann,
|
| 2570 |
|
| 2571 |
for (int iw = 0; iw < n_threads - 1; ++iw) {
|
| 2572 |
workers[iw].join();
|
|
@@ -2580,7 +2595,6 @@ static bool log_mel_spectrogram(
|
|
| 2580 |
mmax = mel.data[i];
|
| 2581 |
}
|
| 2582 |
}
|
| 2583 |
-
//printf("%s: max = %f\n", __func__, mmax);
|
| 2584 |
|
| 2585 |
mmax -= 8.0;
|
| 2586 |
|
|
@@ -2594,7 +2608,16 @@ static bool log_mel_spectrogram(
|
|
| 2594 |
|
| 2595 |
wstate.t_mel_us += ggml_time_us() - t_start_us;
|
| 2596 |
|
| 2597 |
-
//
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2598 |
|
| 2599 |
return true;
|
| 2600 |
}
|
|
@@ -3026,9 +3049,9 @@ int whisper_pcm_to_mel(struct whisper_context * ctx, const float * samples, int
|
|
| 3026 |
return whisper_pcm_to_mel_with_state(ctx, ctx->state, samples, n_samples, n_threads);
|
| 3027 |
}
|
| 3028 |
|
| 3029 |
-
// same as whisper_pcm_to_mel, but applies a Phase Vocoder to speed up the audio x2
|
| 3030 |
int whisper_pcm_to_mel_phase_vocoder_with_state(struct whisper_context * ctx, struct whisper_state * state, const float * samples, int n_samples, int n_threads) {
|
| 3031 |
-
if (!log_mel_spectrogram(*state, samples, n_samples, WHISPER_SAMPLE_RATE, 2 * WHISPER_N_FFT, 2 * WHISPER_HOP_LENGTH, WHISPER_N_MEL, n_threads, ctx->model.filters,
|
| 3032 |
log("%s: failed to compute mel spectrogram\n", __func__);
|
| 3033 |
return -1;
|
| 3034 |
}
|
|
@@ -3036,11 +3059,20 @@ int whisper_pcm_to_mel_phase_vocoder_with_state(struct whisper_context * ctx, st
|
|
| 3036 |
return 0;
|
| 3037 |
}
|
| 3038 |
|
| 3039 |
-
// same as whisper_pcm_to_mel, but applies a Phase Vocoder to speed up the audio x2
|
| 3040 |
int whisper_pcm_to_mel_phase_vocoder(struct whisper_context * ctx, const float * samples, int n_samples, int n_threads) {
|
| 3041 |
return whisper_pcm_to_mel_phase_vocoder_with_state(ctx, ctx->state, samples, n_samples, n_threads);
|
| 3042 |
}
|
| 3043 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3044 |
int whisper_set_mel_with_state(
|
| 3045 |
struct whisper_context * /*ctx*/,
|
| 3046 |
struct whisper_state * state,
|
|
@@ -3492,6 +3524,7 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
|
|
| 3492 |
/*.max_tokens =*/ 0,
|
| 3493 |
|
| 3494 |
/*.speed_up =*/ false,
|
|
|
|
| 3495 |
/*.audio_ctx =*/ 0,
|
| 3496 |
|
| 3497 |
/*.tdrz_enable =*/ false,
|
|
@@ -3653,7 +3686,7 @@ static void whisper_process_logits(
|
|
| 3653 |
WHISPER_ASSERT(n_logits == ctx.vocab.n_vocab);
|
| 3654 |
|
| 3655 |
// extract the logits for the last token
|
| 3656 |
-
// we will be mutating and therefore we don't want to use the ctx.logits buffer directly
|
| 3657 |
auto & probs = decoder.probs;
|
| 3658 |
auto & logits = decoder.logits;
|
| 3659 |
auto & logprobs = decoder.logprobs;
|
|
@@ -4056,10 +4089,9 @@ int whisper_full_with_state(
|
|
| 4056 |
|
| 4057 |
// compute log mel spectrogram
|
| 4058 |
if (params.speed_up) {
|
| 4059 |
-
|
| 4060 |
-
|
| 4061 |
-
|
| 4062 |
-
}
|
| 4063 |
} else {
|
| 4064 |
if (whisper_pcm_to_mel_with_state(ctx, state, samples, n_samples, params.n_threads) != 0) {
|
| 4065 |
log("%s: failed to compute log mel spectrogram\n", __func__);
|
|
@@ -4095,8 +4127,8 @@ int whisper_full_with_state(
|
|
| 4095 |
const int seek_start = params.offset_ms/10;
|
| 4096 |
const int seek_end = params.duration_ms == 0 ? whisper_n_len_from_state(state) : seek_start + params.duration_ms/10;
|
| 4097 |
|
| 4098 |
-
// if length of spectrogram is less than
|
| 4099 |
-
// basically don't process anything that is less than
|
| 4100 |
// see issue #39: https://github.com/ggerganov/whisper.cpp/issues/39
|
| 4101 |
if (seek_end < seek_start + (params.speed_up ? 50 : 100)) {
|
| 4102 |
return 0;
|
|
|
|
| 2445 |
}
|
| 2446 |
}
|
| 2447 |
|
| 2448 |
+
static bool hann_window(int length, bool periodic, std::vector<float> & output) {
|
| 2449 |
+
if (output.size() < length) {
|
| 2450 |
+
output.resize(length);
|
| 2451 |
+
}
|
| 2452 |
+
int offset = -1;
|
| 2453 |
+
if (periodic) {
|
| 2454 |
+
offset = 0;
|
| 2455 |
+
}
|
| 2456 |
+
for (int i = 0; i < length; i++) {
|
| 2457 |
+
output[i] = 0.5*(1.0 - cosf((2.0*M_PI*i)/(length + offset)));
|
| 2458 |
+
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2459 |
|
| 2460 |
+
return true;
|
| 2461 |
+
}
|
| 2462 |
|
| 2463 |
+
static void log_mel_spectrogram_worker_thread(int ith, const std::vector<float> & hann, const std::vector<float> & samples,
|
| 2464 |
+
int n_samples, int frame_size, int frame_step, int n_threads,
|
| 2465 |
+
const whisper_filters & filters, whisper_mel & mel) {
|
| 2466 |
+
std::vector<float> fft_in(frame_size, 0.0);
|
| 2467 |
+
std::vector<float> fft_out(2 * frame_step);
|
| 2468 |
+
// make sure n_fft == 1 + (WHISPER_N_FFT / 2), bin_0 to bin_nyquist
|
| 2469 |
+
int n_fft = 1 + (frame_size / 2);
|
| 2470 |
+
int i = ith;
|
| 2471 |
+
|
| 2472 |
+
// calculate FFT only when fft_in are not all zero
|
| 2473 |
+
for (; i < std::min(n_samples / frame_step + 1, mel.n_len); i += n_threads) {
|
| 2474 |
+
const int offset = i * frame_step;
|
| 2475 |
+
|
| 2476 |
+
// apply Hanning window (~10% faster)
|
| 2477 |
+
for (int j = 0; j < std::min(frame_size, n_samples - offset); j++) {
|
| 2478 |
+
fft_in[j] = hann[j] * samples[offset + j];
|
| 2479 |
}
|
| 2480 |
+
// fill the rest with zeros
|
| 2481 |
+
if (n_samples - offset < frame_size) {
|
| 2482 |
+
std::fill(fft_in.begin() + (n_samples - offset), fft_in.end(), 0.0);
|
| 2483 |
}
|
| 2484 |
|
| 2485 |
+
// FFT
|
| 2486 |
+
fft(fft_in, fft_out);
|
| 2487 |
+
|
| 2488 |
+
// Calculate modulus^2 of complex numbers
|
| 2489 |
+
// Use pow(fft_out[2 * j + 0], 2) + pow(fft_out[2 * j + 1], 2) causes inference quality problem? Interesting.
|
| 2490 |
+
for (int j = 0; j < frame_size; j++) {
|
| 2491 |
+
fft_out[j] = (fft_out[2 * j + 0] * fft_out[2 * j + 0] + fft_out[2 * j + 1] * fft_out[2 * j + 1]);
|
| 2492 |
}
|
| 2493 |
|
| 2494 |
// mel spectrogram
|
|
|
|
| 2499 |
int k = 0;
|
| 2500 |
for (k = 0; k < n_fft - 3; k += 4) {
|
| 2501 |
sum +=
|
| 2502 |
+
fft_out[k + 0] * filters.data[j * n_fft + k + 0] +
|
| 2503 |
+
fft_out[k + 1] * filters.data[j * n_fft + k + 1] +
|
| 2504 |
+
fft_out[k + 2] * filters.data[j * n_fft + k + 2] +
|
| 2505 |
+
fft_out[k + 3] * filters.data[j * n_fft + k + 3];
|
| 2506 |
}
|
| 2507 |
|
| 2508 |
// handle n_fft remainder
|
|
|
|
| 2515 |
mel.data[j * mel.n_len + i] = sum;
|
| 2516 |
}
|
| 2517 |
}
|
| 2518 |
+
|
| 2519 |
+
// Otherwise fft_out are all zero
|
| 2520 |
+
double sum = log10(1e-10);
|
| 2521 |
+
for (; i < mel.n_len; i += n_threads) {
|
| 2522 |
+
for (int j = 0; j < mel.n_mel; j++) {
|
| 2523 |
+
mel.data[j * mel.n_len + i] = sum;
|
| 2524 |
+
}
|
| 2525 |
+
}
|
| 2526 |
}
|
| 2527 |
|
| 2528 |
+
// ref: https://github.com/openai/whisper/blob/main/whisper/audio.py#L110-L157
|
| 2529 |
static bool log_mel_spectrogram(
|
| 2530 |
+
whisper_state & wstate,
|
| 2531 |
+
const float * samples,
|
| 2532 |
const int n_samples,
|
| 2533 |
const int /*sample_rate*/,
|
| 2534 |
+
const int frame_size,
|
| 2535 |
+
const int frame_step,
|
| 2536 |
const int n_mel,
|
| 2537 |
const int n_threads,
|
| 2538 |
+
const whisper_filters & filters,
|
| 2539 |
+
const bool debug,
|
| 2540 |
+
whisper_mel & mel) {
|
| 2541 |
const int64_t t_start_us = ggml_time_us();
|
| 2542 |
|
| 2543 |
+
// Hanning window (Use cosf to eliminate difference)
|
| 2544 |
+
// ref: https://pytorch.org/docs/stable/generated/torch.hann_window.html
|
| 2545 |
+
// ref: https://github.com/openai/whisper/blob/main/whisper/audio.py#L147
|
| 2546 |
std::vector<float> hann;
|
| 2547 |
+
hann_window(frame_size, true, hann);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2548 |
|
|
|
|
| 2549 |
|
| 2550 |
+
// Calculate the length of padding
|
| 2551 |
+
int64_t stage_1_pad = WHISPER_SAMPLE_RATE * 30;
|
| 2552 |
+
int64_t stage_2_pad = frame_size / 2;
|
| 2553 |
|
| 2554 |
+
// Initialize a vector and copy data from C array to it.
|
| 2555 |
+
std::vector<float> samples_padded;
|
| 2556 |
+
samples_padded.resize(n_samples + stage_1_pad + stage_2_pad * 2);
|
| 2557 |
+
std::copy(samples, samples + n_samples, samples_padded.begin() + stage_2_pad);
|
| 2558 |
|
| 2559 |
+
// pad 30 seconds of zeros at the end of audio (480,000 samples) + reflective pad 200 samples at the end of audio
|
| 2560 |
+
std::fill(samples_padded.begin() + n_samples + stage_2_pad, samples_padded.begin() + n_samples + stage_1_pad + 2 * stage_2_pad, 0);
|
|
|
|
| 2561 |
|
| 2562 |
+
// reflective pad 200 samples at the beginning of audio
|
| 2563 |
+
std::reverse_copy(samples + 1, samples + 1 + stage_2_pad, samples_padded.begin());
|
| 2564 |
|
| 2565 |
+
mel.n_mel = n_mel;
|
| 2566 |
+
// https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/SpectralOps.cpp#L936
|
| 2567 |
+
// Calculate number of frames + remove the last frame
|
| 2568 |
+
mel.n_len = (samples_padded.size() - frame_size) / frame_step;
|
| 2569 |
+
// Calculate semi-padded sample length to ensure compatibility
|
| 2570 |
+
mel.n_len_org = 1 + (n_samples + stage_2_pad - frame_size) / frame_step;
|
| 2571 |
+
mel.data.resize(mel.n_mel * mel.n_len);
|
| 2572 |
|
|
|
|
|
|
|
| 2573 |
|
| 2574 |
{
|
| 2575 |
std::vector<std::thread> workers(n_threads - 1);
|
| 2576 |
for (int iw = 0; iw < n_threads - 1; ++iw) {
|
| 2577 |
workers[iw] = std::thread(
|
| 2578 |
+
log_mel_spectrogram_worker_thread, iw + 1, std::cref(hann), samples_padded,
|
| 2579 |
+
n_samples + stage_2_pad, frame_size, frame_step, n_threads,
|
| 2580 |
+
std::cref(filters), std::ref(mel));
|
| 2581 |
}
|
| 2582 |
|
| 2583 |
// main thread
|
| 2584 |
+
log_mel_spectrogram_worker_thread(0, hann, samples_padded, n_samples + stage_2_pad, frame_size, frame_step, n_threads, filters, mel);
|
| 2585 |
|
| 2586 |
for (int iw = 0; iw < n_threads - 1; ++iw) {
|
| 2587 |
workers[iw].join();
|
|
|
|
| 2595 |
mmax = mel.data[i];
|
| 2596 |
}
|
| 2597 |
}
|
|
|
|
| 2598 |
|
| 2599 |
mmax -= 8.0;
|
| 2600 |
|
|
|
|
| 2608 |
|
| 2609 |
wstate.t_mel_us += ggml_time_us() - t_start_us;
|
| 2610 |
|
| 2611 |
+
// Dump log_mel_spectrogram
|
| 2612 |
+
if (debug) {
|
| 2613 |
+
std::ofstream outFile("log_mel_spectrogram.json");
|
| 2614 |
+
outFile << "[";
|
| 2615 |
+
for (uint64_t i = 0; i < mel.data.size() - 1; i++) {
|
| 2616 |
+
outFile << mel.data[i] << ", ";
|
| 2617 |
+
}
|
| 2618 |
+
outFile << mel.data[mel.data.size() - 1] << "]";
|
| 2619 |
+
outFile.close();
|
| 2620 |
+
}
|
| 2621 |
|
| 2622 |
return true;
|
| 2623 |
}
|
|
|
|
| 3049 |
return whisper_pcm_to_mel_with_state(ctx, ctx->state, samples, n_samples, n_threads);
|
| 3050 |
}
|
| 3051 |
|
| 3052 |
+
// same as whisper_pcm_to_mel, but applies a Phase Vocoder to speed up the audio x2 (PV without phase lock is not good)
|
| 3053 |
int whisper_pcm_to_mel_phase_vocoder_with_state(struct whisper_context * ctx, struct whisper_state * state, const float * samples, int n_samples, int n_threads) {
|
| 3054 |
+
if (!log_mel_spectrogram(*state, samples, n_samples, WHISPER_SAMPLE_RATE, 2 * WHISPER_N_FFT, 2 * WHISPER_HOP_LENGTH, WHISPER_N_MEL, n_threads, ctx->model.filters, false, state->mel)) {
|
| 3055 |
log("%s: failed to compute mel spectrogram\n", __func__);
|
| 3056 |
return -1;
|
| 3057 |
}
|
|
|
|
| 3059 |
return 0;
|
| 3060 |
}
|
| 3061 |
|
| 3062 |
+
// same as whisper_pcm_to_mel, but applies a Phase Vocoder to speed up the audio x2 (PV without phase lock is not good)
|
| 3063 |
int whisper_pcm_to_mel_phase_vocoder(struct whisper_context * ctx, const float * samples, int n_samples, int n_threads) {
|
| 3064 |
return whisper_pcm_to_mel_phase_vocoder_with_state(ctx, ctx->state, samples, n_samples, n_threads);
|
| 3065 |
}
|
| 3066 |
|
| 3067 |
+
// same as whisper_pcm_to_mel, but applies WSOLA to speed up the audio x2
|
| 3068 |
+
// TODO
|
| 3069 |
+
|
| 3070 |
+
// same as whisper_pcm_to_mel, but applies HPTSM to speed up the audio x2
|
| 3071 |
+
// TODO
|
| 3072 |
+
|
| 3073 |
+
// same as whisper_pcm_to_mel, but applies PV (with phase lock) to speed up the audio x2
|
| 3074 |
+
// TODO
|
| 3075 |
+
|
| 3076 |
int whisper_set_mel_with_state(
|
| 3077 |
struct whisper_context * /*ctx*/,
|
| 3078 |
struct whisper_state * state,
|
|
|
|
| 3524 |
/*.max_tokens =*/ 0,
|
| 3525 |
|
| 3526 |
/*.speed_up =*/ false,
|
| 3527 |
+
/*.debug_mode =*/ false,
|
| 3528 |
/*.audio_ctx =*/ 0,
|
| 3529 |
|
| 3530 |
/*.tdrz_enable =*/ false,
|
|
|
|
| 3686 |
WHISPER_ASSERT(n_logits == ctx.vocab.n_vocab);
|
| 3687 |
|
| 3688 |
// extract the logits for the last token
|
| 3689 |
+
// we will be mutating, and therefore we don't want to use the ctx.logits buffer directly
|
| 3690 |
auto & probs = decoder.probs;
|
| 3691 |
auto & logits = decoder.logits;
|
| 3692 |
auto & logprobs = decoder.logprobs;
|
|
|
|
| 4089 |
|
| 4090 |
// compute log mel spectrogram
|
| 4091 |
if (params.speed_up) {
|
| 4092 |
+
// TODO: Replace PV with more advanced algorithm
|
| 4093 |
+
log("%s: failed to compute log mel spectrogram\n", __func__);
|
| 4094 |
+
return -1;
|
|
|
|
| 4095 |
} else {
|
| 4096 |
if (whisper_pcm_to_mel_with_state(ctx, state, samples, n_samples, params.n_threads) != 0) {
|
| 4097 |
log("%s: failed to compute log mel spectrogram\n", __func__);
|
|
|
|
| 4127 |
const int seek_start = params.offset_ms/10;
|
| 4128 |
const int seek_end = params.duration_ms == 0 ? whisper_n_len_from_state(state) : seek_start + params.duration_ms/10;
|
| 4129 |
|
| 4130 |
+
// if length of spectrogram is less than 1.0s (100 frames), then return
|
| 4131 |
+
// basically don't process anything that is less than 1.0s
|
| 4132 |
// see issue #39: https://github.com/ggerganov/whisper.cpp/issues/39
|
| 4133 |
if (seek_end < seek_start + (params.speed_up ? 50 : 100)) {
|
| 4134 |
return 0;
|
|
@@ -375,6 +375,7 @@ extern "C" {
|
|
| 375 |
// [EXPERIMENTAL] speed-up techniques
|
| 376 |
// note: these can significantly reduce the quality of the output
|
| 377 |
bool speed_up; // speed-up the audio by 2x using Phase Vocoder
|
|
|
|
| 378 |
int audio_ctx; // overwrite the audio context size (0 = use default)
|
| 379 |
|
| 380 |
// [EXPERIMENTAL] [TDRZ] tinydiarize
|
|
|
|
| 375 |
// [EXPERIMENTAL] speed-up techniques
|
| 376 |
// note: these can significantly reduce the quality of the output
|
| 377 |
bool speed_up; // speed-up the audio by 2x using Phase Vocoder
|
| 378 |
+
bool debug_mode; // enable debug_mode provides extra info (eg. Dump log_mel)
|
| 379 |
int audio_ctx; // overwrite the audio context size (0 = use default)
|
| 380 |
|
| 381 |
// [EXPERIMENTAL] [TDRZ] tinydiarize
|