Spaces:
Sleeping
Sleeping
whisper : add new-segment callback
Browse filesCan be used to process new segments as they are being generated.
Sample usage in main, for printing the resulting segments during the
inference.
- main.cpp +56 -39
- whisper.cpp +16 -0
- whisper.h +9 -0
main.cpp
CHANGED
|
@@ -141,6 +141,55 @@ void whisper_print_usage(int argc, char ** argv, const whisper_params & params)
|
|
| 141 |
fprintf(stderr, "\n");
|
| 142 |
}
|
| 143 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 144 |
bool output_txt(struct whisper_context * ctx, const char * fname) {
|
| 145 |
std::ofstream fout(fname);
|
| 146 |
if (!fout.is_open()) {
|
|
@@ -294,7 +343,7 @@ int main(int argc, char ** argv) {
|
|
| 294 |
{
|
| 295 |
whisper_full_params wparams = whisper_full_default_params(WHISPER_SAMPLING_GREEDY);
|
| 296 |
|
| 297 |
-
wparams.print_realtime =
|
| 298 |
wparams.print_progress = false;
|
| 299 |
wparams.print_timestamps = !params.no_timestamps;
|
| 300 |
wparams.print_special_tokens = params.print_special_tokens;
|
|
@@ -303,49 +352,17 @@ int main(int argc, char ** argv) {
|
|
| 303 |
wparams.n_threads = params.n_threads;
|
| 304 |
wparams.offset_ms = params.offset_t_ms;
|
| 305 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 306 |
if (whisper_full(ctx, wparams, pcmf32.data(), pcmf32.size()) != 0) {
|
| 307 |
fprintf(stderr, "%s: failed to process audio\n", argv[0]);
|
| 308 |
return 7;
|
| 309 |
}
|
| 310 |
|
| 311 |
-
// print result
|
| 312 |
-
if (!wparams.print_realtime) {
|
| 313 |
-
printf("\n");
|
| 314 |
-
|
| 315 |
-
const int n_segments = whisper_full_n_segments(ctx);
|
| 316 |
-
for (int i = 0; i < n_segments; ++i) {
|
| 317 |
-
if (params.no_timestamps) {
|
| 318 |
-
if (params.print_colors) {
|
| 319 |
-
// TODO
|
| 320 |
-
} else {
|
| 321 |
-
const char * text = whisper_full_get_segment_text(ctx, i);
|
| 322 |
-
printf("%s", text);
|
| 323 |
-
fflush(stdout);
|
| 324 |
-
}
|
| 325 |
-
} else {
|
| 326 |
-
const int64_t t0 = whisper_full_get_segment_t0(ctx, i);
|
| 327 |
-
const int64_t t1 = whisper_full_get_segment_t1(ctx, i);
|
| 328 |
-
|
| 329 |
-
if (params.print_colors) {
|
| 330 |
-
printf("[%s --> %s] ", to_timestamp(t0).c_str(), to_timestamp(t1).c_str());
|
| 331 |
-
for (int j = 0; j < whisper_full_n_tokens(ctx, i); ++j) {
|
| 332 |
-
const char * text = whisper_full_get_token_text(ctx, i, j);
|
| 333 |
-
const float p = whisper_full_get_token_p (ctx, i, j);
|
| 334 |
-
|
| 335 |
-
const int col = std::max(0, std::min((int) k_colors.size(), (int) (std::pow(p, 3)*float(k_colors.size()))));
|
| 336 |
-
|
| 337 |
-
printf("%s%s%s", k_colors[col].c_str(), text, "\033[0m");
|
| 338 |
-
}
|
| 339 |
-
printf("\n");
|
| 340 |
-
} else {
|
| 341 |
-
const char * text = whisper_full_get_segment_text(ctx, i);
|
| 342 |
-
|
| 343 |
-
printf("[%s --> %s] %s\n", to_timestamp(t0).c_str(), to_timestamp(t1).c_str(), text);
|
| 344 |
-
}
|
| 345 |
-
}
|
| 346 |
-
}
|
| 347 |
-
}
|
| 348 |
-
|
| 349 |
printf("\n");
|
| 350 |
|
| 351 |
// output to text file
|
|
|
|
| 141 |
fprintf(stderr, "\n");
|
| 142 |
}
|
| 143 |
|
| 144 |
+
void whisper_print_segment_callback(struct whisper_context * ctx, void * user_data) {
|
| 145 |
+
const whisper_params & params = *(whisper_params *) user_data;
|
| 146 |
+
|
| 147 |
+
const int n_segments = whisper_full_n_segments(ctx);
|
| 148 |
+
|
| 149 |
+
// print the last segment
|
| 150 |
+
const int i = n_segments - 1;
|
| 151 |
+
if (i == 0) {
|
| 152 |
+
printf("\n");
|
| 153 |
+
}
|
| 154 |
+
|
| 155 |
+
if (params.no_timestamps) {
|
| 156 |
+
if (params.print_colors) {
|
| 157 |
+
// TODO
|
| 158 |
+
} else {
|
| 159 |
+
const char * text = whisper_full_get_segment_text(ctx, i);
|
| 160 |
+
printf("%s", text);
|
| 161 |
+
fflush(stdout);
|
| 162 |
+
}
|
| 163 |
+
} else {
|
| 164 |
+
const int64_t t0 = whisper_full_get_segment_t0(ctx, i);
|
| 165 |
+
const int64_t t1 = whisper_full_get_segment_t1(ctx, i);
|
| 166 |
+
|
| 167 |
+
if (params.print_colors) {
|
| 168 |
+
printf("[%s --> %s] ", to_timestamp(t0).c_str(), to_timestamp(t1).c_str());
|
| 169 |
+
for (int j = 0; j < whisper_full_n_tokens(ctx, i); ++j) {
|
| 170 |
+
if (params.print_special_tokens == false) {
|
| 171 |
+
const whisper_token id = whisper_full_get_token_id(ctx, i, j);
|
| 172 |
+
if (id >= whisper_token_eot(ctx)) {
|
| 173 |
+
continue;
|
| 174 |
+
}
|
| 175 |
+
}
|
| 176 |
+
|
| 177 |
+
const char * text = whisper_full_get_token_text(ctx, i, j);
|
| 178 |
+
const float p = whisper_full_get_token_p (ctx, i, j);
|
| 179 |
+
|
| 180 |
+
const int col = std::max(0, std::min((int) k_colors.size(), (int) (std::pow(p, 3)*float(k_colors.size()))));
|
| 181 |
+
|
| 182 |
+
printf("%s%s%s", k_colors[col].c_str(), text, "\033[0m");
|
| 183 |
+
}
|
| 184 |
+
printf("\n");
|
| 185 |
+
} else {
|
| 186 |
+
const char * text = whisper_full_get_segment_text(ctx, i);
|
| 187 |
+
|
| 188 |
+
printf("[%s --> %s] %s\n", to_timestamp(t0).c_str(), to_timestamp(t1).c_str(), text);
|
| 189 |
+
}
|
| 190 |
+
}
|
| 191 |
+
}
|
| 192 |
+
|
| 193 |
bool output_txt(struct whisper_context * ctx, const char * fname) {
|
| 194 |
std::ofstream fout(fname);
|
| 195 |
if (!fout.is_open()) {
|
|
|
|
| 343 |
{
|
| 344 |
whisper_full_params wparams = whisper_full_default_params(WHISPER_SAMPLING_GREEDY);
|
| 345 |
|
| 346 |
+
wparams.print_realtime = false;
|
| 347 |
wparams.print_progress = false;
|
| 348 |
wparams.print_timestamps = !params.no_timestamps;
|
| 349 |
wparams.print_special_tokens = params.print_special_tokens;
|
|
|
|
| 352 |
wparams.n_threads = params.n_threads;
|
| 353 |
wparams.offset_ms = params.offset_t_ms;
|
| 354 |
|
| 355 |
+
// this callback is called on each new segment
|
| 356 |
+
if (!wparams.print_realtime) {
|
| 357 |
+
wparams.new_segment_callback = whisper_print_segment_callback;
|
| 358 |
+
wparams.new_segment_callback_user_data = ¶ms;
|
| 359 |
+
}
|
| 360 |
+
|
| 361 |
if (whisper_full(ctx, wparams, pcmf32.data(), pcmf32.size()) != 0) {
|
| 362 |
fprintf(stderr, "%s: failed to process audio\n", argv[0]);
|
| 363 |
return 7;
|
| 364 |
}
|
| 365 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 366 |
printf("\n");
|
| 367 |
|
| 368 |
// output to text file
|
whisper.cpp
CHANGED
|
@@ -2320,6 +2320,9 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
|
|
| 2320 |
/*.beam_width =*/ -1,
|
| 2321 |
/*.n_best =*/ -1,
|
| 2322 |
},
|
|
|
|
|
|
|
|
|
|
| 2323 |
};
|
| 2324 |
} break;
|
| 2325 |
case WHISPER_SAMPLING_BEAM_SEARCH:
|
|
@@ -2348,6 +2351,9 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
|
|
| 2348 |
/*.beam_width =*/ 10,
|
| 2349 |
/*.n_best =*/ 5,
|
| 2350 |
},
|
|
|
|
|
|
|
|
|
|
| 2351 |
};
|
| 2352 |
} break;
|
| 2353 |
}
|
|
@@ -2549,6 +2555,9 @@ int whisper_full(
|
|
| 2549 |
for (int j = i0; j <= i; j++) {
|
| 2550 |
result_all.back().tokens.push_back(tokens_cur[j]);
|
| 2551 |
}
|
|
|
|
|
|
|
|
|
|
| 2552 |
}
|
| 2553 |
text = "";
|
| 2554 |
while (i < (int) tokens_cur.size() && tokens_cur[i].id > whisper_token_beg(ctx)) {
|
|
@@ -2576,6 +2585,9 @@ int whisper_full(
|
|
| 2576 |
for (int j = i0; j < (int) tokens_cur.size(); j++) {
|
| 2577 |
result_all.back().tokens.push_back(tokens_cur[j]);
|
| 2578 |
}
|
|
|
|
|
|
|
|
|
|
| 2579 |
}
|
| 2580 |
}
|
| 2581 |
|
|
@@ -2609,6 +2621,10 @@ const char * whisper_full_get_token_text(struct whisper_context * ctx, int i_seg
|
|
| 2609 |
return ctx->vocab.id_to_token[ctx->result_all[i_segment].tokens[i_token].id].c_str();
|
| 2610 |
}
|
| 2611 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2612 |
float whisper_full_get_token_p(struct whisper_context * ctx, int i_segment, int i_token) {
|
| 2613 |
return ctx->result_all[i_segment].tokens[i_token].p;
|
| 2614 |
}
|
|
|
|
| 2320 |
/*.beam_width =*/ -1,
|
| 2321 |
/*.n_best =*/ -1,
|
| 2322 |
},
|
| 2323 |
+
|
| 2324 |
+
/*.new_segment_callback =*/ nullptr,
|
| 2325 |
+
/*.new_segment_callback_user_data =*/ nullptr,
|
| 2326 |
};
|
| 2327 |
} break;
|
| 2328 |
case WHISPER_SAMPLING_BEAM_SEARCH:
|
|
|
|
| 2351 |
/*.beam_width =*/ 10,
|
| 2352 |
/*.n_best =*/ 5,
|
| 2353 |
},
|
| 2354 |
+
|
| 2355 |
+
/*.new_segment_callback =*/ nullptr,
|
| 2356 |
+
/*.new_segment_callback_user_data =*/ nullptr,
|
| 2357 |
};
|
| 2358 |
} break;
|
| 2359 |
}
|
|
|
|
| 2555 |
for (int j = i0; j <= i; j++) {
|
| 2556 |
result_all.back().tokens.push_back(tokens_cur[j]);
|
| 2557 |
}
|
| 2558 |
+
if (params.new_segment_callback) {
|
| 2559 |
+
params.new_segment_callback(ctx, params.new_segment_callback_user_data);
|
| 2560 |
+
}
|
| 2561 |
}
|
| 2562 |
text = "";
|
| 2563 |
while (i < (int) tokens_cur.size() && tokens_cur[i].id > whisper_token_beg(ctx)) {
|
|
|
|
| 2585 |
for (int j = i0; j < (int) tokens_cur.size(); j++) {
|
| 2586 |
result_all.back().tokens.push_back(tokens_cur[j]);
|
| 2587 |
}
|
| 2588 |
+
if (params.new_segment_callback) {
|
| 2589 |
+
params.new_segment_callback(ctx, params.new_segment_callback_user_data);
|
| 2590 |
+
}
|
| 2591 |
}
|
| 2592 |
}
|
| 2593 |
|
|
|
|
| 2621 |
return ctx->vocab.id_to_token[ctx->result_all[i_segment].tokens[i_token].id].c_str();
|
| 2622 |
}
|
| 2623 |
|
| 2624 |
+
whisper_token whisper_full_get_token_id(struct whisper_context * ctx, int i_segment, int i_token) {
|
| 2625 |
+
return ctx->result_all[i_segment].tokens[i_token].id;
|
| 2626 |
+
}
|
| 2627 |
+
|
| 2628 |
float whisper_full_get_token_p(struct whisper_context * ctx, int i_segment, int i_token) {
|
| 2629 |
return ctx->result_all[i_segment].tokens[i_token].p;
|
| 2630 |
}
|
whisper.h
CHANGED
|
@@ -160,6 +160,11 @@ extern "C" {
|
|
| 160 |
WHISPER_SAMPLING_BEAM_SEARCH, // TODO: not implemented yet!
|
| 161 |
};
|
| 162 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 163 |
struct whisper_full_params {
|
| 164 |
enum whisper_sampling_strategy strategy;
|
| 165 |
|
|
@@ -184,6 +189,9 @@ extern "C" {
|
|
| 184 |
int beam_width;
|
| 185 |
int n_best;
|
| 186 |
} beam_search;
|
|
|
|
|
|
|
|
|
|
| 187 |
};
|
| 188 |
|
| 189 |
WHISPER_API struct whisper_full_params whisper_full_default_params(enum whisper_sampling_strategy strategy);
|
|
@@ -212,6 +220,7 @@ extern "C" {
|
|
| 212 |
|
| 213 |
// Get the token text of the specified token in the specified segment.
|
| 214 |
WHISPER_API const char * whisper_full_get_token_text(struct whisper_context * ctx, int i_segment, int i_token);
|
|
|
|
| 215 |
|
| 216 |
// Get the probability of the specified token in the specified segment.
|
| 217 |
WHISPER_API float whisper_full_get_token_p(struct whisper_context * ctx, int i_segment, int i_token);
|
|
|
|
| 160 |
WHISPER_SAMPLING_BEAM_SEARCH, // TODO: not implemented yet!
|
| 161 |
};
|
| 162 |
|
| 163 |
+
// Text segment callback
|
| 164 |
+
// Called on every newly generated text segment
|
| 165 |
+
// Use the whisper_full_...() functions to obtain the text segments
|
| 166 |
+
typedef void (*whisper_new_segment_callback)(struct whisper_context * ctx, void * user_data);
|
| 167 |
+
|
| 168 |
struct whisper_full_params {
|
| 169 |
enum whisper_sampling_strategy strategy;
|
| 170 |
|
|
|
|
| 189 |
int beam_width;
|
| 190 |
int n_best;
|
| 191 |
} beam_search;
|
| 192 |
+
|
| 193 |
+
whisper_new_segment_callback new_segment_callback;
|
| 194 |
+
void * new_segment_callback_user_data;
|
| 195 |
};
|
| 196 |
|
| 197 |
WHISPER_API struct whisper_full_params whisper_full_default_params(enum whisper_sampling_strategy strategy);
|
|
|
|
| 220 |
|
| 221 |
// Get the token text of the specified token in the specified segment.
|
| 222 |
WHISPER_API const char * whisper_full_get_token_text(struct whisper_context * ctx, int i_segment, int i_token);
|
| 223 |
+
WHISPER_API whisper_token whisper_full_get_token_id (struct whisper_context * ctx, int i_segment, int i_token);
|
| 224 |
|
| 225 |
// Get the probability of the specified token in the specified segment.
|
| 226 |
WHISPER_API float whisper_full_get_token_p(struct whisper_context * ctx, int i_segment, int i_token);
|