Spaces:
Sleeping
Sleeping
go : improve progress reporting and callback handling (#1024)
Browse files- Rename `cb` to `callNewSegment` in the `Process` function
- Add `callProgress` as a new parameter to the `Process` function
- Introduce `ProgressCallback` type for reporting progress during processing
- Update `Whisper_full` function to include `progressCallback` parameter
- Add `registerProgressCallback` function and `cbProgress` map for handling progress callbacks
Signed-off-by: appleboy <[email protected]>
- bindings/go/Makefile +1 -1
- bindings/go/pkg/whisper/context.go +13 -5
- bindings/go/pkg/whisper/interface.go +5 -1
- bindings/go/whisper.go +37 -1
- bindings/go/whisper_test.go +1 -1
bindings/go/Makefile
CHANGED
|
@@ -32,7 +32,7 @@ mkdir:
|
|
| 32 |
modtidy:
|
| 33 |
@go mod tidy
|
| 34 |
|
| 35 |
-
clean:
|
| 36 |
@echo Clean
|
| 37 |
@rm -fr $(BUILD_DIR)
|
| 38 |
@go clean
|
|
|
|
| 32 |
modtidy:
|
| 33 |
@go mod tidy
|
| 34 |
|
| 35 |
+
clean:
|
| 36 |
@echo Clean
|
| 37 |
@rm -fr $(BUILD_DIR)
|
| 38 |
@go clean
|
bindings/go/pkg/whisper/context.go
CHANGED
|
@@ -152,7 +152,11 @@ func (context *context) WhisperLangAutoDetect(offset_ms int, n_threads int) ([]f
|
|
| 152 |
}
|
| 153 |
|
| 154 |
// Process new sample data and return any errors
|
| 155 |
-
func (context *context) Process(
|
|
|
|
|
|
|
|
|
|
|
|
|
| 156 |
if context.model.ctx == nil {
|
| 157 |
return ErrInternalAppError
|
| 158 |
}
|
|
@@ -165,24 +169,28 @@ func (context *context) Process(data []float32, cb SegmentCallback) error {
|
|
| 165 |
processors := 0
|
| 166 |
if processors > 1 {
|
| 167 |
if err := context.model.ctx.Whisper_full_parallel(context.params, data, processors, nil, func(new int) {
|
| 168 |
-
if
|
| 169 |
num_segments := context.model.ctx.Whisper_full_n_segments()
|
| 170 |
s0 := num_segments - new
|
| 171 |
for i := s0; i < num_segments; i++ {
|
| 172 |
-
|
| 173 |
}
|
| 174 |
}
|
| 175 |
}); err != nil {
|
| 176 |
return err
|
| 177 |
}
|
| 178 |
} else if err := context.model.ctx.Whisper_full(context.params, data, nil, func(new int) {
|
| 179 |
-
if
|
| 180 |
num_segments := context.model.ctx.Whisper_full_n_segments()
|
| 181 |
s0 := num_segments - new
|
| 182 |
for i := s0; i < num_segments; i++ {
|
| 183 |
-
|
| 184 |
}
|
| 185 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
| 186 |
}); err != nil {
|
| 187 |
return err
|
| 188 |
}
|
|
|
|
| 152 |
}
|
| 153 |
|
| 154 |
// Process new sample data and return any errors
|
| 155 |
+
func (context *context) Process(
|
| 156 |
+
data []float32,
|
| 157 |
+
callNewSegment SegmentCallback,
|
| 158 |
+
callProgress ProgressCallback,
|
| 159 |
+
) error {
|
| 160 |
if context.model.ctx == nil {
|
| 161 |
return ErrInternalAppError
|
| 162 |
}
|
|
|
|
| 169 |
processors := 0
|
| 170 |
if processors > 1 {
|
| 171 |
if err := context.model.ctx.Whisper_full_parallel(context.params, data, processors, nil, func(new int) {
|
| 172 |
+
if callNewSegment != nil {
|
| 173 |
num_segments := context.model.ctx.Whisper_full_n_segments()
|
| 174 |
s0 := num_segments - new
|
| 175 |
for i := s0; i < num_segments; i++ {
|
| 176 |
+
callNewSegment(toSegment(context.model.ctx, i))
|
| 177 |
}
|
| 178 |
}
|
| 179 |
}); err != nil {
|
| 180 |
return err
|
| 181 |
}
|
| 182 |
} else if err := context.model.ctx.Whisper_full(context.params, data, nil, func(new int) {
|
| 183 |
+
if callNewSegment != nil {
|
| 184 |
num_segments := context.model.ctx.Whisper_full_n_segments()
|
| 185 |
s0 := num_segments - new
|
| 186 |
for i := s0; i < num_segments; i++ {
|
| 187 |
+
callNewSegment(toSegment(context.model.ctx, i))
|
| 188 |
}
|
| 189 |
}
|
| 190 |
+
}, func(progress int) {
|
| 191 |
+
if callProgress != nil {
|
| 192 |
+
callProgress(progress)
|
| 193 |
+
}
|
| 194 |
}); err != nil {
|
| 195 |
return err
|
| 196 |
}
|
bindings/go/pkg/whisper/interface.go
CHANGED
|
@@ -12,6 +12,10 @@ import (
|
|
| 12 |
// time. It is called during the Process function
|
| 13 |
type SegmentCallback func(Segment)
|
| 14 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 15 |
// Model is the interface to a whisper model. Create a new model with the
|
| 16 |
// function whisper.New(string)
|
| 17 |
type Model interface {
|
|
@@ -47,7 +51,7 @@ type Context interface {
|
|
| 47 |
// Process mono audio data and return any errors.
|
| 48 |
// If defined, newly generated segments are passed to the
|
| 49 |
// callback function during processing.
|
| 50 |
-
Process([]float32, SegmentCallback) error
|
| 51 |
|
| 52 |
// After process is called, return segments until the end of the stream
|
| 53 |
// is reached, when io.EOF is returned.
|
|
|
|
| 12 |
// time. It is called during the Process function
|
| 13 |
type SegmentCallback func(Segment)
|
| 14 |
|
| 15 |
+
// ProgressCallback is the callback function for reporting progress during
|
| 16 |
+
// processing. It is called during the Process function
|
| 17 |
+
type ProgressCallback func(int)
|
| 18 |
+
|
| 19 |
// Model is the interface to a whisper model. Create a new model with the
|
| 20 |
// function whisper.New(string)
|
| 21 |
type Model interface {
|
|
|
|
| 51 |
// Process mono audio data and return any errors.
|
| 52 |
// If defined, newly generated segments are passed to the
|
| 53 |
// callback function during processing.
|
| 54 |
+
Process([]float32, SegmentCallback, ProgressCallback) error
|
| 55 |
|
| 56 |
// After process is called, return segments until the end of the stream
|
| 57 |
// is reached, when io.EOF is returned.
|
bindings/go/whisper.go
CHANGED
|
@@ -15,6 +15,7 @@ import (
|
|
| 15 |
#include <stdlib.h>
|
| 16 |
|
| 17 |
extern void callNewSegment(void* user_data, int new);
|
|
|
|
| 18 |
extern bool callEncoderBegin(void* user_data);
|
| 19 |
|
| 20 |
// Text segment callback
|
|
@@ -26,6 +27,15 @@ static void whisper_new_segment_cb(struct whisper_context* ctx, struct whisper_s
|
|
| 26 |
}
|
| 27 |
}
|
| 28 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 29 |
// Encoder begin callback
|
| 30 |
// If not NULL, called before the encoder starts
|
| 31 |
// If it returns false, the computation is aborted
|
|
@@ -43,6 +53,8 @@ static struct whisper_full_params whisper_full_default_params_cb(struct whisper_
|
|
| 43 |
params.new_segment_callback_user_data = (void*)(ctx);
|
| 44 |
params.encoder_begin_callback = whisper_encoder_begin_cb;
|
| 45 |
params.encoder_begin_callback_user_data = (void*)(ctx);
|
|
|
|
|
|
|
| 46 |
return params;
|
| 47 |
}
|
| 48 |
*/
|
|
@@ -290,11 +302,19 @@ func (ctx *Context) Whisper_full_default_params(strategy SamplingStrategy) Param
|
|
| 290 |
|
| 291 |
// Run the entire model: PCM -> log mel spectrogram -> encoder -> decoder -> text
|
| 292 |
// Uses the specified decoding strategy to obtain the text.
|
| 293 |
-
func (ctx *Context) Whisper_full(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 294 |
registerEncoderBeginCallback(ctx, encoderBeginCallback)
|
| 295 |
registerNewSegmentCallback(ctx, newSegmentCallback)
|
|
|
|
| 296 |
defer registerEncoderBeginCallback(ctx, nil)
|
| 297 |
defer registerNewSegmentCallback(ctx, nil)
|
|
|
|
| 298 |
if C.whisper_full((*C.struct_whisper_context)(ctx), (C.struct_whisper_full_params)(params), (*C.float)(&samples[0]), C.int(len(samples))) == 0 {
|
| 299 |
return nil
|
| 300 |
} else {
|
|
@@ -370,6 +390,7 @@ func (ctx *Context) Whisper_full_get_token_p(segment int, token int) float32 {
|
|
| 370 |
|
| 371 |
var (
|
| 372 |
cbNewSegment = make(map[unsafe.Pointer]func(int))
|
|
|
|
| 373 |
cbEncoderBegin = make(map[unsafe.Pointer]func() bool)
|
| 374 |
)
|
| 375 |
|
|
@@ -381,6 +402,14 @@ func registerNewSegmentCallback(ctx *Context, fn func(int)) {
|
|
| 381 |
}
|
| 382 |
}
|
| 383 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 384 |
func registerEncoderBeginCallback(ctx *Context, fn func() bool) {
|
| 385 |
if fn == nil {
|
| 386 |
delete(cbEncoderBegin, unsafe.Pointer(ctx))
|
|
@@ -396,6 +425,13 @@ func callNewSegment(user_data unsafe.Pointer, new C.int) {
|
|
| 396 |
}
|
| 397 |
}
|
| 398 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 399 |
//export callEncoderBegin
|
| 400 |
func callEncoderBegin(user_data unsafe.Pointer) C.bool {
|
| 401 |
if fn, ok := cbEncoderBegin[user_data]; ok {
|
|
|
|
| 15 |
#include <stdlib.h>
|
| 16 |
|
| 17 |
extern void callNewSegment(void* user_data, int new);
|
| 18 |
+
extern void callProgress(void* user_data, int progress);
|
| 19 |
extern bool callEncoderBegin(void* user_data);
|
| 20 |
|
| 21 |
// Text segment callback
|
|
|
|
| 27 |
}
|
| 28 |
}
|
| 29 |
|
| 30 |
+
// Progress callback
|
| 31 |
+
// Called on every newly generated text segment
|
| 32 |
+
// Use the whisper_full_...() functions to obtain the text segments
|
| 33 |
+
static void whisper_progress_cb(struct whisper_context* ctx, struct whisper_state* state, int progress, void* user_data) {
|
| 34 |
+
if(user_data != NULL && ctx != NULL) {
|
| 35 |
+
callProgress(user_data, progress);
|
| 36 |
+
}
|
| 37 |
+
}
|
| 38 |
+
|
| 39 |
// Encoder begin callback
|
| 40 |
// If not NULL, called before the encoder starts
|
| 41 |
// If it returns false, the computation is aborted
|
|
|
|
| 53 |
params.new_segment_callback_user_data = (void*)(ctx);
|
| 54 |
params.encoder_begin_callback = whisper_encoder_begin_cb;
|
| 55 |
params.encoder_begin_callback_user_data = (void*)(ctx);
|
| 56 |
+
params.progress_callback = whisper_progress_cb;
|
| 57 |
+
params.progress_callback_user_data = (void*)(ctx);
|
| 58 |
return params;
|
| 59 |
}
|
| 60 |
*/
|
|
|
|
| 302 |
|
| 303 |
// Run the entire model: PCM -> log mel spectrogram -> encoder -> decoder -> text
|
| 304 |
// Uses the specified decoding strategy to obtain the text.
|
| 305 |
+
func (ctx *Context) Whisper_full(
|
| 306 |
+
params Params,
|
| 307 |
+
samples []float32,
|
| 308 |
+
encoderBeginCallback func() bool,
|
| 309 |
+
newSegmentCallback func(int),
|
| 310 |
+
progressCallback func(int),
|
| 311 |
+
) error {
|
| 312 |
registerEncoderBeginCallback(ctx, encoderBeginCallback)
|
| 313 |
registerNewSegmentCallback(ctx, newSegmentCallback)
|
| 314 |
+
registerProgressCallback(ctx, progressCallback)
|
| 315 |
defer registerEncoderBeginCallback(ctx, nil)
|
| 316 |
defer registerNewSegmentCallback(ctx, nil)
|
| 317 |
+
defer registerProgressCallback(ctx, nil)
|
| 318 |
if C.whisper_full((*C.struct_whisper_context)(ctx), (C.struct_whisper_full_params)(params), (*C.float)(&samples[0]), C.int(len(samples))) == 0 {
|
| 319 |
return nil
|
| 320 |
} else {
|
|
|
|
| 390 |
|
| 391 |
var (
|
| 392 |
cbNewSegment = make(map[unsafe.Pointer]func(int))
|
| 393 |
+
cbProgress = make(map[unsafe.Pointer]func(int))
|
| 394 |
cbEncoderBegin = make(map[unsafe.Pointer]func() bool)
|
| 395 |
)
|
| 396 |
|
|
|
|
| 402 |
}
|
| 403 |
}
|
| 404 |
|
| 405 |
+
func registerProgressCallback(ctx *Context, fn func(int)) {
|
| 406 |
+
if fn == nil {
|
| 407 |
+
delete(cbProgress, unsafe.Pointer(ctx))
|
| 408 |
+
} else {
|
| 409 |
+
cbProgress[unsafe.Pointer(ctx)] = fn
|
| 410 |
+
}
|
| 411 |
+
}
|
| 412 |
+
|
| 413 |
func registerEncoderBeginCallback(ctx *Context, fn func() bool) {
|
| 414 |
if fn == nil {
|
| 415 |
delete(cbEncoderBegin, unsafe.Pointer(ctx))
|
|
|
|
| 425 |
}
|
| 426 |
}
|
| 427 |
|
| 428 |
+
//export callProgress
|
| 429 |
+
func callProgress(user_data unsafe.Pointer, progress C.int) {
|
| 430 |
+
if fn, ok := cbProgress[user_data]; ok {
|
| 431 |
+
fn(int(progress))
|
| 432 |
+
}
|
| 433 |
+
}
|
| 434 |
+
|
| 435 |
//export callEncoderBegin
|
| 436 |
func callEncoderBegin(user_data unsafe.Pointer) C.bool {
|
| 437 |
if fn, ok := cbEncoderBegin[user_data]; ok {
|
bindings/go/whisper_test.go
CHANGED
|
@@ -52,7 +52,7 @@ func Test_Whisper_001(t *testing.T) {
|
|
| 52 |
defer ctx.Whisper_free()
|
| 53 |
params := ctx.Whisper_full_default_params(whisper.SAMPLING_GREEDY)
|
| 54 |
data := buf.AsFloat32Buffer().Data
|
| 55 |
-
err = ctx.Whisper_full(params, data, nil, nil)
|
| 56 |
assert.NoError(err)
|
| 57 |
|
| 58 |
// Print out tokens
|
|
|
|
| 52 |
defer ctx.Whisper_free()
|
| 53 |
params := ctx.Whisper_full_default_params(whisper.SAMPLING_GREEDY)
|
| 54 |
data := buf.AsFloat32Buffer().Data
|
| 55 |
+
err = ctx.Whisper_full(params, data, nil, nil, nil)
|
| 56 |
assert.NoError(err)
|
| 57 |
|
| 58 |
// Print out tokens
|