Spaces:
Running
Running
Prompt previous tokens for streaming (#163)
Browse files* feat: prompt previous tokens for streaming
I used a vector pointer instead of vector itself because it gave weird errors, and why not
* convert vector to use with C api
* feat: remove old refs, check for prompt size
* feat: use better way of getting the pointer
- examples/stream/stream.cpp +14 -0
- whisper.cpp +15 -0
- whisper.h +4 -0
examples/stream/stream.cpp
CHANGED
|
@@ -234,6 +234,7 @@ int main(int argc, char ** argv) {
|
|
| 234 |
std::vector<float> pcmf32(n_samples_30s, 0.0f);
|
| 235 |
std::vector<float> pcmf32_old;
|
| 236 |
|
|
|
|
| 237 |
const int n_new_line = params.length_ms / params.step_ms - 1;
|
| 238 |
|
| 239 |
// print some info about the processing
|
|
@@ -344,6 +345,9 @@ int main(int argc, char ** argv) {
|
|
| 344 |
wparams.audio_ctx = params.audio_ctx;
|
| 345 |
wparams.speed_up = params.speed_up;
|
| 346 |
|
|
|
|
|
|
|
|
|
|
| 347 |
if (whisper_full(ctx, wparams, pcmf32.data(), pcmf32.size()) != 0) {
|
| 348 |
fprintf(stderr, "%s: failed to process audio\n", argv[0]);
|
| 349 |
return 6;
|
|
@@ -393,6 +397,16 @@ int main(int argc, char ** argv) {
|
|
| 393 |
|
| 394 |
// keep part of the audio for next iteration to try to mitigate word boundary issues
|
| 395 |
pcmf32_old = std::vector<float>(pcmf32.end() - n_samples_keep, pcmf32.end());
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 396 |
}
|
| 397 |
}
|
| 398 |
}
|
|
|
|
| 234 |
std::vector<float> pcmf32(n_samples_30s, 0.0f);
|
| 235 |
std::vector<float> pcmf32_old;
|
| 236 |
|
| 237 |
+
std::vector<whisper_token> prompt_tokens;
|
| 238 |
const int n_new_line = params.length_ms / params.step_ms - 1;
|
| 239 |
|
| 240 |
// print some info about the processing
|
|
|
|
| 345 |
wparams.audio_ctx = params.audio_ctx;
|
| 346 |
wparams.speed_up = params.speed_up;
|
| 347 |
|
| 348 |
+
wparams.prompt_tokens = prompt_tokens.data();
|
| 349 |
+
wparams.prompt_n_tokens = prompt_tokens.size();
|
| 350 |
+
|
| 351 |
if (whisper_full(ctx, wparams, pcmf32.data(), pcmf32.size()) != 0) {
|
| 352 |
fprintf(stderr, "%s: failed to process audio\n", argv[0]);
|
| 353 |
return 6;
|
|
|
|
| 397 |
|
| 398 |
// keep part of the audio for next iteration to try to mitigate word boundary issues
|
| 399 |
pcmf32_old = std::vector<float>(pcmf32.end() - n_samples_keep, pcmf32.end());
|
| 400 |
+
|
| 401 |
+
// Add tokens of the last full length segment as the prompt
|
| 402 |
+
prompt_tokens.clear();
|
| 403 |
+
const int n_segments = whisper_full_n_segments(ctx);
|
| 404 |
+
for (int i = 0; i < n_segments; ++i) {
|
| 405 |
+
const int token_count = whisper_full_n_tokens(ctx, i);
|
| 406 |
+
for (int j = 0; j < token_count; ++j) {
|
| 407 |
+
prompt_tokens.push_back(whisper_full_get_token_id(ctx, i, j));
|
| 408 |
+
}
|
| 409 |
+
}
|
| 410 |
}
|
| 411 |
}
|
| 412 |
}
|
whisper.cpp
CHANGED
|
@@ -2412,6 +2412,9 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
|
|
| 2412 |
/*.speed_up =*/ false,
|
| 2413 |
/*.audio_ctx =*/ 0,
|
| 2414 |
|
|
|
|
|
|
|
|
|
|
| 2415 |
/*.language =*/ "en",
|
| 2416 |
|
| 2417 |
/*.greedy =*/ {
|
|
@@ -2455,6 +2458,9 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
|
|
| 2455 |
/*.speed_up =*/ false,
|
| 2456 |
/*.audio_ctx =*/ 0,
|
| 2457 |
|
|
|
|
|
|
|
|
|
|
| 2458 |
/*.language =*/ "en",
|
| 2459 |
|
| 2460 |
/*.greedy =*/ {
|
|
@@ -2584,6 +2590,15 @@ int whisper_full(
|
|
| 2584 |
prompt_past.clear();
|
| 2585 |
}
|
| 2586 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2587 |
// overwrite audio_ctx
|
| 2588 |
ctx->exp_n_audio_ctx = params.audio_ctx;
|
| 2589 |
|
|
|
|
| 2412 |
/*.speed_up =*/ false,
|
| 2413 |
/*.audio_ctx =*/ 0,
|
| 2414 |
|
| 2415 |
+
/*.prompt_tokens =*/ nullptr,
|
| 2416 |
+
/*.prompt_n_tokens =*/ 0,
|
| 2417 |
+
|
| 2418 |
/*.language =*/ "en",
|
| 2419 |
|
| 2420 |
/*.greedy =*/ {
|
|
|
|
| 2458 |
/*.speed_up =*/ false,
|
| 2459 |
/*.audio_ctx =*/ 0,
|
| 2460 |
|
| 2461 |
+
/*.prompt_tokens =*/ nullptr,
|
| 2462 |
+
/*.prompt_n_tokens =*/ 0,
|
| 2463 |
+
|
| 2464 |
/*.language =*/ "en",
|
| 2465 |
|
| 2466 |
/*.greedy =*/ {
|
|
|
|
| 2590 |
prompt_past.clear();
|
| 2591 |
}
|
| 2592 |
|
| 2593 |
+
// Prepend the prompt tokens to the prompt_past
|
| 2594 |
+
if (params.prompt_tokens && params.prompt_n_tokens > 0) {
|
| 2595 |
+
// Parse tokens from the pointer (it points to an std::vector)
|
| 2596 |
+
for (int i = 0; i < params.prompt_n_tokens; i++) {
|
| 2597 |
+
prompt_past.push_back(params.prompt_tokens[i]);
|
| 2598 |
+
}
|
| 2599 |
+
std::rotate(prompt_past.begin(), prompt_past.end() - params.prompt_n_tokens, prompt_past.end());
|
| 2600 |
+
}
|
| 2601 |
+
|
| 2602 |
// overwrite audio_ctx
|
| 2603 |
ctx->exp_n_audio_ctx = params.audio_ctx;
|
| 2604 |
|
whisper.h
CHANGED
|
@@ -208,6 +208,10 @@ extern "C" {
|
|
| 208 |
bool speed_up; // speed-up the audio by 2x using Phase Vocoder
|
| 209 |
int audio_ctx; // overwrite the audio context size (0 = use default)
|
| 210 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 211 |
const char * language;
|
| 212 |
|
| 213 |
struct {
|
|
|
|
| 208 |
bool speed_up; // speed-up the audio by 2x using Phase Vocoder
|
| 209 |
int audio_ctx; // overwrite the audio context size (0 = use default)
|
| 210 |
|
| 211 |
+
// std::vector<whisper_token>: tokens to provide the whisper model as initial prompt
|
| 212 |
+
const whisper_token * prompt_tokens;
|
| 213 |
+
int prompt_n_tokens;
|
| 214 |
+
|
| 215 |
const char * language;
|
| 216 |
|
| 217 |
struct {
|