Spaces:
Running
Running
Matija Pevec
commited on
whisper : add "split_on_word" flag when using using "max_len" option (#455)
Browse files* Update whisper.cpp
* fix: trim function
* feat: added flag to split on word
* fix: arguments for main
- examples/main/main.cpp +4 -0
- whisper.cpp +34 -5
- whisper.h +1 -0
examples/main/main.cpp
CHANGED
|
@@ -69,6 +69,7 @@ struct whisper_params {
|
|
| 69 |
bool speed_up = false;
|
| 70 |
bool translate = false;
|
| 71 |
bool diarize = false;
|
|
|
|
| 72 |
bool no_fallback = false;
|
| 73 |
bool output_txt = false;
|
| 74 |
bool output_vtt = false;
|
|
@@ -118,6 +119,7 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
|
|
| 118 |
else if (arg == "-su" || arg == "--speed-up") { params.speed_up = true; }
|
| 119 |
else if (arg == "-tr" || arg == "--translate") { params.translate = true; }
|
| 120 |
else if (arg == "-di" || arg == "--diarize") { params.diarize = true; }
|
|
|
|
| 121 |
else if (arg == "-nf" || arg == "--no-fallback") { params.no_fallback = true; }
|
| 122 |
else if (arg == "-otxt" || arg == "--output-txt") { params.output_txt = true; }
|
| 123 |
else if (arg == "-ovtt" || arg == "--output-vtt") { params.output_vtt = true; }
|
|
@@ -156,6 +158,7 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
|
|
| 156 |
fprintf(stderr, " -d N, --duration N [%-7d] duration of audio to process in milliseconds\n", params.duration_ms);
|
| 157 |
fprintf(stderr, " -mc N, --max-context N [%-7d] maximum number of text context tokens to store\n", params.max_context);
|
| 158 |
fprintf(stderr, " -ml N, --max-len N [%-7d] maximum segment length in characters\n", params.max_len);
|
|
|
|
| 159 |
fprintf(stderr, " -bo N, --best-of N [%-7d] number of best candidates to keep\n", params.best_of);
|
| 160 |
fprintf(stderr, " -bs N, --beam-size N [%-7d] beam size for beam search\n", params.beam_size);
|
| 161 |
fprintf(stderr, " -wt N, --word-thold N [%-7.2f] word timestamp probability threshold\n", params.word_thold);
|
|
@@ -651,6 +654,7 @@ int main(int argc, char ** argv) {
|
|
| 651 |
wparams.token_timestamps = params.output_wts || params.max_len > 0;
|
| 652 |
wparams.thold_pt = params.word_thold;
|
| 653 |
wparams.max_len = params.output_wts && params.max_len == 0 ? 60 : params.max_len;
|
|
|
|
| 654 |
|
| 655 |
wparams.speed_up = params.speed_up;
|
| 656 |
|
|
|
|
| 69 |
bool speed_up = false;
|
| 70 |
bool translate = false;
|
| 71 |
bool diarize = false;
|
| 72 |
+
bool split_on_word = false;
|
| 73 |
bool no_fallback = false;
|
| 74 |
bool output_txt = false;
|
| 75 |
bool output_vtt = false;
|
|
|
|
| 119 |
else if (arg == "-su" || arg == "--speed-up") { params.speed_up = true; }
|
| 120 |
else if (arg == "-tr" || arg == "--translate") { params.translate = true; }
|
| 121 |
else if (arg == "-di" || arg == "--diarize") { params.diarize = true; }
|
| 122 |
+
else if (arg == "-sow" || arg == "--split-on-word") { params.split_on_word = true; }
|
| 123 |
else if (arg == "-nf" || arg == "--no-fallback") { params.no_fallback = true; }
|
| 124 |
else if (arg == "-otxt" || arg == "--output-txt") { params.output_txt = true; }
|
| 125 |
else if (arg == "-ovtt" || arg == "--output-vtt") { params.output_vtt = true; }
|
|
|
|
| 158 |
fprintf(stderr, " -d N, --duration N [%-7d] duration of audio to process in milliseconds\n", params.duration_ms);
|
| 159 |
fprintf(stderr, " -mc N, --max-context N [%-7d] maximum number of text context tokens to store\n", params.max_context);
|
| 160 |
fprintf(stderr, " -ml N, --max-len N [%-7d] maximum segment length in characters\n", params.max_len);
|
| 161 |
+
fprintf(stderr, " -sow, --split-on-word [%-7s] split on word rather than on token\n", params.split_on_word ? "true" : "false");
|
| 162 |
fprintf(stderr, " -bo N, --best-of N [%-7d] number of best candidates to keep\n", params.best_of);
|
| 163 |
fprintf(stderr, " -bs N, --beam-size N [%-7d] beam size for beam search\n", params.beam_size);
|
| 164 |
fprintf(stderr, " -wt N, --word-thold N [%-7.2f] word timestamp probability threshold\n", params.word_thold);
|
|
|
|
| 654 |
wparams.token_timestamps = params.output_wts || params.max_len > 0;
|
| 655 |
wparams.thold_pt = params.word_thold;
|
| 656 |
wparams.max_len = params.output_wts && params.max_len == 0 ? 60 : params.max_len;
|
| 657 |
+
wparams.split_on_word = params.split_on_word;
|
| 658 |
|
| 659 |
wparams.speed_up = params.speed_up;
|
| 660 |
|
whisper.cpp
CHANGED
|
@@ -2922,6 +2922,7 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
|
|
| 2922 |
/*.thold_pt =*/ 0.01f,
|
| 2923 |
/*.thold_ptsum =*/ 0.01f,
|
| 2924 |
/*.max_len =*/ 0,
|
|
|
|
| 2925 |
/*.max_tokens =*/ 0,
|
| 2926 |
|
| 2927 |
/*.speed_up =*/ false,
|
|
@@ -2988,9 +2989,36 @@ static void whisper_exp_compute_token_level_timestamps(
|
|
| 2988 |
float thold_pt,
|
| 2989 |
float thold_ptsum);
|
| 2990 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2991 |
// wrap the last segment to max_len characters
|
| 2992 |
// returns the number of new segments
|
| 2993 |
-
static int whisper_wrap_segment(struct whisper_context & ctx, int max_len) {
|
| 2994 |
auto segment = ctx.result_all.back();
|
| 2995 |
|
| 2996 |
int res = 1;
|
|
@@ -3005,11 +3033,11 @@ static int whisper_wrap_segment(struct whisper_context & ctx, int max_len) {
|
|
| 3005 |
}
|
| 3006 |
|
| 3007 |
const auto txt = whisper_token_to_str(&ctx, token.id);
|
| 3008 |
-
|
| 3009 |
const int cur = strlen(txt);
|
| 3010 |
|
| 3011 |
-
if (acc + cur > max_len && i > 0) {
|
| 3012 |
// split here
|
|
|
|
| 3013 |
ctx.result_all.back().text = std::move(text);
|
| 3014 |
ctx.result_all.back().t1 = token.t0;
|
| 3015 |
ctx.result_all.back().tokens.resize(i);
|
|
@@ -3037,6 +3065,7 @@ static int whisper_wrap_segment(struct whisper_context & ctx, int max_len) {
|
|
| 3037 |
}
|
| 3038 |
}
|
| 3039 |
|
|
|
|
| 3040 |
ctx.result_all.back().text = std::move(text);
|
| 3041 |
|
| 3042 |
return res;
|
|
@@ -4069,7 +4098,7 @@ int whisper_full(
|
|
| 4069 |
*ctx, result_all.size() - 1, params.thold_pt, params.thold_ptsum);
|
| 4070 |
|
| 4071 |
if (params.max_len > 0) {
|
| 4072 |
-
n_new = whisper_wrap_segment(*ctx, params.max_len);
|
| 4073 |
}
|
| 4074 |
}
|
| 4075 |
if (params.new_segment_callback) {
|
|
@@ -4113,7 +4142,7 @@ int whisper_full(
|
|
| 4113 |
*ctx, result_all.size() - 1, params.thold_pt, params.thold_ptsum);
|
| 4114 |
|
| 4115 |
if (params.max_len > 0) {
|
| 4116 |
-
n_new = whisper_wrap_segment(*ctx, params.max_len);
|
| 4117 |
}
|
| 4118 |
}
|
| 4119 |
if (params.new_segment_callback) {
|
|
|
|
| 2922 |
/*.thold_pt =*/ 0.01f,
|
| 2923 |
/*.thold_ptsum =*/ 0.01f,
|
| 2924 |
/*.max_len =*/ 0,
|
| 2925 |
+
/*.split_on_word =*/ false,
|
| 2926 |
/*.max_tokens =*/ 0,
|
| 2927 |
|
| 2928 |
/*.speed_up =*/ false,
|
|
|
|
| 2989 |
float thold_pt,
|
| 2990 |
float thold_ptsum);
|
| 2991 |
|
| 2992 |
+
// trim from start (in place)
|
| 2993 |
+
static inline void ltrim(std::string &s) {
|
| 2994 |
+
s.erase(s.begin(), std::find_if(s.begin(), s.end(), [](unsigned char ch) {
|
| 2995 |
+
return !std::isspace(ch);
|
| 2996 |
+
}));
|
| 2997 |
+
}
|
| 2998 |
+
|
| 2999 |
+
// trim from end (in place)
|
| 3000 |
+
static inline void rtrim(std::string &s) {
|
| 3001 |
+
s.erase(std::find_if(s.rbegin(), s.rend(), [](unsigned char ch) {
|
| 3002 |
+
return !std::isspace(ch);
|
| 3003 |
+
}).base(), s.end());
|
| 3004 |
+
}
|
| 3005 |
+
|
| 3006 |
+
// trim from both ends (in place)
|
| 3007 |
+
static inline void trim(std::string &s) {
|
| 3008 |
+
rtrim(s);
|
| 3009 |
+
ltrim(s);
|
| 3010 |
+
}
|
| 3011 |
+
|
| 3012 |
+
static inline bool should_split_on_word(const char * txt, bool split_on_word) {
|
| 3013 |
+
if (!split_on_word) return true;
|
| 3014 |
+
|
| 3015 |
+
std::string s = txt;
|
| 3016 |
+
return s.substr(0, 1) == " ";
|
| 3017 |
+
}
|
| 3018 |
+
|
| 3019 |
// wrap the last segment to max_len characters
|
| 3020 |
// returns the number of new segments
|
| 3021 |
+
static int whisper_wrap_segment(struct whisper_context & ctx, int max_len, bool split_on_word) {
|
| 3022 |
auto segment = ctx.result_all.back();
|
| 3023 |
|
| 3024 |
int res = 1;
|
|
|
|
| 3033 |
}
|
| 3034 |
|
| 3035 |
const auto txt = whisper_token_to_str(&ctx, token.id);
|
|
|
|
| 3036 |
const int cur = strlen(txt);
|
| 3037 |
|
| 3038 |
+
if (acc + cur > max_len && i > 0 && should_split_on_word(txt, split_on_word)) {
|
| 3039 |
// split here
|
| 3040 |
+
trim(text);
|
| 3041 |
ctx.result_all.back().text = std::move(text);
|
| 3042 |
ctx.result_all.back().t1 = token.t0;
|
| 3043 |
ctx.result_all.back().tokens.resize(i);
|
|
|
|
| 3065 |
}
|
| 3066 |
}
|
| 3067 |
|
| 3068 |
+
trim(text);
|
| 3069 |
ctx.result_all.back().text = std::move(text);
|
| 3070 |
|
| 3071 |
return res;
|
|
|
|
| 4098 |
*ctx, result_all.size() - 1, params.thold_pt, params.thold_ptsum);
|
| 4099 |
|
| 4100 |
if (params.max_len > 0) {
|
| 4101 |
+
n_new = whisper_wrap_segment(*ctx, params.max_len, params.split_on_word);
|
| 4102 |
}
|
| 4103 |
}
|
| 4104 |
if (params.new_segment_callback) {
|
|
|
|
| 4142 |
*ctx, result_all.size() - 1, params.thold_pt, params.thold_ptsum);
|
| 4143 |
|
| 4144 |
if (params.max_len > 0) {
|
| 4145 |
+
n_new = whisper_wrap_segment(*ctx, params.max_len, params.split_on_word);
|
| 4146 |
}
|
| 4147 |
}
|
| 4148 |
if (params.new_segment_callback) {
|
whisper.h
CHANGED
|
@@ -257,6 +257,7 @@ extern "C" {
|
|
| 257 |
float thold_pt; // timestamp token probability threshold (~0.01)
|
| 258 |
float thold_ptsum; // timestamp token sum probability threshold (~0.01)
|
| 259 |
int max_len; // max segment length in characters
|
|
|
|
| 260 |
int max_tokens; // max tokens per segment (0 = no limit)
|
| 261 |
|
| 262 |
// [EXPERIMENTAL] speed-up techniques
|
|
|
|
| 257 |
float thold_pt; // timestamp token probability threshold (~0.01)
|
| 258 |
float thold_ptsum; // timestamp token sum probability threshold (~0.01)
|
| 259 |
int max_len; // max segment length in characters
|
| 260 |
+
bool split_on_word; // split on word rather than on token (when used with max_len)
|
| 261 |
int max_tokens; // max tokens per segment (0 = no limit)
|
| 262 |
|
| 263 |
// [EXPERIMENTAL] speed-up techniques
|