matteng1 ggerganov commited on
Commit
1473e33
·
unverified ·
1 Parent(s): a7bcfbf

talk-llama : fix for swedish umlauts + expose model inference settings in talk-llama.cpp (#3187)

Browse files

Quick fix for not removing swedish umlauts.

* Update talk-llama.cpp

Expose model inference settings to user instead of hard coding them. Same defaults as previous defaults.

* Update examples/talk-llama/talk-llama.cpp

Co-authored-by: Georgi Gerganov <[email protected]>

Files changed (1) hide show
  1. examples/talk-llama/talk-llama.cpp +26 -13
examples/talk-llama/talk-llama.cpp CHANGED
@@ -60,7 +60,13 @@ struct whisper_params {
60
  int32_t max_tokens = 32;
61
  int32_t audio_ctx = 0;
62
  int32_t n_gpu_layers = 999;
63
-
 
 
 
 
 
 
64
  float vad_thold = 0.6f;
65
  float freq_thold = 100.0f;
66
 
@@ -102,6 +108,12 @@ static bool whisper_params_parse(int argc, char ** argv, whisper_params & params
102
  else if (arg == "-mt" || arg == "--max-tokens") { params.max_tokens = std::stoi(argv[++i]); }
103
  else if (arg == "-ac" || arg == "--audio-ctx") { params.audio_ctx = std::stoi(argv[++i]); }
104
  else if (arg == "-ngl" || arg == "--n-gpu-layers") { params.n_gpu_layers = std::stoi(argv[++i]); }
 
 
 
 
 
 
105
  else if (arg == "-vth" || arg == "--vad-thold") { params.vad_thold = std::stof(argv[++i]); }
106
  else if (arg == "-fth" || arg == "--freq-thold") { params.freq_thold = std::stof(argv[++i]); }
107
  else if (arg == "-tr" || arg == "--translate") { params.translate = true; }
@@ -150,6 +162,12 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
150
  fprintf(stderr, " -mt N, --max-tokens N [%-7d] maximum number of tokens per audio chunk\n", params.max_tokens);
151
  fprintf(stderr, " -ac N, --audio-ctx N [%-7d] audio context size (0 - all)\n", params.audio_ctx);
152
  fprintf(stderr, " -ngl N, --n-gpu-layers N [%-7d] number of layers to store in VRAM\n", params.n_gpu_layers);
 
 
 
 
 
 
153
  fprintf(stderr, " -vth N, --vad-thold N [%-7.2f] voice activity detection threshold\n", params.vad_thold);
154
  fprintf(stderr, " -fth N, --freq-thold N [%-7.2f] high-pass frequency cutoff\n", params.freq_thold);
155
  fprintf(stderr, " -tr, --translate [%-7s] translate from source language to english\n", params.translate ? "true" : "false");
@@ -409,21 +427,16 @@ int main(int argc, char ** argv) {
409
  llama_batch batch = llama_batch_init(llama_n_ctx(ctx_llama), 0, 1);
410
 
411
  // init sampler
412
- const float top_k = 5;
413
- const float top_p = 0.80f;
414
- const float temp = 0.30f;
415
-
416
- const int seed = 0;
417
-
418
  auto sparams = llama_sampler_chain_default_params();
419
 
420
  llama_sampler * smpl = llama_sampler_chain_init(sparams);
421
 
422
- if (temp > 0.0f) {
423
- llama_sampler_chain_add(smpl, llama_sampler_init_top_k(top_k));
424
- llama_sampler_chain_add(smpl, llama_sampler_init_top_p(top_p, 1));
425
- llama_sampler_chain_add(smpl, llama_sampler_init_temp (temp));
426
- llama_sampler_chain_add(smpl, llama_sampler_init_dist (seed));
 
427
  } else {
428
  llama_sampler_chain_add(smpl, llama_sampler_init_greedy());
429
  }
@@ -615,7 +628,7 @@ int main(int argc, char ** argv) {
615
  }
616
 
617
  // remove all characters, except for letters, numbers, punctuation and ':', '\'', '-', ' '
618
- text_heard = std::regex_replace(text_heard, std::regex("[^a-zA-Z0-9\\.,\\?!\\s\\:\\'\\-]"), "");
619
 
620
  // take first line
621
  text_heard = text_heard.substr(0, text_heard.find_first_of('\n'));
 
60
  int32_t max_tokens = 32;
61
  int32_t audio_ctx = 0;
62
  int32_t n_gpu_layers = 999;
63
+ int32_t seed = 0;
64
+ int32_t top_k = 5;
65
+ int32_t min_keep = 1;
66
+ float top_p = 0.80f;
67
+ float min_p = 0.01f;
68
+ float temp = 0.30f;
69
+
70
  float vad_thold = 0.6f;
71
  float freq_thold = 100.0f;
72
 
 
108
  else if (arg == "-mt" || arg == "--max-tokens") { params.max_tokens = std::stoi(argv[++i]); }
109
  else if (arg == "-ac" || arg == "--audio-ctx") { params.audio_ctx = std::stoi(argv[++i]); }
110
  else if (arg == "-ngl" || arg == "--n-gpu-layers") { params.n_gpu_layers = std::stoi(argv[++i]); }
111
+ else if (arg == "--seed") { params.seed = std::stoi(argv[++i]); }
112
+ else if (arg == "--top-k") { params.top_k = std::stoi(argv[++i]); }
113
+ else if (arg == "--min-keep") { params.min_keep = std::stoul(argv[++i]);}
114
+ else if (arg == "--top-p") { params.top_p = std::stof(argv[++i]); }
115
+ else if (arg == "--min-p") { params.min_p = std::stof(argv[++i]); }
116
+ else if (arg == "--temp") { params.temp = std::stof(argv[++i]); }
117
  else if (arg == "-vth" || arg == "--vad-thold") { params.vad_thold = std::stof(argv[++i]); }
118
  else if (arg == "-fth" || arg == "--freq-thold") { params.freq_thold = std::stof(argv[++i]); }
119
  else if (arg == "-tr" || arg == "--translate") { params.translate = true; }
 
162
  fprintf(stderr, " -mt N, --max-tokens N [%-7d] maximum number of tokens per audio chunk\n", params.max_tokens);
163
  fprintf(stderr, " -ac N, --audio-ctx N [%-7d] audio context size (0 - all)\n", params.audio_ctx);
164
  fprintf(stderr, " -ngl N, --n-gpu-layers N [%-7d] number of layers to store in VRAM\n", params.n_gpu_layers);
165
+ fprintf(stderr, " --seed N [%-7d] seed sampling\n", params.seed);
166
+ fprintf(stderr, " --top-k N [%-7d] top-k sampling (0 = disabled)\n", params.top_k);
167
+ fprintf(stderr, " --min-keep N [%-7d] minimum number of tokens to keep\n", params.min_keep);
168
+ fprintf(stderr, " --top-p N [%-7.2f] top-p sampling\n", params.top_p);
169
+ fprintf(stderr, " --min-p N [%-7.2f] min-p sampling\n", params.min_p);
170
+ fprintf(stderr, " --temp N [%-7.2f] temperature\n", params.temp);
171
  fprintf(stderr, " -vth N, --vad-thold N [%-7.2f] voice activity detection threshold\n", params.vad_thold);
172
  fprintf(stderr, " -fth N, --freq-thold N [%-7.2f] high-pass frequency cutoff\n", params.freq_thold);
173
  fprintf(stderr, " -tr, --translate [%-7s] translate from source language to english\n", params.translate ? "true" : "false");
 
427
  llama_batch batch = llama_batch_init(llama_n_ctx(ctx_llama), 0, 1);
428
 
429
  // init sampler
 
 
 
 
 
 
430
  auto sparams = llama_sampler_chain_default_params();
431
 
432
  llama_sampler * smpl = llama_sampler_chain_init(sparams);
433
 
434
+ if (params.temp > 0.0f) {
435
+ llama_sampler_chain_add(smpl, llama_sampler_init_top_k(params.top_k));
436
+ llama_sampler_chain_add(smpl, llama_sampler_init_top_p(params.top_p, params.min_keep));
437
+ llama_sampler_chain_add(smpl, llama_sampler_init_temp (params.temp));
438
+ llama_sampler_chain_add(smpl, llama_sampler_init_dist (params.seed));
439
+ llama_sampler_chain_add(smpl, llama_sampler_init_min_p (params.min_p, params.min_keep));
440
  } else {
441
  llama_sampler_chain_add(smpl, llama_sampler_init_greedy());
442
  }
 
628
  }
629
 
630
  // remove all characters, except for letters, numbers, punctuation and ':', '\'', '-', ' '
631
+ text_heard = std::regex_replace(text_heard, std::regex("[^a-zA-Z0-9åäöÅÄÖ\\.,\\?!\\s\\:\\'\\-]"), "");
632
 
633
  // take first line
634
  text_heard = text_heard.substr(0, text_heard.find_first_of('\n'));