ggerganov commited on
Commit
a488eb7
·
unverified ·
1 Parent(s): 7a2aa88

main : add "--prompt" command line argument (#90)

Browse files

This allows to provide an initial prompt to be used at the start of the
processing.

Files changed (1) hide show
  1. examples/main/main.cpp +24 -3
examples/main/main.cpp CHANGED
@@ -73,8 +73,9 @@ struct whisper_params {
73
  bool print_colors = false;
74
  bool no_timestamps = false;
75
 
76
- std::string language = "en";
77
- std::string model = "models/ggml-base.en.bin";
 
78
 
79
  std::vector<std::string> fname_inp = {};
80
  };
@@ -113,6 +114,7 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
113
  else if (arg == "-pc" || arg == "--print-colors") { params.print_colors = true; }
114
  else if (arg == "-nt" || arg == "--no-timestamps") { params.no_timestamps = true; }
115
  else if (arg == "-l" || arg == "--language") { params.language = argv[++i]; }
 
116
  else if (arg == "-m" || arg == "--model") { params.model = argv[++i]; }
117
  else if (arg == "-f" || arg == "--file") { params.fname_inp.push_back(argv[++i]); }
118
  else {
@@ -150,6 +152,7 @@ void whisper_print_usage(int argc, char ** argv, const whisper_params & params)
150
  fprintf(stderr, " -pc, --print-colors [%-7s] print colors\n", params.print_colors ? "true" : "false");
151
  fprintf(stderr, " -nt, --no-timestamps [%-7s] do not print timestamps\n", params.no_timestamps ? "false" : "true");
152
  fprintf(stderr, " -l LANG, --language LANG [%-7s] spoken language\n", params.language.c_str());
 
153
  fprintf(stderr, " -m FNAME, --model FNAME [%-7s] model path\n", params.model.c_str());
154
  fprintf(stderr, " -f FNAME, --file FNAME [%-7s] input WAV file path\n", "");
155
  fprintf(stderr, "\n");
@@ -462,6 +465,22 @@ int main(int argc, char ** argv) {
462
  return 3;
463
  }
464
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
465
  for (int f = 0; f < (int) params.fname_inp.size(); ++f) {
466
  const auto fname_inp = params.fname_inp[f];
467
 
@@ -577,7 +596,6 @@ int main(int argc, char ** argv) {
577
  fprintf(stderr, "\n");
578
  }
579
 
580
-
581
  // run the inference
582
  {
583
  whisper_full_params wparams = whisper_full_default_params(WHISPER_SAMPLING_GREEDY);
@@ -599,6 +617,9 @@ int main(int argc, char ** argv) {
599
 
600
  wparams.speed_up = params.speed_up;
601
 
 
 
 
602
  whisper_print_user_data user_data = { &params, &pcmf32s };
603
 
604
  // this callback is called on each new segment
 
73
  bool print_colors = false;
74
  bool no_timestamps = false;
75
 
76
+ std::string language = "en";
77
+ std::string prompt = "";
78
+ std::string model = "models/ggml-base.en.bin";
79
 
80
  std::vector<std::string> fname_inp = {};
81
  };
 
114
  else if (arg == "-pc" || arg == "--print-colors") { params.print_colors = true; }
115
  else if (arg == "-nt" || arg == "--no-timestamps") { params.no_timestamps = true; }
116
  else if (arg == "-l" || arg == "--language") { params.language = argv[++i]; }
117
+ else if ( arg == "--prompt") { params.prompt = argv[++i]; }
118
  else if (arg == "-m" || arg == "--model") { params.model = argv[++i]; }
119
  else if (arg == "-f" || arg == "--file") { params.fname_inp.push_back(argv[++i]); }
120
  else {
 
152
  fprintf(stderr, " -pc, --print-colors [%-7s] print colors\n", params.print_colors ? "true" : "false");
153
  fprintf(stderr, " -nt, --no-timestamps [%-7s] do not print timestamps\n", params.no_timestamps ? "false" : "true");
154
  fprintf(stderr, " -l LANG, --language LANG [%-7s] spoken language\n", params.language.c_str());
155
+ fprintf(stderr, " --prompt PROMPT [%-7s] initial prompt\n", params.prompt.c_str());
156
  fprintf(stderr, " -m FNAME, --model FNAME [%-7s] model path\n", params.model.c_str());
157
  fprintf(stderr, " -f FNAME, --file FNAME [%-7s] input WAV file path\n", "");
158
  fprintf(stderr, "\n");
 
465
  return 3;
466
  }
467
 
468
+ // initial prompt
469
+ std::vector<whisper_token> prompt_tokens;
470
+
471
+ if (params.prompt.size() > 0) {
472
+ prompt_tokens.resize(1024);
473
+ prompt_tokens.resize(whisper_tokenize(ctx, params.prompt.c_str(), prompt_tokens.data(), prompt_tokens.size()));
474
+
475
+ fprintf(stderr, "\n");
476
+ fprintf(stderr, "initial prompt: '%s'\n", params.prompt.c_str());
477
+ fprintf(stderr, "initial tokens: [ ");
478
+ for (int i = 0; i < (int) prompt_tokens.size(); ++i) {
479
+ fprintf(stderr, "%d ", prompt_tokens[i]);
480
+ }
481
+ fprintf(stderr, "]\n");
482
+ }
483
+
484
  for (int f = 0; f < (int) params.fname_inp.size(); ++f) {
485
  const auto fname_inp = params.fname_inp[f];
486
 
 
596
  fprintf(stderr, "\n");
597
  }
598
 
 
599
  // run the inference
600
  {
601
  whisper_full_params wparams = whisper_full_default_params(WHISPER_SAMPLING_GREEDY);
 
617
 
618
  wparams.speed_up = params.speed_up;
619
 
620
+ wparams.prompt_tokens = prompt_tokens.size() == 0 ? nullptr : prompt_tokens.data();
621
+ wparams.prompt_n_tokens = prompt_tokens.size() == 0 ? 0 : prompt_tokens.size();
622
+
623
  whisper_print_user_data user_data = { &params, &pcmf32s };
624
 
625
  // this callback is called on each new segment