ggerganov commited on
Commit
fdd70c9
·
unverified ·
1 Parent(s): d11f3b5

command : clean-up / refactoring / formatting (#383)

Browse files
Files changed (1) hide show
  1. examples/command/command.cpp +324 -359
examples/command/command.cpp CHANGED
@@ -11,7 +11,6 @@
11
  #include <SDL.h>
12
  #include <SDL_audio.h>
13
 
14
- #include <iostream>
15
  #include <sstream>
16
  #include <cassert>
17
  #include <cstdio>
@@ -515,440 +514,406 @@ std::vector<std::string> read_allowed_commands(const std::string & fname) {
515
  return allowed_commands;
516
  }
517
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
518
  // command-list mode
519
  // guide the transcription to match the most likely command from a provided list
520
  int process_command_list(struct whisper_context * ctx, audio_async &audio, const whisper_params &params) {
521
- fprintf(stderr, "\n");
522
- fprintf(stderr, "%s: guided mode\n", __func__);
523
-
524
- std::vector<std::string> allowed_commands = read_allowed_commands(params.commands);
525
-
526
- if (allowed_commands.empty()) {
527
- fprintf(stderr, "%s: error: failed to read allowed commands from '%s'\n", __func__, params.commands.c_str());
528
- return 2;
529
- }
530
-
531
- int max_len = 0;
532
-
533
- std::vector<std::vector<whisper_token>> allowed_tokens;
534
-
535
- for (const auto & cmd : allowed_commands) {
536
- whisper_token tokens[1024];
537
- allowed_tokens.emplace_back();
538
-
539
- for (int l = 0; l < (int) cmd.size(); ++l) {
540
- // NOTE: very important to add the whitespace !
541
- // the reason is that the first decoded token starts with a whitespace too!
542
- std::string ss = std::string(" ") + cmd.substr(0, l + 1);
543
-
544
- const int n = whisper_tokenize(ctx, ss.c_str(), tokens, 1024);
545
- if (n < 0) {
546
- fprintf(stderr, "%s: error: failed to tokenize command '%s'\n", __func__, cmd.c_str());
547
- return 3;
548
- }
549
-
550
- if (n == 1) {
551
- allowed_tokens.back().push_back(tokens[0]);
552
- }
553
- }
554
-
555
- max_len = std::max(max_len, (int) cmd.size());
556
- }
557
-
558
- fprintf(stderr, "%s: allowed commands [ tokens ]:\n", __func__);
559
- fprintf(stderr, "\n");
560
- for (int i = 0; i < (int) allowed_commands.size(); ++i) {
561
- fprintf(stderr, " - \033[1m%-*s\033[0m = [", max_len, allowed_commands[i].c_str());
562
- for (const auto & token : allowed_tokens[i]) {
563
- fprintf(stderr, " %5d", token);
564
- }
565
- fprintf(stderr, " ]\n");
566
- }
567
-
568
- std::string k_prompt = "select one from the available words: ";
569
- for (int i = 0; i < (int) allowed_commands.size(); ++i) {
570
- if (i > 0) {
571
- k_prompt += ", ";
572
- }
573
- k_prompt += allowed_commands[i];
574
- }
575
- k_prompt += ". selected word: ";
576
-
577
- // tokenize prompt
578
- std::vector<whisper_token> k_tokens;
579
- {
580
- k_tokens.resize(1024);
581
- const int n = whisper_tokenize(ctx, k_prompt.c_str(), k_tokens.data(), 1024);
582
- if (n < 0) {
583
- fprintf(stderr, "%s: error: failed to tokenize prompt '%s'\n", __func__, k_prompt.c_str());
584
- return 4;
585
- }
586
- k_tokens.resize(n);
587
- }
588
-
589
- fprintf(stderr, "\n");
590
- fprintf(stderr, "%s: prompt: '%s'\n", __func__, k_prompt.c_str());
591
- fprintf(stderr, "%s: tokens: [", __func__);
592
- for (const auto & token : k_tokens) {
593
- fprintf(stderr, " %d", token);
594
- }
595
- fprintf(stderr, " ]\n");
596
-
597
- fprintf(stderr, "\n");
598
- fprintf(stderr, "%s: listening for a command ...\n", __func__);
599
- fprintf(stderr, "\n");
600
-
601
- bool is_running = true;
602
-
603
- std::vector<float> pcmf32_cur;
604
- std::vector<float> pcmf32_prompt;
605
-
606
- // main loop
607
- while (is_running) {
608
- // handle Ctrl + C
609
- {
610
- SDL_Event event;
611
- while (SDL_PollEvent(&event)) {
612
- switch (event.type) {
613
- case SDL_QUIT:
614
- {
615
- is_running = false;
616
- } break;
617
- default:
618
- break;
619
  }
620
- }
621
 
622
- if (!is_running) {
623
- return 0;
624
- }
625
- }
626
 
627
- // delay
628
- std::this_thread::sleep_for(std::chrono::milliseconds(100));
629
 
630
- audio.get(2000, pcmf32_cur);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
631
 
632
- if (vad_simple(pcmf32_cur, WHISPER_SAMPLE_RATE, 1000, params.vad_thold, params.freq_thold, params.print_energy)) {
633
- fprintf(stdout, "%s: Speech detected! Processing ...\n", __func__);
634
 
635
- const auto t_start = std::chrono::high_resolution_clock::now();
 
 
 
636
 
637
- whisper_full_params wparams = whisper_full_default_params(WHISPER_SAMPLING_GREEDY);
 
 
 
 
 
 
638
 
639
- wparams.print_progress = false;
640
- wparams.print_special = params.print_special;
641
- wparams.print_realtime = false;
642
- wparams.print_timestamps = !params.no_timestamps;
643
- wparams.translate = params.translate;
644
- wparams.no_context = true;
645
- wparams.single_segment = true;
646
- wparams.max_tokens = 1;
647
- wparams.language = params.language.c_str();
648
- wparams.n_threads = params.n_threads;
649
 
650
- wparams.audio_ctx = params.audio_ctx;
651
- wparams.speed_up = params.speed_up;
652
 
653
- wparams.prompt_tokens = k_tokens.data();
654
- wparams.prompt_n_tokens = k_tokens.size();
 
 
 
 
 
 
 
 
655
 
656
- // run the transformer and a single decoding pass
657
- if (whisper_full(ctx, wparams, pcmf32_cur.data(), pcmf32_cur.size()) != 0) {
658
- fprintf(stderr, "%s: ERROR: whisper_full() failed\n", __func__);
659
- break;
660
- }
661
 
662
- const auto * probs = whisper_get_probs(ctx);
663
- std::vector<std::pair<float, int>> probs_id;
664
 
665
- double psum = 0.0;
666
- for (int i = 0; i < (int) allowed_commands.size(); ++i) {
667
- probs_id.emplace_back(probs[allowed_tokens[i][0]], i);
668
- for (int j = 1; j < (int) allowed_tokens[i].size(); ++j) {
669
- probs_id.back().first += probs[allowed_tokens[i][j]];
670
  }
671
- probs_id.back().first /= allowed_tokens[i].size();
672
- psum += probs_id.back().first;
673
- }
674
-
675
- // normalize
676
- for (auto & p : probs_id) {
677
- p.first /= psum;
678
- }
679
-
680
- // sort descending
681
- {
682
- using pair_type = decltype(probs_id)::value_type;
683
- std::sort(probs_id.begin(), probs_id.end(), [](const pair_type & a, const pair_type & b) {
684
- return a.first > b.first;
685
- });
686
- }
687
-
688
- // print the commands and the respective probabilities
689
- {
690
- fprintf(stdout, "\n");
691
- for (const auto & cmd : probs_id) {
692
- fprintf(stdout, "%s: %s%-*s%s = %f | ", __func__, "\033[1m", max_len, allowed_commands[cmd.second].c_str(), "\033[0m", cmd.first);
693
- for (int token : allowed_tokens[cmd.second]) {
694
- fprintf(stdout, "'%4s' %f ", whisper_token_to_str(ctx, token), probs[token]);
695
- }
696
- fprintf(stdout, "\n");
697
  }
698
- }
699
 
700
- // best command
701
- {
702
- const auto t_end = std::chrono::high_resolution_clock::now();
 
703
 
704
- const float prob = probs_id[0].first;
705
- const int index = probs_id[0].second;
 
 
 
 
 
706
 
707
- fprintf(stdout, "\n");
708
- fprintf(stdout, "%s: detected command: %s%s%s | p = %f | t = %d ms\n", __func__,
709
- "\033[1m", allowed_commands[index].c_str(), "\033[0m", prob,
710
- (int) std::chrono::duration_cast<std::chrono::milliseconds>(t_end - t_start).count());
711
- fprintf(stdout, "\n");
712
- }
 
 
 
 
 
713
 
714
- audio.clear();
715
- }
716
- }
717
 
718
- return 0;
719
- }
720
 
721
- // general-purpose mode
722
- // freely transcribe the voice into text
723
- int process_general_transcription(struct whisper_context * ctx, audio_async &audio, const whisper_params &params) {
724
- bool is_running = true;
725
- bool have_prompt = false;
726
- bool ask_prompt = true;
727
-
728
- float prob0 = 0.0f;
729
- float prob = 0.0f;
730
-
731
- std::vector<float> pcmf32_cur;
732
- std::vector<float> pcmf32_prompt;
733
-
734
- const std::string k_prompt = "Ok Whisper, start listening for commands.";
735
-
736
- fprintf(stderr, "\n");
737
- fprintf(stderr, "%s: general-purpose mode\n", __func__);
738
-
739
- // main loop
740
- while (is_running) {
741
- // handle Ctrl + C
742
- {
743
- SDL_Event event;
744
- while (SDL_PollEvent(&event)) {
745
- switch (event.type) {
746
- case SDL_QUIT:
747
- {
748
- is_running = false;
749
- } break;
750
- default:
751
- break;
752
  }
753
- }
754
 
755
- if (!is_running) {
756
- return 0;
757
- }
758
- }
 
 
759
 
760
- // delay
761
- std::this_thread::sleep_for(std::chrono::milliseconds(100));
 
 
 
762
 
763
- if (ask_prompt) {
764
- fprintf(stdout, "\n");
765
- fprintf(stdout, "%s: Say the following phrase: '%s%s%s'\n", __func__, "\033[1m", k_prompt.c_str(), "\033[0m");
766
- fprintf(stdout, "\n");
767
 
768
- ask_prompt = false;
769
- }
770
 
771
- {
772
- audio.get(2000, pcmf32_cur);
773
 
774
- if (vad_simple(pcmf32_cur, WHISPER_SAMPLE_RATE, 1000, params.vad_thold, params.freq_thold, params.print_energy)) {
775
- fprintf(stdout, "%s: Speech detected! Processing ...\n", __func__);
776
 
777
- int64_t t_ms = 0;
 
 
 
 
 
 
778
 
779
- if (!have_prompt) {
780
- // wait for activation phrase
781
- audio.get(params.prompt_ms, pcmf32_cur);
782
 
783
- const auto txt = ::trim(::transcribe(ctx, params, pcmf32_cur, prob0, t_ms));
 
 
 
784
 
785
- fprintf(stdout, "%s: Heard '%s%s%s', (t = %d ms)\n", __func__, "\033[1m", txt.c_str(), "\033[0m", (int) t_ms);
 
786
 
787
- const float sim = similarity(txt, k_prompt);
 
788
 
789
- if (txt.length() < 0.8*k_prompt.length() || txt.length() > 1.2*k_prompt.length() || sim < 0.8f) {
790
- fprintf(stdout, "%s: WARNING: prompt not recognized, try again\n", __func__);
791
- ask_prompt = true;
792
- } else {
793
- fprintf(stdout, "\n");
794
- fprintf(stdout, "%s: The prompt has been recognized!\n", __func__);
795
- fprintf(stdout, "%s: Waiting for voice commands ...\n", __func__);
796
- fprintf(stdout, "\n");
797
 
798
- // save the audio for the prompt
799
- pcmf32_prompt = pcmf32_cur;
800
- have_prompt = true;
801
- }
802
- } else {
803
- // we have heard the activation phrase, now detect the commands
804
- audio.get(params.command_ms, pcmf32_cur);
805
 
806
- // prepend the prompt audio
807
- pcmf32_cur.insert(pcmf32_cur.begin(), pcmf32_prompt.begin(), pcmf32_prompt.end());
808
 
809
- const auto txt = ::trim(::transcribe(ctx, params, pcmf32_cur, prob, t_ms));
810
 
811
- prob = 100.0f*(prob - prob0);
812
 
813
- //fprintf(stdout, "%s: heard '%s'\n", __func__, txt.c_str());
 
814
 
815
- // find the prompt in the text
816
- float best_sim = 0.0f;
817
- size_t best_len = 0;
818
- for (int n = 0.8*k_prompt.size(); n <= 1.2*k_prompt.size(); ++n) {
819
- const auto prompt = txt.substr(0, n);
 
 
820
 
821
- const float sim = similarity(prompt, k_prompt);
822
 
823
- //fprintf(stderr, "%s: prompt = '%s', sim = %f\n", __func__, prompt.c_str(), sim);
 
824
 
825
- if (sim > best_sim) {
826
- best_sim = sim;
827
- best_len = n;
828
- }
829
- }
830
 
831
- const std::string command = ::trim(txt.substr(best_len));
832
 
833
- fprintf(stdout, "%s: Command '%s%s%s', (t = %d ms)\n", __func__, "\033[1m", command.c_str(), "\033[0m", (int) t_ms);
834
- fprintf(stdout, "\n");
835
  }
 
 
836
 
837
- audio.clear();
838
- }
839
- }
840
- }
841
-
842
- return 0;
843
  }
844
 
 
 
 
 
 
 
845
 
846
- // always prompt mode
847
- // transcribe the voice into text after valid prompt
848
- int always_prompt_transcription(struct whisper_context * ctx, audio_async &audio, const whisper_params &params) {
849
- bool is_running = true;
850
- bool ask_prompt = true;
851
 
852
- float prob = 0.0f;
 
853
 
854
- std::vector<float> pcmf32_cur;
855
 
856
- const std::string k_prompt = params.prompt;
 
857
 
858
- std::vector<std::string> words;
 
 
 
859
 
860
- std::istringstream iss(k_prompt);
861
- std::string word;
862
 
863
- while (iss >> word) {
864
- words.push_back(word);
865
- }
 
866
 
867
- int k_prompt_length = words.size();
 
868
 
869
- // main loop
870
- while (is_running) {
871
- // handle Ctrl + C
872
- {
873
- SDL_Event event;
874
- while (SDL_PollEvent(&event)) {
875
- switch (event.type) {
876
- case SDL_QUIT:
877
- {
878
- is_running = false;
879
- } break;
880
- default:
881
- break;
882
- }
883
- }
884
 
885
- if (!is_running) {
886
- return 0;
887
- }
888
- }
889
 
890
- // delay
891
- std::this_thread::sleep_for(std::chrono::milliseconds(100));
892
 
893
- if (ask_prompt) {
894
- fprintf(stdout, "\n");
895
- fprintf(stdout, "%s: The prompt is: '%s%s%s'\n", __func__, "\033[1m", k_prompt.c_str(), "\033[0m");
896
- fprintf(stdout, "\n");
897
 
898
- ask_prompt = false;
899
- }
900
 
901
- {
902
- audio.get(2000, pcmf32_cur);
903
 
904
- if (vad_simple(pcmf32_cur, WHISPER_SAMPLE_RATE, 1000, params.vad_thold, params.freq_thold, params.print_energy)) {
905
- fprintf(stdout, "%s: Speech detected! Processing ...\n", __func__);
906
 
907
- int64_t t_ms = 0;
 
 
 
 
 
 
 
908
 
909
- // detect the commands
910
- audio.get(params.command_ms, pcmf32_cur);
 
 
 
 
 
911
 
912
- const auto txt = ::trim(::transcribe(ctx, params, pcmf32_cur, prob, t_ms));
 
913
 
914
- std::istringstream iss(txt);
915
- std::string word;
916
- std::string prompt;
917
- std::string command;
918
- int i = 0;
919
- int command_length = 0;
920
- while (iss >> word) {
921
- if (i == k_prompt_length - 1) {
922
- prompt += word + ' ';
923
- break;
924
- }
925
- prompt += word + ' ';
926
- i++;
927
- }
928
- while (iss >> word) {
929
- command += word + ' ';
930
- command_length++;
931
- }
932
 
933
- const float sim = similarity(prompt, k_prompt);
934
 
935
- //debug
936
- //fprintf(stdout, "command size: %i\n", command_length);
937
 
 
 
 
 
 
938
 
939
- if ((sim > 0.7f) && (command_length >0)){
940
- fprintf(stdout, "%s: Command '%s%s%s', (t = %d ms)\n", __func__, "\033[1m", command.c_str(), "\033[0m", (int) t_ms);
941
- }
942
 
943
- fprintf(stdout, "\n");
944
 
 
 
 
 
 
945
 
946
- audio.clear();
947
- }
948
- }
949
- }
 
 
 
 
 
 
950
 
951
- return 0;
952
  }
953
 
954
  int main(int argc, char ** argv) {
@@ -1005,11 +970,11 @@ int main(int argc, char ** argv) {
1005
  int ret_val = 0;
1006
 
1007
  if (!params.commands.empty()) {
1008
- ret_val = process_command_list(ctx, audio, params);
1009
  } else if (!params.prompt.empty()) {
1010
- ret_val = always_prompt_transcription(ctx, audio, params);
1011
  } else {
1012
- ret_val = process_general_transcription(ctx, audio, params);
1013
  }
1014
 
1015
  audio.pause();
 
11
  #include <SDL.h>
12
  #include <SDL_audio.h>
13
 
 
14
  #include <sstream>
15
  #include <cassert>
16
  #include <cstdio>
 
514
  return allowed_commands;
515
  }
516
 
517
+ std::vector<std::string> get_words(const std::string &txt) {
518
+ std::vector<std::string> words;
519
+
520
+ std::istringstream iss(txt);
521
+ std::string word;
522
+ while (iss >> word) {
523
+ words.push_back(word);
524
+ }
525
+
526
+ return words;
527
+ }
528
+
529
+ // returns true if no exit event was received
530
+ bool process_sdl_events() {
531
+ SDL_Event event;
532
+ while (SDL_PollEvent(&event)) {
533
+ switch (event.type) {
534
+ case SDL_QUIT:
535
+ {
536
+ return false;
537
+ } break;
538
+ default:
539
+ break;
540
+ }
541
+ }
542
+
543
+ return true;
544
+ }
545
+
546
  // command-list mode
547
  // guide the transcription to match the most likely command from a provided list
548
  int process_command_list(struct whisper_context * ctx, audio_async &audio, const whisper_params &params) {
549
+ fprintf(stderr, "\n");
550
+ fprintf(stderr, "%s: guided mode\n", __func__);
551
+
552
+ std::vector<std::string> allowed_commands = read_allowed_commands(params.commands);
553
+
554
+ if (allowed_commands.empty()) {
555
+ fprintf(stderr, "%s: error: failed to read allowed commands from '%s'\n", __func__, params.commands.c_str());
556
+ return 2;
557
+ }
558
+
559
+ int max_len = 0;
560
+
561
+ std::vector<std::vector<whisper_token>> allowed_tokens;
562
+
563
+ for (const auto & cmd : allowed_commands) {
564
+ whisper_token tokens[1024];
565
+ allowed_tokens.emplace_back();
566
+
567
+ for (int l = 0; l < (int) cmd.size(); ++l) {
568
+ // NOTE: very important to add the whitespace !
569
+ // the reason is that the first decoded token starts with a whitespace too!
570
+ std::string ss = std::string(" ") + cmd.substr(0, l + 1);
571
+
572
+ const int n = whisper_tokenize(ctx, ss.c_str(), tokens, 1024);
573
+ if (n < 0) {
574
+ fprintf(stderr, "%s: error: failed to tokenize command '%s'\n", __func__, cmd.c_str());
575
+ return 3;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
576
  }
 
577
 
578
+ if (n == 1) {
579
+ allowed_tokens.back().push_back(tokens[0]);
580
+ }
581
+ }
582
 
583
+ max_len = std::max(max_len, (int) cmd.size());
584
+ }
585
 
586
+ fprintf(stderr, "%s: allowed commands [ tokens ]:\n", __func__);
587
+ fprintf(stderr, "\n");
588
+ for (int i = 0; i < (int) allowed_commands.size(); ++i) {
589
+ fprintf(stderr, " - \033[1m%-*s\033[0m = [", max_len, allowed_commands[i].c_str());
590
+ for (const auto & token : allowed_tokens[i]) {
591
+ fprintf(stderr, " %5d", token);
592
+ }
593
+ fprintf(stderr, " ]\n");
594
+ }
595
+
596
+ std::string k_prompt = "select one from the available words: ";
597
+ for (int i = 0; i < (int) allowed_commands.size(); ++i) {
598
+ if (i > 0) {
599
+ k_prompt += ", ";
600
+ }
601
+ k_prompt += allowed_commands[i];
602
+ }
603
+ k_prompt += ". selected word: ";
604
+
605
+ // tokenize prompt
606
+ std::vector<whisper_token> k_tokens;
607
+ {
608
+ k_tokens.resize(1024);
609
+ const int n = whisper_tokenize(ctx, k_prompt.c_str(), k_tokens.data(), 1024);
610
+ if (n < 0) {
611
+ fprintf(stderr, "%s: error: failed to tokenize prompt '%s'\n", __func__, k_prompt.c_str());
612
+ return 4;
613
+ }
614
+ k_tokens.resize(n);
615
+ }
616
+
617
+ fprintf(stderr, "\n");
618
+ fprintf(stderr, "%s: prompt: '%s'\n", __func__, k_prompt.c_str());
619
+ fprintf(stderr, "%s: tokens: [", __func__);
620
+ for (const auto & token : k_tokens) {
621
+ fprintf(stderr, " %d", token);
622
+ }
623
+ fprintf(stderr, " ]\n");
624
+
625
+ fprintf(stderr, "\n");
626
+ fprintf(stderr, "%s: listening for a command ...\n", __func__);
627
+ fprintf(stderr, "\n");
628
+
629
+ bool is_running = true;
630
 
631
+ std::vector<float> pcmf32_cur;
632
+ std::vector<float> pcmf32_prompt;
633
 
634
+ // main loop
635
+ while (is_running) {
636
+ // handle Ctrl + C
637
+ is_running = process_sdl_events();
638
 
639
+ // delay
640
+ std::this_thread::sleep_for(std::chrono::milliseconds(100));
641
+
642
+ audio.get(2000, pcmf32_cur);
643
+
644
+ if (vad_simple(pcmf32_cur, WHISPER_SAMPLE_RATE, 1000, params.vad_thold, params.freq_thold, params.print_energy)) {
645
+ fprintf(stdout, "%s: Speech detected! Processing ...\n", __func__);
646
 
647
+ const auto t_start = std::chrono::high_resolution_clock::now();
 
 
 
 
 
 
 
 
 
648
 
649
+ whisper_full_params wparams = whisper_full_default_params(WHISPER_SAMPLING_GREEDY);
 
650
 
651
+ wparams.print_progress = false;
652
+ wparams.print_special = params.print_special;
653
+ wparams.print_realtime = false;
654
+ wparams.print_timestamps = !params.no_timestamps;
655
+ wparams.translate = params.translate;
656
+ wparams.no_context = true;
657
+ wparams.single_segment = true;
658
+ wparams.max_tokens = 1;
659
+ wparams.language = params.language.c_str();
660
+ wparams.n_threads = params.n_threads;
661
 
662
+ wparams.audio_ctx = params.audio_ctx;
663
+ wparams.speed_up = params.speed_up;
 
 
 
664
 
665
+ wparams.prompt_tokens = k_tokens.data();
666
+ wparams.prompt_n_tokens = k_tokens.size();
667
 
668
+ // run the transformer and a single decoding pass
669
+ if (whisper_full(ctx, wparams, pcmf32_cur.data(), pcmf32_cur.size()) != 0) {
670
+ fprintf(stderr, "%s: ERROR: whisper_full() failed\n", __func__);
671
+ break;
 
672
  }
673
+
674
+ const auto * probs = whisper_get_probs(ctx);
675
+ std::vector<std::pair<float, int>> probs_id;
676
+
677
+ double psum = 0.0;
678
+ for (int i = 0; i < (int) allowed_commands.size(); ++i) {
679
+ probs_id.emplace_back(probs[allowed_tokens[i][0]], i);
680
+ for (int j = 1; j < (int) allowed_tokens[i].size(); ++j) {
681
+ probs_id.back().first += probs[allowed_tokens[i][j]];
682
+ }
683
+ probs_id.back().first /= allowed_tokens[i].size();
684
+ psum += probs_id.back().first;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
685
  }
 
686
 
687
+ // normalize
688
+ for (auto & p : probs_id) {
689
+ p.first /= psum;
690
+ }
691
 
692
+ // sort descending
693
+ {
694
+ using pair_type = decltype(probs_id)::value_type;
695
+ std::sort(probs_id.begin(), probs_id.end(), [](const pair_type & a, const pair_type & b) {
696
+ return a.first > b.first;
697
+ });
698
+ }
699
 
700
+ // print the commands and the respective probabilities
701
+ {
702
+ fprintf(stdout, "\n");
703
+ for (const auto & cmd : probs_id) {
704
+ fprintf(stdout, "%s: %s%-*s%s = %f | ", __func__, "\033[1m", max_len, allowed_commands[cmd.second].c_str(), "\033[0m", cmd.first);
705
+ for (int token : allowed_tokens[cmd.second]) {
706
+ fprintf(stdout, "'%4s' %f ", whisper_token_to_str(ctx, token), probs[token]);
707
+ }
708
+ fprintf(stdout, "\n");
709
+ }
710
+ }
711
 
712
+ // best command
713
+ {
714
+ const auto t_end = std::chrono::high_resolution_clock::now();
715
 
716
+ const float prob = probs_id[0].first;
717
+ const int index = probs_id[0].second;
718
 
719
+ fprintf(stdout, "\n");
720
+ fprintf(stdout, "%s: detected command: %s%s%s | p = %f | t = %d ms\n", __func__,
721
+ "\033[1m", allowed_commands[index].c_str(), "\033[0m", prob,
722
+ (int) std::chrono::duration_cast<std::chrono::milliseconds>(t_end - t_start).count());
723
+ fprintf(stdout, "\n");
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
724
  }
 
725
 
726
+ audio.clear();
727
+ }
728
+ }
729
+
730
+ return 0;
731
+ }
732
 
733
+ // always-prompt mode
734
+ // transcribe the voice into text after valid prompt
735
+ int always_prompt_transcription(struct whisper_context * ctx, audio_async & audio, const whisper_params & params) {
736
+ bool is_running = true;
737
+ bool ask_prompt = true;
738
 
739
+ float prob = 0.0f;
 
 
 
740
 
741
+ std::vector<float> pcmf32_cur;
 
742
 
743
+ const std::string k_prompt = params.prompt;
 
744
 
745
+ const int k_prompt_length = get_words(k_prompt).size();
 
746
 
747
+ fprintf(stderr, "\n");
748
+ fprintf(stderr, "%s: always-prompt mode\n", __func__);
749
+
750
+ // main loop
751
+ while (is_running) {
752
+ // handle Ctrl + C
753
+ is_running = process_sdl_events();
754
 
755
+ // delay
756
+ std::this_thread::sleep_for(std::chrono::milliseconds(100));
 
757
 
758
+ if (ask_prompt) {
759
+ fprintf(stdout, "\n");
760
+ fprintf(stdout, "%s: The prompt is: '%s%s%s'\n", __func__, "\033[1m", k_prompt.c_str(), "\033[0m");
761
+ fprintf(stdout, "\n");
762
 
763
+ ask_prompt = false;
764
+ }
765
 
766
+ {
767
+ audio.get(2000, pcmf32_cur);
768
 
769
+ if (vad_simple(pcmf32_cur, WHISPER_SAMPLE_RATE, 1000, params.vad_thold, params.freq_thold, params.print_energy)) {
770
+ fprintf(stdout, "%s: Speech detected! Processing ...\n", __func__);
 
 
 
 
 
 
771
 
772
+ int64_t t_ms = 0;
 
 
 
 
 
 
773
 
774
+ // detect the commands
775
+ audio.get(params.command_ms, pcmf32_cur);
776
 
777
+ const auto txt = ::trim(::transcribe(ctx, params, pcmf32_cur, prob, t_ms));
778
 
779
+ const auto words = get_words(txt);
780
 
781
+ std::string prompt;
782
+ std::string command;
783
 
784
+ for (int i = 0; i < words.size(); ++i) {
785
+ if (i < k_prompt_length) {
786
+ prompt += words[i] + " ";
787
+ } else {
788
+ command += words[i] + " ";
789
+ }
790
+ }
791
 
792
+ const float sim = similarity(prompt, k_prompt);
793
 
794
+ //debug
795
+ //fprintf(stdout, "command size: %i\n", command_length);
796
 
797
+ if ((sim > 0.7f) && (command.size() > 0)) {
798
+ fprintf(stdout, "%s: Command '%s%s%s', (t = %d ms)\n", __func__, "\033[1m", command.c_str(), "\033[0m", (int) t_ms);
799
+ }
 
 
800
 
801
+ fprintf(stdout, "\n");
802
 
803
+ audio.clear();
 
804
  }
805
+ }
806
+ }
807
 
808
+ return 0;
 
 
 
 
 
809
  }
810
 
811
+ // general-purpose mode
812
+ // freely transcribe the voice into text
813
+ int process_general_transcription(struct whisper_context * ctx, audio_async &audio, const whisper_params &params) {
814
+ bool is_running = true;
815
+ bool have_prompt = false;
816
+ bool ask_prompt = true;
817
 
818
+ float prob0 = 0.0f;
819
+ float prob = 0.0f;
 
 
 
820
 
821
+ std::vector<float> pcmf32_cur;
822
+ std::vector<float> pcmf32_prompt;
823
 
824
+ const std::string k_prompt = "Ok Whisper, start listening for commands.";
825
 
826
+ fprintf(stderr, "\n");
827
+ fprintf(stderr, "%s: general-purpose mode\n", __func__);
828
 
829
+ // main loop
830
+ while (is_running) {
831
+ // handle Ctrl + C
832
+ is_running = process_sdl_events();
833
 
834
+ // delay
835
+ std::this_thread::sleep_for(std::chrono::milliseconds(100));
836
 
837
+ if (ask_prompt) {
838
+ fprintf(stdout, "\n");
839
+ fprintf(stdout, "%s: Say the following phrase: '%s%s%s'\n", __func__, "\033[1m", k_prompt.c_str(), "\033[0m");
840
+ fprintf(stdout, "\n");
841
 
842
+ ask_prompt = false;
843
+ }
844
 
845
+ {
846
+ audio.get(2000, pcmf32_cur);
 
 
 
 
 
 
 
 
 
 
 
 
 
847
 
848
+ if (vad_simple(pcmf32_cur, WHISPER_SAMPLE_RATE, 1000, params.vad_thold, params.freq_thold, params.print_energy)) {
849
+ fprintf(stdout, "%s: Speech detected! Processing ...\n", __func__);
 
 
850
 
851
+ int64_t t_ms = 0;
 
852
 
853
+ if (!have_prompt) {
854
+ // wait for activation phrase
855
+ audio.get(params.prompt_ms, pcmf32_cur);
 
856
 
857
+ const auto txt = ::trim(::transcribe(ctx, params, pcmf32_cur, prob0, t_ms));
 
858
 
859
+ fprintf(stdout, "%s: Heard '%s%s%s', (t = %d ms)\n", __func__, "\033[1m", txt.c_str(), "\033[0m", (int) t_ms);
 
860
 
861
+ const float sim = similarity(txt, k_prompt);
 
862
 
863
+ if (txt.length() < 0.8*k_prompt.length() || txt.length() > 1.2*k_prompt.length() || sim < 0.8f) {
864
+ fprintf(stdout, "%s: WARNING: prompt not recognized, try again\n", __func__);
865
+ ask_prompt = true;
866
+ } else {
867
+ fprintf(stdout, "\n");
868
+ fprintf(stdout, "%s: The prompt has been recognized!\n", __func__);
869
+ fprintf(stdout, "%s: Waiting for voice commands ...\n", __func__);
870
+ fprintf(stdout, "\n");
871
 
872
+ // save the audio for the prompt
873
+ pcmf32_prompt = pcmf32_cur;
874
+ have_prompt = true;
875
+ }
876
+ } else {
877
+ // we have heard the activation phrase, now detect the commands
878
+ audio.get(params.command_ms, pcmf32_cur);
879
 
880
+ // prepend the prompt audio
881
+ pcmf32_cur.insert(pcmf32_cur.begin(), pcmf32_prompt.begin(), pcmf32_prompt.end());
882
 
883
+ const auto txt = ::trim(::transcribe(ctx, params, pcmf32_cur, prob, t_ms));
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
884
 
885
+ prob = 100.0f*(prob - prob0);
886
 
887
+ //fprintf(stdout, "%s: heard '%s'\n", __func__, txt.c_str());
 
888
 
889
+ // find the prompt in the text
890
+ float best_sim = 0.0f;
891
+ size_t best_len = 0;
892
+ for (int n = 0.8*k_prompt.size(); n <= 1.2*k_prompt.size(); ++n) {
893
+ const auto prompt = txt.substr(0, n);
894
 
895
+ const float sim = similarity(prompt, k_prompt);
 
 
896
 
897
+ //fprintf(stderr, "%s: prompt = '%s', sim = %f\n", __func__, prompt.c_str(), sim);
898
 
899
+ if (sim > best_sim) {
900
+ best_sim = sim;
901
+ best_len = n;
902
+ }
903
+ }
904
 
905
+ const std::string command = ::trim(txt.substr(best_len));
906
+
907
+ fprintf(stdout, "%s: Command '%s%s%s', (t = %d ms)\n", __func__, "\033[1m", command.c_str(), "\033[0m", (int) t_ms);
908
+ fprintf(stdout, "\n");
909
+ }
910
+
911
+ audio.clear();
912
+ }
913
+ }
914
+ }
915
 
916
+ return 0;
917
  }
918
 
919
  int main(int argc, char ** argv) {
 
970
  int ret_val = 0;
971
 
972
  if (!params.commands.empty()) {
973
+ ret_val = process_command_list(ctx, audio, params);
974
  } else if (!params.prompt.empty()) {
975
+ ret_val = always_prompt_transcription(ctx, audio, params);
976
  } else {
977
+ ret_val = process_general_transcription(ctx, audio, params);
978
  }
979
 
980
  audio.pause();