ggerganov commited on
Commit
d85b238
·
unverified ·
1 Parent(s): 1080ab7

whisper : add API for applying custom logits filters during decoding

Browse files
Files changed (2) hide show
  1. whisper.cpp +9 -3
  2. whisper.h +14 -0
whisper.cpp CHANGED
@@ -805,7 +805,7 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
805
  MEM_REQ_SCRATCH3.at (model.type) +
806
  scale*MEM_REQ_MODEL.at (model.type) +
807
  scale*MEM_REQ_KV_CROSS.at(model.type) +
808
- scale*std::max(MEM_REQ_ENCODE.at(model.type), MEM_REQ_DECODE.at(model.type));
809
 
810
  // this is the memory required by one decoder
811
  const size_t mem_required_decoder =
@@ -2962,6 +2962,9 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
2962
 
2963
  /*.encoder_begin_callback =*/ nullptr,
2964
  /*.encoder_begin_callback_user_data =*/ nullptr,
 
 
 
2965
  };
2966
 
2967
  switch (strategy) {
@@ -3089,7 +3092,7 @@ static const std::vector<std::string> non_speech_tokens = {
3089
  // - applies logit filters
3090
  // - computes logprobs and probs
3091
  static void whisper_process_logits(
3092
- const struct whisper_context & ctx,
3093
  const struct whisper_full_params params,
3094
  struct whisper_decoder & decoder,
3095
  float temperature) {
@@ -3145,6 +3148,9 @@ static void whisper_process_logits(
3145
  logits[vocab.token_translate] = -INFINITY;
3146
  logits[vocab.token_transcribe] = -INFINITY;
3147
 
 
 
 
3148
 
3149
  // suppress non-speech tokens
3150
  // ref: https://github.com/openai/whisper/blob/7858aa9c08d98f75575035ecd6481f462d66ca27/whisper/tokenizer.py#L224-L253
@@ -3848,7 +3854,7 @@ int whisper_full(
3848
  return a.sequence.sum_logprobs_all > b.sequence.sum_logprobs_all;
3849
  });
3850
 
3851
- unsigned int cur_c = 0;
3852
 
3853
  for (int j = 0; j < n_decoders_cur; ++j) {
3854
  auto & decoder = ctx->decoders[j];
 
805
  MEM_REQ_SCRATCH3.at (model.type) +
806
  scale*MEM_REQ_MODEL.at (model.type) +
807
  scale*MEM_REQ_KV_CROSS.at(model.type) +
808
+ scale*std::max(MEM_REQ_ENCODE.at(model.type), MEM_REQ_DECODE.at(model.type));
809
 
810
  // this is the memory required by one decoder
811
  const size_t mem_required_decoder =
 
2962
 
2963
  /*.encoder_begin_callback =*/ nullptr,
2964
  /*.encoder_begin_callback_user_data =*/ nullptr,
2965
+
2966
+ /*.logits_filter_callback =*/ nullptr,
2967
+ /*.logits_filter_callback_user_data =*/ nullptr,
2968
  };
2969
 
2970
  switch (strategy) {
 
3092
  // - applies logit filters
3093
  // - computes logprobs and probs
3094
  static void whisper_process_logits(
3095
+ struct whisper_context & ctx,
3096
  const struct whisper_full_params params,
3097
  struct whisper_decoder & decoder,
3098
  float temperature) {
 
3148
  logits[vocab.token_translate] = -INFINITY;
3149
  logits[vocab.token_transcribe] = -INFINITY;
3150
 
3151
+ if (params.logits_filter_callback) {
3152
+ params.logits_filter_callback(&ctx, tokens_cur.data(), tokens_cur.size(), logits.data(), params.logits_filter_callback_user_data);
3153
+ }
3154
 
3155
  // suppress non-speech tokens
3156
  // ref: https://github.com/openai/whisper/blob/7858aa9c08d98f75575035ecd6481f462d66ca27/whisper/tokenizer.py#L224-L253
 
3854
  return a.sequence.sum_logprobs_all > b.sequence.sum_logprobs_all;
3855
  });
3856
 
3857
+ uint32_t cur_c = 0;
3858
 
3859
  for (int j = 0; j < n_decoders_cur; ++j) {
3860
  auto & decoder = ctx->decoders[j];
whisper.h CHANGED
@@ -243,6 +243,16 @@ extern "C" {
243
  // If it returns false, the computation is aborted
244
  typedef bool (*whisper_encoder_begin_callback)(struct whisper_context * ctx, void * user_data);
245
 
 
 
 
 
 
 
 
 
 
 
246
  // Parameters for the whisper_full() function
247
  // If you chnage the order or add new parameters, make sure to update the default values in whisper.cpp:
248
  // whisper_full_default_params()
@@ -315,6 +325,10 @@ extern "C" {
315
  // called each time before the encoder starts
316
  whisper_encoder_begin_callback encoder_begin_callback;
317
  void * encoder_begin_callback_user_data;
 
 
 
 
318
  };
319
 
320
  WHISPER_API struct whisper_full_params whisper_full_default_params(enum whisper_sampling_strategy strategy);
 
243
  // If it returns false, the computation is aborted
244
  typedef bool (*whisper_encoder_begin_callback)(struct whisper_context * ctx, void * user_data);
245
 
246
+ // Logits filter callback
247
+ // Can be used to modify the logits before sampling
248
+ // If not NULL, called after applying temperature to logits
249
+ typedef void (*whisper_logits_filter_callback)(
250
+ struct whisper_context * ctx,
251
+ const whisper_token_data * tokens,
252
+ int n_tokens,
253
+ float * logits,
254
+ void * user_data);
255
+
256
  // Parameters for the whisper_full() function
257
  // If you chnage the order or add new parameters, make sure to update the default values in whisper.cpp:
258
  // whisper_full_default_params()
 
325
  // called each time before the encoder starts
326
  whisper_encoder_begin_callback encoder_begin_callback;
327
  void * encoder_begin_callback_user_data;
328
+
329
+ // called by each decoder to filter obtained logits
330
+ whisper_logits_filter_callback logits_filter_callback;
331
+ void * logits_filter_callback_user_data;
332
  };
333
 
334
  WHISPER_API struct whisper_full_params whisper_full_default_params(enum whisper_sampling_strategy strategy);