josharian commited on
Commit
6e9276c
·
unverified ·
1 Parent(s): 3dee0de

whisper : improve beam search candidate diversity (#1947)

Browse files

As of #1486, whisper.cpp uses a unified KV cache with KQ masking.
As a result, depending on their location in the batch,
identical sequences in a batch can have slightly different outputs
due to floating point rounding errors during reduction.
See the discussion in #1941 for more details.

The beam search code used "has identical sum of log probabilities"
as a shorthand for "is an identical token sequence". However, per above,
identical tokens do not necessarily result in identical probabilities.

Instead, explicitly compare on sequences.
This is linear in cost when they are identical,
but the lengths are always small and the comparisons are cheap.

This increases diversity during beam search.

This improves output quality for some short samples I've been working
with, at no detectable performance cost.
I haven't checked against larger corpuses.

Fixes #1941

Files changed (1) hide show
  1. whisper.cpp +14 -1
whisper.cpp CHANGED
@@ -4759,6 +4759,19 @@ static void whisper_process_logits(
4759
  #endif
4760
  }
4761
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4762
  static whisper_token_data whisper_sample_token(
4763
  whisper_context & ctx,
4764
  const whisper_decoder & decoder,
@@ -5378,7 +5391,7 @@ int whisper_full_with_state(
5378
 
5379
  auto & cur = beam_candidates[cur_c++];
5380
 
5381
- while (beam_candidates.size() > cur_c && beam_candidates[cur_c].sequence.sum_logprobs_all == cur.sequence.sum_logprobs_all && i > 0) {
5382
  ++cur_c;
5383
  }
5384
 
 
4759
  #endif
4760
  }
4761
 
4762
+ static bool whisper_sequence_tokens_equal(const whisper_sequence & a, const whisper_sequence & b) {
4763
+ if (a.tokens.size() != b.tokens.size()) {
4764
+ return false;
4765
+ }
4766
+ // sequences are more likely to diverge at the end
4767
+ for (int i = a.tokens.size() - 1; i >= 0; i--) {
4768
+ if (a.tokens[i].id != b.tokens[i].id) {
4769
+ return false;
4770
+ }
4771
+ }
4772
+ return true;
4773
+ }
4774
+
4775
  static whisper_token_data whisper_sample_token(
4776
  whisper_context & ctx,
4777
  const whisper_decoder & decoder,
 
5391
 
5392
  auto & cur = beam_candidates[cur_c++];
5393
 
5394
+ while (beam_candidates.size() > cur_c && whisper_sequence_tokens_equal(beam_candidates[cur_c].sequence, cur.sequence) && i > 0) {
5395
  ++cur_c;
5396
  }
5397