Andy Maloney commited on
Commit
e0255d4
·
unverified ·
1 Parent(s): 550fbf8

examples : small code cleanups (#322)

Browse files

- remove unnecessary initialization of string to ""
- use empty() instead of checking size()
- use emplace_back instead of push_back
- use nullptr instead of NULL
- remove unnecessary call to .data() on string
- use character overload of find_first_of() instead of passing a string

examples/command/command.cpp CHANGED
@@ -41,8 +41,8 @@ struct whisper_params {
41
 
42
  std::string language = "en";
43
  std::string model = "models/ggml-base.en.bin";
44
- std::string fname_out = "";
45
- std::string commands = "";
46
  };
47
 
48
  void whisper_print_usage(int argc, char ** argv, const whisper_params & params);
@@ -576,10 +576,10 @@ int main(int argc, char ** argv) {
576
  std::vector<std::string> allowed_commands;
577
  std::vector<std::vector<whisper_token>> allowed_tokens;
578
 
579
- std::string k_prompt = "";
580
  std::vector<whisper_token> k_tokens;
581
 
582
- if (params.commands != "") {
583
  fprintf(stderr, "\n");
584
  fprintf(stderr, "%s: guided mode\n", __func__);
585
 
@@ -808,7 +808,7 @@ int main(int argc, char ** argv) {
808
 
809
  double psum = 0.0;
810
  for (int i = 0; i < (int) allowed_commands.size(); ++i) {
811
- probs_id.push_back(std::make_pair(probs[allowed_tokens[i][0]], i));
812
  for (int j = 1; j < (int) allowed_tokens[i].size(); ++j) {
813
  probs_id.back().first += probs[allowed_tokens[i][j]];
814
  }
 
41
 
42
  std::string language = "en";
43
  std::string model = "models/ggml-base.en.bin";
44
+ std::string fname_out;
45
+ std::string commands;
46
  };
47
 
48
  void whisper_print_usage(int argc, char ** argv, const whisper_params & params);
 
576
  std::vector<std::string> allowed_commands;
577
  std::vector<std::vector<whisper_token>> allowed_tokens;
578
 
579
+ std::string k_prompt;
580
  std::vector<whisper_token> k_tokens;
581
 
582
+ if (!params.commands.empty()) {
583
  fprintf(stderr, "\n");
584
  fprintf(stderr, "%s: guided mode\n", __func__);
585
 
 
808
 
809
  double psum = 0.0;
810
  for (int i = 0; i < (int) allowed_commands.size(); ++i) {
811
+ probs_id.emplace_back(probs[allowed_tokens[i][0]], i);
812
  for (int j = 1; j < (int) allowed_tokens[i].size(); ++j) {
813
  probs_id.back().first += probs[allowed_tokens[i][j]];
814
  }
examples/main/main.cpp CHANGED
@@ -75,7 +75,7 @@ struct whisper_params {
75
  bool no_timestamps = false;
76
 
77
  std::string language = "en";
78
- std::string prompt = "";
79
  std::string model = "models/ggml-base.en.bin";
80
 
81
  std::vector<std::string> fname_inp = {};
@@ -118,7 +118,7 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
118
  else if (arg == "-l" || arg == "--language") { params.language = argv[++i]; }
119
  else if ( arg == "--prompt") { params.prompt = argv[++i]; }
120
  else if (arg == "-m" || arg == "--model") { params.model = argv[++i]; }
121
- else if (arg == "-f" || arg == "--file") { params.fname_inp.push_back(argv[++i]); }
122
  else {
123
  fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
124
  whisper_print_usage(argc, argv, params);
@@ -206,7 +206,7 @@ void whisper_print_segment_callback(struct whisper_context * ctx, int n_new, voi
206
  const int64_t t0 = whisper_full_get_segment_t0(ctx, i);
207
  const int64_t t1 = whisper_full_get_segment_t1(ctx, i);
208
 
209
- std::string speaker = "";
210
 
211
  if (params.diarize && pcmf32s.size() == 2) {
212
  const int64_t n_samples = pcmf32s[0].size();
@@ -468,7 +468,7 @@ int main(int argc, char ** argv) {
468
  // initial prompt
469
  std::vector<whisper_token> prompt_tokens;
470
 
471
- if (params.prompt.size() > 0) {
472
  prompt_tokens.resize(1024);
473
  prompt_tokens.resize(whisper_tokenize(ctx, params.prompt.c_str(), prompt_tokens.data(), prompt_tokens.size()));
474
 
@@ -505,14 +505,14 @@ int main(int argc, char ** argv) {
505
  }
506
  }
507
 
508
- if (drwav_init_memory(&wav, wav_data.data(), wav_data.size(), NULL) == false) {
509
  fprintf(stderr, "error: failed to open WAV file from stdin\n");
510
  return 4;
511
  }
512
 
513
  fprintf(stderr, "%s: read %zu bytes from stdin\n", __func__, wav_data.size());
514
  }
515
- else if (drwav_init_file(&wav, fname_inp.c_str(), NULL) == false) {
516
  fprintf(stderr, "error: failed to open '%s' as WAV file\n", fname_inp.c_str());
517
  return 5;
518
  }
@@ -617,8 +617,8 @@ int main(int argc, char ** argv) {
617
 
618
  wparams.speed_up = params.speed_up;
619
 
620
- wparams.prompt_tokens = prompt_tokens.size() == 0 ? nullptr : prompt_tokens.data();
621
- wparams.prompt_n_tokens = prompt_tokens.size() == 0 ? 0 : prompt_tokens.size();
622
 
623
  whisper_print_user_data user_data = { &params, &pcmf32s };
624
 
 
75
  bool no_timestamps = false;
76
 
77
  std::string language = "en";
78
+ std::string prompt;
79
  std::string model = "models/ggml-base.en.bin";
80
 
81
  std::vector<std::string> fname_inp = {};
 
118
  else if (arg == "-l" || arg == "--language") { params.language = argv[++i]; }
119
  else if ( arg == "--prompt") { params.prompt = argv[++i]; }
120
  else if (arg == "-m" || arg == "--model") { params.model = argv[++i]; }
121
+ else if (arg == "-f" || arg == "--file") { params.fname_inp.emplace_back(argv[++i]); }
122
  else {
123
  fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
124
  whisper_print_usage(argc, argv, params);
 
206
  const int64_t t0 = whisper_full_get_segment_t0(ctx, i);
207
  const int64_t t1 = whisper_full_get_segment_t1(ctx, i);
208
 
209
+ std::string speaker;
210
 
211
  if (params.diarize && pcmf32s.size() == 2) {
212
  const int64_t n_samples = pcmf32s[0].size();
 
468
  // initial prompt
469
  std::vector<whisper_token> prompt_tokens;
470
 
471
+ if (!params.prompt.empty()) {
472
  prompt_tokens.resize(1024);
473
  prompt_tokens.resize(whisper_tokenize(ctx, params.prompt.c_str(), prompt_tokens.data(), prompt_tokens.size()));
474
 
 
505
  }
506
  }
507
 
508
+ if (drwav_init_memory(&wav, wav_data.data(), wav_data.size(), nullptr) == false) {
509
  fprintf(stderr, "error: failed to open WAV file from stdin\n");
510
  return 4;
511
  }
512
 
513
  fprintf(stderr, "%s: read %zu bytes from stdin\n", __func__, wav_data.size());
514
  }
515
+ else if (drwav_init_file(&wav, fname_inp.c_str(), nullptr) == false) {
516
  fprintf(stderr, "error: failed to open '%s' as WAV file\n", fname_inp.c_str());
517
  return 5;
518
  }
 
617
 
618
  wparams.speed_up = params.speed_up;
619
 
620
+ wparams.prompt_tokens = prompt_tokens.empty() ? nullptr : prompt_tokens.data();
621
+ wparams.prompt_n_tokens = prompt_tokens.empty() ? 0 : prompt_tokens.size();
622
 
623
  whisper_print_user_data user_data = { &params, &pcmf32s };
624
 
examples/stream/stream.cpp CHANGED
@@ -51,7 +51,7 @@ struct whisper_params {
51
 
52
  std::string language = "en";
53
  std::string model = "models/ggml-base.en.bin";
54
- std::string fname_out = "";
55
  };
56
 
57
  void whisper_print_usage(int argc, char ** argv, const whisper_params & params);
 
51
 
52
  std::string language = "en";
53
  std::string model = "models/ggml-base.en.bin";
54
+ std::string fname_out;
55
  };
56
 
57
  void whisper_print_usage(int argc, char ** argv, const whisper_params & params);
examples/talk/gpt-2.cpp CHANGED
@@ -40,7 +40,7 @@ std::vector<gpt_vocab::id> gpt_tokenize(const gpt_vocab & vocab, const std::stri
40
  // find the longest tokens that form the words:
41
  std::vector<gpt_vocab::id> tokens;
42
  for (const auto & word : words) {
43
- if (word.size() == 0) continue;
44
 
45
  int i = 0;
46
  int n = word.size();
@@ -86,7 +86,7 @@ gpt_vocab::id gpt_sample_top_k_top_p(
86
  logits_id.reserve(n_logits);
87
 
88
  for (int i = 0; i < n_logits; i++) {
89
- logits_id.push_back(std::make_pair(logits[i], i));
90
  }
91
 
92
  // find the top K tokens
@@ -327,7 +327,7 @@ bool gpt2_model_load(const std::string & fname, gpt2_model & model, gpt_vocab &
327
  {
328
  struct ggml_init_params params;
329
  params.mem_size = ctx_size;
330
- params.mem_buffer = NULL;
331
 
332
  model.ctx = ggml_init(params);
333
  if (!model.ctx) {
@@ -448,7 +448,7 @@ bool gpt2_model_load(const std::string & fname, gpt2_model & model, gpt_vocab &
448
  std::string name(length, 0);
449
  fin.read(&name[0], length);
450
 
451
- if (model.tensors.find(name.data()) == model.tensors.end()) {
452
  fprintf(stderr, "%s: unknown tensor '%s' in model file\n", __func__, name.data());
453
  return false;
454
  }
@@ -833,7 +833,7 @@ Me too.
833
  struct gpt2_context * gpt2_init(const char * path_model) {
834
  gpt2_context * ctx = new gpt2_context;
835
 
836
- ctx->rng = std::mt19937(time(NULL));
837
 
838
  // load the model
839
  {
@@ -886,7 +886,7 @@ std::string gpt2_gen_text(gpt2_context * ctx, const char * text, int max_tokens)
886
 
887
  for (int i = embd.size(); i < (int) embd_inp.size() + n_predict; i++) {
888
  // predict
889
- if (embd.size() > 0) {
890
  if (!gpt2_eval(ctx->model, ctx->n_threads, n_past, embd, embd_w, mem_per_token)) {
891
  printf("gpt-2: failed to generate text\n");
892
  return "";
 
40
  // find the longest tokens that form the words:
41
  std::vector<gpt_vocab::id> tokens;
42
  for (const auto & word : words) {
43
+ if (word.empty()) continue;
44
 
45
  int i = 0;
46
  int n = word.size();
 
86
  logits_id.reserve(n_logits);
87
 
88
  for (int i = 0; i < n_logits; i++) {
89
+ logits_id.emplace_back(logits[i], i);
90
  }
91
 
92
  // find the top K tokens
 
327
  {
328
  struct ggml_init_params params;
329
  params.mem_size = ctx_size;
330
+ params.mem_buffer = nullptr;
331
 
332
  model.ctx = ggml_init(params);
333
  if (!model.ctx) {
 
448
  std::string name(length, 0);
449
  fin.read(&name[0], length);
450
 
451
+ if (model.tensors.find(name) == model.tensors.end()) {
452
  fprintf(stderr, "%s: unknown tensor '%s' in model file\n", __func__, name.data());
453
  return false;
454
  }
 
833
  struct gpt2_context * gpt2_init(const char * path_model) {
834
  gpt2_context * ctx = new gpt2_context;
835
 
836
+ ctx->rng = std::mt19937(time(nullptr));
837
 
838
  // load the model
839
  {
 
886
 
887
  for (int i = embd.size(); i < (int) embd_inp.size() + n_predict; i++) {
888
  // predict
889
+ if (!embd.empty()) {
890
  if (!gpt2_eval(ctx->model, ctx->n_threads, n_past, embd, embd_w, mem_per_token)) {
891
  printf("gpt-2: failed to generate text\n");
892
  return "";
examples/talk/talk.cpp CHANGED
@@ -39,7 +39,7 @@ struct whisper_params {
39
  std::string model_wsp = "models/ggml-base.en.bin";
40
  std::string model_gpt = "models/ggml-gpt-2-117M.bin";
41
  std::string speak = "./examples/talk/speak.sh";
42
- std::string fname_out = "";
43
  };
44
 
45
  void whisper_print_usage(int argc, char ** argv, const whisper_params & params);
@@ -588,7 +588,7 @@ int main(int argc, char ** argv) {
588
 
589
  audio.get(params.voice_ms, pcmf32_cur);
590
 
591
- std::string text_heard = "";
592
 
593
  if (!force_speak) {
594
  text_heard = ::trim(::transcribe(ctx_wsp, params, pcmf32_cur, prob0, t_ms));
@@ -610,7 +610,7 @@ int main(int argc, char ** argv) {
610
  text_heard = std::regex_replace(text_heard, std::regex("[^a-zA-Z0-9\\.,\\?!\\s\\:\\'\\-]"), "");
611
 
612
  // take first line
613
- text_heard = text_heard.substr(0, text_heard.find_first_of("\n"));
614
 
615
  // remove leading and trailing whitespace
616
  text_heard = std::regex_replace(text_heard, std::regex("^\\s+"), "");
@@ -640,18 +640,18 @@ int main(int argc, char ** argv) {
640
 
641
  text_to_speak = gpt2_gen_text(ctx_gpt, prompt.c_str(), params.max_tokens);
642
  text_to_speak = std::regex_replace(text_to_speak, std::regex("[^a-zA-Z0-9\\.,\\?!\\s\\:\\'\\-]"), "");
643
- text_to_speak = text_to_speak.substr(0, text_to_speak.find_first_of("\n"));
644
 
645
  // remove first 2 lines of base prompt
646
  if (n_iter > 4) {
647
  {
648
- const size_t pos = prompt_base.find_first_of("\n");
649
  if (pos != std::string::npos) {
650
  prompt_base = prompt_base.substr(pos + 1);
651
  }
652
  }
653
  {
654
- const size_t pos = prompt_base.find_first_of("\n");
655
  if (pos != std::string::npos) {
656
  prompt_base = prompt_base.substr(pos + 1);
657
  }
 
39
  std::string model_wsp = "models/ggml-base.en.bin";
40
  std::string model_gpt = "models/ggml-gpt-2-117M.bin";
41
  std::string speak = "./examples/talk/speak.sh";
42
+ std::string fname_out;
43
  };
44
 
45
  void whisper_print_usage(int argc, char ** argv, const whisper_params & params);
 
588
 
589
  audio.get(params.voice_ms, pcmf32_cur);
590
 
591
+ std::string text_heard;
592
 
593
  if (!force_speak) {
594
  text_heard = ::trim(::transcribe(ctx_wsp, params, pcmf32_cur, prob0, t_ms));
 
610
  text_heard = std::regex_replace(text_heard, std::regex("[^a-zA-Z0-9\\.,\\?!\\s\\:\\'\\-]"), "");
611
 
612
  // take first line
613
+ text_heard = text_heard.substr(0, text_heard.find_first_of('\n'));
614
 
615
  // remove leading and trailing whitespace
616
  text_heard = std::regex_replace(text_heard, std::regex("^\\s+"), "");
 
640
 
641
  text_to_speak = gpt2_gen_text(ctx_gpt, prompt.c_str(), params.max_tokens);
642
  text_to_speak = std::regex_replace(text_to_speak, std::regex("[^a-zA-Z0-9\\.,\\?!\\s\\:\\'\\-]"), "");
643
+ text_to_speak = text_to_speak.substr(0, text_to_speak.find_first_of('\n'));
644
 
645
  // remove first 2 lines of base prompt
646
  if (n_iter > 4) {
647
  {
648
+ const size_t pos = prompt_base.find_first_of('\n');
649
  if (pos != std::string::npos) {
650
  prompt_base = prompt_base.substr(pos + 1);
651
  }
652
  }
653
  {
654
+ const size_t pos = prompt_base.find_first_of('\n');
655
  if (pos != std::string::npos) {
656
  prompt_base = prompt_base.substr(pos + 1);
657
  }