mkiol commited on
Commit
776adfd
·
unverified ·
1 Parent(s): fb91f57

whisper : abort callback improvements (#1345)

Browse files

* whisper : initialize abort_callback to null

* whisper : add example how to use abort_callback

Files changed (2) hide show
  1. examples/main/main.cpp +14 -2
  2. whisper.cpp +3 -0
examples/main/main.cpp CHANGED
@@ -944,8 +944,9 @@ int main(int argc, char ** argv) {
944
  wparams.progress_callback_user_data = &user_data;
945
  }
946
 
947
- // example for abort mechanism
948
- // in this example, we do not abort the processing, but we could if the flag is set to true
 
949
  // the callback is called before every encoder run - if it returns false, the processing is aborted
950
  {
951
  static bool is_aborted = false; // NOTE: this should be atomic to avoid data race
@@ -957,6 +958,17 @@ int main(int argc, char ** argv) {
957
  wparams.encoder_begin_callback_user_data = &is_aborted;
958
  }
959
 
 
 
 
 
 
 
 
 
 
 
 
960
  if (whisper_full_parallel(ctx, wparams, pcmf32.data(), pcmf32.size(), params.n_processors) != 0) {
961
  fprintf(stderr, "%s: failed to process audio\n", argv[0]);
962
  return 10;
 
944
  wparams.progress_callback_user_data = &user_data;
945
  }
946
 
947
+ // examples for abort mechanism
948
+ // in examples below, we do not abort the processing, but we could if the flag is set to true
949
+
950
  // the callback is called before every encoder run - if it returns false, the processing is aborted
951
  {
952
  static bool is_aborted = false; // NOTE: this should be atomic to avoid data race
 
958
  wparams.encoder_begin_callback_user_data = &is_aborted;
959
  }
960
 
961
+ // the callback is called before every computation - if it returns true, the computation is aborted
962
+ {
963
+ static bool is_aborted = false; // NOTE: this should be atomic to avoid data race
964
+
965
+ wparams.abort_callback = [](void * user_data) {
966
+ bool is_aborted = *(bool*)user_data;
967
+ return is_aborted;
968
+ };
969
+ wparams.abort_callback_user_data = &is_aborted;
970
+ }
971
+
972
  if (whisper_full_parallel(ctx, wparams, pcmf32.data(), pcmf32.size(), params.n_processors) != 0) {
973
  fprintf(stderr, "%s: failed to process audio\n", argv[0]);
974
  return 10;
whisper.cpp CHANGED
@@ -3773,6 +3773,9 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
3773
  /*.encoder_begin_callback =*/ nullptr,
3774
  /*.encoder_begin_callback_user_data =*/ nullptr,
3775
 
 
 
 
3776
  /*.logits_filter_callback =*/ nullptr,
3777
  /*.logits_filter_callback_user_data =*/ nullptr,
3778
  };
 
3773
  /*.encoder_begin_callback =*/ nullptr,
3774
  /*.encoder_begin_callback_user_data =*/ nullptr,
3775
 
3776
+ /*.abort_callback =*/ nullptr,
3777
+ /*.abort_callback_user_data =*/ nullptr,
3778
+
3779
  /*.logits_filter_callback =*/ nullptr,
3780
  /*.logits_filter_callback_user_data =*/ nullptr,
3781
  };