rakksor commited on
Commit
542e8da
·
unverified ·
1 Parent(s): f928f33

talk-llama : optional wake-up command and audio confirmation (#1765)

Browse files

* talk-llama: add optional wake-word detection from command

* talk-llama: add optional audio confirmation before generating answer

* talk-llama: fix small formatting issue in output

* talk-llama.cpp: fix Windows build

Files changed (1) hide show
  1. examples/talk-llama/talk-llama.cpp +62 -2
examples/talk-llama/talk-llama.cpp CHANGED
@@ -14,6 +14,7 @@
14
  #include <thread>
15
  #include <vector>
16
  #include <regex>
 
17
 
18
  std::vector<llama_token> llama_tokenize(struct llama_context * ctx, const std::string & text, bool add_bos) {
19
  auto * model = llama_get_model(ctx);
@@ -68,6 +69,8 @@ struct whisper_params {
68
 
69
  std::string person = "Georgi";
70
  std::string bot_name = "LLaMA";
 
 
71
  std::string language = "en";
72
  std::string model_wsp = "models/ggml-base.en.bin";
73
  std::string model_llama = "models/ggml-llama-7B.bin";
@@ -104,6 +107,8 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
104
  else if (arg == "-p" || arg == "--person") { params.person = argv[++i]; }
105
  else if (arg == "-bn" || arg == "--bot-name") { params.bot_name = argv[++i]; }
106
  else if (arg == "--session") { params.path_session = argv[++i]; }
 
 
107
  else if (arg == "-l" || arg == "--language") { params.language = argv[++i]; }
108
  else if (arg == "-mw" || arg == "--model-whisper") { params.model_wsp = argv[++i]; }
109
  else if (arg == "-ml" || arg == "--model-llama") { params.model_llama = argv[++i]; }
@@ -149,6 +154,8 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
149
  fprintf(stderr, " -ng, --no-gpu [%-7s] disable GPU\n", params.use_gpu ? "false" : "true");
150
  fprintf(stderr, " -p NAME, --person NAME [%-7s] person name (for prompt selection)\n", params.person.c_str());
151
  fprintf(stderr, " -bn NAME, --bot-name NAME [%-7s] bot name (to display)\n", params.bot_name.c_str());
 
 
152
  fprintf(stderr, " -l LANG, --language LANG [%-7s] spoken language\n", params.language.c_str());
153
  fprintf(stderr, " -mw FILE, --model-whisper [%-7s] whisper model file\n", params.model_wsp.c_str());
154
  fprintf(stderr, " -ml FILE, --model-llama [%-7s] llama model file\n", params.model_llama.c_str());
@@ -227,6 +234,18 @@ std::string transcribe(
227
  return result;
228
  }
229
 
 
 
 
 
 
 
 
 
 
 
 
 
230
  const std::string k_prompt_whisper = R"(A conversation with a person called {1}.)";
231
 
232
  const std::string k_prompt_llama = R"(Text transcript of a never ending dialog, where {0} interacts with an AI assistant named {1}.
@@ -441,6 +460,16 @@ int main(int argc, char ** argv) {
441
  bool need_to_save_session = !path_session.empty() && n_matching_session_tokens < (embd_inp.size() * 3 / 4);
442
 
443
  printf("%s : done! start speaking in the microphone\n", __func__);
 
 
 
 
 
 
 
 
 
 
444
  printf("\n");
445
  printf("%s%s", params.person.c_str(), chat_symb.c_str());
446
  fflush(stdout);
@@ -486,10 +515,41 @@ int main(int argc, char ** argv) {
486
 
487
  audio.get(params.voice_ms, pcmf32_cur);
488
 
489
- std::string text_heard;
490
 
491
  if (!force_speak) {
492
- text_heard = ::trim(::transcribe(ctx_wsp, params, pcmf32_cur, prompt_whisper, prob0, t_ms));
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
493
  }
494
 
495
  // remove text between brackets using regex
 
14
  #include <thread>
15
  #include <vector>
16
  #include <regex>
17
+ #include <sstream>
18
 
19
  std::vector<llama_token> llama_tokenize(struct llama_context * ctx, const std::string & text, bool add_bos) {
20
  auto * model = llama_get_model(ctx);
 
69
 
70
  std::string person = "Georgi";
71
  std::string bot_name = "LLaMA";
72
+ std::string wake_cmd = "";
73
+ std::string heard_ok = "";
74
  std::string language = "en";
75
  std::string model_wsp = "models/ggml-base.en.bin";
76
  std::string model_llama = "models/ggml-llama-7B.bin";
 
107
  else if (arg == "-p" || arg == "--person") { params.person = argv[++i]; }
108
  else if (arg == "-bn" || arg == "--bot-name") { params.bot_name = argv[++i]; }
109
  else if (arg == "--session") { params.path_session = argv[++i]; }
110
+ else if (arg == "-w" || arg == "--wake-command") { params.wake_cmd = argv[++i]; }
111
+ else if (arg == "-ho" || arg == "--heard-ok") { params.heard_ok = argv[++i]; }
112
  else if (arg == "-l" || arg == "--language") { params.language = argv[++i]; }
113
  else if (arg == "-mw" || arg == "--model-whisper") { params.model_wsp = argv[++i]; }
114
  else if (arg == "-ml" || arg == "--model-llama") { params.model_llama = argv[++i]; }
 
154
  fprintf(stderr, " -ng, --no-gpu [%-7s] disable GPU\n", params.use_gpu ? "false" : "true");
155
  fprintf(stderr, " -p NAME, --person NAME [%-7s] person name (for prompt selection)\n", params.person.c_str());
156
  fprintf(stderr, " -bn NAME, --bot-name NAME [%-7s] bot name (to display)\n", params.bot_name.c_str());
157
+ fprintf(stderr, " -w TEXT, --wake-command T [%-7s] wake-up command to listen for\n", params.wake_cmd.c_str());
158
+ fprintf(stderr, " -ho TEXT, --heard-ok TEXT [%-7s] said by TTS before generating reply\n", params.heard_ok.c_str());
159
  fprintf(stderr, " -l LANG, --language LANG [%-7s] spoken language\n", params.language.c_str());
160
  fprintf(stderr, " -mw FILE, --model-whisper [%-7s] whisper model file\n", params.model_wsp.c_str());
161
  fprintf(stderr, " -ml FILE, --model-llama [%-7s] llama model file\n", params.model_llama.c_str());
 
234
  return result;
235
  }
236
 
237
+ std::vector<std::string> get_words(const std::string &txt) {
238
+ std::vector<std::string> words;
239
+
240
+ std::istringstream iss(txt);
241
+ std::string word;
242
+ while (iss >> word) {
243
+ words.push_back(word);
244
+ }
245
+
246
+ return words;
247
+ }
248
+
249
  const std::string k_prompt_whisper = R"(A conversation with a person called {1}.)";
250
 
251
  const std::string k_prompt_llama = R"(Text transcript of a never ending dialog, where {0} interacts with an AI assistant named {1}.
 
460
  bool need_to_save_session = !path_session.empty() && n_matching_session_tokens < (embd_inp.size() * 3 / 4);
461
 
462
  printf("%s : done! start speaking in the microphone\n", __func__);
463
+
464
+ // show wake command if enabled
465
+ const std::string wake_cmd = params.wake_cmd;
466
+ const int wake_cmd_length = get_words(wake_cmd).size();
467
+ const bool use_wake_cmd = wake_cmd_length > 0;
468
+
469
+ if (use_wake_cmd) {
470
+ printf("%s : the wake-up command is: '%s%s%s'\n", __func__, "\033[1m", wake_cmd.c_str(), "\033[0m");
471
+ }
472
+
473
  printf("\n");
474
  printf("%s%s", params.person.c_str(), chat_symb.c_str());
475
  fflush(stdout);
 
515
 
516
  audio.get(params.voice_ms, pcmf32_cur);
517
 
518
+ std::string all_heard;
519
 
520
  if (!force_speak) {
521
+ all_heard = ::trim(::transcribe(ctx_wsp, params, pcmf32_cur, prompt_whisper, prob0, t_ms));
522
+ }
523
+
524
+ const auto words = get_words(all_heard);
525
+
526
+ std::string wake_cmd_heard;
527
+ std::string text_heard;
528
+
529
+ for (int i = 0; i < (int) words.size(); ++i) {
530
+ if (i < wake_cmd_length) {
531
+ wake_cmd_heard += words[i] + " ";
532
+ } else {
533
+ text_heard += words[i] + " ";
534
+ }
535
+ }
536
+
537
+ // check if audio starts with the wake-up command if enabled
538
+ if (use_wake_cmd) {
539
+ const float sim = similarity(wake_cmd_heard, wake_cmd);
540
+
541
+ if ((sim < 0.7f) || (text_heard.empty())) {
542
+ audio.clear();
543
+ continue;
544
+ }
545
+ }
546
+
547
+ // optionally give audio feedback that the current text is being processed
548
+ if (!params.heard_ok.empty()) {
549
+ int ret = system((params.speak + " " + std::to_string(voice_id) + " '" + params.heard_ok + "'").c_str());
550
+ if (ret != 0) {
551
+ fprintf(stderr, "%s: failed to speak\n", __func__);
552
+ }
553
  }
554
 
555
  // remove text between brackets using regex