Matija Pevec commited on
Commit
6b6bdd4
·
unverified ·
1 Parent(s): 07e1dc7

whisper : add "split_on_word" flag when using using "max_len" option (#455)

Browse files

* Update whisper.cpp

* fix: trim function

* feat: added flag to split on word

* fix: arguments for main

Files changed (3) hide show
  1. examples/main/main.cpp +4 -0
  2. whisper.cpp +34 -5
  3. whisper.h +1 -0
examples/main/main.cpp CHANGED
@@ -69,6 +69,7 @@ struct whisper_params {
69
  bool speed_up = false;
70
  bool translate = false;
71
  bool diarize = false;
 
72
  bool no_fallback = false;
73
  bool output_txt = false;
74
  bool output_vtt = false;
@@ -118,6 +119,7 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
118
  else if (arg == "-su" || arg == "--speed-up") { params.speed_up = true; }
119
  else if (arg == "-tr" || arg == "--translate") { params.translate = true; }
120
  else if (arg == "-di" || arg == "--diarize") { params.diarize = true; }
 
121
  else if (arg == "-nf" || arg == "--no-fallback") { params.no_fallback = true; }
122
  else if (arg == "-otxt" || arg == "--output-txt") { params.output_txt = true; }
123
  else if (arg == "-ovtt" || arg == "--output-vtt") { params.output_vtt = true; }
@@ -156,6 +158,7 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
156
  fprintf(stderr, " -d N, --duration N [%-7d] duration of audio to process in milliseconds\n", params.duration_ms);
157
  fprintf(stderr, " -mc N, --max-context N [%-7d] maximum number of text context tokens to store\n", params.max_context);
158
  fprintf(stderr, " -ml N, --max-len N [%-7d] maximum segment length in characters\n", params.max_len);
 
159
  fprintf(stderr, " -bo N, --best-of N [%-7d] number of best candidates to keep\n", params.best_of);
160
  fprintf(stderr, " -bs N, --beam-size N [%-7d] beam size for beam search\n", params.beam_size);
161
  fprintf(stderr, " -wt N, --word-thold N [%-7.2f] word timestamp probability threshold\n", params.word_thold);
@@ -651,6 +654,7 @@ int main(int argc, char ** argv) {
651
  wparams.token_timestamps = params.output_wts || params.max_len > 0;
652
  wparams.thold_pt = params.word_thold;
653
  wparams.max_len = params.output_wts && params.max_len == 0 ? 60 : params.max_len;
 
654
 
655
  wparams.speed_up = params.speed_up;
656
 
 
69
  bool speed_up = false;
70
  bool translate = false;
71
  bool diarize = false;
72
+ bool split_on_word = false;
73
  bool no_fallback = false;
74
  bool output_txt = false;
75
  bool output_vtt = false;
 
119
  else if (arg == "-su" || arg == "--speed-up") { params.speed_up = true; }
120
  else if (arg == "-tr" || arg == "--translate") { params.translate = true; }
121
  else if (arg == "-di" || arg == "--diarize") { params.diarize = true; }
122
+ else if (arg == "-sow" || arg == "--split-on-word") { params.split_on_word = true; }
123
  else if (arg == "-nf" || arg == "--no-fallback") { params.no_fallback = true; }
124
  else if (arg == "-otxt" || arg == "--output-txt") { params.output_txt = true; }
125
  else if (arg == "-ovtt" || arg == "--output-vtt") { params.output_vtt = true; }
 
158
  fprintf(stderr, " -d N, --duration N [%-7d] duration of audio to process in milliseconds\n", params.duration_ms);
159
  fprintf(stderr, " -mc N, --max-context N [%-7d] maximum number of text context tokens to store\n", params.max_context);
160
  fprintf(stderr, " -ml N, --max-len N [%-7d] maximum segment length in characters\n", params.max_len);
161
+ fprintf(stderr, " -sow, --split-on-word [%-7s] split on word rather than on token\n", params.split_on_word ? "true" : "false");
162
  fprintf(stderr, " -bo N, --best-of N [%-7d] number of best candidates to keep\n", params.best_of);
163
  fprintf(stderr, " -bs N, --beam-size N [%-7d] beam size for beam search\n", params.beam_size);
164
  fprintf(stderr, " -wt N, --word-thold N [%-7.2f] word timestamp probability threshold\n", params.word_thold);
 
654
  wparams.token_timestamps = params.output_wts || params.max_len > 0;
655
  wparams.thold_pt = params.word_thold;
656
  wparams.max_len = params.output_wts && params.max_len == 0 ? 60 : params.max_len;
657
+ wparams.split_on_word = params.split_on_word;
658
 
659
  wparams.speed_up = params.speed_up;
660
 
whisper.cpp CHANGED
@@ -2922,6 +2922,7 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
2922
  /*.thold_pt =*/ 0.01f,
2923
  /*.thold_ptsum =*/ 0.01f,
2924
  /*.max_len =*/ 0,
 
2925
  /*.max_tokens =*/ 0,
2926
 
2927
  /*.speed_up =*/ false,
@@ -2988,9 +2989,36 @@ static void whisper_exp_compute_token_level_timestamps(
2988
  float thold_pt,
2989
  float thold_ptsum);
2990
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2991
  // wrap the last segment to max_len characters
2992
  // returns the number of new segments
2993
- static int whisper_wrap_segment(struct whisper_context & ctx, int max_len) {
2994
  auto segment = ctx.result_all.back();
2995
 
2996
  int res = 1;
@@ -3005,11 +3033,11 @@ static int whisper_wrap_segment(struct whisper_context & ctx, int max_len) {
3005
  }
3006
 
3007
  const auto txt = whisper_token_to_str(&ctx, token.id);
3008
-
3009
  const int cur = strlen(txt);
3010
 
3011
- if (acc + cur > max_len && i > 0) {
3012
  // split here
 
3013
  ctx.result_all.back().text = std::move(text);
3014
  ctx.result_all.back().t1 = token.t0;
3015
  ctx.result_all.back().tokens.resize(i);
@@ -3037,6 +3065,7 @@ static int whisper_wrap_segment(struct whisper_context & ctx, int max_len) {
3037
  }
3038
  }
3039
 
 
3040
  ctx.result_all.back().text = std::move(text);
3041
 
3042
  return res;
@@ -4069,7 +4098,7 @@ int whisper_full(
4069
  *ctx, result_all.size() - 1, params.thold_pt, params.thold_ptsum);
4070
 
4071
  if (params.max_len > 0) {
4072
- n_new = whisper_wrap_segment(*ctx, params.max_len);
4073
  }
4074
  }
4075
  if (params.new_segment_callback) {
@@ -4113,7 +4142,7 @@ int whisper_full(
4113
  *ctx, result_all.size() - 1, params.thold_pt, params.thold_ptsum);
4114
 
4115
  if (params.max_len > 0) {
4116
- n_new = whisper_wrap_segment(*ctx, params.max_len);
4117
  }
4118
  }
4119
  if (params.new_segment_callback) {
 
2922
  /*.thold_pt =*/ 0.01f,
2923
  /*.thold_ptsum =*/ 0.01f,
2924
  /*.max_len =*/ 0,
2925
+ /*.split_on_word =*/ false,
2926
  /*.max_tokens =*/ 0,
2927
 
2928
  /*.speed_up =*/ false,
 
2989
  float thold_pt,
2990
  float thold_ptsum);
2991
 
2992
+ // trim from start (in place)
2993
+ static inline void ltrim(std::string &s) {
2994
+ s.erase(s.begin(), std::find_if(s.begin(), s.end(), [](unsigned char ch) {
2995
+ return !std::isspace(ch);
2996
+ }));
2997
+ }
2998
+
2999
+ // trim from end (in place)
3000
+ static inline void rtrim(std::string &s) {
3001
+ s.erase(std::find_if(s.rbegin(), s.rend(), [](unsigned char ch) {
3002
+ return !std::isspace(ch);
3003
+ }).base(), s.end());
3004
+ }
3005
+
3006
+ // trim from both ends (in place)
3007
+ static inline void trim(std::string &s) {
3008
+ rtrim(s);
3009
+ ltrim(s);
3010
+ }
3011
+
3012
+ static inline bool should_split_on_word(const char * txt, bool split_on_word) {
3013
+ if (!split_on_word) return true;
3014
+
3015
+ std::string s = txt;
3016
+ return s.substr(0, 1) == " ";
3017
+ }
3018
+
3019
  // wrap the last segment to max_len characters
3020
  // returns the number of new segments
3021
+ static int whisper_wrap_segment(struct whisper_context & ctx, int max_len, bool split_on_word) {
3022
  auto segment = ctx.result_all.back();
3023
 
3024
  int res = 1;
 
3033
  }
3034
 
3035
  const auto txt = whisper_token_to_str(&ctx, token.id);
 
3036
  const int cur = strlen(txt);
3037
 
3038
+ if (acc + cur > max_len && i > 0 && should_split_on_word(txt, split_on_word)) {
3039
  // split here
3040
+ trim(text);
3041
  ctx.result_all.back().text = std::move(text);
3042
  ctx.result_all.back().t1 = token.t0;
3043
  ctx.result_all.back().tokens.resize(i);
 
3065
  }
3066
  }
3067
 
3068
+ trim(text);
3069
  ctx.result_all.back().text = std::move(text);
3070
 
3071
  return res;
 
4098
  *ctx, result_all.size() - 1, params.thold_pt, params.thold_ptsum);
4099
 
4100
  if (params.max_len > 0) {
4101
+ n_new = whisper_wrap_segment(*ctx, params.max_len, params.split_on_word);
4102
  }
4103
  }
4104
  if (params.new_segment_callback) {
 
4142
  *ctx, result_all.size() - 1, params.thold_pt, params.thold_ptsum);
4143
 
4144
  if (params.max_len > 0) {
4145
+ n_new = whisper_wrap_segment(*ctx, params.max_len, params.split_on_word);
4146
  }
4147
  }
4148
  if (params.new_segment_callback) {
whisper.h CHANGED
@@ -257,6 +257,7 @@ extern "C" {
257
  float thold_pt; // timestamp token probability threshold (~0.01)
258
  float thold_ptsum; // timestamp token sum probability threshold (~0.01)
259
  int max_len; // max segment length in characters
 
260
  int max_tokens; // max tokens per segment (0 = no limit)
261
 
262
  // [EXPERIMENTAL] speed-up techniques
 
257
  float thold_pt; // timestamp token probability threshold (~0.01)
258
  float thold_ptsum; // timestamp token sum probability threshold (~0.01)
259
  int max_len; // max segment length in characters
260
+ bool split_on_word; // split on word rather than on token (when used with max_len)
261
  int max_tokens; // max tokens per segment (0 = no limit)
262
 
263
  // [EXPERIMENTAL] speed-up techniques