ggerganov commited on
Commit
83926f7
·
1 Parent(s): e004a9e

whisper : add new-segment callback

Browse files

Can be used to process new segments as they are being generated.
Sample usage in main, for printing the resulting segments during the
inference.

Files changed (3) hide show
  1. main.cpp +56 -39
  2. whisper.cpp +16 -0
  3. whisper.h +9 -0
main.cpp CHANGED
@@ -141,6 +141,55 @@ void whisper_print_usage(int argc, char ** argv, const whisper_params & params)
141
  fprintf(stderr, "\n");
142
  }
143
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
144
  bool output_txt(struct whisper_context * ctx, const char * fname) {
145
  std::ofstream fout(fname);
146
  if (!fout.is_open()) {
@@ -294,7 +343,7 @@ int main(int argc, char ** argv) {
294
  {
295
  whisper_full_params wparams = whisper_full_default_params(WHISPER_SAMPLING_GREEDY);
296
 
297
- wparams.print_realtime = !params.print_colors;
298
  wparams.print_progress = false;
299
  wparams.print_timestamps = !params.no_timestamps;
300
  wparams.print_special_tokens = params.print_special_tokens;
@@ -303,49 +352,17 @@ int main(int argc, char ** argv) {
303
  wparams.n_threads = params.n_threads;
304
  wparams.offset_ms = params.offset_t_ms;
305
 
 
 
 
 
 
 
306
  if (whisper_full(ctx, wparams, pcmf32.data(), pcmf32.size()) != 0) {
307
  fprintf(stderr, "%s: failed to process audio\n", argv[0]);
308
  return 7;
309
  }
310
 
311
- // print result
312
- if (!wparams.print_realtime) {
313
- printf("\n");
314
-
315
- const int n_segments = whisper_full_n_segments(ctx);
316
- for (int i = 0; i < n_segments; ++i) {
317
- if (params.no_timestamps) {
318
- if (params.print_colors) {
319
- // TODO
320
- } else {
321
- const char * text = whisper_full_get_segment_text(ctx, i);
322
- printf("%s", text);
323
- fflush(stdout);
324
- }
325
- } else {
326
- const int64_t t0 = whisper_full_get_segment_t0(ctx, i);
327
- const int64_t t1 = whisper_full_get_segment_t1(ctx, i);
328
-
329
- if (params.print_colors) {
330
- printf("[%s --> %s] ", to_timestamp(t0).c_str(), to_timestamp(t1).c_str());
331
- for (int j = 0; j < whisper_full_n_tokens(ctx, i); ++j) {
332
- const char * text = whisper_full_get_token_text(ctx, i, j);
333
- const float p = whisper_full_get_token_p (ctx, i, j);
334
-
335
- const int col = std::max(0, std::min((int) k_colors.size(), (int) (std::pow(p, 3)*float(k_colors.size()))));
336
-
337
- printf("%s%s%s", k_colors[col].c_str(), text, "\033[0m");
338
- }
339
- printf("\n");
340
- } else {
341
- const char * text = whisper_full_get_segment_text(ctx, i);
342
-
343
- printf("[%s --> %s] %s\n", to_timestamp(t0).c_str(), to_timestamp(t1).c_str(), text);
344
- }
345
- }
346
- }
347
- }
348
-
349
  printf("\n");
350
 
351
  // output to text file
 
141
  fprintf(stderr, "\n");
142
  }
143
 
144
+ void whisper_print_segment_callback(struct whisper_context * ctx, void * user_data) {
145
+ const whisper_params & params = *(whisper_params *) user_data;
146
+
147
+ const int n_segments = whisper_full_n_segments(ctx);
148
+
149
+ // print the last segment
150
+ const int i = n_segments - 1;
151
+ if (i == 0) {
152
+ printf("\n");
153
+ }
154
+
155
+ if (params.no_timestamps) {
156
+ if (params.print_colors) {
157
+ // TODO
158
+ } else {
159
+ const char * text = whisper_full_get_segment_text(ctx, i);
160
+ printf("%s", text);
161
+ fflush(stdout);
162
+ }
163
+ } else {
164
+ const int64_t t0 = whisper_full_get_segment_t0(ctx, i);
165
+ const int64_t t1 = whisper_full_get_segment_t1(ctx, i);
166
+
167
+ if (params.print_colors) {
168
+ printf("[%s --> %s] ", to_timestamp(t0).c_str(), to_timestamp(t1).c_str());
169
+ for (int j = 0; j < whisper_full_n_tokens(ctx, i); ++j) {
170
+ if (params.print_special_tokens == false) {
171
+ const whisper_token id = whisper_full_get_token_id(ctx, i, j);
172
+ if (id >= whisper_token_eot(ctx)) {
173
+ continue;
174
+ }
175
+ }
176
+
177
+ const char * text = whisper_full_get_token_text(ctx, i, j);
178
+ const float p = whisper_full_get_token_p (ctx, i, j);
179
+
180
+ const int col = std::max(0, std::min((int) k_colors.size(), (int) (std::pow(p, 3)*float(k_colors.size()))));
181
+
182
+ printf("%s%s%s", k_colors[col].c_str(), text, "\033[0m");
183
+ }
184
+ printf("\n");
185
+ } else {
186
+ const char * text = whisper_full_get_segment_text(ctx, i);
187
+
188
+ printf("[%s --> %s] %s\n", to_timestamp(t0).c_str(), to_timestamp(t1).c_str(), text);
189
+ }
190
+ }
191
+ }
192
+
193
  bool output_txt(struct whisper_context * ctx, const char * fname) {
194
  std::ofstream fout(fname);
195
  if (!fout.is_open()) {
 
343
  {
344
  whisper_full_params wparams = whisper_full_default_params(WHISPER_SAMPLING_GREEDY);
345
 
346
+ wparams.print_realtime = false;
347
  wparams.print_progress = false;
348
  wparams.print_timestamps = !params.no_timestamps;
349
  wparams.print_special_tokens = params.print_special_tokens;
 
352
  wparams.n_threads = params.n_threads;
353
  wparams.offset_ms = params.offset_t_ms;
354
 
355
+ // this callback is called on each new segment
356
+ if (!wparams.print_realtime) {
357
+ wparams.new_segment_callback = whisper_print_segment_callback;
358
+ wparams.new_segment_callback_user_data = &params;
359
+ }
360
+
361
  if (whisper_full(ctx, wparams, pcmf32.data(), pcmf32.size()) != 0) {
362
  fprintf(stderr, "%s: failed to process audio\n", argv[0]);
363
  return 7;
364
  }
365
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
366
  printf("\n");
367
 
368
  // output to text file
whisper.cpp CHANGED
@@ -2320,6 +2320,9 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
2320
  /*.beam_width =*/ -1,
2321
  /*.n_best =*/ -1,
2322
  },
 
 
 
2323
  };
2324
  } break;
2325
  case WHISPER_SAMPLING_BEAM_SEARCH:
@@ -2348,6 +2351,9 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
2348
  /*.beam_width =*/ 10,
2349
  /*.n_best =*/ 5,
2350
  },
 
 
 
2351
  };
2352
  } break;
2353
  }
@@ -2549,6 +2555,9 @@ int whisper_full(
2549
  for (int j = i0; j <= i; j++) {
2550
  result_all.back().tokens.push_back(tokens_cur[j]);
2551
  }
 
 
 
2552
  }
2553
  text = "";
2554
  while (i < (int) tokens_cur.size() && tokens_cur[i].id > whisper_token_beg(ctx)) {
@@ -2576,6 +2585,9 @@ int whisper_full(
2576
  for (int j = i0; j < (int) tokens_cur.size(); j++) {
2577
  result_all.back().tokens.push_back(tokens_cur[j]);
2578
  }
 
 
 
2579
  }
2580
  }
2581
 
@@ -2609,6 +2621,10 @@ const char * whisper_full_get_token_text(struct whisper_context * ctx, int i_seg
2609
  return ctx->vocab.id_to_token[ctx->result_all[i_segment].tokens[i_token].id].c_str();
2610
  }
2611
 
 
 
 
 
2612
  float whisper_full_get_token_p(struct whisper_context * ctx, int i_segment, int i_token) {
2613
  return ctx->result_all[i_segment].tokens[i_token].p;
2614
  }
 
2320
  /*.beam_width =*/ -1,
2321
  /*.n_best =*/ -1,
2322
  },
2323
+
2324
+ /*.new_segment_callback =*/ nullptr,
2325
+ /*.new_segment_callback_user_data =*/ nullptr,
2326
  };
2327
  } break;
2328
  case WHISPER_SAMPLING_BEAM_SEARCH:
 
2351
  /*.beam_width =*/ 10,
2352
  /*.n_best =*/ 5,
2353
  },
2354
+
2355
+ /*.new_segment_callback =*/ nullptr,
2356
+ /*.new_segment_callback_user_data =*/ nullptr,
2357
  };
2358
  } break;
2359
  }
 
2555
  for (int j = i0; j <= i; j++) {
2556
  result_all.back().tokens.push_back(tokens_cur[j]);
2557
  }
2558
+ if (params.new_segment_callback) {
2559
+ params.new_segment_callback(ctx, params.new_segment_callback_user_data);
2560
+ }
2561
  }
2562
  text = "";
2563
  while (i < (int) tokens_cur.size() && tokens_cur[i].id > whisper_token_beg(ctx)) {
 
2585
  for (int j = i0; j < (int) tokens_cur.size(); j++) {
2586
  result_all.back().tokens.push_back(tokens_cur[j]);
2587
  }
2588
+ if (params.new_segment_callback) {
2589
+ params.new_segment_callback(ctx, params.new_segment_callback_user_data);
2590
+ }
2591
  }
2592
  }
2593
 
 
2621
  return ctx->vocab.id_to_token[ctx->result_all[i_segment].tokens[i_token].id].c_str();
2622
  }
2623
 
2624
+ whisper_token whisper_full_get_token_id(struct whisper_context * ctx, int i_segment, int i_token) {
2625
+ return ctx->result_all[i_segment].tokens[i_token].id;
2626
+ }
2627
+
2628
  float whisper_full_get_token_p(struct whisper_context * ctx, int i_segment, int i_token) {
2629
  return ctx->result_all[i_segment].tokens[i_token].p;
2630
  }
whisper.h CHANGED
@@ -160,6 +160,11 @@ extern "C" {
160
  WHISPER_SAMPLING_BEAM_SEARCH, // TODO: not implemented yet!
161
  };
162
 
 
 
 
 
 
163
  struct whisper_full_params {
164
  enum whisper_sampling_strategy strategy;
165
 
@@ -184,6 +189,9 @@ extern "C" {
184
  int beam_width;
185
  int n_best;
186
  } beam_search;
 
 
 
187
  };
188
 
189
  WHISPER_API struct whisper_full_params whisper_full_default_params(enum whisper_sampling_strategy strategy);
@@ -212,6 +220,7 @@ extern "C" {
212
 
213
  // Get the token text of the specified token in the specified segment.
214
  WHISPER_API const char * whisper_full_get_token_text(struct whisper_context * ctx, int i_segment, int i_token);
 
215
 
216
  // Get the probability of the specified token in the specified segment.
217
  WHISPER_API float whisper_full_get_token_p(struct whisper_context * ctx, int i_segment, int i_token);
 
160
  WHISPER_SAMPLING_BEAM_SEARCH, // TODO: not implemented yet!
161
  };
162
 
163
+ // Text segment callback
164
+ // Called on every newly generated text segment
165
+ // Use the whisper_full_...() functions to obtain the text segments
166
+ typedef void (*whisper_new_segment_callback)(struct whisper_context * ctx, void * user_data);
167
+
168
  struct whisper_full_params {
169
  enum whisper_sampling_strategy strategy;
170
 
 
189
  int beam_width;
190
  int n_best;
191
  } beam_search;
192
+
193
+ whisper_new_segment_callback new_segment_callback;
194
+ void * new_segment_callback_user_data;
195
  };
196
 
197
  WHISPER_API struct whisper_full_params whisper_full_default_params(enum whisper_sampling_strategy strategy);
 
220
 
221
  // Get the token text of the specified token in the specified segment.
222
  WHISPER_API const char * whisper_full_get_token_text(struct whisper_context * ctx, int i_segment, int i_token);
223
+ WHISPER_API whisper_token whisper_full_get_token_id (struct whisper_context * ctx, int i_segment, int i_token);
224
 
225
  // Get the probability of the specified token in the specified segment.
226
  WHISPER_API float whisper_full_get_token_p(struct whisper_context * ctx, int i_segment, int i_token);