meakbiyik commited on
Commit
8ad3dbf
·
unverified ·
1 Parent(s): 5cb7243

Prompt previous tokens for streaming (#163)

Browse files

* feat: prompt previous tokens for streaming

I used a vector pointer instead of vector itself because it gave weird errors, and why not

* convert vector to use with C api

* feat: remove old refs, check for prompt size

* feat: use better way of getting the pointer

Files changed (3) hide show
  1. examples/stream/stream.cpp +14 -0
  2. whisper.cpp +15 -0
  3. whisper.h +4 -0
examples/stream/stream.cpp CHANGED
@@ -234,6 +234,7 @@ int main(int argc, char ** argv) {
234
  std::vector<float> pcmf32(n_samples_30s, 0.0f);
235
  std::vector<float> pcmf32_old;
236
 
 
237
  const int n_new_line = params.length_ms / params.step_ms - 1;
238
 
239
  // print some info about the processing
@@ -344,6 +345,9 @@ int main(int argc, char ** argv) {
344
  wparams.audio_ctx = params.audio_ctx;
345
  wparams.speed_up = params.speed_up;
346
 
 
 
 
347
  if (whisper_full(ctx, wparams, pcmf32.data(), pcmf32.size()) != 0) {
348
  fprintf(stderr, "%s: failed to process audio\n", argv[0]);
349
  return 6;
@@ -393,6 +397,16 @@ int main(int argc, char ** argv) {
393
 
394
  // keep part of the audio for next iteration to try to mitigate word boundary issues
395
  pcmf32_old = std::vector<float>(pcmf32.end() - n_samples_keep, pcmf32.end());
 
 
 
 
 
 
 
 
 
 
396
  }
397
  }
398
  }
 
234
  std::vector<float> pcmf32(n_samples_30s, 0.0f);
235
  std::vector<float> pcmf32_old;
236
 
237
+ std::vector<whisper_token> prompt_tokens;
238
  const int n_new_line = params.length_ms / params.step_ms - 1;
239
 
240
  // print some info about the processing
 
345
  wparams.audio_ctx = params.audio_ctx;
346
  wparams.speed_up = params.speed_up;
347
 
348
+ wparams.prompt_tokens = prompt_tokens.data();
349
+ wparams.prompt_n_tokens = prompt_tokens.size();
350
+
351
  if (whisper_full(ctx, wparams, pcmf32.data(), pcmf32.size()) != 0) {
352
  fprintf(stderr, "%s: failed to process audio\n", argv[0]);
353
  return 6;
 
397
 
398
  // keep part of the audio for next iteration to try to mitigate word boundary issues
399
  pcmf32_old = std::vector<float>(pcmf32.end() - n_samples_keep, pcmf32.end());
400
+
401
+ // Add tokens of the last full length segment as the prompt
402
+ prompt_tokens.clear();
403
+ const int n_segments = whisper_full_n_segments(ctx);
404
+ for (int i = 0; i < n_segments; ++i) {
405
+ const int token_count = whisper_full_n_tokens(ctx, i);
406
+ for (int j = 0; j < token_count; ++j) {
407
+ prompt_tokens.push_back(whisper_full_get_token_id(ctx, i, j));
408
+ }
409
+ }
410
  }
411
  }
412
  }
whisper.cpp CHANGED
@@ -2412,6 +2412,9 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
2412
  /*.speed_up =*/ false,
2413
  /*.audio_ctx =*/ 0,
2414
 
 
 
 
2415
  /*.language =*/ "en",
2416
 
2417
  /*.greedy =*/ {
@@ -2455,6 +2458,9 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
2455
  /*.speed_up =*/ false,
2456
  /*.audio_ctx =*/ 0,
2457
 
 
 
 
2458
  /*.language =*/ "en",
2459
 
2460
  /*.greedy =*/ {
@@ -2584,6 +2590,15 @@ int whisper_full(
2584
  prompt_past.clear();
2585
  }
2586
 
 
 
 
 
 
 
 
 
 
2587
  // overwrite audio_ctx
2588
  ctx->exp_n_audio_ctx = params.audio_ctx;
2589
 
 
2412
  /*.speed_up =*/ false,
2413
  /*.audio_ctx =*/ 0,
2414
 
2415
+ /*.prompt_tokens =*/ nullptr,
2416
+ /*.prompt_n_tokens =*/ 0,
2417
+
2418
  /*.language =*/ "en",
2419
 
2420
  /*.greedy =*/ {
 
2458
  /*.speed_up =*/ false,
2459
  /*.audio_ctx =*/ 0,
2460
 
2461
+ /*.prompt_tokens =*/ nullptr,
2462
+ /*.prompt_n_tokens =*/ 0,
2463
+
2464
  /*.language =*/ "en",
2465
 
2466
  /*.greedy =*/ {
 
2590
  prompt_past.clear();
2591
  }
2592
 
2593
+ // Prepend the prompt tokens to the prompt_past
2594
+ if (params.prompt_tokens && params.prompt_n_tokens > 0) {
2595
+ // Parse tokens from the pointer (it points to an std::vector)
2596
+ for (int i = 0; i < params.prompt_n_tokens; i++) {
2597
+ prompt_past.push_back(params.prompt_tokens[i]);
2598
+ }
2599
+ std::rotate(prompt_past.begin(), prompt_past.end() - params.prompt_n_tokens, prompt_past.end());
2600
+ }
2601
+
2602
  // overwrite audio_ctx
2603
  ctx->exp_n_audio_ctx = params.audio_ctx;
2604
 
whisper.h CHANGED
@@ -208,6 +208,10 @@ extern "C" {
208
  bool speed_up; // speed-up the audio by 2x using Phase Vocoder
209
  int audio_ctx; // overwrite the audio context size (0 = use default)
210
 
 
 
 
 
211
  const char * language;
212
 
213
  struct {
 
208
  bool speed_up; // speed-up the audio by 2x using Phase Vocoder
209
  int audio_ctx; // overwrite the audio context size (0 = use default)
210
 
211
+ // std::vector<whisper_token>: tokens to provide the whisper model as initial prompt
212
+ const whisper_token * prompt_tokens;
213
+ int prompt_n_tokens;
214
+
215
  const char * language;
216
 
217
  struct {