djthorpe commited on
Commit
728fcbe
·
unverified ·
1 Parent(s): c59ce76

go : adding features to the go-whisper example, go ci, etc (#384)

Browse files

* Updated bindings so they can be used in third pary packages.

* Updated makefiles to set FMA flag on optionally, for xeon E5 on Darwin

* Added test script

* Changes for examples

* Reverted

* Made the NewContext method private

.github/workflows/bindings.yml ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: Bindings Tests
2
+ on:
3
+ push:
4
+ paths:
5
+ - bindings/go/**
6
+
7
+ jobs:
8
+ ubuntu-latest:
9
+ runs-on: ubuntu-latest
10
+ steps:
11
+ - uses: actions/setup-go@v3
12
+ with:
13
+ go-version: '^1.19'
14
+ - uses: actions/checkout@v1
15
+ - run: |
16
+ cd bindings/go
17
+ make test
bindings/go/examples/go-whisper/color.go ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ package main
2
+
3
+ import "fmt"
4
+
5
+ ///////////////////////////////////////////////////////////////////////////////
6
+ // CONSTANTS
7
+
8
+ const (
9
+ Reset = "\033[0m"
10
+ RGBPrefix = "\033[38;5;" // followed by RGB values in decimal format separated by colons
11
+ RGBSuffix = "m"
12
+ )
13
+
14
+ ///////////////////////////////////////////////////////////////////////////////
15
+ // PUBLIC METHODS
16
+
17
+ // Colorize text with RGB values, from 0 to 23
18
+ func Colorize(text string, v int) string {
19
+ // https://en.wikipedia.org/wiki/ANSI_escape_code#8-bit
20
+ // Grayscale colors are in the range 232-255
21
+ return RGBPrefix + fmt.Sprint(v%24+232) + RGBSuffix + text + Reset
22
+ }
bindings/go/examples/go-whisper/flags.go CHANGED
@@ -2,6 +2,12 @@ package main
2
 
3
  import (
4
  "flag"
 
 
 
 
 
 
5
  )
6
 
7
  ///////////////////////////////////////////////////////////////////////////////
@@ -42,6 +48,26 @@ func (flags *Flags) GetLanguage() string {
42
  return flags.Lookup("language").Value.String()
43
  }
44
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
  func (flags *Flags) IsSpeedup() bool {
46
  return flags.Lookup("speedup").Value.String() == "true"
47
  }
@@ -50,12 +76,81 @@ func (flags *Flags) IsTokens() bool {
50
  return flags.Lookup("tokens").Value.String() == "true"
51
  }
52
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
  ///////////////////////////////////////////////////////////////////////////////
54
  // PRIVATE METHODS
55
 
56
  func registerFlags(flag *Flags) {
57
  flag.String("model", "", "Path to the model file")
58
- flag.String("language", "", "Language")
 
 
 
 
59
  flag.Bool("speedup", false, "Enable speedup")
 
 
 
60
  flag.Bool("tokens", false, "Display tokens")
 
 
61
  }
 
2
 
3
  import (
4
  "flag"
5
+ "fmt"
6
+ "strings"
7
+ "time"
8
+
9
+ // Packages
10
+ whisper "github.com/ggerganov/whisper.cpp/bindings/go/pkg/whisper"
11
  )
12
 
13
  ///////////////////////////////////////////////////////////////////////////////
 
48
  return flags.Lookup("language").Value.String()
49
  }
50
 
51
+ func (flags *Flags) IsTranslate() bool {
52
+ return flags.Lookup("translate").Value.(flag.Getter).Get().(bool)
53
+ }
54
+
55
+ func (flags *Flags) GetOffset() time.Duration {
56
+ return flags.Lookup("offset").Value.(flag.Getter).Get().(time.Duration)
57
+ }
58
+
59
+ func (flags *Flags) GetDuration() time.Duration {
60
+ return flags.Lookup("duration").Value.(flag.Getter).Get().(time.Duration)
61
+ }
62
+
63
+ func (flags *Flags) GetThreads() uint {
64
+ return flags.Lookup("threads").Value.(flag.Getter).Get().(uint)
65
+ }
66
+
67
+ func (flags *Flags) GetOut() string {
68
+ return strings.ToLower(flags.Lookup("out").Value.String())
69
+ }
70
+
71
  func (flags *Flags) IsSpeedup() bool {
72
  return flags.Lookup("speedup").Value.String() == "true"
73
  }
 
76
  return flags.Lookup("tokens").Value.String() == "true"
77
  }
78
 
79
+ func (flags *Flags) IsColorize() bool {
80
+ return flags.Lookup("colorize").Value.String() == "true"
81
+ }
82
+
83
+ func (flags *Flags) GetMaxLen() uint {
84
+ return flags.Lookup("max-len").Value.(flag.Getter).Get().(uint)
85
+ }
86
+
87
+ func (flags *Flags) GetMaxTokens() uint {
88
+ return flags.Lookup("max-tokens").Value.(flag.Getter).Get().(uint)
89
+ }
90
+
91
+ func (flags *Flags) GetWordThreshold() float32 {
92
+ return float32(flags.Lookup("word-thold").Value.(flag.Getter).Get().(float64))
93
+ }
94
+
95
+ func (flags *Flags) SetParams(context whisper.Context) error {
96
+ if lang := flags.GetLanguage(); lang != "" && lang != "auto" {
97
+ fmt.Fprintf(flags.Output(), "Setting language to %q\n", lang)
98
+ if err := context.SetLanguage(lang); err != nil {
99
+ return err
100
+ }
101
+ }
102
+ if flags.IsTranslate() && context.IsMultilingual() {
103
+ fmt.Fprintf(flags.Output(), "Setting translate to true\n")
104
+ context.SetTranslate(true)
105
+ }
106
+ if offset := flags.GetOffset(); offset != 0 {
107
+ fmt.Fprintf(flags.Output(), "Setting offset to %v\n", offset)
108
+ context.SetOffset(offset)
109
+ }
110
+ if duration := flags.GetDuration(); duration != 0 {
111
+ fmt.Fprintf(flags.Output(), "Setting duration to %v\n", duration)
112
+ context.SetDuration(duration)
113
+ }
114
+ if flags.IsSpeedup() {
115
+ fmt.Fprintf(flags.Output(), "Setting speedup to true\n")
116
+ context.SetSpeedup(true)
117
+ }
118
+ if threads := flags.GetThreads(); threads != 0 {
119
+ fmt.Fprintf(flags.Output(), "Setting threads to %d\n", threads)
120
+ context.SetThreads(threads)
121
+ }
122
+ if max_len := flags.GetMaxLen(); max_len != 0 {
123
+ fmt.Fprintf(flags.Output(), "Setting max_segment_length to %d\n", max_len)
124
+ context.SetMaxSegmentLength(max_len)
125
+ }
126
+ if max_tokens := flags.GetMaxTokens(); max_tokens != 0 {
127
+ fmt.Fprintf(flags.Output(), "Setting max_tokens to %d\n", max_tokens)
128
+ context.SetMaxTokensPerSegment(max_tokens)
129
+ }
130
+ if word_threshold := flags.GetWordThreshold(); word_threshold != 0 {
131
+ fmt.Fprintf(flags.Output(), "Setting word_threshold to %f\n", word_threshold)
132
+ context.SetTokenThreshold(word_threshold)
133
+ }
134
+
135
+ // Return success
136
+ return nil
137
+ }
138
+
139
  ///////////////////////////////////////////////////////////////////////////////
140
  // PRIVATE METHODS
141
 
142
  func registerFlags(flag *Flags) {
143
  flag.String("model", "", "Path to the model file")
144
+ flag.String("language", "", "Spoken language")
145
+ flag.Bool("translate", false, "Translate from source language to english")
146
+ flag.Duration("offset", 0, "Time offset")
147
+ flag.Duration("duration", 0, "Duration of audio to process")
148
+ flag.Uint("threads", 0, "Number of threads to use")
149
  flag.Bool("speedup", false, "Enable speedup")
150
+ flag.Uint("max-len", 0, "Maximum segment length in characters")
151
+ flag.Uint("max-tokens", 0, "Maximum tokens per segment")
152
+ flag.Float64("word-thold", 0, "Maximum segment score")
153
  flag.Bool("tokens", false, "Display tokens")
154
+ flag.Bool("colorize", false, "Colorize tokens")
155
+ flag.String("out", "", "Output format (srt, none or leave as empty string)")
156
  }
bindings/go/examples/go-whisper/main.go CHANGED
@@ -35,8 +35,7 @@ func main() {
35
 
36
  // Process files
37
  for _, filename := range flags.Args() {
38
- fmt.Println("Processing", filename)
39
- if err := Process(model, filename, flags.GetLanguage(), flags.IsSpeedup(), flags.IsTokens()); err != nil {
40
  fmt.Fprintln(os.Stderr, err)
41
  continue
42
  }
 
35
 
36
  // Process files
37
  for _, filename := range flags.Args() {
38
+ if err := Process(model, filename, flags); err != nil {
 
39
  fmt.Fprintln(os.Stderr, err)
40
  continue
41
  }
bindings/go/examples/go-whisper/process.go CHANGED
@@ -11,7 +11,7 @@ import (
11
  wav "github.com/go-audio/wav"
12
  )
13
 
14
- func Process(model whisper.Model, path string, lang string, speedup, tokens bool) error {
15
  var data []float32
16
 
17
  // Create processing context
@@ -20,14 +20,20 @@ func Process(model whisper.Model, path string, lang string, speedup, tokens bool
20
  return err
21
  }
22
 
 
 
 
 
 
23
  // Open the file
 
24
  fh, err := os.Open(path)
25
  if err != nil {
26
  return err
27
  }
28
  defer fh.Close()
29
 
30
- // Decode the WAV file
31
  dec := wav.NewDecoder(fh)
32
  if buf, err := dec.FullPCMBuffer(); err != nil {
33
  return err
@@ -39,42 +45,83 @@ func Process(model whisper.Model, path string, lang string, speedup, tokens bool
39
  data = buf.AsFloat32Buffer().Data
40
  }
41
 
42
- // Set the parameters
43
  var cb whisper.SegmentCallback
44
- if lang != "" {
45
- if err := context.SetLanguage(lang); err != nil {
46
- return err
47
- }
48
- }
49
- if speedup {
50
- context.SetSpeedup(true)
51
- }
52
- if tokens {
53
  cb = func(segment whisper.Segment) {
54
- fmt.Printf("%02d [%6s->%6s] ", segment.Num, segment.Start.Truncate(time.Millisecond), segment.End.Truncate(time.Millisecond))
55
  for _, token := range segment.Tokens {
56
- fmt.Printf("%q ", token.Text)
 
 
 
 
57
  }
58
- fmt.Println("")
 
59
  }
60
  }
61
 
62
  // Process the data
 
63
  if err := context.Process(data, cb); err != nil {
64
  return err
65
  }
66
 
67
  // Print out the results
 
 
 
 
 
 
 
 
 
 
 
 
 
68
  for {
69
  segment, err := context.NextSegment()
70
  if err == io.EOF {
71
- break
72
  } else if err != nil {
73
  return err
74
  }
75
- fmt.Printf("[%6s->%6s] %s\n", segment.Start.Truncate(time.Millisecond), segment.End.Truncate(time.Millisecond), segment.Text)
 
 
 
 
76
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
 
78
- // Return success
79
- return nil
 
80
  }
 
11
  wav "github.com/go-audio/wav"
12
  )
13
 
14
+ func Process(model whisper.Model, path string, flags *Flags) error {
15
  var data []float32
16
 
17
  // Create processing context
 
20
  return err
21
  }
22
 
23
+ // Set the parameters
24
+ if err := flags.SetParams(context); err != nil {
25
+ return err
26
+ }
27
+
28
  // Open the file
29
+ fmt.Fprintf(flags.Output(), "Loading %q\n", path)
30
  fh, err := os.Open(path)
31
  if err != nil {
32
  return err
33
  }
34
  defer fh.Close()
35
 
36
+ // Decode the WAV file - load the full buffer
37
  dec := wav.NewDecoder(fh)
38
  if buf, err := dec.FullPCMBuffer(); err != nil {
39
  return err
 
45
  data = buf.AsFloat32Buffer().Data
46
  }
47
 
48
+ // Segment callback when -tokens is specified
49
  var cb whisper.SegmentCallback
50
+ if flags.IsTokens() {
 
 
 
 
 
 
 
 
51
  cb = func(segment whisper.Segment) {
52
+ fmt.Fprintf(flags.Output(), "%02d [%6s->%6s] ", segment.Num, segment.Start.Truncate(time.Millisecond), segment.End.Truncate(time.Millisecond))
53
  for _, token := range segment.Tokens {
54
+ if flags.IsColorize() && context.IsText(token) {
55
+ fmt.Fprint(flags.Output(), Colorize(token.Text, int(token.P*24.0)), " ")
56
+ } else {
57
+ fmt.Fprint(flags.Output(), token.Text, " ")
58
+ }
59
  }
60
+ fmt.Fprintln(flags.Output(), "")
61
+ fmt.Fprintln(flags.Output(), "")
62
  }
63
  }
64
 
65
  // Process the data
66
+ fmt.Fprintf(flags.Output(), " ...processing %q\n", path)
67
  if err := context.Process(data, cb); err != nil {
68
  return err
69
  }
70
 
71
  // Print out the results
72
+ switch {
73
+ case flags.GetOut() == "srt":
74
+ return OutputSRT(os.Stdout, context)
75
+ case flags.GetOut() == "none":
76
+ return nil
77
+ default:
78
+ return Output(os.Stdout, context, flags.IsColorize())
79
+ }
80
+ }
81
+
82
+ // Output text as SRT file
83
+ func OutputSRT(w io.Writer, context whisper.Context) error {
84
+ n := 1
85
  for {
86
  segment, err := context.NextSegment()
87
  if err == io.EOF {
88
+ return nil
89
  } else if err != nil {
90
  return err
91
  }
92
+ fmt.Fprintln(w, n)
93
+ fmt.Fprintln(w, srtTimestamp(segment.Start), " --> ", srtTimestamp(segment.End))
94
+ fmt.Fprintln(w, segment.Text)
95
+ fmt.Fprintln(w, "")
96
+ n++
97
  }
98
+ }
99
+
100
+ // Output text to terminal
101
+ func Output(w io.Writer, context whisper.Context, colorize bool) error {
102
+ for {
103
+ segment, err := context.NextSegment()
104
+ if err == io.EOF {
105
+ return nil
106
+ } else if err != nil {
107
+ return err
108
+ }
109
+ fmt.Fprintf(w, "[%6s->%6s]", segment.Start.Truncate(time.Millisecond), segment.End.Truncate(time.Millisecond))
110
+ if colorize {
111
+ for _, token := range segment.Tokens {
112
+ if !context.IsText(token) {
113
+ continue
114
+ }
115
+ fmt.Fprint(w, " ", Colorize(token.Text, int(token.P*24.0)))
116
+ }
117
+ fmt.Fprint(w, "\n")
118
+ } else {
119
+ fmt.Fprintln(w, " ", segment.Text)
120
+ }
121
+ }
122
+ }
123
 
124
+ // Return srtTimestamp
125
+ func srtTimestamp(t time.Duration) string {
126
+ return fmt.Sprintf("%02d:%02d:%02d,%03d", t/time.Hour, (t%time.Hour)/time.Minute, (t%time.Minute)/time.Second, (t%time.Second)/time.Millisecond)
127
  }
bindings/go/params.go CHANGED
@@ -47,6 +47,7 @@ func (p *Params) SetSpeedup(v bool) {
47
  p.speed_up = toBool(v)
48
  }
49
 
 
50
  func (p *Params) SetLanguage(lang int) error {
51
  str := C.whisper_lang_str(C.int(lang))
52
  if str == nil {
@@ -57,6 +58,7 @@ func (p *Params) SetLanguage(lang int) error {
57
  return nil
58
  }
59
 
 
60
  func (p *Params) Language() int {
61
  if p.language == nil {
62
  return -1
@@ -64,18 +66,41 @@ func (p *Params) Language() int {
64
  return int(C.whisper_lang_id(p.language))
65
  }
66
 
 
67
  func (p *Params) SetThreads(threads int) {
68
  p.n_threads = C.int(threads)
69
  }
70
 
 
71
  func (p *Params) SetOffset(offset_ms int) {
72
  p.offset_ms = C.int(offset_ms)
73
  }
74
 
 
75
  func (p *Params) SetDuration(duration_ms int) {
76
  p.duration_ms = C.int(duration_ms)
77
  }
78
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79
  ///////////////////////////////////////////////////////////////////////////////
80
  // PRIVATE METHODS
81
 
 
47
  p.speed_up = toBool(v)
48
  }
49
 
50
+ // Set language id
51
  func (p *Params) SetLanguage(lang int) error {
52
  str := C.whisper_lang_str(C.int(lang))
53
  if str == nil {
 
58
  return nil
59
  }
60
 
61
+ // Get language id
62
  func (p *Params) Language() int {
63
  if p.language == nil {
64
  return -1
 
66
  return int(C.whisper_lang_id(p.language))
67
  }
68
 
69
+ // Set number of threads to use
70
  func (p *Params) SetThreads(threads int) {
71
  p.n_threads = C.int(threads)
72
  }
73
 
74
+ // Set start offset in ms
75
  func (p *Params) SetOffset(offset_ms int) {
76
  p.offset_ms = C.int(offset_ms)
77
  }
78
 
79
+ // Set audio duration to process in ms
80
  func (p *Params) SetDuration(duration_ms int) {
81
  p.duration_ms = C.int(duration_ms)
82
  }
83
 
84
+ // Set timestamp token probability threshold (~0.01)
85
+ func (p *Params) SetTokenThreshold(t float32) {
86
+ p.thold_pt = C.float(t)
87
+ }
88
+
89
+ // Set timestamp token sum probability threshold (~0.01)
90
+ func (p *Params) SetTokenSumThreshold(t float32) {
91
+ p.thold_ptsum = C.float(t)
92
+ }
93
+
94
+ // Set max segment length in characters
95
+ func (p *Params) SetMaxSegmentLength(n int) {
96
+ p.max_len = C.int(n)
97
+ }
98
+
99
+ // Set max tokens per segment (0 = no limit)
100
+ func (p *Params) SetMaxTokensPerSegment(n int) {
101
+ p.max_tokens = C.int(n)
102
+ }
103
+
104
  ///////////////////////////////////////////////////////////////////////////////
105
  // PRIVATE METHODS
106
 
bindings/go/pkg/whisper/consts.go CHANGED
@@ -11,10 +11,11 @@ import (
11
  // ERRORS
12
 
13
  var (
14
- ErrUnableToLoadModel = errors.New("unable to load model")
15
- ErrInternalAppError = errors.New("internal application error")
16
- ErrProcessingFailed = errors.New("processing failed")
17
- ErrUnsupportedLanguage = errors.New("unsupported language")
 
18
  )
19
 
20
  ///////////////////////////////////////////////////////////////////////////////
 
11
  // ERRORS
12
 
13
  var (
14
+ ErrUnableToLoadModel = errors.New("unable to load model")
15
+ ErrInternalAppError = errors.New("internal application error")
16
+ ErrProcessingFailed = errors.New("processing failed")
17
+ ErrUnsupportedLanguage = errors.New("unsupported language")
18
+ ErrModelNotMultilingual = errors.New("model is not multilingual")
19
  )
20
 
21
  ///////////////////////////////////////////////////////////////////////////////
bindings/go/pkg/whisper/context.go CHANGED
@@ -24,7 +24,7 @@ var _ Context = (*context)(nil)
24
  ///////////////////////////////////////////////////////////////////////////////
25
  // LIFECYCLE
26
 
27
- func NewContext(model *model, params whisper.Params) (Context, error) {
28
  context := new(context)
29
  context.model = model
30
  context.params = params
@@ -41,6 +41,9 @@ func (context *context) SetLanguage(lang string) error {
41
  if context.model.ctx == nil {
42
  return ErrInternalAppError
43
  }
 
 
 
44
  if id := context.model.ctx.Whisper_lang_id(lang); id < 0 {
45
  return ErrUnsupportedLanguage
46
  } else if err := context.params.SetLanguage(id); err != nil {
@@ -50,16 +53,60 @@ func (context *context) SetLanguage(lang string) error {
50
  return nil
51
  }
52
 
 
 
 
 
53
  // Get language
54
  func (context *context) Language() string {
55
  return whisper.Whisper_lang_str(context.params.Language())
56
  }
57
 
 
 
 
 
 
58
  // Set speedup flag
59
  func (context *context) SetSpeedup(v bool) {
60
  context.params.SetSpeedup(v)
61
  }
62
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
  // Process new sample data and return any errors
64
  func (context *context) Process(data []float32, cb SegmentCallback) error {
65
  if context.model.ctx == nil {
@@ -119,6 +166,65 @@ func (context *context) NextSegment() (Segment, error) {
119
  return result, nil
120
  }
121
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
122
  ///////////////////////////////////////////////////////////////////////////////
123
  // PRIVATE METHODS
124
 
 
24
  ///////////////////////////////////////////////////////////////////////////////
25
  // LIFECYCLE
26
 
27
+ func newContext(model *model, params whisper.Params) (Context, error) {
28
  context := new(context)
29
  context.model = model
30
  context.params = params
 
41
  if context.model.ctx == nil {
42
  return ErrInternalAppError
43
  }
44
+ if !context.model.IsMultilingual() {
45
+ return ErrModelNotMultilingual
46
+ }
47
  if id := context.model.ctx.Whisper_lang_id(lang); id < 0 {
48
  return ErrUnsupportedLanguage
49
  } else if err := context.params.SetLanguage(id); err != nil {
 
53
  return nil
54
  }
55
 
56
+ func (context *context) IsMultilingual() bool {
57
+ return context.model.IsMultilingual()
58
+ }
59
+
60
  // Get language
61
  func (context *context) Language() string {
62
  return whisper.Whisper_lang_str(context.params.Language())
63
  }
64
 
65
+ // Set translate flag
66
+ func (context *context) SetTranslate(v bool) {
67
+ context.params.SetTranslate(v)
68
+ }
69
+
70
  // Set speedup flag
71
  func (context *context) SetSpeedup(v bool) {
72
  context.params.SetSpeedup(v)
73
  }
74
 
75
+ // Set number of threads to use
76
+ func (context *context) SetThreads(v uint) {
77
+ context.params.SetThreads(int(v))
78
+ }
79
+
80
+ // Set time offset
81
+ func (context *context) SetOffset(v time.Duration) {
82
+ context.params.SetOffset(int(v.Milliseconds()))
83
+ }
84
+
85
+ // Set duration of audio to process
86
+ func (context *context) SetDuration(v time.Duration) {
87
+ context.params.SetOffset(int(v.Milliseconds()))
88
+ }
89
+
90
+ // Set timestamp token probability threshold (~0.01)
91
+ func (context *context) SetTokenThreshold(t float32) {
92
+ context.params.SetTokenThreshold(t)
93
+ }
94
+
95
+ // Set timestamp token sum probability threshold (~0.01)
96
+ func (context *context) SetTokenSumThreshold(t float32) {
97
+ context.params.SetTokenSumThreshold(t)
98
+ }
99
+
100
+ // Set max segment length in characters
101
+ func (context *context) SetMaxSegmentLength(n uint) {
102
+ context.params.SetMaxSegmentLength(int(n))
103
+ }
104
+
105
+ // Set max tokens per segment (0 = no limit)
106
+ func (context *context) SetMaxTokensPerSegment(n uint) {
107
+ context.params.SetMaxTokensPerSegment(int(n))
108
+ }
109
+
110
  // Process new sample data and return any errors
111
  func (context *context) Process(data []float32, cb SegmentCallback) error {
112
  if context.model.ctx == nil {
 
166
  return result, nil
167
  }
168
 
169
+ // Test for text tokens
170
+ func (context *context) IsText(t Token) bool {
171
+ switch {
172
+ case context.IsBEG(t):
173
+ return false
174
+ case context.IsSOT(t):
175
+ return false
176
+ case whisper.Token(t.Id) >= context.model.ctx.Whisper_token_eot():
177
+ return false
178
+ case context.IsPREV(t):
179
+ return false
180
+ case context.IsSOLM(t):
181
+ return false
182
+ case context.IsNOT(t):
183
+ return false
184
+ default:
185
+ return true
186
+ }
187
+ }
188
+
189
+ // Test for "begin" token
190
+ func (context *context) IsBEG(t Token) bool {
191
+ return whisper.Token(t.Id) == context.model.ctx.Whisper_token_beg()
192
+ }
193
+
194
+ // Test for "start of transcription" token
195
+ func (context *context) IsSOT(t Token) bool {
196
+ return whisper.Token(t.Id) == context.model.ctx.Whisper_token_sot()
197
+ }
198
+
199
+ // Test for "end of transcription" token
200
+ func (context *context) IsEOT(t Token) bool {
201
+ return whisper.Token(t.Id) == context.model.ctx.Whisper_token_eot()
202
+ }
203
+
204
+ // Test for "start of prev" token
205
+ func (context *context) IsPREV(t Token) bool {
206
+ return whisper.Token(t.Id) == context.model.ctx.Whisper_token_prev()
207
+ }
208
+
209
+ // Test for "start of lm" token
210
+ func (context *context) IsSOLM(t Token) bool {
211
+ return whisper.Token(t.Id) == context.model.ctx.Whisper_token_solm()
212
+ }
213
+
214
+ // Test for "No timestamps" token
215
+ func (context *context) IsNOT(t Token) bool {
216
+ return whisper.Token(t.Id) == context.model.ctx.Whisper_token_not()
217
+ }
218
+
219
+ // Test for token associated with a specific language
220
+ func (context *context) IsLANG(t Token, lang string) bool {
221
+ if id := context.model.ctx.Whisper_lang_id(lang); id >= 0 {
222
+ return whisper.Token(t.Id) == context.model.ctx.Whisper_token_lang(id)
223
+ } else {
224
+ return false
225
+ }
226
+ }
227
+
228
  ///////////////////////////////////////////////////////////////////////////////
229
  // PRIVATE METHODS
230
 
bindings/go/pkg/whisper/interface.go CHANGED
@@ -20,6 +20,9 @@ type Model interface {
20
  // Return a new speech-to-text context.
21
  NewContext() (Context, error)
22
 
 
 
 
23
  // Return all languages supported.
24
  Languages() []string
25
  }
@@ -27,8 +30,18 @@ type Model interface {
27
  // Context is the speach recognition context.
28
  type Context interface {
29
  SetLanguage(string) error // Set the language to use for speech recognition.
 
 
30
  Language() string // Get language
31
- SetSpeedup(bool) // Set speedup flag
 
 
 
 
 
 
 
 
32
 
33
  // Process mono audio data and return any errors.
34
  // If defined, newly generated segments are passed to the
@@ -38,6 +51,15 @@ type Context interface {
38
  // After process is called, return segments until the end of the stream
39
  // is reached, when io.EOF is returned.
40
  NextSegment() (Segment, error)
 
 
 
 
 
 
 
 
 
41
  }
42
 
43
  // Segment is the text result of a speech recognition.
 
20
  // Return a new speech-to-text context.
21
  NewContext() (Context, error)
22
 
23
+ // Return true if the model is multilingual.
24
+ IsMultilingual() bool
25
+
26
  // Return all languages supported.
27
  Languages() []string
28
  }
 
30
  // Context is the speach recognition context.
31
  type Context interface {
32
  SetLanguage(string) error // Set the language to use for speech recognition.
33
+ SetTranslate(bool) // Set translate flag
34
+ IsMultilingual() bool // Return true if the model is multilingual.
35
  Language() string // Get language
36
+
37
+ SetOffset(time.Duration) // Set offset
38
+ SetDuration(time.Duration) // Set duration
39
+ SetThreads(uint) // Set number of threads to use
40
+ SetSpeedup(bool) // Set speedup flag
41
+ SetTokenThreshold(float32) // Set timestamp token probability threshold
42
+ SetTokenSumThreshold(float32) // Set timestamp token sum probability threshold
43
+ SetMaxSegmentLength(uint) // Set max segment length in characters
44
+ SetMaxTokensPerSegment(uint) // Set max tokens per segment (0 = no limit)
45
 
46
  // Process mono audio data and return any errors.
47
  // If defined, newly generated segments are passed to the
 
51
  // After process is called, return segments until the end of the stream
52
  // is reached, when io.EOF is returned.
53
  NextSegment() (Segment, error)
54
+
55
+ IsBEG(Token) bool // Test for "begin" token
56
+ IsSOT(Token) bool // Test for "start of transcription" token
57
+ IsEOT(Token) bool // Test for "end of transcription" token
58
+ IsPREV(Token) bool // Test for "start of prev" token
59
+ IsSOLM(Token) bool // Test for "start of lm" token
60
+ IsNOT(Token) bool // Test for "No timestamps" token
61
+ IsLANG(Token, string) bool // Test for token associated with a specific language
62
+ IsText(Token) bool // Test for text token
63
  }
64
 
65
  // Segment is the text result of a speech recognition.
bindings/go/pkg/whisper/model.go CHANGED
@@ -23,7 +23,7 @@ var _ Model = (*model)(nil)
23
  ///////////////////////////////////////////////////////////////////////////////
24
  // LIFECYCLE
25
 
26
- func New(path string) (*model, error) {
27
  model := new(model)
28
  if _, err := os.Stat(path); err != nil {
29
  return nil, err
@@ -64,6 +64,11 @@ func (model *model) String() string {
64
  ///////////////////////////////////////////////////////////////////////////////
65
  // PUBLIC METHODS
66
 
 
 
 
 
 
67
  // Return all recognized languages. Initially it is set to auto-detect
68
  func (model *model) Languages() []string {
69
  result := make([]string, 0, whisper.Whisper_lang_max_id())
@@ -91,5 +96,5 @@ func (model *model) NewContext() (Context, error) {
91
  params.SetThreads(runtime.NumCPU())
92
 
93
  // Return new context
94
- return NewContext(model, params)
95
  }
 
23
  ///////////////////////////////////////////////////////////////////////////////
24
  // LIFECYCLE
25
 
26
+ func New(path string) (Model, error) {
27
  model := new(model)
28
  if _, err := os.Stat(path); err != nil {
29
  return nil, err
 
64
  ///////////////////////////////////////////////////////////////////////////////
65
  // PUBLIC METHODS
66
 
67
+ // Return true if model is multilingual (language and translation options are supported)
68
+ func (model *model) IsMultilingual() bool {
69
+ return model.ctx.Whisper_is_multilingual() != 0
70
+ }
71
+
72
  // Return all recognized languages. Initially it is set to auto-detect
73
  func (model *model) Languages() []string {
74
  result := make([]string, 0, whisper.Whisper_lang_max_id())
 
96
  params.SetThreads(runtime.NumCPU())
97
 
98
  // Return new context
99
+ return newContext(model, params)
100
  }