felrock ggerganov commited on
Commit
8a0d34c
·
unverified ·
1 Parent(s): 5a48cf5

server : add a REST Whisper server example with OAI-like API (#1380)

Browse files

* Add first draft of server

* Added json support and base funcs for server.cpp

* Add more user input via api-request

also some clean up

* Add reqest params and load post function

Also some general clean up

* Remove unused function

* Add readme

* Add exception handlers

* Update examples/server/server.cpp

* make : add server target

* Add magic curl syntax

Co-authored-by: Georgi Gerganov <[email protected]>

---------

Co-authored-by: Georgi Gerganov <[email protected]>

.gitignore CHANGED
@@ -31,6 +31,7 @@ build-sanitize-thread/
31
  /talk-llama
32
  /bench
33
  /quantize
 
34
  /lsp
35
 
36
  arm_neon.h
 
31
  /talk-llama
32
  /bench
33
  /quantize
34
+ /server
35
  /lsp
36
 
37
  arm_neon.h
Makefile CHANGED
@@ -1,4 +1,4 @@
1
- default: main bench quantize
2
 
3
  ifndef UNAME_S
4
  UNAME_S := $(shell uname -s)
@@ -338,7 +338,7 @@ libwhisper.so: $(WHISPER_OBJ)
338
  $(CXX) $(CXXFLAGS) -shared -o libwhisper.so $(WHISPER_OBJ) $(LDFLAGS)
339
 
340
  clean:
341
- rm -f *.o main stream command talk talk-llama bench quantize lsp libwhisper.a libwhisper.so
342
 
343
  #
344
  # Examples
@@ -359,6 +359,9 @@ bench: examples/bench/bench.cpp $(WHISPER_OBJ)
359
  quantize: examples/quantize/quantize.cpp $(WHISPER_OBJ) $(SRC_COMMON)
360
  $(CXX) $(CXXFLAGS) examples/quantize/quantize.cpp $(SRC_COMMON) $(WHISPER_OBJ) -o quantize $(LDFLAGS)
361
 
 
 
 
362
  stream: examples/stream/stream.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) $(WHISPER_OBJ)
363
  $(CXX) $(CXXFLAGS) examples/stream/stream.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) $(WHISPER_OBJ) -o stream $(CC_SDL) $(LDFLAGS)
364
 
 
1
+ default: main bench quantize server
2
 
3
  ifndef UNAME_S
4
  UNAME_S := $(shell uname -s)
 
338
  $(CXX) $(CXXFLAGS) -shared -o libwhisper.so $(WHISPER_OBJ) $(LDFLAGS)
339
 
340
  clean:
341
+ rm -f *.o main stream command talk talk-llama bench quantize server lsp libwhisper.a libwhisper.so
342
 
343
  #
344
  # Examples
 
359
  quantize: examples/quantize/quantize.cpp $(WHISPER_OBJ) $(SRC_COMMON)
360
  $(CXX) $(CXXFLAGS) examples/quantize/quantize.cpp $(SRC_COMMON) $(WHISPER_OBJ) -o quantize $(LDFLAGS)
361
 
362
+ server: examples/server/server.cpp $(SRC_COMMON) $(WHISPER_OBJ)
363
+ $(CXX) $(CXXFLAGS) examples/server/server.cpp $(SRC_COMMON) $(WHISPER_OBJ) -o server $(LDFLAGS)
364
+
365
  stream: examples/stream/stream.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) $(WHISPER_OBJ)
366
  $(CXX) $(CXXFLAGS) examples/stream/stream.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) $(WHISPER_OBJ) -o stream $(CC_SDL) $(LDFLAGS)
367
 
examples/CMakeLists.txt CHANGED
@@ -65,6 +65,7 @@ elseif(CMAKE_JS_VERSION)
65
  else()
66
  add_subdirectory(main)
67
  add_subdirectory(stream)
 
68
  add_subdirectory(command)
69
  add_subdirectory(bench)
70
  add_subdirectory(quantize)
 
65
  else()
66
  add_subdirectory(main)
67
  add_subdirectory(stream)
68
+ add_subdirectory(server)
69
  add_subdirectory(command)
70
  add_subdirectory(bench)
71
  add_subdirectory(quantize)
examples/main/main.cpp CHANGED
@@ -165,8 +165,8 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
165
  else if (arg == "-m" || arg == "--model") { params.model = argv[++i]; }
166
  else if (arg == "-f" || arg == "--file") { params.fname_inp.emplace_back(argv[++i]); }
167
  else if (arg == "-oved" || arg == "--ov-e-device") { params.openvino_encode_device = argv[++i]; }
168
- else if (arg == "-ls" || arg == "--log-score") { params.log_score = true; }
169
- else if (arg == "-ng" || arg == "--no-gpu") { params.use_gpu = false; }
170
  else {
171
  fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
172
  whisper_print_usage(argc, argv, params);
 
165
  else if (arg == "-m" || arg == "--model") { params.model = argv[++i]; }
166
  else if (arg == "-f" || arg == "--file") { params.fname_inp.emplace_back(argv[++i]); }
167
  else if (arg == "-oved" || arg == "--ov-e-device") { params.openvino_encode_device = argv[++i]; }
168
+ else if (arg == "-ls" || arg == "--log-score") { params.log_score = true; }
169
+ else if (arg == "-ng" || arg == "--no-gpu") { params.use_gpu = false; }
170
  else {
171
  fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
172
  whisper_print_usage(argc, argv, params);
examples/server/CMakeLists.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ set(TARGET server)
2
+ add_executable(${TARGET} server.cpp httplib.h json.hpp)
3
+
4
+ include(DefaultTargetOptions)
5
+
6
+ target_link_libraries(${TARGET} PRIVATE common whisper ${CMAKE_THREAD_LIBS_INIT})
examples/server/README.md ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # whisper.cpp http server
2
+
3
+ Simple http server. WAV Files are passed to the inference model via http requests.
4
+
5
+ ```
6
+ ./server -h
7
+
8
+ usage: ./bin/server [options]
9
+
10
+ options:
11
+ -h, --help [default] show this help message and exit
12
+ -t N, --threads N [4 ] number of threads to use during computation
13
+ -p N, --processors N [1 ] number of processors to use during computation
14
+ -ot N, --offset-t N [0 ] time offset in milliseconds
15
+ -on N, --offset-n N [0 ] segment index offset
16
+ -d N, --duration N [0 ] duration of audio to process in milliseconds
17
+ -mc N, --max-context N [-1 ] maximum number of text context tokens to store
18
+ -ml N, --max-len N [0 ] maximum segment length in characters
19
+ -sow, --split-on-word [false ] split on word rather than on token
20
+ -bo N, --best-of N [2 ] number of best candidates to keep
21
+ -bs N, --beam-size N [-1 ] beam size for beam search
22
+ -wt N, --word-thold N [0.01 ] word timestamp probability threshold
23
+ -et N, --entropy-thold N [2.40 ] entropy threshold for decoder fail
24
+ -lpt N, --logprob-thold N [-1.00 ] log probability threshold for decoder fail
25
+ -debug, --debug-mode [false ] enable debug mode (eg. dump log_mel)
26
+ -tr, --translate [false ] translate from source language to english
27
+ -di, --diarize [false ] stereo audio diarization
28
+ -tdrz, --tinydiarize [false ] enable tinydiarize (requires a tdrz model)
29
+ -nf, --no-fallback [false ] do not use temperature fallback while decoding
30
+ -ps, --print-special [false ] print special tokens
31
+ -pc, --print-colors [false ] print colors
32
+ -pp, --print-progress [false ] print progress
33
+ -nt, --no-timestamps [false ] do not print timestamps
34
+ -l LANG, --language LANG [en ] spoken language ('auto' for auto-detect)
35
+ -dl, --detect-language [false ] exit after automatically detecting language
36
+ --prompt PROMPT [ ] initial prompt
37
+ -m FNAME, --model FNAME [models/ggml-base.en.bin] model path
38
+ -oved D, --ov-e-device DNAME [CPU ] the OpenVINO device used for encode inference
39
+ --host HOST, [127.0.0.1] Hostname/ip-adress for the server
40
+ --port PORT, [8080 ] Port number for the server
41
+ ```
42
+
43
+ ## request examples
44
+
45
+ **/inference**
46
+ ```
47
+ curl 127.0.0.1:8080/inference \
48
+ -H "Content-Type: multipart/form-data" \
49
+ -F file="@<file-path>" \
50
+ -F temperature="0.2" \
51
+ -F response-format="json"
52
+ ```
53
+
54
+ **/load**
55
+ ```
56
+ curl 127.0.0.1:8080/load \
57
+ -H "Content-Type: multipart/form-data" \
58
+ -F model="<path-to-model-file>"
59
+ ```
examples/server/httplib.h ADDED
The diff for this file is too large to render. See raw diff
 
examples/server/json.hpp ADDED
The diff for this file is too large to render. See raw diff
 
examples/server/server.cpp ADDED
@@ -0,0 +1,699 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include "common.h"
2
+
3
+ #include "whisper.h"
4
+ #include "httplib.h"
5
+ #include "json.hpp"
6
+
7
+ #include <cmath>
8
+ #include <fstream>
9
+ #include <cstdio>
10
+ #include <string>
11
+ #include <thread>
12
+ #include <vector>
13
+ #include <cstring>
14
+
15
+ #if defined(_MSC_VER)
16
+ #pragma warning(disable: 4244 4267) // possible loss of data
17
+ #endif
18
+
19
+ using namespace httplib;
20
+ using json = nlohmann::json;
21
+
22
+ namespace {
23
+
24
+ // Terminal color map. 10 colors grouped in ranges [0.0, 0.1, ..., 0.9]
25
+ // Lowest is red, middle is yellow, highest is green.
26
+ const std::vector<std::string> k_colors = {
27
+ "\033[38;5;196m", "\033[38;5;202m", "\033[38;5;208m", "\033[38;5;214m", "\033[38;5;220m",
28
+ "\033[38;5;226m", "\033[38;5;190m", "\033[38;5;154m", "\033[38;5;118m", "\033[38;5;82m",
29
+ };
30
+
31
+ // output formats
32
+ const std::string json_format = "json";
33
+ const std::string text_format = "text";
34
+ const std::string srt_format = "srt";
35
+ const std::string vjson_format = "verbose_json";
36
+ const std::string vtt_format = "vtt";
37
+
38
+ struct server_params
39
+ {
40
+ std::string hostname = "127.0.0.1";
41
+ std::string public_path = "examples/server/public";
42
+
43
+ int32_t port = 8080;
44
+ int32_t read_timeout = 600;
45
+ int32_t write_timeout = 600;
46
+ };
47
+
48
+ struct whisper_params {
49
+ int32_t n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency());
50
+ int32_t n_processors = 1;
51
+ int32_t offset_t_ms = 0;
52
+ int32_t offset_n = 0;
53
+ int32_t duration_ms = 0;
54
+ int32_t progress_step = 5;
55
+ int32_t max_context = -1;
56
+ int32_t max_len = 0;
57
+ int32_t best_of = 2;
58
+ int32_t beam_size = -1;
59
+
60
+ float word_thold = 0.01f;
61
+ float entropy_thold = 2.40f;
62
+ float logprob_thold = -1.00f;
63
+ float userdef_temp = 0.20f;
64
+
65
+ bool speed_up = false;
66
+ bool debug_mode = false;
67
+ bool translate = false;
68
+ bool detect_language = false;
69
+ bool diarize = false;
70
+ bool tinydiarize = false;
71
+ bool split_on_word = false;
72
+ bool no_fallback = false;
73
+ bool print_special = false;
74
+ bool print_colors = false;
75
+ bool print_progress = false;
76
+ bool no_timestamps = false;
77
+ bool use_gpu = true;
78
+
79
+ std::string language = "en";
80
+ std::string prompt = "";
81
+ std::string font_path = "/System/Library/Fonts/Supplemental/Courier New Bold.ttf";
82
+ std::string model = "models/ggml-base.en.bin";
83
+
84
+ std::string response_format = json_format;
85
+
86
+ // [TDRZ] speaker turn string
87
+ std::string tdrz_speaker_turn = " [SPEAKER_TURN]"; // TODO: set from command line
88
+
89
+ std::string openvino_encode_device = "CPU";
90
+ };
91
+
92
+ // 500 -> 00:05.000
93
+ // 6000 -> 01:00.000
94
+ std::string to_timestamp(int64_t t, bool comma = false) {
95
+ int64_t msec = t * 10;
96
+ int64_t hr = msec / (1000 * 60 * 60);
97
+ msec = msec - hr * (1000 * 60 * 60);
98
+ int64_t min = msec / (1000 * 60);
99
+ msec = msec - min * (1000 * 60);
100
+ int64_t sec = msec / 1000;
101
+ msec = msec - sec * 1000;
102
+
103
+ char buf[32];
104
+ snprintf(buf, sizeof(buf), "%02d:%02d:%02d%s%03d", (int) hr, (int) min, (int) sec, comma ? "," : ".", (int) msec);
105
+
106
+ return std::string(buf);
107
+ }
108
+
109
+ int timestamp_to_sample(int64_t t, int n_samples) {
110
+ return std::max(0, std::min((int) n_samples - 1, (int) ((t*WHISPER_SAMPLE_RATE)/100)));
111
+ }
112
+
113
+ bool is_file_exist(const char *fileName)
114
+ {
115
+ std::ifstream infile(fileName);
116
+ return infile.good();
117
+ }
118
+
119
+ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & params,
120
+ const server_params& sparams) {
121
+ fprintf(stderr, "\n");
122
+ fprintf(stderr, "usage: %s [options] \n", argv[0]);
123
+ fprintf(stderr, "\n");
124
+ fprintf(stderr, "options:\n");
125
+ fprintf(stderr, " -h, --help [default] show this help message and exit\n");
126
+ fprintf(stderr, " -t N, --threads N [%-7d] number of threads to use during computation\n", params.n_threads);
127
+ fprintf(stderr, " -p N, --processors N [%-7d] number of processors to use during computation\n", params.n_processors);
128
+ fprintf(stderr, " -ot N, --offset-t N [%-7d] time offset in milliseconds\n", params.offset_t_ms);
129
+ fprintf(stderr, " -on N, --offset-n N [%-7d] segment index offset\n", params.offset_n);
130
+ fprintf(stderr, " -d N, --duration N [%-7d] duration of audio to process in milliseconds\n", params.duration_ms);
131
+ fprintf(stderr, " -mc N, --max-context N [%-7d] maximum number of text context tokens to store\n", params.max_context);
132
+ fprintf(stderr, " -ml N, --max-len N [%-7d] maximum segment length in characters\n", params.max_len);
133
+ fprintf(stderr, " -sow, --split-on-word [%-7s] split on word rather than on token\n", params.split_on_word ? "true" : "false");
134
+ fprintf(stderr, " -bo N, --best-of N [%-7d] number of best candidates to keep\n", params.best_of);
135
+ fprintf(stderr, " -bs N, --beam-size N [%-7d] beam size for beam search\n", params.beam_size);
136
+ fprintf(stderr, " -wt N, --word-thold N [%-7.2f] word timestamp probability threshold\n", params.word_thold);
137
+ fprintf(stderr, " -et N, --entropy-thold N [%-7.2f] entropy threshold for decoder fail\n", params.entropy_thold);
138
+ fprintf(stderr, " -lpt N, --logprob-thold N [%-7.2f] log probability threshold for decoder fail\n", params.logprob_thold);
139
+ // fprintf(stderr, " -su, --speed-up [%-7s] speed up audio by x2 (reduced accuracy)\n", params.speed_up ? "true" : "false");
140
+ fprintf(stderr, " -debug, --debug-mode [%-7s] enable debug mode (eg. dump log_mel)\n", params.debug_mode ? "true" : "false");
141
+ fprintf(stderr, " -tr, --translate [%-7s] translate from source language to english\n", params.translate ? "true" : "false");
142
+ fprintf(stderr, " -di, --diarize [%-7s] stereo audio diarization\n", params.diarize ? "true" : "false");
143
+ fprintf(stderr, " -tdrz, --tinydiarize [%-7s] enable tinydiarize (requires a tdrz model)\n", params.tinydiarize ? "true" : "false");
144
+ fprintf(stderr, " -nf, --no-fallback [%-7s] do not use temperature fallback while decoding\n", params.no_fallback ? "true" : "false");
145
+ fprintf(stderr, " -ps, --print-special [%-7s] print special tokens\n", params.print_special ? "true" : "false");
146
+ fprintf(stderr, " -pc, --print-colors [%-7s] print colors\n", params.print_colors ? "true" : "false");
147
+ fprintf(stderr, " -pp, --print-progress [%-7s] print progress\n", params.print_progress ? "true" : "false");
148
+ fprintf(stderr, " -nt, --no-timestamps [%-7s] do not print timestamps\n", params.no_timestamps ? "true" : "false");
149
+ fprintf(stderr, " -l LANG, --language LANG [%-7s] spoken language ('auto' for auto-detect)\n", params.language.c_str());
150
+ fprintf(stderr, " -dl, --detect-language [%-7s] exit after automatically detecting language\n", params.detect_language ? "true" : "false");
151
+ fprintf(stderr, " --prompt PROMPT [%-7s] initial prompt\n", params.prompt.c_str());
152
+ fprintf(stderr, " -m FNAME, --model FNAME [%-7s] model path\n", params.model.c_str());
153
+ fprintf(stderr, " -oved D, --ov-e-device DNAME [%-7s] the OpenVINO device used for encode inference\n", params.openvino_encode_device.c_str());
154
+ // server params
155
+ fprintf(stderr, " --host HOST, [%-7s] Hostname/ip-adress for the server\n", sparams.hostname.c_str());
156
+ fprintf(stderr, " --port PORT, [%-7d] Port number for the server\n", sparams.port);
157
+ fprintf(stderr, " --public PATH, [%-7s] Path to the public folder\n", sparams.public_path.c_str());
158
+ fprintf(stderr, "\n");
159
+ }
160
+
161
+ bool whisper_params_parse(int argc, char ** argv, whisper_params & params, server_params & sparams) {
162
+ for (int i = 1; i < argc; i++) {
163
+ std::string arg = argv[i];
164
+
165
+ if (arg == "-h" || arg == "--help") {
166
+ whisper_print_usage(argc, argv, params, sparams);
167
+ exit(0);
168
+ }
169
+ else if (arg == "-t" || arg == "--threads") { params.n_threads = std::stoi(argv[++i]); }
170
+ else if (arg == "-p" || arg == "--processors") { params.n_processors = std::stoi(argv[++i]); }
171
+ else if (arg == "-ot" || arg == "--offset-t") { params.offset_t_ms = std::stoi(argv[++i]); }
172
+ else if (arg == "-on" || arg == "--offset-n") { params.offset_n = std::stoi(argv[++i]); }
173
+ else if (arg == "-d" || arg == "--duration") { params.duration_ms = std::stoi(argv[++i]); }
174
+ else if (arg == "-mc" || arg == "--max-context") { params.max_context = std::stoi(argv[++i]); }
175
+ else if (arg == "-ml" || arg == "--max-len") { params.max_len = std::stoi(argv[++i]); }
176
+ else if (arg == "-bo" || arg == "--best-of") { params.best_of = std::stoi(argv[++i]); }
177
+ else if (arg == "-bs" || arg == "--beam-size") { params.beam_size = std::stoi(argv[++i]); }
178
+ else if (arg == "-wt" || arg == "--word-thold") { params.word_thold = std::stof(argv[++i]); }
179
+ else if (arg == "-et" || arg == "--entropy-thold") { params.entropy_thold = std::stof(argv[++i]); }
180
+ else if (arg == "-lpt" || arg == "--logprob-thold") { params.logprob_thold = std::stof(argv[++i]); }
181
+ // else if (arg == "-su" || arg == "--speed-up") { params.speed_up = true; }
182
+ else if (arg == "-debug"|| arg == "--debug-mode") { params.debug_mode = true; }
183
+ else if (arg == "-tr" || arg == "--translate") { params.translate = true; }
184
+ else if (arg == "-di" || arg == "--diarize") { params.diarize = true; }
185
+ else if (arg == "-tdrz" || arg == "--tinydiarize") { params.tinydiarize = true; }
186
+ else if (arg == "-sow" || arg == "--split-on-word") { params.split_on_word = true; }
187
+ else if (arg == "-nf" || arg == "--no-fallback") { params.no_fallback = true; }
188
+ else if (arg == "-fp" || arg == "--font-path") { params.font_path = argv[++i]; }
189
+ else if (arg == "-ps" || arg == "--print-special") { params.print_special = true; }
190
+ else if (arg == "-pc" || arg == "--print-colors") { params.print_colors = true; }
191
+ else if (arg == "-pp" || arg == "--print-progress") { params.print_progress = true; }
192
+ else if (arg == "-nt" || arg == "--no-timestamps") { params.no_timestamps = true; }
193
+ else if (arg == "-l" || arg == "--language") { params.language = argv[++i]; }
194
+ else if (arg == "-dl" || arg == "--detect-language") { params.detect_language = true; }
195
+ else if ( arg == "--prompt") { params.prompt = argv[++i]; }
196
+ else if (arg == "-m" || arg == "--model") { params.model = argv[++i]; }
197
+ else if (arg == "-oved" || arg == "--ov-e-device") { params.openvino_encode_device = argv[++i]; }
198
+ else if (arg == "-ng" || arg == "--no-gpu") { params.use_gpu = false; }
199
+ // server params
200
+ else if ( arg == "--port") { sparams.port = std::stoi(argv[++i]); }
201
+ else if ( arg == "--host") { sparams.hostname = argv[++i]; }
202
+ else if ( arg == "--public") { sparams.public_path = argv[++i]; }
203
+ else {
204
+ fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
205
+ whisper_print_usage(argc, argv, params, sparams);
206
+ exit(0);
207
+ }
208
+ }
209
+
210
+ return true;
211
+ }
212
+
213
+ struct whisper_print_user_data {
214
+ const whisper_params * params;
215
+
216
+ const std::vector<std::vector<float>> * pcmf32s;
217
+ int progress_prev;
218
+ };
219
+
220
+ std::string estimate_diarization_speaker(std::vector<std::vector<float>> pcmf32s, int64_t t0, int64_t t1, bool id_only = false) {
221
+ std::string speaker = "";
222
+ const int64_t n_samples = pcmf32s[0].size();
223
+
224
+ const int64_t is0 = timestamp_to_sample(t0, n_samples);
225
+ const int64_t is1 = timestamp_to_sample(t1, n_samples);
226
+
227
+ double energy0 = 0.0f;
228
+ double energy1 = 0.0f;
229
+
230
+ for (int64_t j = is0; j < is1; j++) {
231
+ energy0 += fabs(pcmf32s[0][j]);
232
+ energy1 += fabs(pcmf32s[1][j]);
233
+ }
234
+
235
+ if (energy0 > 1.1*energy1) {
236
+ speaker = "0";
237
+ } else if (energy1 > 1.1*energy0) {
238
+ speaker = "1";
239
+ } else {
240
+ speaker = "?";
241
+ }
242
+
243
+ //printf("is0 = %lld, is1 = %lld, energy0 = %f, energy1 = %f, speaker = %s\n", is0, is1, energy0, energy1, speaker.c_str());
244
+
245
+ if (!id_only) {
246
+ speaker.insert(0, "(speaker ");
247
+ speaker.append(")");
248
+ }
249
+
250
+ return speaker;
251
+ }
252
+
253
+ void whisper_print_progress_callback(struct whisper_context * /*ctx*/, struct whisper_state * /*state*/, int progress, void * user_data) {
254
+ int progress_step = ((whisper_print_user_data *) user_data)->params->progress_step;
255
+ int * progress_prev = &(((whisper_print_user_data *) user_data)->progress_prev);
256
+ if (progress >= *progress_prev + progress_step) {
257
+ *progress_prev += progress_step;
258
+ fprintf(stderr, "%s: progress = %3d%%\n", __func__, progress);
259
+ }
260
+ }
261
+
262
+ void whisper_print_segment_callback(struct whisper_context * ctx, struct whisper_state * /*state*/, int n_new, void * user_data) {
263
+ const auto & params = *((whisper_print_user_data *) user_data)->params;
264
+ const auto & pcmf32s = *((whisper_print_user_data *) user_data)->pcmf32s;
265
+
266
+ const int n_segments = whisper_full_n_segments(ctx);
267
+
268
+ std::string speaker = "";
269
+
270
+ int64_t t0 = 0;
271
+ int64_t t1 = 0;
272
+
273
+ // print the last n_new segments
274
+ const int s0 = n_segments - n_new;
275
+
276
+ if (s0 == 0) {
277
+ printf("\n");
278
+ }
279
+
280
+ for (int i = s0; i < n_segments; i++) {
281
+ if (!params.no_timestamps || params.diarize) {
282
+ t0 = whisper_full_get_segment_t0(ctx, i);
283
+ t1 = whisper_full_get_segment_t1(ctx, i);
284
+ }
285
+
286
+ if (!params.no_timestamps) {
287
+ printf("[%s --> %s] ", to_timestamp(t0).c_str(), to_timestamp(t1).c_str());
288
+ }
289
+
290
+ if (params.diarize && pcmf32s.size() == 2) {
291
+ speaker = estimate_diarization_speaker(pcmf32s, t0, t1);
292
+ }
293
+
294
+ if (params.print_colors) {
295
+ for (int j = 0; j < whisper_full_n_tokens(ctx, i); ++j) {
296
+ if (params.print_special == false) {
297
+ const whisper_token id = whisper_full_get_token_id(ctx, i, j);
298
+ if (id >= whisper_token_eot(ctx)) {
299
+ continue;
300
+ }
301
+ }
302
+
303
+ const char * text = whisper_full_get_token_text(ctx, i, j);
304
+ const float p = whisper_full_get_token_p (ctx, i, j);
305
+
306
+ const int col = std::max(0, std::min((int) k_colors.size() - 1, (int) (std::pow(p, 3)*float(k_colors.size()))));
307
+
308
+ printf("%s%s%s%s", speaker.c_str(), k_colors[col].c_str(), text, "\033[0m");
309
+ }
310
+ } else {
311
+ const char * text = whisper_full_get_segment_text(ctx, i);
312
+
313
+ printf("%s%s", speaker.c_str(), text);
314
+ }
315
+
316
+ if (params.tinydiarize) {
317
+ if (whisper_full_get_segment_speaker_turn_next(ctx, i)) {
318
+ printf("%s", params.tdrz_speaker_turn.c_str());
319
+ }
320
+ }
321
+
322
+ // with timestamps or speakers: each segment on new line
323
+ if (!params.no_timestamps || params.diarize) {
324
+ printf("\n");
325
+ }
326
+ fflush(stdout);
327
+ }
328
+ }
329
+
330
+ std::string output_str(struct whisper_context * ctx, const whisper_params & params, std::vector<std::vector<float>> pcmf32s) {
331
+ std::stringstream result;
332
+ const int n_segments = whisper_full_n_segments(ctx);
333
+ for (int i = 0; i < n_segments; ++i) {
334
+ const char * text = whisper_full_get_segment_text(ctx, i);
335
+ std::string speaker = "";
336
+
337
+ if (params.diarize && pcmf32s.size() == 2)
338
+ {
339
+ const int64_t t0 = whisper_full_get_segment_t0(ctx, i);
340
+ const int64_t t1 = whisper_full_get_segment_t1(ctx, i);
341
+ speaker = estimate_diarization_speaker(pcmf32s, t0, t1);
342
+ }
343
+
344
+ result << speaker << text << "\n";
345
+ }
346
+ return result.str();
347
+ }
348
+
349
+ void get_req_parameters(const Request & req, whisper_params & params)
350
+ {
351
+ // user model configu.has_fileion
352
+ if (req.has_file("offset-t"))
353
+ {
354
+ params.offset_t_ms = std::stoi(req.get_file_value("offset-t").content);
355
+ }
356
+ if (req.has_file("offset-n"))
357
+ {
358
+ params.offset_n = std::stoi(req.get_file_value("offset-n").content);
359
+ }
360
+ if (req.has_file("duration"))
361
+ {
362
+ params.duration_ms = std::stoi(req.get_file_value("duration").content);
363
+ }
364
+ if (req.has_file("max-context"))
365
+ {
366
+ params.max_context = std::stoi(req.get_file_value("max-context").content);
367
+ }
368
+ if (req.has_file("prompt"))
369
+ {
370
+ params.prompt = req.get_file_value("prompt").content;
371
+ }
372
+ if (req.has_file("response-format"))
373
+ {
374
+ params.response_format = req.get_file_value("response-format").content;
375
+ }
376
+ if (req.has_file("temerature"))
377
+ {
378
+ params.userdef_temp = std::stof(req.get_file_value("temperature").content);
379
+ }
380
+ }
381
+
382
+ } // namespace
383
+
384
+ int main(int argc, char ** argv) {
385
+ whisper_params params;
386
+ server_params sparams;
387
+
388
+ std::mutex whisper_mutex;
389
+
390
+ if (whisper_params_parse(argc, argv, params, sparams) == false) {
391
+ whisper_print_usage(argc, argv, params, sparams);
392
+ return 1;
393
+ }
394
+
395
+ if (params.language != "auto" && whisper_lang_id(params.language.c_str()) == -1) {
396
+ fprintf(stderr, "error: unknown language '%s'\n", params.language.c_str());
397
+ whisper_print_usage(argc, argv, params, sparams);
398
+ exit(0);
399
+ }
400
+
401
+ if (params.diarize && params.tinydiarize) {
402
+ fprintf(stderr, "error: cannot use both --diarize and --tinydiarize\n");
403
+ whisper_print_usage(argc, argv, params, sparams);
404
+ exit(0);
405
+ }
406
+
407
+ // whisper init
408
+ struct whisper_context_params cparams;
409
+ cparams.use_gpu = params.use_gpu;
410
+
411
+ struct whisper_context * ctx = whisper_init_from_file_with_params(params.model.c_str(), cparams);
412
+
413
+ if (ctx == nullptr) {
414
+ fprintf(stderr, "error: failed to initialize whisper context\n");
415
+ return 3;
416
+ }
417
+
418
+ // initialize openvino encoder. this has no effect on whisper.cpp builds that don't have OpenVINO configured
419
+ whisper_ctx_init_openvino_encoder(ctx, nullptr, params.openvino_encode_device.c_str(), nullptr);
420
+
421
+ Server svr;
422
+
423
+ std::string const default_content = "<html>hello</html>";
424
+
425
+ // this is only called if no index.html is found in the public --path
426
+ svr.Get("/", [&default_content](const Request &, Response &res){
427
+ res.set_content(default_content, "text/html");
428
+ return false;
429
+ });
430
+
431
+ svr.Post("/inference", [&](const Request &req, Response &res){
432
+ // aquire whisper model mutex lock
433
+ whisper_mutex.lock();
434
+
435
+ // first check user requested fields of the request
436
+ if (!req.has_file("file"))
437
+ {
438
+ fprintf(stderr, "error: no 'file' field in the request\n");
439
+ const std::string error_resp = "{\"error\":\"no 'file' field in the request\"}";
440
+ res.set_content(error_resp, "application/json");
441
+ whisper_mutex.unlock();
442
+ return;
443
+ }
444
+ auto audio_file = req.get_file_value("file");
445
+
446
+ // check non-required fields
447
+ get_req_parameters(req, params);
448
+
449
+ std::string filename{audio_file.filename};
450
+ printf("Received request: %s\n", filename.c_str());
451
+
452
+ // audio arrays
453
+ std::vector<float> pcmf32; // mono-channel F32 PCM
454
+ std::vector<std::vector<float>> pcmf32s; // stereo-channel F32 PCM
455
+
456
+ // write file to temporary file
457
+ std::ofstream temp_file{filename, std::ios::binary};
458
+ temp_file << audio_file.content;
459
+
460
+ // read wav content into pcmf32
461
+ if (!::read_wav(filename, pcmf32, pcmf32s, params.diarize)) {
462
+ fprintf(stderr, "error: failed to read WAV file '%s'\n", filename.c_str());
463
+ const std::string error_resp = "{\"error\":\"failed to read WAV file\"}";
464
+ res.set_content(error_resp, "application/json");
465
+ whisper_mutex.unlock();
466
+ return;
467
+ }
468
+ // remove temp file
469
+ std::remove(filename.c_str());
470
+
471
+ printf("Successfully loaded %s\n", filename.c_str());
472
+
473
+ // print system information
474
+ {
475
+ fprintf(stderr, "\n");
476
+ fprintf(stderr, "system_info: n_threads = %d / %d | %s\n",
477
+ params.n_threads*params.n_processors, std::thread::hardware_concurrency(), whisper_print_system_info());
478
+ }
479
+
480
+ // print some info about the processing
481
+ {
482
+ fprintf(stderr, "\n");
483
+ if (!whisper_is_multilingual(ctx)) {
484
+ if (params.language != "en" || params.translate) {
485
+ params.language = "en";
486
+ params.translate = false;
487
+ fprintf(stderr, "%s: WARNING: model is not multilingual, ignoring language and translation options\n", __func__);
488
+ }
489
+ }
490
+ if (params.detect_language) {
491
+ params.language = "auto";
492
+ }
493
+ fprintf(stderr, "%s: processing '%s' (%d samples, %.1f sec), %d threads, %d processors, lang = %s, task = %s, %stimestamps = %d ...\n",
494
+ __func__, filename.c_str(), int(pcmf32.size()), float(pcmf32.size())/WHISPER_SAMPLE_RATE,
495
+ params.n_threads, params.n_processors,
496
+ params.language.c_str(),
497
+ params.translate ? "translate" : "transcribe",
498
+ params.tinydiarize ? "tdrz = 1, " : "",
499
+ params.no_timestamps ? 0 : 1);
500
+
501
+ fprintf(stderr, "\n");
502
+ }
503
+
504
+ // run the inference
505
+ {
506
+
507
+ printf("Running whisper.cpp inference on %s\n", filename.c_str());
508
+ whisper_full_params wparams = whisper_full_default_params(WHISPER_SAMPLING_GREEDY);
509
+
510
+ wparams.strategy = params.beam_size > 1 ? WHISPER_SAMPLING_BEAM_SEARCH : WHISPER_SAMPLING_GREEDY;
511
+
512
+ wparams.print_realtime = false;
513
+ wparams.print_progress = params.print_progress;
514
+ wparams.print_timestamps = !params.no_timestamps;
515
+ wparams.print_special = params.print_special;
516
+ wparams.translate = params.translate;
517
+ wparams.language = params.language.c_str();
518
+ wparams.detect_language = params.detect_language;
519
+ wparams.n_threads = params.n_threads;
520
+ wparams.n_max_text_ctx = params.max_context >= 0 ? params.max_context : wparams.n_max_text_ctx;
521
+ wparams.offset_ms = params.offset_t_ms;
522
+ wparams.duration_ms = params.duration_ms;
523
+
524
+ wparams.thold_pt = params.word_thold;
525
+ wparams.split_on_word = params.split_on_word;
526
+
527
+ wparams.speed_up = params.speed_up;
528
+ wparams.debug_mode = params.debug_mode;
529
+
530
+ wparams.tdrz_enable = params.tinydiarize; // [TDRZ]
531
+
532
+ wparams.initial_prompt = params.prompt.c_str();
533
+
534
+ wparams.greedy.best_of = params.best_of;
535
+ wparams.beam_search.beam_size = params.beam_size;
536
+
537
+ wparams.temperature_inc = params.userdef_temp;
538
+ wparams.entropy_thold = params.entropy_thold;
539
+ wparams.logprob_thold = params.logprob_thold;
540
+
541
+ whisper_print_user_data user_data = { &params, &pcmf32s, 0 };
542
+
543
+ // this callback is called on each new segment
544
+ if (!wparams.print_realtime) {
545
+ wparams.new_segment_callback = whisper_print_segment_callback;
546
+ wparams.new_segment_callback_user_data = &user_data;
547
+ }
548
+
549
+ if (wparams.print_progress) {
550
+ wparams.progress_callback = whisper_print_progress_callback;
551
+ wparams.progress_callback_user_data = &user_data;
552
+ }
553
+
554
+ // examples for abort mechanism
555
+ // in examples below, we do not abort the processing, but we could if the flag is set to true
556
+
557
+ // the callback is called before every encoder run - if it returns false, the processing is aborted
558
+ {
559
+ static bool is_aborted = false; // NOTE: this should be atomic to avoid data race
560
+
561
+ wparams.encoder_begin_callback = [](struct whisper_context * /*ctx*/, struct whisper_state * /*state*/, void * user_data) {
562
+ bool is_aborted = *(bool*)user_data;
563
+ return !is_aborted;
564
+ };
565
+ wparams.encoder_begin_callback_user_data = &is_aborted;
566
+ }
567
+
568
+ // the callback is called before every computation - if it returns true, the computation is aborted
569
+ {
570
+ static bool is_aborted = false; // NOTE: this should be atomic to avoid data race
571
+
572
+ wparams.abort_callback = [](void * user_data) {
573
+ bool is_aborted = *(bool*)user_data;
574
+ return is_aborted;
575
+ };
576
+ wparams.abort_callback_user_data = &is_aborted;
577
+ }
578
+
579
+ if (whisper_full_parallel(ctx, wparams, pcmf32.data(), pcmf32.size(), params.n_processors) != 0) {
580
+ fprintf(stderr, "%s: failed to process audio\n", argv[0]);
581
+ const std::string error_resp = "{\"error\":\"failed to process audio\"}";
582
+ res.set_content(error_resp, "application/json");
583
+ whisper_mutex.unlock();
584
+ return;
585
+ }
586
+ }
587
+
588
+ // return results to user
589
+ if (params.response_format == text_format)
590
+ {
591
+ std::string results = output_str(ctx, params, pcmf32s);
592
+ res.set_content(results.c_str(), "text/html");
593
+ }
594
+ // TODO add more output formats
595
+ else
596
+ {
597
+ std::string results = output_str(ctx, params, pcmf32s);
598
+ json jres = json{
599
+ {"text", results}
600
+ };
601
+ res.set_content(jres.dump(-1, ' ', false, json::error_handler_t::replace),
602
+ "application/json");
603
+ }
604
+
605
+ // return whisper model mutex lock
606
+ whisper_mutex.unlock();
607
+ });
608
+ svr.Post("/load", [&](const Request &req, Response &res){
609
+ whisper_mutex.lock();
610
+ if (!req.has_file("model"))
611
+ {
612
+ fprintf(stderr, "error: no 'model' field in the request\n");
613
+ const std::string error_resp = "{\"error\":\"no 'model' field in the request\"}";
614
+ res.set_content(error_resp, "application/json");
615
+ whisper_mutex.unlock();
616
+ return;
617
+ }
618
+ std::string model = req.get_file_value("model").content;
619
+ if (!is_file_exist(model.c_str()))
620
+ {
621
+ fprintf(stderr, "error: 'model': %s not found!\n", model.c_str());
622
+ const std::string error_resp = "{\"error\":\"model not found!\"}";
623
+ res.set_content(error_resp, "application/json");
624
+ whisper_mutex.unlock();
625
+ return;
626
+ }
627
+
628
+ // clean up
629
+ whisper_free(ctx);
630
+
631
+ // whisper init
632
+ ctx = whisper_init_from_file_with_params(model.c_str(), cparams);
633
+
634
+ // TODO perhaps load prior model here instead of exit
635
+ if (ctx == nullptr) {
636
+ fprintf(stderr, "error: model init failed, no model loaded must exit\n");
637
+ exit(1);
638
+ }
639
+
640
+ // initialize openvino encoder. this has no effect on whisper.cpp builds that don't have OpenVINO configured
641
+ whisper_ctx_init_openvino_encoder(ctx, nullptr, params.openvino_encode_device.c_str(), nullptr);
642
+
643
+ const std::string success = "Load was successful!";
644
+ res.set_content(success, "application/text");
645
+
646
+ // check if the model is in the file system
647
+ whisper_mutex.unlock();
648
+ });
649
+
650
+ svr.set_exception_handler([](const Request &, Response &res, std::exception_ptr ep) {
651
+ const char fmt[] = "500 Internal Server Error\n%s";
652
+ char buf[BUFSIZ];
653
+ try {
654
+ std::rethrow_exception(std::move(ep));
655
+ } catch (std::exception &e) {
656
+ snprintf(buf, sizeof(buf), fmt, e.what());
657
+ } catch (...) {
658
+ snprintf(buf, sizeof(buf), fmt, "Unknown Exception");
659
+ }
660
+ res.set_content(buf, "text/plain");
661
+ res.status = 500;
662
+ });
663
+
664
+ svr.set_error_handler([](const Request &, Response &res) {
665
+ if (res.status == 400) {
666
+ res.set_content("Invalid request", "text/plain");
667
+ } else if (res.status != 500) {
668
+ res.set_content("File Not Found", "text/plain");
669
+ res.status = 404;
670
+ }
671
+ });
672
+
673
+ // set timeouts and change hostname and port
674
+ svr.set_read_timeout(sparams.read_timeout);
675
+ svr.set_write_timeout(sparams.write_timeout);
676
+
677
+ if (!svr.bind_to_port(sparams.hostname, sparams.port))
678
+ {
679
+ fprintf(stderr, "\ncouldn't bind to server socket: hostname=%s port=%d\n\n",
680
+ sparams.hostname.c_str(), sparams.port);
681
+ return 1;
682
+ }
683
+
684
+ // Set the base directory for serving static files
685
+ svr.set_base_dir(sparams.public_path);
686
+
687
+ // to make it ctrl+clickable:
688
+ printf("\nwhisper server listening at http://%s:%d\n\n", sparams.hostname.c_str(), sparams.port);
689
+
690
+ if (!svr.listen_after_bind())
691
+ {
692
+ return 1;
693
+ }
694
+
695
+ whisper_print_timings(ctx);
696
+ whisper_free(ctx);
697
+
698
+ return 0;
699
+ }