ggerganov commited on
Commit
54e85c7
·
unverified ·
1 Parent(s): 3564c33

whisper : do not provide past prompt when n_max_text_ctx == 0

Browse files
Files changed (1) hide show
  1. whisper.cpp +6 -6
whisper.cpp CHANGED
@@ -3524,7 +3524,7 @@ int whisper_full(
3524
  prompt.clear();
3525
 
3526
  // if we have already generated some text, use it as a prompt to condition the next generation
3527
- if (!prompt_past.empty() && t_cur < 0.5f) {
3528
  int n_take = std::min(std::min(params.n_max_text_ctx, whisper_n_text_ctx(ctx)/2), int(prompt_past.size()));
3529
 
3530
  prompt = { whisper_token_prev(ctx) };
@@ -3535,11 +3535,11 @@ int whisper_full(
3535
  prompt.insert(prompt.end(), prompt_init.begin(), prompt_init.end());
3536
 
3537
  // print the prompt
3538
- //WHISPER_PRINT_DEBUG("\n\n");
3539
- //for (int i = 0; i < (int) prompt.size(); i++) {
3540
- // WHISPER_PRINT_DEBUG("%s: prompt[%d] = %s\n", __func__, i, ctx->vocab.id_to_token.at(prompt[i]).c_str());
3541
- //}
3542
- //WHISPER_PRINT_DEBUG("\n\n");
3543
 
3544
  if (!whisper_decode(*ctx, ctx->decoders[0], prompt.data(), prompt.size(), 0, params.n_threads)) {
3545
  fprintf(stderr, "%s: failed to decode\n", __func__);
 
3524
  prompt.clear();
3525
 
3526
  // if we have already generated some text, use it as a prompt to condition the next generation
3527
+ if (!prompt_past.empty() && t_cur < 0.5f && params.n_max_text_ctx > 0) {
3528
  int n_take = std::min(std::min(params.n_max_text_ctx, whisper_n_text_ctx(ctx)/2), int(prompt_past.size()));
3529
 
3530
  prompt = { whisper_token_prev(ctx) };
 
3535
  prompt.insert(prompt.end(), prompt_init.begin(), prompt_init.end());
3536
 
3537
  // print the prompt
3538
+ WHISPER_PRINT_DEBUG("\n\n");
3539
+ for (int i = 0; i < (int) prompt.size(); i++) {
3540
+ WHISPER_PRINT_DEBUG("%s: prompt[%d] = %s\n", __func__, i, ctx->vocab.id_to_token.at(prompt[i]).c_str());
3541
+ }
3542
+ WHISPER_PRINT_DEBUG("\n\n");
3543
 
3544
  if (!whisper_decode(*ctx, ctx->decoders[0], prompt.data(), prompt.size(), 0, params.n_threads)) {
3545
  fprintf(stderr, "%s: failed to decode\n", __func__);