Skip to content

Commit eb6722e

Browse files
authored
feat: add chat_template_kwargs param to v1/chat/completion (#3016)
Signed-off-by: Chi McIsaac <[email protected]>
1 parent 9060ce1 commit eb6722e

File tree

11 files changed

+40
-3
lines changed

11 files changed

+40
-3
lines changed

lib/llm/src/entrypoint/input/batch.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -228,6 +228,7 @@ async fn evaluate(
228228
inner,
229229
common: Default::default(),
230230
nvext: None,
231+
chat_template_args: None,
231232
};
232233
let mut stream = engine.generate(Context::new(req)).await?;
233234
let mut output = String::new();

lib/llm/src/entrypoint/input/text.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,7 @@ async fn main_loop(
111111
inner,
112112
common: Default::default(),
113113
nvext: None,
114+
chat_template_args: None,
114115
};
115116

116117
// Call the model

lib/llm/src/http/service/openai.rs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1350,6 +1350,7 @@ mod tests {
13501350
},
13511351
common: Default::default(),
13521352
nvext: None,
1353+
chat_template_args: None,
13531354
};
13541355
let result = validate_chat_completion_required_fields(&request);
13551356
assert!(result.is_err());
@@ -1377,6 +1378,7 @@ mod tests {
13771378
},
13781379
common: Default::default(),
13791380
nvext: None,
1381+
chat_template_args: None,
13801382
};
13811383
let result = validate_chat_completion_required_fields(&request);
13821384
assert!(result.is_ok());
@@ -1549,6 +1551,7 @@ mod tests {
15491551
},
15501552
common: Default::default(),
15511553
nvext: None,
1554+
chat_template_args: None,
15521555
};
15531556

15541557
let result = validate_chat_completion_fields_generic(&request);
@@ -1576,6 +1579,7 @@ mod tests {
15761579
},
15771580
common: Default::default(),
15781581
nvext: None,
1582+
chat_template_args: None,
15791583
};
15801584
let result = validate_chat_completion_fields_generic(&request);
15811585
assert!(result.is_err());
@@ -1602,6 +1606,7 @@ mod tests {
16021606
},
16031607
common: Default::default(),
16041608
nvext: None,
1609+
chat_template_args: None,
16051610
};
16061611
let result = validate_chat_completion_fields_generic(&request);
16071612
assert!(result.is_err());
@@ -1628,6 +1633,7 @@ mod tests {
16281633
},
16291634
common: Default::default(),
16301635
nvext: None,
1636+
chat_template_args: None,
16311637
};
16321638
let result = validate_chat_completion_fields_generic(&request);
16331639
assert!(result.is_err());
@@ -1656,6 +1662,7 @@ mod tests {
16561662
.build()
16571663
.unwrap(),
16581664
nvext: None,
1665+
chat_template_args: None,
16591666
};
16601667
let result = validate_chat_completion_fields_generic(&request);
16611668
assert!(result.is_err());
@@ -1682,6 +1689,7 @@ mod tests {
16821689
},
16831690
common: Default::default(),
16841691
nvext: None,
1692+
chat_template_args: None,
16851693
};
16861694
let result = validate_chat_completion_fields_generic(&request);
16871695
assert!(result.is_err());

lib/llm/src/preprocessor/prompt.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020

2121
use anyhow::Result;
2222
use minijinja::value::Value;
23+
use std::collections::HashMap;
2324
use std::sync::Arc;
2425

2526
mod template;
@@ -57,6 +58,11 @@ pub trait OAIChatLikeRequest {
5758

5859
fn should_add_generation_prompt(&self) -> bool;
5960

61+
/// Optional additional args to merge into the chat template context
62+
fn chat_template_args(&self) -> Option<&HashMap<String, serde_json::Value>> {
63+
None
64+
}
65+
6066
/// Returns the type of input for the prompt. Default is Text.
6167
fn prompt_input_type(&self) -> PromptInput {
6268
PromptInput::Text(TextInput::Single(String::new()))

lib/llm/src/preprocessor/prompt/template/oai.rs

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,10 @@ impl OAIChatLikeRequest for NvCreateChatCompletionRequest {
114114
fn extract_text(&self) -> Option<TextInput> {
115115
Some(TextInput::Single(String::new()))
116116
}
117+
118+
fn chat_template_args(&self) -> Option<&std::collections::HashMap<String, serde_json::Value>> {
119+
self.chat_template_args.as_ref()
120+
}
117121
}
118122

119123
impl OAIChatLikeRequest for NvCreateCompletionRequest {
@@ -207,9 +211,13 @@ impl OAIPromptFormatter for HfTokenizerConfigJsonFormatter {
207211
..mixins
208212
};
209213

210-
let ctx = context! { ..ctx, ..context! {
211-
212-
}};
214+
// Merge any additional args into the context last so they take precedence
215+
let ctx = if let Some(args) = req.chat_template_args() {
216+
let extra = Value::from_serialize(args);
217+
context! { ..ctx, ..extra }
218+
} else {
219+
ctx
220+
};
213221

214222
let tmpl: minijinja::Template<'_, '_> = if has_tools {
215223
self.env.get_template("tool_use")?

lib/llm/src/protocols/openai/chat_completions.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,10 @@ pub struct NvCreateChatCompletionRequest {
4141

4242
#[serde(skip_serializing_if = "Option::is_none")]
4343
pub nvext: Option<NvExt>,
44+
45+
/// Extra args to pass to the chat template rendering context
46+
#[serde(default, skip_serializing_if = "Option::is_none")]
47+
pub chat_template_args: Option<std::collections::HashMap<String, serde_json::Value>>,
4448
}
4549

4650
/// A response structure for unary chat completion responses, embedding OpenAI's

lib/llm/src/protocols/openai/responses.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,7 @@ impl TryFrom<NvCreateResponse> for NvCreateChatCompletionRequest {
175175
},
176176
common: Default::default(),
177177
nvext: resp.nvext,
178+
chat_template_args: None,
178179
})
179180
}
180181
}

lib/llm/tests/http-service.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -768,6 +768,7 @@ async fn test_nv_custom_client() {
768768
inner: inner_request,
769769
common: Default::default(),
770770
nvext: None,
771+
chat_template_args: None,
771772
};
772773

773774
let result = nv_custom_client.chat_stream(request).await;
@@ -807,6 +808,7 @@ async fn test_nv_custom_client() {
807808
inner: inner_request,
808809
common: Default::default(),
809810
nvext: None,
811+
chat_template_args: None,
810812
};
811813

812814
let result = nv_custom_client.chat_stream(request).await;
@@ -847,6 +849,7 @@ async fn test_nv_custom_client() {
847849
inner: inner_request,
848850
common: Default::default(),
849851
nvext: None,
852+
chat_template_args: None,
850853
};
851854

852855
let result = nv_custom_client

lib/llm/tests/preprocessor.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -270,6 +270,7 @@ impl Request {
270270
inner,
271271
common: Default::default(),
272272
nvext: None,
273+
chat_template_args: None,
273274
}
274275
}
275276
}

lib/llm/tests/test_common_ext.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ fn test_sampling_parameters_include_stop_str_in_output_extraction() {
6767
.build()
6868
.unwrap(),
6969
nvext: None,
70+
chat_template_args: None,
7071
};
7172

7273
let sampling = request.extract_sampling_options().unwrap();
@@ -327,6 +328,7 @@ fn test_serialization_preserves_structure() {
327328
ignore_eos: Some(false),
328329
..Default::default()
329330
}),
331+
chat_template_args: None,
330332
};
331333

332334
let json = serde_json::to_value(&request).unwrap();
@@ -376,6 +378,7 @@ fn test_sampling_parameters_extraction() {
376378
.build()
377379
.unwrap(),
378380
nvext: None,
381+
chat_template_args: None,
379382
};
380383

381384
let sampling_options = request.extract_sampling_options().unwrap();

0 commit comments

Comments
 (0)