Skip to content

Commit 0f2daad

Browse files
feat: add chat template struct to avoid tuple ordering errors (#1570)
1 parent 9946165 commit 0f2daad

File tree

1 file changed

+46
-33
lines changed

1 file changed

+46
-33
lines changed

router/src/infer.rs

Lines changed: 46 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -31,14 +31,10 @@ pub struct Infer {
3131
queue: Queue,
3232
/// Shared state
3333
shared: Arc<Shared>,
34+
/// Chat template
35+
chat_template: Option<ChatTemplate>,
3436
/// Inference limit
3537
limit_concurrent_requests: Arc<Semaphore>,
36-
/// Chat template (template, bos_token, eos_token)
37-
template: (
38-
Option<Template<'static, 'static>>,
39-
Option<String>,
40-
Option<String>,
41-
),
4238
}
4339

4440
/// Infer shared state
@@ -88,32 +84,19 @@ impl Infer {
8884
generation_health,
8985
));
9086

87+
let chat_template = tokenizer_config
88+
.chat_template
89+
.map(|t| ChatTemplate::new(t, tokenizer_config.bos_token, tokenizer_config.eos_token));
90+
9191
// Inference limit with a semaphore
9292
let semaphore = Arc::new(Semaphore::new(max_concurrent_requests));
9393

94-
let template = tokenizer_config.chat_template.map(|t| {
95-
let mut env = Box::new(Environment::new());
96-
let template_str = t.into_boxed_str();
97-
env.add_function("raise_exception", raise_exception);
98-
// leaking env and template_str as read-only, static resources for performance.
99-
Box::leak(env)
100-
.template_from_str(Box::leak(template_str))
101-
.unwrap()
102-
});
103-
let eos_token = tokenizer_config
104-
.eos_token
105-
.map_or_else(String::new, |t| t)
106-
.into();
107-
let bos_token = tokenizer_config
108-
.bos_token
109-
.map_or_else(String::new, |t| t)
110-
.into();
11194
Self {
11295
validation,
11396
queue,
11497
shared,
98+
chat_template,
11599
limit_concurrent_requests: semaphore,
116-
template: (template, bos_token, eos_token),
117100
}
118101
}
119102

@@ -192,20 +175,14 @@ impl Infer {
192175
/// Apply the chat template to the chat request
193176
#[instrument(skip_all)]
194177
pub(crate) fn apply_chat_template(&self, messages: Vec<Message>) -> Result<String, InferError> {
195-
let (template, bos_token, eos_token) = &self.template;
196-
template
178+
self.chat_template
197179
.as_ref()
198180
.ok_or_else(|| InferError::TemplateError(ErrorKind::TemplateNotFound.into()))?
199-
.render(ChatTemplateInputs {
200-
messages,
201-
eos_token: eos_token.as_deref(),
202-
bos_token: bos_token.as_deref(),
203-
add_generation_prompt: true,
204-
})
181+
.apply(messages)
205182
.map_err(|e| {
206183
metrics::increment_counter!("tgi_request_failure", "err" => "template");
207184
tracing::error!("{e}");
208-
InferError::TemplateError(e)
185+
e
209186
})
210187
}
211188

@@ -329,6 +306,42 @@ impl Infer {
329306
}
330307
}
331308

309+
#[derive(Clone)]
310+
struct ChatTemplate {
311+
template: Template<'static, 'static>,
312+
bos_token: Option<String>,
313+
eos_token: Option<String>,
314+
}
315+
316+
impl ChatTemplate {
317+
fn new(template: String, bos_token: Option<String>, eos_token: Option<String>) -> Self {
318+
let mut env = Box::new(Environment::new());
319+
let template_str = template.into_boxed_str();
320+
env.add_function("raise_exception", raise_exception);
321+
// leaking env and template_str as read-only, static resources for performance.
322+
let template = Box::leak(env)
323+
.template_from_str(Box::leak(template_str))
324+
.unwrap();
325+
326+
Self {
327+
template,
328+
bos_token,
329+
eos_token,
330+
}
331+
}
332+
333+
fn apply(&self, messages: Vec<Message>) -> Result<String, InferError> {
334+
self.template
335+
.render(ChatTemplateInputs {
336+
messages,
337+
bos_token: self.bos_token.as_deref(),
338+
eos_token: self.eos_token.as_deref(),
339+
add_generation_prompt: true,
340+
})
341+
.map_err(InferError::TemplateError)
342+
}
343+
}
344+
332345
/// Batching logic
333346
/// Will be launched in a background Tokio task
334347
///

0 commit comments

Comments
 (0)