Lin Xiaodong linxiaodong commited on
Commit
3f6a806
·
unverified ·
1 Parent(s): 2220ea9

examples : support progress_callback API for addon.node (#2941)

Browse files

* feat: progress supported

* fix: missing params

* style: Format the code to improve readability

Unified code indentation ensures consistent coding style, enhancing code readability and maintainability.

* feat: support prompt api

---------

Co-authored-by: linxiaodong <[email protected]>

examples/addon.node/__test__/whisper.spec.js CHANGED
@@ -18,6 +18,7 @@ const whisperParamsMock = {
18
  translate: true,
19
  no_timestamps: false,
20
  audio_ctx: 0,
 
21
  };
22
 
23
  describe("Run whisper.node", () => {
 
18
  translate: true,
19
  no_timestamps: false,
20
  audio_ctx: 0,
21
+ max_len: 0,
22
  };
23
 
24
  describe("Run whisper.node", () => {
examples/addon.node/addon.cpp CHANGED
@@ -128,192 +128,227 @@ void whisper_print_segment_callback(struct whisper_context * ctx, struct whisper
128
 
129
  void cb_log_disable(enum ggml_log_level, const char *, void *) {}
130
 
131
- int run(whisper_params &params, std::vector<std::vector<std::string>> &result) {
132
- if (params.no_prints) {
133
- whisper_log_set(cb_log_disable, NULL);
 
 
 
 
 
 
 
 
 
 
 
134
  }
135
 
136
- if (params.fname_inp.empty() && params.pcmf32.empty()) {
137
- fprintf(stderr, "error: no input files or audio buffer specified\n");
138
- return 2;
 
 
139
  }
140
 
141
- if (params.language != "auto" && whisper_lang_id(params.language.c_str()) == -1) {
142
- fprintf(stderr, "error: unknown language '%s'\n", params.language.c_str());
143
- exit(0);
144
  }
145
 
146
- // whisper init
147
-
148
- struct whisper_context_params cparams = whisper_context_default_params();
149
- cparams.use_gpu = params.use_gpu;
150
- cparams.flash_attn = params.flash_attn;
151
- struct whisper_context * ctx = whisper_init_from_file_with_params(params.model.c_str(), cparams);
152
-
153
- if (ctx == nullptr) {
154
- fprintf(stderr, "error: failed to initialize whisper context\n");
155
- return 3;
 
156
  }
157
 
158
- // if params.pcmf32 is provided, set params.fname_inp to "buffer"
159
- // this is simpler than further modifications in the code
160
- if (!params.pcmf32.empty()) {
161
- fprintf(stderr, "info: using audio buffer as input\n");
162
- params.fname_inp.clear();
163
- params.fname_inp.emplace_back("buffer");
 
 
 
 
164
  }
165
 
166
- for (int f = 0; f < (int) params.fname_inp.size(); ++f) {
167
- const auto fname_inp = params.fname_inp[f];
168
- const auto fname_out = f < (int)params.fname_out.size() && !params.fname_out[f].empty() ? params.fname_out[f] : params.fname_inp[f];
 
 
 
 
 
 
 
 
169
 
170
- std::vector<float> pcmf32; // mono-channel F32 PCM
171
- std::vector<std::vector<float>> pcmf32s; // stereo-channel F32 PCM
 
 
172
 
173
- // read the input audio file if params.pcmf32 is not provided
174
- if (params.pcmf32.empty()) {
175
- if (!::read_audio_data(fname_inp, pcmf32, pcmf32s, params.diarize)) {
176
- fprintf(stderr, "error: failed to read audio file '%s'\n", fname_inp.c_str());
177
- continue;
178
- }
179
- } else {
180
- pcmf32 = params.pcmf32;
181
  }
182
 
183
- // print system information
184
- if (!params.no_prints) {
185
- fprintf(stderr, "\n");
186
- fprintf(stderr, "system_info: n_threads = %d / %d | %s\n",
187
- params.n_threads*params.n_processors, std::thread::hardware_concurrency(), whisper_print_system_info());
 
 
 
 
188
  }
189
 
190
- // print some info about the processing
191
- if (!params.no_prints) {
192
- fprintf(stderr, "\n");
193
- if (!whisper_is_multilingual(ctx)) {
194
- if (params.language != "en" || params.translate) {
195
- params.language = "en";
196
- params.translate = false;
197
- fprintf(stderr, "%s: WARNING: model is not multilingual, ignoring language and translation options\n", __func__);
 
 
 
 
 
 
 
 
 
 
 
198
  }
 
 
199
  }
200
- fprintf(stderr, "%s: processing '%s' (%d samples, %.1f sec), %d threads, %d processors, lang = %s, task = %s, timestamps = %d, audio_ctx = %d ...\n",
201
- __func__, fname_inp.c_str(), int(pcmf32.size()), float(pcmf32.size())/WHISPER_SAMPLE_RATE,
202
- params.n_threads, params.n_processors,
203
- params.language.c_str(),
204
- params.translate ? "translate" : "transcribe",
205
- params.no_timestamps ? 0 : 1,
206
- params.audio_ctx);
207
-
208
- fprintf(stderr, "\n");
209
- }
210
 
211
- // run the inference
212
- {
213
- whisper_full_params wparams = whisper_full_default_params(WHISPER_SAMPLING_GREEDY);
 
 
 
214
 
215
- wparams.strategy = params.beam_size > 1 ? WHISPER_SAMPLING_BEAM_SEARCH : WHISPER_SAMPLING_GREEDY;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
216
 
217
- wparams.print_realtime = false;
218
- wparams.print_progress = params.print_progress;
219
- wparams.print_timestamps = !params.no_timestamps;
220
- wparams.print_special = params.print_special;
221
- wparams.translate = params.translate;
222
- wparams.language = params.language.c_str();
223
- wparams.n_threads = params.n_threads;
224
- wparams.n_max_text_ctx = params.max_context >= 0 ? params.max_context : wparams.n_max_text_ctx;
225
- wparams.offset_ms = params.offset_t_ms;
226
- wparams.duration_ms = params.duration_ms;
227
 
228
- wparams.token_timestamps = params.output_wts || params.max_len > 0;
229
- wparams.thold_pt = params.word_thold;
230
- wparams.entropy_thold = params.entropy_thold;
231
- wparams.logprob_thold = params.logprob_thold;
232
- wparams.max_len = params.output_wts && params.max_len == 0 ? 60 : params.max_len;
233
- wparams.audio_ctx = params.audio_ctx;
234
 
235
- wparams.greedy.best_of = params.best_of;
236
- wparams.beam_search.beam_size = params.beam_size;
 
 
 
 
 
 
 
 
237
 
238
- wparams.initial_prompt = params.prompt.c_str();
 
 
 
 
 
239
 
240
- wparams.no_timestamps = params.no_timestamps;
 
241
 
242
- whisper_print_user_data user_data = { &params, &pcmf32s };
243
 
244
- // this callback is called on each new segment
245
- if (!wparams.print_realtime) {
246
- wparams.new_segment_callback = whisper_print_segment_callback;
247
- wparams.new_segment_callback_user_data = &user_data;
248
- }
249
 
250
- // example for abort mechanism
251
- // in this example, we do not abort the processing, but we could if the flag is set to true
252
- // the callback is called before every encoder run - if it returns false, the processing is aborted
253
- {
254
- static bool is_aborted = false; // NOTE: this should be atomic to avoid data race
255
 
256
- wparams.encoder_begin_callback = [](struct whisper_context * /*ctx*/, struct whisper_state * /*state*/, void * user_data) {
257
- bool is_aborted = *(bool*)user_data;
258
- return !is_aborted;
 
 
 
 
 
 
 
259
  };
260
- wparams.encoder_begin_callback_user_data = &is_aborted;
261
- }
262
 
263
- if (whisper_full_parallel(ctx, wparams, pcmf32.data(), pcmf32.size(), params.n_processors) != 0) {
264
- fprintf(stderr, "failed to process audio\n");
265
- return 10;
266
- }
267
- }
268
- }
269
 
270
- const int n_segments = whisper_full_n_segments(ctx);
271
- result.resize(n_segments);
272
- for (int i = 0; i < n_segments; ++i) {
273
- const char * text = whisper_full_get_segment_text(ctx, i);
274
- const int64_t t0 = whisper_full_get_segment_t0(ctx, i);
275
- const int64_t t1 = whisper_full_get_segment_t1(ctx, i);
276
 
277
- result[i].emplace_back(to_timestamp(t0, params.comma_in_time));
278
- result[i].emplace_back(to_timestamp(t1, params.comma_in_time));
279
- result[i].emplace_back(text);
 
 
280
  }
281
 
282
- whisper_print_timings(ctx);
283
- whisper_free(ctx);
284
-
285
- return 0;
286
- }
 
287
 
288
- class Worker : public Napi::AsyncWorker {
289
- public:
290
- Worker(Napi::Function& callback, whisper_params params)
291
- : Napi::AsyncWorker(callback), params(params) {}
292
 
293
- void Execute() override {
294
- run(params, result);
295
- }
296
 
297
- void OnOK() override {
298
- Napi::HandleScope scope(Env());
299
- Napi::Object res = Napi::Array::New(Env(), result.size());
300
- for (uint64_t i = 0; i < result.size(); ++i) {
301
- Napi::Object tmp = Napi::Array::New(Env(), 3);
302
- for (uint64_t j = 0; j < 3; ++j) {
303
- tmp[j] = Napi::String::New(Env(), result[i][j]);
304
- }
305
- res[i] = tmp;
306
  }
307
- Callback().Call({Env().Null(), res});
308
- }
309
-
310
- private:
311
- whisper_params params;
312
- std::vector<std::vector<std::string>> result;
313
  };
314
 
315
-
316
-
317
  Napi::Value whisper(const Napi::CallbackInfo& info) {
318
  Napi::Env env = info.Env();
319
  if (info.Length() <= 0 || !info[0].IsObject()) {
@@ -332,6 +367,23 @@ Napi::Value whisper(const Napi::CallbackInfo& info) {
332
  int32_t audio_ctx = whisper_params.Get("audio_ctx").As<Napi::Number>();
333
  bool comma_in_time = whisper_params.Get("comma_in_time").As<Napi::Boolean>();
334
  int32_t max_len = whisper_params.Get("max_len").As<Napi::Number>();
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
335
 
336
  Napi::Value pcmf32Value = whisper_params.Get("pcmf32");
337
  std::vector<float> pcmf32_vec;
@@ -355,9 +407,12 @@ Napi::Value whisper(const Napi::CallbackInfo& info) {
355
  params.pcmf32 = pcmf32_vec;
356
  params.comma_in_time = comma_in_time;
357
  params.max_len = max_len;
 
 
358
 
359
  Napi::Function callback = info[1].As<Napi::Function>();
360
- Worker* worker = new Worker(callback, params);
 
361
  worker->Queue();
362
  return env.Undefined();
363
  }
 
128
 
129
  void cb_log_disable(enum ggml_log_level, const char *, void *) {}
130
 
131
+ class ProgressWorker : public Napi::AsyncWorker {
132
+ public:
133
+ ProgressWorker(Napi::Function& callback, whisper_params params, Napi::Function progress_callback, Napi::Env env)
134
+ : Napi::AsyncWorker(callback), params(params), env(env) {
135
+ // Create thread-safe function
136
+ if (!progress_callback.IsEmpty()) {
137
+ tsfn = Napi::ThreadSafeFunction::New(
138
+ env,
139
+ progress_callback,
140
+ "Progress Callback",
141
+ 0,
142
+ 1
143
+ );
144
+ }
145
  }
146
 
147
+ ~ProgressWorker() {
148
+ if (tsfn) {
149
+ // Make sure to release the thread-safe function on destruction
150
+ tsfn.Release();
151
+ }
152
  }
153
 
154
+ void Execute() override {
155
+ // Use custom run function with progress callback support
156
+ run_with_progress(params, result);
157
  }
158
 
159
+ void OnOK() override {
160
+ Napi::HandleScope scope(Env());
161
+ Napi::Object res = Napi::Array::New(Env(), result.size());
162
+ for (uint64_t i = 0; i < result.size(); ++i) {
163
+ Napi::Object tmp = Napi::Array::New(Env(), 3);
164
+ for (uint64_t j = 0; j < 3; ++j) {
165
+ tmp[j] = Napi::String::New(Env(), result[i][j]);
166
+ }
167
+ res[i] = tmp;
168
+ }
169
+ Callback().Call({Env().Null(), res});
170
  }
171
 
172
+ // Progress callback function - using thread-safe function
173
+ void OnProgress(int progress) {
174
+ if (tsfn) {
175
+ // Use thread-safe function to call JavaScript callback
176
+ auto callback = [progress](Napi::Env env, Napi::Function jsCallback) {
177
+ jsCallback.Call({Napi::Number::New(env, progress)});
178
+ };
179
+
180
+ tsfn.BlockingCall(callback);
181
+ }
182
  }
183
 
184
+ private:
185
+ whisper_params params;
186
+ std::vector<std::vector<std::string>> result;
187
+ Napi::Env env;
188
+ Napi::ThreadSafeFunction tsfn;
189
+
190
+ // Custom run function with progress callback support
191
+ int run_with_progress(whisper_params &params, std::vector<std::vector<std::string>> &result) {
192
+ if (params.no_prints) {
193
+ whisper_log_set(cb_log_disable, NULL);
194
+ }
195
 
196
+ if (params.fname_inp.empty() && params.pcmf32.empty()) {
197
+ fprintf(stderr, "error: no input files or audio buffer specified\n");
198
+ return 2;
199
+ }
200
 
201
+ if (params.language != "auto" && whisper_lang_id(params.language.c_str()) == -1) {
202
+ fprintf(stderr, "error: unknown language '%s'\n", params.language.c_str());
203
+ exit(0);
 
 
 
 
 
204
  }
205
 
206
+ // whisper init
207
+ struct whisper_context_params cparams = whisper_context_default_params();
208
+ cparams.use_gpu = params.use_gpu;
209
+ cparams.flash_attn = params.flash_attn;
210
+ struct whisper_context * ctx = whisper_init_from_file_with_params(params.model.c_str(), cparams);
211
+
212
+ if (ctx == nullptr) {
213
+ fprintf(stderr, "error: failed to initialize whisper context\n");
214
+ return 3;
215
  }
216
 
217
+ // If params.pcmf32 provides, set params.fname_inp as "buffer"
218
+ if (!params.pcmf32.empty()) {
219
+ fprintf(stderr, "info: using audio buffer as input\n");
220
+ params.fname_inp.clear();
221
+ params.fname_inp.emplace_back("buffer");
222
+ }
223
+
224
+ for (int f = 0; f < (int) params.fname_inp.size(); ++f) {
225
+ const auto fname_inp = params.fname_inp[f];
226
+ const auto fname_out = f < (int)params.fname_out.size() && !params.fname_out[f].empty() ? params.fname_out[f] : params.fname_inp[f];
227
+
228
+ std::vector<float> pcmf32; // mono-channel F32 PCM
229
+ std::vector<std::vector<float>> pcmf32s; // stereo-channel F32 PCM
230
+
231
+ // If params.pcmf32 is empty, read input audio file
232
+ if (params.pcmf32.empty()) {
233
+ if (!::read_audio_data(fname_inp, pcmf32, pcmf32s, params.diarize)) {
234
+ fprintf(stderr, "error: failed to read audio file '%s'\n", fname_inp.c_str());
235
+ continue;
236
  }
237
+ } else {
238
+ pcmf32 = params.pcmf32;
239
  }
 
 
 
 
 
 
 
 
 
 
240
 
241
+ // Print system info
242
+ if (!params.no_prints) {
243
+ fprintf(stderr, "\n");
244
+ fprintf(stderr, "system_info: n_threads = %d / %d | %s\n",
245
+ params.n_threads*params.n_processors, std::thread::hardware_concurrency(), whisper_print_system_info());
246
+ }
247
 
248
+ // Print processing info
249
+ if (!params.no_prints) {
250
+ fprintf(stderr, "\n");
251
+ if (!whisper_is_multilingual(ctx)) {
252
+ if (params.language != "en" || params.translate) {
253
+ params.language = "en";
254
+ params.translate = false;
255
+ fprintf(stderr, "%s: WARNING: model is not multilingual, ignoring language and translation options\n", __func__);
256
+ }
257
+ }
258
+ fprintf(stderr, "%s: processing '%s' (%d samples, %.1f sec), %d threads, %d processors, lang = %s, task = %s, timestamps = %d, audio_ctx = %d ...\n",
259
+ __func__, fname_inp.c_str(), int(pcmf32.size()), float(pcmf32.size())/WHISPER_SAMPLE_RATE,
260
+ params.n_threads, params.n_processors,
261
+ params.language.c_str(),
262
+ params.translate ? "translate" : "transcribe",
263
+ params.no_timestamps ? 0 : 1,
264
+ params.audio_ctx);
265
+
266
+ fprintf(stderr, "\n");
267
+ }
268
 
269
+ // Run inference
270
+ {
271
+ whisper_full_params wparams = whisper_full_default_params(WHISPER_SAMPLING_GREEDY);
 
 
 
 
 
 
 
272
 
273
+ wparams.strategy = params.beam_size > 1 ? WHISPER_SAMPLING_BEAM_SEARCH : WHISPER_SAMPLING_GREEDY;
 
 
 
 
 
274
 
275
+ wparams.print_realtime = false;
276
+ wparams.print_progress = params.print_progress;
277
+ wparams.print_timestamps = !params.no_timestamps;
278
+ wparams.print_special = params.print_special;
279
+ wparams.translate = params.translate;
280
+ wparams.language = params.language.c_str();
281
+ wparams.n_threads = params.n_threads;
282
+ wparams.n_max_text_ctx = params.max_context >= 0 ? params.max_context : wparams.n_max_text_ctx;
283
+ wparams.offset_ms = params.offset_t_ms;
284
+ wparams.duration_ms = params.duration_ms;
285
 
286
+ wparams.token_timestamps = params.output_wts || params.max_len > 0;
287
+ wparams.thold_pt = params.word_thold;
288
+ wparams.entropy_thold = params.entropy_thold;
289
+ wparams.logprob_thold = params.logprob_thold;
290
+ wparams.max_len = params.output_wts && params.max_len == 0 ? 60 : params.max_len;
291
+ wparams.audio_ctx = params.audio_ctx;
292
 
293
+ wparams.greedy.best_of = params.best_of;
294
+ wparams.beam_search.beam_size = params.beam_size;
295
 
296
+ wparams.initial_prompt = params.prompt.c_str();
297
 
298
+ wparams.no_timestamps = params.no_timestamps;
 
 
 
 
299
 
300
+ whisper_print_user_data user_data = { &params, &pcmf32s };
 
 
 
 
301
 
302
+ // This callback is called for each new segment
303
+ if (!wparams.print_realtime) {
304
+ wparams.new_segment_callback = whisper_print_segment_callback;
305
+ wparams.new_segment_callback_user_data = &user_data;
306
+ }
307
+
308
+ // Set progress callback
309
+ wparams.progress_callback = [](struct whisper_context * /*ctx*/, struct whisper_state * /*state*/, int progress, void * user_data) {
310
+ ProgressWorker* worker = static_cast<ProgressWorker*>(user_data);
311
+ worker->OnProgress(progress);
312
  };
313
+ wparams.progress_callback_user_data = this;
 
314
 
315
+ // Abort mechanism example
316
+ {
317
+ static bool is_aborted = false; // Note: this should be atomic to avoid data races
 
 
 
318
 
319
+ wparams.encoder_begin_callback = [](struct whisper_context * /*ctx*/, struct whisper_state * /*state*/, void * user_data) {
320
+ bool is_aborted = *(bool*)user_data;
321
+ return !is_aborted;
322
+ };
323
+ wparams.encoder_begin_callback_user_data = &is_aborted;
324
+ }
325
 
326
+ if (whisper_full_parallel(ctx, wparams, pcmf32.data(), pcmf32.size(), params.n_processors) != 0) {
327
+ fprintf(stderr, "failed to process audio\n");
328
+ return 10;
329
+ }
330
+ }
331
  }
332
 
333
+ const int n_segments = whisper_full_n_segments(ctx);
334
+ result.resize(n_segments);
335
+ for (int i = 0; i < n_segments; ++i) {
336
+ const char * text = whisper_full_get_segment_text(ctx, i);
337
+ const int64_t t0 = whisper_full_get_segment_t0(ctx, i);
338
+ const int64_t t1 = whisper_full_get_segment_t1(ctx, i);
339
 
340
+ result[i].emplace_back(to_timestamp(t0, params.comma_in_time));
341
+ result[i].emplace_back(to_timestamp(t1, params.comma_in_time));
342
+ result[i].emplace_back(text);
343
+ }
344
 
345
+ whisper_print_timings(ctx);
346
+ whisper_free(ctx);
 
347
 
348
+ return 0;
 
 
 
 
 
 
 
 
349
  }
 
 
 
 
 
 
350
  };
351
 
 
 
352
  Napi::Value whisper(const Napi::CallbackInfo& info) {
353
  Napi::Env env = info.Env();
354
  if (info.Length() <= 0 || !info[0].IsObject()) {
 
367
  int32_t audio_ctx = whisper_params.Get("audio_ctx").As<Napi::Number>();
368
  bool comma_in_time = whisper_params.Get("comma_in_time").As<Napi::Boolean>();
369
  int32_t max_len = whisper_params.Get("max_len").As<Napi::Number>();
370
+
371
+ // support prompt
372
+ std::string prompt = "";
373
+ if (whisper_params.Has("prompt") && whisper_params.Get("prompt").IsString()) {
374
+ prompt = whisper_params.Get("prompt").As<Napi::String>();
375
+ }
376
+
377
+ // Add support for print_progress
378
+ bool print_progress = false;
379
+ if (whisper_params.Has("print_progress")) {
380
+ print_progress = whisper_params.Get("print_progress").As<Napi::Boolean>();
381
+ }
382
+ // Add support for progress_callback
383
+ Napi::Function progress_callback;
384
+ if (whisper_params.Has("progress_callback") && whisper_params.Get("progress_callback").IsFunction()) {
385
+ progress_callback = whisper_params.Get("progress_callback").As<Napi::Function>();
386
+ }
387
 
388
  Napi::Value pcmf32Value = whisper_params.Get("pcmf32");
389
  std::vector<float> pcmf32_vec;
 
407
  params.pcmf32 = pcmf32_vec;
408
  params.comma_in_time = comma_in_time;
409
  params.max_len = max_len;
410
+ params.print_progress = print_progress;
411
+ params.prompt = prompt;
412
 
413
  Napi::Function callback = info[1].As<Napi::Function>();
414
+ // Create a new Worker class with progress callback support
415
+ ProgressWorker* worker = new ProgressWorker(callback, params, progress_callback, env);
416
  worker->Queue();
417
  return env.Undefined();
418
  }
examples/addon.node/index.js CHANGED
@@ -19,6 +19,9 @@ const whisperParams = {
19
  no_timestamps: false,
20
  audio_ctx: 0,
21
  max_len: 0,
 
 
 
22
  };
23
 
24
  const arguments = process.argv.slice(2);
 
19
  no_timestamps: false,
20
  audio_ctx: 0,
21
  max_len: 0,
22
+ progress_callback: (progress) => {
23
+ console.log(`progress: ${progress}%`);
24
+ }
25
  };
26
 
27
  const arguments = process.argv.slice(2);