ggerganov commited on
Commit
d311de4
·
1 Parent(s): c7f9e5b

whisper : add mechanism for aborting the whisper_full() computation

Browse files
Files changed (3) hide show
  1. examples/main/main.cpp +13 -0
  2. whisper.cpp +13 -0
  3. whisper.h +11 -0
examples/main/main.cpp CHANGED
@@ -607,6 +607,19 @@ int main(int argc, char ** argv) {
607
  wparams.new_segment_callback_user_data = &user_data;
608
  }
609
 
 
 
 
 
 
 
 
 
 
 
 
 
 
610
  if (whisper_full_parallel(ctx, wparams, pcmf32.data(), pcmf32.size(), params.n_processors) != 0) {
611
  fprintf(stderr, "%s: failed to process audio\n", argv[0]);
612
  return 10;
 
607
  wparams.new_segment_callback_user_data = &user_data;
608
  }
609
 
610
+ // example for abort mechanism
611
+ // in this example, we do not abort the processing, but we could if the flag is set to true
612
+ // the callback is called before every encoder run - if it returns false, the processing is aborted
613
+ {
614
+ static bool is_aborted = false; // NOTE: this should be atomic to avoid data race
615
+
616
+ wparams.encoder_begin_callback = [](struct whisper_context * ctx, void * user_data) {
617
+ bool is_aborted = *(bool*)user_data;
618
+ return !is_aborted;
619
+ };
620
+ wparams.encoder_begin_callback_user_data = &is_aborted;
621
+ }
622
+
623
  if (whisper_full_parallel(ctx, wparams, pcmf32.data(), pcmf32.size(), params.n_processors) != 0) {
624
  fprintf(stderr, "%s: failed to process audio\n", argv[0]);
625
  return 10;
whisper.cpp CHANGED
@@ -2451,6 +2451,9 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
2451
 
2452
  /*.new_segment_callback =*/ nullptr,
2453
  /*.new_segment_callback_user_data =*/ nullptr,
 
 
 
2454
  };
2455
  } break;
2456
  case WHISPER_SAMPLING_BEAM_SEARCH:
@@ -2497,6 +2500,9 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
2497
 
2498
  /*.new_segment_callback =*/ nullptr,
2499
  /*.new_segment_callback_user_data =*/ nullptr,
 
 
 
2500
  };
2501
  } break;
2502
  }
@@ -2659,6 +2665,13 @@ int whisper_full(
2659
  break;
2660
  }
2661
 
 
 
 
 
 
 
 
2662
  // encode audio features starting at offset seek
2663
  if (whisper_encode(ctx, seek, params.n_threads) != 0) {
2664
  fprintf(stderr, "%s: failed to encode\n", __func__);
 
2451
 
2452
  /*.new_segment_callback =*/ nullptr,
2453
  /*.new_segment_callback_user_data =*/ nullptr,
2454
+
2455
+ /*.encoder_begin_callback =*/ nullptr,
2456
+ /*.encoder_begin_callback_user_data =*/ nullptr,
2457
  };
2458
  } break;
2459
  case WHISPER_SAMPLING_BEAM_SEARCH:
 
2500
 
2501
  /*.new_segment_callback =*/ nullptr,
2502
  /*.new_segment_callback_user_data =*/ nullptr,
2503
+
2504
+ /*.encoder_begin_callback =*/ nullptr,
2505
+ /*.encoder_begin_callback_user_data =*/ nullptr,
2506
  };
2507
  } break;
2508
  }
 
2665
  break;
2666
  }
2667
 
2668
+ if (params.encoder_begin_callback) {
2669
+ if (params.encoder_begin_callback(ctx, params.encoder_begin_callback_user_data) == false) {
2670
+ fprintf(stderr, "%s: encoder_begin_callback returned false - aborting\n", __func__);
2671
+ break;
2672
+ }
2673
+ }
2674
+
2675
  // encode audio features starting at offset seek
2676
  if (whisper_encode(ctx, seek, params.n_threads) != 0) {
2677
  fprintf(stderr, "%s: failed to encode\n", __func__);
whisper.h CHANGED
@@ -185,6 +185,14 @@ extern "C" {
185
  // Use the whisper_full_...() functions to obtain the text segments
186
  typedef void (*whisper_new_segment_callback)(struct whisper_context * ctx, int n_new, void * user_data);
187
 
 
 
 
 
 
 
 
 
188
  struct whisper_full_params {
189
  enum whisper_sampling_strategy strategy;
190
 
@@ -231,6 +239,9 @@ extern "C" {
231
 
232
  whisper_new_segment_callback new_segment_callback;
233
  void * new_segment_callback_user_data;
 
 
 
234
  };
235
 
236
  WHISPER_API struct whisper_full_params whisper_full_default_params(enum whisper_sampling_strategy strategy);
 
185
  // Use the whisper_full_...() functions to obtain the text segments
186
  typedef void (*whisper_new_segment_callback)(struct whisper_context * ctx, int n_new, void * user_data);
187
 
188
+ // Encoder begin callback
189
+ // If not NULL, called before the encoder starts
190
+ // If it returns false, the computation is aborted
191
+ typedef bool (*whisper_encoder_begin_callback)(struct whisper_context * ctx, void * user_data);
192
+
193
+ // Parameters for the whisper_full() function
194
+ // If you chnage the order or add new parameters, make sure to update the default values in whisper.cpp:
195
+ // whisper_full_default_params()
196
  struct whisper_full_params {
197
  enum whisper_sampling_strategy strategy;
198
 
 
239
 
240
  whisper_new_segment_callback new_segment_callback;
241
  void * new_segment_callback_user_data;
242
+
243
+ whisper_encoder_begin_callback encoder_begin_callback;
244
+ void * encoder_begin_callback_user_data;
245
  };
246
 
247
  WHISPER_API struct whisper_full_params whisper_full_default_params(enum whisper_sampling_strategy strategy);