ggerganov commited on
Commit
c0943fb
·
1 Parent(s): b3f1468

Initial C-style interface for whisper.cpp

Browse files
Files changed (5) hide show
  1. Makefile +6 -3
  2. main.cpp +130 -2222
  3. stream.cpp +132 -2215
  4. whisper.cpp +2221 -0
  5. whisper.h +133 -0
Makefile CHANGED
@@ -1,17 +1,20 @@
1
  CC_SDL=`sdl2-config --cflags --libs`
2
 
3
- main: ggml.o main.o
4
- g++ -pthread -o main ggml.o main.o
5
  ./main -h
6
 
7
  ggml.o: ggml.c ggml.h
8
  gcc -pthread -O3 -mavx -mavx2 -mfma -mf16c -c ggml.c
9
 
 
 
 
10
  main.o: main.cpp ggml.h
11
  g++ -pthread -O3 -std=c++11 -c main.cpp
12
 
13
  stream: stream.cpp
14
- g++ -pthread -O3 -std=c++11 -o stream stream.cpp ggml.o $(CC_SDL)
15
 
16
  # clean up the directory
17
  clean:
 
1
  CC_SDL=`sdl2-config --cflags --libs`
2
 
3
+ main: ggml.o whisper.o main.o
4
+ g++ -pthread -o main ggml.o whisper.o main.o
5
  ./main -h
6
 
7
  ggml.o: ggml.c ggml.h
8
  gcc -pthread -O3 -mavx -mavx2 -mfma -mf16c -c ggml.c
9
 
10
+ whisper.o: whisper.cpp whisper.h
11
+ gcc -pthread -O3 -std=c++11 -c whisper.cpp
12
+
13
  main.o: main.cpp ggml.h
14
  g++ -pthread -O3 -std=c++11 -c main.cpp
15
 
16
  stream: stream.cpp
17
+ g++ -pthread -O3 -std=c++11 -o stream stream.cpp ggml.o whisper.o $(CC_SDL)
18
 
19
  # clean up the directory
20
  clean:
main.cpp CHANGED
@@ -1,2123 +1,117 @@
1
- #include "ggml.h"
2
 
3
- #define USE_FLASH_ATTN
4
- #define USE_FLASH_FF
5
-
6
- // third-party utilities
7
- // use your favorite implementations
8
- #define DR_WAV_IMPLEMENTATION
9
- #include "dr_wav.h"
10
-
11
- #include <algorithm>
12
- #include <cassert>
13
- #include <cmath>
14
- #include <cstdio>
15
- #include <cstring>
16
- #include <fstream>
17
- #include <map>
18
- #include <string>
19
- #include <thread>
20
- #include <vector>
21
-
22
- // available whisper models
23
- enum e_model {
24
- MODEL_UNKNOWN,
25
- MODEL_TINY,
26
- MODEL_BASE,
27
- MODEL_SMALL,
28
- MODEL_MEDIUM,
29
- MODEL_LARGE,
30
- };
31
-
32
- const std::map<std::string, std::pair<int, std::string>> g_lang = {
33
- { "en", { 0, "english", } },
34
- { "zh", { 1, "chinese", } },
35
- { "de", { 2, "german", } },
36
- { "es", { 3, "spanish", } },
37
- { "ru", { 4, "russian", } },
38
- { "ko", { 5, "korean", } },
39
- { "fr", { 6, "french", } },
40
- { "ja", { 7, "japanese", } },
41
- { "pt", { 8, "portuguese", } },
42
- { "tr", { 9, "turkish", } },
43
- { "pl", { 10, "polish", } },
44
- { "ca", { 11, "catalan", } },
45
- { "nl", { 12, "dutch", } },
46
- { "ar", { 13, "arabic", } },
47
- { "sv", { 14, "swedish", } },
48
- { "it", { 15, "italian", } },
49
- { "id", { 16, "indonesian", } },
50
- { "hi", { 17, "hindi", } },
51
- { "fi", { 18, "finnish", } },
52
- { "vi", { 19, "vietnamese", } },
53
- { "iw", { 20, "hebrew", } },
54
- { "uk", { 21, "ukrainian", } },
55
- { "el", { 22, "greek", } },
56
- { "ms", { 23, "malay", } },
57
- { "cs", { 24, "czech", } },
58
- { "ro", { 25, "romanian", } },
59
- { "da", { 26, "danish", } },
60
- { "hu", { 27, "hungarian", } },
61
- { "ta", { 28, "tamil", } },
62
- { "no", { 29, "norwegian", } },
63
- { "th", { 30, "thai", } },
64
- { "ur", { 31, "urdu", } },
65
- { "hr", { 32, "croatian", } },
66
- { "bg", { 33, "bulgarian", } },
67
- { "lt", { 34, "lithuanian", } },
68
- { "la", { 35, "latin", } },
69
- { "mi", { 36, "maori", } },
70
- { "ml", { 37, "malayalam", } },
71
- { "cy", { 38, "welsh", } },
72
- { "sk", { 39, "slovak", } },
73
- { "te", { 40, "telugu", } },
74
- { "fa", { 41, "persian", } },
75
- { "lv", { 42, "latvian", } },
76
- { "bn", { 43, "bengali", } },
77
- { "sr", { 44, "serbian", } },
78
- { "az", { 45, "azerbaijani", } },
79
- { "sl", { 46, "slovenian", } },
80
- { "kn", { 47, "kannada", } },
81
- { "et", { 48, "estonian", } },
82
- { "mk", { 49, "macedonian", } },
83
- { "br", { 50, "breton", } },
84
- { "eu", { 51, "basque", } },
85
- { "is", { 52, "icelandic", } },
86
- { "hy", { 53, "armenian", } },
87
- { "ne", { 54, "nepali", } },
88
- { "mn", { 55, "mongolian", } },
89
- { "bs", { 56, "bosnian", } },
90
- { "kk", { 57, "kazakh", } },
91
- { "sq", { 58, "albanian", } },
92
- { "sw", { 59, "swahili", } },
93
- { "gl", { 60, "galician", } },
94
- { "mr", { 61, "marathi", } },
95
- { "pa", { 62, "punjabi", } },
96
- { "si", { 63, "sinhala", } },
97
- { "km", { 64, "khmer", } },
98
- { "sn", { 65, "shona", } },
99
- { "yo", { 66, "yoruba", } },
100
- { "so", { 67, "somali", } },
101
- { "af", { 68, "afrikaans", } },
102
- { "oc", { 69, "occitan", } },
103
- { "ka", { 70, "georgian", } },
104
- { "be", { 71, "belarusian", } },
105
- { "tg", { 72, "tajik", } },
106
- { "sd", { 73, "sindhi", } },
107
- { "gu", { 74, "gujarati", } },
108
- { "am", { 75, "amharic", } },
109
- { "yi", { 76, "yiddish", } },
110
- { "lo", { 77, "lao", } },
111
- { "uz", { 78, "uzbek", } },
112
- { "fo", { 79, "faroese", } },
113
- { "ht", { 80, "haitian creole", } },
114
- { "ps", { 81, "pashto", } },
115
- { "tk", { 82, "turkmen", } },
116
- { "nn", { 83, "nynorsk", } },
117
- { "mt", { 84, "maltese", } },
118
- { "sa", { 85, "sanskrit", } },
119
- { "lb", { 86, "luxembourgish", } },
120
- { "my", { 87, "myanmar", } },
121
- { "bo", { 88, "tibetan", } },
122
- { "tl", { 89, "tagalog", } },
123
- { "mg", { 90, "malagasy", } },
124
- { "as", { 91, "assamese", } },
125
- { "tt", { 92, "tatar", } },
126
- { "haw", { 93, "hawaiian", } },
127
- { "ln", { 94, "lingala", } },
128
- { "ha", { 95, "hausa", } },
129
- { "ba", { 96, "bashkir", } },
130
- { "jw", { 97, "javanese", } },
131
- { "su", { 98, "sundanese", } },
132
- };
133
-
134
- const size_t MB = 1024*1024;
135
-
136
- const std::map<e_model, size_t> MEM_REQ_MODEL = {
137
- { MODEL_TINY, 86ull*MB },
138
- { MODEL_BASE, 165ull*MB },
139
- { MODEL_SMALL, 540ull*MB },
140
- { MODEL_MEDIUM, 1650ull*MB },
141
- { MODEL_LARGE, 3260ull*MB },
142
- };
143
-
144
- const std::map<e_model, size_t> MEM_REQ_ENCODE = {
145
- { MODEL_TINY, 80ull*MB },
146
- { MODEL_BASE, 128ull*MB },
147
- { MODEL_SMALL, 300ull*MB },
148
- { MODEL_MEDIUM, 680ull*MB },
149
- { MODEL_LARGE, 1100ull*MB },
150
- };
151
-
152
- const std::map<e_model, size_t> MEM_REQ_ENCODE_LAYER = {
153
- { MODEL_TINY, 64ull*MB },
154
- { MODEL_BASE, 84ull*MB },
155
- { MODEL_SMALL, 128ull*MB },
156
- { MODEL_MEDIUM, 172ull*MB },
157
- { MODEL_LARGE, 216ull*MB },
158
- };
159
-
160
- const std::map<e_model, size_t> MEM_REQ_DECODE = {
161
- { MODEL_TINY, 94ull*MB },
162
- { MODEL_BASE, 96ull*MB },
163
- { MODEL_SMALL, 98ull*MB },
164
- { MODEL_MEDIUM, 100ull*MB },
165
- { MODEL_LARGE, 102ull*MB },
166
- };
167
-
168
- const std::map<e_model, size_t> MEM_REQ_DECODE_LAYER = {
169
- { MODEL_TINY, 32ull*MB },
170
- { MODEL_BASE, 44ull*MB },
171
- { MODEL_SMALL, 64ull*MB },
172
- { MODEL_MEDIUM, 84ull*MB },
173
- { MODEL_LARGE, 110ull*MB },
174
- };
175
-
176
- // the memory buffers used to store the model in memory and perform the inference computations
177
- std::vector<uint8_t> g_buf_model;
178
- std::vector<uint8_t> g_buf_compute;
179
- std::vector<uint8_t> g_buf_compute_layer;
180
-
181
- const int SAMPLE_RATE = 16000;
182
- const int N_FFT = 400;
183
- const int N_MEL = 80;
184
- const int HOP_LENGTH = 160;
185
- const int CHUNK_SIZE = 30; // seconds
186
-
187
- struct whisper_mel {
188
- int n_len;
189
- int n_mel;
190
-
191
- std::vector<float> data;
192
- };
193
-
194
- struct whisper_filters {
195
- int32_t n_mel;
196
- int32_t n_fft;
197
-
198
- std::vector<float> data;
199
- };
200
-
201
- struct whisper_vocab {
202
- using id = int32_t;
203
- using token = std::string;
204
-
205
- int n_vocab = 51864;
206
-
207
- std::map<token, id> token_to_id;
208
- std::map<id, token> id_to_token;
209
-
210
- id token_eot = 50256;
211
- id token_sot = 50257;
212
- id token_prev = 50360;
213
- id token_solm = 50361; // ??
214
- id token_not = 50362; // no timestamps
215
- id token_beg = 50363;
216
-
217
- // available tasks
218
- const id token_translate = 50358;
219
- const id token_transcribe = 50359;
220
-
221
- bool is_multilingual() const {
222
- return n_vocab == 51865;
223
- }
224
- };
225
-
226
- struct whisper_result {
227
- whisper_vocab::id id;
228
- int64_t t;
229
- };
230
-
231
- // command-line parameters
232
- struct whisper_params {
233
- int32_t seed = -1; // RNG seed, not used currently
234
- int32_t n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency());
235
-
236
- bool verbose = false;
237
- bool translate = false;
238
- bool print_special_tokens = false;
239
- bool no_timestamps = false;
240
-
241
- std::string language = "en";
242
- std::string model = "models/ggml-base.en.bin";
243
- std::string fname_inp = "samples/jfk.wav";
244
- };
245
-
246
- void whisper_print_usage(int argc, char ** argv, const whisper_params & params);
247
-
248
- bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
249
- for (int i = 1; i < argc; i++) {
250
- std::string arg = argv[i];
251
-
252
- if (arg == "-s" || arg == "--seed") {
253
- params.seed = std::stoi(argv[++i]);
254
- } else if (arg == "-t" || arg == "--threads") {
255
- params.n_threads = std::stoi(argv[++i]);
256
- } else if (arg == "-v" || arg == "--verbose") {
257
- params.verbose = true;
258
- } else if (arg == "--translate") {
259
- params.translate = true;
260
- } else if (arg == "-l" || arg == "--language") {
261
- params.language = argv[++i];
262
- if (g_lang.find(params.language) == g_lang.end()) {
263
- fprintf(stderr, "error: unknown language '%s'\n", params.language.c_str());
264
- whisper_print_usage(argc, argv, params);
265
- exit(0);
266
- }
267
- } else if (arg == "-ps" || arg == "--print_special") {
268
- params.print_special_tokens = true;
269
- } else if (arg == "-nt" || arg == "--no_timestamps") {
270
- params.no_timestamps = true;
271
- } else if (arg == "-m" || arg == "--model") {
272
- params.model = argv[++i];
273
- } else if (arg == "-f" || arg == "--file") {
274
- params.fname_inp = argv[++i];
275
- } else if (arg == "-h" || arg == "--help") {
276
- whisper_print_usage(argc, argv, params);
277
- exit(0);
278
- } else {
279
- fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
280
- whisper_print_usage(argc, argv, params);
281
- exit(0);
282
- }
283
- }
284
-
285
- return true;
286
- }
287
-
288
- void whisper_print_usage(int argc, char ** argv, const whisper_params & params) {
289
- fprintf(stderr, "\n");
290
- fprintf(stderr, "usage: %s [options]\n", argv[0]);
291
- fprintf(stderr, "\n");
292
- fprintf(stderr, "options:\n");
293
- fprintf(stderr, " -h, --help show this help message and exit\n");
294
- fprintf(stderr, " -s SEED, --seed SEED RNG seed (default: -1)\n");
295
- fprintf(stderr, " -t N, --threads N number of threads to use during computation (default: %d)\n", params.n_threads);
296
- fprintf(stderr, " -v, --verbose verbose output\n");
297
- fprintf(stderr, " --translate translate from source language to english\n");
298
- fprintf(stderr, " -ps, --print_special print special tokens\n");
299
- fprintf(stderr, " -nt, --no_timestamps do not print timestamps\n");
300
- fprintf(stderr, " -l LANG, --language LANG spoken language (default: %s)\n", params.language.c_str());
301
- fprintf(stderr, " -m FNAME, --model FNAME model path (default: %s)\n", params.model.c_str());
302
- fprintf(stderr, " -f FNAME, --file FNAME input WAV file path (default: %s)\n", params.fname_inp.c_str());
303
- fprintf(stderr, "\n");
304
- }
305
-
306
-
307
- // medium
308
- // hparams: {
309
- // 'n_mels': 80,
310
- // 'n_vocab': 51864,
311
- // 'n_audio_ctx': 1500,
312
- // 'n_audio_state': 1024,
313
- // 'n_audio_head': 16,
314
- // 'n_audio_layer': 24,
315
- // 'n_text_ctx': 448,
316
- // 'n_text_state': 1024,
317
- // 'n_text_head': 16,
318
- // 'n_text_layer': 24
319
- // }
320
- //
321
- // default hparams (Whisper tiny)
322
- struct whisper_hparams {
323
- int32_t n_vocab = 51864;
324
- int32_t n_audio_ctx = 1500;
325
- int32_t n_audio_state = 384;
326
- int32_t n_audio_head = 6;
327
- int32_t n_audio_layer = 4;
328
- int32_t n_text_ctx = 448;
329
- int32_t n_text_state = 384;
330
- int32_t n_text_head = 6;
331
- int32_t n_text_layer = 4;
332
- int32_t n_mels = 80;
333
- int32_t f16 = 1;
334
- };
335
-
336
- // audio encoding layer
337
- struct whisper_layer_encoder {
338
- // encoder.blocks.*.attn_ln
339
- struct ggml_tensor * attn_ln_0_w;
340
- struct ggml_tensor * attn_ln_0_b;
341
-
342
- // encoder.blocks.*.attn.out
343
- struct ggml_tensor * attn_ln_1_w;
344
- struct ggml_tensor * attn_ln_1_b;
345
-
346
- // encoder.blocks.*.attn.query
347
- struct ggml_tensor * attn_q_w;
348
- struct ggml_tensor * attn_q_b;
349
-
350
- // encoder.blocks.*.attn.key
351
- struct ggml_tensor * attn_k_w;
352
-
353
- // encoder.blocks.*.attn.value
354
- struct ggml_tensor * attn_v_w;
355
- struct ggml_tensor * attn_v_b;
356
-
357
- // encoder.blocks.*.mlp_ln
358
- struct ggml_tensor * mlp_ln_w;
359
- struct ggml_tensor * mlp_ln_b;
360
-
361
- // encoder.blocks.*.mlp.0
362
- struct ggml_tensor * mlp_0_w;
363
- struct ggml_tensor * mlp_0_b;
364
-
365
- // encoder.blocks.*.mlp.2
366
- struct ggml_tensor * mlp_1_w;
367
- struct ggml_tensor * mlp_1_b;
368
- };
369
-
370
- // token decoding layer
371
- struct whisper_layer_decoder {
372
- // decoder.blocks.*.attn_ln
373
- struct ggml_tensor * attn_ln_0_w;
374
- struct ggml_tensor * attn_ln_0_b;
375
-
376
- // decoder.blocks.*.attn.out
377
- struct ggml_tensor * attn_ln_1_w;
378
- struct ggml_tensor * attn_ln_1_b;
379
-
380
- // decoder.blocks.*.attn.query
381
- struct ggml_tensor * attn_q_w;
382
- struct ggml_tensor * attn_q_b;
383
-
384
- // decoder.blocks.*.attn.key
385
- struct ggml_tensor * attn_k_w;
386
-
387
- // decoder.blocks.*.attn.value
388
- struct ggml_tensor * attn_v_w;
389
- struct ggml_tensor * attn_v_b;
390
-
391
- // decoder.blocks.*.cross_attn_ln
392
- struct ggml_tensor * cross_attn_ln_0_w;
393
- struct ggml_tensor * cross_attn_ln_0_b;
394
-
395
- // decoder.blocks.*.cross_attn.out
396
- struct ggml_tensor * cross_attn_ln_1_w;
397
- struct ggml_tensor * cross_attn_ln_1_b;
398
-
399
- // decoder.blocks.*.cross_attn.query
400
- struct ggml_tensor * cross_attn_q_w;
401
- struct ggml_tensor * cross_attn_q_b;
402
-
403
- // decoder.blocks.*.cross_attn.key
404
- struct ggml_tensor * cross_attn_k_w;
405
-
406
- // decoder.blocks.*.cross_attn.value
407
- struct ggml_tensor * cross_attn_v_w;
408
- struct ggml_tensor * cross_attn_v_b;
409
-
410
- // decoder.blocks.*.mlp_ln
411
- struct ggml_tensor * mlp_ln_w;
412
- struct ggml_tensor * mlp_ln_b;
413
-
414
- // decoder.blocks.*.mlp.0
415
- struct ggml_tensor * mlp_0_w;
416
- struct ggml_tensor * mlp_0_b;
417
-
418
- // decoder.blocks.*.mlp.2
419
- struct ggml_tensor * mlp_1_w;
420
- struct ggml_tensor * mlp_1_b;
421
- };
422
-
423
- struct whisper_model {
424
- e_model type = MODEL_UNKNOWN;
425
-
426
- whisper_hparams hparams;
427
- whisper_filters filters;
428
-
429
- // encoder.positional_embedding
430
- struct ggml_tensor * e_pe;
431
-
432
- // encoder.conv1
433
- struct ggml_tensor * e_conv_1_w;
434
- struct ggml_tensor * e_conv_1_b;
435
-
436
- // encoder.conv2
437
- struct ggml_tensor * e_conv_2_w;
438
- struct ggml_tensor * e_conv_2_b;
439
-
440
- // encoder.ln_post
441
- struct ggml_tensor * e_ln_w;
442
- struct ggml_tensor * e_ln_b;
443
-
444
- // decoder.positional_embedding
445
- struct ggml_tensor * d_pe; // DD
446
-
447
- // decoder.token_embedding
448
- struct ggml_tensor * d_te; // DD
449
-
450
- // decoder.ln
451
- struct ggml_tensor * d_ln_w; // DD
452
- struct ggml_tensor * d_ln_b; // DD
453
-
454
- std::vector<whisper_layer_encoder> layers_encoder;
455
- std::vector<whisper_layer_decoder> layers_decoder;
456
-
457
- // key + value memory
458
- struct ggml_tensor * memory_k;
459
- struct ggml_tensor * memory_v;
460
-
461
- struct ggml_tensor * memory_cross_k;
462
- struct ggml_tensor * memory_cross_v;
463
-
464
- //
465
- struct ggml_context * ctx;
466
- std::map<std::string, struct ggml_tensor *> tensors;
467
- };
468
-
469
- // load the model from a ggml file
470
- //
471
- // file format:
472
- //
473
- // - hparams
474
- // - pre-computed mel filters
475
- // - vocab
476
- // - weights
477
- //
478
- // see the convert-pt-to-ggml.py script for details
479
- //
480
- bool whisper_model_load(const std::string & fname, whisper_model & model, whisper_vocab & vocab) {
481
- printf("%s: loading model from '%s'\n", __func__, fname.c_str());
482
-
483
- auto fin = std::ifstream(fname, std::ios::binary);
484
- if (!fin) {
485
- fprintf(stderr, "%s: failed to open '%s'\n", __func__, fname.c_str());
486
- return false;
487
- }
488
-
489
- // verify magic
490
- {
491
- uint32_t magic;
492
- fin.read((char *) &magic, sizeof(magic));
493
- if (magic != 0x67676d6c) {
494
- fprintf(stderr, "%s: invalid model file '%s' (bad magic)\n", __func__, fname.c_str());
495
- return false;
496
- }
497
- }
498
-
499
- //load hparams
500
- {
501
- auto & hparams = model.hparams;
502
-
503
- fin.read((char *) &hparams.n_vocab, sizeof(hparams.n_vocab));
504
- fin.read((char *) &hparams.n_audio_ctx, sizeof(hparams.n_audio_ctx));
505
- fin.read((char *) &hparams.n_audio_state, sizeof(hparams.n_audio_state));
506
- fin.read((char *) &hparams.n_audio_head, sizeof(hparams.n_audio_head));
507
- fin.read((char *) &hparams.n_audio_layer, sizeof(hparams.n_audio_layer));
508
- fin.read((char *) &hparams.n_text_ctx, sizeof(hparams.n_text_ctx));
509
- fin.read((char *) &hparams.n_text_state, sizeof(hparams.n_text_state));
510
- fin.read((char *) &hparams.n_text_head, sizeof(hparams.n_text_head));
511
- fin.read((char *) &hparams.n_text_layer, sizeof(hparams.n_text_layer));
512
- fin.read((char *) &hparams.n_mels, sizeof(hparams.n_mels));
513
- fin.read((char *) &hparams.f16, sizeof(hparams.f16));
514
-
515
- assert(hparams.n_text_state == hparams.n_audio_state);
516
-
517
- if (hparams.n_audio_layer == 4) {
518
- model.type = e_model::MODEL_TINY;
519
- }
520
-
521
- if (hparams.n_audio_layer == 6) {
522
- model.type = e_model::MODEL_BASE;
523
- }
524
-
525
- if (hparams.n_audio_layer == 12) {
526
- model.type = e_model::MODEL_SMALL;
527
- }
528
-
529
- if (hparams.n_audio_layer == 24) {
530
- model.type = e_model::MODEL_MEDIUM;
531
- }
532
-
533
- if (hparams.n_audio_layer == 32) {
534
- model.type = e_model::MODEL_LARGE;
535
- }
536
-
537
- printf("%s: n_vocab = %d\n", __func__, hparams.n_vocab);
538
- printf("%s: n_audio_ctx = %d\n", __func__, hparams.n_audio_ctx);
539
- printf("%s: n_audio_state = %d\n", __func__, hparams.n_audio_state);
540
- printf("%s: n_audio_head = %d\n", __func__, hparams.n_audio_head);
541
- printf("%s: n_audio_layer = %d\n", __func__, hparams.n_audio_layer);
542
- printf("%s: n_text_ctx = %d\n", __func__, hparams.n_text_ctx);
543
- printf("%s: n_text_state = %d\n", __func__, hparams.n_text_state);
544
- printf("%s: n_text_head = %d\n", __func__, hparams.n_text_head);
545
- printf("%s: n_text_layer = %d\n", __func__, hparams.n_text_layer);
546
- printf("%s: n_mels = %d\n", __func__, hparams.n_mels);
547
- printf("%s: f16 = %d\n", __func__, hparams.f16);
548
- printf("%s: type = %d\n", __func__, model.type);
549
-
550
- g_buf_model.resize(MEM_REQ_MODEL.at(model.type));
551
- g_buf_compute.resize(std::max(MEM_REQ_ENCODE.at(model.type), MEM_REQ_DECODE.at(model.type)));
552
- g_buf_compute_layer.resize(std::max(MEM_REQ_ENCODE_LAYER.at(model.type), MEM_REQ_DECODE_LAYER.at(model.type)));
553
-
554
- // this is the total memory required to run the inference
555
- const size_t mem_required =
556
- g_buf_model.size() +
557
- g_buf_compute.size() +
558
- g_buf_compute_layer.size();
559
-
560
- printf("%s: mem_required = %.2f MB\n", __func__, mem_required / 1024.0 / 1024.0);
561
- }
562
-
563
- // load mel filters
564
- {
565
- auto & filters = model.filters;
566
-
567
- fin.read((char *) &filters.n_mel, sizeof(filters.n_mel));
568
- fin.read((char *) &filters.n_fft, sizeof(filters.n_fft));
569
-
570
- filters.data.resize(filters.n_mel * filters.n_fft);
571
- fin.read((char *) filters.data.data(), filters.data.size() * sizeof(float));
572
- }
573
-
574
- // load vocab
575
- {
576
- int32_t n_vocab = 0;
577
- fin.read((char *) &n_vocab, sizeof(n_vocab));
578
-
579
- //if (n_vocab != model.hparams.n_vocab) {
580
- // fprintf(stderr, "%s: invalid model file '%s' (bad vocab size %d != %d)\n",
581
- // __func__, fname.c_str(), n_vocab, model.hparams.n_vocab);
582
- // return false;
583
- //}
584
-
585
- std::string word;
586
- for (int i = 0; i < n_vocab; i++) {
587
- uint32_t len;
588
- fin.read((char *) &len, sizeof(len));
589
-
590
- word.resize(len);
591
- fin.read((char *) word.data(), len);
592
-
593
- vocab.token_to_id[word] = i;
594
- vocab.id_to_token[i] = word;
595
-
596
- //printf("%s: vocab[%d] = '%s'\n", __func__, i, word.c_str());
597
- }
598
-
599
- vocab.n_vocab = model.hparams.n_vocab;
600
- if (vocab.is_multilingual()) {
601
- vocab.token_eot++;
602
- vocab.token_sot++;
603
- vocab.token_prev++;
604
- vocab.token_solm++;
605
- vocab.token_not++;
606
- vocab.token_beg++;
607
- }
608
-
609
- if (n_vocab < model.hparams.n_vocab) {
610
- printf("%s: adding %d extra tokens\n", __func__, model.hparams.n_vocab - n_vocab);
611
- for (int i = n_vocab; i < model.hparams.n_vocab; i++) {
612
- if (i > vocab.token_beg) {
613
- word = "[_TT_" + std::to_string(i - vocab.token_beg) + "]";
614
- } else if (i == vocab.token_eot) {
615
- word = "[_EOT_]";
616
- } else if (i == vocab.token_sot) {
617
- word = "[_SOT_]";
618
- } else if (i == vocab.token_prev) {
619
- word = "[_PREV_]";
620
- } else if (i == vocab.token_not) {
621
- word = "[_NOT_]";
622
- } else if (i == vocab.token_beg) {
623
- word = "[_BEG_]";
624
- } else {
625
- word = "[_extra_token_" + std::to_string(i) + "]";
626
- }
627
- vocab.token_to_id[word] = i;
628
- vocab.id_to_token[i] = word;
629
- }
630
- }
631
- }
632
-
633
- // for the big tensors, we have the option to store the data in 16-bit floats
634
- // in order to save memory and also to speed up the computation
635
- const ggml_type wtype = model.hparams.f16 ? GGML_TYPE_F16 : GGML_TYPE_F32;
636
-
637
- auto & ctx = model.ctx;
638
-
639
- size_t ctx_size = 0;
640
-
641
- {
642
- const auto & hparams = model.hparams;
643
-
644
- const int n_vocab = hparams.n_vocab;
645
-
646
- const int n_audio_ctx = hparams.n_audio_ctx;
647
- const int n_audio_state = hparams.n_audio_state;
648
- const int n_audio_layer = hparams.n_audio_layer;
649
-
650
- const int n_text_ctx = hparams.n_text_ctx;
651
- const int n_text_state = hparams.n_text_state;
652
- const int n_text_layer = hparams.n_text_layer;
653
-
654
- const int n_mels = hparams.n_mels;
655
-
656
- // encoder
657
- {
658
- // TODO: F16 .. maybe not?
659
- ctx_size += n_audio_ctx*n_audio_state*ggml_type_size(GGML_TYPE_F32); // e_pe;
660
-
661
- ctx_size += 3*n_mels*n_audio_state*ggml_type_size(wtype); // e_conv_1_w
662
- ctx_size += n_audio_state*ggml_type_size(GGML_TYPE_F32); // e_conv_1_b
663
-
664
- ctx_size += 3*n_audio_state*n_audio_state*ggml_type_size(wtype); // e_conv_2_w
665
- ctx_size += n_audio_state*ggml_type_size(GGML_TYPE_F32); // e_conv_2_b
666
-
667
- ctx_size += n_audio_state*ggml_type_size(GGML_TYPE_F32); // e_ln_w;
668
- ctx_size += n_audio_state*ggml_type_size(GGML_TYPE_F32); // e_ln_b;
669
- }
670
-
671
- // decoder
672
- {
673
- // TODO: F16 .. maybe not?
674
- ctx_size += n_text_ctx*n_text_state*ggml_type_size(GGML_TYPE_F32); // d_pe;
675
-
676
- ctx_size += n_vocab*n_text_state*ggml_type_size(wtype); // d_te;
677
-
678
- ctx_size += n_text_state*ggml_type_size(GGML_TYPE_F32); // d_ln_w;
679
- ctx_size += n_text_state*ggml_type_size(GGML_TYPE_F32); // d_ln_b;
680
- }
681
-
682
- // encoder layers
683
- {
684
- ctx_size += n_audio_layer*(n_audio_state*ggml_type_size(GGML_TYPE_F32)); // mlp_ln_w
685
- ctx_size += n_audio_layer*(n_audio_state*ggml_type_size(GGML_TYPE_F32)); // mlp_ln_b
686
-
687
- ctx_size += n_audio_layer*(4*n_audio_state*n_audio_state*ggml_type_size(wtype)); // mlp_0_w
688
- ctx_size += n_audio_layer*( 4*n_audio_state*ggml_type_size(GGML_TYPE_F32)); // mlp_0_b
689
-
690
- ctx_size += n_audio_layer*(4*n_audio_state*n_audio_state*ggml_type_size(wtype)); // mlp_1_w
691
- ctx_size += n_audio_layer*( n_audio_state*ggml_type_size(GGML_TYPE_F32)); // mlp_1_b
692
-
693
- ctx_size += n_audio_layer*(n_audio_state*ggml_type_size(GGML_TYPE_F32)); // attn_ln_0_w
694
- ctx_size += n_audio_layer*(n_audio_state*ggml_type_size(GGML_TYPE_F32)); // attn_ln_0_b
695
-
696
- ctx_size += n_audio_layer*(n_audio_state*n_audio_state*ggml_type_size(wtype)); // attn_q_w
697
- ctx_size += n_audio_layer*( n_audio_state*ggml_type_size(GGML_TYPE_F32)); // attn_q_b
698
-
699
- ctx_size += n_audio_layer*(n_audio_state*n_audio_state*ggml_type_size(wtype)); // attn_k_w
700
-
701
- ctx_size += n_audio_layer*(n_audio_state*n_audio_state*ggml_type_size(wtype)); // attn_v_w
702
- ctx_size += n_audio_layer*( n_audio_state*ggml_type_size(GGML_TYPE_F32)); // attn_v_b
703
-
704
- ctx_size += n_audio_layer*(n_audio_state*n_audio_state*ggml_type_size(wtype)); // attn_ln_1_w
705
- ctx_size += n_audio_layer*( n_audio_state*ggml_type_size(GGML_TYPE_F32)); // attn_ln_1_b
706
- }
707
-
708
- // decoder layers
709
- {
710
- ctx_size += n_text_layer*(n_text_state*ggml_type_size(GGML_TYPE_F32)); // mlp_ln_w
711
- ctx_size += n_text_layer*(n_text_state*ggml_type_size(GGML_TYPE_F32)); // mlp_ln_b
712
-
713
- ctx_size += n_text_layer*(4*n_text_state*n_text_state*ggml_type_size(wtype)); // mlp_0_w
714
- ctx_size += n_text_layer*( 4*n_text_state*ggml_type_size(GGML_TYPE_F32)); // mlp_0_b
715
-
716
- ctx_size += n_text_layer*(4*n_text_state*n_text_state*ggml_type_size(wtype)); // mlp_1_w
717
- ctx_size += n_text_layer*( n_text_state*ggml_type_size(GGML_TYPE_F32)); // mlp_1_b
718
-
719
- ctx_size += n_text_layer*(n_text_state*ggml_type_size(GGML_TYPE_F32)); // attn_ln_0_w
720
- ctx_size += n_text_layer*(n_text_state*ggml_type_size(GGML_TYPE_F32)); // attn_ln_0_b
721
-
722
- ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_size(wtype)); // attn_q_w
723
- ctx_size += n_text_layer*( n_text_state*ggml_type_size(GGML_TYPE_F32)); // attn_q_b
724
-
725
- ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_size(wtype)); // attn_k_w
726
-
727
- ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_size(wtype)); // attn_v_w
728
- ctx_size += n_text_layer*( n_text_state*ggml_type_size(GGML_TYPE_F32)); // attn_v_b
729
-
730
- ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_size(wtype)); // attn_ln_1_w
731
- ctx_size += n_text_layer*( n_text_state*ggml_type_size(GGML_TYPE_F32)); // attn_ln_1_b
732
- //
733
- ctx_size += n_text_layer*(n_text_state*ggml_type_size(GGML_TYPE_F32)); // cross_attn_ln_0_w
734
- ctx_size += n_text_layer*(n_text_state*ggml_type_size(GGML_TYPE_F32)); // cross_attn_ln_0_b
735
-
736
- ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_size(wtype)); // cross_attn_q_w
737
- ctx_size += n_text_layer*( n_text_state*ggml_type_size(GGML_TYPE_F32)); // cross_attn_q_b
738
-
739
- ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_size(wtype)); // cross_attn_k_w
740
-
741
- ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_size(wtype)); // cross_attn_v_w
742
- ctx_size += n_text_layer*( n_text_state*ggml_type_size(GGML_TYPE_F32)); // cross_attn_v_b
743
-
744
- ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_size(wtype)); // cross_attn_ln_1_w
745
- ctx_size += n_text_layer*( n_text_state*ggml_type_size(GGML_TYPE_F32)); // cross_attn_ln_1_b
746
- }
747
-
748
- ctx_size += n_text_layer*n_text_ctx*n_text_state*ggml_type_size(GGML_TYPE_F16); // memory_k
749
- ctx_size += n_text_layer*n_text_ctx*n_text_state*ggml_type_size(GGML_TYPE_F16); // memory_v
750
-
751
- ctx_size += n_text_layer*n_audio_ctx*n_text_state*ggml_type_size(GGML_TYPE_F16); // memory_cross_k
752
- ctx_size += n_text_layer*n_audio_ctx*n_text_state*ggml_type_size(GGML_TYPE_F16); // memory_cross_v
753
-
754
- ctx_size += (15 + 15*n_audio_layer + 24*n_text_layer)*256; // object overhead
755
-
756
- printf("%s: ggml ctx size = %6.2f MB\n", __func__, ctx_size/(1024.0*1024.0));
757
- }
758
-
759
- // create the ggml context
760
- {
761
- struct ggml_init_params params = {
762
- .mem_size = g_buf_model.size(),
763
- .mem_buffer = g_buf_model.data(),
764
- };
765
-
766
- model.ctx = ggml_init(params);
767
- if (!model.ctx) {
768
- fprintf(stderr, "%s: ggml_init() failed\n", __func__);
769
- return false;
770
- }
771
- }
772
-
773
- // prepare memory for the weights
774
- {
775
- const auto & hparams = model.hparams;
776
-
777
- const int n_vocab = hparams.n_vocab;
778
-
779
- const int n_audio_ctx = hparams.n_audio_ctx;
780
- const int n_audio_state = hparams.n_audio_state;
781
- const int n_audio_layer = hparams.n_audio_layer;
782
-
783
- const int n_text_ctx = hparams.n_text_ctx;
784
- const int n_text_state = hparams.n_text_state;
785
- const int n_text_layer = hparams.n_text_layer;
786
-
787
- const int n_mels = hparams.n_mels;
788
-
789
- model.layers_encoder.resize(n_audio_layer);
790
- model.layers_decoder.resize(n_text_layer);
791
-
792
- // encoder
793
- {
794
- model.e_pe = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_audio_state, n_audio_ctx);
795
-
796
- model.e_conv_1_w = ggml_new_tensor_3d(ctx, wtype, 3, n_mels, n_audio_state);
797
- model.e_conv_1_b = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, 1, n_audio_state);
798
-
799
- model.e_conv_2_w = ggml_new_tensor_3d(ctx, wtype, 3, n_audio_state, n_audio_state);
800
- model.e_conv_2_b = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, 1, n_audio_state);
801
-
802
- model.e_ln_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state);
803
- model.e_ln_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state);
804
-
805
- // map by name
806
- model.tensors["encoder.positional_embedding"] = model.e_pe;
807
-
808
- model.tensors["encoder.conv1.weight"] = model.e_conv_1_w;
809
- model.tensors["encoder.conv1.bias"] = model.e_conv_1_b;
810
-
811
- model.tensors["encoder.conv2.weight"] = model.e_conv_2_w;
812
- model.tensors["encoder.conv2.bias"] = model.e_conv_2_b;
813
-
814
- model.tensors["encoder.ln_post.weight"] = model.e_ln_w;
815
- model.tensors["encoder.ln_post.bias"] = model.e_ln_b;
816
-
817
- for (int i = 0; i < n_audio_layer; ++i) {
818
- auto & layer = model.layers_encoder[i];
819
-
820
- layer.mlp_ln_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state);
821
- layer.mlp_ln_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state);
822
-
823
- layer.mlp_0_w = ggml_new_tensor_2d(ctx, wtype, n_audio_state, 4*n_audio_state);
824
- layer.mlp_0_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 4*n_audio_state);
825
-
826
- layer.mlp_1_w = ggml_new_tensor_2d(ctx, wtype, 4*n_audio_state, n_audio_state);
827
- layer.mlp_1_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state);
828
-
829
- layer.attn_ln_0_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state);
830
- layer.attn_ln_0_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state);
831
-
832
- layer.attn_q_w = ggml_new_tensor_2d(ctx, wtype, n_audio_state, n_audio_state);
833
- layer.attn_q_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state);
834
-
835
- layer.attn_k_w = ggml_new_tensor_2d(ctx, wtype, n_audio_state, n_audio_state);
836
-
837
- layer.attn_v_w = ggml_new_tensor_2d(ctx, wtype, n_audio_state, n_audio_state);
838
- layer.attn_v_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state);
839
-
840
- layer.attn_ln_1_w = ggml_new_tensor_2d(ctx, wtype, n_audio_state, n_audio_state);
841
- layer.attn_ln_1_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state);
842
-
843
- // map by name
844
- model.tensors["encoder.blocks." + std::to_string(i) + ".mlp_ln.weight"] = layer.mlp_ln_w;
845
- model.tensors["encoder.blocks." + std::to_string(i) + ".mlp_ln.bias"] = layer.mlp_ln_b;
846
-
847
- model.tensors["encoder.blocks." + std::to_string(i) + ".mlp.0.weight"] = layer.mlp_0_w;
848
- model.tensors["encoder.blocks." + std::to_string(i) + ".mlp.0.bias"] = layer.mlp_0_b;
849
-
850
- model.tensors["encoder.blocks." + std::to_string(i) + ".mlp.2.weight"] = layer.mlp_1_w;
851
- model.tensors["encoder.blocks." + std::to_string(i) + ".mlp.2.bias"] = layer.mlp_1_b;
852
-
853
- model.tensors["encoder.blocks." + std::to_string(i) + ".attn_ln.weight"] = layer.attn_ln_0_w;
854
- model.tensors["encoder.blocks." + std::to_string(i) + ".attn_ln.bias"] = layer.attn_ln_0_b;
855
-
856
- model.tensors["encoder.blocks." + std::to_string(i) + ".attn.query.weight"] = layer.attn_q_w;
857
- model.tensors["encoder.blocks." + std::to_string(i) + ".attn.query.bias"] = layer.attn_q_b;
858
-
859
- model.tensors["encoder.blocks." + std::to_string(i) + ".attn.key.weight"] = layer.attn_k_w;
860
-
861
- model.tensors["encoder.blocks." + std::to_string(i) + ".attn.value.weight"] = layer.attn_v_w;
862
- model.tensors["encoder.blocks." + std::to_string(i) + ".attn.value.bias"] = layer.attn_v_b;
863
-
864
- model.tensors["encoder.blocks." + std::to_string(i) + ".attn.out.weight"] = layer.attn_ln_1_w;
865
- model.tensors["encoder.blocks." + std::to_string(i) + ".attn.out.bias"] = layer.attn_ln_1_b;
866
- }
867
- }
868
-
869
- // decoder
870
- {
871
- model.d_pe = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_text_state, n_text_ctx);
872
-
873
- model.d_te = ggml_new_tensor_2d(ctx, wtype, n_text_state, n_vocab);
874
-
875
- model.d_ln_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
876
- model.d_ln_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
877
-
878
- // map by name
879
- model.tensors["decoder.positional_embedding"] = model.d_pe;
880
-
881
- model.tensors["decoder.token_embedding.weight"] = model.d_te;
882
-
883
- model.tensors["decoder.ln.weight"] = model.d_ln_w;
884
- model.tensors["decoder.ln.bias"] = model.d_ln_b;
885
-
886
- for (int i = 0; i < n_text_layer; ++i) {
887
- auto & layer = model.layers_decoder[i];
888
-
889
- layer.mlp_ln_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
890
- layer.mlp_ln_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
891
-
892
- layer.mlp_0_w = ggml_new_tensor_2d(ctx, wtype, n_text_state, 4*n_text_state);
893
- layer.mlp_0_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 4*n_text_state);
894
-
895
- layer.mlp_1_w = ggml_new_tensor_2d(ctx, wtype, 4*n_text_state, n_text_state);
896
- layer.mlp_1_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
897
-
898
- layer.attn_ln_0_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
899
- layer.attn_ln_0_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
900
-
901
- layer.attn_q_w = ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state);
902
- layer.attn_q_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
903
-
904
- layer.attn_k_w = ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state);
905
-
906
- layer.attn_v_w = ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state);
907
- layer.attn_v_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
908
-
909
- layer.attn_ln_1_w = ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state);
910
- layer.attn_ln_1_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
911
-
912
- layer.cross_attn_ln_0_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
913
- layer.cross_attn_ln_0_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
914
-
915
- layer.cross_attn_q_w = ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state);
916
- layer.cross_attn_q_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
917
-
918
- layer.cross_attn_k_w = ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state);
919
-
920
- layer.cross_attn_v_w = ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state);
921
- layer.cross_attn_v_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
922
-
923
- layer.cross_attn_ln_1_w = ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state);
924
- layer.cross_attn_ln_1_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
925
-
926
- // map by name
927
- model.tensors["decoder.blocks." + std::to_string(i) + ".mlp_ln.weight"] = layer.mlp_ln_w;
928
- model.tensors["decoder.blocks." + std::to_string(i) + ".mlp_ln.bias"] = layer.mlp_ln_b;
929
-
930
- model.tensors["decoder.blocks." + std::to_string(i) + ".mlp.0.weight"] = layer.mlp_0_w;
931
- model.tensors["decoder.blocks." + std::to_string(i) + ".mlp.0.bias"] = layer.mlp_0_b;
932
-
933
- model.tensors["decoder.blocks." + std::to_string(i) + ".mlp.2.weight"] = layer.mlp_1_w;
934
- model.tensors["decoder.blocks." + std::to_string(i) + ".mlp.2.bias"] = layer.mlp_1_b;
935
-
936
- model.tensors["decoder.blocks." + std::to_string(i) + ".attn_ln.weight"] = layer.attn_ln_0_w;
937
- model.tensors["decoder.blocks." + std::to_string(i) + ".attn_ln.bias"] = layer.attn_ln_0_b;
938
-
939
- model.tensors["decoder.blocks." + std::to_string(i) + ".attn.query.weight"] = layer.attn_q_w;
940
- model.tensors["decoder.blocks." + std::to_string(i) + ".attn.query.bias"] = layer.attn_q_b;
941
-
942
- model.tensors["decoder.blocks." + std::to_string(i) + ".attn.key.weight"] = layer.attn_k_w;
943
-
944
- model.tensors["decoder.blocks." + std::to_string(i) + ".attn.value.weight"] = layer.attn_v_w;
945
- model.tensors["decoder.blocks." + std::to_string(i) + ".attn.value.bias"] = layer.attn_v_b;
946
-
947
- model.tensors["decoder.blocks." + std::to_string(i) + ".attn.out.weight"] = layer.attn_ln_1_w;
948
- model.tensors["decoder.blocks." + std::to_string(i) + ".attn.out.bias"] = layer.attn_ln_1_b;
949
-
950
- model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn_ln.weight"] = layer.cross_attn_ln_0_w;
951
- model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn_ln.bias"] = layer.cross_attn_ln_0_b;
952
-
953
- model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn.query.weight"] = layer.cross_attn_q_w;
954
- model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn.query.bias"] = layer.cross_attn_q_b;
955
-
956
- model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn.key.weight"] = layer.cross_attn_k_w;
957
-
958
- model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn.value.weight"] = layer.cross_attn_v_w;
959
- model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn.value.bias"] = layer.cross_attn_v_b;
960
-
961
- model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn.out.weight"] = layer.cross_attn_ln_1_w;
962
- model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn.out.bias"] = layer.cross_attn_ln_1_b;
963
- }
964
- }
965
- }
966
-
967
- // key + value memory
968
- {
969
- const auto & hparams = model.hparams;
970
-
971
- const int n_text_state = hparams.n_text_state;
972
- const int n_text_layer = hparams.n_text_layer;
973
- const int n_text_ctx = hparams.n_text_ctx;
974
-
975
- // key/value memory for the self-attention layer
976
- {
977
- const int n_mem = n_text_layer*n_text_ctx;
978
- const int n_elements = n_text_state*n_mem;
979
-
980
- model.memory_k = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, n_elements);
981
- model.memory_v = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, n_elements);
982
- }
983
-
984
- // key/value memory for the cross-attention layer
985
- {
986
- const int n_audio_ctx = hparams.n_audio_ctx;
987
-
988
- const int n_mem = n_text_layer*n_audio_ctx;
989
- const int n_elements = n_text_state*n_mem;
990
-
991
- model.memory_cross_k = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, n_elements);
992
- model.memory_cross_v = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, n_elements);
993
- }
994
-
995
- const size_t memory_size =
996
- ggml_nbytes(model.memory_k) + ggml_nbytes(model.memory_v) +
997
- ggml_nbytes(model.memory_cross_k) + ggml_nbytes(model.memory_cross_v);
998
-
999
- printf("%s: memory size = %8.2f MB \n", __func__, memory_size/1024.0/1024.0);
1000
- }
1001
-
1002
- // load weights
1003
- {
1004
- size_t total_size = 0;
1005
-
1006
- while (true) {
1007
- int32_t n_dims;
1008
- int32_t length;
1009
- int32_t ftype;
1010
-
1011
- fin.read(reinterpret_cast<char *>(&n_dims), sizeof(n_dims));
1012
- fin.read(reinterpret_cast<char *>(&length), sizeof(length));
1013
- fin.read(reinterpret_cast<char *>(&ftype), sizeof(ftype));
1014
-
1015
- if (fin.eof()) {
1016
- break;
1017
- }
1018
-
1019
- int32_t nelements = 1;
1020
- int32_t ne[3] = { 1, 1, 1 };
1021
- for (int i = 0; i < n_dims; ++i) {
1022
- fin.read(reinterpret_cast<char *>(&ne[i]), sizeof(ne[i]));
1023
- nelements *= ne[i];
1024
- }
1025
-
1026
- std::string name(length, 0);
1027
- fin.read(&name[0], length);
1028
-
1029
- if (model.tensors.find(name.data()) == model.tensors.end()) {
1030
- fprintf(stderr, "%s: unknown tensor '%s' in model file\n", __func__, name.data());
1031
- return false;
1032
- }
1033
-
1034
- auto tensor = model.tensors[name.data()];
1035
- if (ggml_nelements(tensor) != nelements) {
1036
- fprintf(stderr, "%s: tensor '%s' has wrong size in model file\n", __func__, name.data());
1037
- return false;
1038
- }
1039
-
1040
- if (tensor->ne[0] != ne[0] || tensor->ne[1] != ne[1] || tensor->ne[2] != ne[2]) {
1041
- fprintf(stderr, "%s: tensor '%s' has wrong shape in model file: got [%d, %d, %d], expected [%d, %d, %d]\n",
1042
- __func__, name.data(), tensor->ne[0], tensor->ne[1], tensor->ne[2], ne[0], ne[1], ne[2]);
1043
- return false;
1044
- }
1045
-
1046
- const size_t bpe = (ftype == 0) ? sizeof(float) : sizeof(ggml_fp16_t);
1047
-
1048
- if (nelements*bpe != ggml_nbytes(tensor)) {
1049
- fprintf(stderr, "%s: tensor '%s' has wrong size in model file: got %zu, expected %zu\n",
1050
- __func__, name.data(), ggml_nbytes(tensor), nelements*bpe);
1051
- return false;
1052
- }
1053
-
1054
- fin.read(reinterpret_cast<char *>(tensor->data), ggml_nbytes(tensor));
1055
-
1056
- //printf("%24s - [%5d, %5d], type = %6s, %6.2f MB\n", name.data(), ne[0], ne[1], ftype == 0 ? "float" : "f16", ggml_nbytes(tensor)/1024.0/1024.0);
1057
- total_size += ggml_nbytes(tensor);
1058
- }
1059
-
1060
- printf("%s: model size = %8.2f MB\n", __func__, total_size/1024.0/1024.0);
1061
- }
1062
-
1063
- fin.close();
1064
-
1065
- return true;
1066
- }
1067
-
1068
- // evaluate the encoder
1069
- //
1070
- // given audio recording (more specifically, its log mel spectrogram), runs forward pass of the encoder
1071
- // part of the transformer model and returns the encoded features
1072
- //
1073
- // - model: the model
1074
- // - n_threads: number of threads to use
1075
- // - mel_offset: offset in the mel spectrogram (i.e. audio offset)
1076
- // - mel_inp: input mel spectrogram
1077
- // - features: output encoded features
1078
- //
1079
- bool whisper_encode(
1080
- const whisper_model & model,
1081
- const int n_threads,
1082
- const int mel_offset,
1083
- const whisper_mel & mel_inp,
1084
- std::vector<float> & features) {
1085
- const auto & hparams = model.hparams;
1086
-
1087
- const int n_vocab = hparams.n_vocab;
1088
-
1089
- const int n_ctx = hparams.n_audio_ctx;
1090
- const int n_state = hparams.n_audio_state;
1091
- const int n_head = hparams.n_audio_head;
1092
- const int n_layer = hparams.n_audio_layer;
1093
-
1094
- const int N = n_ctx;
1095
-
1096
- const int n_mels = hparams.n_mels;
1097
- assert(mel_inp.n_mel == n_mels);
1098
-
1099
- struct ggml_init_params params = {
1100
- .mem_size = g_buf_compute.size(),
1101
- .mem_buffer = g_buf_compute.data(),
1102
- };
1103
-
1104
- struct ggml_context * ctx0 = ggml_init(params);
1105
-
1106
- struct ggml_tensor * mel = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, 2*n_ctx, n_mels);
1107
- assert(mel->type == GGML_TYPE_F32);
1108
- {
1109
- float * dst = (float *) mel->data;
1110
- memset(dst, 0, ggml_nbytes(mel));
1111
-
1112
- const int i0 = std::min(mel_offset, mel_inp.n_len);
1113
- const int i1 = std::min(mel_offset + 2*n_ctx, mel_inp.n_len);
1114
-
1115
- for (int j = 0; j < mel_inp.n_mel; ++j) {
1116
- for (int i = i0; i < i1; ++i) {
1117
- dst[j*2*n_ctx + (i - i0)] = mel_inp.data[j*mel_inp.n_len + i];
1118
- }
1119
- }
1120
- }
1121
-
1122
- struct ggml_tensor * cur;
1123
-
1124
- // convolution + gelu
1125
- {
1126
- cur = ggml_conv_1d_1s(ctx0, model.e_conv_1_w, mel);
1127
- cur = ggml_add(ctx0,
1128
- ggml_repeat(ctx0,
1129
- model.e_conv_1_b,
1130
- cur),
1131
- cur);
1132
-
1133
- cur = ggml_gelu(ctx0, cur);
1134
-
1135
- cur = ggml_conv_1d_2s(ctx0, model.e_conv_2_w, cur);
1136
- cur = ggml_add(ctx0,
1137
- ggml_repeat(ctx0,
1138
- model.e_conv_2_b,
1139
- cur),
1140
- cur);
1141
-
1142
- cur = ggml_gelu(ctx0, cur);
1143
- }
1144
-
1145
- cur = ggml_add(ctx0, model.e_pe, ggml_transpose(ctx0, cur));
1146
-
1147
- struct ggml_tensor * inpL = cur;
1148
-
1149
- for (int il = 0; il < n_layer; ++il) {
1150
- const auto & layer = model.layers_encoder[il];
1151
-
1152
- // create separate context for each layer to reduce memory usage
1153
-
1154
- struct ggml_init_params paramsL = {
1155
- .mem_size = g_buf_compute_layer.size(),
1156
- .mem_buffer = g_buf_compute_layer.data(),
1157
- };
1158
-
1159
- struct ggml_context * ctxL = ggml_init(paramsL);
1160
-
1161
- // norm
1162
- {
1163
- cur = ggml_norm(ctxL, inpL);
1164
-
1165
- // cur = ln_0_w*cur + ln_0_b
1166
- cur = ggml_add(ctxL,
1167
- ggml_mul(ctxL,
1168
- ggml_repeat(ctxL, layer.attn_ln_0_w, cur),
1169
- cur),
1170
- ggml_repeat(ctxL, layer.attn_ln_0_b, cur));
1171
- }
1172
-
1173
- // self-attention
1174
- {
1175
- struct ggml_tensor * Qcur = ggml_mul_mat(ctxL,
1176
- layer.attn_q_w,
1177
- cur);
1178
-
1179
- Qcur = ggml_add(ctxL,
1180
- ggml_repeat(ctxL,
1181
- layer.attn_q_b,
1182
- Qcur),
1183
- Qcur);
1184
-
1185
- //Qcur = ggml_scale(ctxL, Qcur, ggml_new_f32(ctxL, pow(float(n_state)/n_head, -0.25)));
1186
-
1187
- // note: no bias for Key
1188
- struct ggml_tensor * Kcur = ggml_mul_mat(ctxL,
1189
- layer.attn_k_w,
1190
- cur);
1191
-
1192
- //Kcur = ggml_scale(ctxL, Kcur, ggml_new_f32(ctxL, pow(float(n_state)/n_head, -0.25)));
1193
-
1194
- struct ggml_tensor * Vcur = ggml_mul_mat(ctxL,
1195
- layer.attn_v_w,
1196
- cur);
1197
-
1198
- Vcur = ggml_add(ctxL,
1199
- ggml_repeat(ctxL,
1200
- layer.attn_v_b,
1201
- Vcur),
1202
- Vcur);
1203
-
1204
- // ------
1205
-
1206
- #ifdef USE_FLASH_ATTN
1207
- struct ggml_tensor * Q =
1208
- ggml_permute(ctxL,
1209
- ggml_cpy(ctxL,
1210
- Qcur,
1211
- ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_state/n_head, n_head, N)),
1212
- 0, 2, 1, 3);
1213
-
1214
- struct ggml_tensor * K =
1215
- ggml_permute(ctxL,
1216
- ggml_cpy(ctxL,
1217
- Kcur,
1218
- ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_state/n_head, n_head, N)),
1219
- 0, 2, 1, 3);
1220
-
1221
- struct ggml_tensor * V =
1222
- ggml_cpy(ctxL,
1223
- ggml_permute(ctxL,
1224
- ggml_reshape_3d(ctxL,
1225
- Vcur,
1226
- n_state/n_head, n_head, N),
1227
- 1, 2, 0, 3),
1228
- ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, N, n_state/n_head, n_head)
1229
- );
1230
-
1231
- struct ggml_tensor * KQV = ggml_flash_attn(ctxL, Q, K, V, false);
1232
- #else
1233
- struct ggml_tensor * Q =
1234
- ggml_permute(ctxL,
1235
- ggml_cpy(ctxL,
1236
- Qcur,
1237
- ggml_new_tensor_3d(ctxL, GGML_TYPE_F32, n_state/n_head, n_head, N)),
1238
- 0, 2, 1, 3);
1239
-
1240
- struct ggml_tensor * K =
1241
- ggml_permute(ctxL,
1242
- ggml_cpy(ctxL,
1243
- Kcur,
1244
- ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_state/n_head, n_head, N)),
1245
- 0, 2, 1, 3);
1246
-
1247
- // K * Q
1248
- struct ggml_tensor * KQ = ggml_mul_mat(ctxL, K, Q);
1249
-
1250
- struct ggml_tensor * KQ_scaled =
1251
- ggml_scale(ctxL,
1252
- KQ,
1253
- ggml_new_f32(ctxL, 1.0f/sqrt(float(n_state)/n_head))
1254
- );
1255
-
1256
- struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctxL, KQ_scaled);
1257
-
1258
- //struct ggml_tensor * V_trans =
1259
- // ggml_permute(ctxL,
1260
- // ggml_cpy(ctxL,
1261
- // Vcur,
1262
- // ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_state/n_head, n_head, N)),
1263
- // 1, 2, 0, 3);
1264
-
1265
- //struct ggml_tensor * KQV = ggml_mul_mat(ctxL, V_trans, KQ_soft_max);
1266
-
1267
- struct ggml_tensor * V =
1268
- ggml_cpy(ctxL,
1269
- ggml_permute(ctxL,
1270
- ggml_reshape_3d(ctxL,
1271
- Vcur,
1272
- n_state/n_head, n_head, N),
1273
- 0, 2, 1, 3),
1274
- ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_state/n_head, N, n_head)
1275
- );
1276
-
1277
- struct ggml_tensor * KQV = ggml_mul_mat(ctxL, ggml_transpose(ctxL, V), KQ_soft_max);
1278
- #endif
1279
-
1280
- struct ggml_tensor * KQV_merged = ggml_permute(ctxL, KQV, 0, 2, 1, 3);
1281
-
1282
- cur = ggml_cpy(ctxL,
1283
- KQV_merged,
1284
- ggml_new_tensor_2d(ctxL, GGML_TYPE_F32, n_state, N));
1285
- }
1286
-
1287
- // projection
1288
- {
1289
- cur = ggml_mul_mat(ctxL,
1290
- layer.attn_ln_1_w,
1291
- cur);
1292
-
1293
- cur = ggml_add(ctxL,
1294
- ggml_repeat(ctxL, layer.attn_ln_1_b, cur),
1295
- cur);
1296
- }
1297
-
1298
- // add the input
1299
- cur = ggml_add(ctxL, cur, inpL);
1300
-
1301
- struct ggml_tensor * inpFF = cur;
1302
-
1303
- // feed-forward network
1304
- {
1305
- // norm
1306
- {
1307
- cur = ggml_norm(ctxL, inpFF);
1308
-
1309
- // cur = mlp_ln_w*cur + mlp_ln_b
1310
- cur = ggml_add(ctxL,
1311
- ggml_mul(ctxL,
1312
- ggml_repeat(ctxL, layer.mlp_ln_w, cur),
1313
- cur),
1314
- ggml_repeat(ctxL, layer.mlp_ln_b, cur));
1315
- }
1316
-
1317
- #ifdef USE_FLASH_FF
1318
- cur = ggml_flash_ff(ctxL,
1319
- ggml_cpy(ctxL, cur, ggml_new_tensor_2d(ctxL, GGML_TYPE_F16, n_state, N)),
1320
- layer.mlp_0_w, layer.mlp_0_b, layer.mlp_1_w, layer.mlp_1_b);
1321
- #else
1322
- // fully connected
1323
- cur = ggml_mul_mat(ctxL,
1324
- layer.mlp_0_w,
1325
- cur);
1326
-
1327
- cur = ggml_add(ctxL,
1328
- ggml_repeat(ctxL, layer.mlp_0_b, cur),
1329
- cur);
1330
-
1331
- // GELU activation
1332
- cur = ggml_gelu(ctxL, cur);
1333
-
1334
- // projection
1335
- cur = ggml_mul_mat(ctxL,
1336
- layer.mlp_1_w,
1337
- cur);
1338
-
1339
- cur = ggml_add(ctxL,
1340
- ggml_repeat(ctxL, layer.mlp_1_b, cur),
1341
- cur);
1342
- #endif
1343
- }
1344
-
1345
- // output from this layer
1346
- struct ggml_tensor * inpO = ggml_add(ctxL, cur, inpFF);
1347
-
1348
- {
1349
- struct ggml_cgraph gf = { .n_threads = n_threads };
1350
-
1351
- ggml_build_forward_expand(&gf, inpO);
1352
- ggml_graph_compute (ctxL, &gf);
1353
-
1354
- //ggml_graph_print(&gf);
1355
- }
1356
-
1357
- // TODO: this is a hack to have per-layer computation graphs - need to come up with something better
1358
- // input for next layer (inpO -> inpL)
1359
- memcpy(inpL->data, inpO->data, ggml_nbytes(inpL));
1360
- inpL->op = GGML_OP_NONE;
1361
- inpL->src0 = NULL;
1362
- inpL->src1 = NULL;
1363
-
1364
- //printf("%s: - used_mem(%d) = %f MB\n", __func__, il, ggml_used_mem(ctxL)/1024.0/1024.0);
1365
-
1366
- ggml_free(ctxL);
1367
- }
1368
-
1369
- cur = inpL;
1370
-
1371
- // norm
1372
- {
1373
- cur = ggml_norm(ctx0, cur);
1374
-
1375
- // cur = ln_f_g*cur + ln_f_b
1376
- cur = ggml_add(ctx0,
1377
- ggml_mul(ctx0,
1378
- ggml_repeat(ctx0, model.e_ln_w, cur),
1379
- cur),
1380
- ggml_repeat(ctx0, model.e_ln_b, cur));
1381
- }
1382
-
1383
- // run the computation
1384
- {
1385
- struct ggml_cgraph gf = { .n_threads = n_threads };
1386
-
1387
- ggml_build_forward_expand(&gf, cur);
1388
- ggml_graph_compute (ctx0, &gf);
1389
-
1390
- //ggml_graph_print(&gf);
1391
- }
1392
-
1393
- // cur
1394
- //{
1395
- // printf("ne0 = %d\n", cur->ne[0]);
1396
- // printf("ne1 = %d\n", cur->ne[1]);
1397
- // for (int i = 0; i < 10; ++i) {
1398
- // printf("%8.4f ", ((float *)(cur->data))[i]);
1399
- // }
1400
- // printf("... ");
1401
- // for (int i = cur->ne[0] - 10; i < cur->ne[0]; ++i) {
1402
- // printf("%8.4f ", ((float *)(cur->data))[i]);
1403
- // }
1404
- // printf("\n");
1405
- //}
1406
-
1407
- // pre-compute cross-attention memory
1408
- {
1409
- struct ggml_cgraph gf = { .n_threads = n_threads };
1410
-
1411
- // TODO: hack to disconnect the encoded features from the previous graph
1412
- cur->op = GGML_OP_NONE;
1413
- cur->src0 = NULL;
1414
- cur->src1 = NULL;
1415
-
1416
- for (int il = 0; il < model.hparams.n_text_layer; ++il) {
1417
- auto & layer = model.layers_decoder[il];
1418
-
1419
- struct ggml_tensor * Kcross = ggml_mul_mat(ctx0,
1420
- layer.cross_attn_k_w,
1421
- cur);
1422
-
1423
- Kcross = ggml_scale(ctx0, Kcross, ggml_new_f32(ctx0, pow(float(n_state)/n_head, -0.25)));
1424
-
1425
- struct ggml_tensor * Vcross = ggml_mul_mat(ctx0,
1426
- layer.cross_attn_v_w,
1427
- cur);
1428
-
1429
- Vcross = ggml_add(ctx0,
1430
- ggml_repeat(ctx0,
1431
- layer.cross_attn_v_b,
1432
- Vcross),
1433
- Vcross);
1434
-
1435
- struct ggml_tensor * k = ggml_view_1d(ctx0, model.memory_cross_k, n_state*n_ctx, (ggml_element_size(model.memory_cross_k)*n_state)*(il*n_ctx));
1436
- struct ggml_tensor * v = ggml_view_1d(ctx0, model.memory_cross_v, n_state*n_ctx, (ggml_element_size(model.memory_cross_v)*n_state)*(il*n_ctx));
1437
-
1438
- ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Kcross, k));
1439
- ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Vcross, v));
1440
- }
1441
-
1442
- ggml_graph_compute(ctx0, &gf);
1443
- }
1444
-
1445
- ////////////////////////////////////////////////////////////////////////////
1446
-
1447
- // output the features
1448
- assert(cur->type == GGML_TYPE_F32);
1449
- features.resize(cur->ne[0]*cur->ne[1]);
1450
- memcpy(features.data(), cur->data, features.size()*sizeof(float));
1451
-
1452
- //printf("%s: used_mem = %f MB\n", __func__, ggml_used_mem(ctx0)/1024.0/1024.0);
1453
-
1454
- ggml_free(ctx0);
1455
-
1456
- return true;
1457
- }
1458
-
1459
- // evaluate the decoder
1460
- //
1461
- // given text prompt + audio features -> predicts the probabilities for the next token
1462
- //
1463
- // - model: the model
1464
- // - n_threads: number of threads to use
1465
- // - n_past: prompt length
1466
- // - prompt: text prompt
1467
- // - logits_out: output logits
1468
- // - probs_out: output probabilities
1469
- //
1470
- bool whisper_decode(
1471
- const whisper_model & model,
1472
- const int n_threads,
1473
- const int n_past,
1474
- const std::vector<whisper_vocab::id> & prompt,
1475
- std::vector<float> & logits_out,
1476
- std::vector<float> & probs_out) {
1477
- const auto & hparams = model.hparams;
1478
-
1479
- const int n_vocab = hparams.n_vocab;
1480
-
1481
- const int n_ctx = hparams.n_text_ctx;
1482
- const int n_state = hparams.n_text_state;
1483
- const int n_head = hparams.n_text_head;
1484
- const int n_layer = hparams.n_text_layer;
1485
-
1486
- const int N = prompt.size();
1487
- const int M = hparams.n_audio_ctx;
1488
-
1489
- struct ggml_init_params params = {
1490
- .mem_size = g_buf_compute.size(),
1491
- .mem_buffer = g_buf_compute.data(),
1492
- };
1493
-
1494
- struct ggml_context * ctx0 = ggml_init(params);
1495
-
1496
- struct ggml_tensor * embd = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
1497
- memcpy(embd->data, prompt.data(), N*ggml_element_size(embd));
1498
-
1499
- struct ggml_tensor * position = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
1500
- for (int i = 0; i < N; ++i) {
1501
- ((int32_t *) position->data)[i] = n_past + i;
1502
- }
1503
-
1504
- // token encoding + position encoding
1505
- struct ggml_tensor * cur =
1506
- ggml_add(ctx0,
1507
- ggml_get_rows(ctx0, model.d_te, embd),
1508
- ggml_get_rows(ctx0, model.d_pe, position));
1509
-
1510
- struct ggml_tensor * inpL = cur;
1511
-
1512
- for (int il = 0; il < n_layer; ++il) {
1513
- const auto & layer = model.layers_decoder[il];
1514
-
1515
- struct ggml_init_params paramsL = {
1516
- .mem_size = g_buf_compute_layer.size(),
1517
- .mem_buffer = g_buf_compute_layer.data(),
1518
- };
1519
-
1520
- struct ggml_context * ctxL = ggml_init(paramsL);
1521
- struct ggml_cgraph gf = { .n_threads = n_threads };
1522
-
1523
- // norm
1524
- {
1525
- cur = ggml_norm(ctxL, inpL);
1526
-
1527
- // cur = ln_0_w*cur + ln_0_b
1528
- cur = ggml_add(ctxL,
1529
- ggml_mul(ctxL,
1530
- ggml_repeat(ctxL, layer.attn_ln_0_w, cur),
1531
- cur),
1532
- ggml_repeat(ctxL, layer.attn_ln_0_b, cur));
1533
- }
1534
-
1535
- // self-attention
1536
- {
1537
- struct ggml_tensor * Qcur = ggml_mul_mat(ctxL,
1538
- layer.attn_q_w,
1539
- cur);
1540
-
1541
- Qcur = ggml_add(ctxL,
1542
- ggml_repeat(ctxL,
1543
- layer.attn_q_b,
1544
- Qcur),
1545
- Qcur);
1546
-
1547
- Qcur = ggml_scale(ctxL, Qcur, ggml_new_f32(ctxL, pow(float(n_state)/n_head, -0.25)));
1548
-
1549
- // note: no bias for Key
1550
- struct ggml_tensor * Kcur = ggml_mul_mat(ctxL,
1551
- layer.attn_k_w,
1552
- cur);
1553
-
1554
- Kcur = ggml_scale(ctxL, Kcur, ggml_new_f32(ctxL, pow(float(n_state)/n_head, -0.25)));
1555
-
1556
- struct ggml_tensor * Vcur = ggml_mul_mat(ctxL,
1557
- layer.attn_v_w,
1558
- cur);
1559
-
1560
- Vcur = ggml_add(ctxL,
1561
- ggml_repeat(ctxL,
1562
- layer.attn_v_b,
1563
- Vcur),
1564
- Vcur);
1565
-
1566
- // store key and value to memory
1567
- {
1568
- struct ggml_tensor * k = ggml_view_1d(ctxL, model.memory_k, N*n_state, (ggml_element_size(model.memory_k)*n_state)*(il*n_ctx + n_past));
1569
- struct ggml_tensor * v = ggml_view_1d(ctxL, model.memory_v, N*n_state, (ggml_element_size(model.memory_v)*n_state)*(il*n_ctx + n_past));
1570
-
1571
- ggml_build_forward_expand(&gf, ggml_cpy(ctxL, Kcur, k));
1572
- ggml_build_forward_expand(&gf, ggml_cpy(ctxL, Vcur, v));
1573
- }
1574
-
1575
- // ------
1576
-
1577
- struct ggml_tensor * Q =
1578
- ggml_permute(ctxL,
1579
- ggml_cpy(ctxL,
1580
- Qcur,
1581
- ggml_new_tensor_3d(ctxL, GGML_TYPE_F32, n_state/n_head, n_head, N)),
1582
- 0, 2, 1, 3);
1583
-
1584
- struct ggml_tensor * K =
1585
- ggml_permute(ctxL,
1586
- ggml_reshape_3d(ctxL,
1587
- ggml_view_1d(ctxL, model.memory_k, (n_past + N)*n_state, il*n_ctx*ggml_element_size(model.memory_k)*n_state),
1588
- n_state/n_head, n_head, n_past + N),
1589
- 0, 2, 1, 3);
1590
-
1591
- // K * Q
1592
- struct ggml_tensor * KQ = ggml_mul_mat(ctxL, K, Q);
1593
-
1594
- //struct ggml_tensor * KQ_scaled =
1595
- // ggml_scale(ctxL,
1596
- // KQ,
1597
- // ggml_new_f32(ctxL, 1.0f/sqrt(float(n_state)/n_head))
1598
- // );
1599
-
1600
- struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctxL, KQ, n_past);
1601
-
1602
- struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctxL, KQ_masked);
1603
-
1604
- struct ggml_tensor * V_trans =
1605
- ggml_permute(ctxL,
1606
- ggml_reshape_3d(ctxL,
1607
- ggml_view_1d(ctxL, model.memory_v, (n_past + N)*n_state, il*n_ctx*ggml_element_size(model.memory_v)*n_state),
1608
- n_state/n_head, n_head, n_past + N),
1609
- 1, 2, 0, 3);
1610
-
1611
- struct ggml_tensor * KQV = ggml_mul_mat(ctxL, V_trans, KQ_soft_max);
1612
-
1613
- struct ggml_tensor * KQV_merged = ggml_permute(ctxL, KQV, 0, 2, 1, 3);
1614
-
1615
- cur = ggml_cpy(ctxL,
1616
- KQV_merged,
1617
- ggml_new_tensor_2d(ctxL, GGML_TYPE_F32, n_state, N));
1618
- }
1619
-
1620
- {
1621
- cur = ggml_mul_mat(ctxL,
1622
- layer.attn_ln_1_w,
1623
- cur);
1624
-
1625
- cur = ggml_add(ctxL,
1626
- ggml_repeat(ctxL, layer.attn_ln_1_b, cur),
1627
- cur);
1628
- }
1629
-
1630
- // add the input
1631
- struct ggml_tensor * inpCA = ggml_add(ctxL, cur, inpL);
1632
-
1633
- // norm
1634
- {
1635
- cur = ggml_norm(ctxL, inpCA); // note: we use inpCA here
1636
-
1637
- // cur = ln_0_w*cur + ln_0_b
1638
- cur = ggml_add(ctxL,
1639
- ggml_mul(ctxL,
1640
- ggml_repeat(ctxL, layer.cross_attn_ln_0_w, cur),
1641
- cur),
1642
- ggml_repeat(ctxL, layer.cross_attn_ln_0_b, cur));
1643
- }
1644
-
1645
- // cross-attention
1646
- {
1647
- struct ggml_tensor * Qcur = ggml_mul_mat(ctxL,
1648
- layer.cross_attn_q_w,
1649
- cur);
1650
-
1651
- Qcur = ggml_add(ctxL,
1652
- ggml_repeat(ctxL,
1653
- layer.cross_attn_q_b,
1654
- Qcur),
1655
- Qcur);
1656
-
1657
- Qcur = ggml_scale(ctxL, Qcur, ggml_new_f32(ctxL, pow(float(n_state)/n_head, -0.25)));
1658
-
1659
- // Kcross is already scaled
1660
- struct ggml_tensor * Kcross =
1661
- ggml_reshape_3d(ctxL,
1662
- ggml_view_1d(ctxL, model.memory_cross_k, M*n_state, il*M*ggml_element_size(model.memory_cross_k)*n_state),
1663
- n_state/n_head, n_head, M);
1664
-
1665
- struct ggml_tensor * Vcross =
1666
- ggml_reshape_3d(ctxL,
1667
- ggml_view_1d(ctxL, model.memory_cross_v, M*n_state, il*M*ggml_element_size(model.memory_cross_v)*n_state),
1668
- n_state/n_head, n_head, M);
1669
-
1670
- // ------
1671
-
1672
- struct ggml_tensor * Q =
1673
- ggml_permute(ctxL,
1674
- ggml_cpy(ctxL,
1675
- Qcur,
1676
- ggml_new_tensor_3d(ctxL, GGML_TYPE_F32, n_state/n_head, n_head, N)),
1677
- 0, 2, 1, 3);
1678
-
1679
- struct ggml_tensor * K = ggml_permute(ctxL, Kcross, 0, 2, 1, 3);
1680
-
1681
- // K * Q
1682
- struct ggml_tensor * KQ = ggml_mul_mat(ctxL, K, Q);
1683
-
1684
- //struct ggml_tensor * KQ_scaled =
1685
- // ggml_scale(ctxL,
1686
- // KQ,
1687
- // ggml_new_f32(ctxL, 1.0f/sqrt(float(n_state)/n_head))
1688
- // );
1689
-
1690
- // no masking for cross-attention
1691
- //struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctxL, KQ_scaled, n_past);
1692
-
1693
- struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctxL, KQ);
1694
-
1695
- struct ggml_tensor * V_trans = ggml_permute(ctxL, Vcross, 1, 2, 0, 3);
1696
-
1697
- struct ggml_tensor * KQV = ggml_mul_mat(ctxL, V_trans, KQ_soft_max);
1698
-
1699
- struct ggml_tensor * KQV_merged = ggml_permute(ctxL, KQV, 0, 2, 1, 3);
1700
-
1701
- // cur = KQV_merged.contiguous().view(n_state, N)
1702
- cur = ggml_cpy(ctxL,
1703
- KQV_merged,
1704
- ggml_new_tensor_2d(ctxL, GGML_TYPE_F32, n_state, N));
1705
- }
1706
-
1707
- // projection
1708
- {
1709
- cur = ggml_mul_mat(ctxL,
1710
- layer.cross_attn_ln_1_w,
1711
- cur);
1712
-
1713
- cur = ggml_add(ctxL,
1714
- ggml_repeat(ctxL, layer.cross_attn_ln_1_b, cur),
1715
- cur);
1716
- }
1717
-
1718
- // add the input
1719
- cur = ggml_add(ctxL, cur, inpCA);
1720
-
1721
- struct ggml_tensor * inpFF = cur;
1722
-
1723
- // feed-forward network
1724
- {
1725
- // norm
1726
- {
1727
- cur = ggml_norm(ctxL, inpFF);
1728
-
1729
- // cur = mlp_ln_w*cur + mlp_ln_b
1730
- cur = ggml_add(ctxL,
1731
- ggml_mul(ctxL,
1732
- ggml_repeat(ctxL, layer.mlp_ln_w, cur),
1733
- cur),
1734
- ggml_repeat(ctxL, layer.mlp_ln_b, cur));
1735
- }
1736
-
1737
- // fully connected
1738
- cur = ggml_mul_mat(ctxL,
1739
- layer.mlp_0_w,
1740
- cur);
1741
-
1742
- cur = ggml_add(ctxL,
1743
- ggml_repeat(ctxL, layer.mlp_0_b, cur),
1744
- cur);
1745
-
1746
- // GELU activation
1747
- cur = ggml_gelu(ctxL, cur);
1748
-
1749
- // projection
1750
- cur = ggml_mul_mat(ctxL,
1751
- layer.mlp_1_w,
1752
- cur);
1753
-
1754
- cur = ggml_add(ctxL,
1755
- ggml_repeat(ctxL, layer.mlp_1_b, cur),
1756
- cur);
1757
- }
1758
-
1759
- // output from this layer
1760
- struct ggml_tensor * inpO = ggml_add(ctxL, cur, inpFF);
1761
-
1762
- {
1763
- ggml_build_forward_expand(&gf, inpO);
1764
- ggml_graph_compute (ctxL, &gf);
1765
-
1766
- //ggml_graph_print(&gf);
1767
- }
1768
-
1769
- // TODO: this is a hack to have per-layer computation graphs - need to come up with something better
1770
- // input for next layer (inpO -> inpL)
1771
- memcpy(inpL->data, inpO->data, ggml_nbytes(inpL));
1772
- inpL->op = GGML_OP_NONE;
1773
- inpL->src0 = NULL;
1774
- inpL->src1 = NULL;
1775
-
1776
- if (N > 1) {
1777
- //printf("%s: - used_mem(%d) = %f MB\n", __func__, il, ggml_used_mem(ctxL)/1024.0/1024.0);
1778
- }
1779
-
1780
- ggml_free(ctxL);
1781
- }
1782
-
1783
- cur = inpL;
1784
-
1785
- // norm
1786
- {
1787
- cur = ggml_norm(ctx0, cur);
1788
-
1789
- cur = ggml_add(ctx0,
1790
- ggml_mul(ctx0,
1791
- ggml_repeat(ctx0, model.d_ln_w, cur),
1792
- cur),
1793
- ggml_repeat(ctx0, model.d_ln_b, cur));
1794
- }
1795
-
1796
- struct ggml_tensor * logits = ggml_mul_mat(ctx0, model.d_te, cur);
1797
-
1798
- // logits -> probs
1799
- cur = ggml_dup(ctx0, logits);
1800
- cur = ggml_soft_max(ctx0, cur); // in-place
1801
-
1802
- // run the computation
1803
- {
1804
- struct ggml_cgraph gf = { .n_threads = n_threads };
1805
-
1806
- ggml_build_forward_expand(&gf, cur);
1807
- ggml_graph_compute (ctx0, &gf);
1808
- }
1809
-
1810
- logits_out.resize(N*n_vocab);
1811
- memcpy(logits_out.data(), ggml_get_data(logits), sizeof(float)*N*n_vocab);
1812
-
1813
- probs_out.resize(N*n_vocab);
1814
- memcpy(probs_out.data(), ggml_get_data(cur), sizeof(float)*N*n_vocab);
1815
-
1816
- if (N > 1) {
1817
- //const float mem_per_token = ggml_used_mem(ctx0)/1024.0/1024.0/N;
1818
- //printf("%s: used_mem = %f MB / %f per token\n", __func__, ggml_used_mem(ctx0)/1024.0/1024.0, mem_per_token);
1819
- //printf("%s: max mem = %f MB\n", __func__, mem_per_token*model.hparams.n_text_ctx);
1820
- }
1821
-
1822
- ggml_free(ctx0);
1823
-
1824
- return true;
1825
- }
1826
-
1827
- // the most basic sampling scheme - select the top token
1828
- // TODO: beam search
1829
- // TODO: temperature
1830
- whisper_vocab::id whisper_sample_best(
1831
- const whisper_vocab & vocab,
1832
- const float * probs, bool need_timestamp) {
1833
- int n_logits = vocab.id_to_token.size();
1834
-
1835
- std::vector<std::pair<double, whisper_vocab::id>> probs_id;
1836
- probs_id.reserve(n_logits);
1837
-
1838
- for (int i = 0; i < n_logits; i++) {
1839
- probs_id.push_back(std::make_pair(probs[i], i));
1840
- }
1841
-
1842
- const int top_k = 4;
1843
-
1844
- // find the top K tokens
1845
- std::partial_sort(
1846
- probs_id.begin(),
1847
- probs_id.begin() + top_k, probs_id.end(),
1848
- [](const std::pair<double, whisper_vocab::id> & a, const std::pair<double, whisper_vocab::id> & b) {
1849
- return a.first > b.first;
1850
- });
1851
-
1852
- probs_id.resize(top_k);
1853
-
1854
- //printf("\n");
1855
- //for (int i = 0; i < (int) probs_id.size(); i++) {
1856
- // printf("%d: '%s' %f, %d\n", i, vocab.id_to_token.at(probs_id[i].second).c_str(), probs_id[i].first, probs_id[i].second);
1857
- //}
1858
-
1859
- if (need_timestamp) {
1860
- // at the end of the 30-second audio segment, we start giving preference to time tokens
1861
- for (int i = 0; i < top_k; i++) {
1862
- if (probs_id[i].second > vocab.token_beg + 1300 && probs_id[i].first > 0.01*probs_id[0].first) {
1863
- return probs_id[i].second;
1864
- }
1865
- }
1866
- }
1867
-
1868
- int res = 0;
1869
- while ((probs_id[res].second == vocab.token_sot ||
1870
- probs_id[res].second == vocab.token_solm ||
1871
- probs_id[res].second == vocab.token_not) &&
1872
- res < (int) probs_id.size() - 1) {
1873
- res++;
1874
- }
1875
-
1876
- return probs_id[res].second;
1877
- }
1878
-
1879
- // samples only from the timestamps tokens
1880
- whisper_vocab::id whisper_sample_timestamp(
1881
- const whisper_vocab & vocab,
1882
- const float * probs) {
1883
- int n_logits = vocab.id_to_token.size();
1884
-
1885
- std::vector<std::pair<double, whisper_vocab::id>> probs_id;
1886
- probs_id.reserve(n_logits);
1887
-
1888
- for (int i = vocab.token_beg + 1; i < n_logits; i++) {
1889
- probs_id.push_back(std::make_pair(probs[i], i));
1890
- }
1891
-
1892
- const int top_k = 10;
1893
-
1894
- // find the top K tokens
1895
- std::partial_sort(
1896
- probs_id.begin(),
1897
- probs_id.begin() + top_k, probs_id.end(),
1898
- [](const std::pair<double, whisper_vocab::id> & a, const std::pair<double, whisper_vocab::id> & b) {
1899
- return a.first > b.first;
1900
- });
1901
-
1902
- probs_id.resize(top_k);
1903
-
1904
- //printf("\n");
1905
- //for (int i = 0; i < (int) probs_id.size(); i++) {
1906
- // printf("%d: '%s' %f, %d\n", i, vocab.id_to_token.at(probs_id[i].second).c_str(), probs_id[i].first, probs_id[i].second);
1907
- //}
1908
-
1909
- return probs_id[0].second;
1910
- }
1911
-
1912
- // naive Discrete Fourier Transform
1913
- // input is real-valued
1914
- // output is complex-valued
1915
- void dft(const std::vector<float> & in, std::vector<float> & out) {
1916
- int N = in.size();
1917
-
1918
- out.resize(N*2);
1919
-
1920
- for (int k = 0; k < N; k++) {
1921
- float re = 0;
1922
- float im = 0;
1923
 
1924
- for (int n = 0; n < N; n++) {
1925
- float angle = 2*M_PI*k*n/N;
1926
- re += in[n]*cos(angle);
1927
- im -= in[n]*sin(angle);
1928
- }
1929
 
1930
- out[k*2 + 0] = re;
1931
- out[k*2 + 1] = im;
1932
- }
1933
  }
1934
 
1935
- // Cooley-Tukey FFT
1936
- // poor man's implementation - use something better
1937
- // input is real-valued
1938
- // output is complex-valued
1939
- void fft(const std::vector<float> & in, std::vector<float> & out) {
1940
- out.resize(in.size()*2);
1941
-
1942
- int N = in.size();
1943
-
1944
- if (N == 1) {
1945
- out[0] = in[0];
1946
- out[1] = 0;
1947
- return;
1948
- }
1949
-
1950
- if (N%2 == 1) {
1951
- dft(in, out);
1952
- return;
1953
- }
1954
-
1955
- std::vector<float> even;
1956
- std::vector<float> odd;
1957
-
1958
- for (int i = 0; i < N; i++) {
1959
- if (i % 2 == 0) {
1960
- even.push_back(in[i]);
1961
- } else {
1962
- odd.push_back(in[i]);
1963
- }
1964
- }
1965
-
1966
- std::vector<float> even_fft;
1967
- std::vector<float> odd_fft;
1968
-
1969
- fft(even, even_fft);
1970
- fft(odd, odd_fft);
1971
-
1972
- for (int k = 0; k < N/2; k++) {
1973
- float theta = 2*M_PI*k/N;
1974
-
1975
- float re = cos(theta);
1976
- float im = -sin(theta);
1977
-
1978
- float re_odd = odd_fft[2*k + 0];
1979
- float im_odd = odd_fft[2*k + 1];
1980
 
1981
- out[2*k + 0] = even_fft[2*k + 0] + re*re_odd - im*im_odd;
1982
- out[2*k + 1] = even_fft[2*k + 1] + re*im_odd + im*re_odd;
1983
 
1984
- out[2*(k + N/2) + 0] = even_fft[2*k + 0] - re*re_odd + im*im_odd;
1985
- out[2*(k + N/2) + 1] = even_fft[2*k + 1] - re*im_odd - im*re_odd;
1986
- }
1987
  }
1988
 
1989
- // ref: https://github.com/openai/whisper/blob/main/whisper/audio.py#L92-L124
1990
- bool log_mel_spectrogram(
1991
- const std::vector<float> sf32,
1992
- const int sample_rate,
1993
- const int fft_size,
1994
- const int fft_step,
1995
- const int n_mel,
1996
- const int n_threads,
1997
- const whisper_filters & filters,
1998
- whisper_mel & mel) {
1999
- const int n_sample = sf32.size();
2000
- const float * samples = sf32.data();
2001
-
2002
- // Hanning window
2003
- std::vector<float> hann;
2004
- hann.resize(fft_size);
2005
- for (int i = 0; i < fft_size; i++) {
2006
- hann[i] = 0.5*(1.0 - cos((2.0*M_PI*i)/(fft_size)));
2007
- }
2008
-
2009
- mel.n_mel = n_mel;
2010
- mel.n_len = (n_sample)/fft_step;
2011
- mel.data.resize(mel.n_mel*mel.n_len);
2012
-
2013
- const int n_fft = 1 + fft_size/2;
2014
-
2015
- printf("%s: n_sample = %d, n_len = %d\n", __func__, n_sample, mel.n_len);
2016
- printf("%s: recording length: %f s\n", __func__, (float) n_sample/sample_rate);
2017
-
2018
- std::vector<std::thread> workers(n_threads);
2019
- for (int iw = 0; iw < n_threads; ++iw) {
2020
- workers[iw] = std::thread([&](int ith) {
2021
- std::vector<float> fft_in;
2022
- fft_in.resize(fft_size);
2023
- for (int i = 0; i < fft_size; i++) {
2024
- fft_in[i] = 0.0;
2025
- }
2026
-
2027
- std::vector<float> fft_out;
2028
- fft_out.resize(2*fft_size);
2029
-
2030
- for (int i = ith; i < mel.n_len; i += n_threads) {
2031
- const int offset = i*fft_step;
2032
-
2033
- // apply Hanning window
2034
- for (int j = 0; j < fft_size; j++) {
2035
- if (offset + j < n_sample) {
2036
- fft_in[j] = hann[j]*samples[offset + j];
2037
- } else {
2038
- fft_in[j] = 0.0;
2039
- }
2040
- }
2041
 
2042
- // FFT -> mag^2
2043
- fft(fft_in, fft_out);
 
 
2044
 
2045
- for (int j = 0; j < fft_size; j++) {
2046
- fft_out[j] = (fft_out[2*j + 0]*fft_out[2*j + 0] + fft_out[2*j + 1]*fft_out[2*j + 1]);
2047
- }
2048
- for (int j = 1; j < fft_size/2; j++) {
2049
- //if (i == 0) {
2050
- // printf("%d: %f %f\n", j, fft_out[j], fft_out[fft_size - j]);
2051
- //}
2052
- fft_out[j] += fft_out[fft_size - j];
2053
- }
2054
- if (i == 0) {
2055
- //for (int j = 0; j < fft_size; j++) {
2056
- // printf("%d: %e\n", j, fft_out[j]);
2057
- //}
2058
- }
2059
 
2060
- // mel spectrogram
2061
- for (int j = 0; j < mel.n_mel; j++) {
2062
- double sum = 0.0;
 
2063
 
2064
- for (int k = 0; k < n_fft; k++) {
2065
- sum += fft_out[k]*filters.data[j*n_fft + k];
2066
- }
2067
- if (sum < 1e-10) {
2068
- sum = 1e-10;
2069
- }
2070
 
2071
- sum = log10(sum);
 
 
2072
 
2073
- mel.data[j*mel.n_len + i] = sum;
2074
- }
 
 
 
 
 
 
 
 
 
 
 
 
2075
  }
2076
- }, iw);
2077
- }
2078
-
2079
- for (int iw = 0; iw < n_threads; ++iw) {
2080
- workers[iw].join();
2081
- }
2082
-
2083
- // clamping and normalization
2084
- double mmax = -1e20;
2085
- for (int i = 0; i < mel.n_mel*mel.n_len; i++) {
2086
- if (mel.data[i] > mmax) {
2087
- mmax = mel.data[i];
2088
- }
2089
- }
2090
- //printf("%s: max = %f\n", __func__, mmax);
2091
-
2092
- mmax -= 8.0;
2093
-
2094
- for (int i = 0; i < mel.n_mel*mel.n_len; i++) {
2095
- if (mel.data[i] < mmax) {
2096
- mel.data[i] = mmax;
2097
  }
2098
-
2099
- mel.data[i] = (mel.data[i] + 4.0)/4.0;
2100
  }
2101
 
2102
  return true;
2103
  }
2104
 
2105
- // 500 -> 00:05.000
2106
- // 6000 -> 01:00.000
2107
- std::string to_timestamp(int64_t t) {
2108
- int64_t sec = t/100;
2109
- int64_t msec = t - sec*100;
2110
- int64_t min = sec/60;
2111
- sec = sec - min*60;
2112
-
2113
- char buf[32];
2114
- snprintf(buf, sizeof(buf), "%02d:%02d.%03d", (int) min, (int) sec, (int) msec);
2115
-
2116
- return std::string(buf);
 
 
 
 
2117
  }
2118
 
2119
  int main(int argc, char ** argv) {
2120
- const int64_t t_main_start_us = ggml_time_us();
2121
 
2122
  whisper_params params;
2123
 
@@ -2129,31 +123,9 @@ int main(int argc, char ** argv) {
2129
  params.seed = time(NULL);
2130
  }
2131
 
2132
- // Model loading
2133
-
2134
- //printf("%s: seed = %d\n", __func__, params.seed);
2135
-
2136
- int64_t t_load_us = 0;
2137
- int64_t t_mel_us = 0;
2138
- int64_t t_sample_us = 0;
2139
- int64_t t_encode_us = 0;
2140
- int64_t t_decode_us = 0;
2141
-
2142
- whisper_vocab vocab;
2143
- whisper_model model;
2144
-
2145
- // load the model
2146
- {
2147
- const int64_t t_start_us = ggml_time_us();
2148
-
2149
- if (!whisper_model_load(params.model, model, vocab)) {
2150
- fprintf(stderr, "%s: failed to load model from '%s'\n", __func__, params.model.c_str());
2151
- whisper_print_usage(argc, argv, {});
2152
- return 1;
2153
- }
2154
 
2155
- t_load_us = ggml_time_us() - t_start_us;
2156
- }
2157
 
2158
  // WAV input
2159
  std::vector<float> pcmf32;
@@ -2201,19 +173,15 @@ int main(int argc, char ** argv) {
2201
  }
2202
 
2203
  // compute log mel spectrogram
2204
- whisper_mel mel_inp;
2205
- {
2206
- const int64_t t_start_us = ggml_time_us();
2207
-
2208
- log_mel_spectrogram(pcmf32, SAMPLE_RATE, N_FFT, HOP_LENGTH, N_MEL, params.n_threads, model.filters, mel_inp);
2209
-
2210
- t_mel_us = ggml_time_us() - t_start_us;
2211
  }
2212
 
2213
  // print some info about the processing
2214
  {
2215
  printf("\n");
2216
- if (!vocab.is_multilingual()) {
2217
  if (params.language != "en" || params.translate) {
2218
  params.language = "en";
2219
  params.translate = false;
@@ -2222,23 +190,23 @@ int main(int argc, char ** argv) {
2222
  }
2223
  printf("%s: processing %d samples (%.1f sec), %d threads, lang = %s, task = %s, timestamps = %d ...\n",
2224
  __func__, int(pcmf32.size()), float(pcmf32.size())/SAMPLE_RATE, params.n_threads,
2225
- g_lang.at(params.language).second.c_str(),
2226
  params.translate ? "translate" : "transcribe",
2227
  params.no_timestamps ? 0 : 1);
2228
  printf("\n");
2229
  }
2230
 
2231
  // the accumulated text context so far
2232
- std::vector<whisper_vocab::id> prompt_past = { };
2233
 
2234
  // these tokens determine the task that will be performed
2235
- std::vector<whisper_vocab::id> prompt_init = { vocab.token_sot };
2236
- if (vocab.is_multilingual()) {
2237
- prompt_init.push_back(vocab.token_sot + 1 + g_lang.at(params.language).first);
2238
  if (params.translate) {
2239
- prompt_init.push_back(vocab.token_translate);
2240
  } else {
2241
- prompt_init.push_back(vocab.token_transcribe);
2242
  }
2243
  }
2244
 
@@ -2248,35 +216,25 @@ int main(int argc, char ** argv) {
2248
  // main loop
2249
  int seek = 0;
2250
  while (true) {
2251
- if (seek >= mel_inp.n_len) {
2252
  break;
2253
  }
2254
 
2255
  // encode audio features starting at offset seek
2256
- std::vector<float> features;
2257
- {
2258
- const int64_t t_start_us = ggml_time_us();
2259
-
2260
- if (!whisper_encode(model, params.n_threads, seek, mel_inp, features)) {
2261
- fprintf(stderr, "%s: failed to eval\n", __func__);
2262
- return 1;
2263
- }
2264
-
2265
- t_encode_us += ggml_time_us() - t_start_us;
2266
  }
2267
 
2268
- std::vector<float> probs;
2269
- std::vector<float> logits;
2270
-
2271
- std::vector<whisper_vocab::id> prompt;
2272
 
2273
  int n_past = 0;
2274
 
2275
  // if we have already generated some text, use it as a prompt to condition the next generation
2276
  if (prompt_past.size() > 0) {
2277
- int n_take = std::min(model.hparams.n_text_ctx/2, int(prompt_past.size()));
2278
 
2279
- prompt = { vocab.token_prev };
2280
  prompt.insert(prompt.begin() + 1, prompt_past.end() - n_take, prompt_past.end());
2281
 
2282
  prompt_past.clear();
@@ -2287,7 +245,7 @@ int main(int argc, char ** argv) {
2287
 
2288
  bool done = false;
2289
  int seek_delta = 100*CHUNK_SIZE;
2290
- whisper_vocab::id last_id = 0;
2291
 
2292
  // print the prompt
2293
  //printf("\n\n");
@@ -2300,17 +258,10 @@ int main(int argc, char ** argv) {
2300
  int result_len = 0;
2301
  std::vector<whisper_result> result_cur;
2302
 
2303
- for (int i = 0; i < model.hparams.n_text_ctx/2 - 4; ++i) {
2304
- // decode
2305
- if (prompt.size() > 0) {
2306
- const int64_t t_start_us = ggml_time_us();
2307
-
2308
- if (!whisper_decode(model, params.n_threads, n_past, prompt, logits, probs)) {
2309
- fprintf(stderr, "%s: failed to eval\n", __func__);
2310
- return 1;
2311
- }
2312
-
2313
- t_decode_us += ggml_time_us() - t_start_us;
2314
  }
2315
 
2316
  n_past += prompt.size();
@@ -2324,37 +275,31 @@ int main(int argc, char ** argv) {
2324
  // feel free to experiment!
2325
  //
2326
  {
2327
- const int n_vocab = model.hparams.n_vocab;
2328
-
2329
- whisper_vocab::id id = 0;
2330
- whisper_vocab::id tid = vocab.token_beg;
2331
 
2332
- {
2333
- const int64_t t_start_sample_us = ggml_time_us();
2334
-
2335
- id = whisper_sample_best(vocab, probs.data() + (probs.size() - n_vocab), result_len == 0);
2336
- if (i > 0) {
2337
- tid = whisper_sample_timestamp(vocab, probs.data() + (probs.size() - n_vocab));
2338
- }
2339
 
2340
- t_sample_us += ggml_time_us() - t_start_sample_us;
 
 
2341
  }
2342
 
2343
  // update sliding window
2344
- if (id > vocab.token_beg) {
2345
- seek_delta = 2*(id - vocab.token_beg);
2346
  result_len = i + 1;
2347
  }
2348
  last_id = id;
2349
 
2350
  // add it to the context
2351
  prompt.push_back(id);
2352
- result_cur.push_back({ id, seek + 2*(tid - vocab.token_beg) });
2353
 
2354
  //printf("%s: %s\n", __func__, vocab.id_to_token[id].c_str());
2355
 
2356
  // end of text token
2357
- if (id == vocab.token_eot) {
2358
  break;
2359
  }
2360
  }
@@ -2377,11 +322,11 @@ int main(int argc, char ** argv) {
2377
 
2378
  std::string text = "";
2379
  for (int i = 0; i < result_cur.size(); i++) {
2380
- if (params.print_special_tokens == false && result_cur[i].id >= vocab.token_eot) {
2381
  } else {
2382
- text += vocab.id_to_token[result_cur[i].id];
2383
  }
2384
- if (result_cur[i].id > vocab.token_beg) {
2385
  const auto t1 = result_cur[i].t;
2386
  if (!text.empty()) {
2387
  if (params.no_timestamps) {
@@ -2392,7 +337,7 @@ int main(int argc, char ** argv) {
2392
  }
2393
  }
2394
  text = "";
2395
- while (result_cur[i].id > vocab.token_beg && i < result_cur.size()) {
2396
  i++;
2397
  }
2398
  i--;
@@ -2408,45 +353,8 @@ int main(int argc, char ** argv) {
2408
  seek += seek_delta;
2409
  }
2410
 
2411
- // WIP: attempt for per-token timestamps
2412
- //if (!params.no_timestamps && result_all.size() > 0) {
2413
- // const int64_t dt = 500; // 5 second intervals
2414
-
2415
- // int i0 = 0;
2416
-
2417
- // int64_t t0 = result_all[0].t;
2418
- // int64_t t1 = t0;
2419
-
2420
- // printf("\n\n");
2421
- // for (int i = 0; i < result_all.size(); ++i) {
2422
- // printf("'%s' -> %lld\n", vocab.id_to_token[result_all[i].id].c_str(), result_all[i].t);
2423
- // if (result_all[i].t - t0 > dt) {
2424
- // t1 = result_all[i - 1].t;
2425
- // printf("[%s --> %s] ", to_timestamp(t0).c_str(), to_timestamp(t1).c_str());
2426
- // for (int j = i0; j < i; ++j) {
2427
- // printf("%s", vocab.id_to_token.at(result_all[j].id).c_str());
2428
- // }
2429
- // printf("\n");
2430
- // i0 = i;
2431
- // t0 = result_all[i].t;
2432
- // }
2433
- // }
2434
- //}
2435
-
2436
- // report timing
2437
- {
2438
- const int64_t t_main_end_us = ggml_time_us();
2439
-
2440
- printf("\n\n");
2441
- printf("%s: load time = %8.2f ms\n", __func__, t_load_us/1000.0f);
2442
- printf("%s: mel time = %8.2f ms\n", __func__, t_mel_us/1000.0f);
2443
- printf("%s: sample time = %8.2f ms\n", __func__, t_sample_us/1000.0f);
2444
- printf("%s: encode time = %8.2f ms / %.2f ms per layer\n", __func__, t_encode_us/1000.0f, t_encode_us/1000.0f/model.hparams.n_audio_layer);
2445
- printf("%s: decode time = %8.2f ms\n", __func__, t_decode_us/1000.0f);
2446
- printf("%s: total time = %8.2f ms\n", __func__, (t_main_end_us - t_main_start_us)/1000.0f);
2447
- }
2448
-
2449
- ggml_free(model.ctx);
2450
 
2451
  return 0;
2452
  }
 
1
+ #include "whisper.h"
2
 
3
+ // third-party utilities
4
+ // use your favorite implementations
5
+ #define DR_WAV_IMPLEMENTATION
6
+ #include "dr_wav.h"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
 
8
+ #include <cassert>
9
+ #include <cstdio>
10
+ #include <string>
11
+ #include <thread>
12
+ #include <vector>
13
 
14
+ int64_t get_time_us() {
15
+ return std::chrono::duration_cast<std::chrono::microseconds>(
16
+ std::chrono::high_resolution_clock::now().time_since_epoch()).count();
17
  }
18
 
19
+ // 500 -> 00:05.000
20
+ // 6000 -> 01:00.000
21
+ std::string to_timestamp(int64_t t) {
22
+ int64_t sec = t/100;
23
+ int64_t msec = t - sec*100;
24
+ int64_t min = sec/60;
25
+ sec = sec - min*60;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
 
27
+ char buf[32];
28
+ snprintf(buf, sizeof(buf), "%02d:%02d.%03d", (int) min, (int) sec, (int) msec);
29
 
30
+ return std::string(buf);
 
 
31
  }
32
 
33
+ struct whisper_result {
34
+ whisper_token id;
35
+ int64_t t;
36
+ };
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
 
38
+ // command-line parameters
39
+ struct whisper_params {
40
+ int32_t seed = -1; // RNG seed, not used currently
41
+ int32_t n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency());
42
 
43
+ bool verbose = false;
44
+ bool translate = false;
45
+ bool print_special_tokens = false;
46
+ bool no_timestamps = false;
 
 
 
 
 
 
 
 
 
 
47
 
48
+ std::string language = "en";
49
+ std::string model = "models/ggml-base.en.bin";
50
+ std::string fname_inp = "samples/jfk.wav";
51
+ };
52
 
53
+ void whisper_print_usage(int argc, char ** argv, const whisper_params & params);
 
 
 
 
 
54
 
55
+ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
56
+ for (int i = 1; i < argc; i++) {
57
+ std::string arg = argv[i];
58
 
59
+ if (arg == "-s" || arg == "--seed") {
60
+ params.seed = std::stoi(argv[++i]);
61
+ } else if (arg == "-t" || arg == "--threads") {
62
+ params.n_threads = std::stoi(argv[++i]);
63
+ } else if (arg == "-v" || arg == "--verbose") {
64
+ params.verbose = true;
65
+ } else if (arg == "--translate") {
66
+ params.translate = true;
67
+ } else if (arg == "-l" || arg == "--language") {
68
+ params.language = argv[++i];
69
+ if (whisper_lang_id(params.language.c_str()) == -1) {
70
+ fprintf(stderr, "error: unknown language '%s'\n", params.language.c_str());
71
+ whisper_print_usage(argc, argv, params);
72
+ exit(0);
73
  }
74
+ } else if (arg == "-ps" || arg == "--print_special") {
75
+ params.print_special_tokens = true;
76
+ } else if (arg == "-nt" || arg == "--no_timestamps") {
77
+ params.no_timestamps = true;
78
+ } else if (arg == "-m" || arg == "--model") {
79
+ params.model = argv[++i];
80
+ } else if (arg == "-f" || arg == "--file") {
81
+ params.fname_inp = argv[++i];
82
+ } else if (arg == "-h" || arg == "--help") {
83
+ whisper_print_usage(argc, argv, params);
84
+ exit(0);
85
+ } else {
86
+ fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
87
+ whisper_print_usage(argc, argv, params);
88
+ exit(0);
 
 
 
 
 
 
89
  }
 
 
90
  }
91
 
92
  return true;
93
  }
94
 
95
+ void whisper_print_usage(int argc, char ** argv, const whisper_params & params) {
96
+ fprintf(stderr, "\n");
97
+ fprintf(stderr, "usage: %s [options]\n", argv[0]);
98
+ fprintf(stderr, "\n");
99
+ fprintf(stderr, "options:\n");
100
+ fprintf(stderr, " -h, --help show this help message and exit\n");
101
+ fprintf(stderr, " -s SEED, --seed SEED RNG seed (default: -1)\n");
102
+ fprintf(stderr, " -t N, --threads N number of threads to use during computation (default: %d)\n", params.n_threads);
103
+ fprintf(stderr, " -v, --verbose verbose output\n");
104
+ fprintf(stderr, " --translate translate from source language to english\n");
105
+ fprintf(stderr, " -ps, --print_special print special tokens\n");
106
+ fprintf(stderr, " -nt, --no_timestamps do not print timestamps\n");
107
+ fprintf(stderr, " -l LANG, --language LANG spoken language (default: %s)\n", params.language.c_str());
108
+ fprintf(stderr, " -m FNAME, --model FNAME model path (default: %s)\n", params.model.c_str());
109
+ fprintf(stderr, " -f FNAME, --file FNAME input WAV file path (default: %s)\n", params.fname_inp.c_str());
110
+ fprintf(stderr, "\n");
111
  }
112
 
113
  int main(int argc, char ** argv) {
114
+ const int64_t t_main_start_us = get_time_us();
115
 
116
  whisper_params params;
117
 
 
123
  params.seed = time(NULL);
124
  }
125
 
126
+ // whisper init
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
127
 
128
+ struct whisper_context * ctx = whisper_init(params.model.c_str());
 
129
 
130
  // WAV input
131
  std::vector<float> pcmf32;
 
173
  }
174
 
175
  // compute log mel spectrogram
176
+ if (whisper_pcm_to_mel(ctx, pcmf32.data(), pcmf32.size(), params.n_threads) != 0) {
177
+ fprintf(stderr, "%s: failed to compute log mel spectrogram\n", argv[0]);
178
+ return 6;
 
 
 
 
179
  }
180
 
181
  // print some info about the processing
182
  {
183
  printf("\n");
184
+ if (!whisper_is_multilingual(ctx)) {
185
  if (params.language != "en" || params.translate) {
186
  params.language = "en";
187
  params.translate = false;
 
190
  }
191
  printf("%s: processing %d samples (%.1f sec), %d threads, lang = %s, task = %s, timestamps = %d ...\n",
192
  __func__, int(pcmf32.size()), float(pcmf32.size())/SAMPLE_RATE, params.n_threads,
193
+ params.language.c_str(),
194
  params.translate ? "translate" : "transcribe",
195
  params.no_timestamps ? 0 : 1);
196
  printf("\n");
197
  }
198
 
199
  // the accumulated text context so far
200
+ std::vector<whisper_token> prompt_past = { };
201
 
202
  // these tokens determine the task that will be performed
203
+ std::vector<whisper_token> prompt_init = { whisper_token_sot(ctx) };
204
+ if (whisper_is_multilingual(ctx)) {
205
+ prompt_init.push_back(whisper_token_sot(ctx) + 1 + whisper_lang_id(params.language.c_str()));
206
  if (params.translate) {
207
+ prompt_init.push_back(whisper_token_translate());
208
  } else {
209
+ prompt_init.push_back(whisper_token_transcribe());
210
  }
211
  }
212
 
 
216
  // main loop
217
  int seek = 0;
218
  while (true) {
219
+ if (seek >= whisper_n_len(ctx)) {
220
  break;
221
  }
222
 
223
  // encode audio features starting at offset seek
224
+ if (whisper_encode(ctx, seek, params.n_threads) != 0) {
225
+ fprintf(stderr, "%s: failed to encode\n", __func__);
226
+ return 7;
 
 
 
 
 
 
 
227
  }
228
 
229
+ std::vector<whisper_token> prompt;
 
 
 
230
 
231
  int n_past = 0;
232
 
233
  // if we have already generated some text, use it as a prompt to condition the next generation
234
  if (prompt_past.size() > 0) {
235
+ int n_take = std::min(whisper_n_text_ctx(ctx)/2, int(prompt_past.size()));
236
 
237
+ prompt = { whisper_token_prev(ctx) };
238
  prompt.insert(prompt.begin() + 1, prompt_past.end() - n_take, prompt_past.end());
239
 
240
  prompt_past.clear();
 
245
 
246
  bool done = false;
247
  int seek_delta = 100*CHUNK_SIZE;
248
+ whisper_token last_id = 0;
249
 
250
  // print the prompt
251
  //printf("\n\n");
 
258
  int result_len = 0;
259
  std::vector<whisper_result> result_cur;
260
 
261
+ for (int i = 0; i < whisper_n_text_ctx(ctx)/2 - 4; ++i) {
262
+ if (whisper_decode(ctx, prompt.data(), prompt.size(), n_past, params.n_threads) != 0) {
263
+ fprintf(stderr, "%s: failed to decode\n", __func__);
264
+ return 8;
 
 
 
 
 
 
 
265
  }
266
 
267
  n_past += prompt.size();
 
275
  // feel free to experiment!
276
  //
277
  {
278
+ const int n_vocab = whisper_n_vocab(ctx);
 
 
 
279
 
280
+ whisper_token id = 0;
281
+ whisper_token tid = whisper_token_beg(ctx);
 
 
 
 
 
282
 
283
+ id = whisper_sample_best(ctx, result_len == 0);
284
+ if (i > 0) {
285
+ tid = whisper_sample_timestamp(ctx);
286
  }
287
 
288
  // update sliding window
289
+ if (id > whisper_token_beg(ctx)) {
290
+ seek_delta = 2*(id - whisper_token_beg(ctx));
291
  result_len = i + 1;
292
  }
293
  last_id = id;
294
 
295
  // add it to the context
296
  prompt.push_back(id);
297
+ result_cur.push_back({ id, seek + 2*(tid - whisper_token_beg(ctx)) });
298
 
299
  //printf("%s: %s\n", __func__, vocab.id_to_token[id].c_str());
300
 
301
  // end of text token
302
+ if (id == whisper_token_eot(ctx)) {
303
  break;
304
  }
305
  }
 
322
 
323
  std::string text = "";
324
  for (int i = 0; i < result_cur.size(); i++) {
325
+ if (params.print_special_tokens == false && result_cur[i].id >= whisper_token_eot(ctx)) {
326
  } else {
327
+ text += whisper_token_to_str(ctx, result_cur[i].id);
328
  }
329
+ if (result_cur[i].id > whisper_token_beg(ctx)) {
330
  const auto t1 = result_cur[i].t;
331
  if (!text.empty()) {
332
  if (params.no_timestamps) {
 
337
  }
338
  }
339
  text = "";
340
+ while (result_cur[i].id > whisper_token_beg(ctx) && i < result_cur.size()) {
341
  i++;
342
  }
343
  i--;
 
353
  seek += seek_delta;
354
  }
355
 
356
+ whisper_print_timings(ctx);
357
+ whisper_free(ctx);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
358
 
359
  return 0;
360
  }
stream.cpp CHANGED
@@ -2,2116 +2,119 @@
2
  //
3
  // A very quick-n-dirty implementation serving mainly as a proof of concept.
4
 
5
- #include "ggml.h"
6
 
7
- #define USE_FLASH_ATTN
8
- #define USE_FLASH_FF
9
-
10
- // third-party utilities
11
- // use your favorite implementations
12
- #define DR_WAV_IMPLEMENTATION
13
- #include "dr_wav.h"
14
-
15
- #include <SDL.h>
16
- #include <SDL_audio.h>
17
-
18
- #include <algorithm>
19
- #include <cassert>
20
- #include <cmath>
21
- #include <cstdio>
22
- #include <cstring>
23
- #include <fstream>
24
- #include <map>
25
- #include <string>
26
- #include <thread>
27
- #include <vector>
28
-
29
- // available whisper models
30
- enum e_model {
31
- MODEL_UNKNOWN,
32
- MODEL_TINY,
33
- MODEL_BASE,
34
- MODEL_SMALL,
35
- MODEL_MEDIUM,
36
- MODEL_LARGE,
37
- };
38
-
39
- const std::map<std::string, std::pair<int, std::string>> g_lang = {
40
- { "en", { 0, "english", } },
41
- { "zh", { 1, "chinese", } },
42
- { "de", { 2, "german", } },
43
- { "es", { 3, "spanish", } },
44
- { "ru", { 4, "russian", } },
45
- { "ko", { 5, "korean", } },
46
- { "fr", { 6, "french", } },
47
- { "ja", { 7, "japanese", } },
48
- { "pt", { 8, "portuguese", } },
49
- { "tr", { 9, "turkish", } },
50
- { "pl", { 10, "polish", } },
51
- { "ca", { 11, "catalan", } },
52
- { "nl", { 12, "dutch", } },
53
- { "ar", { 13, "arabic", } },
54
- { "sv", { 14, "swedish", } },
55
- { "it", { 15, "italian", } },
56
- { "id", { 16, "indonesian", } },
57
- { "hi", { 17, "hindi", } },
58
- { "fi", { 18, "finnish", } },
59
- { "vi", { 19, "vietnamese", } },
60
- { "iw", { 20, "hebrew", } },
61
- { "uk", { 21, "ukrainian", } },
62
- { "el", { 22, "greek", } },
63
- { "ms", { 23, "malay", } },
64
- { "cs", { 24, "czech", } },
65
- { "ro", { 25, "romanian", } },
66
- { "da", { 26, "danish", } },
67
- { "hu", { 27, "hungarian", } },
68
- { "ta", { 28, "tamil", } },
69
- { "no", { 29, "norwegian", } },
70
- { "th", { 30, "thai", } },
71
- { "ur", { 31, "urdu", } },
72
- { "hr", { 32, "croatian", } },
73
- { "bg", { 33, "bulgarian", } },
74
- { "lt", { 34, "lithuanian", } },
75
- { "la", { 35, "latin", } },
76
- { "mi", { 36, "maori", } },
77
- { "ml", { 37, "malayalam", } },
78
- { "cy", { 38, "welsh", } },
79
- { "sk", { 39, "slovak", } },
80
- { "te", { 40, "telugu", } },
81
- { "fa", { 41, "persian", } },
82
- { "lv", { 42, "latvian", } },
83
- { "bn", { 43, "bengali", } },
84
- { "sr", { 44, "serbian", } },
85
- { "az", { 45, "azerbaijani", } },
86
- { "sl", { 46, "slovenian", } },
87
- { "kn", { 47, "kannada", } },
88
- { "et", { 48, "estonian", } },
89
- { "mk", { 49, "macedonian", } },
90
- { "br", { 50, "breton", } },
91
- { "eu", { 51, "basque", } },
92
- { "is", { 52, "icelandic", } },
93
- { "hy", { 53, "armenian", } },
94
- { "ne", { 54, "nepali", } },
95
- { "mn", { 55, "mongolian", } },
96
- { "bs", { 56, "bosnian", } },
97
- { "kk", { 57, "kazakh", } },
98
- { "sq", { 58, "albanian", } },
99
- { "sw", { 59, "swahili", } },
100
- { "gl", { 60, "galician", } },
101
- { "mr", { 61, "marathi", } },
102
- { "pa", { 62, "punjabi", } },
103
- { "si", { 63, "sinhala", } },
104
- { "km", { 64, "khmer", } },
105
- { "sn", { 65, "shona", } },
106
- { "yo", { 66, "yoruba", } },
107
- { "so", { 67, "somali", } },
108
- { "af", { 68, "afrikaans", } },
109
- { "oc", { 69, "occitan", } },
110
- { "ka", { 70, "georgian", } },
111
- { "be", { 71, "belarusian", } },
112
- { "tg", { 72, "tajik", } },
113
- { "sd", { 73, "sindhi", } },
114
- { "gu", { 74, "gujarati", } },
115
- { "am", { 75, "amharic", } },
116
- { "yi", { 76, "yiddish", } },
117
- { "lo", { 77, "lao", } },
118
- { "uz", { 78, "uzbek", } },
119
- { "fo", { 79, "faroese", } },
120
- { "ht", { 80, "haitian creole", } },
121
- { "ps", { 81, "pashto", } },
122
- { "tk", { 82, "turkmen", } },
123
- { "nn", { 83, "nynorsk", } },
124
- { "mt", { 84, "maltese", } },
125
- { "sa", { 85, "sanskrit", } },
126
- { "lb", { 86, "luxembourgish", } },
127
- { "my", { 87, "myanmar", } },
128
- { "bo", { 88, "tibetan", } },
129
- { "tl", { 89, "tagalog", } },
130
- { "mg", { 90, "malagasy", } },
131
- { "as", { 91, "assamese", } },
132
- { "tt", { 92, "tatar", } },
133
- { "haw", { 93, "hawaiian", } },
134
- { "ln", { 94, "lingala", } },
135
- { "ha", { 95, "hausa", } },
136
- { "ba", { 96, "bashkir", } },
137
- { "jw", { 97, "javanese", } },
138
- { "su", { 98, "sundanese", } },
139
- };
140
-
141
- const size_t MB = 1024*1024;
142
-
143
- const std::map<e_model, size_t> MEM_REQ_MODEL = {
144
- { MODEL_TINY, 86ull*MB },
145
- { MODEL_BASE, 165ull*MB },
146
- { MODEL_SMALL, 540ull*MB },
147
- { MODEL_MEDIUM, 1650ull*MB },
148
- { MODEL_LARGE, 3260ull*MB },
149
- };
150
-
151
- const std::map<e_model, size_t> MEM_REQ_ENCODE = {
152
- { MODEL_TINY, 80ull*MB },
153
- { MODEL_BASE, 128ull*MB },
154
- { MODEL_SMALL, 300ull*MB },
155
- { MODEL_MEDIUM, 680ull*MB },
156
- { MODEL_LARGE, 1100ull*MB },
157
- };
158
-
159
- const std::map<e_model, size_t> MEM_REQ_ENCODE_LAYER = {
160
- { MODEL_TINY, 64ull*MB },
161
- { MODEL_BASE, 84ull*MB },
162
- { MODEL_SMALL, 128ull*MB },
163
- { MODEL_MEDIUM, 172ull*MB },
164
- { MODEL_LARGE, 216ull*MB },
165
- };
166
-
167
- const std::map<e_model, size_t> MEM_REQ_DECODE = {
168
- { MODEL_TINY, 94ull*MB },
169
- { MODEL_BASE, 96ull*MB },
170
- { MODEL_SMALL, 98ull*MB },
171
- { MODEL_MEDIUM, 100ull*MB },
172
- { MODEL_LARGE, 102ull*MB },
173
- };
174
-
175
- const std::map<e_model, size_t> MEM_REQ_DECODE_LAYER = {
176
- { MODEL_TINY, 32ull*MB },
177
- { MODEL_BASE, 44ull*MB },
178
- { MODEL_SMALL, 64ull*MB },
179
- { MODEL_MEDIUM, 84ull*MB },
180
- { MODEL_LARGE, 110ull*MB },
181
- };
182
-
183
- // the memory buffers used to store the model in memory and perform the inference computations
184
- std::vector<uint8_t> g_buf_model;
185
- std::vector<uint8_t> g_buf_compute;
186
- std::vector<uint8_t> g_buf_compute_layer;
187
-
188
- const int SAMPLE_RATE = 16000;
189
- const int N_FFT = 400;
190
- const int N_MEL = 80;
191
- const int HOP_LENGTH = 160;
192
- const int CHUNK_SIZE = 30; // seconds
193
-
194
- struct whisper_mel {
195
- int n_len;
196
- int n_mel;
197
-
198
- std::vector<float> data;
199
- };
200
-
201
- struct whisper_filters {
202
- int32_t n_mel;
203
- int32_t n_fft;
204
-
205
- std::vector<float> data;
206
- };
207
-
208
- struct whisper_vocab {
209
- using id = int32_t;
210
- using token = std::string;
211
-
212
- int n_vocab = 51864;
213
-
214
- std::map<token, id> token_to_id;
215
- std::map<id, token> id_to_token;
216
-
217
- id token_eot = 50256;
218
- id token_sot = 50257;
219
- id token_prev = 50360;
220
- id token_solm = 50361; // ??
221
- id token_not = 50362; // no timestamps
222
- id token_beg = 50363;
223
-
224
- // available tasks
225
- const id token_translate = 50358;
226
- const id token_transcribe = 50359;
227
-
228
- bool is_multilingual() const {
229
- return n_vocab == 51865;
230
- }
231
- };
232
-
233
- struct whisper_result {
234
- whisper_vocab::id id;
235
- int64_t t;
236
- };
237
-
238
- // command-line parameters
239
- struct whisper_params {
240
- int32_t seed = -1; // RNG seed, not used currently
241
- int32_t n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency());
242
-
243
- bool verbose = false;
244
- bool translate = false;
245
- bool print_special_tokens = false;
246
- bool no_timestamps = true;
247
-
248
- std::string language = "en";
249
- std::string model = "models/ggml-base.en.bin";
250
- std::string fname_inp = "samples/jfk.wav";
251
- };
252
-
253
- void whisper_print_usage(int argc, char ** argv, const whisper_params & params);
254
-
255
- bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
256
- for (int i = 1; i < argc; i++) {
257
- std::string arg = argv[i];
258
-
259
- if (arg == "-s" || arg == "--seed") {
260
- params.seed = std::stoi(argv[++i]);
261
- } else if (arg == "-t" || arg == "--threads") {
262
- params.n_threads = std::stoi(argv[++i]);
263
- } else if (arg == "-v" || arg == "--verbose") {
264
- params.verbose = true;
265
- } else if (arg == "--translate") {
266
- params.translate = true;
267
- } else if (arg == "-l" || arg == "--language") {
268
- params.language = argv[++i];
269
- if (g_lang.find(params.language) == g_lang.end()) {
270
- fprintf(stderr, "error: unknown language '%s'\n", params.language.c_str());
271
- whisper_print_usage(argc, argv, params);
272
- exit(0);
273
- }
274
- } else if (arg == "-ps" || arg == "--print_special") {
275
- params.print_special_tokens = true;
276
- } else if (arg == "-nt" || arg == "--no_timestamps") {
277
- params.no_timestamps = true;
278
- } else if (arg == "-m" || arg == "--model") {
279
- params.model = argv[++i];
280
- } else if (arg == "-f" || arg == "--file") {
281
- params.fname_inp = argv[++i];
282
- } else if (arg == "-h" || arg == "--help") {
283
- whisper_print_usage(argc, argv, params);
284
- exit(0);
285
- } else {
286
- fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
287
- whisper_print_usage(argc, argv, params);
288
- exit(0);
289
- }
290
- }
291
-
292
- return true;
293
- }
294
-
295
- void whisper_print_usage(int argc, char ** argv, const whisper_params & params) {
296
- fprintf(stderr, "\n");
297
- fprintf(stderr, "usage: %s [options]\n", argv[0]);
298
- fprintf(stderr, "\n");
299
- fprintf(stderr, "options:\n");
300
- fprintf(stderr, " -h, --help show this help message and exit\n");
301
- fprintf(stderr, " -s SEED, --seed SEED RNG seed (default: -1)\n");
302
- fprintf(stderr, " -t N, --threads N number of threads to use during computation (default: %d)\n", params.n_threads);
303
- fprintf(stderr, " -v, --verbose verbose output\n");
304
- fprintf(stderr, " --translate translate from source language to english\n");
305
- fprintf(stderr, " -ps, --print_special print special tokens\n");
306
- fprintf(stderr, " -nt, --no_timestamps do not print timestamps\n");
307
- fprintf(stderr, " -l LANG, --language LANG spoken language (default: %s)\n", params.language.c_str());
308
- fprintf(stderr, " -m FNAME, --model FNAME model path (default: %s)\n", params.model.c_str());
309
- fprintf(stderr, " -f FNAME, --file FNAME input WAV file path (default: %s)\n", params.fname_inp.c_str());
310
- fprintf(stderr, "\n");
311
- }
312
-
313
-
314
- // medium
315
- // hparams: {
316
- // 'n_mels': 80,
317
- // 'n_vocab': 51864,
318
- // 'n_audio_ctx': 1500,
319
- // 'n_audio_state': 1024,
320
- // 'n_audio_head': 16,
321
- // 'n_audio_layer': 24,
322
- // 'n_text_ctx': 448,
323
- // 'n_text_state': 1024,
324
- // 'n_text_head': 16,
325
- // 'n_text_layer': 24
326
- // }
327
- //
328
- // default hparams (Whisper tiny)
329
- struct whisper_hparams {
330
- int32_t n_vocab = 51864;
331
- int32_t n_audio_ctx = 1500;
332
- int32_t n_audio_state = 384;
333
- int32_t n_audio_head = 6;
334
- int32_t n_audio_layer = 4;
335
- int32_t n_text_ctx = 448;
336
- int32_t n_text_state = 384;
337
- int32_t n_text_head = 6;
338
- int32_t n_text_layer = 4;
339
- int32_t n_mels = 80;
340
- int32_t f16 = 1;
341
- };
342
-
343
- // audio encoding layer
344
- struct whisper_layer_encoder {
345
- // encoder.blocks.*.attn_ln
346
- struct ggml_tensor * attn_ln_0_w;
347
- struct ggml_tensor * attn_ln_0_b;
348
-
349
- // encoder.blocks.*.attn.out
350
- struct ggml_tensor * attn_ln_1_w;
351
- struct ggml_tensor * attn_ln_1_b;
352
-
353
- // encoder.blocks.*.attn.query
354
- struct ggml_tensor * attn_q_w;
355
- struct ggml_tensor * attn_q_b;
356
-
357
- // encoder.blocks.*.attn.key
358
- struct ggml_tensor * attn_k_w;
359
-
360
- // encoder.blocks.*.attn.value
361
- struct ggml_tensor * attn_v_w;
362
- struct ggml_tensor * attn_v_b;
363
-
364
- // encoder.blocks.*.mlp_ln
365
- struct ggml_tensor * mlp_ln_w;
366
- struct ggml_tensor * mlp_ln_b;
367
-
368
- // encoder.blocks.*.mlp.0
369
- struct ggml_tensor * mlp_0_w;
370
- struct ggml_tensor * mlp_0_b;
371
-
372
- // encoder.blocks.*.mlp.2
373
- struct ggml_tensor * mlp_1_w;
374
- struct ggml_tensor * mlp_1_b;
375
- };
376
-
377
- // token decoding layer
378
- struct whisper_layer_decoder {
379
- // decoder.blocks.*.attn_ln
380
- struct ggml_tensor * attn_ln_0_w;
381
- struct ggml_tensor * attn_ln_0_b;
382
-
383
- // decoder.blocks.*.attn.out
384
- struct ggml_tensor * attn_ln_1_w;
385
- struct ggml_tensor * attn_ln_1_b;
386
-
387
- // decoder.blocks.*.attn.query
388
- struct ggml_tensor * attn_q_w;
389
- struct ggml_tensor * attn_q_b;
390
-
391
- // decoder.blocks.*.attn.key
392
- struct ggml_tensor * attn_k_w;
393
-
394
- // decoder.blocks.*.attn.value
395
- struct ggml_tensor * attn_v_w;
396
- struct ggml_tensor * attn_v_b;
397
-
398
- // decoder.blocks.*.cross_attn_ln
399
- struct ggml_tensor * cross_attn_ln_0_w;
400
- struct ggml_tensor * cross_attn_ln_0_b;
401
-
402
- // decoder.blocks.*.cross_attn.out
403
- struct ggml_tensor * cross_attn_ln_1_w;
404
- struct ggml_tensor * cross_attn_ln_1_b;
405
-
406
- // decoder.blocks.*.cross_attn.query
407
- struct ggml_tensor * cross_attn_q_w;
408
- struct ggml_tensor * cross_attn_q_b;
409
-
410
- // decoder.blocks.*.cross_attn.key
411
- struct ggml_tensor * cross_attn_k_w;
412
-
413
- // decoder.blocks.*.cross_attn.value
414
- struct ggml_tensor * cross_attn_v_w;
415
- struct ggml_tensor * cross_attn_v_b;
416
-
417
- // decoder.blocks.*.mlp_ln
418
- struct ggml_tensor * mlp_ln_w;
419
- struct ggml_tensor * mlp_ln_b;
420
-
421
- // decoder.blocks.*.mlp.0
422
- struct ggml_tensor * mlp_0_w;
423
- struct ggml_tensor * mlp_0_b;
424
-
425
- // decoder.blocks.*.mlp.2
426
- struct ggml_tensor * mlp_1_w;
427
- struct ggml_tensor * mlp_1_b;
428
- };
429
-
430
- struct whisper_model {
431
- e_model type = MODEL_UNKNOWN;
432
-
433
- whisper_hparams hparams;
434
- whisper_filters filters;
435
-
436
- // encoder.positional_embedding
437
- struct ggml_tensor * e_pe;
438
-
439
- // encoder.conv1
440
- struct ggml_tensor * e_conv_1_w;
441
- struct ggml_tensor * e_conv_1_b;
442
-
443
- // encoder.conv2
444
- struct ggml_tensor * e_conv_2_w;
445
- struct ggml_tensor * e_conv_2_b;
446
-
447
- // encoder.ln_post
448
- struct ggml_tensor * e_ln_w;
449
- struct ggml_tensor * e_ln_b;
450
-
451
- // decoder.positional_embedding
452
- struct ggml_tensor * d_pe; // DD
453
-
454
- // decoder.token_embedding
455
- struct ggml_tensor * d_te; // DD
456
-
457
- // decoder.ln
458
- struct ggml_tensor * d_ln_w; // DD
459
- struct ggml_tensor * d_ln_b; // DD
460
-
461
- std::vector<whisper_layer_encoder> layers_encoder;
462
- std::vector<whisper_layer_decoder> layers_decoder;
463
-
464
- // key + value memory
465
- struct ggml_tensor * memory_k;
466
- struct ggml_tensor * memory_v;
467
-
468
- struct ggml_tensor * memory_cross_k;
469
- struct ggml_tensor * memory_cross_v;
470
-
471
- //
472
- struct ggml_context * ctx;
473
- std::map<std::string, struct ggml_tensor *> tensors;
474
- };
475
-
476
- // load the model from a ggml file
477
- //
478
- // file format:
479
- //
480
- // - hparams
481
- // - pre-computed mel filters
482
- // - vocab
483
- // - weights
484
- //
485
- // see the convert-pt-to-ggml.py script for details
486
- //
487
- bool whisper_model_load(const std::string & fname, whisper_model & model, whisper_vocab & vocab) {
488
- printf("%s: loading model from '%s'\n", __func__, fname.c_str());
489
-
490
- auto fin = std::ifstream(fname, std::ios::binary);
491
- if (!fin) {
492
- fprintf(stderr, "%s: failed to open '%s'\n", __func__, fname.c_str());
493
- return false;
494
- }
495
-
496
- // verify magic
497
- {
498
- uint32_t magic;
499
- fin.read((char *) &magic, sizeof(magic));
500
- if (magic != 0x67676d6c) {
501
- fprintf(stderr, "%s: invalid model file '%s' (bad magic)\n", __func__, fname.c_str());
502
- return false;
503
- }
504
- }
505
-
506
- //load hparams
507
- {
508
- auto & hparams = model.hparams;
509
-
510
- fin.read((char *) &hparams.n_vocab, sizeof(hparams.n_vocab));
511
- fin.read((char *) &hparams.n_audio_ctx, sizeof(hparams.n_audio_ctx));
512
- fin.read((char *) &hparams.n_audio_state, sizeof(hparams.n_audio_state));
513
- fin.read((char *) &hparams.n_audio_head, sizeof(hparams.n_audio_head));
514
- fin.read((char *) &hparams.n_audio_layer, sizeof(hparams.n_audio_layer));
515
- fin.read((char *) &hparams.n_text_ctx, sizeof(hparams.n_text_ctx));
516
- fin.read((char *) &hparams.n_text_state, sizeof(hparams.n_text_state));
517
- fin.read((char *) &hparams.n_text_head, sizeof(hparams.n_text_head));
518
- fin.read((char *) &hparams.n_text_layer, sizeof(hparams.n_text_layer));
519
- fin.read((char *) &hparams.n_mels, sizeof(hparams.n_mels));
520
- fin.read((char *) &hparams.f16, sizeof(hparams.f16));
521
-
522
- assert(hparams.n_text_state == hparams.n_audio_state);
523
-
524
- if (hparams.n_audio_layer == 4) {
525
- model.type = e_model::MODEL_TINY;
526
- }
527
-
528
- if (hparams.n_audio_layer == 6) {
529
- model.type = e_model::MODEL_BASE;
530
- }
531
-
532
- if (hparams.n_audio_layer == 12) {
533
- model.type = e_model::MODEL_SMALL;
534
- }
535
-
536
- if (hparams.n_audio_layer == 24) {
537
- model.type = e_model::MODEL_MEDIUM;
538
- }
539
-
540
- if (hparams.n_audio_layer == 32) {
541
- model.type = e_model::MODEL_LARGE;
542
- }
543
-
544
- printf("%s: n_vocab = %d\n", __func__, hparams.n_vocab);
545
- printf("%s: n_audio_ctx = %d\n", __func__, hparams.n_audio_ctx);
546
- printf("%s: n_audio_state = %d\n", __func__, hparams.n_audio_state);
547
- printf("%s: n_audio_head = %d\n", __func__, hparams.n_audio_head);
548
- printf("%s: n_audio_layer = %d\n", __func__, hparams.n_audio_layer);
549
- printf("%s: n_text_ctx = %d\n", __func__, hparams.n_text_ctx);
550
- printf("%s: n_text_state = %d\n", __func__, hparams.n_text_state);
551
- printf("%s: n_text_head = %d\n", __func__, hparams.n_text_head);
552
- printf("%s: n_text_layer = %d\n", __func__, hparams.n_text_layer);
553
- printf("%s: n_mels = %d\n", __func__, hparams.n_mels);
554
- printf("%s: f16 = %d\n", __func__, hparams.f16);
555
- printf("%s: type = %d\n", __func__, model.type);
556
-
557
- g_buf_model.resize(MEM_REQ_MODEL.at(model.type));
558
- g_buf_compute.resize(std::max(MEM_REQ_ENCODE.at(model.type), MEM_REQ_DECODE.at(model.type)));
559
- g_buf_compute_layer.resize(std::max(MEM_REQ_ENCODE_LAYER.at(model.type), MEM_REQ_DECODE_LAYER.at(model.type)));
560
-
561
- // this is the total memory required to run the inference
562
- const size_t mem_required =
563
- g_buf_model.size() +
564
- g_buf_compute.size() +
565
- g_buf_compute_layer.size();
566
-
567
- printf("%s: mem_required = %.2f MB\n", __func__, mem_required / 1024.0 / 1024.0);
568
- }
569
-
570
- // load mel filters
571
- {
572
- auto & filters = model.filters;
573
-
574
- fin.read((char *) &filters.n_mel, sizeof(filters.n_mel));
575
- fin.read((char *) &filters.n_fft, sizeof(filters.n_fft));
576
-
577
- filters.data.resize(filters.n_mel * filters.n_fft);
578
- fin.read((char *) filters.data.data(), filters.data.size() * sizeof(float));
579
- }
580
-
581
- // load vocab
582
- {
583
- int32_t n_vocab = 0;
584
- fin.read((char *) &n_vocab, sizeof(n_vocab));
585
-
586
- //if (n_vocab != model.hparams.n_vocab) {
587
- // fprintf(stderr, "%s: invalid model file '%s' (bad vocab size %d != %d)\n",
588
- // __func__, fname.c_str(), n_vocab, model.hparams.n_vocab);
589
- // return false;
590
- //}
591
-
592
- std::string word;
593
- for (int i = 0; i < n_vocab; i++) {
594
- uint32_t len;
595
- fin.read((char *) &len, sizeof(len));
596
-
597
- word.resize(len);
598
- fin.read((char *) word.data(), len);
599
-
600
- vocab.token_to_id[word] = i;
601
- vocab.id_to_token[i] = word;
602
-
603
- //printf("%s: vocab[%d] = '%s'\n", __func__, i, word.c_str());
604
- }
605
-
606
- vocab.n_vocab = model.hparams.n_vocab;
607
- if (vocab.is_multilingual()) {
608
- vocab.token_eot++;
609
- vocab.token_sot++;
610
- vocab.token_prev++;
611
- vocab.token_solm++;
612
- vocab.token_not++;
613
- vocab.token_beg++;
614
- }
615
-
616
- if (n_vocab < model.hparams.n_vocab) {
617
- printf("%s: adding %d extra tokens\n", __func__, model.hparams.n_vocab - n_vocab);
618
- for (int i = n_vocab; i < model.hparams.n_vocab; i++) {
619
- if (i > vocab.token_beg) {
620
- word = "[_TT_" + std::to_string(i - vocab.token_beg) + "]";
621
- } else if (i == vocab.token_eot) {
622
- word = "[_EOT_]";
623
- } else if (i == vocab.token_sot) {
624
- word = "[_SOT_]";
625
- } else if (i == vocab.token_prev) {
626
- word = "[_PREV_]";
627
- } else if (i == vocab.token_not) {
628
- word = "[_NOT_]";
629
- } else if (i == vocab.token_beg) {
630
- word = "[_BEG_]";
631
- } else {
632
- word = "[_extra_token_" + std::to_string(i) + "]";
633
- }
634
- vocab.token_to_id[word] = i;
635
- vocab.id_to_token[i] = word;
636
- }
637
- }
638
- }
639
-
640
- // for the big tensors, we have the option to store the data in 16-bit floats
641
- // in order to save memory and also to speed up the computation
642
- const ggml_type wtype = model.hparams.f16 ? GGML_TYPE_F16 : GGML_TYPE_F32;
643
-
644
- auto & ctx = model.ctx;
645
-
646
- size_t ctx_size = 0;
647
-
648
- {
649
- const auto & hparams = model.hparams;
650
-
651
- const int n_vocab = hparams.n_vocab;
652
-
653
- const int n_audio_ctx = hparams.n_audio_ctx;
654
- const int n_audio_state = hparams.n_audio_state;
655
- const int n_audio_layer = hparams.n_audio_layer;
656
-
657
- const int n_text_ctx = hparams.n_text_ctx;
658
- const int n_text_state = hparams.n_text_state;
659
- const int n_text_layer = hparams.n_text_layer;
660
-
661
- const int n_mels = hparams.n_mels;
662
-
663
- // encoder
664
- {
665
- // TODO: F16 .. maybe not?
666
- ctx_size += n_audio_ctx*n_audio_state*ggml_type_size(GGML_TYPE_F32); // e_pe;
667
-
668
- ctx_size += 3*n_mels*n_audio_state*ggml_type_size(wtype); // e_conv_1_w
669
- ctx_size += n_audio_state*ggml_type_size(GGML_TYPE_F32); // e_conv_1_b
670
-
671
- ctx_size += 3*n_audio_state*n_audio_state*ggml_type_size(wtype); // e_conv_2_w
672
- ctx_size += n_audio_state*ggml_type_size(GGML_TYPE_F32); // e_conv_2_b
673
-
674
- ctx_size += n_audio_state*ggml_type_size(GGML_TYPE_F32); // e_ln_w;
675
- ctx_size += n_audio_state*ggml_type_size(GGML_TYPE_F32); // e_ln_b;
676
- }
677
-
678
- // decoder
679
- {
680
- // TODO: F16 .. maybe not?
681
- ctx_size += n_text_ctx*n_text_state*ggml_type_size(GGML_TYPE_F32); // d_pe;
682
-
683
- ctx_size += n_vocab*n_text_state*ggml_type_size(wtype); // d_te;
684
-
685
- ctx_size += n_text_state*ggml_type_size(GGML_TYPE_F32); // d_ln_w;
686
- ctx_size += n_text_state*ggml_type_size(GGML_TYPE_F32); // d_ln_b;
687
- }
688
-
689
- // encoder layers
690
- {
691
- ctx_size += n_audio_layer*(n_audio_state*ggml_type_size(GGML_TYPE_F32)); // mlp_ln_w
692
- ctx_size += n_audio_layer*(n_audio_state*ggml_type_size(GGML_TYPE_F32)); // mlp_ln_b
693
-
694
- ctx_size += n_audio_layer*(4*n_audio_state*n_audio_state*ggml_type_size(wtype)); // mlp_0_w
695
- ctx_size += n_audio_layer*( 4*n_audio_state*ggml_type_size(GGML_TYPE_F32)); // mlp_0_b
696
-
697
- ctx_size += n_audio_layer*(4*n_audio_state*n_audio_state*ggml_type_size(wtype)); // mlp_1_w
698
- ctx_size += n_audio_layer*( n_audio_state*ggml_type_size(GGML_TYPE_F32)); // mlp_1_b
699
-
700
- ctx_size += n_audio_layer*(n_audio_state*ggml_type_size(GGML_TYPE_F32)); // attn_ln_0_w
701
- ctx_size += n_audio_layer*(n_audio_state*ggml_type_size(GGML_TYPE_F32)); // attn_ln_0_b
702
-
703
- ctx_size += n_audio_layer*(n_audio_state*n_audio_state*ggml_type_size(wtype)); // attn_q_w
704
- ctx_size += n_audio_layer*( n_audio_state*ggml_type_size(GGML_TYPE_F32)); // attn_q_b
705
-
706
- ctx_size += n_audio_layer*(n_audio_state*n_audio_state*ggml_type_size(wtype)); // attn_k_w
707
-
708
- ctx_size += n_audio_layer*(n_audio_state*n_audio_state*ggml_type_size(wtype)); // attn_v_w
709
- ctx_size += n_audio_layer*( n_audio_state*ggml_type_size(GGML_TYPE_F32)); // attn_v_b
710
-
711
- ctx_size += n_audio_layer*(n_audio_state*n_audio_state*ggml_type_size(wtype)); // attn_ln_1_w
712
- ctx_size += n_audio_layer*( n_audio_state*ggml_type_size(GGML_TYPE_F32)); // attn_ln_1_b
713
- }
714
-
715
- // decoder layers
716
- {
717
- ctx_size += n_text_layer*(n_text_state*ggml_type_size(GGML_TYPE_F32)); // mlp_ln_w
718
- ctx_size += n_text_layer*(n_text_state*ggml_type_size(GGML_TYPE_F32)); // mlp_ln_b
719
-
720
- ctx_size += n_text_layer*(4*n_text_state*n_text_state*ggml_type_size(wtype)); // mlp_0_w
721
- ctx_size += n_text_layer*( 4*n_text_state*ggml_type_size(GGML_TYPE_F32)); // mlp_0_b
722
-
723
- ctx_size += n_text_layer*(4*n_text_state*n_text_state*ggml_type_size(wtype)); // mlp_1_w
724
- ctx_size += n_text_layer*( n_text_state*ggml_type_size(GGML_TYPE_F32)); // mlp_1_b
725
-
726
- ctx_size += n_text_layer*(n_text_state*ggml_type_size(GGML_TYPE_F32)); // attn_ln_0_w
727
- ctx_size += n_text_layer*(n_text_state*ggml_type_size(GGML_TYPE_F32)); // attn_ln_0_b
728
-
729
- ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_size(wtype)); // attn_q_w
730
- ctx_size += n_text_layer*( n_text_state*ggml_type_size(GGML_TYPE_F32)); // attn_q_b
731
-
732
- ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_size(wtype)); // attn_k_w
733
-
734
- ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_size(wtype)); // attn_v_w
735
- ctx_size += n_text_layer*( n_text_state*ggml_type_size(GGML_TYPE_F32)); // attn_v_b
736
-
737
- ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_size(wtype)); // attn_ln_1_w
738
- ctx_size += n_text_layer*( n_text_state*ggml_type_size(GGML_TYPE_F32)); // attn_ln_1_b
739
- //
740
- ctx_size += n_text_layer*(n_text_state*ggml_type_size(GGML_TYPE_F32)); // cross_attn_ln_0_w
741
- ctx_size += n_text_layer*(n_text_state*ggml_type_size(GGML_TYPE_F32)); // cross_attn_ln_0_b
742
-
743
- ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_size(wtype)); // cross_attn_q_w
744
- ctx_size += n_text_layer*( n_text_state*ggml_type_size(GGML_TYPE_F32)); // cross_attn_q_b
745
-
746
- ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_size(wtype)); // cross_attn_k_w
747
-
748
- ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_size(wtype)); // cross_attn_v_w
749
- ctx_size += n_text_layer*( n_text_state*ggml_type_size(GGML_TYPE_F32)); // cross_attn_v_b
750
-
751
- ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_size(wtype)); // cross_attn_ln_1_w
752
- ctx_size += n_text_layer*( n_text_state*ggml_type_size(GGML_TYPE_F32)); // cross_attn_ln_1_b
753
- }
754
-
755
- ctx_size += n_text_layer*n_text_ctx*n_text_state*ggml_type_size(GGML_TYPE_F16); // memory_k
756
- ctx_size += n_text_layer*n_text_ctx*n_text_state*ggml_type_size(GGML_TYPE_F16); // memory_v
757
-
758
- ctx_size += n_text_layer*n_audio_ctx*n_text_state*ggml_type_size(GGML_TYPE_F16); // memory_cross_k
759
- ctx_size += n_text_layer*n_audio_ctx*n_text_state*ggml_type_size(GGML_TYPE_F16); // memory_cross_v
760
-
761
- ctx_size += (15 + 15*n_audio_layer + 24*n_text_layer)*256; // object overhead
762
-
763
- printf("%s: ggml ctx size = %6.2f MB\n", __func__, ctx_size/(1024.0*1024.0));
764
- }
765
-
766
- // create the ggml context
767
- {
768
- struct ggml_init_params params = {
769
- .mem_size = g_buf_model.size(),
770
- .mem_buffer = g_buf_model.data(),
771
- };
772
-
773
- model.ctx = ggml_init(params);
774
- if (!model.ctx) {
775
- fprintf(stderr, "%s: ggml_init() failed\n", __func__);
776
- return false;
777
- }
778
- }
779
-
780
- // prepare memory for the weights
781
- {
782
- const auto & hparams = model.hparams;
783
-
784
- const int n_vocab = hparams.n_vocab;
785
-
786
- const int n_audio_ctx = hparams.n_audio_ctx;
787
- const int n_audio_state = hparams.n_audio_state;
788
- const int n_audio_layer = hparams.n_audio_layer;
789
-
790
- const int n_text_ctx = hparams.n_text_ctx;
791
- const int n_text_state = hparams.n_text_state;
792
- const int n_text_layer = hparams.n_text_layer;
793
-
794
- const int n_mels = hparams.n_mels;
795
-
796
- model.layers_encoder.resize(n_audio_layer);
797
- model.layers_decoder.resize(n_text_layer);
798
-
799
- // encoder
800
- {
801
- model.e_pe = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_audio_state, n_audio_ctx);
802
-
803
- model.e_conv_1_w = ggml_new_tensor_3d(ctx, wtype, 3, n_mels, n_audio_state);
804
- model.e_conv_1_b = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, 1, n_audio_state);
805
-
806
- model.e_conv_2_w = ggml_new_tensor_3d(ctx, wtype, 3, n_audio_state, n_audio_state);
807
- model.e_conv_2_b = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, 1, n_audio_state);
808
-
809
- model.e_ln_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state);
810
- model.e_ln_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state);
811
-
812
- // map by name
813
- model.tensors["encoder.positional_embedding"] = model.e_pe;
814
-
815
- model.tensors["encoder.conv1.weight"] = model.e_conv_1_w;
816
- model.tensors["encoder.conv1.bias"] = model.e_conv_1_b;
817
-
818
- model.tensors["encoder.conv2.weight"] = model.e_conv_2_w;
819
- model.tensors["encoder.conv2.bias"] = model.e_conv_2_b;
820
-
821
- model.tensors["encoder.ln_post.weight"] = model.e_ln_w;
822
- model.tensors["encoder.ln_post.bias"] = model.e_ln_b;
823
-
824
- for (int i = 0; i < n_audio_layer; ++i) {
825
- auto & layer = model.layers_encoder[i];
826
-
827
- layer.mlp_ln_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state);
828
- layer.mlp_ln_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state);
829
-
830
- layer.mlp_0_w = ggml_new_tensor_2d(ctx, wtype, n_audio_state, 4*n_audio_state);
831
- layer.mlp_0_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 4*n_audio_state);
832
-
833
- layer.mlp_1_w = ggml_new_tensor_2d(ctx, wtype, 4*n_audio_state, n_audio_state);
834
- layer.mlp_1_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state);
835
-
836
- layer.attn_ln_0_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state);
837
- layer.attn_ln_0_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state);
838
-
839
- layer.attn_q_w = ggml_new_tensor_2d(ctx, wtype, n_audio_state, n_audio_state);
840
- layer.attn_q_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state);
841
-
842
- layer.attn_k_w = ggml_new_tensor_2d(ctx, wtype, n_audio_state, n_audio_state);
843
-
844
- layer.attn_v_w = ggml_new_tensor_2d(ctx, wtype, n_audio_state, n_audio_state);
845
- layer.attn_v_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state);
846
-
847
- layer.attn_ln_1_w = ggml_new_tensor_2d(ctx, wtype, n_audio_state, n_audio_state);
848
- layer.attn_ln_1_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state);
849
-
850
- // map by name
851
- model.tensors["encoder.blocks." + std::to_string(i) + ".mlp_ln.weight"] = layer.mlp_ln_w;
852
- model.tensors["encoder.blocks." + std::to_string(i) + ".mlp_ln.bias"] = layer.mlp_ln_b;
853
-
854
- model.tensors["encoder.blocks." + std::to_string(i) + ".mlp.0.weight"] = layer.mlp_0_w;
855
- model.tensors["encoder.blocks." + std::to_string(i) + ".mlp.0.bias"] = layer.mlp_0_b;
856
-
857
- model.tensors["encoder.blocks." + std::to_string(i) + ".mlp.2.weight"] = layer.mlp_1_w;
858
- model.tensors["encoder.blocks." + std::to_string(i) + ".mlp.2.bias"] = layer.mlp_1_b;
859
-
860
- model.tensors["encoder.blocks." + std::to_string(i) + ".attn_ln.weight"] = layer.attn_ln_0_w;
861
- model.tensors["encoder.blocks." + std::to_string(i) + ".attn_ln.bias"] = layer.attn_ln_0_b;
862
-
863
- model.tensors["encoder.blocks." + std::to_string(i) + ".attn.query.weight"] = layer.attn_q_w;
864
- model.tensors["encoder.blocks." + std::to_string(i) + ".attn.query.bias"] = layer.attn_q_b;
865
-
866
- model.tensors["encoder.blocks." + std::to_string(i) + ".attn.key.weight"] = layer.attn_k_w;
867
-
868
- model.tensors["encoder.blocks." + std::to_string(i) + ".attn.value.weight"] = layer.attn_v_w;
869
- model.tensors["encoder.blocks." + std::to_string(i) + ".attn.value.bias"] = layer.attn_v_b;
870
-
871
- model.tensors["encoder.blocks." + std::to_string(i) + ".attn.out.weight"] = layer.attn_ln_1_w;
872
- model.tensors["encoder.blocks." + std::to_string(i) + ".attn.out.bias"] = layer.attn_ln_1_b;
873
- }
874
- }
875
-
876
- // decoder
877
- {
878
- model.d_pe = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_text_state, n_text_ctx);
879
-
880
- model.d_te = ggml_new_tensor_2d(ctx, wtype, n_text_state, n_vocab);
881
-
882
- model.d_ln_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
883
- model.d_ln_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
884
-
885
- // map by name
886
- model.tensors["decoder.positional_embedding"] = model.d_pe;
887
-
888
- model.tensors["decoder.token_embedding.weight"] = model.d_te;
889
-
890
- model.tensors["decoder.ln.weight"] = model.d_ln_w;
891
- model.tensors["decoder.ln.bias"] = model.d_ln_b;
892
-
893
- for (int i = 0; i < n_text_layer; ++i) {
894
- auto & layer = model.layers_decoder[i];
895
-
896
- layer.mlp_ln_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
897
- layer.mlp_ln_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
898
-
899
- layer.mlp_0_w = ggml_new_tensor_2d(ctx, wtype, n_text_state, 4*n_text_state);
900
- layer.mlp_0_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 4*n_text_state);
901
-
902
- layer.mlp_1_w = ggml_new_tensor_2d(ctx, wtype, 4*n_text_state, n_text_state);
903
- layer.mlp_1_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
904
-
905
- layer.attn_ln_0_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
906
- layer.attn_ln_0_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
907
-
908
- layer.attn_q_w = ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state);
909
- layer.attn_q_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
910
-
911
- layer.attn_k_w = ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state);
912
-
913
- layer.attn_v_w = ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state);
914
- layer.attn_v_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
915
-
916
- layer.attn_ln_1_w = ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state);
917
- layer.attn_ln_1_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
918
-
919
- layer.cross_attn_ln_0_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
920
- layer.cross_attn_ln_0_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
921
-
922
- layer.cross_attn_q_w = ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state);
923
- layer.cross_attn_q_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
924
-
925
- layer.cross_attn_k_w = ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state);
926
-
927
- layer.cross_attn_v_w = ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state);
928
- layer.cross_attn_v_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
929
-
930
- layer.cross_attn_ln_1_w = ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state);
931
- layer.cross_attn_ln_1_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
932
-
933
- // map by name
934
- model.tensors["decoder.blocks." + std::to_string(i) + ".mlp_ln.weight"] = layer.mlp_ln_w;
935
- model.tensors["decoder.blocks." + std::to_string(i) + ".mlp_ln.bias"] = layer.mlp_ln_b;
936
-
937
- model.tensors["decoder.blocks." + std::to_string(i) + ".mlp.0.weight"] = layer.mlp_0_w;
938
- model.tensors["decoder.blocks." + std::to_string(i) + ".mlp.0.bias"] = layer.mlp_0_b;
939
-
940
- model.tensors["decoder.blocks." + std::to_string(i) + ".mlp.2.weight"] = layer.mlp_1_w;
941
- model.tensors["decoder.blocks." + std::to_string(i) + ".mlp.2.bias"] = layer.mlp_1_b;
942
-
943
- model.tensors["decoder.blocks." + std::to_string(i) + ".attn_ln.weight"] = layer.attn_ln_0_w;
944
- model.tensors["decoder.blocks." + std::to_string(i) + ".attn_ln.bias"] = layer.attn_ln_0_b;
945
-
946
- model.tensors["decoder.blocks." + std::to_string(i) + ".attn.query.weight"] = layer.attn_q_w;
947
- model.tensors["decoder.blocks." + std::to_string(i) + ".attn.query.bias"] = layer.attn_q_b;
948
-
949
- model.tensors["decoder.blocks." + std::to_string(i) + ".attn.key.weight"] = layer.attn_k_w;
950
-
951
- model.tensors["decoder.blocks." + std::to_string(i) + ".attn.value.weight"] = layer.attn_v_w;
952
- model.tensors["decoder.blocks." + std::to_string(i) + ".attn.value.bias"] = layer.attn_v_b;
953
-
954
- model.tensors["decoder.blocks." + std::to_string(i) + ".attn.out.weight"] = layer.attn_ln_1_w;
955
- model.tensors["decoder.blocks." + std::to_string(i) + ".attn.out.bias"] = layer.attn_ln_1_b;
956
-
957
- model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn_ln.weight"] = layer.cross_attn_ln_0_w;
958
- model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn_ln.bias"] = layer.cross_attn_ln_0_b;
959
-
960
- model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn.query.weight"] = layer.cross_attn_q_w;
961
- model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn.query.bias"] = layer.cross_attn_q_b;
962
-
963
- model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn.key.weight"] = layer.cross_attn_k_w;
964
-
965
- model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn.value.weight"] = layer.cross_attn_v_w;
966
- model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn.value.bias"] = layer.cross_attn_v_b;
967
-
968
- model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn.out.weight"] = layer.cross_attn_ln_1_w;
969
- model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn.out.bias"] = layer.cross_attn_ln_1_b;
970
- }
971
- }
972
- }
973
-
974
- // key + value memory
975
- {
976
- const auto & hparams = model.hparams;
977
-
978
- const int n_text_state = hparams.n_text_state;
979
- const int n_text_layer = hparams.n_text_layer;
980
- const int n_text_ctx = hparams.n_text_ctx;
981
-
982
- // key/value memory for the self-attention layer
983
- {
984
- const int n_mem = n_text_layer*n_text_ctx;
985
- const int n_elements = n_text_state*n_mem;
986
-
987
- model.memory_k = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, n_elements);
988
- model.memory_v = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, n_elements);
989
- }
990
-
991
- // key/value memory for the cross-attention layer
992
- {
993
- const int n_audio_ctx = hparams.n_audio_ctx;
994
-
995
- const int n_mem = n_text_layer*n_audio_ctx;
996
- const int n_elements = n_text_state*n_mem;
997
-
998
- model.memory_cross_k = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, n_elements);
999
- model.memory_cross_v = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, n_elements);
1000
- }
1001
-
1002
- const size_t memory_size =
1003
- ggml_nbytes(model.memory_k) + ggml_nbytes(model.memory_v) +
1004
- ggml_nbytes(model.memory_cross_k) + ggml_nbytes(model.memory_cross_v);
1005
-
1006
- printf("%s: memory size = %8.2f MB \n", __func__, memory_size/1024.0/1024.0);
1007
- }
1008
-
1009
- // load weights
1010
- {
1011
- size_t total_size = 0;
1012
-
1013
- while (true) {
1014
- int32_t n_dims;
1015
- int32_t length;
1016
- int32_t ftype;
1017
-
1018
- fin.read(reinterpret_cast<char *>(&n_dims), sizeof(n_dims));
1019
- fin.read(reinterpret_cast<char *>(&length), sizeof(length));
1020
- fin.read(reinterpret_cast<char *>(&ftype), sizeof(ftype));
1021
-
1022
- if (fin.eof()) {
1023
- break;
1024
- }
1025
-
1026
- int32_t nelements = 1;
1027
- int32_t ne[3] = { 1, 1, 1 };
1028
- for (int i = 0; i < n_dims; ++i) {
1029
- fin.read(reinterpret_cast<char *>(&ne[i]), sizeof(ne[i]));
1030
- nelements *= ne[i];
1031
- }
1032
-
1033
- std::string name(length, 0);
1034
- fin.read(&name[0], length);
1035
-
1036
- if (model.tensors.find(name.data()) == model.tensors.end()) {
1037
- fprintf(stderr, "%s: unknown tensor '%s' in model file\n", __func__, name.data());
1038
- return false;
1039
- }
1040
-
1041
- auto tensor = model.tensors[name.data()];
1042
- if (ggml_nelements(tensor) != nelements) {
1043
- fprintf(stderr, "%s: tensor '%s' has wrong size in model file\n", __func__, name.data());
1044
- return false;
1045
- }
1046
-
1047
- if (tensor->ne[0] != ne[0] || tensor->ne[1] != ne[1] || tensor->ne[2] != ne[2]) {
1048
- fprintf(stderr, "%s: tensor '%s' has wrong shape in model file: got [%d, %d, %d], expected [%d, %d, %d]\n",
1049
- __func__, name.data(), tensor->ne[0], tensor->ne[1], tensor->ne[2], ne[0], ne[1], ne[2]);
1050
- return false;
1051
- }
1052
-
1053
- const size_t bpe = (ftype == 0) ? sizeof(float) : sizeof(ggml_fp16_t);
1054
-
1055
- if (nelements*bpe != ggml_nbytes(tensor)) {
1056
- fprintf(stderr, "%s: tensor '%s' has wrong size in model file: got %zu, expected %zu\n",
1057
- __func__, name.data(), ggml_nbytes(tensor), nelements*bpe);
1058
- return false;
1059
- }
1060
-
1061
- fin.read(reinterpret_cast<char *>(tensor->data), ggml_nbytes(tensor));
1062
-
1063
- //printf("%24s - [%5d, %5d], type = %6s, %6.2f MB\n", name.data(), ne[0], ne[1], ftype == 0 ? "float" : "f16", ggml_nbytes(tensor)/1024.0/1024.0);
1064
- total_size += ggml_nbytes(tensor);
1065
- }
1066
-
1067
- printf("%s: model size = %8.2f MB\n", __func__, total_size/1024.0/1024.0);
1068
- }
1069
-
1070
- fin.close();
1071
-
1072
- return true;
1073
- }
1074
-
1075
- // evaluate the encoder
1076
- //
1077
- // given audio recording (more specifically, its log mel spectrogram), runs forward pass of the encoder
1078
- // part of the transformer model and returns the encoded features
1079
- //
1080
- // - model: the model
1081
- // - n_threads: number of threads to use
1082
- // - mel_offset: offset in the mel spectrogram (i.e. audio offset)
1083
- // - mel_inp: input mel spectrogram
1084
- // - features: output encoded features
1085
- //
1086
- bool whisper_encode(
1087
- const whisper_model & model,
1088
- const int n_threads,
1089
- const int mel_offset,
1090
- const whisper_mel & mel_inp,
1091
- std::vector<float> & features) {
1092
- const auto & hparams = model.hparams;
1093
-
1094
- const int n_vocab = hparams.n_vocab;
1095
-
1096
- const int n_ctx = hparams.n_audio_ctx;
1097
- const int n_state = hparams.n_audio_state;
1098
- const int n_head = hparams.n_audio_head;
1099
- const int n_layer = hparams.n_audio_layer;
1100
-
1101
- const int N = n_ctx;
1102
-
1103
- const int n_mels = hparams.n_mels;
1104
- assert(mel_inp.n_mel == n_mels);
1105
-
1106
- struct ggml_init_params params = {
1107
- .mem_size = g_buf_compute.size(),
1108
- .mem_buffer = g_buf_compute.data(),
1109
- };
1110
-
1111
- struct ggml_context * ctx0 = ggml_init(params);
1112
-
1113
- struct ggml_tensor * mel = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, 2*n_ctx, n_mels);
1114
- assert(mel->type == GGML_TYPE_F32);
1115
- {
1116
- float * dst = (float *) mel->data;
1117
- memset(dst, 0, ggml_nbytes(mel));
1118
-
1119
- const int i0 = std::min(mel_offset, mel_inp.n_len);
1120
- const int i1 = std::min(mel_offset + 2*n_ctx, mel_inp.n_len);
1121
-
1122
- for (int j = 0; j < mel_inp.n_mel; ++j) {
1123
- for (int i = i0; i < i1; ++i) {
1124
- dst[j*2*n_ctx + (i - i0)] = mel_inp.data[j*mel_inp.n_len + i];
1125
- }
1126
- }
1127
- }
1128
-
1129
- struct ggml_tensor * cur;
1130
-
1131
- // convolution + gelu
1132
- {
1133
- cur = ggml_conv_1d_1s(ctx0, model.e_conv_1_w, mel);
1134
- cur = ggml_add(ctx0,
1135
- ggml_repeat(ctx0,
1136
- model.e_conv_1_b,
1137
- cur),
1138
- cur);
1139
-
1140
- cur = ggml_gelu(ctx0, cur);
1141
-
1142
- cur = ggml_conv_1d_2s(ctx0, model.e_conv_2_w, cur);
1143
- cur = ggml_add(ctx0,
1144
- ggml_repeat(ctx0,
1145
- model.e_conv_2_b,
1146
- cur),
1147
- cur);
1148
-
1149
- cur = ggml_gelu(ctx0, cur);
1150
- }
1151
-
1152
- cur = ggml_add(ctx0, model.e_pe, ggml_transpose(ctx0, cur));
1153
-
1154
- struct ggml_tensor * inpL = cur;
1155
-
1156
- for (int il = 0; il < n_layer; ++il) {
1157
- const auto & layer = model.layers_encoder[il];
1158
-
1159
- // create separate context for each layer to reduce memory usage
1160
-
1161
- struct ggml_init_params paramsL = {
1162
- .mem_size = g_buf_compute_layer.size(),
1163
- .mem_buffer = g_buf_compute_layer.data(),
1164
- };
1165
-
1166
- struct ggml_context * ctxL = ggml_init(paramsL);
1167
-
1168
- // norm
1169
- {
1170
- cur = ggml_norm(ctxL, inpL);
1171
-
1172
- // cur = ln_0_w*cur + ln_0_b
1173
- cur = ggml_add(ctxL,
1174
- ggml_mul(ctxL,
1175
- ggml_repeat(ctxL, layer.attn_ln_0_w, cur),
1176
- cur),
1177
- ggml_repeat(ctxL, layer.attn_ln_0_b, cur));
1178
- }
1179
-
1180
- // self-attention
1181
- {
1182
- struct ggml_tensor * Qcur = ggml_mul_mat(ctxL,
1183
- layer.attn_q_w,
1184
- cur);
1185
-
1186
- Qcur = ggml_add(ctxL,
1187
- ggml_repeat(ctxL,
1188
- layer.attn_q_b,
1189
- Qcur),
1190
- Qcur);
1191
-
1192
- //Qcur = ggml_scale(ctxL, Qcur, ggml_new_f32(ctxL, pow(float(n_state)/n_head, -0.25)));
1193
-
1194
- // note: no bias for Key
1195
- struct ggml_tensor * Kcur = ggml_mul_mat(ctxL,
1196
- layer.attn_k_w,
1197
- cur);
1198
-
1199
- //Kcur = ggml_scale(ctxL, Kcur, ggml_new_f32(ctxL, pow(float(n_state)/n_head, -0.25)));
1200
-
1201
- struct ggml_tensor * Vcur = ggml_mul_mat(ctxL,
1202
- layer.attn_v_w,
1203
- cur);
1204
-
1205
- Vcur = ggml_add(ctxL,
1206
- ggml_repeat(ctxL,
1207
- layer.attn_v_b,
1208
- Vcur),
1209
- Vcur);
1210
-
1211
- // ------
1212
-
1213
- #ifdef USE_FLASH_ATTN
1214
- struct ggml_tensor * Q =
1215
- ggml_permute(ctxL,
1216
- ggml_cpy(ctxL,
1217
- Qcur,
1218
- ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_state/n_head, n_head, N)),
1219
- 0, 2, 1, 3);
1220
-
1221
- struct ggml_tensor * K =
1222
- ggml_permute(ctxL,
1223
- ggml_cpy(ctxL,
1224
- Kcur,
1225
- ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_state/n_head, n_head, N)),
1226
- 0, 2, 1, 3);
1227
-
1228
- struct ggml_tensor * V =
1229
- ggml_cpy(ctxL,
1230
- ggml_permute(ctxL,
1231
- ggml_reshape_3d(ctxL,
1232
- Vcur,
1233
- n_state/n_head, n_head, N),
1234
- 1, 2, 0, 3),
1235
- ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, N, n_state/n_head, n_head)
1236
- );
1237
-
1238
- struct ggml_tensor * KQV = ggml_flash_attn(ctxL, Q, K, V, false);
1239
- #else
1240
- struct ggml_tensor * Q =
1241
- ggml_permute(ctxL,
1242
- ggml_cpy(ctxL,
1243
- Qcur,
1244
- ggml_new_tensor_3d(ctxL, GGML_TYPE_F32, n_state/n_head, n_head, N)),
1245
- 0, 2, 1, 3);
1246
-
1247
- struct ggml_tensor * K =
1248
- ggml_permute(ctxL,
1249
- ggml_cpy(ctxL,
1250
- Kcur,
1251
- ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_state/n_head, n_head, N)),
1252
- 0, 2, 1, 3);
1253
-
1254
- // K * Q
1255
- struct ggml_tensor * KQ = ggml_mul_mat(ctxL, K, Q);
1256
-
1257
- struct ggml_tensor * KQ_scaled =
1258
- ggml_scale(ctxL,
1259
- KQ,
1260
- ggml_new_f32(ctxL, 1.0f/sqrt(float(n_state)/n_head))
1261
- );
1262
-
1263
- struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctxL, KQ_scaled);
1264
-
1265
- //struct ggml_tensor * V_trans =
1266
- // ggml_permute(ctxL,
1267
- // ggml_cpy(ctxL,
1268
- // Vcur,
1269
- // ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_state/n_head, n_head, N)),
1270
- // 1, 2, 0, 3);
1271
-
1272
- //struct ggml_tensor * KQV = ggml_mul_mat(ctxL, V_trans, KQ_soft_max);
1273
-
1274
- struct ggml_tensor * V =
1275
- ggml_cpy(ctxL,
1276
- ggml_permute(ctxL,
1277
- ggml_reshape_3d(ctxL,
1278
- Vcur,
1279
- n_state/n_head, n_head, N),
1280
- 0, 2, 1, 3),
1281
- ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_state/n_head, N, n_head)
1282
- );
1283
-
1284
- struct ggml_tensor * KQV = ggml_mul_mat(ctxL, ggml_transpose(ctxL, V), KQ_soft_max);
1285
- #endif
1286
-
1287
- struct ggml_tensor * KQV_merged = ggml_permute(ctxL, KQV, 0, 2, 1, 3);
1288
-
1289
- cur = ggml_cpy(ctxL,
1290
- KQV_merged,
1291
- ggml_new_tensor_2d(ctxL, GGML_TYPE_F32, n_state, N));
1292
- }
1293
-
1294
- // projection
1295
- {
1296
- cur = ggml_mul_mat(ctxL,
1297
- layer.attn_ln_1_w,
1298
- cur);
1299
-
1300
- cur = ggml_add(ctxL,
1301
- ggml_repeat(ctxL, layer.attn_ln_1_b, cur),
1302
- cur);
1303
- }
1304
-
1305
- // add the input
1306
- cur = ggml_add(ctxL, cur, inpL);
1307
-
1308
- struct ggml_tensor * inpFF = cur;
1309
-
1310
- // feed-forward network
1311
- {
1312
- // norm
1313
- {
1314
- cur = ggml_norm(ctxL, inpFF);
1315
-
1316
- // cur = mlp_ln_w*cur + mlp_ln_b
1317
- cur = ggml_add(ctxL,
1318
- ggml_mul(ctxL,
1319
- ggml_repeat(ctxL, layer.mlp_ln_w, cur),
1320
- cur),
1321
- ggml_repeat(ctxL, layer.mlp_ln_b, cur));
1322
- }
1323
-
1324
- #ifdef USE_FLASH_FF
1325
- cur = ggml_flash_ff(ctxL,
1326
- ggml_cpy(ctxL, cur, ggml_new_tensor_2d(ctxL, GGML_TYPE_F16, n_state, N)),
1327
- layer.mlp_0_w, layer.mlp_0_b, layer.mlp_1_w, layer.mlp_1_b);
1328
- #else
1329
- // fully connected
1330
- cur = ggml_mul_mat(ctxL,
1331
- layer.mlp_0_w,
1332
- cur);
1333
-
1334
- cur = ggml_add(ctxL,
1335
- ggml_repeat(ctxL, layer.mlp_0_b, cur),
1336
- cur);
1337
-
1338
- // GELU activation
1339
- cur = ggml_gelu(ctxL, cur);
1340
-
1341
- // projection
1342
- cur = ggml_mul_mat(ctxL,
1343
- layer.mlp_1_w,
1344
- cur);
1345
-
1346
- cur = ggml_add(ctxL,
1347
- ggml_repeat(ctxL, layer.mlp_1_b, cur),
1348
- cur);
1349
- #endif
1350
- }
1351
-
1352
- // output from this layer
1353
- struct ggml_tensor * inpO = ggml_add(ctxL, cur, inpFF);
1354
-
1355
- {
1356
- struct ggml_cgraph gf = { .n_threads = n_threads };
1357
-
1358
- ggml_build_forward_expand(&gf, inpO);
1359
- ggml_graph_compute (ctxL, &gf);
1360
-
1361
- //ggml_graph_print(&gf);
1362
- }
1363
-
1364
- // TODO: this is a hack to have per-layer computation graphs - need to come up with something better
1365
- // input for next layer (inpO -> inpL)
1366
- memcpy(inpL->data, inpO->data, ggml_nbytes(inpL));
1367
- inpL->op = GGML_OP_NONE;
1368
- inpL->src0 = NULL;
1369
- inpL->src1 = NULL;
1370
-
1371
- //printf("%s: - used_mem(%d) = %f MB\n", __func__, il, ggml_used_mem(ctxL)/1024.0/1024.0);
1372
-
1373
- ggml_free(ctxL);
1374
- }
1375
-
1376
- cur = inpL;
1377
-
1378
- // norm
1379
- {
1380
- cur = ggml_norm(ctx0, cur);
1381
-
1382
- // cur = ln_f_g*cur + ln_f_b
1383
- cur = ggml_add(ctx0,
1384
- ggml_mul(ctx0,
1385
- ggml_repeat(ctx0, model.e_ln_w, cur),
1386
- cur),
1387
- ggml_repeat(ctx0, model.e_ln_b, cur));
1388
- }
1389
-
1390
- // run the computation
1391
- {
1392
- struct ggml_cgraph gf = { .n_threads = n_threads };
1393
-
1394
- ggml_build_forward_expand(&gf, cur);
1395
- ggml_graph_compute (ctx0, &gf);
1396
-
1397
- //ggml_graph_print(&gf);
1398
- }
1399
-
1400
- // cur
1401
- //{
1402
- // printf("ne0 = %d\n", cur->ne[0]);
1403
- // printf("ne1 = %d\n", cur->ne[1]);
1404
- // for (int i = 0; i < 10; ++i) {
1405
- // printf("%8.4f ", ((float *)(cur->data))[i]);
1406
- // }
1407
- // printf("... ");
1408
- // for (int i = cur->ne[0] - 10; i < cur->ne[0]; ++i) {
1409
- // printf("%8.4f ", ((float *)(cur->data))[i]);
1410
- // }
1411
- // printf("\n");
1412
- //}
1413
-
1414
- // pre-compute cross-attention memory
1415
- {
1416
- struct ggml_cgraph gf = { .n_threads = n_threads };
1417
-
1418
- // TODO: hack to disconnect the encoded features from the previous graph
1419
- cur->op = GGML_OP_NONE;
1420
- cur->src0 = NULL;
1421
- cur->src1 = NULL;
1422
-
1423
- for (int il = 0; il < model.hparams.n_text_layer; ++il) {
1424
- auto & layer = model.layers_decoder[il];
1425
-
1426
- struct ggml_tensor * Kcross = ggml_mul_mat(ctx0,
1427
- layer.cross_attn_k_w,
1428
- cur);
1429
-
1430
- Kcross = ggml_scale(ctx0, Kcross, ggml_new_f32(ctx0, pow(float(n_state)/n_head, -0.25)));
1431
-
1432
- struct ggml_tensor * Vcross = ggml_mul_mat(ctx0,
1433
- layer.cross_attn_v_w,
1434
- cur);
1435
-
1436
- Vcross = ggml_add(ctx0,
1437
- ggml_repeat(ctx0,
1438
- layer.cross_attn_v_b,
1439
- Vcross),
1440
- Vcross);
1441
-
1442
- struct ggml_tensor * k = ggml_view_1d(ctx0, model.memory_cross_k, n_state*n_ctx, (ggml_element_size(model.memory_cross_k)*n_state)*(il*n_ctx));
1443
- struct ggml_tensor * v = ggml_view_1d(ctx0, model.memory_cross_v, n_state*n_ctx, (ggml_element_size(model.memory_cross_v)*n_state)*(il*n_ctx));
1444
-
1445
- ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Kcross, k));
1446
- ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Vcross, v));
1447
- }
1448
-
1449
- ggml_graph_compute(ctx0, &gf);
1450
- }
1451
-
1452
- ////////////////////////////////////////////////////////////////////////////
1453
-
1454
- // output the features
1455
- assert(cur->type == GGML_TYPE_F32);
1456
- features.resize(cur->ne[0]*cur->ne[1]);
1457
- memcpy(features.data(), cur->data, features.size()*sizeof(float));
1458
-
1459
- //printf("%s: used_mem = %f MB\n", __func__, ggml_used_mem(ctx0)/1024.0/1024.0);
1460
-
1461
- ggml_free(ctx0);
1462
-
1463
- return true;
1464
- }
1465
-
1466
- // evaluate the decoder
1467
- //
1468
- // given text prompt + audio features -> predicts the probabilities for the next token
1469
- //
1470
- // - model: the model
1471
- // - n_threads: number of threads to use
1472
- // - n_past: prompt length
1473
- // - prompt: text prompt
1474
- // - logits_out: output logits
1475
- // - probs_out: output probabilities
1476
- //
1477
- bool whisper_decode(
1478
- const whisper_model & model,
1479
- const int n_threads,
1480
- const int n_past,
1481
- const std::vector<whisper_vocab::id> & prompt,
1482
- std::vector<float> & logits_out,
1483
- std::vector<float> & probs_out) {
1484
- const auto & hparams = model.hparams;
1485
-
1486
- const int n_vocab = hparams.n_vocab;
1487
-
1488
- const int n_ctx = hparams.n_text_ctx;
1489
- const int n_state = hparams.n_text_state;
1490
- const int n_head = hparams.n_text_head;
1491
- const int n_layer = hparams.n_text_layer;
1492
-
1493
- const int N = prompt.size();
1494
- const int M = hparams.n_audio_ctx;
1495
-
1496
- struct ggml_init_params params = {
1497
- .mem_size = g_buf_compute.size(),
1498
- .mem_buffer = g_buf_compute.data(),
1499
- };
1500
-
1501
- struct ggml_context * ctx0 = ggml_init(params);
1502
-
1503
- struct ggml_tensor * embd = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
1504
- memcpy(embd->data, prompt.data(), N*ggml_element_size(embd));
1505
-
1506
- struct ggml_tensor * position = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
1507
- for (int i = 0; i < N; ++i) {
1508
- ((int32_t *) position->data)[i] = n_past + i;
1509
- }
1510
-
1511
- // token encoding + position encoding
1512
- struct ggml_tensor * cur =
1513
- ggml_add(ctx0,
1514
- ggml_get_rows(ctx0, model.d_te, embd),
1515
- ggml_get_rows(ctx0, model.d_pe, position));
1516
-
1517
- struct ggml_tensor * inpL = cur;
1518
-
1519
- for (int il = 0; il < n_layer; ++il) {
1520
- const auto & layer = model.layers_decoder[il];
1521
-
1522
- struct ggml_init_params paramsL = {
1523
- .mem_size = g_buf_compute_layer.size(),
1524
- .mem_buffer = g_buf_compute_layer.data(),
1525
- };
1526
-
1527
- struct ggml_context * ctxL = ggml_init(paramsL);
1528
- struct ggml_cgraph gf = { .n_threads = n_threads };
1529
-
1530
- // norm
1531
- {
1532
- cur = ggml_norm(ctxL, inpL);
1533
-
1534
- // cur = ln_0_w*cur + ln_0_b
1535
- cur = ggml_add(ctxL,
1536
- ggml_mul(ctxL,
1537
- ggml_repeat(ctxL, layer.attn_ln_0_w, cur),
1538
- cur),
1539
- ggml_repeat(ctxL, layer.attn_ln_0_b, cur));
1540
- }
1541
-
1542
- // self-attention
1543
- {
1544
- struct ggml_tensor * Qcur = ggml_mul_mat(ctxL,
1545
- layer.attn_q_w,
1546
- cur);
1547
-
1548
- Qcur = ggml_add(ctxL,
1549
- ggml_repeat(ctxL,
1550
- layer.attn_q_b,
1551
- Qcur),
1552
- Qcur);
1553
-
1554
- Qcur = ggml_scale(ctxL, Qcur, ggml_new_f32(ctxL, pow(float(n_state)/n_head, -0.25)));
1555
-
1556
- // note: no bias for Key
1557
- struct ggml_tensor * Kcur = ggml_mul_mat(ctxL,
1558
- layer.attn_k_w,
1559
- cur);
1560
-
1561
- Kcur = ggml_scale(ctxL, Kcur, ggml_new_f32(ctxL, pow(float(n_state)/n_head, -0.25)));
1562
-
1563
- struct ggml_tensor * Vcur = ggml_mul_mat(ctxL,
1564
- layer.attn_v_w,
1565
- cur);
1566
-
1567
- Vcur = ggml_add(ctxL,
1568
- ggml_repeat(ctxL,
1569
- layer.attn_v_b,
1570
- Vcur),
1571
- Vcur);
1572
-
1573
- // store key and value to memory
1574
- {
1575
- struct ggml_tensor * k = ggml_view_1d(ctxL, model.memory_k, N*n_state, (ggml_element_size(model.memory_k)*n_state)*(il*n_ctx + n_past));
1576
- struct ggml_tensor * v = ggml_view_1d(ctxL, model.memory_v, N*n_state, (ggml_element_size(model.memory_v)*n_state)*(il*n_ctx + n_past));
1577
-
1578
- ggml_build_forward_expand(&gf, ggml_cpy(ctxL, Kcur, k));
1579
- ggml_build_forward_expand(&gf, ggml_cpy(ctxL, Vcur, v));
1580
- }
1581
-
1582
- // ------
1583
-
1584
- struct ggml_tensor * Q =
1585
- ggml_permute(ctxL,
1586
- ggml_cpy(ctxL,
1587
- Qcur,
1588
- ggml_new_tensor_3d(ctxL, GGML_TYPE_F32, n_state/n_head, n_head, N)),
1589
- 0, 2, 1, 3);
1590
-
1591
- struct ggml_tensor * K =
1592
- ggml_permute(ctxL,
1593
- ggml_reshape_3d(ctxL,
1594
- ggml_view_1d(ctxL, model.memory_k, (n_past + N)*n_state, il*n_ctx*ggml_element_size(model.memory_k)*n_state),
1595
- n_state/n_head, n_head, n_past + N),
1596
- 0, 2, 1, 3);
1597
-
1598
- // K * Q
1599
- struct ggml_tensor * KQ = ggml_mul_mat(ctxL, K, Q);
1600
-
1601
- //struct ggml_tensor * KQ_scaled =
1602
- // ggml_scale(ctxL,
1603
- // KQ,
1604
- // ggml_new_f32(ctxL, 1.0f/sqrt(float(n_state)/n_head))
1605
- // );
1606
-
1607
- struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctxL, KQ, n_past);
1608
-
1609
- struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctxL, KQ_masked);
1610
-
1611
- struct ggml_tensor * V_trans =
1612
- ggml_permute(ctxL,
1613
- ggml_reshape_3d(ctxL,
1614
- ggml_view_1d(ctxL, model.memory_v, (n_past + N)*n_state, il*n_ctx*ggml_element_size(model.memory_v)*n_state),
1615
- n_state/n_head, n_head, n_past + N),
1616
- 1, 2, 0, 3);
1617
-
1618
- struct ggml_tensor * KQV = ggml_mul_mat(ctxL, V_trans, KQ_soft_max);
1619
-
1620
- struct ggml_tensor * KQV_merged = ggml_permute(ctxL, KQV, 0, 2, 1, 3);
1621
-
1622
- cur = ggml_cpy(ctxL,
1623
- KQV_merged,
1624
- ggml_new_tensor_2d(ctxL, GGML_TYPE_F32, n_state, N));
1625
- }
1626
-
1627
- {
1628
- cur = ggml_mul_mat(ctxL,
1629
- layer.attn_ln_1_w,
1630
- cur);
1631
-
1632
- cur = ggml_add(ctxL,
1633
- ggml_repeat(ctxL, layer.attn_ln_1_b, cur),
1634
- cur);
1635
- }
1636
-
1637
- // add the input
1638
- struct ggml_tensor * inpCA = ggml_add(ctxL, cur, inpL);
1639
-
1640
- // norm
1641
- {
1642
- cur = ggml_norm(ctxL, inpCA); // note: we use inpCA here
1643
-
1644
- // cur = ln_0_w*cur + ln_0_b
1645
- cur = ggml_add(ctxL,
1646
- ggml_mul(ctxL,
1647
- ggml_repeat(ctxL, layer.cross_attn_ln_0_w, cur),
1648
- cur),
1649
- ggml_repeat(ctxL, layer.cross_attn_ln_0_b, cur));
1650
- }
1651
-
1652
- // cross-attention
1653
- {
1654
- struct ggml_tensor * Qcur = ggml_mul_mat(ctxL,
1655
- layer.cross_attn_q_w,
1656
- cur);
1657
-
1658
- Qcur = ggml_add(ctxL,
1659
- ggml_repeat(ctxL,
1660
- layer.cross_attn_q_b,
1661
- Qcur),
1662
- Qcur);
1663
-
1664
- Qcur = ggml_scale(ctxL, Qcur, ggml_new_f32(ctxL, pow(float(n_state)/n_head, -0.25)));
1665
-
1666
- // Kcross is already scaled
1667
- struct ggml_tensor * Kcross =
1668
- ggml_reshape_3d(ctxL,
1669
- ggml_view_1d(ctxL, model.memory_cross_k, M*n_state, il*M*ggml_element_size(model.memory_cross_k)*n_state),
1670
- n_state/n_head, n_head, M);
1671
-
1672
- struct ggml_tensor * Vcross =
1673
- ggml_reshape_3d(ctxL,
1674
- ggml_view_1d(ctxL, model.memory_cross_v, M*n_state, il*M*ggml_element_size(model.memory_cross_v)*n_state),
1675
- n_state/n_head, n_head, M);
1676
-
1677
- // ------
1678
-
1679
- struct ggml_tensor * Q =
1680
- ggml_permute(ctxL,
1681
- ggml_cpy(ctxL,
1682
- Qcur,
1683
- ggml_new_tensor_3d(ctxL, GGML_TYPE_F32, n_state/n_head, n_head, N)),
1684
- 0, 2, 1, 3);
1685
-
1686
- struct ggml_tensor * K = ggml_permute(ctxL, Kcross, 0, 2, 1, 3);
1687
-
1688
- // K * Q
1689
- struct ggml_tensor * KQ = ggml_mul_mat(ctxL, K, Q);
1690
-
1691
- //struct ggml_tensor * KQ_scaled =
1692
- // ggml_scale(ctxL,
1693
- // KQ,
1694
- // ggml_new_f32(ctxL, 1.0f/sqrt(float(n_state)/n_head))
1695
- // );
1696
-
1697
- // no masking for cross-attention
1698
- //struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctxL, KQ_scaled, n_past);
1699
-
1700
- struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctxL, KQ);
1701
-
1702
- struct ggml_tensor * V_trans = ggml_permute(ctxL, Vcross, 1, 2, 0, 3);
1703
-
1704
- struct ggml_tensor * KQV = ggml_mul_mat(ctxL, V_trans, KQ_soft_max);
1705
-
1706
- struct ggml_tensor * KQV_merged = ggml_permute(ctxL, KQV, 0, 2, 1, 3);
1707
-
1708
- // cur = KQV_merged.contiguous().view(n_state, N)
1709
- cur = ggml_cpy(ctxL,
1710
- KQV_merged,
1711
- ggml_new_tensor_2d(ctxL, GGML_TYPE_F32, n_state, N));
1712
- }
1713
-
1714
- // projection
1715
- {
1716
- cur = ggml_mul_mat(ctxL,
1717
- layer.cross_attn_ln_1_w,
1718
- cur);
1719
-
1720
- cur = ggml_add(ctxL,
1721
- ggml_repeat(ctxL, layer.cross_attn_ln_1_b, cur),
1722
- cur);
1723
- }
1724
-
1725
- // add the input
1726
- cur = ggml_add(ctxL, cur, inpCA);
1727
-
1728
- struct ggml_tensor * inpFF = cur;
1729
-
1730
- // feed-forward network
1731
- {
1732
- // norm
1733
- {
1734
- cur = ggml_norm(ctxL, inpFF);
1735
-
1736
- // cur = mlp_ln_w*cur + mlp_ln_b
1737
- cur = ggml_add(ctxL,
1738
- ggml_mul(ctxL,
1739
- ggml_repeat(ctxL, layer.mlp_ln_w, cur),
1740
- cur),
1741
- ggml_repeat(ctxL, layer.mlp_ln_b, cur));
1742
- }
1743
-
1744
- // fully connected
1745
- cur = ggml_mul_mat(ctxL,
1746
- layer.mlp_0_w,
1747
- cur);
1748
-
1749
- cur = ggml_add(ctxL,
1750
- ggml_repeat(ctxL, layer.mlp_0_b, cur),
1751
- cur);
1752
-
1753
- // GELU activation
1754
- cur = ggml_gelu(ctxL, cur);
1755
-
1756
- // projection
1757
- cur = ggml_mul_mat(ctxL,
1758
- layer.mlp_1_w,
1759
- cur);
1760
-
1761
- cur = ggml_add(ctxL,
1762
- ggml_repeat(ctxL, layer.mlp_1_b, cur),
1763
- cur);
1764
- }
1765
-
1766
- // output from this layer
1767
- struct ggml_tensor * inpO = ggml_add(ctxL, cur, inpFF);
1768
-
1769
- {
1770
- ggml_build_forward_expand(&gf, inpO);
1771
- ggml_graph_compute (ctxL, &gf);
1772
-
1773
- //ggml_graph_print(&gf);
1774
- }
1775
-
1776
- // TODO: this is a hack to have per-layer computation graphs - need to come up with something better
1777
- // input for next layer (inpO -> inpL)
1778
- memcpy(inpL->data, inpO->data, ggml_nbytes(inpL));
1779
- inpL->op = GGML_OP_NONE;
1780
- inpL->src0 = NULL;
1781
- inpL->src1 = NULL;
1782
-
1783
- if (N > 1) {
1784
- //printf("%s: - used_mem(%d) = %f MB\n", __func__, il, ggml_used_mem(ctxL)/1024.0/1024.0);
1785
- }
1786
-
1787
- ggml_free(ctxL);
1788
- }
1789
-
1790
- cur = inpL;
1791
-
1792
- // norm
1793
- {
1794
- cur = ggml_norm(ctx0, cur);
1795
-
1796
- cur = ggml_add(ctx0,
1797
- ggml_mul(ctx0,
1798
- ggml_repeat(ctx0, model.d_ln_w, cur),
1799
- cur),
1800
- ggml_repeat(ctx0, model.d_ln_b, cur));
1801
- }
1802
-
1803
- struct ggml_tensor * logits = ggml_mul_mat(ctx0, model.d_te, cur);
1804
-
1805
- // logits -> probs
1806
- cur = ggml_dup(ctx0, logits);
1807
- cur = ggml_soft_max(ctx0, cur); // in-place
1808
-
1809
- // run the computation
1810
- {
1811
- struct ggml_cgraph gf = { .n_threads = n_threads };
1812
-
1813
- ggml_build_forward_expand(&gf, cur);
1814
- ggml_graph_compute (ctx0, &gf);
1815
- }
1816
-
1817
- logits_out.resize(N*n_vocab);
1818
- memcpy(logits_out.data(), ggml_get_data(logits), sizeof(float)*N*n_vocab);
1819
-
1820
- probs_out.resize(N*n_vocab);
1821
- memcpy(probs_out.data(), ggml_get_data(cur), sizeof(float)*N*n_vocab);
1822
-
1823
- if (N > 1) {
1824
- //const float mem_per_token = ggml_used_mem(ctx0)/1024.0/1024.0/N;
1825
- //printf("%s: used_mem = %f MB / %f per token\n", __func__, ggml_used_mem(ctx0)/1024.0/1024.0, mem_per_token);
1826
- //printf("%s: max mem = %f MB\n", __func__, mem_per_token*model.hparams.n_text_ctx);
1827
- }
1828
-
1829
- ggml_free(ctx0);
1830
-
1831
- return true;
1832
- }
1833
-
1834
- // the most basic sampling scheme - select the top token
1835
- // TODO: beam search
1836
- // TODO: temperature
1837
- whisper_vocab::id whisper_sample_best(
1838
- const whisper_vocab & vocab,
1839
- const float * probs, bool need_timestamp) {
1840
- int n_logits = vocab.id_to_token.size();
1841
-
1842
- std::vector<std::pair<double, whisper_vocab::id>> probs_id;
1843
- probs_id.reserve(n_logits);
1844
-
1845
- for (int i = 0; i < n_logits; i++) {
1846
- probs_id.push_back(std::make_pair(probs[i], i));
1847
- }
1848
-
1849
- const int top_k = 4;
1850
-
1851
- // find the top K tokens
1852
- std::partial_sort(
1853
- probs_id.begin(),
1854
- probs_id.begin() + top_k, probs_id.end(),
1855
- [](const std::pair<double, whisper_vocab::id> & a, const std::pair<double, whisper_vocab::id> & b) {
1856
- return a.first > b.first;
1857
- });
1858
-
1859
- probs_id.resize(top_k);
1860
-
1861
- //printf("\n");
1862
- //for (int i = 0; i < (int) probs_id.size(); i++) {
1863
- // printf("%d: '%s' %f, %d\n", i, vocab.id_to_token.at(probs_id[i].second).c_str(), probs_id[i].first, probs_id[i].second);
1864
- //}
1865
-
1866
- if (need_timestamp) {
1867
- // at the end of the 30-second audio segment, we start giving preference to time tokens
1868
- for (int i = 0; i < top_k; i++) {
1869
- if (probs_id[i].second > vocab.token_beg + 1300 && probs_id[i].first > 0.01*probs_id[0].first) {
1870
- return probs_id[i].second;
1871
- }
1872
- }
1873
- }
1874
-
1875
- int res = 0;
1876
- while ((probs_id[res].second == vocab.token_sot ||
1877
- probs_id[res].second == vocab.token_solm ||
1878
- probs_id[res].second == vocab.token_not) &&
1879
- res < (int) probs_id.size() - 1) {
1880
- res++;
1881
- }
1882
-
1883
- return probs_id[res].second;
1884
- }
1885
-
1886
- // samples only from the timestamps tokens
1887
- whisper_vocab::id whisper_sample_timestamp(
1888
- const whisper_vocab & vocab,
1889
- const float * probs) {
1890
- int n_logits = vocab.id_to_token.size();
1891
-
1892
- std::vector<std::pair<double, whisper_vocab::id>> probs_id;
1893
- probs_id.reserve(n_logits);
1894
-
1895
- for (int i = vocab.token_beg + 1; i < n_logits; i++) {
1896
- probs_id.push_back(std::make_pair(probs[i], i));
1897
- }
1898
-
1899
- const int top_k = 10;
1900
-
1901
- // find the top K tokens
1902
- std::partial_sort(
1903
- probs_id.begin(),
1904
- probs_id.begin() + top_k, probs_id.end(),
1905
- [](const std::pair<double, whisper_vocab::id> & a, const std::pair<double, whisper_vocab::id> & b) {
1906
- return a.first > b.first;
1907
- });
1908
-
1909
- probs_id.resize(top_k);
1910
-
1911
- //printf("\n");
1912
- //for (int i = 0; i < (int) probs_id.size(); i++) {
1913
- // printf("%d: '%s' %f, %d\n", i, vocab.id_to_token.at(probs_id[i].second).c_str(), probs_id[i].first, probs_id[i].second);
1914
- //}
1915
-
1916
- return probs_id[0].second;
1917
- }
1918
-
1919
- // naive Discrete Fourier Transform
1920
- // input is real-valued
1921
- // output is complex-valued
1922
- void dft(const std::vector<float> & in, std::vector<float> & out) {
1923
- int N = in.size();
1924
-
1925
- out.resize(N*2);
1926
 
1927
- for (int k = 0; k < N; k++) {
1928
- float re = 0;
1929
- float im = 0;
1930
 
1931
- for (int n = 0; n < N; n++) {
1932
- float angle = 2*M_PI*k*n/N;
1933
- re += in[n]*cos(angle);
1934
- im -= in[n]*sin(angle);
1935
- }
1936
 
1937
- out[k*2 + 0] = re;
1938
- out[k*2 + 1] = im;
1939
- }
1940
  }
1941
 
1942
- // Cooley-Tukey FFT
1943
- // poor man's implmentation - use something better
1944
- // input is real-valued
1945
- // output is complex-valued
1946
- void fft(const std::vector<float> & in, std::vector<float> & out) {
1947
- out.resize(in.size()*2);
1948
-
1949
- int N = in.size();
1950
-
1951
- if (N == 1) {
1952
- out[0] = in[0];
1953
- out[1] = 0;
1954
- return;
1955
- }
1956
-
1957
- if (N%2 == 1) {
1958
- dft(in, out);
1959
- return;
1960
- }
1961
-
1962
- std::vector<float> even;
1963
- std::vector<float> odd;
1964
-
1965
- for (int i = 0; i < N; i++) {
1966
- if (i % 2 == 0) {
1967
- even.push_back(in[i]);
1968
- } else {
1969
- odd.push_back(in[i]);
1970
- }
1971
- }
1972
-
1973
- std::vector<float> even_fft;
1974
- std::vector<float> odd_fft;
1975
-
1976
- fft(even, even_fft);
1977
- fft(odd, odd_fft);
1978
-
1979
- for (int k = 0; k < N/2; k++) {
1980
- float theta = 2*M_PI*k/N;
1981
-
1982
- float re = cos(theta);
1983
- float im = -sin(theta);
1984
-
1985
- float re_odd = odd_fft[2*k + 0];
1986
- float im_odd = odd_fft[2*k + 1];
1987
 
1988
- out[2*k + 0] = even_fft[2*k + 0] + re*re_odd - im*im_odd;
1989
- out[2*k + 1] = even_fft[2*k + 1] + re*im_odd + im*re_odd;
1990
 
1991
- out[2*(k + N/2) + 0] = even_fft[2*k + 0] - re*re_odd + im*im_odd;
1992
- out[2*(k + N/2) + 1] = even_fft[2*k + 1] - re*im_odd - im*re_odd;
1993
- }
1994
  }
1995
 
1996
- // ref: https://github.com/openai/whisper/blob/main/whisper/audio.py#L92-L124
1997
- bool log_mel_spectrogram(
1998
- const std::vector<float> sf32,
1999
- const int sample_rate,
2000
- const int fft_size,
2001
- const int fft_step,
2002
- const int n_mel,
2003
- const int n_threads,
2004
- const whisper_filters & filters,
2005
- whisper_mel & mel) {
2006
- const int n_sample = sf32.size();
2007
- const float * samples = sf32.data();
2008
-
2009
- // Hanning window
2010
- std::vector<float> hann;
2011
- hann.resize(fft_size);
2012
- for (int i = 0; i < fft_size; i++) {
2013
- hann[i] = 0.5*(1.0 - cos((2.0*M_PI*i)/(fft_size)));
2014
- }
2015
-
2016
- mel.n_mel = n_mel;
2017
- mel.n_len = (n_sample)/fft_step;
2018
- mel.data.resize(mel.n_mel*mel.n_len);
2019
-
2020
- const int n_fft = 1 + fft_size/2;
2021
-
2022
- //printf("%s: n_sample = %d, n_len = %d\n", __func__, n_sample, mel.n_len);
2023
- //printf("%s: recording length: %f s\n", __func__, (float) n_sample/sample_rate);
2024
-
2025
- std::vector<std::thread> workers(n_threads);
2026
- for (int iw = 0; iw < n_threads; ++iw) {
2027
- workers[iw] = std::thread([&](int ith) {
2028
- std::vector<float> fft_in;
2029
- fft_in.resize(fft_size);
2030
- for (int i = 0; i < fft_size; i++) {
2031
- fft_in[i] = 0.0;
2032
- }
2033
-
2034
- std::vector<float> fft_out;
2035
- fft_out.resize(2*fft_size);
2036
-
2037
- for (int i = ith; i < mel.n_len; i += n_threads) {
2038
- const int offset = i*fft_step;
2039
-
2040
- // apply Hanning window
2041
- for (int j = 0; j < fft_size; j++) {
2042
- if (offset + j < n_sample) {
2043
- fft_in[j] = hann[j]*samples[offset + j];
2044
- } else {
2045
- fft_in[j] = 0.0;
2046
- }
2047
- }
2048
 
2049
- // FFT -> mag^2
2050
- fft(fft_in, fft_out);
 
 
2051
 
2052
- for (int j = 0; j < fft_size; j++) {
2053
- fft_out[j] = (fft_out[2*j + 0]*fft_out[2*j + 0] + fft_out[2*j + 1]*fft_out[2*j + 1]);
2054
- }
2055
- for (int j = 1; j < fft_size/2; j++) {
2056
- fft_out[j] += fft_out[fft_size - j];
2057
- }
2058
 
2059
- // mel spectrogram
2060
- for (int j = 0; j < mel.n_mel; j++) {
2061
- double sum = 0.0;
 
2062
 
2063
- for (int k = 0; k < n_fft; k++) {
2064
- sum += fft_out[k]*filters.data[j*n_fft + k];
2065
- }
2066
- if (sum < 1e-10) {
2067
- sum = 1e-10;
2068
- }
2069
 
2070
- sum = log10(sum);
 
 
2071
 
2072
- mel.data[j*mel.n_len + i] = sum;
2073
- }
 
 
 
 
 
 
 
 
 
 
 
 
2074
  }
2075
- }, iw);
2076
- }
2077
-
2078
- for (int iw = 0; iw < n_threads; ++iw) {
2079
- workers[iw].join();
2080
- }
2081
-
2082
- // clamping and normalization
2083
- double mmax = -1e20;
2084
- for (int i = 0; i < mel.n_mel*mel.n_len; i++) {
2085
- if (mel.data[i] > mmax) {
2086
- mmax = mel.data[i];
2087
- }
2088
- }
2089
-
2090
- mmax -= 8.0;
2091
-
2092
- for (int i = 0; i < mel.n_mel*mel.n_len; i++) {
2093
- if (mel.data[i] < mmax) {
2094
- mel.data[i] = mmax;
2095
  }
2096
-
2097
- mel.data[i] = (mel.data[i] + 4.0)/4.0;
2098
  }
2099
 
2100
  return true;
2101
  }
2102
 
2103
- // 500 -> 00:05.000
2104
- // 6000 -> 01:00.000
2105
- std::string to_timestamp(int64_t t) {
2106
- int64_t sec = t/100;
2107
- int64_t msec = t - sec*100;
2108
- int64_t min = sec/60;
2109
- sec = sec - min*60;
2110
-
2111
- char buf[32];
2112
- snprintf(buf, sizeof(buf), "%02d:%02d.%03d", (int) min, (int) sec, (int) msec);
2113
-
2114
- return std::string(buf);
 
 
 
 
2115
  }
2116
 
2117
  //
@@ -2183,7 +186,7 @@ bool audio_sdl_init(const int capture_id) {
2183
  ///////////////////////////
2184
 
2185
  int main(int argc, char ** argv) {
2186
- const int64_t t_main_start_us = ggml_time_us();
2187
 
2188
  whisper_params params;
2189
 
@@ -2202,31 +205,9 @@ int main(int argc, char ** argv) {
2202
  return 1;
2203
  }
2204
 
2205
- // model loading
2206
 
2207
- //printf("%s: seed = %d\n", __func__, params.seed);
2208
-
2209
- int64_t t_load_us = 0;
2210
- int64_t t_mel_us = 0;
2211
- int64_t t_sample_us = 0;
2212
- int64_t t_encode_us = 0;
2213
- int64_t t_decode_us = 0;
2214
-
2215
- whisper_vocab vocab;
2216
- whisper_model model;
2217
-
2218
- // load the model
2219
- {
2220
- const int64_t t_start_us = ggml_time_us();
2221
-
2222
- if (!whisper_model_load(params.model, model, vocab)) {
2223
- fprintf(stderr, "%s: failed to load model from '%s'\n", __func__, params.model.c_str());
2224
- whisper_print_usage(argc, argv, {});
2225
- return 1;
2226
- }
2227
-
2228
- t_load_us = ggml_time_us() - t_start_us;
2229
- }
2230
 
2231
  const int n_samples_30s = 30*SAMPLE_RATE;
2232
  std::vector<float> pcmf32(n_samples_30s, 0.0f);
@@ -2235,7 +216,7 @@ int main(int argc, char ** argv) {
2235
  // print some info about the processing
2236
  {
2237
  printf("\n");
2238
- if (!vocab.is_multilingual()) {
2239
  if (params.language != "en" || params.translate) {
2240
  params.language = "en";
2241
  params.translate = false;
@@ -2244,7 +225,7 @@ int main(int argc, char ** argv) {
2244
  }
2245
  printf("%s: processing %d samples (%.1f sec), %d threads, lang = %s, task = %s, timestamps = %d ...\n",
2246
  __func__, int(pcmf32.size()), float(pcmf32.size())/SAMPLE_RATE, params.n_threads,
2247
- g_lang.at(params.language).second.c_str(),
2248
  params.translate ? "translate" : "transcribe",
2249
  params.no_timestamps ? 0 : 1);
2250
  printf("\n");
@@ -2291,26 +272,22 @@ int main(int argc, char ** argv) {
2291
  pcmf32_old = pcmf32;
2292
 
2293
  // compute log mel spectrogram
2294
- whisper_mel mel_inp;
2295
- {
2296
- const int64_t t_start_us = ggml_time_us();
2297
-
2298
- log_mel_spectrogram(pcmf32, SAMPLE_RATE, N_FFT, HOP_LENGTH, N_MEL, params.n_threads, model.filters, mel_inp);
2299
-
2300
- t_mel_us = ggml_time_us() - t_start_us;
2301
  }
2302
 
2303
  // the accumulated text context so far
2304
- std::vector<whisper_vocab::id> prompt_past = { };
2305
 
2306
  // these tokens determine the task that will be performed
2307
- std::vector<whisper_vocab::id> prompt_init = { vocab.token_sot };
2308
- if (vocab.is_multilingual()) {
2309
- prompt_init.push_back(vocab.token_sot + 1 + g_lang.at(params.language).first);
2310
  if (params.translate) {
2311
- prompt_init.push_back(vocab.token_translate);
2312
  } else {
2313
- prompt_init.push_back(vocab.token_transcribe);
2314
  }
2315
  }
2316
 
@@ -2320,35 +297,25 @@ int main(int argc, char ** argv) {
2320
  // main loop
2321
  int seek = 0;
2322
  while (true) {
2323
- if (seek >= mel_inp.n_len) {
2324
  break;
2325
  }
2326
 
2327
  // encode audio features starting at offset seek
2328
- std::vector<float> features;
2329
- {
2330
- const int64_t t_start_us = ggml_time_us();
2331
-
2332
- if (!whisper_encode(model, params.n_threads, seek, mel_inp, features)) {
2333
- fprintf(stderr, "%s: failed to eval\n", __func__);
2334
- return 1;
2335
- }
2336
-
2337
- t_encode_us += ggml_time_us() - t_start_us;
2338
  }
2339
 
2340
- std::vector<float> probs;
2341
- std::vector<float> logits;
2342
-
2343
- std::vector<whisper_vocab::id> prompt;
2344
 
2345
  int n_past = 0;
2346
 
2347
  // if we have already generated some text, use it as a prompt to condition the next generation
2348
  if (prompt_past.size() > 0) {
2349
- int n_take = std::min(model.hparams.n_text_ctx/2, int(prompt_past.size()));
2350
 
2351
- prompt = { vocab.token_prev };
2352
  prompt.insert(prompt.begin() + 1, prompt_past.end() - n_take, prompt_past.end());
2353
 
2354
  prompt_past.clear();
@@ -2359,7 +326,7 @@ int main(int argc, char ** argv) {
2359
 
2360
  bool done = false;
2361
  int seek_delta = 100*CHUNK_SIZE;
2362
- whisper_vocab::id last_id = 0;
2363
 
2364
  // print the prompt
2365
  //printf("\n\n");
@@ -2372,17 +339,10 @@ int main(int argc, char ** argv) {
2372
  int result_len = 0;
2373
  std::vector<whisper_result> result_cur;
2374
 
2375
- for (int i = 0; i < model.hparams.n_text_ctx/2 - 4; ++i) {
2376
- // decode
2377
- if (prompt.size() > 0) {
2378
- const int64_t t_start_us = ggml_time_us();
2379
-
2380
- if (!whisper_decode(model, params.n_threads, n_past, prompt, logits, probs)) {
2381
- fprintf(stderr, "%s: failed to eval\n", __func__);
2382
- return 1;
2383
- }
2384
-
2385
- t_decode_us += ggml_time_us() - t_start_us;
2386
  }
2387
 
2388
  n_past += prompt.size();
@@ -2396,37 +356,31 @@ int main(int argc, char ** argv) {
2396
  // feel free to experiment!
2397
  //
2398
  {
2399
- const int n_vocab = model.hparams.n_vocab;
2400
-
2401
- whisper_vocab::id id = 0;
2402
- whisper_vocab::id tid = vocab.token_beg;
2403
-
2404
- {
2405
- const int64_t t_start_sample_us = ggml_time_us();
2406
 
2407
- id = whisper_sample_best(vocab, probs.data() + (probs.size() - n_vocab), result_len == 0);
2408
- if (i > 0) {
2409
- tid = whisper_sample_timestamp(vocab, probs.data() + (probs.size() - n_vocab));
2410
- }
2411
 
2412
- t_sample_us += ggml_time_us() - t_start_sample_us;
 
 
2413
  }
2414
 
2415
  // update sliding window
2416
- if (id > vocab.token_beg) {
2417
- seek_delta = 2*(id - vocab.token_beg);
2418
  result_len = i + 1;
2419
  }
2420
  last_id = id;
2421
 
2422
  // add it to the context
2423
  prompt.push_back(id);
2424
- result_cur.push_back({ id, seek + 2*(tid - vocab.token_beg) });
2425
 
2426
  //printf("%s: %s\n", __func__, vocab.id_to_token[id].c_str());
2427
 
2428
  // end of text token
2429
- if (id == vocab.token_eot) {
2430
  break;
2431
  }
2432
  }
@@ -2449,11 +403,11 @@ int main(int argc, char ** argv) {
2449
 
2450
  std::string text = "";
2451
  for (int i = 0; i < result_cur.size(); i++) {
2452
- if (params.print_special_tokens == false && result_cur[i].id >= vocab.token_eot) {
2453
  } else {
2454
- text += vocab.id_to_token[result_cur[i].id];
2455
  }
2456
- if (result_cur[i].id > vocab.token_beg) {
2457
  const auto t1 = result_cur[i].t;
2458
  if (!text.empty()) {
2459
  if (params.no_timestamps) {
@@ -2464,7 +418,7 @@ int main(int argc, char ** argv) {
2464
  }
2465
  }
2466
  text = "";
2467
- while (result_cur[i].id > vocab.token_beg && i < result_cur.size()) {
2468
  i++;
2469
  }
2470
  i--;
@@ -2481,45 +435,8 @@ int main(int argc, char ** argv) {
2481
  }
2482
  }
2483
 
2484
- // WIP: attempt for per-token timestamps
2485
- //if (!params.no_timestamps && result_all.size() > 0) {
2486
- // const int64_t dt = 500; // 5 second intervals
2487
-
2488
- // int i0 = 0;
2489
-
2490
- // int64_t t0 = result_all[0].t;
2491
- // int64_t t1 = t0;
2492
-
2493
- // printf("\n\n");
2494
- // for (int i = 0; i < result_all.size(); ++i) {
2495
- // printf("'%s' -> %lld\n", vocab.id_to_token[result_all[i].id].c_str(), result_all[i].t);
2496
- // if (result_all[i].t - t0 > dt) {
2497
- // t1 = result_all[i - 1].t;
2498
- // printf("[%s --> %s] ", to_timestamp(t0).c_str(), to_timestamp(t1).c_str());
2499
- // for (int j = i0; j < i; ++j) {
2500
- // printf("%s", vocab.id_to_token.at(result_all[j].id).c_str());
2501
- // }
2502
- // printf("\n");
2503
- // i0 = i;
2504
- // t0 = result_all[i].t;
2505
- // }
2506
- // }
2507
- //}
2508
-
2509
- // report timing
2510
- {
2511
- const int64_t t_main_end_us = ggml_time_us();
2512
-
2513
- printf("\n\n");
2514
- printf("%s: load time = %8.2f ms\n", __func__, t_load_us/1000.0f);
2515
- printf("%s: mel time = %8.2f ms\n", __func__, t_mel_us/1000.0f);
2516
- printf("%s: sample time = %8.2f ms\n", __func__, t_sample_us/1000.0f);
2517
- printf("%s: encode time = %8.2f ms / %.2f ms per layer\n", __func__, t_encode_us/1000.0f, t_encode_us/1000.0f/model.hparams.n_audio_layer);
2518
- printf("%s: decode time = %8.2f ms\n", __func__, t_decode_us/1000.0f);
2519
- printf("%s: total time = %8.2f ms\n", __func__, (t_main_end_us - t_main_start_us)/1000.0f);
2520
- }
2521
-
2522
- ggml_free(model.ctx);
2523
 
2524
  return 0;
2525
  }
 
2
  //
3
  // A very quick-n-dirty implementation serving mainly as a proof of concept.
4
 
5
+ #include "whisper.h"
6
 
7
+ // third-party utilities
8
+ // use your favorite implementations
9
+ #define DR_WAV_IMPLEMENTATION
10
+ #include "dr_wav.h"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
+ #include <SDL.h>
13
+ #include <SDL_audio.h>
 
14
 
15
+ #include <cassert>
16
+ #include <cstdio>
17
+ #include <string>
18
+ #include <thread>
19
+ #include <vector>
20
 
21
+ int64_t get_time_us() {
22
+ return std::chrono::duration_cast<std::chrono::microseconds>(
23
+ std::chrono::high_resolution_clock::now().time_since_epoch()).count();
24
  }
25
 
26
+ // 500 -> 00:05.000
27
+ // 6000 -> 01:00.000
28
+ std::string to_timestamp(int64_t t) {
29
+ int64_t sec = t/100;
30
+ int64_t msec = t - sec*100;
31
+ int64_t min = sec/60;
32
+ sec = sec - min*60;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
 
34
+ char buf[32];
35
+ snprintf(buf, sizeof(buf), "%02d:%02d.%03d", (int) min, (int) sec, (int) msec);
36
 
37
+ return std::string(buf);
 
 
38
  }
39
 
40
+ struct whisper_result {
41
+ whisper_token id;
42
+ int64_t t;
43
+ };
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
 
45
+ // command-line parameters
46
+ struct whisper_params {
47
+ int32_t seed = -1; // RNG seed, not used currently
48
+ int32_t n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency());
49
 
50
+ bool verbose = false;
51
+ bool translate = false;
52
+ bool print_special_tokens = false;
53
+ bool no_timestamps = true;
 
 
54
 
55
+ std::string language = "en";
56
+ std::string model = "models/ggml-base.en.bin";
57
+ std::string fname_inp = "samples/jfk.wav";
58
+ };
59
 
60
+ void whisper_print_usage(int argc, char ** argv, const whisper_params & params);
 
 
 
 
 
61
 
62
+ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
63
+ for (int i = 1; i < argc; i++) {
64
+ std::string arg = argv[i];
65
 
66
+ if (arg == "-s" || arg == "--seed") {
67
+ params.seed = std::stoi(argv[++i]);
68
+ } else if (arg == "-t" || arg == "--threads") {
69
+ params.n_threads = std::stoi(argv[++i]);
70
+ } else if (arg == "-v" || arg == "--verbose") {
71
+ params.verbose = true;
72
+ } else if (arg == "--translate") {
73
+ params.translate = true;
74
+ } else if (arg == "-l" || arg == "--language") {
75
+ params.language = argv[++i];
76
+ if (whisper_lang_id(params.language.c_str()) == -1) {
77
+ fprintf(stderr, "error: unknown language '%s'\n", params.language.c_str());
78
+ whisper_print_usage(argc, argv, params);
79
+ exit(0);
80
  }
81
+ } else if (arg == "-ps" || arg == "--print_special") {
82
+ params.print_special_tokens = true;
83
+ } else if (arg == "-nt" || arg == "--no_timestamps") {
84
+ params.no_timestamps = true;
85
+ } else if (arg == "-m" || arg == "--model") {
86
+ params.model = argv[++i];
87
+ } else if (arg == "-f" || arg == "--file") {
88
+ params.fname_inp = argv[++i];
89
+ } else if (arg == "-h" || arg == "--help") {
90
+ whisper_print_usage(argc, argv, params);
91
+ exit(0);
92
+ } else {
93
+ fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
94
+ whisper_print_usage(argc, argv, params);
95
+ exit(0);
 
 
 
 
 
96
  }
 
 
97
  }
98
 
99
  return true;
100
  }
101
 
102
+ void whisper_print_usage(int argc, char ** argv, const whisper_params & params) {
103
+ fprintf(stderr, "\n");
104
+ fprintf(stderr, "usage: %s [options]\n", argv[0]);
105
+ fprintf(stderr, "\n");
106
+ fprintf(stderr, "options:\n");
107
+ fprintf(stderr, " -h, --help show this help message and exit\n");
108
+ fprintf(stderr, " -s SEED, --seed SEED RNG seed (default: -1)\n");
109
+ fprintf(stderr, " -t N, --threads N number of threads to use during computation (default: %d)\n", params.n_threads);
110
+ fprintf(stderr, " -v, --verbose verbose output\n");
111
+ fprintf(stderr, " --translate translate from source language to english\n");
112
+ fprintf(stderr, " -ps, --print_special print special tokens\n");
113
+ fprintf(stderr, " -nt, --no_timestamps do not print timestamps\n");
114
+ fprintf(stderr, " -l LANG, --language LANG spoken language (default: %s)\n", params.language.c_str());
115
+ fprintf(stderr, " -m FNAME, --model FNAME model path (default: %s)\n", params.model.c_str());
116
+ fprintf(stderr, " -f FNAME, --file FNAME input WAV file path (default: %s)\n", params.fname_inp.c_str());
117
+ fprintf(stderr, "\n");
118
  }
119
 
120
  //
 
186
  ///////////////////////////
187
 
188
  int main(int argc, char ** argv) {
189
+ const int64_t t_main_start_us = get_time_us();
190
 
191
  whisper_params params;
192
 
 
205
  return 1;
206
  }
207
 
208
+ // whisper init
209
 
210
+ struct whisper_context * ctx = whisper_init(params.model.c_str());
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
211
 
212
  const int n_samples_30s = 30*SAMPLE_RATE;
213
  std::vector<float> pcmf32(n_samples_30s, 0.0f);
 
216
  // print some info about the processing
217
  {
218
  printf("\n");
219
+ if (!whisper_is_multilingual(ctx)) {
220
  if (params.language != "en" || params.translate) {
221
  params.language = "en";
222
  params.translate = false;
 
225
  }
226
  printf("%s: processing %d samples (%.1f sec), %d threads, lang = %s, task = %s, timestamps = %d ...\n",
227
  __func__, int(pcmf32.size()), float(pcmf32.size())/SAMPLE_RATE, params.n_threads,
228
+ params.language.c_str(),
229
  params.translate ? "translate" : "transcribe",
230
  params.no_timestamps ? 0 : 1);
231
  printf("\n");
 
272
  pcmf32_old = pcmf32;
273
 
274
  // compute log mel spectrogram
275
+ if (whisper_pcm_to_mel(ctx, pcmf32.data(), pcmf32.size(), params.n_threads) != 0) {
276
+ fprintf(stderr, "%s: failed to compute log mel spectrogram\n", argv[0]);
277
+ return 6;
 
 
 
 
278
  }
279
 
280
  // the accumulated text context so far
281
+ std::vector<whisper_token> prompt_past = { };
282
 
283
  // these tokens determine the task that will be performed
284
+ std::vector<whisper_token> prompt_init = { whisper_token_sot(ctx) };
285
+ if (whisper_is_multilingual(ctx)) {
286
+ prompt_init.push_back(whisper_token_sot(ctx) + 1 + whisper_lang_id(params.language.c_str()));
287
  if (params.translate) {
288
+ prompt_init.push_back(whisper_token_translate());
289
  } else {
290
+ prompt_init.push_back(whisper_token_transcribe());
291
  }
292
  }
293
 
 
297
  // main loop
298
  int seek = 0;
299
  while (true) {
300
+ if (seek >= whisper_n_len(ctx)) {
301
  break;
302
  }
303
 
304
  // encode audio features starting at offset seek
305
+ if (whisper_encode(ctx, seek, params.n_threads) != 0) {
306
+ fprintf(stderr, "%s: failed to encode\n", __func__);
307
+ return 7;
 
 
 
 
 
 
 
308
  }
309
 
310
+ std::vector<whisper_token> prompt;
 
 
 
311
 
312
  int n_past = 0;
313
 
314
  // if we have already generated some text, use it as a prompt to condition the next generation
315
  if (prompt_past.size() > 0) {
316
+ int n_take = std::min(whisper_n_text_ctx(ctx)/2, int(prompt_past.size()));
317
 
318
+ prompt = { whisper_token_prev(ctx) };
319
  prompt.insert(prompt.begin() + 1, prompt_past.end() - n_take, prompt_past.end());
320
 
321
  prompt_past.clear();
 
326
 
327
  bool done = false;
328
  int seek_delta = 100*CHUNK_SIZE;
329
+ whisper_token last_id = 0;
330
 
331
  // print the prompt
332
  //printf("\n\n");
 
339
  int result_len = 0;
340
  std::vector<whisper_result> result_cur;
341
 
342
+ for (int i = 0; i < whisper_n_text_ctx(ctx)/2 - 4; ++i) {
343
+ if (whisper_decode(ctx, prompt.data(), prompt.size(), n_past, params.n_threads) != 0) {
344
+ fprintf(stderr, "%s: failed to decode\n", __func__);
345
+ return 8;
 
 
 
 
 
 
 
346
  }
347
 
348
  n_past += prompt.size();
 
356
  // feel free to experiment!
357
  //
358
  {
359
+ const int n_vocab = whisper_n_vocab(ctx);
 
 
 
 
 
 
360
 
361
+ whisper_token id = 0;
362
+ whisper_token tid = whisper_token_beg(ctx);
 
 
363
 
364
+ id = whisper_sample_best(ctx, result_len == 0);
365
+ if (i > 0) {
366
+ tid = whisper_sample_timestamp(ctx);
367
  }
368
 
369
  // update sliding window
370
+ if (id > whisper_token_beg(ctx)) {
371
+ seek_delta = 2*(id - whisper_token_beg(ctx));
372
  result_len = i + 1;
373
  }
374
  last_id = id;
375
 
376
  // add it to the context
377
  prompt.push_back(id);
378
+ result_cur.push_back({ id, seek + 2*(tid - whisper_token_beg(ctx)) });
379
 
380
  //printf("%s: %s\n", __func__, vocab.id_to_token[id].c_str());
381
 
382
  // end of text token
383
+ if (id == whisper_token_eot(ctx)) {
384
  break;
385
  }
386
  }
 
403
 
404
  std::string text = "";
405
  for (int i = 0; i < result_cur.size(); i++) {
406
+ if (params.print_special_tokens == false && result_cur[i].id >= whisper_token_eot(ctx)) {
407
  } else {
408
+ text += whisper_token_to_str(ctx, result_cur[i].id);
409
  }
410
+ if (result_cur[i].id > whisper_token_beg(ctx)) {
411
  const auto t1 = result_cur[i].t;
412
  if (!text.empty()) {
413
  if (params.no_timestamps) {
 
418
  }
419
  }
420
  text = "";
421
+ while (result_cur[i].id > whisper_token_beg(ctx) && i < result_cur.size()) {
422
  i++;
423
  }
424
  i--;
 
435
  }
436
  }
437
 
438
+ whisper_print_timings(ctx);
439
+ whisper_free(ctx);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
440
 
441
  return 0;
442
  }
whisper.cpp ADDED
@@ -0,0 +1,2221 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include "whisper.h"
2
+
3
+ #include "ggml.h"
4
+
5
+ #include <algorithm>
6
+ #include <cassert>
7
+ #include <cmath>
8
+ #include <cstdio>
9
+ #include <cstring>
10
+ #include <fstream>
11
+ #include <map>
12
+ #include <string>
13
+ #include <thread>
14
+ #include <vector>
15
+
16
+ #define USE_FLASH_ATTN
17
+ #define USE_FLASH_FF
18
+
19
+ // available whisper models
20
+ enum e_model {
21
+ MODEL_UNKNOWN,
22
+ MODEL_TINY,
23
+ MODEL_BASE,
24
+ MODEL_SMALL,
25
+ MODEL_MEDIUM,
26
+ MODEL_LARGE,
27
+ };
28
+
29
+ static const std::map<std::string, std::pair<int, std::string>> g_lang = {
30
+ { "en", { 0, "english", } },
31
+ { "zh", { 1, "chinese", } },
32
+ { "de", { 2, "german", } },
33
+ { "es", { 3, "spanish", } },
34
+ { "ru", { 4, "russian", } },
35
+ { "ko", { 5, "korean", } },
36
+ { "fr", { 6, "french", } },
37
+ { "ja", { 7, "japanese", } },
38
+ { "pt", { 8, "portuguese", } },
39
+ { "tr", { 9, "turkish", } },
40
+ { "pl", { 10, "polish", } },
41
+ { "ca", { 11, "catalan", } },
42
+ { "nl", { 12, "dutch", } },
43
+ { "ar", { 13, "arabic", } },
44
+ { "sv", { 14, "swedish", } },
45
+ { "it", { 15, "italian", } },
46
+ { "id", { 16, "indonesian", } },
47
+ { "hi", { 17, "hindi", } },
48
+ { "fi", { 18, "finnish", } },
49
+ { "vi", { 19, "vietnamese", } },
50
+ { "iw", { 20, "hebrew", } },
51
+ { "uk", { 21, "ukrainian", } },
52
+ { "el", { 22, "greek", } },
53
+ { "ms", { 23, "malay", } },
54
+ { "cs", { 24, "czech", } },
55
+ { "ro", { 25, "romanian", } },
56
+ { "da", { 26, "danish", } },
57
+ { "hu", { 27, "hungarian", } },
58
+ { "ta", { 28, "tamil", } },
59
+ { "no", { 29, "norwegian", } },
60
+ { "th", { 30, "thai", } },
61
+ { "ur", { 31, "urdu", } },
62
+ { "hr", { 32, "croatian", } },
63
+ { "bg", { 33, "bulgarian", } },
64
+ { "lt", { 34, "lithuanian", } },
65
+ { "la", { 35, "latin", } },
66
+ { "mi", { 36, "maori", } },
67
+ { "ml", { 37, "malayalam", } },
68
+ { "cy", { 38, "welsh", } },
69
+ { "sk", { 39, "slovak", } },
70
+ { "te", { 40, "telugu", } },
71
+ { "fa", { 41, "persian", } },
72
+ { "lv", { 42, "latvian", } },
73
+ { "bn", { 43, "bengali", } },
74
+ { "sr", { 44, "serbian", } },
75
+ { "az", { 45, "azerbaijani", } },
76
+ { "sl", { 46, "slovenian", } },
77
+ { "kn", { 47, "kannada", } },
78
+ { "et", { 48, "estonian", } },
79
+ { "mk", { 49, "macedonian", } },
80
+ { "br", { 50, "breton", } },
81
+ { "eu", { 51, "basque", } },
82
+ { "is", { 52, "icelandic", } },
83
+ { "hy", { 53, "armenian", } },
84
+ { "ne", { 54, "nepali", } },
85
+ { "mn", { 55, "mongolian", } },
86
+ { "bs", { 56, "bosnian", } },
87
+ { "kk", { 57, "kazakh", } },
88
+ { "sq", { 58, "albanian", } },
89
+ { "sw", { 59, "swahili", } },
90
+ { "gl", { 60, "galician", } },
91
+ { "mr", { 61, "marathi", } },
92
+ { "pa", { 62, "punjabi", } },
93
+ { "si", { 63, "sinhala", } },
94
+ { "km", { 64, "khmer", } },
95
+ { "sn", { 65, "shona", } },
96
+ { "yo", { 66, "yoruba", } },
97
+ { "so", { 67, "somali", } },
98
+ { "af", { 68, "afrikaans", } },
99
+ { "oc", { 69, "occitan", } },
100
+ { "ka", { 70, "georgian", } },
101
+ { "be", { 71, "belarusian", } },
102
+ { "tg", { 72, "tajik", } },
103
+ { "sd", { 73, "sindhi", } },
104
+ { "gu", { 74, "gujarati", } },
105
+ { "am", { 75, "amharic", } },
106
+ { "yi", { 76, "yiddish", } },
107
+ { "lo", { 77, "lao", } },
108
+ { "uz", { 78, "uzbek", } },
109
+ { "fo", { 79, "faroese", } },
110
+ { "ht", { 80, "haitian creole", } },
111
+ { "ps", { 81, "pashto", } },
112
+ { "tk", { 82, "turkmen", } },
113
+ { "nn", { 83, "nynorsk", } },
114
+ { "mt", { 84, "maltese", } },
115
+ { "sa", { 85, "sanskrit", } },
116
+ { "lb", { 86, "luxembourgish", } },
117
+ { "my", { 87, "myanmar", } },
118
+ { "bo", { 88, "tibetan", } },
119
+ { "tl", { 89, "tagalog", } },
120
+ { "mg", { 90, "malagasy", } },
121
+ { "as", { 91, "assamese", } },
122
+ { "tt", { 92, "tatar", } },
123
+ { "haw", { 93, "hawaiian", } },
124
+ { "ln", { 94, "lingala", } },
125
+ { "ha", { 95, "hausa", } },
126
+ { "ba", { 96, "bashkir", } },
127
+ { "jw", { 97, "javanese", } },
128
+ { "su", { 98, "sundanese", } },
129
+ };
130
+
131
+ static const size_t MB = 1024*1024;
132
+
133
+ static const std::map<e_model, size_t> MEM_REQ_MODEL = {
134
+ { MODEL_TINY, 86ull*MB },
135
+ { MODEL_BASE, 165ull*MB },
136
+ { MODEL_SMALL, 540ull*MB },
137
+ { MODEL_MEDIUM, 1650ull*MB },
138
+ { MODEL_LARGE, 3260ull*MB },
139
+ };
140
+
141
+ static const std::map<e_model, size_t> MEM_REQ_ENCODE = {
142
+ { MODEL_TINY, 80ull*MB },
143
+ { MODEL_BASE, 128ull*MB },
144
+ { MODEL_SMALL, 300ull*MB },
145
+ { MODEL_MEDIUM, 680ull*MB },
146
+ { MODEL_LARGE, 1100ull*MB },
147
+ };
148
+
149
+ static const std::map<e_model, size_t> MEM_REQ_ENCODE_LAYER = {
150
+ { MODEL_TINY, 64ull*MB },
151
+ { MODEL_BASE, 84ull*MB },
152
+ { MODEL_SMALL, 128ull*MB },
153
+ { MODEL_MEDIUM, 172ull*MB },
154
+ { MODEL_LARGE, 216ull*MB },
155
+ };
156
+
157
+ static const std::map<e_model, size_t> MEM_REQ_DECODE = {
158
+ { MODEL_TINY, 94ull*MB },
159
+ { MODEL_BASE, 96ull*MB },
160
+ { MODEL_SMALL, 98ull*MB },
161
+ { MODEL_MEDIUM, 100ull*MB },
162
+ { MODEL_LARGE, 102ull*MB },
163
+ };
164
+
165
+ static const std::map<e_model, size_t> MEM_REQ_DECODE_LAYER = {
166
+ { MODEL_TINY, 32ull*MB },
167
+ { MODEL_BASE, 44ull*MB },
168
+ { MODEL_SMALL, 64ull*MB },
169
+ { MODEL_MEDIUM, 84ull*MB },
170
+ { MODEL_LARGE, 110ull*MB },
171
+ };
172
+
173
+ struct whisper_mel {
174
+ int n_len;
175
+ int n_mel;
176
+
177
+ std::vector<float> data;
178
+ };
179
+
180
+ struct whisper_filters {
181
+ int32_t n_mel;
182
+ int32_t n_fft;
183
+
184
+ std::vector<float> data;
185
+ };
186
+
187
+ struct whisper_vocab {
188
+ using id = int32_t;
189
+ using token = std::string;
190
+
191
+ int n_vocab = 51864;
192
+
193
+ std::map<token, id> token_to_id;
194
+ std::map<id, token> id_to_token;
195
+
196
+ id token_eot = 50256;
197
+ id token_sot = 50257;
198
+ id token_prev = 50360;
199
+ id token_solm = 50361; // ??
200
+ id token_not = 50362; // no timestamps
201
+ id token_beg = 50363;
202
+
203
+ // available tasks
204
+ static const id token_translate = 50358;
205
+ static const id token_transcribe = 50359;
206
+
207
+ bool is_multilingual() const {
208
+ return n_vocab == 51865;
209
+ }
210
+ };
211
+
212
+ struct whisper_result {
213
+ whisper_vocab::id id;
214
+ int64_t t;
215
+ };
216
+
217
+ // medium
218
+ // hparams: {
219
+ // 'n_mels': 80,
220
+ // 'n_vocab': 51864,
221
+ // 'n_audio_ctx': 1500,
222
+ // 'n_audio_state': 1024,
223
+ // 'n_audio_head': 16,
224
+ // 'n_audio_layer': 24,
225
+ // 'n_text_ctx': 448,
226
+ // 'n_text_state': 1024,
227
+ // 'n_text_head': 16,
228
+ // 'n_text_layer': 24
229
+ // }
230
+ //
231
+ // default hparams (Whisper tiny)
232
+ struct whisper_hparams {
233
+ int32_t n_vocab = 51864;
234
+ int32_t n_audio_ctx = 1500;
235
+ int32_t n_audio_state = 384;
236
+ int32_t n_audio_head = 6;
237
+ int32_t n_audio_layer = 4;
238
+ int32_t n_text_ctx = 448;
239
+ int32_t n_text_state = 384;
240
+ int32_t n_text_head = 6;
241
+ int32_t n_text_layer = 4;
242
+ int32_t n_mels = 80;
243
+ int32_t f16 = 1;
244
+ };
245
+
246
+ // audio encoding layer
247
+ struct whisper_layer_encoder {
248
+ // encoder.blocks.*.attn_ln
249
+ struct ggml_tensor * attn_ln_0_w;
250
+ struct ggml_tensor * attn_ln_0_b;
251
+
252
+ // encoder.blocks.*.attn.out
253
+ struct ggml_tensor * attn_ln_1_w;
254
+ struct ggml_tensor * attn_ln_1_b;
255
+
256
+ // encoder.blocks.*.attn.query
257
+ struct ggml_tensor * attn_q_w;
258
+ struct ggml_tensor * attn_q_b;
259
+
260
+ // encoder.blocks.*.attn.key
261
+ struct ggml_tensor * attn_k_w;
262
+
263
+ // encoder.blocks.*.attn.value
264
+ struct ggml_tensor * attn_v_w;
265
+ struct ggml_tensor * attn_v_b;
266
+
267
+ // encoder.blocks.*.mlp_ln
268
+ struct ggml_tensor * mlp_ln_w;
269
+ struct ggml_tensor * mlp_ln_b;
270
+
271
+ // encoder.blocks.*.mlp.0
272
+ struct ggml_tensor * mlp_0_w;
273
+ struct ggml_tensor * mlp_0_b;
274
+
275
+ // encoder.blocks.*.mlp.2
276
+ struct ggml_tensor * mlp_1_w;
277
+ struct ggml_tensor * mlp_1_b;
278
+ };
279
+
280
+ // token decoding layer
281
+ struct whisper_layer_decoder {
282
+ // decoder.blocks.*.attn_ln
283
+ struct ggml_tensor * attn_ln_0_w;
284
+ struct ggml_tensor * attn_ln_0_b;
285
+
286
+ // decoder.blocks.*.attn.out
287
+ struct ggml_tensor * attn_ln_1_w;
288
+ struct ggml_tensor * attn_ln_1_b;
289
+
290
+ // decoder.blocks.*.attn.query
291
+ struct ggml_tensor * attn_q_w;
292
+ struct ggml_tensor * attn_q_b;
293
+
294
+ // decoder.blocks.*.attn.key
295
+ struct ggml_tensor * attn_k_w;
296
+
297
+ // decoder.blocks.*.attn.value
298
+ struct ggml_tensor * attn_v_w;
299
+ struct ggml_tensor * attn_v_b;
300
+
301
+ // decoder.blocks.*.cross_attn_ln
302
+ struct ggml_tensor * cross_attn_ln_0_w;
303
+ struct ggml_tensor * cross_attn_ln_0_b;
304
+
305
+ // decoder.blocks.*.cross_attn.out
306
+ struct ggml_tensor * cross_attn_ln_1_w;
307
+ struct ggml_tensor * cross_attn_ln_1_b;
308
+
309
+ // decoder.blocks.*.cross_attn.query
310
+ struct ggml_tensor * cross_attn_q_w;
311
+ struct ggml_tensor * cross_attn_q_b;
312
+
313
+ // decoder.blocks.*.cross_attn.key
314
+ struct ggml_tensor * cross_attn_k_w;
315
+
316
+ // decoder.blocks.*.cross_attn.value
317
+ struct ggml_tensor * cross_attn_v_w;
318
+ struct ggml_tensor * cross_attn_v_b;
319
+
320
+ // decoder.blocks.*.mlp_ln
321
+ struct ggml_tensor * mlp_ln_w;
322
+ struct ggml_tensor * mlp_ln_b;
323
+
324
+ // decoder.blocks.*.mlp.0
325
+ struct ggml_tensor * mlp_0_w;
326
+ struct ggml_tensor * mlp_0_b;
327
+
328
+ // decoder.blocks.*.mlp.2
329
+ struct ggml_tensor * mlp_1_w;
330
+ struct ggml_tensor * mlp_1_b;
331
+ };
332
+
333
+ struct whisper_model {
334
+ e_model type = MODEL_UNKNOWN;
335
+
336
+ whisper_hparams hparams;
337
+ whisper_filters filters;
338
+
339
+ // encoder.positional_embedding
340
+ struct ggml_tensor * e_pe;
341
+
342
+ // encoder.conv1
343
+ struct ggml_tensor * e_conv_1_w;
344
+ struct ggml_tensor * e_conv_1_b;
345
+
346
+ // encoder.conv2
347
+ struct ggml_tensor * e_conv_2_w;
348
+ struct ggml_tensor * e_conv_2_b;
349
+
350
+ // encoder.ln_post
351
+ struct ggml_tensor * e_ln_w;
352
+ struct ggml_tensor * e_ln_b;
353
+
354
+ // decoder.positional_embedding
355
+ struct ggml_tensor * d_pe; // DD
356
+
357
+ // decoder.token_embedding
358
+ struct ggml_tensor * d_te; // DD
359
+
360
+ // decoder.ln
361
+ struct ggml_tensor * d_ln_w; // DD
362
+ struct ggml_tensor * d_ln_b; // DD
363
+
364
+ std::vector<whisper_layer_encoder> layers_encoder;
365
+ std::vector<whisper_layer_decoder> layers_decoder;
366
+
367
+ // key + value memory
368
+ struct ggml_tensor * memory_k;
369
+ struct ggml_tensor * memory_v;
370
+
371
+ struct ggml_tensor * memory_cross_k;
372
+ struct ggml_tensor * memory_cross_v;
373
+
374
+ //
375
+ struct ggml_context * ctx;
376
+ std::map<std::string, struct ggml_tensor *> tensors;
377
+ };
378
+
379
+ struct whisper_context {
380
+ int64_t t_load_us = 0;
381
+ int64_t t_mel_us = 0;
382
+ int64_t t_sample_us = 0;
383
+ int64_t t_encode_us = 0;
384
+ int64_t t_decode_us = 0;
385
+ int64_t t_start_us = 0;
386
+
387
+ std::vector<uint8_t> buf_model;
388
+ std::vector<uint8_t> buf_compute;
389
+ std::vector<uint8_t> buf_compute_layer;
390
+
391
+ whisper_model model;
392
+ whisper_vocab vocab;
393
+
394
+ whisper_mel mel;
395
+
396
+ std::vector<float> probs;
397
+ std::vector<float> logits;
398
+ };
399
+
400
+ // load the model from a ggml file
401
+ //
402
+ // file format:
403
+ //
404
+ // - hparams
405
+ // - pre-computed mel filters
406
+ // - vocab
407
+ // - weights
408
+ //
409
+ // see the convert-pt-to-ggml.py script for details
410
+ //
411
+ bool whisper_model_load(const std::string & fname, whisper_context & wctx) {
412
+ printf("%s: loading model from '%s'\n", __func__, fname.c_str());
413
+
414
+ auto & model = wctx.model;
415
+ auto & vocab = wctx.vocab;
416
+
417
+ auto fin = std::ifstream(fname, std::ios::binary);
418
+ if (!fin) {
419
+ fprintf(stderr, "%s: failed to open '%s'\n", __func__, fname.c_str());
420
+ return false;
421
+ }
422
+
423
+ // verify magic
424
+ {
425
+ uint32_t magic;
426
+ fin.read((char *) &magic, sizeof(magic));
427
+ if (magic != 0x67676d6c) {
428
+ fprintf(stderr, "%s: invalid model file '%s' (bad magic)\n", __func__, fname.c_str());
429
+ return false;
430
+ }
431
+ }
432
+
433
+ //load hparams
434
+ {
435
+ auto & hparams = model.hparams;
436
+
437
+ fin.read((char *) &hparams.n_vocab, sizeof(hparams.n_vocab));
438
+ fin.read((char *) &hparams.n_audio_ctx, sizeof(hparams.n_audio_ctx));
439
+ fin.read((char *) &hparams.n_audio_state, sizeof(hparams.n_audio_state));
440
+ fin.read((char *) &hparams.n_audio_head, sizeof(hparams.n_audio_head));
441
+ fin.read((char *) &hparams.n_audio_layer, sizeof(hparams.n_audio_layer));
442
+ fin.read((char *) &hparams.n_text_ctx, sizeof(hparams.n_text_ctx));
443
+ fin.read((char *) &hparams.n_text_state, sizeof(hparams.n_text_state));
444
+ fin.read((char *) &hparams.n_text_head, sizeof(hparams.n_text_head));
445
+ fin.read((char *) &hparams.n_text_layer, sizeof(hparams.n_text_layer));
446
+ fin.read((char *) &hparams.n_mels, sizeof(hparams.n_mels));
447
+ fin.read((char *) &hparams.f16, sizeof(hparams.f16));
448
+
449
+ assert(hparams.n_text_state == hparams.n_audio_state);
450
+
451
+ if (hparams.n_audio_layer == 4) {
452
+ model.type = e_model::MODEL_TINY;
453
+ }
454
+
455
+ if (hparams.n_audio_layer == 6) {
456
+ model.type = e_model::MODEL_BASE;
457
+ }
458
+
459
+ if (hparams.n_audio_layer == 12) {
460
+ model.type = e_model::MODEL_SMALL;
461
+ }
462
+
463
+ if (hparams.n_audio_layer == 24) {
464
+ model.type = e_model::MODEL_MEDIUM;
465
+ }
466
+
467
+ if (hparams.n_audio_layer == 32) {
468
+ model.type = e_model::MODEL_LARGE;
469
+ }
470
+
471
+ printf("%s: n_vocab = %d\n", __func__, hparams.n_vocab);
472
+ printf("%s: n_audio_ctx = %d\n", __func__, hparams.n_audio_ctx);
473
+ printf("%s: n_audio_state = %d\n", __func__, hparams.n_audio_state);
474
+ printf("%s: n_audio_head = %d\n", __func__, hparams.n_audio_head);
475
+ printf("%s: n_audio_layer = %d\n", __func__, hparams.n_audio_layer);
476
+ printf("%s: n_text_ctx = %d\n", __func__, hparams.n_text_ctx);
477
+ printf("%s: n_text_state = %d\n", __func__, hparams.n_text_state);
478
+ printf("%s: n_text_head = %d\n", __func__, hparams.n_text_head);
479
+ printf("%s: n_text_layer = %d\n", __func__, hparams.n_text_layer);
480
+ printf("%s: n_mels = %d\n", __func__, hparams.n_mels);
481
+ printf("%s: f16 = %d\n", __func__, hparams.f16);
482
+ printf("%s: type = %d\n", __func__, model.type);
483
+
484
+ wctx.buf_model.resize(MEM_REQ_MODEL.at(model.type));
485
+ wctx.buf_compute.resize(std::max(MEM_REQ_ENCODE.at(model.type), MEM_REQ_DECODE.at(model.type)));
486
+ wctx.buf_compute_layer.resize(std::max(MEM_REQ_ENCODE_LAYER.at(model.type), MEM_REQ_DECODE_LAYER.at(model.type)));
487
+
488
+ // this is the total memory required to run the inference
489
+ const size_t mem_required =
490
+ wctx.buf_model.size() +
491
+ wctx.buf_compute.size() +
492
+ wctx.buf_compute_layer.size();
493
+
494
+ printf("%s: mem_required = %.2f MB\n", __func__, mem_required / 1024.0 / 1024.0);
495
+ }
496
+
497
+ // load mel filters
498
+ {
499
+ auto & filters = wctx.model.filters;
500
+
501
+ fin.read((char *) &filters.n_mel, sizeof(filters.n_mel));
502
+ fin.read((char *) &filters.n_fft, sizeof(filters.n_fft));
503
+
504
+ filters.data.resize(filters.n_mel * filters.n_fft);
505
+ fin.read((char *) filters.data.data(), filters.data.size() * sizeof(float));
506
+ }
507
+
508
+ // load vocab
509
+ {
510
+ int32_t n_vocab = 0;
511
+ fin.read((char *) &n_vocab, sizeof(n_vocab));
512
+
513
+ //if (n_vocab != model.hparams.n_vocab) {
514
+ // fprintf(stderr, "%s: invalid model file '%s' (bad vocab size %d != %d)\n",
515
+ // __func__, fname.c_str(), n_vocab, model.hparams.n_vocab);
516
+ // return false;
517
+ //}
518
+
519
+ std::string word;
520
+ for (int i = 0; i < n_vocab; i++) {
521
+ uint32_t len;
522
+ fin.read((char *) &len, sizeof(len));
523
+
524
+ word.resize(len);
525
+ fin.read((char *) word.data(), len);
526
+
527
+ vocab.token_to_id[word] = i;
528
+ vocab.id_to_token[i] = word;
529
+
530
+ //printf("%s: vocab[%d] = '%s'\n", __func__, i, word.c_str());
531
+ }
532
+
533
+ vocab.n_vocab = model.hparams.n_vocab;
534
+ if (vocab.is_multilingual()) {
535
+ vocab.token_eot++;
536
+ vocab.token_sot++;
537
+ vocab.token_prev++;
538
+ vocab.token_solm++;
539
+ vocab.token_not++;
540
+ vocab.token_beg++;
541
+ }
542
+
543
+ if (n_vocab < model.hparams.n_vocab) {
544
+ printf("%s: adding %d extra tokens\n", __func__, model.hparams.n_vocab - n_vocab);
545
+ for (int i = n_vocab; i < model.hparams.n_vocab; i++) {
546
+ if (i > vocab.token_beg) {
547
+ word = "[_TT_" + std::to_string(i - vocab.token_beg) + "]";
548
+ } else if (i == vocab.token_eot) {
549
+ word = "[_EOT_]";
550
+ } else if (i == vocab.token_sot) {
551
+ word = "[_SOT_]";
552
+ } else if (i == vocab.token_prev) {
553
+ word = "[_PREV_]";
554
+ } else if (i == vocab.token_not) {
555
+ word = "[_NOT_]";
556
+ } else if (i == vocab.token_beg) {
557
+ word = "[_BEG_]";
558
+ } else {
559
+ word = "[_extra_token_" + std::to_string(i) + "]";
560
+ }
561
+ vocab.token_to_id[word] = i;
562
+ vocab.id_to_token[i] = word;
563
+ }
564
+ }
565
+ }
566
+
567
+ // for the big tensors, we have the option to store the data in 16-bit floats
568
+ // in order to save memory and also to speed up the computation
569
+ const ggml_type wtype = model.hparams.f16 ? GGML_TYPE_F16 : GGML_TYPE_F32;
570
+
571
+
572
+ size_t ctx_size = 0;
573
+
574
+ {
575
+ const auto & hparams = model.hparams;
576
+
577
+ const int n_vocab = hparams.n_vocab;
578
+
579
+ const int n_audio_ctx = hparams.n_audio_ctx;
580
+ const int n_audio_state = hparams.n_audio_state;
581
+ const int n_audio_layer = hparams.n_audio_layer;
582
+
583
+ const int n_text_ctx = hparams.n_text_ctx;
584
+ const int n_text_state = hparams.n_text_state;
585
+ const int n_text_layer = hparams.n_text_layer;
586
+
587
+ const int n_mels = hparams.n_mels;
588
+
589
+ // encoder
590
+ {
591
+ // TODO: F16 .. maybe not?
592
+ ctx_size += n_audio_ctx*n_audio_state*ggml_type_size(GGML_TYPE_F32); // e_pe;
593
+
594
+ ctx_size += 3*n_mels*n_audio_state*ggml_type_size(wtype); // e_conv_1_w
595
+ ctx_size += n_audio_state*ggml_type_size(GGML_TYPE_F32); // e_conv_1_b
596
+
597
+ ctx_size += 3*n_audio_state*n_audio_state*ggml_type_size(wtype); // e_conv_2_w
598
+ ctx_size += n_audio_state*ggml_type_size(GGML_TYPE_F32); // e_conv_2_b
599
+
600
+ ctx_size += n_audio_state*ggml_type_size(GGML_TYPE_F32); // e_ln_w;
601
+ ctx_size += n_audio_state*ggml_type_size(GGML_TYPE_F32); // e_ln_b;
602
+ }
603
+
604
+ // decoder
605
+ {
606
+ // TODO: F16 .. maybe not?
607
+ ctx_size += n_text_ctx*n_text_state*ggml_type_size(GGML_TYPE_F32); // d_pe;
608
+
609
+ ctx_size += n_vocab*n_text_state*ggml_type_size(wtype); // d_te;
610
+
611
+ ctx_size += n_text_state*ggml_type_size(GGML_TYPE_F32); // d_ln_w;
612
+ ctx_size += n_text_state*ggml_type_size(GGML_TYPE_F32); // d_ln_b;
613
+ }
614
+
615
+ // encoder layers
616
+ {
617
+ ctx_size += n_audio_layer*(n_audio_state*ggml_type_size(GGML_TYPE_F32)); // mlp_ln_w
618
+ ctx_size += n_audio_layer*(n_audio_state*ggml_type_size(GGML_TYPE_F32)); // mlp_ln_b
619
+
620
+ ctx_size += n_audio_layer*(4*n_audio_state*n_audio_state*ggml_type_size(wtype)); // mlp_0_w
621
+ ctx_size += n_audio_layer*( 4*n_audio_state*ggml_type_size(GGML_TYPE_F32)); // mlp_0_b
622
+
623
+ ctx_size += n_audio_layer*(4*n_audio_state*n_audio_state*ggml_type_size(wtype)); // mlp_1_w
624
+ ctx_size += n_audio_layer*( n_audio_state*ggml_type_size(GGML_TYPE_F32)); // mlp_1_b
625
+
626
+ ctx_size += n_audio_layer*(n_audio_state*ggml_type_size(GGML_TYPE_F32)); // attn_ln_0_w
627
+ ctx_size += n_audio_layer*(n_audio_state*ggml_type_size(GGML_TYPE_F32)); // attn_ln_0_b
628
+
629
+ ctx_size += n_audio_layer*(n_audio_state*n_audio_state*ggml_type_size(wtype)); // attn_q_w
630
+ ctx_size += n_audio_layer*( n_audio_state*ggml_type_size(GGML_TYPE_F32)); // attn_q_b
631
+
632
+ ctx_size += n_audio_layer*(n_audio_state*n_audio_state*ggml_type_size(wtype)); // attn_k_w
633
+
634
+ ctx_size += n_audio_layer*(n_audio_state*n_audio_state*ggml_type_size(wtype)); // attn_v_w
635
+ ctx_size += n_audio_layer*( n_audio_state*ggml_type_size(GGML_TYPE_F32)); // attn_v_b
636
+
637
+ ctx_size += n_audio_layer*(n_audio_state*n_audio_state*ggml_type_size(wtype)); // attn_ln_1_w
638
+ ctx_size += n_audio_layer*( n_audio_state*ggml_type_size(GGML_TYPE_F32)); // attn_ln_1_b
639
+ }
640
+
641
+ // decoder layers
642
+ {
643
+ ctx_size += n_text_layer*(n_text_state*ggml_type_size(GGML_TYPE_F32)); // mlp_ln_w
644
+ ctx_size += n_text_layer*(n_text_state*ggml_type_size(GGML_TYPE_F32)); // mlp_ln_b
645
+
646
+ ctx_size += n_text_layer*(4*n_text_state*n_text_state*ggml_type_size(wtype)); // mlp_0_w
647
+ ctx_size += n_text_layer*( 4*n_text_state*ggml_type_size(GGML_TYPE_F32)); // mlp_0_b
648
+
649
+ ctx_size += n_text_layer*(4*n_text_state*n_text_state*ggml_type_size(wtype)); // mlp_1_w
650
+ ctx_size += n_text_layer*( n_text_state*ggml_type_size(GGML_TYPE_F32)); // mlp_1_b
651
+
652
+ ctx_size += n_text_layer*(n_text_state*ggml_type_size(GGML_TYPE_F32)); // attn_ln_0_w
653
+ ctx_size += n_text_layer*(n_text_state*ggml_type_size(GGML_TYPE_F32)); // attn_ln_0_b
654
+
655
+ ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_size(wtype)); // attn_q_w
656
+ ctx_size += n_text_layer*( n_text_state*ggml_type_size(GGML_TYPE_F32)); // attn_q_b
657
+
658
+ ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_size(wtype)); // attn_k_w
659
+
660
+ ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_size(wtype)); // attn_v_w
661
+ ctx_size += n_text_layer*( n_text_state*ggml_type_size(GGML_TYPE_F32)); // attn_v_b
662
+
663
+ ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_size(wtype)); // attn_ln_1_w
664
+ ctx_size += n_text_layer*( n_text_state*ggml_type_size(GGML_TYPE_F32)); // attn_ln_1_b
665
+ //
666
+ ctx_size += n_text_layer*(n_text_state*ggml_type_size(GGML_TYPE_F32)); // cross_attn_ln_0_w
667
+ ctx_size += n_text_layer*(n_text_state*ggml_type_size(GGML_TYPE_F32)); // cross_attn_ln_0_b
668
+
669
+ ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_size(wtype)); // cross_attn_q_w
670
+ ctx_size += n_text_layer*( n_text_state*ggml_type_size(GGML_TYPE_F32)); // cross_attn_q_b
671
+
672
+ ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_size(wtype)); // cross_attn_k_w
673
+
674
+ ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_size(wtype)); // cross_attn_v_w
675
+ ctx_size += n_text_layer*( n_text_state*ggml_type_size(GGML_TYPE_F32)); // cross_attn_v_b
676
+
677
+ ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_size(wtype)); // cross_attn_ln_1_w
678
+ ctx_size += n_text_layer*( n_text_state*ggml_type_size(GGML_TYPE_F32)); // cross_attn_ln_1_b
679
+ }
680
+
681
+ ctx_size += n_text_layer*n_text_ctx*n_text_state*ggml_type_size(GGML_TYPE_F16); // memory_k
682
+ ctx_size += n_text_layer*n_text_ctx*n_text_state*ggml_type_size(GGML_TYPE_F16); // memory_v
683
+
684
+ ctx_size += n_text_layer*n_audio_ctx*n_text_state*ggml_type_size(GGML_TYPE_F16); // memory_cross_k
685
+ ctx_size += n_text_layer*n_audio_ctx*n_text_state*ggml_type_size(GGML_TYPE_F16); // memory_cross_v
686
+
687
+ ctx_size += (15 + 15*n_audio_layer + 24*n_text_layer)*256; // object overhead
688
+
689
+ printf("%s: ggml ctx size = %6.2f MB\n", __func__, ctx_size/(1024.0*1024.0));
690
+ }
691
+
692
+ // create the ggml context
693
+ {
694
+ struct ggml_init_params params = {
695
+ .mem_size = wctx.buf_model.size(),
696
+ .mem_buffer = wctx.buf_model.data(),
697
+ };
698
+
699
+ model.ctx = ggml_init(params);
700
+ if (!model.ctx) {
701
+ fprintf(stderr, "%s: ggml_init() failed\n", __func__);
702
+ return false;
703
+ }
704
+ }
705
+
706
+ // prepare memory for the weights
707
+ {
708
+ auto & ctx = model.ctx;
709
+
710
+ const auto & hparams = model.hparams;
711
+
712
+ const int n_vocab = hparams.n_vocab;
713
+
714
+ const int n_audio_ctx = hparams.n_audio_ctx;
715
+ const int n_audio_state = hparams.n_audio_state;
716
+ const int n_audio_layer = hparams.n_audio_layer;
717
+
718
+ const int n_text_ctx = hparams.n_text_ctx;
719
+ const int n_text_state = hparams.n_text_state;
720
+ const int n_text_layer = hparams.n_text_layer;
721
+
722
+ const int n_mels = hparams.n_mels;
723
+
724
+ model.layers_encoder.resize(n_audio_layer);
725
+ model.layers_decoder.resize(n_text_layer);
726
+
727
+ // encoder
728
+ {
729
+ model.e_pe = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_audio_state, n_audio_ctx);
730
+
731
+ model.e_conv_1_w = ggml_new_tensor_3d(ctx, wtype, 3, n_mels, n_audio_state);
732
+ model.e_conv_1_b = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, 1, n_audio_state);
733
+
734
+ model.e_conv_2_w = ggml_new_tensor_3d(ctx, wtype, 3, n_audio_state, n_audio_state);
735
+ model.e_conv_2_b = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, 1, n_audio_state);
736
+
737
+ model.e_ln_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state);
738
+ model.e_ln_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state);
739
+
740
+ // map by name
741
+ model.tensors["encoder.positional_embedding"] = model.e_pe;
742
+
743
+ model.tensors["encoder.conv1.weight"] = model.e_conv_1_w;
744
+ model.tensors["encoder.conv1.bias"] = model.e_conv_1_b;
745
+
746
+ model.tensors["encoder.conv2.weight"] = model.e_conv_2_w;
747
+ model.tensors["encoder.conv2.bias"] = model.e_conv_2_b;
748
+
749
+ model.tensors["encoder.ln_post.weight"] = model.e_ln_w;
750
+ model.tensors["encoder.ln_post.bias"] = model.e_ln_b;
751
+
752
+ for (int i = 0; i < n_audio_layer; ++i) {
753
+ auto & layer = model.layers_encoder[i];
754
+
755
+ layer.mlp_ln_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state);
756
+ layer.mlp_ln_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state);
757
+
758
+ layer.mlp_0_w = ggml_new_tensor_2d(ctx, wtype, n_audio_state, 4*n_audio_state);
759
+ layer.mlp_0_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 4*n_audio_state);
760
+
761
+ layer.mlp_1_w = ggml_new_tensor_2d(ctx, wtype, 4*n_audio_state, n_audio_state);
762
+ layer.mlp_1_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state);
763
+
764
+ layer.attn_ln_0_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state);
765
+ layer.attn_ln_0_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state);
766
+
767
+ layer.attn_q_w = ggml_new_tensor_2d(ctx, wtype, n_audio_state, n_audio_state);
768
+ layer.attn_q_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state);
769
+
770
+ layer.attn_k_w = ggml_new_tensor_2d(ctx, wtype, n_audio_state, n_audio_state);
771
+
772
+ layer.attn_v_w = ggml_new_tensor_2d(ctx, wtype, n_audio_state, n_audio_state);
773
+ layer.attn_v_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state);
774
+
775
+ layer.attn_ln_1_w = ggml_new_tensor_2d(ctx, wtype, n_audio_state, n_audio_state);
776
+ layer.attn_ln_1_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state);
777
+
778
+ // map by name
779
+ model.tensors["encoder.blocks." + std::to_string(i) + ".mlp_ln.weight"] = layer.mlp_ln_w;
780
+ model.tensors["encoder.blocks." + std::to_string(i) + ".mlp_ln.bias"] = layer.mlp_ln_b;
781
+
782
+ model.tensors["encoder.blocks." + std::to_string(i) + ".mlp.0.weight"] = layer.mlp_0_w;
783
+ model.tensors["encoder.blocks." + std::to_string(i) + ".mlp.0.bias"] = layer.mlp_0_b;
784
+
785
+ model.tensors["encoder.blocks." + std::to_string(i) + ".mlp.2.weight"] = layer.mlp_1_w;
786
+ model.tensors["encoder.blocks." + std::to_string(i) + ".mlp.2.bias"] = layer.mlp_1_b;
787
+
788
+ model.tensors["encoder.blocks." + std::to_string(i) + ".attn_ln.weight"] = layer.attn_ln_0_w;
789
+ model.tensors["encoder.blocks." + std::to_string(i) + ".attn_ln.bias"] = layer.attn_ln_0_b;
790
+
791
+ model.tensors["encoder.blocks." + std::to_string(i) + ".attn.query.weight"] = layer.attn_q_w;
792
+ model.tensors["encoder.blocks." + std::to_string(i) + ".attn.query.bias"] = layer.attn_q_b;
793
+
794
+ model.tensors["encoder.blocks." + std::to_string(i) + ".attn.key.weight"] = layer.attn_k_w;
795
+
796
+ model.tensors["encoder.blocks." + std::to_string(i) + ".attn.value.weight"] = layer.attn_v_w;
797
+ model.tensors["encoder.blocks." + std::to_string(i) + ".attn.value.bias"] = layer.attn_v_b;
798
+
799
+ model.tensors["encoder.blocks." + std::to_string(i) + ".attn.out.weight"] = layer.attn_ln_1_w;
800
+ model.tensors["encoder.blocks." + std::to_string(i) + ".attn.out.bias"] = layer.attn_ln_1_b;
801
+ }
802
+ }
803
+
804
+ // decoder
805
+ {
806
+ model.d_pe = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_text_state, n_text_ctx);
807
+
808
+ model.d_te = ggml_new_tensor_2d(ctx, wtype, n_text_state, n_vocab);
809
+
810
+ model.d_ln_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
811
+ model.d_ln_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
812
+
813
+ // map by name
814
+ model.tensors["decoder.positional_embedding"] = model.d_pe;
815
+
816
+ model.tensors["decoder.token_embedding.weight"] = model.d_te;
817
+
818
+ model.tensors["decoder.ln.weight"] = model.d_ln_w;
819
+ model.tensors["decoder.ln.bias"] = model.d_ln_b;
820
+
821
+ for (int i = 0; i < n_text_layer; ++i) {
822
+ auto & layer = model.layers_decoder[i];
823
+
824
+ layer.mlp_ln_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
825
+ layer.mlp_ln_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
826
+
827
+ layer.mlp_0_w = ggml_new_tensor_2d(ctx, wtype, n_text_state, 4*n_text_state);
828
+ layer.mlp_0_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 4*n_text_state);
829
+
830
+ layer.mlp_1_w = ggml_new_tensor_2d(ctx, wtype, 4*n_text_state, n_text_state);
831
+ layer.mlp_1_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
832
+
833
+ layer.attn_ln_0_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
834
+ layer.attn_ln_0_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
835
+
836
+ layer.attn_q_w = ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state);
837
+ layer.attn_q_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
838
+
839
+ layer.attn_k_w = ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state);
840
+
841
+ layer.attn_v_w = ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state);
842
+ layer.attn_v_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
843
+
844
+ layer.attn_ln_1_w = ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state);
845
+ layer.attn_ln_1_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
846
+
847
+ layer.cross_attn_ln_0_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
848
+ layer.cross_attn_ln_0_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
849
+
850
+ layer.cross_attn_q_w = ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state);
851
+ layer.cross_attn_q_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
852
+
853
+ layer.cross_attn_k_w = ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state);
854
+
855
+ layer.cross_attn_v_w = ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state);
856
+ layer.cross_attn_v_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
857
+
858
+ layer.cross_attn_ln_1_w = ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state);
859
+ layer.cross_attn_ln_1_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
860
+
861
+ // map by name
862
+ model.tensors["decoder.blocks." + std::to_string(i) + ".mlp_ln.weight"] = layer.mlp_ln_w;
863
+ model.tensors["decoder.blocks." + std::to_string(i) + ".mlp_ln.bias"] = layer.mlp_ln_b;
864
+
865
+ model.tensors["decoder.blocks." + std::to_string(i) + ".mlp.0.weight"] = layer.mlp_0_w;
866
+ model.tensors["decoder.blocks." + std::to_string(i) + ".mlp.0.bias"] = layer.mlp_0_b;
867
+
868
+ model.tensors["decoder.blocks." + std::to_string(i) + ".mlp.2.weight"] = layer.mlp_1_w;
869
+ model.tensors["decoder.blocks." + std::to_string(i) + ".mlp.2.bias"] = layer.mlp_1_b;
870
+
871
+ model.tensors["decoder.blocks." + std::to_string(i) + ".attn_ln.weight"] = layer.attn_ln_0_w;
872
+ model.tensors["decoder.blocks." + std::to_string(i) + ".attn_ln.bias"] = layer.attn_ln_0_b;
873
+
874
+ model.tensors["decoder.blocks." + std::to_string(i) + ".attn.query.weight"] = layer.attn_q_w;
875
+ model.tensors["decoder.blocks." + std::to_string(i) + ".attn.query.bias"] = layer.attn_q_b;
876
+
877
+ model.tensors["decoder.blocks." + std::to_string(i) + ".attn.key.weight"] = layer.attn_k_w;
878
+
879
+ model.tensors["decoder.blocks." + std::to_string(i) + ".attn.value.weight"] = layer.attn_v_w;
880
+ model.tensors["decoder.blocks." + std::to_string(i) + ".attn.value.bias"] = layer.attn_v_b;
881
+
882
+ model.tensors["decoder.blocks." + std::to_string(i) + ".attn.out.weight"] = layer.attn_ln_1_w;
883
+ model.tensors["decoder.blocks." + std::to_string(i) + ".attn.out.bias"] = layer.attn_ln_1_b;
884
+
885
+ model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn_ln.weight"] = layer.cross_attn_ln_0_w;
886
+ model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn_ln.bias"] = layer.cross_attn_ln_0_b;
887
+
888
+ model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn.query.weight"] = layer.cross_attn_q_w;
889
+ model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn.query.bias"] = layer.cross_attn_q_b;
890
+
891
+ model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn.key.weight"] = layer.cross_attn_k_w;
892
+
893
+ model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn.value.weight"] = layer.cross_attn_v_w;
894
+ model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn.value.bias"] = layer.cross_attn_v_b;
895
+
896
+ model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn.out.weight"] = layer.cross_attn_ln_1_w;
897
+ model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn.out.bias"] = layer.cross_attn_ln_1_b;
898
+ }
899
+ }
900
+ }
901
+
902
+ // key + value memory
903
+ {
904
+ auto & ctx = model.ctx;
905
+
906
+ const auto & hparams = model.hparams;
907
+
908
+ const int n_text_state = hparams.n_text_state;
909
+ const int n_text_layer = hparams.n_text_layer;
910
+ const int n_text_ctx = hparams.n_text_ctx;
911
+
912
+ // key/value memory for the self-attention layer
913
+ {
914
+ const int n_mem = n_text_layer*n_text_ctx;
915
+ const int n_elements = n_text_state*n_mem;
916
+
917
+ model.memory_k = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, n_elements);
918
+ model.memory_v = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, n_elements);
919
+ }
920
+
921
+ // key/value memory for the cross-attention layer
922
+ {
923
+ const int n_audio_ctx = hparams.n_audio_ctx;
924
+
925
+ const int n_mem = n_text_layer*n_audio_ctx;
926
+ const int n_elements = n_text_state*n_mem;
927
+
928
+ model.memory_cross_k = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, n_elements);
929
+ model.memory_cross_v = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, n_elements);
930
+ }
931
+
932
+ const size_t memory_size =
933
+ ggml_nbytes(model.memory_k) + ggml_nbytes(model.memory_v) +
934
+ ggml_nbytes(model.memory_cross_k) + ggml_nbytes(model.memory_cross_v);
935
+
936
+ printf("%s: memory size = %8.2f MB \n", __func__, memory_size/1024.0/1024.0);
937
+ }
938
+
939
+ // load weights
940
+ {
941
+ size_t total_size = 0;
942
+
943
+ while (true) {
944
+ int32_t n_dims;
945
+ int32_t length;
946
+ int32_t ftype;
947
+
948
+ fin.read(reinterpret_cast<char *>(&n_dims), sizeof(n_dims));
949
+ fin.read(reinterpret_cast<char *>(&length), sizeof(length));
950
+ fin.read(reinterpret_cast<char *>(&ftype), sizeof(ftype));
951
+
952
+ if (fin.eof()) {
953
+ break;
954
+ }
955
+
956
+ int32_t nelements = 1;
957
+ int32_t ne[3] = { 1, 1, 1 };
958
+ for (int i = 0; i < n_dims; ++i) {
959
+ fin.read(reinterpret_cast<char *>(&ne[i]), sizeof(ne[i]));
960
+ nelements *= ne[i];
961
+ }
962
+
963
+ std::string name(length, 0);
964
+ fin.read(&name[0], length);
965
+
966
+ if (model.tensors.find(name.data()) == model.tensors.end()) {
967
+ fprintf(stderr, "%s: unknown tensor '%s' in model file\n", __func__, name.data());
968
+ return false;
969
+ }
970
+
971
+ auto tensor = model.tensors[name.data()];
972
+ if (ggml_nelements(tensor) != nelements) {
973
+ fprintf(stderr, "%s: tensor '%s' has wrong size in model file\n", __func__, name.data());
974
+ return false;
975
+ }
976
+
977
+ if (tensor->ne[0] != ne[0] || tensor->ne[1] != ne[1] || tensor->ne[2] != ne[2]) {
978
+ fprintf(stderr, "%s: tensor '%s' has wrong shape in model file: got [%d, %d, %d], expected [%d, %d, %d]\n",
979
+ __func__, name.data(), tensor->ne[0], tensor->ne[1], tensor->ne[2], ne[0], ne[1], ne[2]);
980
+ return false;
981
+ }
982
+
983
+ const size_t bpe = (ftype == 0) ? sizeof(float) : sizeof(ggml_fp16_t);
984
+
985
+ if (nelements*bpe != ggml_nbytes(tensor)) {
986
+ fprintf(stderr, "%s: tensor '%s' has wrong size in model file: got %zu, expected %zu\n",
987
+ __func__, name.data(), ggml_nbytes(tensor), nelements*bpe);
988
+ return false;
989
+ }
990
+
991
+ fin.read(reinterpret_cast<char *>(tensor->data), ggml_nbytes(tensor));
992
+
993
+ //printf("%24s - [%5d, %5d], type = %6s, %6.2f MB\n", name.data(), ne[0], ne[1], ftype == 0 ? "float" : "f16", ggml_nbytes(tensor)/1024.0/1024.0);
994
+ total_size += ggml_nbytes(tensor);
995
+ }
996
+
997
+ printf("%s: model size = %8.2f MB\n", __func__, total_size/1024.0/1024.0);
998
+ }
999
+
1000
+ fin.close();
1001
+
1002
+ return true;
1003
+ }
1004
+
1005
+ // evaluate the encoder
1006
+ //
1007
+ // given audio recording (more specifically, its log mel spectrogram), runs forward pass of the encoder
1008
+ // part of the transformer model and returns the encoded features
1009
+ //
1010
+ // - model: the model
1011
+ // - n_threads: number of threads to use
1012
+ // - mel_offset: offset in the mel spectrogram (i.e. audio offset)
1013
+ // - mel_inp: input mel spectrogram
1014
+ // - features: output encoded features
1015
+ //
1016
+ bool whisper_encode(
1017
+ whisper_context & wctx,
1018
+ const int n_threads,
1019
+ const int mel_offset) {
1020
+ const auto & model = wctx.model;
1021
+ const auto & mel_inp = wctx.mel;
1022
+ const auto & hparams = model.hparams;
1023
+
1024
+ const int n_vocab = hparams.n_vocab;
1025
+
1026
+ const int n_ctx = hparams.n_audio_ctx;
1027
+ const int n_state = hparams.n_audio_state;
1028
+ const int n_head = hparams.n_audio_head;
1029
+ const int n_layer = hparams.n_audio_layer;
1030
+
1031
+ const int N = n_ctx;
1032
+
1033
+ const int n_mels = hparams.n_mels;
1034
+ assert(mel_inp.n_mel == n_mels);
1035
+
1036
+ struct ggml_init_params params = {
1037
+ .mem_size = wctx.buf_compute.size(),
1038
+ .mem_buffer = wctx.buf_compute.data(),
1039
+ };
1040
+
1041
+ struct ggml_context * ctx0 = ggml_init(params);
1042
+
1043
+ struct ggml_tensor * mel = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, 2*n_ctx, n_mels);
1044
+ assert(mel->type == GGML_TYPE_F32);
1045
+ {
1046
+ float * dst = (float *) mel->data;
1047
+ memset(dst, 0, ggml_nbytes(mel));
1048
+
1049
+ const int i0 = std::min(mel_offset, mel_inp.n_len);
1050
+ const int i1 = std::min(mel_offset + 2*n_ctx, mel_inp.n_len);
1051
+
1052
+ for (int j = 0; j < mel_inp.n_mel; ++j) {
1053
+ for (int i = i0; i < i1; ++i) {
1054
+ dst[j*2*n_ctx + (i - i0)] = mel_inp.data[j*mel_inp.n_len + i];
1055
+ }
1056
+ }
1057
+ }
1058
+
1059
+ struct ggml_tensor * cur;
1060
+
1061
+ // convolution + gelu
1062
+ {
1063
+ cur = ggml_conv_1d_1s(ctx0, model.e_conv_1_w, mel);
1064
+ cur = ggml_add(ctx0,
1065
+ ggml_repeat(ctx0,
1066
+ model.e_conv_1_b,
1067
+ cur),
1068
+ cur);
1069
+
1070
+ cur = ggml_gelu(ctx0, cur);
1071
+
1072
+ cur = ggml_conv_1d_2s(ctx0, model.e_conv_2_w, cur);
1073
+ cur = ggml_add(ctx0,
1074
+ ggml_repeat(ctx0,
1075
+ model.e_conv_2_b,
1076
+ cur),
1077
+ cur);
1078
+
1079
+ cur = ggml_gelu(ctx0, cur);
1080
+ }
1081
+
1082
+ cur = ggml_add(ctx0, model.e_pe, ggml_transpose(ctx0, cur));
1083
+
1084
+ struct ggml_tensor * inpL = cur;
1085
+
1086
+ for (int il = 0; il < n_layer; ++il) {
1087
+ const auto & layer = model.layers_encoder[il];
1088
+
1089
+ // create separate context for each layer to reduce memory usage
1090
+
1091
+ struct ggml_init_params paramsL = {
1092
+ .mem_size = wctx.buf_compute_layer.size(),
1093
+ .mem_buffer = wctx.buf_compute_layer.data(),
1094
+ };
1095
+
1096
+ struct ggml_context * ctxL = ggml_init(paramsL);
1097
+
1098
+ // norm
1099
+ {
1100
+ cur = ggml_norm(ctxL, inpL);
1101
+
1102
+ // cur = ln_0_w*cur + ln_0_b
1103
+ cur = ggml_add(ctxL,
1104
+ ggml_mul(ctxL,
1105
+ ggml_repeat(ctxL, layer.attn_ln_0_w, cur),
1106
+ cur),
1107
+ ggml_repeat(ctxL, layer.attn_ln_0_b, cur));
1108
+ }
1109
+
1110
+ // self-attention
1111
+ {
1112
+ struct ggml_tensor * Qcur = ggml_mul_mat(ctxL,
1113
+ layer.attn_q_w,
1114
+ cur);
1115
+
1116
+ Qcur = ggml_add(ctxL,
1117
+ ggml_repeat(ctxL,
1118
+ layer.attn_q_b,
1119
+ Qcur),
1120
+ Qcur);
1121
+
1122
+ //Qcur = ggml_scale(ctxL, Qcur, ggml_new_f32(ctxL, pow(float(n_state)/n_head, -0.25)));
1123
+
1124
+ // note: no bias for Key
1125
+ struct ggml_tensor * Kcur = ggml_mul_mat(ctxL,
1126
+ layer.attn_k_w,
1127
+ cur);
1128
+
1129
+ //Kcur = ggml_scale(ctxL, Kcur, ggml_new_f32(ctxL, pow(float(n_state)/n_head, -0.25)));
1130
+
1131
+ struct ggml_tensor * Vcur = ggml_mul_mat(ctxL,
1132
+ layer.attn_v_w,
1133
+ cur);
1134
+
1135
+ Vcur = ggml_add(ctxL,
1136
+ ggml_repeat(ctxL,
1137
+ layer.attn_v_b,
1138
+ Vcur),
1139
+ Vcur);
1140
+
1141
+ // ------
1142
+
1143
+ #ifdef USE_FLASH_ATTN
1144
+ struct ggml_tensor * Q =
1145
+ ggml_permute(ctxL,
1146
+ ggml_cpy(ctxL,
1147
+ Qcur,
1148
+ ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_state/n_head, n_head, N)),
1149
+ 0, 2, 1, 3);
1150
+
1151
+ struct ggml_tensor * K =
1152
+ ggml_permute(ctxL,
1153
+ ggml_cpy(ctxL,
1154
+ Kcur,
1155
+ ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_state/n_head, n_head, N)),
1156
+ 0, 2, 1, 3);
1157
+
1158
+ struct ggml_tensor * V =
1159
+ ggml_cpy(ctxL,
1160
+ ggml_permute(ctxL,
1161
+ ggml_reshape_3d(ctxL,
1162
+ Vcur,
1163
+ n_state/n_head, n_head, N),
1164
+ 1, 2, 0, 3),
1165
+ ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, N, n_state/n_head, n_head)
1166
+ );
1167
+
1168
+ struct ggml_tensor * KQV = ggml_flash_attn(ctxL, Q, K, V, false);
1169
+ #else
1170
+ struct ggml_tensor * Q =
1171
+ ggml_permute(ctxL,
1172
+ ggml_cpy(ctxL,
1173
+ Qcur,
1174
+ ggml_new_tensor_3d(ctxL, GGML_TYPE_F32, n_state/n_head, n_head, N)),
1175
+ 0, 2, 1, 3);
1176
+
1177
+ struct ggml_tensor * K =
1178
+ ggml_permute(ctxL,
1179
+ ggml_cpy(ctxL,
1180
+ Kcur,
1181
+ ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_state/n_head, n_head, N)),
1182
+ 0, 2, 1, 3);
1183
+
1184
+ // K * Q
1185
+ struct ggml_tensor * KQ = ggml_mul_mat(ctxL, K, Q);
1186
+
1187
+ struct ggml_tensor * KQ_scaled =
1188
+ ggml_scale(ctxL,
1189
+ KQ,
1190
+ ggml_new_f32(ctxL, 1.0f/sqrt(float(n_state)/n_head))
1191
+ );
1192
+
1193
+ struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctxL, KQ_scaled);
1194
+
1195
+ //struct ggml_tensor * V_trans =
1196
+ // ggml_permute(ctxL,
1197
+ // ggml_cpy(ctxL,
1198
+ // Vcur,
1199
+ // ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_state/n_head, n_head, N)),
1200
+ // 1, 2, 0, 3);
1201
+
1202
+ //struct ggml_tensor * KQV = ggml_mul_mat(ctxL, V_trans, KQ_soft_max);
1203
+
1204
+ struct ggml_tensor * V =
1205
+ ggml_cpy(ctxL,
1206
+ ggml_permute(ctxL,
1207
+ ggml_reshape_3d(ctxL,
1208
+ Vcur,
1209
+ n_state/n_head, n_head, N),
1210
+ 0, 2, 1, 3),
1211
+ ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_state/n_head, N, n_head)
1212
+ );
1213
+
1214
+ struct ggml_tensor * KQV = ggml_mul_mat(ctxL, ggml_transpose(ctxL, V), KQ_soft_max);
1215
+ #endif
1216
+
1217
+ struct ggml_tensor * KQV_merged = ggml_permute(ctxL, KQV, 0, 2, 1, 3);
1218
+
1219
+ cur = ggml_cpy(ctxL,
1220
+ KQV_merged,
1221
+ ggml_new_tensor_2d(ctxL, GGML_TYPE_F32, n_state, N));
1222
+ }
1223
+
1224
+ // projection
1225
+ {
1226
+ cur = ggml_mul_mat(ctxL,
1227
+ layer.attn_ln_1_w,
1228
+ cur);
1229
+
1230
+ cur = ggml_add(ctxL,
1231
+ ggml_repeat(ctxL, layer.attn_ln_1_b, cur),
1232
+ cur);
1233
+ }
1234
+
1235
+ // add the input
1236
+ cur = ggml_add(ctxL, cur, inpL);
1237
+
1238
+ struct ggml_tensor * inpFF = cur;
1239
+
1240
+ // feed-forward network
1241
+ {
1242
+ // norm
1243
+ {
1244
+ cur = ggml_norm(ctxL, inpFF);
1245
+
1246
+ // cur = mlp_ln_w*cur + mlp_ln_b
1247
+ cur = ggml_add(ctxL,
1248
+ ggml_mul(ctxL,
1249
+ ggml_repeat(ctxL, layer.mlp_ln_w, cur),
1250
+ cur),
1251
+ ggml_repeat(ctxL, layer.mlp_ln_b, cur));
1252
+ }
1253
+
1254
+ #ifdef USE_FLASH_FF
1255
+ cur = ggml_flash_ff(ctxL,
1256
+ ggml_cpy(ctxL, cur, ggml_new_tensor_2d(ctxL, GGML_TYPE_F16, n_state, N)),
1257
+ layer.mlp_0_w, layer.mlp_0_b, layer.mlp_1_w, layer.mlp_1_b);
1258
+ #else
1259
+ // fully connected
1260
+ cur = ggml_mul_mat(ctxL,
1261
+ layer.mlp_0_w,
1262
+ cur);
1263
+
1264
+ cur = ggml_add(ctxL,
1265
+ ggml_repeat(ctxL, layer.mlp_0_b, cur),
1266
+ cur);
1267
+
1268
+ // GELU activation
1269
+ cur = ggml_gelu(ctxL, cur);
1270
+
1271
+ // projection
1272
+ cur = ggml_mul_mat(ctxL,
1273
+ layer.mlp_1_w,
1274
+ cur);
1275
+
1276
+ cur = ggml_add(ctxL,
1277
+ ggml_repeat(ctxL, layer.mlp_1_b, cur),
1278
+ cur);
1279
+ #endif
1280
+ }
1281
+
1282
+ // output from this layer
1283
+ struct ggml_tensor * inpO = ggml_add(ctxL, cur, inpFF);
1284
+
1285
+ {
1286
+ struct ggml_cgraph gf = { .n_threads = n_threads };
1287
+
1288
+ ggml_build_forward_expand(&gf, inpO);
1289
+ ggml_graph_compute (ctxL, &gf);
1290
+
1291
+ //ggml_graph_print(&gf);
1292
+ }
1293
+
1294
+ // TODO: this is a hack to have per-layer computation graphs - need to come up with something better
1295
+ // input for next layer (inpO -> inpL)
1296
+ memcpy(inpL->data, inpO->data, ggml_nbytes(inpL));
1297
+ inpL->op = GGML_OP_NONE;
1298
+ inpL->src0 = NULL;
1299
+ inpL->src1 = NULL;
1300
+
1301
+ //printf("%s: - used_mem(%d) = %f MB\n", __func__, il, ggml_used_mem(ctxL)/1024.0/1024.0);
1302
+
1303
+ ggml_free(ctxL);
1304
+ }
1305
+
1306
+ cur = inpL;
1307
+
1308
+ // norm
1309
+ {
1310
+ cur = ggml_norm(ctx0, cur);
1311
+
1312
+ // cur = ln_f_g*cur + ln_f_b
1313
+ cur = ggml_add(ctx0,
1314
+ ggml_mul(ctx0,
1315
+ ggml_repeat(ctx0, model.e_ln_w, cur),
1316
+ cur),
1317
+ ggml_repeat(ctx0, model.e_ln_b, cur));
1318
+ }
1319
+
1320
+ // run the computation
1321
+ {
1322
+ struct ggml_cgraph gf = { .n_threads = n_threads };
1323
+
1324
+ ggml_build_forward_expand(&gf, cur);
1325
+ ggml_graph_compute (ctx0, &gf);
1326
+
1327
+ //ggml_graph_print(&gf);
1328
+ }
1329
+
1330
+ // cur
1331
+ //{
1332
+ // printf("ne0 = %d\n", cur->ne[0]);
1333
+ // printf("ne1 = %d\n", cur->ne[1]);
1334
+ // for (int i = 0; i < 10; ++i) {
1335
+ // printf("%8.4f ", ((float *)(cur->data))[i]);
1336
+ // }
1337
+ // printf("... ");
1338
+ // for (int i = cur->ne[0] - 10; i < cur->ne[0]; ++i) {
1339
+ // printf("%8.4f ", ((float *)(cur->data))[i]);
1340
+ // }
1341
+ // printf("\n");
1342
+ //}
1343
+
1344
+ // pre-compute cross-attention memory
1345
+ {
1346
+ struct ggml_cgraph gf = { .n_threads = n_threads };
1347
+
1348
+ // TODO: hack to disconnect the encoded features from the previous graph
1349
+ cur->op = GGML_OP_NONE;
1350
+ cur->src0 = NULL;
1351
+ cur->src1 = NULL;
1352
+
1353
+ for (int il = 0; il < model.hparams.n_text_layer; ++il) {
1354
+ auto & layer = model.layers_decoder[il];
1355
+
1356
+ struct ggml_tensor * Kcross = ggml_mul_mat(ctx0,
1357
+ layer.cross_attn_k_w,
1358
+ cur);
1359
+
1360
+ Kcross = ggml_scale(ctx0, Kcross, ggml_new_f32(ctx0, pow(float(n_state)/n_head, -0.25)));
1361
+
1362
+ struct ggml_tensor * Vcross = ggml_mul_mat(ctx0,
1363
+ layer.cross_attn_v_w,
1364
+ cur);
1365
+
1366
+ Vcross = ggml_add(ctx0,
1367
+ ggml_repeat(ctx0,
1368
+ layer.cross_attn_v_b,
1369
+ Vcross),
1370
+ Vcross);
1371
+
1372
+ struct ggml_tensor * k = ggml_view_1d(ctx0, model.memory_cross_k, n_state*n_ctx, (ggml_element_size(model.memory_cross_k)*n_state)*(il*n_ctx));
1373
+ struct ggml_tensor * v = ggml_view_1d(ctx0, model.memory_cross_v, n_state*n_ctx, (ggml_element_size(model.memory_cross_v)*n_state)*(il*n_ctx));
1374
+
1375
+ ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Kcross, k));
1376
+ ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Vcross, v));
1377
+ }
1378
+
1379
+ ggml_graph_compute(ctx0, &gf);
1380
+ }
1381
+
1382
+ ////////////////////////////////////////////////////////////////////////////
1383
+
1384
+ //printf("%s: used_mem = %f MB\n", __func__, ggml_used_mem(ctx0)/1024.0/1024.0);
1385
+
1386
+ ggml_free(ctx0);
1387
+
1388
+ return true;
1389
+ }
1390
+
1391
+ // evaluate the decoder
1392
+ //
1393
+ // given text prompt + audio features -> predicts the probabilities for the next token
1394
+ //
1395
+ // - model: the model
1396
+ // - n_threads: number of threads to use
1397
+ // - n_past: prompt length
1398
+ // - prompt: text prompt
1399
+ // - logits_out: output logits
1400
+ // - probs_out: output probabilities
1401
+ //
1402
+ bool whisper_decode(
1403
+ whisper_context & wctx,
1404
+ const int n_threads,
1405
+ const whisper_token * tokens,
1406
+ const int n_tokens,
1407
+ const int n_past) {
1408
+ const auto & model = wctx.model;
1409
+ const auto & hparams = model.hparams;
1410
+
1411
+ auto & logits_out = wctx.logits;
1412
+ auto & probs_out = wctx.probs;
1413
+
1414
+ const int n_vocab = hparams.n_vocab;
1415
+
1416
+ const int n_ctx = hparams.n_text_ctx;
1417
+ const int n_state = hparams.n_text_state;
1418
+ const int n_head = hparams.n_text_head;
1419
+ const int n_layer = hparams.n_text_layer;
1420
+
1421
+ const int N = n_tokens;
1422
+ const int M = hparams.n_audio_ctx;
1423
+
1424
+ struct ggml_init_params params = {
1425
+ .mem_size = wctx.buf_compute.size(),
1426
+ .mem_buffer = wctx.buf_compute.data(),
1427
+ };
1428
+
1429
+ struct ggml_context * ctx0 = ggml_init(params);
1430
+
1431
+ struct ggml_tensor * embd = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
1432
+ memcpy(embd->data, tokens, N*ggml_element_size(embd));
1433
+
1434
+ struct ggml_tensor * position = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
1435
+ for (int i = 0; i < N; ++i) {
1436
+ ((int32_t *) position->data)[i] = n_past + i;
1437
+ }
1438
+
1439
+ // token encoding + position encoding
1440
+ struct ggml_tensor * cur =
1441
+ ggml_add(ctx0,
1442
+ ggml_get_rows(ctx0, model.d_te, embd),
1443
+ ggml_get_rows(ctx0, model.d_pe, position));
1444
+
1445
+ struct ggml_tensor * inpL = cur;
1446
+
1447
+ for (int il = 0; il < n_layer; ++il) {
1448
+ const auto & layer = model.layers_decoder[il];
1449
+
1450
+ struct ggml_init_params paramsL = {
1451
+ .mem_size = wctx.buf_compute_layer.size(),
1452
+ .mem_buffer = wctx.buf_compute_layer.data(),
1453
+ };
1454
+
1455
+ struct ggml_context * ctxL = ggml_init(paramsL);
1456
+ struct ggml_cgraph gf = { .n_threads = n_threads };
1457
+
1458
+ // norm
1459
+ {
1460
+ cur = ggml_norm(ctxL, inpL);
1461
+
1462
+ // cur = ln_0_w*cur + ln_0_b
1463
+ cur = ggml_add(ctxL,
1464
+ ggml_mul(ctxL,
1465
+ ggml_repeat(ctxL, layer.attn_ln_0_w, cur),
1466
+ cur),
1467
+ ggml_repeat(ctxL, layer.attn_ln_0_b, cur));
1468
+ }
1469
+
1470
+ // self-attention
1471
+ {
1472
+ struct ggml_tensor * Qcur = ggml_mul_mat(ctxL,
1473
+ layer.attn_q_w,
1474
+ cur);
1475
+
1476
+ Qcur = ggml_add(ctxL,
1477
+ ggml_repeat(ctxL,
1478
+ layer.attn_q_b,
1479
+ Qcur),
1480
+ Qcur);
1481
+
1482
+ Qcur = ggml_scale(ctxL, Qcur, ggml_new_f32(ctxL, pow(float(n_state)/n_head, -0.25)));
1483
+
1484
+ // note: no bias for Key
1485
+ struct ggml_tensor * Kcur = ggml_mul_mat(ctxL,
1486
+ layer.attn_k_w,
1487
+ cur);
1488
+
1489
+ Kcur = ggml_scale(ctxL, Kcur, ggml_new_f32(ctxL, pow(float(n_state)/n_head, -0.25)));
1490
+
1491
+ struct ggml_tensor * Vcur = ggml_mul_mat(ctxL,
1492
+ layer.attn_v_w,
1493
+ cur);
1494
+
1495
+ Vcur = ggml_add(ctxL,
1496
+ ggml_repeat(ctxL,
1497
+ layer.attn_v_b,
1498
+ Vcur),
1499
+ Vcur);
1500
+
1501
+ // store key and value to memory
1502
+ {
1503
+ struct ggml_tensor * k = ggml_view_1d(ctxL, model.memory_k, N*n_state, (ggml_element_size(model.memory_k)*n_state)*(il*n_ctx + n_past));
1504
+ struct ggml_tensor * v = ggml_view_1d(ctxL, model.memory_v, N*n_state, (ggml_element_size(model.memory_v)*n_state)*(il*n_ctx + n_past));
1505
+
1506
+ ggml_build_forward_expand(&gf, ggml_cpy(ctxL, Kcur, k));
1507
+ ggml_build_forward_expand(&gf, ggml_cpy(ctxL, Vcur, v));
1508
+ }
1509
+
1510
+ // ------
1511
+
1512
+ struct ggml_tensor * Q =
1513
+ ggml_permute(ctxL,
1514
+ ggml_cpy(ctxL,
1515
+ Qcur,
1516
+ ggml_new_tensor_3d(ctxL, GGML_TYPE_F32, n_state/n_head, n_head, N)),
1517
+ 0, 2, 1, 3);
1518
+
1519
+ struct ggml_tensor * K =
1520
+ ggml_permute(ctxL,
1521
+ ggml_reshape_3d(ctxL,
1522
+ ggml_view_1d(ctxL, model.memory_k, (n_past + N)*n_state, il*n_ctx*ggml_element_size(model.memory_k)*n_state),
1523
+ n_state/n_head, n_head, n_past + N),
1524
+ 0, 2, 1, 3);
1525
+
1526
+ // K * Q
1527
+ struct ggml_tensor * KQ = ggml_mul_mat(ctxL, K, Q);
1528
+
1529
+ //struct ggml_tensor * KQ_scaled =
1530
+ // ggml_scale(ctxL,
1531
+ // KQ,
1532
+ // ggml_new_f32(ctxL, 1.0f/sqrt(float(n_state)/n_head))
1533
+ // );
1534
+
1535
+ struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctxL, KQ, n_past);
1536
+
1537
+ struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctxL, KQ_masked);
1538
+
1539
+ struct ggml_tensor * V_trans =
1540
+ ggml_permute(ctxL,
1541
+ ggml_reshape_3d(ctxL,
1542
+ ggml_view_1d(ctxL, model.memory_v, (n_past + N)*n_state, il*n_ctx*ggml_element_size(model.memory_v)*n_state),
1543
+ n_state/n_head, n_head, n_past + N),
1544
+ 1, 2, 0, 3);
1545
+
1546
+ struct ggml_tensor * KQV = ggml_mul_mat(ctxL, V_trans, KQ_soft_max);
1547
+
1548
+ struct ggml_tensor * KQV_merged = ggml_permute(ctxL, KQV, 0, 2, 1, 3);
1549
+
1550
+ cur = ggml_cpy(ctxL,
1551
+ KQV_merged,
1552
+ ggml_new_tensor_2d(ctxL, GGML_TYPE_F32, n_state, N));
1553
+ }
1554
+
1555
+ {
1556
+ cur = ggml_mul_mat(ctxL,
1557
+ layer.attn_ln_1_w,
1558
+ cur);
1559
+
1560
+ cur = ggml_add(ctxL,
1561
+ ggml_repeat(ctxL, layer.attn_ln_1_b, cur),
1562
+ cur);
1563
+ }
1564
+
1565
+ // add the input
1566
+ struct ggml_tensor * inpCA = ggml_add(ctxL, cur, inpL);
1567
+
1568
+ // norm
1569
+ {
1570
+ cur = ggml_norm(ctxL, inpCA); // note: we use inpCA here
1571
+
1572
+ // cur = ln_0_w*cur + ln_0_b
1573
+ cur = ggml_add(ctxL,
1574
+ ggml_mul(ctxL,
1575
+ ggml_repeat(ctxL, layer.cross_attn_ln_0_w, cur),
1576
+ cur),
1577
+ ggml_repeat(ctxL, layer.cross_attn_ln_0_b, cur));
1578
+ }
1579
+
1580
+ // cross-attention
1581
+ {
1582
+ struct ggml_tensor * Qcur = ggml_mul_mat(ctxL,
1583
+ layer.cross_attn_q_w,
1584
+ cur);
1585
+
1586
+ Qcur = ggml_add(ctxL,
1587
+ ggml_repeat(ctxL,
1588
+ layer.cross_attn_q_b,
1589
+ Qcur),
1590
+ Qcur);
1591
+
1592
+ Qcur = ggml_scale(ctxL, Qcur, ggml_new_f32(ctxL, pow(float(n_state)/n_head, -0.25)));
1593
+
1594
+ // Kcross is already scaled
1595
+ struct ggml_tensor * Kcross =
1596
+ ggml_reshape_3d(ctxL,
1597
+ ggml_view_1d(ctxL, model.memory_cross_k, M*n_state, il*M*ggml_element_size(model.memory_cross_k)*n_state),
1598
+ n_state/n_head, n_head, M);
1599
+
1600
+ struct ggml_tensor * Vcross =
1601
+ ggml_reshape_3d(ctxL,
1602
+ ggml_view_1d(ctxL, model.memory_cross_v, M*n_state, il*M*ggml_element_size(model.memory_cross_v)*n_state),
1603
+ n_state/n_head, n_head, M);
1604
+
1605
+ // ------
1606
+
1607
+ struct ggml_tensor * Q =
1608
+ ggml_permute(ctxL,
1609
+ ggml_cpy(ctxL,
1610
+ Qcur,
1611
+ ggml_new_tensor_3d(ctxL, GGML_TYPE_F32, n_state/n_head, n_head, N)),
1612
+ 0, 2, 1, 3);
1613
+
1614
+ struct ggml_tensor * K = ggml_permute(ctxL, Kcross, 0, 2, 1, 3);
1615
+
1616
+ // K * Q
1617
+ struct ggml_tensor * KQ = ggml_mul_mat(ctxL, K, Q);
1618
+
1619
+ //struct ggml_tensor * KQ_scaled =
1620
+ // ggml_scale(ctxL,
1621
+ // KQ,
1622
+ // ggml_new_f32(ctxL, 1.0f/sqrt(float(n_state)/n_head))
1623
+ // );
1624
+
1625
+ // no masking for cross-attention
1626
+ //struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctxL, KQ_scaled, n_past);
1627
+
1628
+ struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctxL, KQ);
1629
+
1630
+ struct ggml_tensor * V_trans = ggml_permute(ctxL, Vcross, 1, 2, 0, 3);
1631
+
1632
+ struct ggml_tensor * KQV = ggml_mul_mat(ctxL, V_trans, KQ_soft_max);
1633
+
1634
+ struct ggml_tensor * KQV_merged = ggml_permute(ctxL, KQV, 0, 2, 1, 3);
1635
+
1636
+ // cur = KQV_merged.contiguous().view(n_state, N)
1637
+ cur = ggml_cpy(ctxL,
1638
+ KQV_merged,
1639
+ ggml_new_tensor_2d(ctxL, GGML_TYPE_F32, n_state, N));
1640
+ }
1641
+
1642
+ // projection
1643
+ {
1644
+ cur = ggml_mul_mat(ctxL,
1645
+ layer.cross_attn_ln_1_w,
1646
+ cur);
1647
+
1648
+ cur = ggml_add(ctxL,
1649
+ ggml_repeat(ctxL, layer.cross_attn_ln_1_b, cur),
1650
+ cur);
1651
+ }
1652
+
1653
+ // add the input
1654
+ cur = ggml_add(ctxL, cur, inpCA);
1655
+
1656
+ struct ggml_tensor * inpFF = cur;
1657
+
1658
+ // feed-forward network
1659
+ {
1660
+ // norm
1661
+ {
1662
+ cur = ggml_norm(ctxL, inpFF);
1663
+
1664
+ // cur = mlp_ln_w*cur + mlp_ln_b
1665
+ cur = ggml_add(ctxL,
1666
+ ggml_mul(ctxL,
1667
+ ggml_repeat(ctxL, layer.mlp_ln_w, cur),
1668
+ cur),
1669
+ ggml_repeat(ctxL, layer.mlp_ln_b, cur));
1670
+ }
1671
+
1672
+ // fully connected
1673
+ cur = ggml_mul_mat(ctxL,
1674
+ layer.mlp_0_w,
1675
+ cur);
1676
+
1677
+ cur = ggml_add(ctxL,
1678
+ ggml_repeat(ctxL, layer.mlp_0_b, cur),
1679
+ cur);
1680
+
1681
+ // GELU activation
1682
+ cur = ggml_gelu(ctxL, cur);
1683
+
1684
+ // projection
1685
+ cur = ggml_mul_mat(ctxL,
1686
+ layer.mlp_1_w,
1687
+ cur);
1688
+
1689
+ cur = ggml_add(ctxL,
1690
+ ggml_repeat(ctxL, layer.mlp_1_b, cur),
1691
+ cur);
1692
+ }
1693
+
1694
+ // output from this layer
1695
+ struct ggml_tensor * inpO = ggml_add(ctxL, cur, inpFF);
1696
+
1697
+ {
1698
+ ggml_build_forward_expand(&gf, inpO);
1699
+ ggml_graph_compute (ctxL, &gf);
1700
+
1701
+ //ggml_graph_print(&gf);
1702
+ }
1703
+
1704
+ // TODO: this is a hack to have per-layer computation graphs - need to come up with something better
1705
+ // input for next layer (inpO -> inpL)
1706
+ memcpy(inpL->data, inpO->data, ggml_nbytes(inpL));
1707
+ inpL->op = GGML_OP_NONE;
1708
+ inpL->src0 = NULL;
1709
+ inpL->src1 = NULL;
1710
+
1711
+ if (N > 1) {
1712
+ //printf("%s: - used_mem(%d) = %f MB\n", __func__, il, ggml_used_mem(ctxL)/1024.0/1024.0);
1713
+ }
1714
+
1715
+ ggml_free(ctxL);
1716
+ }
1717
+
1718
+ cur = inpL;
1719
+
1720
+ // norm
1721
+ {
1722
+ cur = ggml_norm(ctx0, cur);
1723
+
1724
+ cur = ggml_add(ctx0,
1725
+ ggml_mul(ctx0,
1726
+ ggml_repeat(ctx0, model.d_ln_w, cur),
1727
+ cur),
1728
+ ggml_repeat(ctx0, model.d_ln_b, cur));
1729
+ }
1730
+
1731
+ struct ggml_tensor * logits = ggml_mul_mat(ctx0, model.d_te, cur);
1732
+
1733
+ // logits -> probs
1734
+ cur = ggml_dup(ctx0, logits);
1735
+ cur = ggml_soft_max(ctx0, cur); // in-place
1736
+
1737
+ // run the computation
1738
+ {
1739
+ struct ggml_cgraph gf = { .n_threads = n_threads };
1740
+
1741
+ ggml_build_forward_expand(&gf, cur);
1742
+ ggml_graph_compute (ctx0, &gf);
1743
+ }
1744
+
1745
+ logits_out.resize(N*n_vocab);
1746
+ memcpy(logits_out.data(), ggml_get_data(logits), sizeof(float)*N*n_vocab);
1747
+
1748
+ probs_out.resize(N*n_vocab);
1749
+ memcpy(probs_out.data(), ggml_get_data(cur), sizeof(float)*N*n_vocab);
1750
+
1751
+ if (N > 1) {
1752
+ //const float mem_per_token = ggml_used_mem(ctx0)/1024.0/1024.0/N;
1753
+ //printf("%s: used_mem = %f MB / %f per token\n", __func__, ggml_used_mem(ctx0)/1024.0/1024.0, mem_per_token);
1754
+ //printf("%s: max mem = %f MB\n", __func__, mem_per_token*model.hparams.n_text_ctx);
1755
+ }
1756
+
1757
+ ggml_free(ctx0);
1758
+
1759
+ return true;
1760
+ }
1761
+
1762
+ // the most basic sampling scheme - select the top token
1763
+ // TODO: beam search
1764
+ // TODO: temperature
1765
+ whisper_vocab::id whisper_sample_best(
1766
+ const whisper_vocab & vocab,
1767
+ const float * probs, bool need_timestamp) {
1768
+ int n_logits = vocab.id_to_token.size();
1769
+
1770
+ std::vector<std::pair<double, whisper_vocab::id>> probs_id;
1771
+ probs_id.reserve(n_logits);
1772
+
1773
+ for (int i = 0; i < n_logits; i++) {
1774
+ probs_id.push_back(std::make_pair(probs[i], i));
1775
+ }
1776
+
1777
+ const int top_k = 4;
1778
+
1779
+ // find the top K tokens
1780
+ std::partial_sort(
1781
+ probs_id.begin(),
1782
+ probs_id.begin() + top_k, probs_id.end(),
1783
+ [](const std::pair<double, whisper_vocab::id> & a, const std::pair<double, whisper_vocab::id> & b) {
1784
+ return a.first > b.first;
1785
+ });
1786
+
1787
+ probs_id.resize(top_k);
1788
+
1789
+ //printf("\n");
1790
+ //for (int i = 0; i < (int) probs_id.size(); i++) {
1791
+ // printf("%d: '%s' %f, %d\n", i, vocab.id_to_token.at(probs_id[i].second).c_str(), probs_id[i].first, probs_id[i].second);
1792
+ //}
1793
+
1794
+ if (need_timestamp) {
1795
+ // at the end of the 30-second audio segment, we start giving preference to time tokens
1796
+ for (int i = 0; i < top_k; i++) {
1797
+ if (probs_id[i].second > vocab.token_beg + 1300 && probs_id[i].first > 0.01*probs_id[0].first) {
1798
+ return probs_id[i].second;
1799
+ }
1800
+ }
1801
+ }
1802
+
1803
+ int res = 0;
1804
+ while ((probs_id[res].second == vocab.token_sot ||
1805
+ probs_id[res].second == vocab.token_solm ||
1806
+ probs_id[res].second == vocab.token_not) &&
1807
+ res < (int) probs_id.size() - 1) {
1808
+ res++;
1809
+ }
1810
+
1811
+ return probs_id[res].second;
1812
+ }
1813
+
1814
+ // samples only from the timestamps tokens
1815
+ whisper_vocab::id whisper_sample_timestamp(
1816
+ const whisper_vocab & vocab,
1817
+ const float * probs) {
1818
+ int n_logits = vocab.id_to_token.size();
1819
+
1820
+ std::vector<std::pair<double, whisper_vocab::id>> probs_id;
1821
+ probs_id.reserve(n_logits);
1822
+
1823
+ for (int i = vocab.token_beg + 1; i < n_logits; i++) {
1824
+ probs_id.push_back(std::make_pair(probs[i], i));
1825
+ }
1826
+
1827
+ const int top_k = 10;
1828
+
1829
+ // find the top K tokens
1830
+ std::partial_sort(
1831
+ probs_id.begin(),
1832
+ probs_id.begin() + top_k, probs_id.end(),
1833
+ [](const std::pair<double, whisper_vocab::id> & a, const std::pair<double, whisper_vocab::id> & b) {
1834
+ return a.first > b.first;
1835
+ });
1836
+
1837
+ probs_id.resize(top_k);
1838
+
1839
+ //printf("\n");
1840
+ //for (int i = 0; i < (int) probs_id.size(); i++) {
1841
+ // printf("%d: '%s' %f, %d\n", i, vocab.id_to_token.at(probs_id[i].second).c_str(), probs_id[i].first, probs_id[i].second);
1842
+ //}
1843
+
1844
+ return probs_id[0].second;
1845
+ }
1846
+
1847
+ // naive Discrete Fourier Transform
1848
+ // input is real-valued
1849
+ // output is complex-valued
1850
+ void dft(const std::vector<float> & in, std::vector<float> & out) {
1851
+ int N = in.size();
1852
+
1853
+ out.resize(N*2);
1854
+
1855
+ for (int k = 0; k < N; k++) {
1856
+ float re = 0;
1857
+ float im = 0;
1858
+
1859
+ for (int n = 0; n < N; n++) {
1860
+ float angle = 2*M_PI*k*n/N;
1861
+ re += in[n]*cos(angle);
1862
+ im -= in[n]*sin(angle);
1863
+ }
1864
+
1865
+ out[k*2 + 0] = re;
1866
+ out[k*2 + 1] = im;
1867
+ }
1868
+ }
1869
+
1870
+ // Cooley-Tukey FFT
1871
+ // poor man's implementation - use something better
1872
+ // input is real-valued
1873
+ // output is complex-valued
1874
+ void fft(const std::vector<float> & in, std::vector<float> & out) {
1875
+ out.resize(in.size()*2);
1876
+
1877
+ int N = in.size();
1878
+
1879
+ if (N == 1) {
1880
+ out[0] = in[0];
1881
+ out[1] = 0;
1882
+ return;
1883
+ }
1884
+
1885
+ if (N%2 == 1) {
1886
+ dft(in, out);
1887
+ return;
1888
+ }
1889
+
1890
+ std::vector<float> even;
1891
+ std::vector<float> odd;
1892
+
1893
+ for (int i = 0; i < N; i++) {
1894
+ if (i % 2 == 0) {
1895
+ even.push_back(in[i]);
1896
+ } else {
1897
+ odd.push_back(in[i]);
1898
+ }
1899
+ }
1900
+
1901
+ std::vector<float> even_fft;
1902
+ std::vector<float> odd_fft;
1903
+
1904
+ fft(even, even_fft);
1905
+ fft(odd, odd_fft);
1906
+
1907
+ for (int k = 0; k < N/2; k++) {
1908
+ float theta = 2*M_PI*k/N;
1909
+
1910
+ float re = cos(theta);
1911
+ float im = -sin(theta);
1912
+
1913
+ float re_odd = odd_fft[2*k + 0];
1914
+ float im_odd = odd_fft[2*k + 1];
1915
+
1916
+ out[2*k + 0] = even_fft[2*k + 0] + re*re_odd - im*im_odd;
1917
+ out[2*k + 1] = even_fft[2*k + 1] + re*im_odd + im*re_odd;
1918
+
1919
+ out[2*(k + N/2) + 0] = even_fft[2*k + 0] - re*re_odd + im*im_odd;
1920
+ out[2*(k + N/2) + 1] = even_fft[2*k + 1] - re*im_odd - im*re_odd;
1921
+ }
1922
+ }
1923
+
1924
+ // ref: https://github.com/openai/whisper/blob/main/whisper/audio.py#L92-L124
1925
+ bool log_mel_spectrogram(
1926
+ const float * samples,
1927
+ const int n_samples,
1928
+ const int sample_rate,
1929
+ const int fft_size,
1930
+ const int fft_step,
1931
+ const int n_mel,
1932
+ const int n_threads,
1933
+ const whisper_filters & filters,
1934
+ whisper_mel & mel) {
1935
+
1936
+ // Hanning window
1937
+ std::vector<float> hann;
1938
+ hann.resize(fft_size);
1939
+ for (int i = 0; i < fft_size; i++) {
1940
+ hann[i] = 0.5*(1.0 - cos((2.0*M_PI*i)/(fft_size)));
1941
+ }
1942
+
1943
+ mel.n_mel = n_mel;
1944
+ mel.n_len = (n_samples)/fft_step;
1945
+ mel.data.resize(mel.n_mel*mel.n_len);
1946
+
1947
+ const int n_fft = 1 + fft_size/2;
1948
+
1949
+ printf("%s: n_samples = %d, n_len = %d\n", __func__, n_samples, mel.n_len);
1950
+ printf("%s: recording length: %f s\n", __func__, (float) n_samples/sample_rate);
1951
+
1952
+ std::vector<std::thread> workers(n_threads);
1953
+ for (int iw = 0; iw < n_threads; ++iw) {
1954
+ workers[iw] = std::thread([&](int ith) {
1955
+ std::vector<float> fft_in;
1956
+ fft_in.resize(fft_size);
1957
+ for (int i = 0; i < fft_size; i++) {
1958
+ fft_in[i] = 0.0;
1959
+ }
1960
+
1961
+ std::vector<float> fft_out;
1962
+ fft_out.resize(2*fft_size);
1963
+
1964
+ for (int i = ith; i < mel.n_len; i += n_threads) {
1965
+ const int offset = i*fft_step;
1966
+
1967
+ // apply Hanning window
1968
+ for (int j = 0; j < fft_size; j++) {
1969
+ if (offset + j < n_samples) {
1970
+ fft_in[j] = hann[j]*samples[offset + j];
1971
+ } else {
1972
+ fft_in[j] = 0.0;
1973
+ }
1974
+ }
1975
+
1976
+ // FFT -> mag^2
1977
+ fft(fft_in, fft_out);
1978
+
1979
+ for (int j = 0; j < fft_size; j++) {
1980
+ fft_out[j] = (fft_out[2*j + 0]*fft_out[2*j + 0] + fft_out[2*j + 1]*fft_out[2*j + 1]);
1981
+ }
1982
+ for (int j = 1; j < fft_size/2; j++) {
1983
+ //if (i == 0) {
1984
+ // printf("%d: %f %f\n", j, fft_out[j], fft_out[fft_size - j]);
1985
+ //}
1986
+ fft_out[j] += fft_out[fft_size - j];
1987
+ }
1988
+ if (i == 0) {
1989
+ //for (int j = 0; j < fft_size; j++) {
1990
+ // printf("%d: %e\n", j, fft_out[j]);
1991
+ //}
1992
+ }
1993
+
1994
+ // mel spectrogram
1995
+ for (int j = 0; j < mel.n_mel; j++) {
1996
+ double sum = 0.0;
1997
+
1998
+ for (int k = 0; k < n_fft; k++) {
1999
+ sum += fft_out[k]*filters.data[j*n_fft + k];
2000
+ }
2001
+ if (sum < 1e-10) {
2002
+ sum = 1e-10;
2003
+ }
2004
+
2005
+ sum = log10(sum);
2006
+
2007
+ mel.data[j*mel.n_len + i] = sum;
2008
+ }
2009
+ }
2010
+ }, iw);
2011
+ }
2012
+
2013
+ for (int iw = 0; iw < n_threads; ++iw) {
2014
+ workers[iw].join();
2015
+ }
2016
+
2017
+ // clamping and normalization
2018
+ double mmax = -1e20;
2019
+ for (int i = 0; i < mel.n_mel*mel.n_len; i++) {
2020
+ if (mel.data[i] > mmax) {
2021
+ mmax = mel.data[i];
2022
+ }
2023
+ }
2024
+ //printf("%s: max = %f\n", __func__, mmax);
2025
+
2026
+ mmax -= 8.0;
2027
+
2028
+ for (int i = 0; i < mel.n_mel*mel.n_len; i++) {
2029
+ if (mel.data[i] < mmax) {
2030
+ mel.data[i] = mmax;
2031
+ }
2032
+
2033
+ mel.data[i] = (mel.data[i] + 4.0)/4.0;
2034
+ }
2035
+
2036
+ return true;
2037
+ }
2038
+
2039
+ //
2040
+ // interface implementation
2041
+ //
2042
+
2043
+ struct whisper_context * whisper_init(const char * path_model) {
2044
+ whisper_context * ctx = new whisper_context;
2045
+
2046
+ const int64_t t_start_us = ggml_time_us();
2047
+
2048
+ ctx->t_start_us = t_start_us;
2049
+
2050
+ if (!whisper_model_load(path_model, *ctx)) {
2051
+ fprintf(stderr, "%s: failed to load model from '%s'\n", __func__, path_model);
2052
+ return NULL;
2053
+ }
2054
+
2055
+ ctx->t_load_us = ggml_time_us() - t_start_us;
2056
+
2057
+ return ctx;
2058
+ }
2059
+
2060
+ void whisper_free(struct whisper_context * ctx) {
2061
+ if (ctx) {
2062
+ delete ctx;
2063
+ }
2064
+ }
2065
+
2066
+ int whisper_pcm_to_mel(struct whisper_context * ctx, const float * samples, int n_samples, int n_threads) {
2067
+ const int64_t t_start_us = ggml_time_us();
2068
+
2069
+ if (!log_mel_spectrogram(samples, n_samples, SAMPLE_RATE, N_FFT, HOP_LENGTH, N_MEL, n_threads, ctx->model.filters, ctx->mel)) {
2070
+ fprintf(stderr, "%s: failed to compute mel spectrogram\n", __func__);
2071
+ return -1;
2072
+ }
2073
+
2074
+ ctx->t_mel_us = ggml_time_us() - t_start_us;
2075
+
2076
+ return 0;
2077
+ }
2078
+
2079
+ int whisper_set_mel(
2080
+ struct whisper_context * ctx,
2081
+ const float * data,
2082
+ int n_len,
2083
+ int n_mel) {
2084
+ if (n_mel != N_MEL) {
2085
+ fprintf(stderr, "%s: invalid number of mel bands: %d (expected %d)\n", __func__, n_mel, N_MEL);
2086
+ return -1;
2087
+ }
2088
+
2089
+ ctx->mel.n_len = n_len;
2090
+ ctx->mel.n_mel = n_mel;
2091
+
2092
+ ctx->mel.data.resize(n_len*n_mel);
2093
+ memcpy(ctx->mel.data.data(), data, n_len*n_mel*sizeof(float));
2094
+
2095
+ return 0;
2096
+ }
2097
+
2098
+ int whisper_encode(struct whisper_context * ctx, int offset, int n_threads) {
2099
+ const int64_t t_start_us = ggml_time_us();
2100
+
2101
+ if (!whisper_encode(*ctx, n_threads, offset)) {
2102
+ fprintf(stderr, "%s: failed to eval\n", __func__);
2103
+ return -1;
2104
+ }
2105
+
2106
+ ctx->t_encode_us += ggml_time_us() - t_start_us;
2107
+
2108
+ return 0;
2109
+ }
2110
+
2111
+ int whisper_decode(struct whisper_context * ctx, const whisper_token * tokens, int n_tokens, int n_past, int n_threads) {
2112
+ const int64_t t_start_us = ggml_time_us();
2113
+
2114
+ if (!whisper_decode(*ctx, n_threads, tokens, n_tokens, n_past)) {
2115
+ fprintf(stderr, "%s: failed to eval\n", __func__);
2116
+ return 1;
2117
+ }
2118
+
2119
+ ctx->t_decode_us += ggml_time_us() - t_start_us;
2120
+
2121
+ return 0;
2122
+ }
2123
+
2124
+ whisper_token whisper_sample_best(struct whisper_context * ctx, bool need_timestamp) {
2125
+ const int64_t t_start_sample_us = ggml_time_us();
2126
+
2127
+ // TODO: simplify
2128
+ auto res = whisper_sample_best(ctx->vocab, ctx->probs.data() + (ctx->probs.size() - ctx->vocab.n_vocab), need_timestamp);
2129
+
2130
+ ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
2131
+
2132
+ return res;
2133
+ }
2134
+
2135
+ whisper_token whisper_sample_timestamp(struct whisper_context * ctx) {
2136
+ const int64_t t_start_sample_us = ggml_time_us();
2137
+
2138
+ // TODO: simplify
2139
+ auto res = whisper_sample_timestamp(ctx->vocab, ctx->probs.data() + (ctx->probs.size() - ctx->vocab.n_vocab));
2140
+
2141
+ ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
2142
+
2143
+ return res;
2144
+ }
2145
+
2146
+ int whisper_lang_id(const char * lang) {
2147
+ if (!g_lang.count(lang)) {
2148
+ fprintf(stderr, "%s: unknown language '%s'\n", __func__, lang);
2149
+ return -1;
2150
+ }
2151
+
2152
+ return g_lang.at(lang).first;
2153
+ }
2154
+
2155
+ int whisper_n_len(struct whisper_context * ctx) {
2156
+ return ctx->mel.n_len;
2157
+ }
2158
+
2159
+ int whisper_n_vocab(struct whisper_context * ctx) {
2160
+ return ctx->vocab.n_vocab;
2161
+ }
2162
+
2163
+ int whisper_n_text_ctx(struct whisper_context * ctx) {
2164
+ return ctx->model.hparams.n_text_ctx;
2165
+ }
2166
+
2167
+ int whisper_is_multilingual(struct whisper_context * ctx) {
2168
+ return ctx->vocab.is_multilingual() ? 1 : 0;
2169
+ }
2170
+
2171
+ float * whisper_get_probs(struct whisper_context * ctx) {
2172
+ return ctx->probs.data();
2173
+ }
2174
+
2175
+ const char * whisper_token_to_str(struct whisper_context * ctx, whisper_token token) {
2176
+ return ctx->vocab.id_to_token.at(token).c_str();
2177
+ }
2178
+
2179
+ whisper_token whisper_token_eot(struct whisper_context * ctx) {
2180
+ return ctx->vocab.token_eot;
2181
+ }
2182
+
2183
+ whisper_token whisper_token_sot(struct whisper_context * ctx) {
2184
+ return ctx->vocab.token_sot;
2185
+ }
2186
+
2187
+ whisper_token whisper_token_prev(struct whisper_context * ctx) {
2188
+ return ctx->vocab.token_prev;
2189
+ }
2190
+
2191
+ whisper_token whisper_token_solm(struct whisper_context * ctx) {
2192
+ return ctx->vocab.token_solm;
2193
+ }
2194
+
2195
+ whisper_token whisper_token_not(struct whisper_context * ctx) {
2196
+ return ctx->vocab.token_not;
2197
+ }
2198
+
2199
+ whisper_token whisper_token_beg(struct whisper_context * ctx) {
2200
+ return ctx->vocab.token_beg;
2201
+ }
2202
+
2203
+ whisper_token whisper_token_translate() {
2204
+ return whisper_vocab::token_translate;
2205
+ }
2206
+
2207
+ whisper_token whisper_token_transcribe() {
2208
+ return whisper_vocab::token_transcribe;
2209
+ }
2210
+
2211
+ void whisper_print_timings(struct whisper_context * ctx) {
2212
+ const int64_t t_end_us = ggml_time_us();
2213
+
2214
+ printf("\n\n");
2215
+ printf("%s: load time = %8.2f ms\n", __func__, ctx->t_load_us/1000.0f);
2216
+ printf("%s: mel time = %8.2f ms\n", __func__, ctx->t_mel_us/1000.0f);
2217
+ printf("%s: sample time = %8.2f ms\n", __func__, ctx->t_sample_us/1000.0f);
2218
+ printf("%s: encode time = %8.2f ms / %.2f ms per layer\n", __func__, ctx->t_encode_us/1000.0f, ctx->t_encode_us/1000.0f/ctx->model.hparams.n_audio_layer);
2219
+ printf("%s: decode time = %8.2f ms / %.2f ms per layer\n", __func__, ctx->t_decode_us/1000.0f, ctx->t_decode_us/1000.0f/ctx->model.hparams.n_text_layer);
2220
+ printf("%s: total time = %8.2f ms\n", __func__, (t_end_us - ctx->t_start_us)/1000.0f);
2221
+ }
whisper.h ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #ifndef WHISPER_H
2
+ #define WHISPER_H
3
+
4
+ #ifdef WHISPER_SHARED
5
+ # ifdef _WIN32
6
+ # ifdef WHISPER_BUILD
7
+ # define WHISPER_API __declspec(dllexport)
8
+ # else
9
+ # define WHISPER_API __declspec(dllimport)
10
+ # endif
11
+ # else
12
+ # define WHISPER_API __attribute__ ((visibility ("default")))
13
+ # endif
14
+ #else
15
+ # define WHISPER_API
16
+ #endif
17
+
18
+ #ifdef __cplusplus
19
+ extern "C" {
20
+ #endif
21
+
22
+ //
23
+ // C interface
24
+ //
25
+
26
+ #define SAMPLE_RATE 16000
27
+ #define N_FFT 400
28
+ #define N_MEL 80
29
+ #define HOP_LENGTH 160
30
+ #define CHUNK_SIZE 30
31
+
32
+ // TODO: documentation will come soon
33
+
34
+ struct whisper_context;
35
+
36
+ typedef int whisper_token;
37
+
38
+ WHISPER_API struct whisper_context * whisper_init(const char * path_model);
39
+ WHISPER_API void whisper_free(struct whisper_context * ctx);
40
+
41
+ WHISPER_API int whisper_pcm_to_mel(
42
+ struct whisper_context * ctx,
43
+ const float * samples,
44
+ int n_samples,
45
+ int n_threads);
46
+
47
+ // n_mel must be 80
48
+ WHISPER_API int whisper_set_mel(
49
+ struct whisper_context * ctx,
50
+ const float * data,
51
+ int n_len,
52
+ int n_mel);
53
+
54
+ WHISPER_API int whisper_encode(
55
+ struct whisper_context * ctx,
56
+ int offset,
57
+ int n_threads);
58
+
59
+ WHISPER_API int whisper_decode(
60
+ struct whisper_context * ctx,
61
+ const whisper_token * tokens,
62
+ int n_tokens,
63
+ int n_past,
64
+ int n_threads);
65
+
66
+ WHISPER_API whisper_token whisper_sample_best(struct whisper_context * ctx, bool need_timestamp);
67
+ WHISPER_API whisper_token whisper_sample_timestamp(struct whisper_context * ctx);
68
+
69
+ // return the id of the specified language, returns -1 if not found
70
+ WHISPER_API int whisper_lang_id(const char * lang);
71
+
72
+ WHISPER_API int whisper_n_len (struct whisper_context * ctx); // mel length
73
+ WHISPER_API int whisper_n_vocab (struct whisper_context * ctx);
74
+ WHISPER_API int whisper_n_text_ctx (struct whisper_context * ctx);
75
+ WHISPER_API int whisper_is_multilingual(struct whisper_context * ctx);
76
+ WHISPER_API float * whisper_get_probs (struct whisper_context * ctx);
77
+
78
+ WHISPER_API const char * whisper_token_to_str(struct whisper_context * ctx, whisper_token token);
79
+
80
+ WHISPER_API whisper_token whisper_token_eot (struct whisper_context * ctx);
81
+ WHISPER_API whisper_token whisper_token_sot (struct whisper_context * ctx);
82
+ WHISPER_API whisper_token whisper_token_prev(struct whisper_context * ctx);
83
+ WHISPER_API whisper_token whisper_token_solm(struct whisper_context * ctx);
84
+ WHISPER_API whisper_token whisper_token_not (struct whisper_context * ctx);
85
+ WHISPER_API whisper_token whisper_token_beg (struct whisper_context * ctx);
86
+
87
+ WHISPER_API whisper_token whisper_token_translate ();
88
+ WHISPER_API whisper_token whisper_token_transcribe();
89
+
90
+ WHISPER_API void whisper_print_timings(struct whisper_context * ctx);
91
+
92
+ ////////////////////////////////////////////////////////////////////////////
93
+
94
+ enum whisper_decode_strategy {
95
+ WHISPER_DECODE_GREEDY,
96
+ WHISPER_DECODE_BEAM_SEARCH,
97
+ };
98
+
99
+ struct whisper_full_params {
100
+ enum whisper_decode_strategy strategy;
101
+
102
+ int n_threads;
103
+
104
+ bool transcribe;
105
+
106
+ const char * language;
107
+
108
+ union {
109
+ struct {
110
+ int n_past;
111
+ } greedy;
112
+
113
+ struct {
114
+ int n_past;
115
+ int beam_width;
116
+ int n_best;
117
+ } beam_search;
118
+ };
119
+ };
120
+
121
+ // full whisper run - encode + decode
122
+ // TODO: implement
123
+ WHISPER_API int whisper_full(
124
+ struct whisper_context * ctx,
125
+ struct whisper_full_params * params,
126
+ const float * samples,
127
+ int n_samples);
128
+
129
+ #ifdef __cplusplus
130
+ }
131
+ #endif
132
+
133
+ #endif