ggerganov commited on
Commit
821a538
·
unverified ·
1 Parent(s): dc8eb84

whisper : add whisper_n_audio_ctx and check for invalid audio_ctx

Browse files
Files changed (2) hide show
  1. whisper.cpp +9 -1
  2. whisper.h +1 -0
whisper.cpp CHANGED
@@ -2497,6 +2497,10 @@ int whisper_n_text_ctx(struct whisper_context * ctx) {
2497
  return ctx->model.hparams.n_text_ctx;
2498
  }
2499
 
 
 
 
 
2500
  int whisper_is_multilingual(struct whisper_context * ctx) {
2501
  return ctx->vocab.is_multilingual() ? 1 : 0;
2502
  }
@@ -2822,7 +2826,11 @@ int whisper_full(
2822
  std::rotate(prompt_past.begin(), prompt_past.end() - params.prompt_n_tokens, prompt_past.end());
2823
  }
2824
 
2825
- // overwrite audio_ctx
 
 
 
 
2826
  ctx->exp_n_audio_ctx = params.audio_ctx;
2827
 
2828
  // these tokens determine the task that will be performed
 
2497
  return ctx->model.hparams.n_text_ctx;
2498
  }
2499
 
2500
+ int whisper_n_audio_ctx(struct whisper_context * ctx) {
2501
+ return ctx->model.hparams.n_audio_ctx;
2502
+ }
2503
+
2504
  int whisper_is_multilingual(struct whisper_context * ctx) {
2505
  return ctx->vocab.is_multilingual() ? 1 : 0;
2506
  }
 
2826
  std::rotate(prompt_past.begin(), prompt_past.end() - params.prompt_n_tokens, prompt_past.end());
2827
  }
2828
 
2829
+ // overwrite audio_ctx, max allowed is hparams.n_audio_ctx
2830
+ if (params.audio_ctx > whisper_n_audio_ctx(ctx)) {
2831
+ fprintf(stderr, "%s: audio_ctx is larger than the maximum allowed (%d > %d)\n", __func__, params.audio_ctx, whisper_n_audio_ctx(ctx));
2832
+ return -4;
2833
+ }
2834
  ctx->exp_n_audio_ctx = params.audio_ctx;
2835
 
2836
  // these tokens determine the task that will be performed
whisper.h CHANGED
@@ -177,6 +177,7 @@ extern "C" {
177
  WHISPER_API int whisper_n_len (struct whisper_context * ctx); // mel length
178
  WHISPER_API int whisper_n_vocab (struct whisper_context * ctx);
179
  WHISPER_API int whisper_n_text_ctx (struct whisper_context * ctx);
 
180
  WHISPER_API int whisper_is_multilingual(struct whisper_context * ctx);
181
 
182
  // The probabilities for the next token
 
177
  WHISPER_API int whisper_n_len (struct whisper_context * ctx); // mel length
178
  WHISPER_API int whisper_n_vocab (struct whisper_context * ctx);
179
  WHISPER_API int whisper_n_text_ctx (struct whisper_context * ctx);
180
+ WHISPER_API int whisper_n_audio_ctx (struct whisper_context * ctx);
181
  WHISPER_API int whisper_is_multilingual(struct whisper_context * ctx);
182
 
183
  // The probabilities for the next token