@@ -60,7 +60,13 @@ struct whisper_params {
60
60
int32_t max_tokens = 32 ;
61
61
int32_t audio_ctx = 0 ;
62
62
int32_t n_gpu_layers = 999 ;
63
-
63
+ int32_t seed = 0 ;
64
+ int32_t top_k = 5 ;
65
+ int32_t min_keep = 1 ;
66
+ float top_p = 0 .80f ;
67
+ float min_p = 0 .01f ;
68
+ float temp = 0 .30f ;
69
+
64
70
float vad_thold = 0 .6f ;
65
71
float freq_thold = 100 .0f ;
66
72
@@ -102,6 +108,12 @@ static bool whisper_params_parse(int argc, char ** argv, whisper_params & params
102
108
else if (arg == " -mt" || arg == " --max-tokens" ) { params.max_tokens = std::stoi (argv[++i]); }
103
109
else if (arg == " -ac" || arg == " --audio-ctx" ) { params.audio_ctx = std::stoi (argv[++i]); }
104
110
else if (arg == " -ngl" || arg == " --n-gpu-layers" ) { params.n_gpu_layers = std::stoi (argv[++i]); }
111
+ else if (arg == " --seed" ) { params.seed = std::stoi (argv[++i]); }
112
+ else if (arg == " --top-k" ) { params.top_k = std::stoi (argv[++i]); }
113
+ else if (arg == " --min-keep" ) { params.min_keep = std::stoul (argv[++i]);}
114
+ else if (arg == " --top-p" ) { params.top_p = std::stof (argv[++i]); }
115
+ else if (arg == " --min-p" ) { params.min_p = std::stof (argv[++i]); }
116
+ else if (arg == " --temp" ) { params.temp = std::stof (argv[++i]); }
105
117
else if (arg == " -vth" || arg == " --vad-thold" ) { params.vad_thold = std::stof (argv[++i]); }
106
118
else if (arg == " -fth" || arg == " --freq-thold" ) { params.freq_thold = std::stof (argv[++i]); }
107
119
else if (arg == " -tr" || arg == " --translate" ) { params.translate = true ; }
@@ -150,6 +162,12 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
150
162
fprintf (stderr, " -mt N, --max-tokens N [%-7d] maximum number of tokens per audio chunk\n " , params.max_tokens );
151
163
fprintf (stderr, " -ac N, --audio-ctx N [%-7d] audio context size (0 - all)\n " , params.audio_ctx );
152
164
fprintf (stderr, " -ngl N, --n-gpu-layers N [%-7d] number of layers to store in VRAM\n " , params.n_gpu_layers );
165
+ fprintf (stderr, " --seed N [%-7d] seed sampling\n " , params.seed );
166
+ fprintf (stderr, " --top-k N [%-7d] top-k sampling (0 = disabled)\n " , params.top_k );
167
+ fprintf (stderr, " --min-keep N [%-7d] minimum number of tokens to keep\n " , params.min_keep );
168
+ fprintf (stderr, " --top-p N [%-7.2f] top-p sampling\n " , params.top_p );
169
+ fprintf (stderr, " --min-p N [%-7.2f] min-p sampling\n " , params.min_p );
170
+ fprintf (stderr, " --temp N [%-7.2f] temperature\n " , params.temp );
153
171
fprintf (stderr, " -vth N, --vad-thold N [%-7.2f] voice activity detection threshold\n " , params.vad_thold );
154
172
fprintf (stderr, " -fth N, --freq-thold N [%-7.2f] high-pass frequency cutoff\n " , params.freq_thold );
155
173
fprintf (stderr, " -tr, --translate [%-7s] translate from source language to english\n " , params.translate ? " true" : " false" );
@@ -409,21 +427,16 @@ int main(int argc, char ** argv) {
409
427
llama_batch batch = llama_batch_init (llama_n_ctx (ctx_llama), 0 , 1 );
410
428
411
429
// init sampler
412
- const float top_k = 5 ;
413
- const float top_p = 0 .80f ;
414
- const float temp = 0 .30f ;
415
-
416
- const int seed = 0 ;
417
-
418
430
auto sparams = llama_sampler_chain_default_params ();
419
431
420
432
llama_sampler * smpl = llama_sampler_chain_init (sparams);
421
433
422
- if (temp > 0 .0f ) {
423
- llama_sampler_chain_add (smpl, llama_sampler_init_top_k (top_k));
424
- llama_sampler_chain_add (smpl, llama_sampler_init_top_p (top_p, 1 ));
425
- llama_sampler_chain_add (smpl, llama_sampler_init_temp (temp));
426
- llama_sampler_chain_add (smpl, llama_sampler_init_dist (seed));
434
+ if (params.temp > 0 .0f ) {
435
+ llama_sampler_chain_add (smpl, llama_sampler_init_top_k (params.top_k ));
436
+ llama_sampler_chain_add (smpl, llama_sampler_init_top_p (params.top_p , params.min_keep ));
437
+ llama_sampler_chain_add (smpl, llama_sampler_init_temp (params.temp ));
438
+ llama_sampler_chain_add (smpl, llama_sampler_init_dist (params.seed ));
439
+ llama_sampler_chain_add (smpl, llama_sampler_init_min_p (params.min_p , params.min_keep ));
427
440
} else {
428
441
llama_sampler_chain_add (smpl, llama_sampler_init_greedy ());
429
442
}
@@ -615,7 +628,7 @@ int main(int argc, char ** argv) {
615
628
}
616
629
617
630
// remove all characters, except for letters, numbers, punctuation and ':', '\'', '-', ' '
618
- text_heard = std::regex_replace (text_heard, std::regex (" [^a-zA-Z0-9 \\ .,\\ ?!\\ s\\ :\\ '\\ -]" ), " " );
631
+ text_heard = std::regex_replace (text_heard, std::regex (" [^a-zA-Z0-9åäöÅÄÖ \\ .,\\ ?!\\ s\\ :\\ '\\ -]" ), " " );
619
632
620
633
// take first line
621
634
text_heard = text_heard.substr (0 , text_heard.find_first_of (' \n ' ));
0 commit comments