1
+ #define _USE_MATH_DEFINES // for M_PI
2
+
1
3
#include " common.h"
2
4
3
5
// third-party utilities
13
15
#include < codecvt>
14
16
#include < sstream>
15
17
16
- #ifndef M_PI
17
- #define M_PI 3.14159265358979323846
18
- #endif
19
-
20
18
#if defined(_MSC_VER)
21
19
#pragma warning(disable: 4244 4267) // possible loss of data
22
20
#endif
23
21
22
+ // Function to check if the next argument exists
23
+ std::string get_next_arg (int & i, int argc, char ** argv, const std::string& flag, gpt_params& params) {
24
+ if (i + 1 < argc && argv[i + 1 ][0 ] != ' -' ) {
25
+ return argv[++i];
26
+ } else {
27
+ fprintf (stderr, " error: %s requires one argument.\n " , flag.c_str ());
28
+ gpt_print_usage (argc, argv, params);
29
+ exit (0 );
30
+ }
31
+ }
32
+
24
33
bool gpt_params_parse (int argc, char ** argv, gpt_params & params) {
25
34
for (int i = 1 ; i < argc; i++) {
26
35
std::string arg = argv[i];
27
36
28
37
if (arg == " -s" || arg == " --seed" ) {
29
- params.seed = std::stoi (argv[++i] );
38
+ params.seed = std::stoi (get_next_arg (i, argc, argv, arg, params) );
30
39
} else if (arg == " -t" || arg == " --threads" ) {
31
- params.n_threads = std::stoi (argv[++i]);
40
+ params.n_threads = std::stoi (get_next_arg (i, argc, argv, arg, params));
41
+ } else if (arg == " -ngl" || arg == " --gpu-layers" || arg == " --n-gpu-layers" ) {
42
+ params.n_gpu_layers = std::stoi (get_next_arg (i, argc, argv, arg, params));
32
43
} else if (arg == " -p" || arg == " --prompt" ) {
33
- params.prompt = argv[++i] ;
44
+ params.prompt = get_next_arg (i, argc, argv, arg, params) ;
34
45
} else if (arg == " -n" || arg == " --n_predict" ) {
35
- params.n_predict = std::stoi (argv[++i] );
46
+ params.n_predict = std::stoi (get_next_arg (i, argc, argv, arg, params) );
36
47
} else if (arg == " --top_k" ) {
37
- params.top_k = std::max ( 1 , std::stoi ( argv[++i] ));
48
+ params.top_k = std::stoi ( get_next_arg (i, argc, argv, arg, params ));
38
49
} else if (arg == " --top_p" ) {
39
- params.top_p = std::stof (argv[++i] );
50
+ params.top_p = std::stof (get_next_arg (i, argc, argv, arg, params) );
40
51
} else if (arg == " --temp" ) {
41
- params.temp = std::stof (argv[++i] );
52
+ params.temp = std::stof (get_next_arg (i, argc, argv, arg, params) );
42
53
} else if (arg == " --repeat-last-n" ) {
43
- params.repeat_last_n = std::stof (argv[++i] );
54
+ params.repeat_last_n = std::stoi ( get_next_arg (i, argc, argv, arg, params) );
44
55
} else if (arg == " --repeat-penalty" ) {
45
- params.repeat_penalty = std::stof (argv[++i] );
56
+ params.repeat_penalty = std::stof (get_next_arg (i, argc, argv, arg, params) );
46
57
} else if (arg == " -b" || arg == " --batch_size" ) {
47
- params.n_batch = std::stoi (argv[++i] );
58
+ params.n_batch = std::stoi (get_next_arg (i, argc, argv, arg, params) );
48
59
} else if (arg == " -m" || arg == " --model" ) {
49
- params.model = argv[++i] ;
60
+ params.model = get_next_arg (i, argc, argv, arg, params) ;
50
61
} else if (arg == " -i" || arg == " --interactive" ) {
51
62
params.interactive = true ;
52
63
} else if (arg == " -ip" || arg == " --interactive-port" ) {
53
64
params.interactive = true ;
54
- params.interactive_port = std::stoi (argv[++i] );
65
+ params.interactive_port = std::stoi (get_next_arg (i, argc, argv, arg, params) );
55
66
} else if (arg == " -h" || arg == " --help" ) {
56
67
gpt_print_usage (argc, argv, params);
57
68
exit (0 );
58
69
} else if (arg == " -f" || arg == " --file" ) {
59
- if (++i > argc) {
60
- fprintf (stderr, " Invalid file param" );
61
- break ;
62
- }
70
+ get_next_arg (i, argc, argv, arg, params);
63
71
std::ifstream file (argv[i]);
64
72
if (!file) {
65
73
fprintf (stderr, " error: failed to open file '%s'\n " , argv[i]);
@@ -70,7 +78,7 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
70
78
params.prompt .pop_back ();
71
79
}
72
80
} else if (arg == " -tt" || arg == " --token_test" ) {
73
- params.token_test = argv[++i] ;
81
+ params.token_test = get_next_arg (i, argc, argv, arg, params) ;
74
82
}
75
83
else {
76
84
fprintf (stderr, " error: unknown argument: %s\n " , arg.c_str ());
@@ -89,6 +97,7 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
89
97
fprintf (stderr, " -h, --help show this help message and exit\n " );
90
98
fprintf (stderr, " -s SEED, --seed SEED RNG seed (default: -1)\n " );
91
99
fprintf (stderr, " -t N, --threads N number of threads to use during computation (default: %d)\n " , params.n_threads );
100
+ fprintf (stderr, " -ngl N, --gpu-layers N number of layers to offload to GPU on supported models (default: %d)\n " , params.n_gpu_layers );
92
101
fprintf (stderr, " -p PROMPT, --prompt PROMPT\n " );
93
102
fprintf (stderr, " prompt to start generation with (default: random)\n " );
94
103
fprintf (stderr, " -f FNAME, --file FNAME\n " );
@@ -755,3 +764,46 @@ float similarity(const std::string & s0, const std::string & s1) {
755
764
756
765
return 1 .0f - (dist / std::max (s0.size (), s1.size ()));
757
766
}
767
+
768
+ bool sam_params_parse (int argc, char ** argv, sam_params & params) {
769
+ for (int i = 1 ; i < argc; i++) {
770
+ std::string arg = argv[i];
771
+
772
+ if (arg == " -s" || arg == " --seed" ) {
773
+ params.seed = std::stoi (argv[++i]);
774
+ } else if (arg == " -t" || arg == " --threads" ) {
775
+ params.n_threads = std::stoi (argv[++i]);
776
+ } else if (arg == " -m" || arg == " --model" ) {
777
+ params.model = argv[++i];
778
+ } else if (arg == " -i" || arg == " --inp" ) {
779
+ params.fname_inp = argv[++i];
780
+ } else if (arg == " -o" || arg == " --out" ) {
781
+ params.fname_out = argv[++i];
782
+ } else if (arg == " -h" || arg == " --help" ) {
783
+ sam_print_usage (argc, argv, params);
784
+ exit (0 );
785
+ } else {
786
+ fprintf (stderr, " error: unknown argument: %s\n " , arg.c_str ());
787
+ sam_print_usage (argc, argv, params);
788
+ exit (0 );
789
+ }
790
+ }
791
+
792
+ return true ;
793
+ }
794
+
795
+ void sam_print_usage (int argc, char ** argv, const sam_params & params) {
796
+ fprintf (stderr, " usage: %s [options]\n " , argv[0 ]);
797
+ fprintf (stderr, " \n " );
798
+ fprintf (stderr, " options:\n " );
799
+ fprintf (stderr, " -h, --help show this help message and exit\n " );
800
+ fprintf (stderr, " -s SEED, --seed SEED RNG seed (default: -1)\n " );
801
+ fprintf (stderr, " -t N, --threads N number of threads to use during computation (default: %d)\n " , params.n_threads );
802
+ fprintf (stderr, " -m FNAME, --model FNAME\n " );
803
+ fprintf (stderr, " model path (default: %s)\n " , params.model .c_str ());
804
+ fprintf (stderr, " -i FNAME, --inp FNAME\n " );
805
+ fprintf (stderr, " input file (default: %s)\n " , params.fname_inp .c_str ());
806
+ fprintf (stderr, " -o FNAME, --out FNAME\n " );
807
+ fprintf (stderr, " output file (default: %s)\n " , params.fname_out .c_str ());
808
+ fprintf (stderr, " \n " );
809
+ }
0 commit comments