Skip to content

Commit f20fc4c

Browse files
committed
fix: assistant msg cropping might crop suffix special tokens
The complexity here: we need to crop the assistant messages so that they end on a block boundary. We cannot pad them out, because this is not what the model server will do when generating the assistant message. But it will cache the prefix of full blocks. Hence the need to crop. However, we cannot just crop at the end, as this would also crop off any "end of text" special tokens that the chat template adds to the end of the `self.assistant(m)` token sequence. Therefore, we need to crop just the message part. This logic tries to do all of that in a way that is agnostic to the cast template. However, the logic does currently assume that the chat template will never add special tokens *in the middle* of the given message `m`; it assumes special tokens are only ever added (if at all) to the beginning or end. Signed-off-by: Nick Mitchell <[email protected]>
1 parent be37c59 commit f20fc4c

File tree

1 file changed

+170
-10
lines changed

1 file changed

+170
-10
lines changed

spnl/src/tokenize.rs

Lines changed: 170 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -32,11 +32,60 @@ impl Tokenizer {
3232
Ok(())
3333
}
3434

35+
/// The complexity here: we need to crop the assistant messages so
36+
/// that they end on a block boundary. We cannot pad them out,
37+
/// because this is not what the model server will do when
38+
/// generating the assistant message. But it will cache the prefix
39+
/// of full blocks. Hence the need to crop. However, we cannot
40+
/// just crop at the end, as this would also crop off any "end of
41+
/// text" special tokens that the chat template adds to the end of
42+
/// the `self.assistant(m)` token sequence. Therefore, we need to
43+
/// crop just the message part. This logic tries to do all of that
44+
/// in a way that is agnostic to the cast template. However, the
45+
/// logic does currently assume that the chat template will never
46+
/// add special tokens *in the middle* of the given message `m`;
47+
/// it assumes special tokens are only ever added (if at all) to
48+
/// the beginning or end.
3549
fn assistanttok(&self, m: &str, tokens: &mut Vec<u32>) -> tokenizers::tokenizer::Result<()> {
36-
self.extend_crop(
37-
self.tok.encode_fast(self.assistant(m), false)?.get_ids(),
38-
tokens,
39-
);
50+
let binding = self.tok.encode_fast(self.assistant(m), false)?;
51+
let binding2 = self.tok.encode_fast(m, false)?;
52+
let with_chat_template = binding.get_ids();
53+
let without_chat_template = binding2.get_ids();
54+
55+
// TODO this is imperfect...
56+
let start_of_message_idx = with_chat_template
57+
.iter()
58+
.position(|t| *t == without_chat_template[0]);
59+
let end_of_message_idx = start_of_message_idx
60+
.map(|start_of_message_idx| start_of_message_idx + without_chat_template.len());
61+
// [pppppmmmmmmmmmss] <- ppppp are the prefix speical tokens added by chat template; ss suffix special tokens
62+
// ^ start_of_message_idx
63+
// ^ end_of_message_idx
64+
65+
if with_chat_template.len() > self.block_size {
66+
eprintln!(
67+
"Warning (spnl): assistant message cannot be cropped due to length chat template"
68+
)
69+
}
70+
71+
if without_chat_template.is_empty() {
72+
self.extend(with_chat_template, tokens);
73+
} else if let Some(start_of_message_idx) = start_of_message_idx
74+
&& let Some(end_of_message_idx) = end_of_message_idx
75+
{
76+
self.extend_crop(
77+
with_chat_template,
78+
start_of_message_idx,
79+
end_of_message_idx,
80+
tokens,
81+
);
82+
} else {
83+
eprintln!(
84+
"Warning (spnl): assistant message could not be cropped because message is not found within chat template"
85+
);
86+
self.extend(with_chat_template, tokens);
87+
}
88+
4089
Ok(())
4190
}
4291

@@ -67,18 +116,39 @@ impl Tokenizer {
67116
tokens.extend(extra);
68117
}
69118

70-
/// Extend with tokens, cropping to a block boundary
71-
fn extend_crop(&self, extra: &[u32], tokens: &mut Vec<u32>) {
119+
/// Extend with tokens, cropping to a block boundary, but only the `mmm` part in the middle, as follows:
120+
/// [pppppmmmmmmmmmss] <- msg_with_chat_template
121+
/// ^ start_of_message_idx
122+
/// ^ end_of_message_idx
123+
fn extend_crop(
124+
&self,
125+
msg_with_chat_template: &[u32],
126+
start_of_message_idx: usize,
127+
end_of_message_idx: usize,
128+
tokens: &mut Vec<u32>,
129+
) {
72130
// Round down to nearest block boundary. Note: for future
73131
// reference, if we need to round up to nearest block
74132
// boundary, replace `tokens.len()` with
75133
// `tokens.len()+self.block_size-1`.
76-
let end = extra.len() + tokens.len();
134+
let end = msg_with_chat_template.len() + tokens.len();
77135
let nearest_block_boundary = end / self.block_size * self.block_size;
78136
let amount_to_crop = end - nearest_block_boundary;
79-
let extra_end = extra.len() - amount_to_crop;
80-
81-
self.extend(&extra[0..extra_end], tokens);
137+
let end_of_crop = if amount_to_crop > (end_of_message_idx - start_of_message_idx) {
138+
start_of_message_idx
139+
} else {
140+
end_of_message_idx - amount_to_crop
141+
};
142+
143+
let m = msg_with_chat_template;
144+
let cropped = [
145+
&m[0..start_of_message_idx],
146+
&m[start_of_message_idx..end_of_crop],
147+
&m[end_of_message_idx..],
148+
]
149+
.concat();
150+
151+
self.extend(&cropped, tokens);
82152
}
83153

84154
/// Pad to block boundary, then push
@@ -460,3 +530,93 @@ pub fn tokenize_prepare(
460530
_ => todo!(),
461531
}
462532
}
533+
534+
#[cfg(test)]
535+
mod tests {
536+
use super::*;
537+
use itertools::Itertools;
538+
539+
const PAD_TOKEN: u32 = 27;
540+
const BLOCK_SIZE: usize = 16;
541+
542+
const MODEL: &str = "ibm-granite/granite-3.3-2b-instruct"; // TODO find smaller model with public tokenizers.json
543+
const START_OF_ROLE: u32 = 49152;
544+
const END_OF_ROLE: u32 = 49153;
545+
const END_OF_TEXT: u32 = 0;
546+
const USER: u32 = 496;
547+
const ASSISTANT: u32 = 17594;
548+
const HELLO: u32 = 7656;
549+
const LONGER: u32 = 8928;
550+
551+
fn tok() -> Result<::std::sync::Arc<Tokenizer>, ::std::sync::Arc<tokenizers::tokenizer::Error>>
552+
{
553+
init(2).get_or_create(&MODEL.into(), PAD_TOKEN, None, None, BLOCK_SIZE)
554+
}
555+
556+
#[test]
557+
fn create_tokenizer() -> Result<(), ::std::sync::Arc<tokenizers::tokenizer::Error>> {
558+
tok().map(|_| ())
559+
}
560+
561+
#[test]
562+
fn user() -> Result<(), ::std::sync::Arc<tokenizers::tokenizer::Error>> {
563+
assert_eq!(
564+
tok().map(|tok| tok.user("hello"))?,
565+
"<|start_of_role|>user<|end_of_role|>hello<|end_of_text|>"
566+
);
567+
Ok(())
568+
}
569+
570+
#[test]
571+
fn usertok() -> Result<(), ::std::sync::Arc<tokenizers::tokenizer::Error>> {
572+
let mut tokens = vec![];
573+
tok()?.usertok("hello", &mut tokens)?;
574+
assert_eq!(
575+
tokens,
576+
[START_OF_ROLE, USER, END_OF_ROLE, HELLO, END_OF_TEXT]
577+
);
578+
Ok(())
579+
}
580+
581+
#[test]
582+
fn assistant() -> Result<(), ::std::sync::Arc<tokenizers::tokenizer::Error>> {
583+
assert_eq!(
584+
tok().map(|tok| tok.assistant("hello"))?,
585+
"<|start_of_role|>assistant<|end_of_role|>hello<|end_of_text|>"
586+
);
587+
Ok(())
588+
}
589+
590+
#[test]
591+
fn assistanttok_fully_cropped() -> Result<(), ::std::sync::Arc<tokenizers::tokenizer::Error>> {
592+
let mut tokens = vec![];
593+
tok()?.assistanttok("hello", &mut tokens)?;
594+
assert_eq!(tokens, [START_OF_ROLE, ASSISTANT, END_OF_ROLE, END_OF_TEXT]);
595+
Ok(())
596+
}
597+
598+
#[test]
599+
fn assistanttok_partially_cropped() -> Result<(), ::std::sync::Arc<tokenizers::tokenizer::Error>>
600+
{
601+
let repeat_input = 17; // repeat this many times for the input message
602+
let repeat_output = 11; // expect this many repetitions after cropping
603+
let mut tokens = vec![];
604+
tok()?.assistanttok(
605+
format!(
606+
"hello {}",
607+
::std::iter::repeat_n("longer", repeat_input).join(" ")
608+
)
609+
.as_str(),
610+
&mut tokens,
611+
)?;
612+
assert_eq!(
613+
tokens,
614+
[START_OF_ROLE, ASSISTANT, END_OF_ROLE, HELLO]
615+
.into_iter()
616+
.chain(::std::iter::repeat_n(LONGER, repeat_output))
617+
.chain([END_OF_TEXT])
618+
.collect::<Vec<u32>>(),
619+
);
620+
Ok(())
621+
}
622+
}

0 commit comments

Comments
 (0)