prsyahmi ggerganov commited on
Commit
8060473
·
unverified ·
1 Parent(s): 6e57274

whisper : add loader class to allow loading from buffer and others (#353)

Browse files

* whisper : add loader to allow loading from other than file

* whisper : rename whisper_init to whisper_init_from_file

* whisper : add whisper_init_from_buffer

* android : Delete local.properties

* android : load models directly from assets

* whisper : adding <stddef.h> needed for size_t + code style

Co-authored-by: Georgi Gerganov <[email protected]>

bindings/go/whisper.go CHANGED
@@ -91,7 +91,7 @@ var (
91
  func Whisper_init(path string) *Context {
92
  cPath := C.CString(path)
93
  defer C.free(unsafe.Pointer(cPath))
94
- if ctx := C.whisper_init(cPath); ctx != nil {
95
  return (*Context)(ctx)
96
  } else {
97
  return nil
 
91
  func Whisper_init(path string) *Context {
92
  cPath := C.CString(path)
93
  defer C.free(unsafe.Pointer(cPath))
94
+ if ctx := C.whisper_init_from_file(cPath); ctx != nil {
95
  return (*Context)(ctx)
96
  } else {
97
  return nil
bindings/javascript/emscripten.cpp CHANGED
@@ -20,7 +20,7 @@ struct whisper_context * g_context;
20
  EMSCRIPTEN_BINDINGS(whisper) {
21
  emscripten::function("init", emscripten::optional_override([](const std::string & path_model) {
22
  if (g_context == nullptr) {
23
- g_context = whisper_init(path_model.c_str());
24
  if (g_context != nullptr) {
25
  return true;
26
  } else {
 
20
  EMSCRIPTEN_BINDINGS(whisper) {
21
  emscripten::function("init", emscripten::optional_override([](const std::string & path_model) {
22
  if (g_context == nullptr) {
23
+ g_context = whisper_init_from_file(path_model.c_str());
24
  if (g_context != nullptr) {
25
  return true;
26
  } else {
examples/bench.wasm/emscripten.cpp CHANGED
@@ -52,7 +52,7 @@ EMSCRIPTEN_BINDINGS(bench) {
52
  emscripten::function("init", emscripten::optional_override([](const std::string & path_model) {
53
  for (size_t i = 0; i < g_contexts.size(); ++i) {
54
  if (g_contexts[i] == nullptr) {
55
- g_contexts[i] = whisper_init(path_model.c_str());
56
  if (g_contexts[i] != nullptr) {
57
  if (g_worker.joinable()) {
58
  g_worker.join();
 
52
  emscripten::function("init", emscripten::optional_override([](const std::string & path_model) {
53
  for (size_t i = 0; i < g_contexts.size(); ++i) {
54
  if (g_contexts[i] == nullptr) {
55
+ g_contexts[i] = whisper_init_from_file(path_model.c_str());
56
  if (g_contexts[i] != nullptr) {
57
  if (g_worker.joinable()) {
58
  g_worker.join();
examples/bench/bench.cpp CHANGED
@@ -53,7 +53,7 @@ int main(int argc, char ** argv) {
53
 
54
  // whisper init
55
 
56
- struct whisper_context * ctx = whisper_init(params.model.c_str());
57
 
58
  {
59
  fprintf(stderr, "\n");
 
53
 
54
  // whisper init
55
 
56
+ struct whisper_context * ctx = whisper_init_from_file(params.model.c_str());
57
 
58
  {
59
  fprintf(stderr, "\n");
examples/command.wasm/emscripten.cpp CHANGED
@@ -324,7 +324,7 @@ EMSCRIPTEN_BINDINGS(command) {
324
  emscripten::function("init", emscripten::optional_override([](const std::string & path_model) {
325
  for (size_t i = 0; i < g_contexts.size(); ++i) {
326
  if (g_contexts[i] == nullptr) {
327
- g_contexts[i] = whisper_init(path_model.c_str());
328
  if (g_contexts[i] != nullptr) {
329
  g_running = true;
330
  if (g_worker.joinable()) {
 
324
  emscripten::function("init", emscripten::optional_override([](const std::string & path_model) {
325
  for (size_t i = 0; i < g_contexts.size(); ++i) {
326
  if (g_contexts[i] == nullptr) {
327
+ g_contexts[i] = whisper_init_from_file(path_model.c_str());
328
  if (g_contexts[i] != nullptr) {
329
  g_running = true;
330
  if (g_worker.joinable()) {
examples/command/command.cpp CHANGED
@@ -931,7 +931,7 @@ int main(int argc, char ** argv) {
931
 
932
  // whisper init
933
 
934
- struct whisper_context * ctx = whisper_init(params.model.c_str());
935
 
936
  // print some info about the processing
937
  {
 
931
 
932
  // whisper init
933
 
934
+ struct whisper_context * ctx = whisper_init_from_file(params.model.c_str());
935
 
936
  // print some info about the processing
937
  {
examples/main/main.cpp CHANGED
@@ -478,7 +478,7 @@ int main(int argc, char ** argv) {
478
 
479
  // whisper init
480
 
481
- struct whisper_context * ctx = whisper_init(params.model.c_str());
482
 
483
  if (ctx == nullptr) {
484
  fprintf(stderr, "error: failed to initialize whisper context\n");
 
478
 
479
  // whisper init
480
 
481
+ struct whisper_context * ctx = whisper_init_from_file(params.model.c_str());
482
 
483
  if (ctx == nullptr) {
484
  fprintf(stderr, "error: failed to initialize whisper context\n");
examples/stream.wasm/emscripten.cpp CHANGED
@@ -129,7 +129,7 @@ EMSCRIPTEN_BINDINGS(stream) {
129
  emscripten::function("init", emscripten::optional_override([](const std::string & path_model) {
130
  for (size_t i = 0; i < g_contexts.size(); ++i) {
131
  if (g_contexts[i] == nullptr) {
132
- g_contexts[i] = whisper_init(path_model.c_str());
133
  if (g_contexts[i] != nullptr) {
134
  g_running = true;
135
  if (g_worker.joinable()) {
 
129
  emscripten::function("init", emscripten::optional_override([](const std::string & path_model) {
130
  for (size_t i = 0; i < g_contexts.size(); ++i) {
131
  if (g_contexts[i] == nullptr) {
132
+ g_contexts[i] = whisper_init_from_file(path_model.c_str());
133
  if (g_contexts[i] != nullptr) {
134
  g_running = true;
135
  if (g_worker.joinable()) {
examples/stream/stream.cpp CHANGED
@@ -456,7 +456,7 @@ int main(int argc, char ** argv) {
456
  exit(0);
457
  }
458
 
459
- struct whisper_context * ctx = whisper_init(params.model.c_str());
460
 
461
  std::vector<float> pcmf32 (n_samples_30s, 0.0f);
462
  std::vector<float> pcmf32_old(n_samples_30s, 0.0f);
 
456
  exit(0);
457
  }
458
 
459
+ struct whisper_context * ctx = whisper_init_from_file(params.model.c_str());
460
 
461
  std::vector<float> pcmf32 (n_samples_30s, 0.0f);
462
  std::vector<float> pcmf32_old(n_samples_30s, 0.0f);
examples/talk.wasm/emscripten.cpp CHANGED
@@ -271,7 +271,7 @@ EMSCRIPTEN_BINDINGS(talk) {
271
  emscripten::function("init", emscripten::optional_override([](const std::string & path_model) {
272
  for (size_t i = 0; i < g_contexts.size(); ++i) {
273
  if (g_contexts[i] == nullptr) {
274
- g_contexts[i] = whisper_init(path_model.c_str());
275
  if (g_contexts[i] != nullptr) {
276
  g_running = true;
277
  if (g_worker.joinable()) {
 
271
  emscripten::function("init", emscripten::optional_override([](const std::string & path_model) {
272
  for (size_t i = 0; i < g_contexts.size(); ++i) {
273
  if (g_contexts[i] == nullptr) {
274
+ g_contexts[i] = whisper_init_from_file(path_model.c_str());
275
  if (g_contexts[i] != nullptr) {
276
  g_running = true;
277
  if (g_worker.joinable()) {
examples/talk/talk.cpp CHANGED
@@ -498,7 +498,7 @@ int main(int argc, char ** argv) {
498
 
499
  // whisper init
500
 
501
- struct whisper_context * ctx_wsp = whisper_init(params.model_wsp.c_str());
502
 
503
  // gpt init
504
 
 
498
 
499
  // whisper init
500
 
501
+ struct whisper_context * ctx_wsp = whisper_init_from_file(params.model_wsp.c_str());
502
 
503
  // gpt init
504
 
examples/whisper.android/app/src/main/java/com/whispercppdemo/ui/main/MainScreenViewModel.kt CHANGED
@@ -64,16 +64,22 @@ class MainScreenViewModel(private val application: Application) : ViewModel() {
64
  private suspend fun copyAssets() = withContext(Dispatchers.IO) {
65
  modelsPath.mkdirs()
66
  samplesPath.mkdirs()
67
- application.copyData("models", modelsPath, ::printMessage)
68
  application.copyData("samples", samplesPath, ::printMessage)
69
  printMessage("All data copied to working directory.\n")
70
  }
71
 
72
  private suspend fun loadBaseModel() = withContext(Dispatchers.IO) {
73
  printMessage("Loading model...\n")
74
- val firstModel = modelsPath.listFiles()!!.first()
75
- whisperContext = WhisperContext.createContext(firstModel.absolutePath)
76
- printMessage("Loaded model ${firstModel.name}.\n")
 
 
 
 
 
 
77
  }
78
 
79
  fun transcribeSample() = viewModelScope.launch {
 
64
  private suspend fun copyAssets() = withContext(Dispatchers.IO) {
65
  modelsPath.mkdirs()
66
  samplesPath.mkdirs()
67
+ //application.copyData("models", modelsPath, ::printMessage)
68
  application.copyData("samples", samplesPath, ::printMessage)
69
  printMessage("All data copied to working directory.\n")
70
  }
71
 
72
  private suspend fun loadBaseModel() = withContext(Dispatchers.IO) {
73
  printMessage("Loading model...\n")
74
+ val models = application.assets.list("models/")
75
+ if (models != null) {
76
+ val inputstream = application.assets.open("models/" + models[0])
77
+ whisperContext = WhisperContext.createContextFromInputStream(inputstream)
78
+ printMessage("Loaded model ${models[0]}.\n")
79
+ }
80
+
81
+ //val firstModel = modelsPath.listFiles()!!.first()
82
+ //whisperContext = WhisperContext.createContextFromFile(firstModel.absolutePath)
83
  }
84
 
85
  fun transcribeSample() = viewModelScope.launch {
examples/whisper.android/app/src/main/java/com/whispercppdemo/whisper/LibWhisper.kt CHANGED
@@ -4,6 +4,7 @@ import android.os.Build
4
  import android.util.Log
5
  import kotlinx.coroutines.*
6
  import java.io.File
 
7
  import java.util.concurrent.Executors
8
 
9
  private const val LOG_TAG = "LibWhisper"
@@ -39,13 +40,22 @@ class WhisperContext private constructor(private var ptr: Long) {
39
  }
40
 
41
  companion object {
42
- fun createContext(filePath: String): WhisperContext {
43
  val ptr = WhisperLib.initContext(filePath)
44
  if (ptr == 0L) {
45
  throw java.lang.RuntimeException("Couldn't create context with path $filePath")
46
  }
47
  return WhisperContext(ptr)
48
  }
 
 
 
 
 
 
 
 
 
49
  }
50
  }
51
 
@@ -76,6 +86,7 @@ private class WhisperLib {
76
  }
77
 
78
  // JNI methods
 
79
  external fun initContext(modelPath: String): Long
80
  external fun freeContext(contextPtr: Long)
81
  external fun fullTranscribe(contextPtr: Long, audioData: FloatArray)
 
4
  import android.util.Log
5
  import kotlinx.coroutines.*
6
  import java.io.File
7
+ import java.io.InputStream
8
  import java.util.concurrent.Executors
9
 
10
  private const val LOG_TAG = "LibWhisper"
 
40
  }
41
 
42
  companion object {
43
+ fun createContextFromFile(filePath: String): WhisperContext {
44
  val ptr = WhisperLib.initContext(filePath)
45
  if (ptr == 0L) {
46
  throw java.lang.RuntimeException("Couldn't create context with path $filePath")
47
  }
48
  return WhisperContext(ptr)
49
  }
50
+
51
+ fun createContextFromInputStream(stream: InputStream): WhisperContext {
52
+ val ptr = WhisperLib.initContextFromInputStream(stream)
53
+
54
+ if (ptr == 0L) {
55
+ throw java.lang.RuntimeException("Couldn't create context from input stream")
56
+ }
57
+ return WhisperContext(ptr)
58
+ }
59
  }
60
  }
61
 
 
86
  }
87
 
88
  // JNI methods
89
+ external fun initContextFromInputStream(inputStream: InputStream): Long
90
  external fun initContext(modelPath: String): Long
91
  external fun freeContext(contextPtr: Long)
92
  external fun fullTranscribe(contextPtr: Long, audioData: FloatArray)
examples/whisper.android/app/src/main/jni/whisper/jni.c CHANGED
@@ -2,6 +2,7 @@
2
  #include <android/log.h>
3
  #include <stdlib.h>
4
  #include <sys/sysinfo.h>
 
5
  #include "whisper.h"
6
 
7
  #define UNUSED(x) (void)(x)
@@ -17,13 +18,86 @@ static inline int max(int a, int b) {
17
  return (a > b) ? a : b;
18
  }
19
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  JNIEXPORT jlong JNICALL
21
  Java_com_whispercppdemo_whisper_WhisperLib_00024Companion_initContext(
22
  JNIEnv *env, jobject thiz, jstring model_path_str) {
23
  UNUSED(thiz);
24
  struct whisper_context *context = NULL;
25
  const char *model_path_chars = (*env)->GetStringUTFChars(env, model_path_str, NULL);
26
- context = whisper_init(model_path_chars);
27
  (*env)->ReleaseStringUTFChars(env, model_path_str, model_path_chars);
28
  return (jlong) context;
29
  }
 
2
  #include <android/log.h>
3
  #include <stdlib.h>
4
  #include <sys/sysinfo.h>
5
+ #include <string.h>
6
  #include "whisper.h"
7
 
8
  #define UNUSED(x) (void)(x)
 
18
  return (a > b) ? a : b;
19
  }
20
 
21
+ struct input_stream_context {
22
+ size_t offset;
23
+ JNIEnv * env;
24
+ jobject thiz;
25
+ jobject input_stream;
26
+
27
+ jmethodID mid_available;
28
+ jmethodID mid_read;
29
+ };
30
+
31
+ size_t inputStreamRead(void * ctx, void * output, size_t read_size) {
32
+ struct input_stream_context* is = (struct input_stream_context*)ctx;
33
+
34
+ jint avail_size = (*is->env)->CallIntMethod(is->env, is->input_stream, is->mid_available);
35
+ jint size_to_copy = read_size < avail_size ? (jint)read_size : avail_size;
36
+
37
+ jbyteArray byte_array = (*is->env)->NewByteArray(is->env, size_to_copy);
38
+
39
+ jint n_read = (*is->env)->CallIntMethod(is->env, is->input_stream, is->mid_read, byte_array, 0, size_to_copy);
40
+
41
+ if (size_to_copy != read_size || size_to_copy != n_read) {
42
+ LOGI("Insufficient Read: Req=%zu, ToCopy=%d, Available=%d", read_size, size_to_copy, n_read);
43
+ }
44
+
45
+ jbyte* byte_array_elements = (*is->env)->GetByteArrayElements(is->env, byte_array, NULL);
46
+ memcpy(output, byte_array_elements, size_to_copy);
47
+ (*is->env)->ReleaseByteArrayElements(is->env, byte_array, byte_array_elements, JNI_ABORT);
48
+
49
+ (*is->env)->DeleteLocalRef(is->env, byte_array);
50
+
51
+ is->offset += size_to_copy;
52
+
53
+ return size_to_copy;
54
+ }
55
+ bool inputStreamEof(void * ctx) {
56
+ struct input_stream_context* is = (struct input_stream_context*)ctx;
57
+
58
+ jint result = (*is->env)->CallIntMethod(is->env, is->input_stream, is->mid_available);
59
+ return result <= 0;
60
+ }
61
+ void inputStreamClose(void * ctx) {
62
+
63
+ }
64
+
65
+ JNIEXPORT jlong JNICALL
66
+ Java_com_whispercppdemo_whisper_WhisperLib_00024Companion_initContextFromInputStream(
67
+ JNIEnv *env, jobject thiz, jobject input_stream) {
68
+ UNUSED(thiz);
69
+
70
+ struct whisper_context *context = NULL;
71
+ struct whisper_model_loader loader = {};
72
+ struct input_stream_context inp_ctx = {};
73
+
74
+ inp_ctx.offset = 0;
75
+ inp_ctx.env = env;
76
+ inp_ctx.thiz = thiz;
77
+ inp_ctx.input_stream = input_stream;
78
+
79
+ jclass cls = (*env)->GetObjectClass(env, input_stream);
80
+ inp_ctx.mid_available = (*env)->GetMethodID(env, cls, "available", "()I");
81
+ inp_ctx.mid_read = (*env)->GetMethodID(env, cls, "read", "([BII)I");
82
+
83
+ loader.context = &inp_ctx;
84
+ loader.read = inputStreamRead;
85
+ loader.eof = inputStreamEof;
86
+ loader.close = inputStreamClose;
87
+
88
+ loader.eof(loader.context);
89
+
90
+ context = whisper_init(&loader);
91
+ return (jlong) context;
92
+ }
93
+
94
  JNIEXPORT jlong JNICALL
95
  Java_com_whispercppdemo_whisper_WhisperLib_00024Companion_initContext(
96
  JNIEnv *env, jobject thiz, jstring model_path_str) {
97
  UNUSED(thiz);
98
  struct whisper_context *context = NULL;
99
  const char *model_path_chars = (*env)->GetStringUTFChars(env, model_path_str, NULL);
100
+ context = whisper_init_from_file(model_path_chars);
101
  (*env)->ReleaseStringUTFChars(env, model_path_str, model_path_chars);
102
  return (jlong) context;
103
  }
examples/whisper.android/local.properties DELETED
@@ -1,10 +0,0 @@
1
- ## This file is automatically generated by Android Studio.
2
- # Do not modify this file -- YOUR CHANGES WILL BE ERASED!
3
- #
4
- # This file should *NOT* be checked into Version Control Systems,
5
- # as it contains information specific to your local configuration.
6
- #
7
- # Location of the SDK. This is only used by Gradle.
8
- # For customization when using a Version Control System, please read the
9
- # header note.
10
- sdk.dir=/Users/kevin/Library/Android/sdk
 
 
 
 
 
 
 
 
 
 
 
examples/whisper.objc/whisper.objc/ViewController.m CHANGED
@@ -61,7 +61,7 @@ void AudioInputCallback(void * inUserData,
61
  NSLog(@"Loading model from %@", modelPath);
62
 
63
  // create ggml context
64
- stateInp.ctx = whisper_init([modelPath UTF8String]);
65
 
66
  // check if the model was loaded successfully
67
  if (stateInp.ctx == NULL) {
 
61
  NSLog(@"Loading model from %@", modelPath);
62
 
63
  // create ggml context
64
+ stateInp.ctx = whisper_init_from_file([modelPath UTF8String]);
65
 
66
  // check if the model was loaded successfully
67
  if (stateInp.ctx == NULL) {
examples/whisper.swiftui/whisper.cpp.swift/LibWhisper.swift CHANGED
@@ -55,7 +55,7 @@ actor WhisperContext {
55
  }
56
 
57
  static func createContext(path: String) throws -> WhisperContext {
58
- let context = whisper_init(path)
59
  if let context {
60
  return WhisperContext(context: context)
61
  } else {
 
55
  }
56
 
57
  static func createContext(path: String) throws -> WhisperContext {
58
+ let context = whisper_init_from_file(path)
59
  if let context {
60
  return WhisperContext(context: context)
61
  } else {
examples/whisper.wasm/emscripten.cpp CHANGED
@@ -18,7 +18,7 @@ EMSCRIPTEN_BINDINGS(whisper) {
18
 
19
  for (size_t i = 0; i < g_contexts.size(); ++i) {
20
  if (g_contexts[i] == nullptr) {
21
- g_contexts[i] = whisper_init(path_model.c_str());
22
  if (g_contexts[i] != nullptr) {
23
  return i + 1;
24
  } else {
 
18
 
19
  for (size_t i = 0; i < g_contexts.size(); ++i) {
20
  if (g_contexts[i] == nullptr) {
21
+ g_contexts[i] = whisper_init_from_file(path_model.c_str());
22
  if (g_contexts[i] != nullptr) {
23
  return i + 1;
24
  } else {
whisper.cpp CHANGED
@@ -437,8 +437,8 @@ struct whisper_context {
437
  };
438
 
439
  template<typename T>
440
- static void read_safe(std::ifstream& fin, T& dest) {
441
- fin.read((char*)& dest, sizeof(T));
442
  }
443
 
444
  // load the model from a ggml file
@@ -452,24 +452,18 @@ static void read_safe(std::ifstream& fin, T& dest) {
452
  //
453
  // see the convert-pt-to-ggml.py script for details
454
  //
455
- static bool whisper_model_load(const std::string & fname, whisper_context & wctx) {
456
- fprintf(stderr, "%s: loading model from '%s'\n", __func__, fname.c_str());
457
 
458
  auto & model = wctx.model;
459
  auto & vocab = wctx.vocab;
460
 
461
- auto fin = std::ifstream(fname, std::ios::binary);
462
- if (!fin) {
463
- fprintf(stderr, "%s: failed to open '%s'\n", __func__, fname.c_str());
464
- return false;
465
- }
466
-
467
  // verify magic
468
  {
469
  uint32_t magic;
470
- read_safe(fin, magic);
471
  if (magic != 0x67676d6c) {
472
- fprintf(stderr, "%s: invalid model file '%s' (bad magic)\n", __func__, fname.c_str());
473
  return false;
474
  }
475
  }
@@ -478,17 +472,17 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx
478
  {
479
  auto & hparams = model.hparams;
480
 
481
- read_safe(fin, hparams.n_vocab);
482
- read_safe(fin, hparams.n_audio_ctx);
483
- read_safe(fin, hparams.n_audio_state);
484
- read_safe(fin, hparams.n_audio_head);
485
- read_safe(fin, hparams.n_audio_layer);
486
- read_safe(fin, hparams.n_text_ctx);
487
- read_safe(fin, hparams.n_text_state);
488
- read_safe(fin, hparams.n_text_head);
489
- read_safe(fin, hparams.n_text_layer);
490
- read_safe(fin, hparams.n_mels);
491
- read_safe(fin, hparams.f16);
492
 
493
  assert(hparams.n_text_state == hparams.n_audio_state);
494
 
@@ -536,17 +530,17 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx
536
  {
537
  auto & filters = wctx.model.filters;
538
 
539
- read_safe(fin, filters.n_mel);
540
- read_safe(fin, filters.n_fft);
541
 
542
  filters.data.resize(filters.n_mel * filters.n_fft);
543
- fin.read((char *) filters.data.data(), filters.data.size() * sizeof(float));
544
  }
545
 
546
  // load vocab
547
  {
548
  int32_t n_vocab = 0;
549
- read_safe(fin, n_vocab);
550
 
551
  //if (n_vocab != model.hparams.n_vocab) {
552
  // fprintf(stderr, "%s: invalid model file '%s' (bad vocab size %d != %d)\n",
@@ -561,11 +555,11 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx
561
 
562
  for (int i = 0; i < n_vocab; i++) {
563
  uint32_t len;
564
- read_safe(fin, len);
565
 
566
  if (len > 0) {
567
  tmp.resize(len);
568
- fin.read(&tmp[0], tmp.size()); // read to buffer
569
  word.assign(&tmp[0], tmp.size());
570
  } else {
571
  // seems like we have an empty-string token in multi-language models (i = 50256)
@@ -1017,24 +1011,24 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx
1017
  int32_t length;
1018
  int32_t ftype;
1019
 
1020
- read_safe(fin, n_dims);
1021
- read_safe(fin, length);
1022
- read_safe(fin, ftype);
1023
 
1024
- if (fin.eof()) {
1025
  break;
1026
  }
1027
 
1028
  int32_t nelements = 1;
1029
  int32_t ne[3] = { 1, 1, 1 };
1030
  for (int i = 0; i < n_dims; ++i) {
1031
- read_safe(fin, ne[i]);
1032
  nelements *= ne[i];
1033
  }
1034
 
1035
  std::string name;
1036
  std::vector<char> tmp(length); // create a buffer
1037
- fin.read(&tmp[0], tmp.size()); // read to buffer
1038
  name.assign(&tmp[0], tmp.size());
1039
 
1040
  if (model.tensors.find(name) == model.tensors.end()) {
@@ -1062,7 +1056,7 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx
1062
  return false;
1063
  }
1064
 
1065
- fin.read(reinterpret_cast<char *>(tensor->data), ggml_nbytes(tensor));
1066
 
1067
  //printf("%48s - [%5d, %5d, %5d], type = %6s, %6.2f MB\n", name.data(), ne[0], ne[1], ne[2], ftype == 0 ? "float" : "f16", ggml_nbytes(tensor)/1024.0/1024.0);
1068
  total_size += ggml_nbytes(tensor);
@@ -1079,8 +1073,6 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx
1079
  }
1080
  }
1081
 
1082
- fin.close();
1083
-
1084
  return true;
1085
  }
1086
 
@@ -2240,7 +2232,74 @@ static std::vector<whisper_vocab::id> tokenize(const whisper_vocab & vocab, cons
2240
  // interface implementation
2241
  //
2242
 
2243
- struct whisper_context * whisper_init(const char * path_model) {
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2244
  ggml_time_init();
2245
 
2246
  whisper_context * ctx = new whisper_context;
@@ -2249,14 +2308,17 @@ struct whisper_context * whisper_init(const char * path_model) {
2249
 
2250
  ctx->t_start_us = t_start_us;
2251
 
2252
- if (!whisper_model_load(path_model, *ctx)) {
2253
- fprintf(stderr, "%s: failed to load model from '%s'\n", __func__, path_model);
 
2254
  delete ctx;
2255
  return nullptr;
2256
  }
2257
 
2258
  ctx->t_load_us = ggml_time_us() - t_start_us;
2259
 
 
 
2260
  return ctx;
2261
  }
2262
 
 
437
  };
438
 
439
  template<typename T>
440
+ static void read_safe(whisper_model_loader * loader, T & dest) {
441
+ loader->read(loader->context, &dest, sizeof(T));
442
  }
443
 
444
  // load the model from a ggml file
 
452
  //
453
  // see the convert-pt-to-ggml.py script for details
454
  //
455
+ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_context & wctx) {
456
+ fprintf(stderr, "%s: loading model\n", __func__);
457
 
458
  auto & model = wctx.model;
459
  auto & vocab = wctx.vocab;
460
 
 
 
 
 
 
 
461
  // verify magic
462
  {
463
  uint32_t magic;
464
+ read_safe(loader, magic);
465
  if (magic != 0x67676d6c) {
466
+ fprintf(stderr, "%s: invalid model data (bad magic)\n", __func__);
467
  return false;
468
  }
469
  }
 
472
  {
473
  auto & hparams = model.hparams;
474
 
475
+ read_safe(loader, hparams.n_vocab);
476
+ read_safe(loader, hparams.n_audio_ctx);
477
+ read_safe(loader, hparams.n_audio_state);
478
+ read_safe(loader, hparams.n_audio_head);
479
+ read_safe(loader, hparams.n_audio_layer);
480
+ read_safe(loader, hparams.n_text_ctx);
481
+ read_safe(loader, hparams.n_text_state);
482
+ read_safe(loader, hparams.n_text_head);
483
+ read_safe(loader, hparams.n_text_layer);
484
+ read_safe(loader, hparams.n_mels);
485
+ read_safe(loader, hparams.f16);
486
 
487
  assert(hparams.n_text_state == hparams.n_audio_state);
488
 
 
530
  {
531
  auto & filters = wctx.model.filters;
532
 
533
+ read_safe(loader, filters.n_mel);
534
+ read_safe(loader, filters.n_fft);
535
 
536
  filters.data.resize(filters.n_mel * filters.n_fft);
537
+ loader->read(loader->context, filters.data.data(), filters.data.size() * sizeof(float));
538
  }
539
 
540
  // load vocab
541
  {
542
  int32_t n_vocab = 0;
543
+ read_safe(loader, n_vocab);
544
 
545
  //if (n_vocab != model.hparams.n_vocab) {
546
  // fprintf(stderr, "%s: invalid model file '%s' (bad vocab size %d != %d)\n",
 
555
 
556
  for (int i = 0; i < n_vocab; i++) {
557
  uint32_t len;
558
+ read_safe(loader, len);
559
 
560
  if (len > 0) {
561
  tmp.resize(len);
562
+ loader->read(loader->context, &tmp[0], tmp.size()); // read to buffer
563
  word.assign(&tmp[0], tmp.size());
564
  } else {
565
  // seems like we have an empty-string token in multi-language models (i = 50256)
 
1011
  int32_t length;
1012
  int32_t ftype;
1013
 
1014
+ read_safe(loader, n_dims);
1015
+ read_safe(loader, length);
1016
+ read_safe(loader, ftype);
1017
 
1018
+ if (loader->eof(loader->context)) {
1019
  break;
1020
  }
1021
 
1022
  int32_t nelements = 1;
1023
  int32_t ne[3] = { 1, 1, 1 };
1024
  for (int i = 0; i < n_dims; ++i) {
1025
+ read_safe(loader, ne[i]);
1026
  nelements *= ne[i];
1027
  }
1028
 
1029
  std::string name;
1030
  std::vector<char> tmp(length); // create a buffer
1031
+ loader->read(loader->context, &tmp[0], tmp.size()); // read to buffer
1032
  name.assign(&tmp[0], tmp.size());
1033
 
1034
  if (model.tensors.find(name) == model.tensors.end()) {
 
1056
  return false;
1057
  }
1058
 
1059
+ loader->read(loader->context, tensor->data, ggml_nbytes(tensor));
1060
 
1061
  //printf("%48s - [%5d, %5d, %5d], type = %6s, %6.2f MB\n", name.data(), ne[0], ne[1], ne[2], ftype == 0 ? "float" : "f16", ggml_nbytes(tensor)/1024.0/1024.0);
1062
  total_size += ggml_nbytes(tensor);
 
1073
  }
1074
  }
1075
 
 
 
1076
  return true;
1077
  }
1078
 
 
2232
  // interface implementation
2233
  //
2234
 
2235
+ struct whisper_context * whisper_init_from_file(const char * path_model) {
2236
+ whisper_model_loader loader = {};
2237
+
2238
+ fprintf(stderr, "%s: loading model from '%s'\n", __func__, path_model);
2239
+
2240
+ auto fin = std::ifstream(path_model, std::ios::binary);
2241
+ if (!fin) {
2242
+ fprintf(stderr, "%s: failed to open '%s'\n", __func__, path_model);
2243
+ return nullptr;
2244
+ }
2245
+
2246
+ loader.context = &fin;
2247
+ loader.read = [](void * ctx, void * output, size_t read_size) {
2248
+ std::ifstream * fin = (std::ifstream*)ctx;
2249
+ fin->read((char *)output, read_size);
2250
+ return read_size;
2251
+ };
2252
+
2253
+ loader.eof = [](void * ctx) {
2254
+ std::ifstream * fin = (std::ifstream*)ctx;
2255
+ return fin->eof();
2256
+ };
2257
+
2258
+ loader.close = [](void * ctx) {
2259
+ std::ifstream * fin = (std::ifstream*)ctx;
2260
+ fin->close();
2261
+ };
2262
+
2263
+ return whisper_init(&loader);
2264
+ }
2265
+
2266
+ struct whisper_context * whisper_init_from_buffer(void * buffer, size_t buffer_size) {
2267
+ struct buf_context {
2268
+ uint8_t* buffer;
2269
+ size_t size;
2270
+ size_t current_offset;
2271
+ };
2272
+
2273
+ buf_context ctx = { reinterpret_cast<uint8_t*>(buffer), buffer_size, 0 };
2274
+ whisper_model_loader loader = {};
2275
+
2276
+ fprintf(stderr, "%s: loading model from buffer\n", __func__);
2277
+
2278
+ loader.context = &ctx;
2279
+
2280
+ loader.read = [](void * ctx, void * output, size_t read_size) {
2281
+ buf_context * buf = reinterpret_cast<buf_context *>(ctx);
2282
+
2283
+ size_t size_to_copy = buf->current_offset + read_size < buf->size ? read_size : buf->size - buf->current_offset;
2284
+
2285
+ memcpy(output, buf->buffer + buf->current_offset, size_to_copy);
2286
+ buf->current_offset += size_to_copy;
2287
+
2288
+ return size_to_copy;
2289
+ };
2290
+
2291
+ loader.eof = [](void * ctx) {
2292
+ buf_context * buf = reinterpret_cast<buf_context *>(ctx);
2293
+
2294
+ return buf->current_offset >= buf->size;
2295
+ };
2296
+
2297
+ loader.close = [](void * /*ctx*/) { };
2298
+
2299
+ return whisper_init(&loader);
2300
+ }
2301
+
2302
+ struct whisper_context * whisper_init(struct whisper_model_loader * loader) {
2303
  ggml_time_init();
2304
 
2305
  whisper_context * ctx = new whisper_context;
 
2308
 
2309
  ctx->t_start_us = t_start_us;
2310
 
2311
+ if (!whisper_model_load(loader, *ctx)) {
2312
+ loader->close(loader->context);
2313
+ fprintf(stderr, "%s: failed to load model\n", __func__);
2314
  delete ctx;
2315
  return nullptr;
2316
  }
2317
 
2318
  ctx->t_load_us = ggml_time_us() - t_start_us;
2319
 
2320
+ loader->close(loader->context);
2321
+
2322
  return ctx;
2323
  }
2324
 
whisper.h CHANGED
@@ -1,6 +1,7 @@
1
  #ifndef WHISPER_H
2
  #define WHISPER_H
3
 
 
4
  #include <stdint.h>
5
  #include <stdbool.h>
6
 
@@ -40,7 +41,7 @@ extern "C" {
40
  //
41
  // ...
42
  //
43
- // struct whisper_context * ctx = whisper_init("/path/to/ggml-base.en.bin");
44
  //
45
  // if (whisper_full(ctx, wparams, pcmf32.data(), pcmf32.size()) != 0) {
46
  // fprintf(stderr, "failed to process audio\n");
@@ -84,9 +85,20 @@ extern "C" {
84
  float vlen; // voice length of the token
85
  } whisper_token_data;
86
 
87
- // Allocates all memory needed for the model and loads the model from the given file.
88
- // Returns NULL on failure.
89
- WHISPER_API struct whisper_context * whisper_init(const char * path_model);
 
 
 
 
 
 
 
 
 
 
 
90
 
91
  // Frees all memory allocated by the model.
92
  WHISPER_API void whisper_free(struct whisper_context * ctx);
 
1
  #ifndef WHISPER_H
2
  #define WHISPER_H
3
 
4
+ #include <stddef.h>
5
  #include <stdint.h>
6
  #include <stdbool.h>
7
 
 
41
  //
42
  // ...
43
  //
44
+ // struct whisper_context * ctx = whisper_init_from_file("/path/to/ggml-base.en.bin");
45
  //
46
  // if (whisper_full(ctx, wparams, pcmf32.data(), pcmf32.size()) != 0) {
47
  // fprintf(stderr, "failed to process audio\n");
 
85
  float vlen; // voice length of the token
86
  } whisper_token_data;
87
 
88
+ typedef struct whisper_model_loader {
89
+ void * context;
90
+
91
+ size_t (*read)(void * ctx, void * output, size_t read_size);
92
+ bool (*eof)(void * ctx);
93
+ void (*close)(void * ctx);
94
+ } whisper_model_loader;
95
+
96
+ // Various function to load a ggml whisper model.
97
+ // Allocates (almost) all memory needed for the model.
98
+ // Return NULL on failure
99
+ WHISPER_API struct whisper_context * whisper_init_from_file(const char * path_model);
100
+ WHISPER_API struct whisper_context * whisper_init_from_buffer(void * buffer, size_t buffer_size);
101
+ WHISPER_API struct whisper_context * whisper_init(struct whisper_model_loader * loader);
102
 
103
  // Frees all memory allocated by the model.
104
  WHISPER_API void whisper_free(struct whisper_context * ctx);