sachaarbonel commited on
Commit
647c7e7
·
unverified ·
1 Parent(s): 30197de

server : add option to suppress non-speech tokens (#2649)

Browse files

* The parameter will suppress non-speech tokens like [LAUGH], [SIGH], etc. from the output when enabled.

* add to whisper_params_parse

* add missing param

Files changed (1) hide show
  1. examples/server/server.cpp +9 -0
examples/server/server.cpp CHANGED
@@ -76,6 +76,7 @@ struct whisper_params {
76
  bool no_timestamps = false;
77
  bool use_gpu = true;
78
  bool flash_attn = false;
 
79
 
80
  std::string language = "en";
81
  std::string prompt = "";
@@ -135,6 +136,7 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
135
  fprintf(stderr, " --request-path PATH, [%-7s] Request path for all requests\n", sparams.request_path.c_str());
136
  fprintf(stderr, " --inference-path PATH, [%-7s] Inference path for all requests\n", sparams.inference_path.c_str());
137
  fprintf(stderr, " --convert, [%-7s] Convert audio to WAV, requires ffmpeg on the server", sparams.ffmpeg_converter ? "true" : "false");
 
138
  fprintf(stderr, "\n");
139
  }
140
 
@@ -179,6 +181,7 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params, serve
179
  else if (arg == "-dtw" || arg == "--dtw") { params.dtw = argv[++i]; }
180
  else if (arg == "-ng" || arg == "--no-gpu") { params.use_gpu = false; }
181
  else if (arg == "-fa" || arg == "--flash-attn") { params.flash_attn = true; }
 
182
  // server params
183
  else if ( arg == "--port") { sparams.port = std::stoi(argv[++i]); }
184
  else if ( arg == "--host") { sparams.hostname = argv[++i]; }
@@ -472,6 +475,10 @@ void get_req_parameters(const Request & req, whisper_params & params)
472
  {
473
  params.temperature_inc = std::stof(req.get_file_value("temperature_inc").content);
474
  }
 
 
 
 
475
  }
476
 
477
  } // namespace
@@ -786,6 +793,8 @@ int main(int argc, char ** argv) {
786
  wparams.no_timestamps = params.no_timestamps;
787
  wparams.token_timestamps = !params.no_timestamps && params.response_format == vjson_format;
788
 
 
 
789
  whisper_print_user_data user_data = { &params, &pcmf32s, 0 };
790
 
791
  // this callback is called on each new segment
 
76
  bool no_timestamps = false;
77
  bool use_gpu = true;
78
  bool flash_attn = false;
79
+ bool suppress_non_speech_tokens = false;
80
 
81
  std::string language = "en";
82
  std::string prompt = "";
 
136
  fprintf(stderr, " --request-path PATH, [%-7s] Request path for all requests\n", sparams.request_path.c_str());
137
  fprintf(stderr, " --inference-path PATH, [%-7s] Inference path for all requests\n", sparams.inference_path.c_str());
138
  fprintf(stderr, " --convert, [%-7s] Convert audio to WAV, requires ffmpeg on the server", sparams.ffmpeg_converter ? "true" : "false");
139
+ fprintf(stderr, " -sns, --suppress-non-speech [%-7s] suppress non-speech tokens\n", params.suppress_non_speech_tokens ? "true" : "false");
140
  fprintf(stderr, "\n");
141
  }
142
 
 
181
  else if (arg == "-dtw" || arg == "--dtw") { params.dtw = argv[++i]; }
182
  else if (arg == "-ng" || arg == "--no-gpu") { params.use_gpu = false; }
183
  else if (arg == "-fa" || arg == "--flash-attn") { params.flash_attn = true; }
184
+ else if (arg == "-sns" || arg == "--suppress-non-speech") { params.suppress_non_speech_tokens = true; }
185
  // server params
186
  else if ( arg == "--port") { sparams.port = std::stoi(argv[++i]); }
187
  else if ( arg == "--host") { sparams.hostname = argv[++i]; }
 
475
  {
476
  params.temperature_inc = std::stof(req.get_file_value("temperature_inc").content);
477
  }
478
+ if (req.has_file("suppress_non_speech"))
479
+ {
480
+ params.suppress_non_speech_tokens = parse_str_to_bool(req.get_file_value("suppress_non_speech").content);
481
+ }
482
  }
483
 
484
  } // namespace
 
793
  wparams.no_timestamps = params.no_timestamps;
794
  wparams.token_timestamps = !params.no_timestamps && params.response_format == vjson_format;
795
 
796
+ wparams.suppress_non_speech_tokens = params.suppress_non_speech_tokens;
797
+
798
  whisper_print_user_data user_data = { &params, &pcmf32s, 0 };
799
 
800
  // this callback is called on each new segment