1010#include <stdio.h>
1111#include <stdlib.h>
1212#include <string.h>
13+ #include <limits.h>
14+
15+ // Callback for streaming token output. Return non-zero to stop generation.
16+ typedef int (* bn_token_callback )(const char * piece , int token_id , void * user_data );
1317
1418#if defined(__APPLE__ )
1519#include <sys/sysctl.h>
@@ -49,6 +53,26 @@ static void print_usage(const char *prog) {
4953 fprintf (stderr , " --no-prefill Disable batch prompt prefill (compute logits for every token)\n" );
5054}
5155
56+ static int parse_int (const char * s , const char * name ) {
57+ char * end ;
58+ long val = strtol (s , & end , 10 );
59+ if (* end != '\0' || val < INT_MIN || val > INT_MAX ) {
60+ fprintf (stderr , "Invalid value for %s: %s\n" , name , s );
61+ exit (1 );
62+ }
63+ return (int )val ;
64+ }
65+
66+ static float parse_float (const char * s , const char * name ) {
67+ char * end ;
68+ float val = strtof (s , & end );
69+ if (* end != '\0' ) {
70+ fprintf (stderr , "Invalid value for %s: %s\n" , name , s );
71+ exit (1 );
72+ }
73+ return val ;
74+ }
75+
5276static CLIArgs parse_args (int argc , char * * argv ) {
5377 CLIArgs args = {0 };
5478 args .prompt = "Hello" ;
@@ -69,16 +93,18 @@ static CLIArgs parse_args(int argc, char **argv) {
6993 if (strcmp (argv [i ], "-p" ) == 0 && i + 1 < argc ) {
7094 args .prompt = argv [++ i ];
7195 } else if (strcmp (argv [i ], "-n" ) == 0 && i + 1 < argc ) {
72- args .n_tokens = atoi (argv [++ i ]);
96+ args .n_tokens = parse_int (argv [++ i ], "-n" );
7397 } else if (strcmp (argv [i ], "--temp" ) == 0 && i + 1 < argc ) {
74- args .temperature = ( float ) atof ( argv [++ i ]);
98+ args .temperature = parse_float ( argv [++ i ], "--temp" );
7599 args .temp_set = 1 ;
76100 } else if (strcmp (argv [i ], "--topp" ) == 0 && i + 1 < argc ) {
77- args .topp = ( float ) atof ( argv [++ i ]);
101+ args .topp = parse_float ( argv [++ i ], "--topp" );
78102 } else if (strcmp (argv [i ], "--seed" ) == 0 && i + 1 < argc ) {
79- args .seed = (uint64_t )atoll (argv [++ i ]);
103+ char * end ;
104+ args .seed = (uint64_t )strtoull (argv [++ i ], & end , 10 );
105+ if (* end != '\0' ) { fprintf (stderr , "Invalid value for --seed: %s\n" , argv [i ]); exit (1 ); }
80106 } else if (strcmp (argv [i ], "--maxseq" ) == 0 && i + 1 < argc ) {
81- args .max_seq_len = atoi (argv [++ i ]);
107+ args .max_seq_len = parse_int (argv [++ i ], "--maxseq" );
82108 } else if (strcmp (argv [i ], "--flash" ) == 0 ) {
83109 args .flash_attn = 1 ;
84110 } else if (strcmp (argv [i ], "--chat" ) == 0 ) {
@@ -88,7 +114,7 @@ static CLIArgs parse_args(int argc, char **argv) {
88114 } else if (strcmp (argv [i ], "--no-prefill" ) == 0 ) {
89115 args .no_prefill = 1 ;
90116 } else if (strcmp (argv [i ], "--repeat-penalty" ) == 0 && i + 1 < argc ) {
91- args .repeat_penalty = ( float ) atof ( argv [++ i ]);
117+ args .repeat_penalty = parse_float ( argv [++ i ], "--repeat-penalty" );
92118 args .repeat_set = 1 ;
93119 } else {
94120 fprintf (stderr , "Unknown option: %s\n" , argv [i ]);
@@ -100,6 +126,67 @@ static CLIArgs parse_args(int argc, char **argv) {
100126 return args ;
101127}
102128
129+ // Loop detection constants
130+ #define LOOP_BUF_SIZE 32
131+ #define LOOP_NGRAM 4
132+
133+ // Generate tokens with callback-based streaming.
134+ // Returns: n_generated, -1 on loop, -2 on error.
135+ static int generate_response (BnModel * model , BnTokenizer * tok , BnSampler * sampler ,
136+ int max_tokens , int * pos ,
137+ bn_token_callback cb , void * user_data ) {
138+ int loop_buf [LOOP_BUF_SIZE ];
139+ int loop_idx = 0 , gen_count = 0 ;
140+ memset (loop_buf , -1 , sizeof (loop_buf ));
141+
142+ float * logits = model -> state .logits ;
143+ if (!logits ) return -2 ;
144+
145+ for (int i = 0 ; i < max_tokens ; i ++ ) {
146+ int next = bn_sampler_sample (sampler , logits );
147+
148+ if (next == tok -> eot_id || next == tok -> eos_id )
149+ break ;
150+
151+ // Ring buffer loop detection
152+ loop_buf [loop_idx ] = next ;
153+ loop_idx = (loop_idx + 1 ) % LOOP_BUF_SIZE ;
154+ gen_count ++ ;
155+
156+ if (gen_count >= 2 * LOOP_NGRAM ) {
157+ int looping = 1 ;
158+ for (int k = 0 ; k < LOOP_NGRAM ; k ++ ) {
159+ int a = loop_buf [((loop_idx - 1 - k ) % LOOP_BUF_SIZE + LOOP_BUF_SIZE ) % LOOP_BUF_SIZE ];
160+ int b = loop_buf [((loop_idx - 1 - k - LOOP_NGRAM ) % LOOP_BUF_SIZE + LOOP_BUF_SIZE ) % LOOP_BUF_SIZE ];
161+ if (a != b ) { looping = 0 ; break ; }
162+ }
163+ if (looping ) return -1 ;
164+ }
165+
166+ bn_sampler_accept (sampler , next );
167+
168+ const char * piece = bn_tokenizer_decode (tok , next );
169+ if (piece && cb ) {
170+ if (cb (piece , next , user_data ))
171+ break ;
172+ }
173+
174+ logits = bn_transformer_forward (model , next , * pos );
175+ (* pos )++ ;
176+ if (!logits ) return -2 ;
177+ }
178+
179+ return gen_count ;
180+ }
181+
182+ static int print_token (const char * piece , int token_id , void * user_data ) {
183+ (void )token_id ;
184+ (void )user_data ;
185+ printf ("%s" , piece );
186+ fflush (stdout );
187+ return 0 ;
188+ }
189+
103190int main (int argc , char * * argv ) {
104191 sh_log_init (NULL );
105192 CLIArgs args = parse_args (argc , argv );
@@ -187,7 +274,6 @@ int main(int argc, char **argv) {
187274 SH_LOG_ERROR ("Failed to init tokenizer" );
188275 bn_model_free (& model );
189276 bn_gguf_free (gf );
190- bn_platform_unload_file (& mf );
191277 return 1 ;
192278 }
193279 {
@@ -220,7 +306,6 @@ int main(int argc, char **argv) {
220306 bn_tokenizer_free (& tokenizer );
221307 bn_model_free (& model );
222308 bn_gguf_free (gf );
223- bn_platform_unload_file (& mf );
224309 return 1 ;
225310 }
226311
@@ -287,48 +372,9 @@ int main(int argc, char **argv) {
287372 break ;
288373 }
289374
290- // Generate until eot_id, eos_id, or seq_len
291- // Loop detector: ring buffer of recent tokens, check for repeating n-grams
292- #define LOOP_BUF_SIZE 32
293- #define LOOP_NGRAM 4
294- int loop_buf [LOOP_BUF_SIZE ];
295- int loop_idx = 0 , gen_count = 0 ;
296- memset (loop_buf , -1 , sizeof (loop_buf ));
297-
298- for (int i = 0 ; i < args .n_tokens ; i ++ ) {
299- int next = bn_sampler_sample (& sampler , logits );
300-
301- if (next == tokenizer .eot_id || next == tokenizer .eos_id )
302- break ;
303-
304- // Record token in ring buffer and check for loops
305- loop_buf [loop_idx ] = next ;
306- loop_idx = (loop_idx + 1 ) % LOOP_BUF_SIZE ;
307- gen_count ++ ;
308-
309- if (gen_count >= 2 * LOOP_NGRAM ) {
310- // Check if last LOOP_NGRAM tokens match the LOOP_NGRAM before them
311- int looping = 1 ;
312- for (int k = 0 ; k < LOOP_NGRAM ; k ++ ) {
313- int a = loop_buf [((loop_idx - 1 - k ) % LOOP_BUF_SIZE + LOOP_BUF_SIZE ) % LOOP_BUF_SIZE ];
314- int b = loop_buf [((loop_idx - 1 - k - LOOP_NGRAM ) % LOOP_BUF_SIZE + LOOP_BUF_SIZE ) % LOOP_BUF_SIZE ];
315- if (a != b ) { looping = 0 ; break ; }
316- }
317- if (looping ) { gen_count = -1 ; break ; }
318- }
319-
320- bn_sampler_accept (& sampler , next );
321-
322- const char * piece = bn_tokenizer_decode (& tokenizer , next );
323- if (piece ) {
324- printf ("%s" , piece );
325- fflush (stdout );
326- }
327-
328- logits = bn_transformer_forward (& model , next , pos );
329- pos ++ ;
330- if (!logits ) break ;
331- }
375+ int gen_count = generate_response (& model , & tokenizer , & sampler ,
376+ args .n_tokens , & pos ,
377+ print_token , NULL );
332378
333379 // Feed EOT into KV cache to close the assistant turn
334380 bn_transformer_forward (& model , tokenizer .eot_id , pos );
@@ -337,7 +383,7 @@ int main(int argc, char **argv) {
337383 printf ("\n" );
338384
339385 turn_count ++ ;
340- int should_reset = (turn_count >= 2 ) || (gen_count == -1 );
386+ int should_reset = (turn_count >= 2 ) || (gen_count < 0 );
341387 if (should_reset ) {
342388 printf ("[auto-reset: starting fresh]\n" );
343389 pos = 0 ;
@@ -360,7 +406,6 @@ int main(int argc, char **argv) {
360406 bn_tokenizer_free (& tokenizer );
361407 bn_model_free (& model );
362408 bn_gguf_free (gf );
363- bn_platform_unload_file (& mf );
364409 return 1 ;
365410 }
366411 int n_prompt = bn_tokenizer_encode (& tokenizer , args .prompt , 1 , prompt_tokens ,
@@ -442,9 +487,8 @@ int main(int argc, char **argv) {
442487 // Cleanup
443488 bn_sampler_free (& sampler );
444489 bn_tokenizer_free (& tokenizer );
445- bn_model_free (& model );
490+ bn_model_free (& model ); // also unloads mmap'd file
446491 bn_gguf_free (gf );
447- bn_platform_unload_file (& mf );
448492
449493 return 0 ;
450494}
0 commit comments