ggerganov commited on
Commit
9ed1355
·
1 Parent(s): cafe46d

main : fix sampling time + add max_context parameter

Browse files
Files changed (3) hide show
  1. examples/main/main.cpp +6 -0
  2. whisper.cpp +8 -14
  3. whisper.h +11 -2
examples/main/main.cpp CHANGED
@@ -42,6 +42,7 @@ struct whisper_params {
42
  int32_t n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency());
43
  int32_t offset_t_ms = 0;
44
  int32_t offset_n = 0;
 
45
 
46
  bool verbose = false;
47
  bool translate = false;
@@ -77,6 +78,8 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
77
  params.offset_t_ms = std::stoi(argv[++i]);
78
  } else if (arg == "-on" || arg == "--offset-n") {
79
  params.offset_n = std::stoi(argv[++i]);
 
 
80
  } else if (arg == "-v" || arg == "--verbose") {
81
  params.verbose = true;
82
  } else if (arg == "--translate") {
@@ -127,6 +130,7 @@ void whisper_print_usage(int argc, char ** argv, const whisper_params & params)
127
  fprintf(stderr, " -t N, --threads N number of threads to use during computation (default: %d)\n", params.n_threads);
128
  fprintf(stderr, " -ot N, --offset-t N time offset in milliseconds (default: %d)\n", params.offset_t_ms);
129
  fprintf(stderr, " -on N, --offset-n N segment index offset (default: %d)\n", params.offset_n);
 
130
  fprintf(stderr, " -v, --verbose verbose output\n");
131
  fprintf(stderr, " --translate translate from source language to english\n");
132
  fprintf(stderr, " -otxt, --output-txt output result in a text file\n");
@@ -380,6 +384,8 @@ int main(int argc, char ** argv) {
380
  wparams.translate = params.translate;
381
  wparams.language = params.language.c_str();
382
  wparams.n_threads = params.n_threads;
 
 
383
  wparams.offset_ms = params.offset_t_ms;
384
 
385
  // this callback is called on each new segment
 
42
  int32_t n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency());
43
  int32_t offset_t_ms = 0;
44
  int32_t offset_n = 0;
45
+ int32_t max_context = -1;
46
 
47
  bool verbose = false;
48
  bool translate = false;
 
78
  params.offset_t_ms = std::stoi(argv[++i]);
79
  } else if (arg == "-on" || arg == "--offset-n") {
80
  params.offset_n = std::stoi(argv[++i]);
81
+ } else if (arg == "-mc" || arg == "--max-context") {
82
+ params.max_context = std::stoi(argv[++i]);
83
  } else if (arg == "-v" || arg == "--verbose") {
84
  params.verbose = true;
85
  } else if (arg == "--translate") {
 
130
  fprintf(stderr, " -t N, --threads N number of threads to use during computation (default: %d)\n", params.n_threads);
131
  fprintf(stderr, " -ot N, --offset-t N time offset in milliseconds (default: %d)\n", params.offset_t_ms);
132
  fprintf(stderr, " -on N, --offset-n N segment index offset (default: %d)\n", params.offset_n);
133
+ fprintf(stderr, " -mc N, --max-context N maximum number of text context tokens to store (default: max)\n");
134
  fprintf(stderr, " -v, --verbose verbose output\n");
135
  fprintf(stderr, " --translate translate from source language to english\n");
136
  fprintf(stderr, " -otxt, --output-txt output result in a text file\n");
 
384
  wparams.translate = params.translate;
385
  wparams.language = params.language.c_str();
386
  wparams.n_threads = params.n_threads;
387
+ wparams.n_processors = 1;
388
+ wparams.n_max_text_ctx = params.max_context >= 0 ? params.max_context : wparams.n_max_text_ctx;
389
  wparams.offset_ms = params.offset_t_ms;
390
 
391
  // this callback is called on each new segment
whisper.cpp CHANGED
@@ -211,14 +211,6 @@ struct whisper_vocab {
211
  }
212
  };
213
 
214
- struct whisper_token_data {
215
- whisper_token id; // token id
216
- whisper_token tid; // forced timestamp token id
217
-
218
- float p; // probability of the token
219
- float pt; // probability of the timestamp token
220
- };
221
-
222
  struct whisper_segment {
223
  int64_t t0;
224
  int64_t t1;
@@ -2219,7 +2211,7 @@ int whisper_decode(struct whisper_context * ctx, const whisper_token * tokens, i
2219
  return 0;
2220
  }
2221
 
2222
- whisper_token whisper_sample_best(struct whisper_context * ctx) {
2223
  const int64_t t_start_sample_us = ggml_time_us();
2224
 
2225
  // TODO: simplify
@@ -2227,7 +2219,7 @@ whisper_token whisper_sample_best(struct whisper_context * ctx) {
2227
 
2228
  ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
2229
 
2230
- return res.id;
2231
  }
2232
 
2233
  whisper_token whisper_sample_timestamp(struct whisper_context * ctx) {
@@ -2330,8 +2322,9 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
2330
  /*.strategy =*/ WHISPER_SAMPLING_GREEDY,
2331
 
2332
  /*.n_threads =*/ std::min(4, (int32_t) std::thread::hardware_concurrency()),
2333
- /*.offset_ms =*/ 0,
2334
  /*.n_processors =*/ 1,
 
 
2335
 
2336
  /*.translate =*/ false,
2337
  /*.no_context =*/ false,
@@ -2362,8 +2355,9 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
2362
  /*.strategy =*/ WHISPER_SAMPLING_BEAM_SEARCH,
2363
 
2364
  /*.n_threads =*/ std::min(4, (int32_t) std::thread::hardware_concurrency()),
2365
- /*.offset_ms =*/ 0,
2366
  /*.n_processors =*/ 1,
 
 
2367
 
2368
  /*.translate =*/ false,
2369
  /*.no_context =*/ false,
@@ -2470,7 +2464,7 @@ int whisper_full(
2470
 
2471
  // if we have already generated some text, use it as a prompt to condition the next generation
2472
  if (prompt_past.size() > 0) {
2473
- int n_take = std::min(whisper_n_text_ctx(ctx)/2, int(prompt_past.size()));
2474
 
2475
  prompt = { whisper_token_prev(ctx) };
2476
  prompt.insert(prompt.begin() + 1, prompt_past.end() - n_take, prompt_past.end());
@@ -2512,7 +2506,7 @@ int whisper_full(
2512
  // feel free to experiment!
2513
  //
2514
  {
2515
- auto token = whisper_sample_best(ctx->vocab, ctx->probs.data() + (ctx->probs.size() - ctx->vocab.n_vocab));
2516
 
2517
  if (i == 0) {
2518
  token.tid = whisper_token_beg(ctx);
 
211
  }
212
  };
213
 
 
 
 
 
 
 
 
 
214
  struct whisper_segment {
215
  int64_t t0;
216
  int64_t t1;
 
2211
  return 0;
2212
  }
2213
 
2214
+ whisper_token_data whisper_sample_best(struct whisper_context * ctx) {
2215
  const int64_t t_start_sample_us = ggml_time_us();
2216
 
2217
  // TODO: simplify
 
2219
 
2220
  ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
2221
 
2222
+ return res;
2223
  }
2224
 
2225
  whisper_token whisper_sample_timestamp(struct whisper_context * ctx) {
 
2322
  /*.strategy =*/ WHISPER_SAMPLING_GREEDY,
2323
 
2324
  /*.n_threads =*/ std::min(4, (int32_t) std::thread::hardware_concurrency()),
 
2325
  /*.n_processors =*/ 1,
2326
+ /*.n_max_text_ctx =*/ 16384,
2327
+ /*.offset_ms =*/ 0,
2328
 
2329
  /*.translate =*/ false,
2330
  /*.no_context =*/ false,
 
2355
  /*.strategy =*/ WHISPER_SAMPLING_BEAM_SEARCH,
2356
 
2357
  /*.n_threads =*/ std::min(4, (int32_t) std::thread::hardware_concurrency()),
 
2358
  /*.n_processors =*/ 1,
2359
+ /*.n_max_text_ctx =*/ 16384,
2360
+ /*.offset_ms =*/ 0,
2361
 
2362
  /*.translate =*/ false,
2363
  /*.no_context =*/ false,
 
2464
 
2465
  // if we have already generated some text, use it as a prompt to condition the next generation
2466
  if (prompt_past.size() > 0) {
2467
+ int n_take = std::min(std::min(params.n_max_text_ctx, whisper_n_text_ctx(ctx)/2), int(prompt_past.size()));
2468
 
2469
  prompt = { whisper_token_prev(ctx) };
2470
  prompt.insert(prompt.begin() + 1, prompt_past.end() - n_take, prompt_past.end());
 
2506
  // feel free to experiment!
2507
  //
2508
  {
2509
+ auto token = whisper_sample_best(ctx);
2510
 
2511
  if (i == 0) {
2512
  token.tid = whisper_token_beg(ctx);
whisper.h CHANGED
@@ -68,6 +68,14 @@ extern "C" {
68
 
69
  typedef int whisper_token;
70
 
 
 
 
 
 
 
 
 
71
  // Allocates all memory needed for the model and loads the model from the given file.
72
  // Returns NULL on failure.
73
  WHISPER_API struct whisper_context * whisper_init(const char * path_model);
@@ -122,7 +130,7 @@ extern "C" {
122
  // You can also implement your own sampling method using the whisper_get_probs() function.
123
  // whisper_sample_best() returns the token with the highest probability
124
  // whisper_sample_timestamp() returns the most probable timestamp token
125
- WHISPER_API whisper_token whisper_sample_best(struct whisper_context * ctx);
126
  WHISPER_API whisper_token whisper_sample_timestamp(struct whisper_context * ctx);
127
 
128
  // Return the id of the specified language, returns -1 if not found
@@ -171,8 +179,9 @@ extern "C" {
171
  enum whisper_sampling_strategy strategy;
172
 
173
  int n_threads;
174
- int offset_ms;
175
  int n_processors;
 
 
176
 
177
  bool translate;
178
  bool no_context;
 
68
 
69
  typedef int whisper_token;
70
 
71
+ struct whisper_token_data {
72
+ whisper_token id; // token id
73
+ whisper_token tid; // forced timestamp token id
74
+
75
+ float p; // probability of the token
76
+ float pt; // probability of the timestamp token
77
+ };
78
+
79
  // Allocates all memory needed for the model and loads the model from the given file.
80
  // Returns NULL on failure.
81
  WHISPER_API struct whisper_context * whisper_init(const char * path_model);
 
130
  // You can also implement your own sampling method using the whisper_get_probs() function.
131
  // whisper_sample_best() returns the token with the highest probability
132
  // whisper_sample_timestamp() returns the most probable timestamp token
133
+ WHISPER_API whisper_token_data whisper_sample_best(struct whisper_context * ctx);
134
  WHISPER_API whisper_token whisper_sample_timestamp(struct whisper_context * ctx);
135
 
136
  // Return the id of the specified language, returns -1 if not found
 
179
  enum whisper_sampling_strategy strategy;
180
 
181
  int n_threads;
 
182
  int n_processors;
183
+ int n_max_text_ctx;
184
+ int offset_ms;
185
 
186
  bool translate;
187
  bool no_context;