Skip to content

Commit 37a1239

Browse files
committed
feat: improve, simplify and rename tool choice struct add required support and refactor
1 parent 72a0305 commit 37a1239

File tree

5 files changed

+143
-112
lines changed

5 files changed

+143
-112
lines changed

docs/openapi.json

Lines changed: 38 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -829,6 +829,42 @@
829829
}
830830
}
831831
},
832+
"ChatCompletionToolChoiceOption": {
833+
"oneOf": [
834+
{
835+
"type": "string",
836+
"description": "Means the model can pick between generating a message or calling one or more tools.",
837+
"enum": [
838+
"auto"
839+
]
840+
},
841+
{
842+
"type": "string",
843+
"description": "Means the model will not call any tool and instead generates a message.",
844+
"enum": [
845+
"none"
846+
]
847+
},
848+
{
849+
"type": "string",
850+
"description": "Means the model must call one or more tools.",
851+
"enum": [
852+
"required"
853+
]
854+
},
855+
{
856+
"type": "object",
857+
"required": [
858+
"function"
859+
],
860+
"properties": {
861+
"function": {
862+
"$ref": "#/components/schemas/FunctionName"
863+
}
864+
}
865+
}
866+
]
867+
},
832868
"ChatCompletionTopLogprob": {
833869
"type": "object",
834870
"required": [
@@ -963,9 +999,10 @@
963999
"tool_choice": {
9641000
"allOf": [
9651001
{
966-
"$ref": "#/components/schemas/ToolChoice"
1002+
"$ref": "#/components/schemas/ChatCompletionToolChoiceOption"
9671003
}
9681004
],
1005+
"default": "null",
9691006
"nullable": true
9701007
},
9711008
"tool_prompt": {
@@ -2103,45 +2140,6 @@
21032140
}
21042141
}
21052142
},
2106-
"ToolChoice": {
2107-
"allOf": [
2108-
{
2109-
"$ref": "#/components/schemas/ToolType"
2110-
}
2111-
],
2112-
"nullable": true
2113-
},
2114-
"ToolType": {
2115-
"oneOf": [
2116-
{
2117-
"type": "string",
2118-
"description": "Means the model can pick between generating a message or calling one or more tools.",
2119-
"enum": [
2120-
"auto"
2121-
]
2122-
},
2123-
{
2124-
"type": "string",
2125-
"description": "Means the model will not call any tool and instead generates a message.",
2126-
"enum": [
2127-
"none"
2128-
]
2129-
},
2130-
{
2131-
"type": "object",
2132-
"required": [
2133-
"function"
2134-
],
2135-
"properties": {
2136-
"function": {
2137-
"$ref": "#/components/schemas/FunctionName"
2138-
}
2139-
}
2140-
}
2141-
],
2142-
"description": "Controls which (if any) tool is called by the model.",
2143-
"example": "auto"
2144-
},
21452143
"Url": {
21462144
"type": "object",
21472145
"required": [

router/src/infer/tool_grammar.rs

Lines changed: 30 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
use crate::infer::InferError;
22
use crate::{
3-
FunctionDefinition, FunctionRef, FunctionsMap, JsonSchemaTool, Properties, Tool, ToolChoice,
4-
ToolType,
3+
ChatCompletionToolChoiceOption, FunctionDefinition, FunctionRef, FunctionsMap, JsonSchemaTool,
4+
Properties, Tool,
55
};
66
use serde_json::{json, Map, Value};
77
use std::collections::HashMap;
@@ -20,44 +20,47 @@ impl ToolGrammar {
2020

2121
pub fn apply(
2222
tools: Vec<Tool>,
23-
tool_choice: ToolChoice,
23+
tool_choice: ChatCompletionToolChoiceOption,
2424
) -> Result<(Vec<Tool>, Option<JsonSchemaTool>), InferError> {
2525
// if no tools are provided, we return None
2626
if tools.is_empty() {
2727
return Ok((tools, None));
2828
}
2929

30-
let tool_choice = tool_choice.0.unwrap_or(ToolType::OneOf);
31-
3230
let mut tools = tools.clone();
3331

34-
// add the no_tool function to the tools
35-
let no_tool = Tool {
36-
r#type: "function".to_string(),
37-
function: FunctionDefinition {
38-
name: "no_tool".to_string(),
39-
description: Some("Open ened response with no specific tool selected".to_string()),
40-
arguments: json!({
41-
"type": "object",
42-
"properties": {
43-
"content": {
44-
"type": "string",
45-
"description": "The response content",
46-
}
47-
},
48-
"required": ["content"]
49-
}),
50-
},
51-
};
52-
tools.push(no_tool);
32+
// add the no_tool function to the tools as long as we are not required to use a specific tool
33+
if tool_choice != ChatCompletionToolChoiceOption::Required {
34+
let no_tool = Tool {
35+
r#type: "function".to_string(),
36+
function: FunctionDefinition {
37+
name: "no_tool".to_string(),
38+
description: Some(
39+
"Open ended response with no specific tool selected".to_string(),
40+
),
41+
arguments: json!({
42+
"type": "object",
43+
"properties": {
44+
"content": {
45+
"type": "string",
46+
"description": "The response content",
47+
}
48+
},
49+
"required": ["content"]
50+
}),
51+
},
52+
};
53+
tools.push(no_tool);
54+
}
5355

5456
// if tools are provided and no tool_choice we default to the OneOf
5557
let tools_to_use = match tool_choice {
56-
ToolType::Function(function) => {
58+
ChatCompletionToolChoiceOption::Function(function) => {
5759
vec![Self::find_tool_by_name(&tools, &function.name)?]
5860
}
59-
ToolType::OneOf => tools.clone(),
60-
ToolType::NoTool => return Ok((tools, None)),
61+
ChatCompletionToolChoiceOption::Required => tools.clone(),
62+
ChatCompletionToolChoiceOption::Auto => tools.clone(),
63+
ChatCompletionToolChoiceOption::NoTool => return Ok((tools, None)),
6164
};
6265

6366
let functions: HashMap<String, serde_json::Value> = tools_to_use

router/src/lib.rs

Lines changed: 67 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -849,8 +849,8 @@ pub(crate) struct ChatRequest {
849849

850850
/// A specific tool to use. If not provided, the model will default to use any of the tools provided in the tools parameter.
851851
#[serde(default)]
852-
#[schema(nullable = true, example = "null")]
853-
pub tool_choice: ToolChoice,
852+
#[schema(nullable = true, default = "null", example = "null")]
853+
pub tool_choice: Option<ChatCompletionToolChoiceOption>,
854854

855855
/// Response format constraints for the generation.
856856
///
@@ -906,8 +906,18 @@ impl ChatRequest {
906906
let (inputs, grammar, using_tools) = prepare_chat_input(
907907
infer,
908908
response_format,
909-
tools,
910-
tool_choice,
909+
tools.clone(),
910+
// unwrap or default (use "auto" if tools are present, and "none" if not)
911+
tool_choice.map_or_else(
912+
|| {
913+
if tools.is_some() {
914+
ChatCompletionToolChoiceOption::Auto
915+
} else {
916+
ChatCompletionToolChoiceOption::NoTool
917+
}
918+
},
919+
|t| t,
920+
),
911921
&tool_prompt,
912922
guideline,
913923
messages,
@@ -956,22 +966,6 @@ pub fn default_tool_prompt() -> String {
956966
"\nGiven the functions available, please respond with a JSON for a function call with its proper arguments that best answers the given prompt. Respond in the format {name: function name, parameters: dictionary of argument name and its value}.Do not use variables.\n".to_string()
957967
}
958968

959-
#[derive(Clone, Debug, Deserialize, PartialEq, Serialize, ToSchema)]
960-
#[schema(example = "auto")]
961-
/// Controls which (if any) tool is called by the model.
962-
pub enum ToolType {
963-
/// Means the model can pick between generating a message or calling one or more tools.
964-
#[schema(rename = "auto")]
965-
OneOf,
966-
/// Means the model will not call any tool and instead generates a message.
967-
#[schema(rename = "none")]
968-
NoTool,
969-
/// Forces the model to call a specific tool.
970-
#[schema(rename = "function")]
971-
#[serde(alias = "function")]
972-
Function(FunctionName),
973-
}
974-
975969
#[derive(Clone, Debug, Deserialize, PartialEq, Serialize)]
976970
#[serde(tag = "type")]
977971
pub enum TypedChoice {
@@ -984,29 +978,59 @@ pub struct FunctionName {
984978
pub name: String,
985979
}
986980

987-
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Default, ToSchema)]
981+
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, ToSchema, Default)]
988982
#[serde(from = "ToolTypeDeserializer")]
989-
pub struct ToolChoice(pub Option<ToolType>);
983+
pub enum ChatCompletionToolChoiceOption {
984+
/// Means the model can pick between generating a message or calling one or more tools.
985+
#[schema(rename = "auto")]
986+
Auto,
987+
/// Means the model will not call any tool and instead generates a message.
988+
#[schema(rename = "none")]
989+
#[default]
990+
NoTool,
991+
/// Means the model must call one or more tools.
992+
#[schema(rename = "required")]
993+
Required,
994+
/// Forces the model to call a specific tool.
995+
#[schema(rename = "function")]
996+
#[serde(alias = "function")]
997+
Function(FunctionName),
998+
}
990999

991-
#[derive(Deserialize)]
1000+
#[derive(Deserialize, ToSchema)]
9921001
#[serde(untagged)]
1002+
/// Controls which (if any) tool is called by the model.
1003+
/// - `none` means the model will not call any tool and instead generates a message.
1004+
/// - `auto` means the model can pick between generating a message or calling one or more tools.
1005+
/// - `required` means the model must call one or more tools.
1006+
/// - Specifying a particular tool via `{\"type\": \"function\", \"function\": {\"name\": \"my_function\"}}` forces the model to call that tool.
1007+
///
1008+
/// `none` is the default when no tools are present. `auto` is the default if tools are present."
9931009
enum ToolTypeDeserializer {
1010+
/// `none` means the model will not call any tool and instead generates a message.
9941011
Null,
1012+
1013+
/// `auto` means the model can pick between generating a message or calling one or more tools.
1014+
#[schema(example = "auto")]
9951015
String(String),
996-
ToolType(TypedChoice),
1016+
1017+
/// Specifying a particular tool forces the model to call that tool, with structured function details.
1018+
#[schema(example = r#"{"type": "function", "function": {"name": "my_function"}}"#)]
1019+
TypedChoice(TypedChoice),
9971020
}
9981021

999-
impl From<ToolTypeDeserializer> for ToolChoice {
1022+
impl From<ToolTypeDeserializer> for ChatCompletionToolChoiceOption {
10001023
fn from(value: ToolTypeDeserializer) -> Self {
10011024
match value {
1002-
ToolTypeDeserializer::Null => ToolChoice(None),
1025+
ToolTypeDeserializer::Null => ChatCompletionToolChoiceOption::NoTool,
10031026
ToolTypeDeserializer::String(s) => match s.as_str() {
1004-
"none" => ToolChoice(Some(ToolType::NoTool)),
1005-
"auto" => ToolChoice(Some(ToolType::OneOf)),
1006-
_ => ToolChoice(Some(ToolType::Function(FunctionName { name: s }))),
1027+
"none" => ChatCompletionToolChoiceOption::NoTool,
1028+
"auto" => ChatCompletionToolChoiceOption::Auto,
1029+
"required" => ChatCompletionToolChoiceOption::Required,
1030+
_ => ChatCompletionToolChoiceOption::Function(FunctionName { name: s }),
10071031
},
1008-
ToolTypeDeserializer::ToolType(TypedChoice::Function { function }) => {
1009-
ToolChoice(Some(ToolType::Function(function)))
1032+
ToolTypeDeserializer::TypedChoice(TypedChoice::Function { function }) => {
1033+
ChatCompletionToolChoiceOption::Function(function)
10101034
}
10111035
}
10121036
}
@@ -1619,20 +1643,27 @@ mod tests {
16191643
fn tool_choice_formats() {
16201644
#[derive(Deserialize)]
16211645
struct TestRequest {
1622-
tool_choice: ToolChoice,
1646+
tool_choice: ChatCompletionToolChoiceOption,
16231647
}
16241648

16251649
let none = r#"{"tool_choice":"none"}"#;
16261650
let de_none: TestRequest = serde_json::from_str(none).unwrap();
1627-
assert_eq!(de_none.tool_choice, ToolChoice(Some(ToolType::NoTool)));
1651+
assert_eq!(de_none.tool_choice, ChatCompletionToolChoiceOption::NoTool);
16281652

16291653
let auto = r#"{"tool_choice":"auto"}"#;
16301654
let de_auto: TestRequest = serde_json::from_str(auto).unwrap();
1631-
assert_eq!(de_auto.tool_choice, ToolChoice(Some(ToolType::OneOf)));
1655+
assert_eq!(de_auto.tool_choice, ChatCompletionToolChoiceOption::Auto);
16321656

1633-
let ref_choice = ToolChoice(Some(ToolType::Function(FunctionName {
1657+
let auto = r#"{"tool_choice":"required"}"#;
1658+
let de_auto: TestRequest = serde_json::from_str(auto).unwrap();
1659+
assert_eq!(
1660+
de_auto.tool_choice,
1661+
ChatCompletionToolChoiceOption::Required
1662+
);
1663+
1664+
let ref_choice = ChatCompletionToolChoiceOption::Function(FunctionName {
16341665
name: "myfn".to_string(),
1635-
})));
1666+
});
16361667

16371668
let named = r#"{"tool_choice":"myfn"}"#;
16381669
let de_named: TestRequest = serde_json::from_str(named).unwrap();

router/src/server.rs

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ use crate::{
2323
ChatRequest, Chunk, CompatGenerateRequest, Completion, CompletionComplete, CompletionFinal,
2424
CompletionRequest, CompletionType, DeltaToolCall, Function, Prompt, Tool,
2525
};
26-
use crate::{FunctionDefinition, HubPreprocessorConfig, ToolCall, ToolChoice, ToolType};
26+
use crate::{ChatCompletionToolChoiceOption, FunctionDefinition, HubPreprocessorConfig, ToolCall};
2727
use crate::{ModelInfo, ModelsInfo};
2828
use async_stream::__private::AsyncStream;
2929
use axum::extract::Extension;
@@ -1560,12 +1560,11 @@ GrammarType,
15601560
Usage,
15611561
StreamOptions,
15621562
DeltaToolCall,
1563-
ToolType,
15641563
Tool,
15651564
ToolCall,
15661565
Function,
15671566
FunctionDefinition,
1568-
ToolChoice,
1567+
ChatCompletionToolChoiceOption,
15691568
ModelInfo,
15701569
)
15711570
),
@@ -2518,7 +2517,7 @@ pub(crate) fn prepare_chat_input(
25182517
infer: &Infer,
25192518
response_format: Option<GrammarType>,
25202519
tools: Option<Vec<Tool>>,
2521-
tool_choice: ToolChoice,
2520+
tool_choice: ChatCompletionToolChoiceOption,
25222521
tool_prompt: &str,
25232522
guideline: Option<String>,
25242523
messages: Vec<Message>,
@@ -2653,7 +2652,7 @@ mod tests {
26532652
&infer,
26542653
response_format,
26552654
tools,
2656-
ToolChoice(None),
2655+
ChatCompletionToolChoiceOption::Auto,
26572656
tool_prompt,
26582657
guideline,
26592658
messages,

0 commit comments

Comments
 (0)