appleboy commited on
Commit
19d7f69
·
unverified ·
1 Parent(s): 5d4a28f

go : improve progress reporting and callback handling (#1024)

Browse files

- Rename `cb` to `callNewSegment` in the `Process` function
- Add `callProgress` as a new parameter to the `Process` function
- Introduce `ProgressCallback` type for reporting progress during processing
- Update `Whisper_full` function to include `progressCallback` parameter
- Add `registerProgressCallback` function and `cbProgress` map for handling progress callbacks

Signed-off-by: appleboy <[email protected]>

bindings/go/Makefile CHANGED
@@ -32,7 +32,7 @@ mkdir:
32
  modtidy:
33
  @go mod tidy
34
 
35
- clean:
36
  @echo Clean
37
  @rm -fr $(BUILD_DIR)
38
  @go clean
 
32
  modtidy:
33
  @go mod tidy
34
 
35
+ clean:
36
  @echo Clean
37
  @rm -fr $(BUILD_DIR)
38
  @go clean
bindings/go/pkg/whisper/context.go CHANGED
@@ -152,7 +152,11 @@ func (context *context) WhisperLangAutoDetect(offset_ms int, n_threads int) ([]f
152
  }
153
 
154
  // Process new sample data and return any errors
155
- func (context *context) Process(data []float32, cb SegmentCallback) error {
 
 
 
 
156
  if context.model.ctx == nil {
157
  return ErrInternalAppError
158
  }
@@ -165,24 +169,28 @@ func (context *context) Process(data []float32, cb SegmentCallback) error {
165
  processors := 0
166
  if processors > 1 {
167
  if err := context.model.ctx.Whisper_full_parallel(context.params, data, processors, nil, func(new int) {
168
- if cb != nil {
169
  num_segments := context.model.ctx.Whisper_full_n_segments()
170
  s0 := num_segments - new
171
  for i := s0; i < num_segments; i++ {
172
- cb(toSegment(context.model.ctx, i))
173
  }
174
  }
175
  }); err != nil {
176
  return err
177
  }
178
  } else if err := context.model.ctx.Whisper_full(context.params, data, nil, func(new int) {
179
- if cb != nil {
180
  num_segments := context.model.ctx.Whisper_full_n_segments()
181
  s0 := num_segments - new
182
  for i := s0; i < num_segments; i++ {
183
- cb(toSegment(context.model.ctx, i))
184
  }
185
  }
 
 
 
 
186
  }); err != nil {
187
  return err
188
  }
 
152
  }
153
 
154
  // Process new sample data and return any errors
155
+ func (context *context) Process(
156
+ data []float32,
157
+ callNewSegment SegmentCallback,
158
+ callProgress ProgressCallback,
159
+ ) error {
160
  if context.model.ctx == nil {
161
  return ErrInternalAppError
162
  }
 
169
  processors := 0
170
  if processors > 1 {
171
  if err := context.model.ctx.Whisper_full_parallel(context.params, data, processors, nil, func(new int) {
172
+ if callNewSegment != nil {
173
  num_segments := context.model.ctx.Whisper_full_n_segments()
174
  s0 := num_segments - new
175
  for i := s0; i < num_segments; i++ {
176
+ callNewSegment(toSegment(context.model.ctx, i))
177
  }
178
  }
179
  }); err != nil {
180
  return err
181
  }
182
  } else if err := context.model.ctx.Whisper_full(context.params, data, nil, func(new int) {
183
+ if callNewSegment != nil {
184
  num_segments := context.model.ctx.Whisper_full_n_segments()
185
  s0 := num_segments - new
186
  for i := s0; i < num_segments; i++ {
187
+ callNewSegment(toSegment(context.model.ctx, i))
188
  }
189
  }
190
+ }, func(progress int) {
191
+ if callProgress != nil {
192
+ callProgress(progress)
193
+ }
194
  }); err != nil {
195
  return err
196
  }
bindings/go/pkg/whisper/interface.go CHANGED
@@ -12,6 +12,10 @@ import (
12
  // time. It is called during the Process function
13
  type SegmentCallback func(Segment)
14
 
 
 
 
 
15
  // Model is the interface to a whisper model. Create a new model with the
16
  // function whisper.New(string)
17
  type Model interface {
@@ -47,7 +51,7 @@ type Context interface {
47
  // Process mono audio data and return any errors.
48
  // If defined, newly generated segments are passed to the
49
  // callback function during processing.
50
- Process([]float32, SegmentCallback) error
51
 
52
  // After process is called, return segments until the end of the stream
53
  // is reached, when io.EOF is returned.
 
12
  // time. It is called during the Process function
13
  type SegmentCallback func(Segment)
14
 
15
+ // ProgressCallback is the callback function for reporting progress during
16
+ // processing. It is called during the Process function
17
+ type ProgressCallback func(int)
18
+
19
  // Model is the interface to a whisper model. Create a new model with the
20
  // function whisper.New(string)
21
  type Model interface {
 
51
  // Process mono audio data and return any errors.
52
  // If defined, newly generated segments are passed to the
53
  // callback function during processing.
54
+ Process([]float32, SegmentCallback, ProgressCallback) error
55
 
56
  // After process is called, return segments until the end of the stream
57
  // is reached, when io.EOF is returned.
bindings/go/whisper.go CHANGED
@@ -15,6 +15,7 @@ import (
15
  #include <stdlib.h>
16
 
17
  extern void callNewSegment(void* user_data, int new);
 
18
  extern bool callEncoderBegin(void* user_data);
19
 
20
  // Text segment callback
@@ -26,6 +27,15 @@ static void whisper_new_segment_cb(struct whisper_context* ctx, struct whisper_s
26
  }
27
  }
28
 
 
 
 
 
 
 
 
 
 
29
  // Encoder begin callback
30
  // If not NULL, called before the encoder starts
31
  // If it returns false, the computation is aborted
@@ -43,6 +53,8 @@ static struct whisper_full_params whisper_full_default_params_cb(struct whisper_
43
  params.new_segment_callback_user_data = (void*)(ctx);
44
  params.encoder_begin_callback = whisper_encoder_begin_cb;
45
  params.encoder_begin_callback_user_data = (void*)(ctx);
 
 
46
  return params;
47
  }
48
  */
@@ -290,11 +302,19 @@ func (ctx *Context) Whisper_full_default_params(strategy SamplingStrategy) Param
290
 
291
  // Run the entire model: PCM -> log mel spectrogram -> encoder -> decoder -> text
292
  // Uses the specified decoding strategy to obtain the text.
293
- func (ctx *Context) Whisper_full(params Params, samples []float32, encoderBeginCallback func() bool, newSegmentCallback func(int)) error {
 
 
 
 
 
 
294
  registerEncoderBeginCallback(ctx, encoderBeginCallback)
295
  registerNewSegmentCallback(ctx, newSegmentCallback)
 
296
  defer registerEncoderBeginCallback(ctx, nil)
297
  defer registerNewSegmentCallback(ctx, nil)
 
298
  if C.whisper_full((*C.struct_whisper_context)(ctx), (C.struct_whisper_full_params)(params), (*C.float)(&samples[0]), C.int(len(samples))) == 0 {
299
  return nil
300
  } else {
@@ -370,6 +390,7 @@ func (ctx *Context) Whisper_full_get_token_p(segment int, token int) float32 {
370
 
371
  var (
372
  cbNewSegment = make(map[unsafe.Pointer]func(int))
 
373
  cbEncoderBegin = make(map[unsafe.Pointer]func() bool)
374
  )
375
 
@@ -381,6 +402,14 @@ func registerNewSegmentCallback(ctx *Context, fn func(int)) {
381
  }
382
  }
383
 
 
 
 
 
 
 
 
 
384
  func registerEncoderBeginCallback(ctx *Context, fn func() bool) {
385
  if fn == nil {
386
  delete(cbEncoderBegin, unsafe.Pointer(ctx))
@@ -396,6 +425,13 @@ func callNewSegment(user_data unsafe.Pointer, new C.int) {
396
  }
397
  }
398
 
 
 
 
 
 
 
 
399
  //export callEncoderBegin
400
  func callEncoderBegin(user_data unsafe.Pointer) C.bool {
401
  if fn, ok := cbEncoderBegin[user_data]; ok {
 
15
  #include <stdlib.h>
16
 
17
  extern void callNewSegment(void* user_data, int new);
18
+ extern void callProgress(void* user_data, int progress);
19
  extern bool callEncoderBegin(void* user_data);
20
 
21
  // Text segment callback
 
27
  }
28
  }
29
 
30
+ // Progress callback
31
+ // Called on every newly generated text segment
32
+ // Use the whisper_full_...() functions to obtain the text segments
33
+ static void whisper_progress_cb(struct whisper_context* ctx, struct whisper_state* state, int progress, void* user_data) {
34
+ if(user_data != NULL && ctx != NULL) {
35
+ callProgress(user_data, progress);
36
+ }
37
+ }
38
+
39
  // Encoder begin callback
40
  // If not NULL, called before the encoder starts
41
  // If it returns false, the computation is aborted
 
53
  params.new_segment_callback_user_data = (void*)(ctx);
54
  params.encoder_begin_callback = whisper_encoder_begin_cb;
55
  params.encoder_begin_callback_user_data = (void*)(ctx);
56
+ params.progress_callback = whisper_progress_cb;
57
+ params.progress_callback_user_data = (void*)(ctx);
58
  return params;
59
  }
60
  */
 
302
 
303
  // Run the entire model: PCM -> log mel spectrogram -> encoder -> decoder -> text
304
  // Uses the specified decoding strategy to obtain the text.
305
+ func (ctx *Context) Whisper_full(
306
+ params Params,
307
+ samples []float32,
308
+ encoderBeginCallback func() bool,
309
+ newSegmentCallback func(int),
310
+ progressCallback func(int),
311
+ ) error {
312
  registerEncoderBeginCallback(ctx, encoderBeginCallback)
313
  registerNewSegmentCallback(ctx, newSegmentCallback)
314
+ registerProgressCallback(ctx, progressCallback)
315
  defer registerEncoderBeginCallback(ctx, nil)
316
  defer registerNewSegmentCallback(ctx, nil)
317
+ defer registerProgressCallback(ctx, nil)
318
  if C.whisper_full((*C.struct_whisper_context)(ctx), (C.struct_whisper_full_params)(params), (*C.float)(&samples[0]), C.int(len(samples))) == 0 {
319
  return nil
320
  } else {
 
390
 
391
  var (
392
  cbNewSegment = make(map[unsafe.Pointer]func(int))
393
+ cbProgress = make(map[unsafe.Pointer]func(int))
394
  cbEncoderBegin = make(map[unsafe.Pointer]func() bool)
395
  )
396
 
 
402
  }
403
  }
404
 
405
+ func registerProgressCallback(ctx *Context, fn func(int)) {
406
+ if fn == nil {
407
+ delete(cbProgress, unsafe.Pointer(ctx))
408
+ } else {
409
+ cbProgress[unsafe.Pointer(ctx)] = fn
410
+ }
411
+ }
412
+
413
  func registerEncoderBeginCallback(ctx *Context, fn func() bool) {
414
  if fn == nil {
415
  delete(cbEncoderBegin, unsafe.Pointer(ctx))
 
425
  }
426
  }
427
 
428
+ //export callProgress
429
+ func callProgress(user_data unsafe.Pointer, progress C.int) {
430
+ if fn, ok := cbProgress[user_data]; ok {
431
+ fn(int(progress))
432
+ }
433
+ }
434
+
435
  //export callEncoderBegin
436
  func callEncoderBegin(user_data unsafe.Pointer) C.bool {
437
  if fn, ok := cbEncoderBegin[user_data]; ok {
bindings/go/whisper_test.go CHANGED
@@ -52,7 +52,7 @@ func Test_Whisper_001(t *testing.T) {
52
  defer ctx.Whisper_free()
53
  params := ctx.Whisper_full_default_params(whisper.SAMPLING_GREEDY)
54
  data := buf.AsFloat32Buffer().Data
55
- err = ctx.Whisper_full(params, data, nil, nil)
56
  assert.NoError(err)
57
 
58
  // Print out tokens
 
52
  defer ctx.Whisper_free()
53
  params := ctx.Whisper_full_default_params(whisper.SAMPLING_GREEDY)
54
  data := buf.AsFloat32Buffer().Data
55
+ err = ctx.Whisper_full(params, data, nil, nil, nil)
56
  assert.NoError(err)
57
 
58
  // Print out tokens