Spaces:
Running
Running
whisper : add whisper_n_audio_ctx and check for invalid audio_ctx
Browse files- whisper.cpp +9 -1
- whisper.h +1 -0
whisper.cpp
CHANGED
|
@@ -2497,6 +2497,10 @@ int whisper_n_text_ctx(struct whisper_context * ctx) {
|
|
| 2497 |
return ctx->model.hparams.n_text_ctx;
|
| 2498 |
}
|
| 2499 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2500 |
int whisper_is_multilingual(struct whisper_context * ctx) {
|
| 2501 |
return ctx->vocab.is_multilingual() ? 1 : 0;
|
| 2502 |
}
|
|
@@ -2822,7 +2826,11 @@ int whisper_full(
|
|
| 2822 |
std::rotate(prompt_past.begin(), prompt_past.end() - params.prompt_n_tokens, prompt_past.end());
|
| 2823 |
}
|
| 2824 |
|
| 2825 |
-
// overwrite audio_ctx
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2826 |
ctx->exp_n_audio_ctx = params.audio_ctx;
|
| 2827 |
|
| 2828 |
// these tokens determine the task that will be performed
|
|
|
|
| 2497 |
return ctx->model.hparams.n_text_ctx;
|
| 2498 |
}
|
| 2499 |
|
| 2500 |
+
int whisper_n_audio_ctx(struct whisper_context * ctx) {
|
| 2501 |
+
return ctx->model.hparams.n_audio_ctx;
|
| 2502 |
+
}
|
| 2503 |
+
|
| 2504 |
int whisper_is_multilingual(struct whisper_context * ctx) {
|
| 2505 |
return ctx->vocab.is_multilingual() ? 1 : 0;
|
| 2506 |
}
|
|
|
|
| 2826 |
std::rotate(prompt_past.begin(), prompt_past.end() - params.prompt_n_tokens, prompt_past.end());
|
| 2827 |
}
|
| 2828 |
|
| 2829 |
+
// overwrite audio_ctx, max allowed is hparams.n_audio_ctx
|
| 2830 |
+
if (params.audio_ctx > whisper_n_audio_ctx(ctx)) {
|
| 2831 |
+
fprintf(stderr, "%s: audio_ctx is larger than the maximum allowed (%d > %d)\n", __func__, params.audio_ctx, whisper_n_audio_ctx(ctx));
|
| 2832 |
+
return -4;
|
| 2833 |
+
}
|
| 2834 |
ctx->exp_n_audio_ctx = params.audio_ctx;
|
| 2835 |
|
| 2836 |
// these tokens determine the task that will be performed
|
whisper.h
CHANGED
|
@@ -177,6 +177,7 @@ extern "C" {
|
|
| 177 |
WHISPER_API int whisper_n_len (struct whisper_context * ctx); // mel length
|
| 178 |
WHISPER_API int whisper_n_vocab (struct whisper_context * ctx);
|
| 179 |
WHISPER_API int whisper_n_text_ctx (struct whisper_context * ctx);
|
|
|
|
| 180 |
WHISPER_API int whisper_is_multilingual(struct whisper_context * ctx);
|
| 181 |
|
| 182 |
// The probabilities for the next token
|
|
|
|
| 177 |
WHISPER_API int whisper_n_len (struct whisper_context * ctx); // mel length
|
| 178 |
WHISPER_API int whisper_n_vocab (struct whisper_context * ctx);
|
| 179 |
WHISPER_API int whisper_n_text_ctx (struct whisper_context * ctx);
|
| 180 |
+
WHISPER_API int whisper_n_audio_ctx (struct whisper_context * ctx);
|
| 181 |
WHISPER_API int whisper_is_multilingual(struct whisper_context * ctx);
|
| 182 |
|
| 183 |
// The probabilities for the next token
|