Spaces:
Running
Running
command : clean-up / refactoring / formatting (#383)
Browse files- examples/command/command.cpp +324 -359
examples/command/command.cpp
CHANGED
|
@@ -11,7 +11,6 @@
|
|
| 11 |
#include <SDL.h>
|
| 12 |
#include <SDL_audio.h>
|
| 13 |
|
| 14 |
-
#include <iostream>
|
| 15 |
#include <sstream>
|
| 16 |
#include <cassert>
|
| 17 |
#include <cstdio>
|
|
@@ -515,440 +514,406 @@ std::vector<std::string> read_allowed_commands(const std::string & fname) {
|
|
| 515 |
return allowed_commands;
|
| 516 |
}
|
| 517 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 518 |
// command-list mode
|
| 519 |
// guide the transcription to match the most likely command from a provided list
|
| 520 |
int process_command_list(struct whisper_context * ctx, audio_async &audio, const whisper_params ¶ms) {
|
| 521 |
-
|
| 522 |
-
|
| 523 |
-
|
| 524 |
-
|
| 525 |
-
|
| 526 |
-
|
| 527 |
-
|
| 528 |
-
|
| 529 |
-
|
| 530 |
-
|
| 531 |
-
|
| 532 |
-
|
| 533 |
-
|
| 534 |
-
|
| 535 |
-
|
| 536 |
-
|
| 537 |
-
|
| 538 |
-
|
| 539 |
-
|
| 540 |
-
|
| 541 |
-
|
| 542 |
-
|
| 543 |
-
|
| 544 |
-
|
| 545 |
-
|
| 546 |
-
|
| 547 |
-
|
| 548 |
-
}
|
| 549 |
-
|
| 550 |
-
if (n == 1) {
|
| 551 |
-
allowed_tokens.back().push_back(tokens[0]);
|
| 552 |
-
}
|
| 553 |
-
}
|
| 554 |
-
|
| 555 |
-
max_len = std::max(max_len, (int) cmd.size());
|
| 556 |
-
}
|
| 557 |
-
|
| 558 |
-
fprintf(stderr, "%s: allowed commands [ tokens ]:\n", __func__);
|
| 559 |
-
fprintf(stderr, "\n");
|
| 560 |
-
for (int i = 0; i < (int) allowed_commands.size(); ++i) {
|
| 561 |
-
fprintf(stderr, " - \033[1m%-*s\033[0m = [", max_len, allowed_commands[i].c_str());
|
| 562 |
-
for (const auto & token : allowed_tokens[i]) {
|
| 563 |
-
fprintf(stderr, " %5d", token);
|
| 564 |
-
}
|
| 565 |
-
fprintf(stderr, " ]\n");
|
| 566 |
-
}
|
| 567 |
-
|
| 568 |
-
std::string k_prompt = "select one from the available words: ";
|
| 569 |
-
for (int i = 0; i < (int) allowed_commands.size(); ++i) {
|
| 570 |
-
if (i > 0) {
|
| 571 |
-
k_prompt += ", ";
|
| 572 |
-
}
|
| 573 |
-
k_prompt += allowed_commands[i];
|
| 574 |
-
}
|
| 575 |
-
k_prompt += ". selected word: ";
|
| 576 |
-
|
| 577 |
-
// tokenize prompt
|
| 578 |
-
std::vector<whisper_token> k_tokens;
|
| 579 |
-
{
|
| 580 |
-
k_tokens.resize(1024);
|
| 581 |
-
const int n = whisper_tokenize(ctx, k_prompt.c_str(), k_tokens.data(), 1024);
|
| 582 |
-
if (n < 0) {
|
| 583 |
-
fprintf(stderr, "%s: error: failed to tokenize prompt '%s'\n", __func__, k_prompt.c_str());
|
| 584 |
-
return 4;
|
| 585 |
-
}
|
| 586 |
-
k_tokens.resize(n);
|
| 587 |
-
}
|
| 588 |
-
|
| 589 |
-
fprintf(stderr, "\n");
|
| 590 |
-
fprintf(stderr, "%s: prompt: '%s'\n", __func__, k_prompt.c_str());
|
| 591 |
-
fprintf(stderr, "%s: tokens: [", __func__);
|
| 592 |
-
for (const auto & token : k_tokens) {
|
| 593 |
-
fprintf(stderr, " %d", token);
|
| 594 |
-
}
|
| 595 |
-
fprintf(stderr, " ]\n");
|
| 596 |
-
|
| 597 |
-
fprintf(stderr, "\n");
|
| 598 |
-
fprintf(stderr, "%s: listening for a command ...\n", __func__);
|
| 599 |
-
fprintf(stderr, "\n");
|
| 600 |
-
|
| 601 |
-
bool is_running = true;
|
| 602 |
-
|
| 603 |
-
std::vector<float> pcmf32_cur;
|
| 604 |
-
std::vector<float> pcmf32_prompt;
|
| 605 |
-
|
| 606 |
-
// main loop
|
| 607 |
-
while (is_running) {
|
| 608 |
-
// handle Ctrl + C
|
| 609 |
-
{
|
| 610 |
-
SDL_Event event;
|
| 611 |
-
while (SDL_PollEvent(&event)) {
|
| 612 |
-
switch (event.type) {
|
| 613 |
-
case SDL_QUIT:
|
| 614 |
-
{
|
| 615 |
-
is_running = false;
|
| 616 |
-
} break;
|
| 617 |
-
default:
|
| 618 |
-
break;
|
| 619 |
}
|
| 620 |
-
}
|
| 621 |
|
| 622 |
-
|
| 623 |
-
|
| 624 |
-
|
| 625 |
-
|
| 626 |
|
| 627 |
-
|
| 628 |
-
|
| 629 |
|
| 630 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 631 |
|
| 632 |
-
|
| 633 |
-
|
| 634 |
|
| 635 |
-
|
|
|
|
|
|
|
|
|
|
| 636 |
|
| 637 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 638 |
|
| 639 |
-
|
| 640 |
-
wparams.print_special = params.print_special;
|
| 641 |
-
wparams.print_realtime = false;
|
| 642 |
-
wparams.print_timestamps = !params.no_timestamps;
|
| 643 |
-
wparams.translate = params.translate;
|
| 644 |
-
wparams.no_context = true;
|
| 645 |
-
wparams.single_segment = true;
|
| 646 |
-
wparams.max_tokens = 1;
|
| 647 |
-
wparams.language = params.language.c_str();
|
| 648 |
-
wparams.n_threads = params.n_threads;
|
| 649 |
|
| 650 |
-
|
| 651 |
-
wparams.speed_up = params.speed_up;
|
| 652 |
|
| 653 |
-
|
| 654 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 655 |
|
| 656 |
-
|
| 657 |
-
|
| 658 |
-
fprintf(stderr, "%s: ERROR: whisper_full() failed\n", __func__);
|
| 659 |
-
break;
|
| 660 |
-
}
|
| 661 |
|
| 662 |
-
|
| 663 |
-
|
| 664 |
|
| 665 |
-
|
| 666 |
-
|
| 667 |
-
|
| 668 |
-
|
| 669 |
-
probs_id.back().first += probs[allowed_tokens[i][j]];
|
| 670 |
}
|
| 671 |
-
|
| 672 |
-
|
| 673 |
-
|
| 674 |
-
|
| 675 |
-
|
| 676 |
-
|
| 677 |
-
|
| 678 |
-
|
| 679 |
-
|
| 680 |
-
|
| 681 |
-
|
| 682 |
-
|
| 683 |
-
std::sort(probs_id.begin(), probs_id.end(), [](const pair_type & a, const pair_type & b) {
|
| 684 |
-
return a.first > b.first;
|
| 685 |
-
});
|
| 686 |
-
}
|
| 687 |
-
|
| 688 |
-
// print the commands and the respective probabilities
|
| 689 |
-
{
|
| 690 |
-
fprintf(stdout, "\n");
|
| 691 |
-
for (const auto & cmd : probs_id) {
|
| 692 |
-
fprintf(stdout, "%s: %s%-*s%s = %f | ", __func__, "\033[1m", max_len, allowed_commands[cmd.second].c_str(), "\033[0m", cmd.first);
|
| 693 |
-
for (int token : allowed_tokens[cmd.second]) {
|
| 694 |
-
fprintf(stdout, "'%4s' %f ", whisper_token_to_str(ctx, token), probs[token]);
|
| 695 |
-
}
|
| 696 |
-
fprintf(stdout, "\n");
|
| 697 |
}
|
| 698 |
-
}
|
| 699 |
|
| 700 |
-
|
| 701 |
-
|
| 702 |
-
|
|
|
|
| 703 |
|
| 704 |
-
|
| 705 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 706 |
|
| 707 |
-
|
| 708 |
-
|
| 709 |
-
|
| 710 |
-
|
| 711 |
-
|
| 712 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 713 |
|
| 714 |
-
|
| 715 |
-
|
| 716 |
-
|
| 717 |
|
| 718 |
-
|
| 719 |
-
|
| 720 |
|
| 721 |
-
|
| 722 |
-
|
| 723 |
-
|
| 724 |
-
|
| 725 |
-
|
| 726 |
-
bool ask_prompt = true;
|
| 727 |
-
|
| 728 |
-
float prob0 = 0.0f;
|
| 729 |
-
float prob = 0.0f;
|
| 730 |
-
|
| 731 |
-
std::vector<float> pcmf32_cur;
|
| 732 |
-
std::vector<float> pcmf32_prompt;
|
| 733 |
-
|
| 734 |
-
const std::string k_prompt = "Ok Whisper, start listening for commands.";
|
| 735 |
-
|
| 736 |
-
fprintf(stderr, "\n");
|
| 737 |
-
fprintf(stderr, "%s: general-purpose mode\n", __func__);
|
| 738 |
-
|
| 739 |
-
// main loop
|
| 740 |
-
while (is_running) {
|
| 741 |
-
// handle Ctrl + C
|
| 742 |
-
{
|
| 743 |
-
SDL_Event event;
|
| 744 |
-
while (SDL_PollEvent(&event)) {
|
| 745 |
-
switch (event.type) {
|
| 746 |
-
case SDL_QUIT:
|
| 747 |
-
{
|
| 748 |
-
is_running = false;
|
| 749 |
-
} break;
|
| 750 |
-
default:
|
| 751 |
-
break;
|
| 752 |
}
|
| 753 |
-
}
|
| 754 |
|
| 755 |
-
|
| 756 |
-
|
| 757 |
-
|
| 758 |
-
|
|
|
|
|
|
|
| 759 |
|
| 760 |
-
|
| 761 |
-
|
|
|
|
|
|
|
|
|
|
| 762 |
|
| 763 |
-
|
| 764 |
-
fprintf(stdout, "\n");
|
| 765 |
-
fprintf(stdout, "%s: Say the following phrase: '%s%s%s'\n", __func__, "\033[1m", k_prompt.c_str(), "\033[0m");
|
| 766 |
-
fprintf(stdout, "\n");
|
| 767 |
|
| 768 |
-
|
| 769 |
-
}
|
| 770 |
|
| 771 |
-
|
| 772 |
-
audio.get(2000, pcmf32_cur);
|
| 773 |
|
| 774 |
-
|
| 775 |
-
fprintf(stdout, "%s: Speech detected! Processing ...\n", __func__);
|
| 776 |
|
| 777 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 778 |
|
| 779 |
-
|
| 780 |
-
|
| 781 |
-
audio.get(params.prompt_ms, pcmf32_cur);
|
| 782 |
|
| 783 |
-
|
|
|
|
|
|
|
|
|
|
| 784 |
|
| 785 |
-
|
|
|
|
| 786 |
|
| 787 |
-
|
|
|
|
| 788 |
|
| 789 |
-
|
| 790 |
-
|
| 791 |
-
ask_prompt = true;
|
| 792 |
-
} else {
|
| 793 |
-
fprintf(stdout, "\n");
|
| 794 |
-
fprintf(stdout, "%s: The prompt has been recognized!\n", __func__);
|
| 795 |
-
fprintf(stdout, "%s: Waiting for voice commands ...\n", __func__);
|
| 796 |
-
fprintf(stdout, "\n");
|
| 797 |
|
| 798 |
-
|
| 799 |
-
pcmf32_prompt = pcmf32_cur;
|
| 800 |
-
have_prompt = true;
|
| 801 |
-
}
|
| 802 |
-
} else {
|
| 803 |
-
// we have heard the activation phrase, now detect the commands
|
| 804 |
-
audio.get(params.command_ms, pcmf32_cur);
|
| 805 |
|
| 806 |
-
|
| 807 |
-
|
| 808 |
|
| 809 |
-
|
| 810 |
|
| 811 |
-
|
| 812 |
|
| 813 |
-
|
|
|
|
| 814 |
|
| 815 |
-
|
| 816 |
-
|
| 817 |
-
|
| 818 |
-
|
| 819 |
-
|
|
|
|
|
|
|
| 820 |
|
| 821 |
-
|
| 822 |
|
| 823 |
-
|
|
|
|
| 824 |
|
| 825 |
-
|
| 826 |
-
|
| 827 |
-
|
| 828 |
-
}
|
| 829 |
-
}
|
| 830 |
|
| 831 |
-
|
| 832 |
|
| 833 |
-
|
| 834 |
-
fprintf(stdout, "\n");
|
| 835 |
}
|
|
|
|
|
|
|
| 836 |
|
| 837 |
-
|
| 838 |
-
}
|
| 839 |
-
}
|
| 840 |
-
}
|
| 841 |
-
|
| 842 |
-
return 0;
|
| 843 |
}
|
| 844 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 845 |
|
| 846 |
-
|
| 847 |
-
|
| 848 |
-
int always_prompt_transcription(struct whisper_context * ctx, audio_async &audio, const whisper_params ¶ms) {
|
| 849 |
-
bool is_running = true;
|
| 850 |
-
bool ask_prompt = true;
|
| 851 |
|
| 852 |
-
|
|
|
|
| 853 |
|
| 854 |
-
|
| 855 |
|
| 856 |
-
|
|
|
|
| 857 |
|
| 858 |
-
|
|
|
|
|
|
|
|
|
|
| 859 |
|
| 860 |
-
|
| 861 |
-
|
| 862 |
|
| 863 |
-
|
| 864 |
-
|
| 865 |
-
|
|
|
|
| 866 |
|
| 867 |
-
|
|
|
|
| 868 |
|
| 869 |
-
|
| 870 |
-
|
| 871 |
-
// handle Ctrl + C
|
| 872 |
-
{
|
| 873 |
-
SDL_Event event;
|
| 874 |
-
while (SDL_PollEvent(&event)) {
|
| 875 |
-
switch (event.type) {
|
| 876 |
-
case SDL_QUIT:
|
| 877 |
-
{
|
| 878 |
-
is_running = false;
|
| 879 |
-
} break;
|
| 880 |
-
default:
|
| 881 |
-
break;
|
| 882 |
-
}
|
| 883 |
-
}
|
| 884 |
|
| 885 |
-
|
| 886 |
-
|
| 887 |
-
}
|
| 888 |
-
}
|
| 889 |
|
| 890 |
-
|
| 891 |
-
std::this_thread::sleep_for(std::chrono::milliseconds(100));
|
| 892 |
|
| 893 |
-
|
| 894 |
-
|
| 895 |
-
|
| 896 |
-
fprintf(stdout, "\n");
|
| 897 |
|
| 898 |
-
|
| 899 |
-
}
|
| 900 |
|
| 901 |
-
|
| 902 |
-
audio.get(2000, pcmf32_cur);
|
| 903 |
|
| 904 |
-
|
| 905 |
-
fprintf(stdout, "%s: Speech detected! Processing ...\n", __func__);
|
| 906 |
|
| 907 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 908 |
|
| 909 |
-
|
| 910 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 911 |
|
| 912 |
-
|
|
|
|
| 913 |
|
| 914 |
-
|
| 915 |
-
std::string word;
|
| 916 |
-
std::string prompt;
|
| 917 |
-
std::string command;
|
| 918 |
-
int i = 0;
|
| 919 |
-
int command_length = 0;
|
| 920 |
-
while (iss >> word) {
|
| 921 |
-
if (i == k_prompt_length - 1) {
|
| 922 |
-
prompt += word + ' ';
|
| 923 |
-
break;
|
| 924 |
-
}
|
| 925 |
-
prompt += word + ' ';
|
| 926 |
-
i++;
|
| 927 |
-
}
|
| 928 |
-
while (iss >> word) {
|
| 929 |
-
command += word + ' ';
|
| 930 |
-
command_length++;
|
| 931 |
-
}
|
| 932 |
|
| 933 |
-
|
| 934 |
|
| 935 |
-
|
| 936 |
-
//fprintf(stdout, "command size: %i\n", command_length);
|
| 937 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 938 |
|
| 939 |
-
|
| 940 |
-
fprintf(stdout, "%s: Command '%s%s%s', (t = %d ms)\n", __func__, "\033[1m", command.c_str(), "\033[0m", (int) t_ms);
|
| 941 |
-
}
|
| 942 |
|
| 943 |
-
|
| 944 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 945 |
|
| 946 |
-
|
| 947 |
-
|
| 948 |
-
|
| 949 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 950 |
|
| 951 |
-
|
| 952 |
}
|
| 953 |
|
| 954 |
int main(int argc, char ** argv) {
|
|
@@ -1005,11 +970,11 @@ int main(int argc, char ** argv) {
|
|
| 1005 |
int ret_val = 0;
|
| 1006 |
|
| 1007 |
if (!params.commands.empty()) {
|
| 1008 |
-
|
| 1009 |
} else if (!params.prompt.empty()) {
|
| 1010 |
-
|
| 1011 |
} else {
|
| 1012 |
-
|
| 1013 |
}
|
| 1014 |
|
| 1015 |
audio.pause();
|
|
|
|
| 11 |
#include <SDL.h>
|
| 12 |
#include <SDL_audio.h>
|
| 13 |
|
|
|
|
| 14 |
#include <sstream>
|
| 15 |
#include <cassert>
|
| 16 |
#include <cstdio>
|
|
|
|
| 514 |
return allowed_commands;
|
| 515 |
}
|
| 516 |
|
| 517 |
+
std::vector<std::string> get_words(const std::string &txt) {
|
| 518 |
+
std::vector<std::string> words;
|
| 519 |
+
|
| 520 |
+
std::istringstream iss(txt);
|
| 521 |
+
std::string word;
|
| 522 |
+
while (iss >> word) {
|
| 523 |
+
words.push_back(word);
|
| 524 |
+
}
|
| 525 |
+
|
| 526 |
+
return words;
|
| 527 |
+
}
|
| 528 |
+
|
| 529 |
+
// returns true if no exit event was received
|
| 530 |
+
bool process_sdl_events() {
|
| 531 |
+
SDL_Event event;
|
| 532 |
+
while (SDL_PollEvent(&event)) {
|
| 533 |
+
switch (event.type) {
|
| 534 |
+
case SDL_QUIT:
|
| 535 |
+
{
|
| 536 |
+
return false;
|
| 537 |
+
} break;
|
| 538 |
+
default:
|
| 539 |
+
break;
|
| 540 |
+
}
|
| 541 |
+
}
|
| 542 |
+
|
| 543 |
+
return true;
|
| 544 |
+
}
|
| 545 |
+
|
| 546 |
// command-list mode
|
| 547 |
// guide the transcription to match the most likely command from a provided list
|
| 548 |
int process_command_list(struct whisper_context * ctx, audio_async &audio, const whisper_params ¶ms) {
|
| 549 |
+
fprintf(stderr, "\n");
|
| 550 |
+
fprintf(stderr, "%s: guided mode\n", __func__);
|
| 551 |
+
|
| 552 |
+
std::vector<std::string> allowed_commands = read_allowed_commands(params.commands);
|
| 553 |
+
|
| 554 |
+
if (allowed_commands.empty()) {
|
| 555 |
+
fprintf(stderr, "%s: error: failed to read allowed commands from '%s'\n", __func__, params.commands.c_str());
|
| 556 |
+
return 2;
|
| 557 |
+
}
|
| 558 |
+
|
| 559 |
+
int max_len = 0;
|
| 560 |
+
|
| 561 |
+
std::vector<std::vector<whisper_token>> allowed_tokens;
|
| 562 |
+
|
| 563 |
+
for (const auto & cmd : allowed_commands) {
|
| 564 |
+
whisper_token tokens[1024];
|
| 565 |
+
allowed_tokens.emplace_back();
|
| 566 |
+
|
| 567 |
+
for (int l = 0; l < (int) cmd.size(); ++l) {
|
| 568 |
+
// NOTE: very important to add the whitespace !
|
| 569 |
+
// the reason is that the first decoded token starts with a whitespace too!
|
| 570 |
+
std::string ss = std::string(" ") + cmd.substr(0, l + 1);
|
| 571 |
+
|
| 572 |
+
const int n = whisper_tokenize(ctx, ss.c_str(), tokens, 1024);
|
| 573 |
+
if (n < 0) {
|
| 574 |
+
fprintf(stderr, "%s: error: failed to tokenize command '%s'\n", __func__, cmd.c_str());
|
| 575 |
+
return 3;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 576 |
}
|
|
|
|
| 577 |
|
| 578 |
+
if (n == 1) {
|
| 579 |
+
allowed_tokens.back().push_back(tokens[0]);
|
| 580 |
+
}
|
| 581 |
+
}
|
| 582 |
|
| 583 |
+
max_len = std::max(max_len, (int) cmd.size());
|
| 584 |
+
}
|
| 585 |
|
| 586 |
+
fprintf(stderr, "%s: allowed commands [ tokens ]:\n", __func__);
|
| 587 |
+
fprintf(stderr, "\n");
|
| 588 |
+
for (int i = 0; i < (int) allowed_commands.size(); ++i) {
|
| 589 |
+
fprintf(stderr, " - \033[1m%-*s\033[0m = [", max_len, allowed_commands[i].c_str());
|
| 590 |
+
for (const auto & token : allowed_tokens[i]) {
|
| 591 |
+
fprintf(stderr, " %5d", token);
|
| 592 |
+
}
|
| 593 |
+
fprintf(stderr, " ]\n");
|
| 594 |
+
}
|
| 595 |
+
|
| 596 |
+
std::string k_prompt = "select one from the available words: ";
|
| 597 |
+
for (int i = 0; i < (int) allowed_commands.size(); ++i) {
|
| 598 |
+
if (i > 0) {
|
| 599 |
+
k_prompt += ", ";
|
| 600 |
+
}
|
| 601 |
+
k_prompt += allowed_commands[i];
|
| 602 |
+
}
|
| 603 |
+
k_prompt += ". selected word: ";
|
| 604 |
+
|
| 605 |
+
// tokenize prompt
|
| 606 |
+
std::vector<whisper_token> k_tokens;
|
| 607 |
+
{
|
| 608 |
+
k_tokens.resize(1024);
|
| 609 |
+
const int n = whisper_tokenize(ctx, k_prompt.c_str(), k_tokens.data(), 1024);
|
| 610 |
+
if (n < 0) {
|
| 611 |
+
fprintf(stderr, "%s: error: failed to tokenize prompt '%s'\n", __func__, k_prompt.c_str());
|
| 612 |
+
return 4;
|
| 613 |
+
}
|
| 614 |
+
k_tokens.resize(n);
|
| 615 |
+
}
|
| 616 |
+
|
| 617 |
+
fprintf(stderr, "\n");
|
| 618 |
+
fprintf(stderr, "%s: prompt: '%s'\n", __func__, k_prompt.c_str());
|
| 619 |
+
fprintf(stderr, "%s: tokens: [", __func__);
|
| 620 |
+
for (const auto & token : k_tokens) {
|
| 621 |
+
fprintf(stderr, " %d", token);
|
| 622 |
+
}
|
| 623 |
+
fprintf(stderr, " ]\n");
|
| 624 |
+
|
| 625 |
+
fprintf(stderr, "\n");
|
| 626 |
+
fprintf(stderr, "%s: listening for a command ...\n", __func__);
|
| 627 |
+
fprintf(stderr, "\n");
|
| 628 |
+
|
| 629 |
+
bool is_running = true;
|
| 630 |
|
| 631 |
+
std::vector<float> pcmf32_cur;
|
| 632 |
+
std::vector<float> pcmf32_prompt;
|
| 633 |
|
| 634 |
+
// main loop
|
| 635 |
+
while (is_running) {
|
| 636 |
+
// handle Ctrl + C
|
| 637 |
+
is_running = process_sdl_events();
|
| 638 |
|
| 639 |
+
// delay
|
| 640 |
+
std::this_thread::sleep_for(std::chrono::milliseconds(100));
|
| 641 |
+
|
| 642 |
+
audio.get(2000, pcmf32_cur);
|
| 643 |
+
|
| 644 |
+
if (vad_simple(pcmf32_cur, WHISPER_SAMPLE_RATE, 1000, params.vad_thold, params.freq_thold, params.print_energy)) {
|
| 645 |
+
fprintf(stdout, "%s: Speech detected! Processing ...\n", __func__);
|
| 646 |
|
| 647 |
+
const auto t_start = std::chrono::high_resolution_clock::now();
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 648 |
|
| 649 |
+
whisper_full_params wparams = whisper_full_default_params(WHISPER_SAMPLING_GREEDY);
|
|
|
|
| 650 |
|
| 651 |
+
wparams.print_progress = false;
|
| 652 |
+
wparams.print_special = params.print_special;
|
| 653 |
+
wparams.print_realtime = false;
|
| 654 |
+
wparams.print_timestamps = !params.no_timestamps;
|
| 655 |
+
wparams.translate = params.translate;
|
| 656 |
+
wparams.no_context = true;
|
| 657 |
+
wparams.single_segment = true;
|
| 658 |
+
wparams.max_tokens = 1;
|
| 659 |
+
wparams.language = params.language.c_str();
|
| 660 |
+
wparams.n_threads = params.n_threads;
|
| 661 |
|
| 662 |
+
wparams.audio_ctx = params.audio_ctx;
|
| 663 |
+
wparams.speed_up = params.speed_up;
|
|
|
|
|
|
|
|
|
|
| 664 |
|
| 665 |
+
wparams.prompt_tokens = k_tokens.data();
|
| 666 |
+
wparams.prompt_n_tokens = k_tokens.size();
|
| 667 |
|
| 668 |
+
// run the transformer and a single decoding pass
|
| 669 |
+
if (whisper_full(ctx, wparams, pcmf32_cur.data(), pcmf32_cur.size()) != 0) {
|
| 670 |
+
fprintf(stderr, "%s: ERROR: whisper_full() failed\n", __func__);
|
| 671 |
+
break;
|
|
|
|
| 672 |
}
|
| 673 |
+
|
| 674 |
+
const auto * probs = whisper_get_probs(ctx);
|
| 675 |
+
std::vector<std::pair<float, int>> probs_id;
|
| 676 |
+
|
| 677 |
+
double psum = 0.0;
|
| 678 |
+
for (int i = 0; i < (int) allowed_commands.size(); ++i) {
|
| 679 |
+
probs_id.emplace_back(probs[allowed_tokens[i][0]], i);
|
| 680 |
+
for (int j = 1; j < (int) allowed_tokens[i].size(); ++j) {
|
| 681 |
+
probs_id.back().first += probs[allowed_tokens[i][j]];
|
| 682 |
+
}
|
| 683 |
+
probs_id.back().first /= allowed_tokens[i].size();
|
| 684 |
+
psum += probs_id.back().first;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 685 |
}
|
|
|
|
| 686 |
|
| 687 |
+
// normalize
|
| 688 |
+
for (auto & p : probs_id) {
|
| 689 |
+
p.first /= psum;
|
| 690 |
+
}
|
| 691 |
|
| 692 |
+
// sort descending
|
| 693 |
+
{
|
| 694 |
+
using pair_type = decltype(probs_id)::value_type;
|
| 695 |
+
std::sort(probs_id.begin(), probs_id.end(), [](const pair_type & a, const pair_type & b) {
|
| 696 |
+
return a.first > b.first;
|
| 697 |
+
});
|
| 698 |
+
}
|
| 699 |
|
| 700 |
+
// print the commands and the respective probabilities
|
| 701 |
+
{
|
| 702 |
+
fprintf(stdout, "\n");
|
| 703 |
+
for (const auto & cmd : probs_id) {
|
| 704 |
+
fprintf(stdout, "%s: %s%-*s%s = %f | ", __func__, "\033[1m", max_len, allowed_commands[cmd.second].c_str(), "\033[0m", cmd.first);
|
| 705 |
+
for (int token : allowed_tokens[cmd.second]) {
|
| 706 |
+
fprintf(stdout, "'%4s' %f ", whisper_token_to_str(ctx, token), probs[token]);
|
| 707 |
+
}
|
| 708 |
+
fprintf(stdout, "\n");
|
| 709 |
+
}
|
| 710 |
+
}
|
| 711 |
|
| 712 |
+
// best command
|
| 713 |
+
{
|
| 714 |
+
const auto t_end = std::chrono::high_resolution_clock::now();
|
| 715 |
|
| 716 |
+
const float prob = probs_id[0].first;
|
| 717 |
+
const int index = probs_id[0].second;
|
| 718 |
|
| 719 |
+
fprintf(stdout, "\n");
|
| 720 |
+
fprintf(stdout, "%s: detected command: %s%s%s | p = %f | t = %d ms\n", __func__,
|
| 721 |
+
"\033[1m", allowed_commands[index].c_str(), "\033[0m", prob,
|
| 722 |
+
(int) std::chrono::duration_cast<std::chrono::milliseconds>(t_end - t_start).count());
|
| 723 |
+
fprintf(stdout, "\n");
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 724 |
}
|
|
|
|
| 725 |
|
| 726 |
+
audio.clear();
|
| 727 |
+
}
|
| 728 |
+
}
|
| 729 |
+
|
| 730 |
+
return 0;
|
| 731 |
+
}
|
| 732 |
|
| 733 |
+
// always-prompt mode
|
| 734 |
+
// transcribe the voice into text after valid prompt
|
| 735 |
+
int always_prompt_transcription(struct whisper_context * ctx, audio_async & audio, const whisper_params & params) {
|
| 736 |
+
bool is_running = true;
|
| 737 |
+
bool ask_prompt = true;
|
| 738 |
|
| 739 |
+
float prob = 0.0f;
|
|
|
|
|
|
|
|
|
|
| 740 |
|
| 741 |
+
std::vector<float> pcmf32_cur;
|
|
|
|
| 742 |
|
| 743 |
+
const std::string k_prompt = params.prompt;
|
|
|
|
| 744 |
|
| 745 |
+
const int k_prompt_length = get_words(k_prompt).size();
|
|
|
|
| 746 |
|
| 747 |
+
fprintf(stderr, "\n");
|
| 748 |
+
fprintf(stderr, "%s: always-prompt mode\n", __func__);
|
| 749 |
+
|
| 750 |
+
// main loop
|
| 751 |
+
while (is_running) {
|
| 752 |
+
// handle Ctrl + C
|
| 753 |
+
is_running = process_sdl_events();
|
| 754 |
|
| 755 |
+
// delay
|
| 756 |
+
std::this_thread::sleep_for(std::chrono::milliseconds(100));
|
|
|
|
| 757 |
|
| 758 |
+
if (ask_prompt) {
|
| 759 |
+
fprintf(stdout, "\n");
|
| 760 |
+
fprintf(stdout, "%s: The prompt is: '%s%s%s'\n", __func__, "\033[1m", k_prompt.c_str(), "\033[0m");
|
| 761 |
+
fprintf(stdout, "\n");
|
| 762 |
|
| 763 |
+
ask_prompt = false;
|
| 764 |
+
}
|
| 765 |
|
| 766 |
+
{
|
| 767 |
+
audio.get(2000, pcmf32_cur);
|
| 768 |
|
| 769 |
+
if (vad_simple(pcmf32_cur, WHISPER_SAMPLE_RATE, 1000, params.vad_thold, params.freq_thold, params.print_energy)) {
|
| 770 |
+
fprintf(stdout, "%s: Speech detected! Processing ...\n", __func__);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 771 |
|
| 772 |
+
int64_t t_ms = 0;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 773 |
|
| 774 |
+
// detect the commands
|
| 775 |
+
audio.get(params.command_ms, pcmf32_cur);
|
| 776 |
|
| 777 |
+
const auto txt = ::trim(::transcribe(ctx, params, pcmf32_cur, prob, t_ms));
|
| 778 |
|
| 779 |
+
const auto words = get_words(txt);
|
| 780 |
|
| 781 |
+
std::string prompt;
|
| 782 |
+
std::string command;
|
| 783 |
|
| 784 |
+
for (int i = 0; i < words.size(); ++i) {
|
| 785 |
+
if (i < k_prompt_length) {
|
| 786 |
+
prompt += words[i] + " ";
|
| 787 |
+
} else {
|
| 788 |
+
command += words[i] + " ";
|
| 789 |
+
}
|
| 790 |
+
}
|
| 791 |
|
| 792 |
+
const float sim = similarity(prompt, k_prompt);
|
| 793 |
|
| 794 |
+
//debug
|
| 795 |
+
//fprintf(stdout, "command size: %i\n", command_length);
|
| 796 |
|
| 797 |
+
if ((sim > 0.7f) && (command.size() > 0)) {
|
| 798 |
+
fprintf(stdout, "%s: Command '%s%s%s', (t = %d ms)\n", __func__, "\033[1m", command.c_str(), "\033[0m", (int) t_ms);
|
| 799 |
+
}
|
|
|
|
|
|
|
| 800 |
|
| 801 |
+
fprintf(stdout, "\n");
|
| 802 |
|
| 803 |
+
audio.clear();
|
|
|
|
| 804 |
}
|
| 805 |
+
}
|
| 806 |
+
}
|
| 807 |
|
| 808 |
+
return 0;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 809 |
}
|
| 810 |
|
| 811 |
+
// general-purpose mode
|
| 812 |
+
// freely transcribe the voice into text
|
| 813 |
+
int process_general_transcription(struct whisper_context * ctx, audio_async &audio, const whisper_params ¶ms) {
|
| 814 |
+
bool is_running = true;
|
| 815 |
+
bool have_prompt = false;
|
| 816 |
+
bool ask_prompt = true;
|
| 817 |
|
| 818 |
+
float prob0 = 0.0f;
|
| 819 |
+
float prob = 0.0f;
|
|
|
|
|
|
|
|
|
|
| 820 |
|
| 821 |
+
std::vector<float> pcmf32_cur;
|
| 822 |
+
std::vector<float> pcmf32_prompt;
|
| 823 |
|
| 824 |
+
const std::string k_prompt = "Ok Whisper, start listening for commands.";
|
| 825 |
|
| 826 |
+
fprintf(stderr, "\n");
|
| 827 |
+
fprintf(stderr, "%s: general-purpose mode\n", __func__);
|
| 828 |
|
| 829 |
+
// main loop
|
| 830 |
+
while (is_running) {
|
| 831 |
+
// handle Ctrl + C
|
| 832 |
+
is_running = process_sdl_events();
|
| 833 |
|
| 834 |
+
// delay
|
| 835 |
+
std::this_thread::sleep_for(std::chrono::milliseconds(100));
|
| 836 |
|
| 837 |
+
if (ask_prompt) {
|
| 838 |
+
fprintf(stdout, "\n");
|
| 839 |
+
fprintf(stdout, "%s: Say the following phrase: '%s%s%s'\n", __func__, "\033[1m", k_prompt.c_str(), "\033[0m");
|
| 840 |
+
fprintf(stdout, "\n");
|
| 841 |
|
| 842 |
+
ask_prompt = false;
|
| 843 |
+
}
|
| 844 |
|
| 845 |
+
{
|
| 846 |
+
audio.get(2000, pcmf32_cur);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 847 |
|
| 848 |
+
if (vad_simple(pcmf32_cur, WHISPER_SAMPLE_RATE, 1000, params.vad_thold, params.freq_thold, params.print_energy)) {
|
| 849 |
+
fprintf(stdout, "%s: Speech detected! Processing ...\n", __func__);
|
|
|
|
|
|
|
| 850 |
|
| 851 |
+
int64_t t_ms = 0;
|
|
|
|
| 852 |
|
| 853 |
+
if (!have_prompt) {
|
| 854 |
+
// wait for activation phrase
|
| 855 |
+
audio.get(params.prompt_ms, pcmf32_cur);
|
|
|
|
| 856 |
|
| 857 |
+
const auto txt = ::trim(::transcribe(ctx, params, pcmf32_cur, prob0, t_ms));
|
|
|
|
| 858 |
|
| 859 |
+
fprintf(stdout, "%s: Heard '%s%s%s', (t = %d ms)\n", __func__, "\033[1m", txt.c_str(), "\033[0m", (int) t_ms);
|
|
|
|
| 860 |
|
| 861 |
+
const float sim = similarity(txt, k_prompt);
|
|
|
|
| 862 |
|
| 863 |
+
if (txt.length() < 0.8*k_prompt.length() || txt.length() > 1.2*k_prompt.length() || sim < 0.8f) {
|
| 864 |
+
fprintf(stdout, "%s: WARNING: prompt not recognized, try again\n", __func__);
|
| 865 |
+
ask_prompt = true;
|
| 866 |
+
} else {
|
| 867 |
+
fprintf(stdout, "\n");
|
| 868 |
+
fprintf(stdout, "%s: The prompt has been recognized!\n", __func__);
|
| 869 |
+
fprintf(stdout, "%s: Waiting for voice commands ...\n", __func__);
|
| 870 |
+
fprintf(stdout, "\n");
|
| 871 |
|
| 872 |
+
// save the audio for the prompt
|
| 873 |
+
pcmf32_prompt = pcmf32_cur;
|
| 874 |
+
have_prompt = true;
|
| 875 |
+
}
|
| 876 |
+
} else {
|
| 877 |
+
// we have heard the activation phrase, now detect the commands
|
| 878 |
+
audio.get(params.command_ms, pcmf32_cur);
|
| 879 |
|
| 880 |
+
// prepend the prompt audio
|
| 881 |
+
pcmf32_cur.insert(pcmf32_cur.begin(), pcmf32_prompt.begin(), pcmf32_prompt.end());
|
| 882 |
|
| 883 |
+
const auto txt = ::trim(::transcribe(ctx, params, pcmf32_cur, prob, t_ms));
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 884 |
|
| 885 |
+
prob = 100.0f*(prob - prob0);
|
| 886 |
|
| 887 |
+
//fprintf(stdout, "%s: heard '%s'\n", __func__, txt.c_str());
|
|
|
|
| 888 |
|
| 889 |
+
// find the prompt in the text
|
| 890 |
+
float best_sim = 0.0f;
|
| 891 |
+
size_t best_len = 0;
|
| 892 |
+
for (int n = 0.8*k_prompt.size(); n <= 1.2*k_prompt.size(); ++n) {
|
| 893 |
+
const auto prompt = txt.substr(0, n);
|
| 894 |
|
| 895 |
+
const float sim = similarity(prompt, k_prompt);
|
|
|
|
|
|
|
| 896 |
|
| 897 |
+
//fprintf(stderr, "%s: prompt = '%s', sim = %f\n", __func__, prompt.c_str(), sim);
|
| 898 |
|
| 899 |
+
if (sim > best_sim) {
|
| 900 |
+
best_sim = sim;
|
| 901 |
+
best_len = n;
|
| 902 |
+
}
|
| 903 |
+
}
|
| 904 |
|
| 905 |
+
const std::string command = ::trim(txt.substr(best_len));
|
| 906 |
+
|
| 907 |
+
fprintf(stdout, "%s: Command '%s%s%s', (t = %d ms)\n", __func__, "\033[1m", command.c_str(), "\033[0m", (int) t_ms);
|
| 908 |
+
fprintf(stdout, "\n");
|
| 909 |
+
}
|
| 910 |
+
|
| 911 |
+
audio.clear();
|
| 912 |
+
}
|
| 913 |
+
}
|
| 914 |
+
}
|
| 915 |
|
| 916 |
+
return 0;
|
| 917 |
}
|
| 918 |
|
| 919 |
int main(int argc, char ** argv) {
|
|
|
|
| 970 |
int ret_val = 0;
|
| 971 |
|
| 972 |
if (!params.commands.empty()) {
|
| 973 |
+
ret_val = process_command_list(ctx, audio, params);
|
| 974 |
} else if (!params.prompt.empty()) {
|
| 975 |
+
ret_val = always_prompt_transcription(ctx, audio, params);
|
| 976 |
} else {
|
| 977 |
+
ret_val = process_general_transcription(ctx, audio, params);
|
| 978 |
}
|
| 979 |
|
| 980 |
audio.pause();
|