@@ -32,11 +32,60 @@ impl Tokenizer {
3232 Ok ( ( ) )
3333 }
3434
35+ /// The complexity here: we need to crop the assistant messages so
36+ /// that they end on a block boundary. We cannot pad them out,
37+ /// because this is not what the model server will do when
38+ /// generating the assistant message. But it will cache the prefix
39+ /// of full blocks. Hence the need to crop. However, we cannot
40+ /// just crop at the end, as this would also crop off any "end of
41+ /// text" special tokens that the chat template adds to the end of
42+ /// the `self.assistant(m)` token sequence. Therefore, we need to
43+ /// crop just the message part. This logic tries to do all of that
44+ /// in a way that is agnostic to the cast template. However, the
45+ /// logic does currently assume that the chat template will never
46+ /// add special tokens *in the middle* of the given message `m`;
47+ /// it assumes special tokens are only ever added (if at all) to
48+ /// the beginning or end.
3549 fn assistanttok ( & self , m : & str , tokens : & mut Vec < u32 > ) -> tokenizers:: tokenizer:: Result < ( ) > {
36- self . extend_crop (
37- self . tok . encode_fast ( self . assistant ( m) , false ) ?. get_ids ( ) ,
38- tokens,
39- ) ;
50+ let binding = self . tok . encode_fast ( self . assistant ( m) , false ) ?;
51+ let binding2 = self . tok . encode_fast ( m, false ) ?;
52+ let with_chat_template = binding. get_ids ( ) ;
53+ let without_chat_template = binding2. get_ids ( ) ;
54+
55+ // TODO this is imperfect...
56+ let start_of_message_idx = with_chat_template
57+ . iter ( )
58+ . position ( |t| * t == without_chat_template[ 0 ] ) ;
59+ let end_of_message_idx = start_of_message_idx
60+ . map ( |start_of_message_idx| start_of_message_idx + without_chat_template. len ( ) ) ;
61+ // [pppppmmmmmmmmmss] <- ppppp are the prefix speical tokens added by chat template; ss suffix special tokens
62+ // ^ start_of_message_idx
63+ // ^ end_of_message_idx
64+
65+ if with_chat_template. len ( ) > self . block_size {
66+ eprintln ! (
67+ "Warning (spnl): assistant message cannot be cropped due to length chat template"
68+ )
69+ }
70+
71+ if without_chat_template. is_empty ( ) {
72+ self . extend ( with_chat_template, tokens) ;
73+ } else if let Some ( start_of_message_idx) = start_of_message_idx
74+ && let Some ( end_of_message_idx) = end_of_message_idx
75+ {
76+ self . extend_crop (
77+ with_chat_template,
78+ start_of_message_idx,
79+ end_of_message_idx,
80+ tokens,
81+ ) ;
82+ } else {
83+ eprintln ! (
84+ "Warning (spnl): assistant message could not be cropped because message is not found within chat template"
85+ ) ;
86+ self . extend ( with_chat_template, tokens) ;
87+ }
88+
4089 Ok ( ( ) )
4190 }
4291
@@ -67,18 +116,39 @@ impl Tokenizer {
67116 tokens. extend ( extra) ;
68117 }
69118
70- /// Extend with tokens, cropping to a block boundary
71- fn extend_crop ( & self , extra : & [ u32 ] , tokens : & mut Vec < u32 > ) {
119+ /// Extend with tokens, cropping to a block boundary, but only the `mmm` part in the middle, as follows:
120+ /// [pppppmmmmmmmmmss] <- msg_with_chat_template
121+ /// ^ start_of_message_idx
122+ /// ^ end_of_message_idx
123+ fn extend_crop (
124+ & self ,
125+ msg_with_chat_template : & [ u32 ] ,
126+ start_of_message_idx : usize ,
127+ end_of_message_idx : usize ,
128+ tokens : & mut Vec < u32 > ,
129+ ) {
72130 // Round down to nearest block boundary. Note: for future
73131 // reference, if we need to round up to nearest block
74132 // boundary, replace `tokens.len()` with
75133 // `tokens.len()+self.block_size-1`.
76- let end = extra . len ( ) + tokens. len ( ) ;
134+ let end = msg_with_chat_template . len ( ) + tokens. len ( ) ;
77135 let nearest_block_boundary = end / self . block_size * self . block_size ;
78136 let amount_to_crop = end - nearest_block_boundary;
79- let extra_end = extra. len ( ) - amount_to_crop;
80-
81- self . extend ( & extra[ 0 ..extra_end] , tokens) ;
137+ let end_of_crop = if amount_to_crop > ( end_of_message_idx - start_of_message_idx) {
138+ start_of_message_idx
139+ } else {
140+ end_of_message_idx - amount_to_crop
141+ } ;
142+
143+ let m = msg_with_chat_template;
144+ let cropped = [
145+ & m[ 0 ..start_of_message_idx] ,
146+ & m[ start_of_message_idx..end_of_crop] ,
147+ & m[ end_of_message_idx..] ,
148+ ]
149+ . concat ( ) ;
150+
151+ self . extend ( & cropped, tokens) ;
82152 }
83153
84154 /// Pad to block boundary, then push
@@ -460,3 +530,93 @@ pub fn tokenize_prepare(
460530 _ => todo ! ( ) ,
461531 }
462532}
533+
534+ #[ cfg( test) ]
535+ mod tests {
536+ use super :: * ;
537+ use itertools:: Itertools ;
538+
539+ const PAD_TOKEN : u32 = 27 ;
540+ const BLOCK_SIZE : usize = 16 ;
541+
542+ const MODEL : & str = "ibm-granite/granite-3.3-2b-instruct" ; // TODO find smaller model with public tokenizers.json
543+ const START_OF_ROLE : u32 = 49152 ;
544+ const END_OF_ROLE : u32 = 49153 ;
545+ const END_OF_TEXT : u32 = 0 ;
546+ const USER : u32 = 496 ;
547+ const ASSISTANT : u32 = 17594 ;
548+ const HELLO : u32 = 7656 ;
549+ const LONGER : u32 = 8928 ;
550+
551+ fn tok ( ) -> Result < :: std:: sync:: Arc < Tokenizer > , :: std:: sync:: Arc < tokenizers:: tokenizer:: Error > >
552+ {
553+ init ( 2 ) . get_or_create ( & MODEL . into ( ) , PAD_TOKEN , None , None , BLOCK_SIZE )
554+ }
555+
556+ #[ test]
557+ fn create_tokenizer ( ) -> Result < ( ) , :: std:: sync:: Arc < tokenizers:: tokenizer:: Error > > {
558+ tok ( ) . map ( |_| ( ) )
559+ }
560+
561+ #[ test]
562+ fn user ( ) -> Result < ( ) , :: std:: sync:: Arc < tokenizers:: tokenizer:: Error > > {
563+ assert_eq ! (
564+ tok( ) . map( |tok| tok. user( "hello" ) ) ?,
565+ "<|start_of_role|>user<|end_of_role|>hello<|end_of_text|>"
566+ ) ;
567+ Ok ( ( ) )
568+ }
569+
570+ #[ test]
571+ fn usertok ( ) -> Result < ( ) , :: std:: sync:: Arc < tokenizers:: tokenizer:: Error > > {
572+ let mut tokens = vec ! [ ] ;
573+ tok ( ) ?. usertok ( "hello" , & mut tokens) ?;
574+ assert_eq ! (
575+ tokens,
576+ [ START_OF_ROLE , USER , END_OF_ROLE , HELLO , END_OF_TEXT ]
577+ ) ;
578+ Ok ( ( ) )
579+ }
580+
581+ #[ test]
582+ fn assistant ( ) -> Result < ( ) , :: std:: sync:: Arc < tokenizers:: tokenizer:: Error > > {
583+ assert_eq ! (
584+ tok( ) . map( |tok| tok. assistant( "hello" ) ) ?,
585+ "<|start_of_role|>assistant<|end_of_role|>hello<|end_of_text|>"
586+ ) ;
587+ Ok ( ( ) )
588+ }
589+
590+ #[ test]
591+ fn assistanttok_fully_cropped ( ) -> Result < ( ) , :: std:: sync:: Arc < tokenizers:: tokenizer:: Error > > {
592+ let mut tokens = vec ! [ ] ;
593+ tok ( ) ?. assistanttok ( "hello" , & mut tokens) ?;
594+ assert_eq ! ( tokens, [ START_OF_ROLE , ASSISTANT , END_OF_ROLE , END_OF_TEXT ] ) ;
595+ Ok ( ( ) )
596+ }
597+
598+ #[ test]
599+ fn assistanttok_partially_cropped ( ) -> Result < ( ) , :: std:: sync:: Arc < tokenizers:: tokenizer:: Error > >
600+ {
601+ let repeat_input = 17 ; // repeat this many times for the input message
602+ let repeat_output = 11 ; // expect this many repetitions after cropping
603+ let mut tokens = vec ! [ ] ;
604+ tok ( ) ?. assistanttok (
605+ format ! (
606+ "hello {}" ,
607+ :: std:: iter:: repeat_n( "longer" , repeat_input) . join( " " )
608+ )
609+ . as_str ( ) ,
610+ & mut tokens,
611+ ) ?;
612+ assert_eq ! (
613+ tokens,
614+ [ START_OF_ROLE , ASSISTANT , END_OF_ROLE , HELLO ]
615+ . into_iter( )
616+ . chain( :: std:: iter:: repeat_n( LONGER , repeat_output) )
617+ . chain( [ END_OF_TEXT ] )
618+ . collect:: <Vec <u32 >>( ) ,
619+ ) ;
620+ Ok ( ( ) )
621+ }
622+ }
0 commit comments