Spaces:
Running
Running
whisper : add API for applying custom logits filters during decoding
Browse files- whisper.cpp +9 -3
- whisper.h +14 -0
whisper.cpp
CHANGED
|
@@ -805,7 +805,7 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
|
|
| 805 |
MEM_REQ_SCRATCH3.at (model.type) +
|
| 806 |
scale*MEM_REQ_MODEL.at (model.type) +
|
| 807 |
scale*MEM_REQ_KV_CROSS.at(model.type) +
|
| 808 |
-
scale*std::max(MEM_REQ_ENCODE.at(model.type),
|
| 809 |
|
| 810 |
// this is the memory required by one decoder
|
| 811 |
const size_t mem_required_decoder =
|
|
@@ -2962,6 +2962,9 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
|
|
| 2962 |
|
| 2963 |
/*.encoder_begin_callback =*/ nullptr,
|
| 2964 |
/*.encoder_begin_callback_user_data =*/ nullptr,
|
|
|
|
|
|
|
|
|
|
| 2965 |
};
|
| 2966 |
|
| 2967 |
switch (strategy) {
|
|
@@ -3089,7 +3092,7 @@ static const std::vector<std::string> non_speech_tokens = {
|
|
| 3089 |
// - applies logit filters
|
| 3090 |
// - computes logprobs and probs
|
| 3091 |
static void whisper_process_logits(
|
| 3092 |
-
|
| 3093 |
const struct whisper_full_params params,
|
| 3094 |
struct whisper_decoder & decoder,
|
| 3095 |
float temperature) {
|
|
@@ -3145,6 +3148,9 @@ static void whisper_process_logits(
|
|
| 3145 |
logits[vocab.token_translate] = -INFINITY;
|
| 3146 |
logits[vocab.token_transcribe] = -INFINITY;
|
| 3147 |
|
|
|
|
|
|
|
|
|
|
| 3148 |
|
| 3149 |
// suppress non-speech tokens
|
| 3150 |
// ref: https://github.com/openai/whisper/blob/7858aa9c08d98f75575035ecd6481f462d66ca27/whisper/tokenizer.py#L224-L253
|
|
@@ -3848,7 +3854,7 @@ int whisper_full(
|
|
| 3848 |
return a.sequence.sum_logprobs_all > b.sequence.sum_logprobs_all;
|
| 3849 |
});
|
| 3850 |
|
| 3851 |
-
|
| 3852 |
|
| 3853 |
for (int j = 0; j < n_decoders_cur; ++j) {
|
| 3854 |
auto & decoder = ctx->decoders[j];
|
|
|
|
| 805 |
MEM_REQ_SCRATCH3.at (model.type) +
|
| 806 |
scale*MEM_REQ_MODEL.at (model.type) +
|
| 807 |
scale*MEM_REQ_KV_CROSS.at(model.type) +
|
| 808 |
+
scale*std::max(MEM_REQ_ENCODE.at(model.type), MEM_REQ_DECODE.at(model.type));
|
| 809 |
|
| 810 |
// this is the memory required by one decoder
|
| 811 |
const size_t mem_required_decoder =
|
|
|
|
| 2962 |
|
| 2963 |
/*.encoder_begin_callback =*/ nullptr,
|
| 2964 |
/*.encoder_begin_callback_user_data =*/ nullptr,
|
| 2965 |
+
|
| 2966 |
+
/*.logits_filter_callback =*/ nullptr,
|
| 2967 |
+
/*.logits_filter_callback_user_data =*/ nullptr,
|
| 2968 |
};
|
| 2969 |
|
| 2970 |
switch (strategy) {
|
|
|
|
| 3092 |
// - applies logit filters
|
| 3093 |
// - computes logprobs and probs
|
| 3094 |
static void whisper_process_logits(
|
| 3095 |
+
struct whisper_context & ctx,
|
| 3096 |
const struct whisper_full_params params,
|
| 3097 |
struct whisper_decoder & decoder,
|
| 3098 |
float temperature) {
|
|
|
|
| 3148 |
logits[vocab.token_translate] = -INFINITY;
|
| 3149 |
logits[vocab.token_transcribe] = -INFINITY;
|
| 3150 |
|
| 3151 |
+
if (params.logits_filter_callback) {
|
| 3152 |
+
params.logits_filter_callback(&ctx, tokens_cur.data(), tokens_cur.size(), logits.data(), params.logits_filter_callback_user_data);
|
| 3153 |
+
}
|
| 3154 |
|
| 3155 |
// suppress non-speech tokens
|
| 3156 |
// ref: https://github.com/openai/whisper/blob/7858aa9c08d98f75575035ecd6481f462d66ca27/whisper/tokenizer.py#L224-L253
|
|
|
|
| 3854 |
return a.sequence.sum_logprobs_all > b.sequence.sum_logprobs_all;
|
| 3855 |
});
|
| 3856 |
|
| 3857 |
+
uint32_t cur_c = 0;
|
| 3858 |
|
| 3859 |
for (int j = 0; j < n_decoders_cur; ++j) {
|
| 3860 |
auto & decoder = ctx->decoders[j];
|
whisper.h
CHANGED
|
@@ -243,6 +243,16 @@ extern "C" {
|
|
| 243 |
// If it returns false, the computation is aborted
|
| 244 |
typedef bool (*whisper_encoder_begin_callback)(struct whisper_context * ctx, void * user_data);
|
| 245 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 246 |
// Parameters for the whisper_full() function
|
| 247 |
// If you chnage the order or add new parameters, make sure to update the default values in whisper.cpp:
|
| 248 |
// whisper_full_default_params()
|
|
@@ -315,6 +325,10 @@ extern "C" {
|
|
| 315 |
// called each time before the encoder starts
|
| 316 |
whisper_encoder_begin_callback encoder_begin_callback;
|
| 317 |
void * encoder_begin_callback_user_data;
|
|
|
|
|
|
|
|
|
|
|
|
|
| 318 |
};
|
| 319 |
|
| 320 |
WHISPER_API struct whisper_full_params whisper_full_default_params(enum whisper_sampling_strategy strategy);
|
|
|
|
| 243 |
// If it returns false, the computation is aborted
|
| 244 |
typedef bool (*whisper_encoder_begin_callback)(struct whisper_context * ctx, void * user_data);
|
| 245 |
|
| 246 |
+
// Logits filter callback
|
| 247 |
+
// Can be used to modify the logits before sampling
|
| 248 |
+
// If not NULL, called after applying temperature to logits
|
| 249 |
+
typedef void (*whisper_logits_filter_callback)(
|
| 250 |
+
struct whisper_context * ctx,
|
| 251 |
+
const whisper_token_data * tokens,
|
| 252 |
+
int n_tokens,
|
| 253 |
+
float * logits,
|
| 254 |
+
void * user_data);
|
| 255 |
+
|
| 256 |
// Parameters for the whisper_full() function
|
| 257 |
// If you chnage the order or add new parameters, make sure to update the default values in whisper.cpp:
|
| 258 |
// whisper_full_default_params()
|
|
|
|
| 325 |
// called each time before the encoder starts
|
| 326 |
whisper_encoder_begin_callback encoder_begin_callback;
|
| 327 |
void * encoder_begin_callback_user_data;
|
| 328 |
+
|
| 329 |
+
// called by each decoder to filter obtained logits
|
| 330 |
+
whisper_logits_filter_callback logits_filter_callback;
|
| 331 |
+
void * logits_filter_callback_user_data;
|
| 332 |
};
|
| 333 |
|
| 334 |
WHISPER_API struct whisper_full_params whisper_full_default_params(enum whisper_sampling_strategy strategy);
|