Skip to content

Commit

Permalink
feat: improve, simplify and rename tool choice struct add required su…
Browse files Browse the repository at this point in the history
…pport and refactor
  • Loading branch information
drbh committed Oct 14, 2024
1 parent 72a0305 commit 37a1239
Show file tree
Hide file tree
Showing 5 changed files with 143 additions and 112 deletions.
78 changes: 38 additions & 40 deletions docs/openapi.json
Original file line number Diff line number Diff line change
Expand Up @@ -829,6 +829,42 @@
}
}
},
"ChatCompletionToolChoiceOption": {
"oneOf": [
{
"type": "string",
"description": "Means the model can pick between generating a message or calling one or more tools.",
"enum": [
"auto"
]
},
{
"type": "string",
"description": "Means the model will not call any tool and instead generates a message.",
"enum": [
"none"
]
},
{
"type": "string",
"description": "Means the model must call one or more tools.",
"enum": [
"required"
]
},
{
"type": "object",
"required": [
"function"
],
"properties": {
"function": {
"$ref": "#/components/schemas/FunctionName"
}
}
}
]
},
"ChatCompletionTopLogprob": {
"type": "object",
"required": [
Expand Down Expand Up @@ -963,9 +999,10 @@
"tool_choice": {
"allOf": [
{
"$ref": "#/components/schemas/ToolChoice"
"$ref": "#/components/schemas/ChatCompletionToolChoiceOption"
}
],
"default": "null",
"nullable": true
},
"tool_prompt": {
Expand Down Expand Up @@ -2103,45 +2140,6 @@
}
}
},
"ToolChoice": {
"allOf": [
{
"$ref": "#/components/schemas/ToolType"
}
],
"nullable": true
},
"ToolType": {
"oneOf": [
{
"type": "string",
"description": "Means the model can pick between generating a message or calling one or more tools.",
"enum": [
"auto"
]
},
{
"type": "string",
"description": "Means the model will not call any tool and instead generates a message.",
"enum": [
"none"
]
},
{
"type": "object",
"required": [
"function"
],
"properties": {
"function": {
"$ref": "#/components/schemas/FunctionName"
}
}
}
],
"description": "Controls which (if any) tool is called by the model.",
"example": "auto"
},
"Url": {
"type": "object",
"required": [
Expand Down
57 changes: 30 additions & 27 deletions router/src/infer/tool_grammar.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use crate::infer::InferError;
use crate::{
FunctionDefinition, FunctionRef, FunctionsMap, JsonSchemaTool, Properties, Tool, ToolChoice,
ToolType,
ChatCompletionToolChoiceOption, FunctionDefinition, FunctionRef, FunctionsMap, JsonSchemaTool,
Properties, Tool,
};
use serde_json::{json, Map, Value};
use std::collections::HashMap;
Expand All @@ -20,44 +20,47 @@ impl ToolGrammar {

pub fn apply(
tools: Vec<Tool>,
tool_choice: ToolChoice,
tool_choice: ChatCompletionToolChoiceOption,
) -> Result<(Vec<Tool>, Option<JsonSchemaTool>), InferError> {
// if no tools are provided, we return None
if tools.is_empty() {
return Ok((tools, None));
}

let tool_choice = tool_choice.0.unwrap_or(ToolType::OneOf);

let mut tools = tools.clone();

// add the no_tool function to the tools
let no_tool = Tool {
r#type: "function".to_string(),
function: FunctionDefinition {
name: "no_tool".to_string(),
description: Some("Open ened response with no specific tool selected".to_string()),
arguments: json!({
"type": "object",
"properties": {
"content": {
"type": "string",
"description": "The response content",
}
},
"required": ["content"]
}),
},
};
tools.push(no_tool);
// add the no_tool function to the tools as long as we are not required to use a specific tool
if tool_choice != ChatCompletionToolChoiceOption::Required {
let no_tool = Tool {
r#type: "function".to_string(),
function: FunctionDefinition {
name: "no_tool".to_string(),
description: Some(
"Open ended response with no specific tool selected".to_string(),
),
arguments: json!({
"type": "object",
"properties": {
"content": {
"type": "string",
"description": "The response content",
}
},
"required": ["content"]
}),
},
};
tools.push(no_tool);
}

// if tools are provided and no tool_choice we default to the OneOf
let tools_to_use = match tool_choice {
ToolType::Function(function) => {
ChatCompletionToolChoiceOption::Function(function) => {
vec![Self::find_tool_by_name(&tools, &function.name)?]
}
ToolType::OneOf => tools.clone(),
ToolType::NoTool => return Ok((tools, None)),
ChatCompletionToolChoiceOption::Required => tools.clone(),
ChatCompletionToolChoiceOption::Auto => tools.clone(),
ChatCompletionToolChoiceOption::NoTool => return Ok((tools, None)),
};

let functions: HashMap<String, serde_json::Value> = tools_to_use
Expand Down
103 changes: 67 additions & 36 deletions router/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -849,8 +849,8 @@ pub(crate) struct ChatRequest {

/// A specific tool to use. If not provided, the model will default to use any of the tools provided in the tools parameter.
#[serde(default)]
#[schema(nullable = true, example = "null")]
pub tool_choice: ToolChoice,
#[schema(nullable = true, default = "null", example = "null")]
pub tool_choice: Option<ChatCompletionToolChoiceOption>,

/// Response format constraints for the generation.
///
Expand Down Expand Up @@ -906,8 +906,18 @@ impl ChatRequest {
let (inputs, grammar, using_tools) = prepare_chat_input(
infer,
response_format,
tools,
tool_choice,
tools.clone(),
// unwrap or default (use "auto" if tools are present, and "none" if not)
tool_choice.map_or_else(
|| {
if tools.is_some() {
ChatCompletionToolChoiceOption::Auto
} else {
ChatCompletionToolChoiceOption::NoTool
}
},
|t| t,
),
&tool_prompt,
guideline,
messages,
Expand Down Expand Up @@ -956,22 +966,6 @@ pub fn default_tool_prompt() -> String {
"\nGiven the functions available, please respond with a JSON for a function call with its proper arguments that best answers the given prompt. Respond in the format {name: function name, parameters: dictionary of argument name and its value}.Do not use variables.\n".to_string()
}

#[derive(Clone, Debug, Deserialize, PartialEq, Serialize, ToSchema)]
#[schema(example = "auto")]
/// Controls which (if any) tool is called by the model.
pub enum ToolType {
/// Means the model can pick between generating a message or calling one or more tools.
#[schema(rename = "auto")]
OneOf,
/// Means the model will not call any tool and instead generates a message.
#[schema(rename = "none")]
NoTool,
/// Forces the model to call a specific tool.
#[schema(rename = "function")]
#[serde(alias = "function")]
Function(FunctionName),
}

#[derive(Clone, Debug, Deserialize, PartialEq, Serialize)]
#[serde(tag = "type")]
pub enum TypedChoice {
Expand All @@ -984,29 +978,59 @@ pub struct FunctionName {
pub name: String,
}

#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Default, ToSchema)]
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, ToSchema, Default)]
#[serde(from = "ToolTypeDeserializer")]
pub struct ToolChoice(pub Option<ToolType>);
pub enum ChatCompletionToolChoiceOption {
/// Means the model can pick between generating a message or calling one or more tools.
#[schema(rename = "auto")]
Auto,
/// Means the model will not call any tool and instead generates a message.
#[schema(rename = "none")]
#[default]
NoTool,
/// Means the model must call one or more tools.
#[schema(rename = "required")]
Required,
/// Forces the model to call a specific tool.
#[schema(rename = "function")]
#[serde(alias = "function")]
Function(FunctionName),
}

#[derive(Deserialize)]
#[derive(Deserialize, ToSchema)]
#[serde(untagged)]
/// Controls which (if any) tool is called by the model.
/// - `none` means the model will not call any tool and instead generates a message.
/// - `auto` means the model can pick between generating a message or calling one or more tools.
/// - `required` means the model must call one or more tools.
/// - Specifying a particular tool via `{\"type\": \"function\", \"function\": {\"name\": \"my_function\"}}` forces the model to call that tool.
///
/// `none` is the default when no tools are present. `auto` is the default if tools are present."
enum ToolTypeDeserializer {
/// `none` means the model will not call any tool and instead generates a message.
Null,

/// `auto` means the model can pick between generating a message or calling one or more tools.
#[schema(example = "auto")]
String(String),
ToolType(TypedChoice),

/// Specifying a particular tool forces the model to call that tool, with structured function details.
#[schema(example = r#"{"type": "function", "function": {"name": "my_function"}}"#)]
TypedChoice(TypedChoice),
}

impl From<ToolTypeDeserializer> for ToolChoice {
impl From<ToolTypeDeserializer> for ChatCompletionToolChoiceOption {
fn from(value: ToolTypeDeserializer) -> Self {
match value {
ToolTypeDeserializer::Null => ToolChoice(None),
ToolTypeDeserializer::Null => ChatCompletionToolChoiceOption::NoTool,
ToolTypeDeserializer::String(s) => match s.as_str() {
"none" => ToolChoice(Some(ToolType::NoTool)),
"auto" => ToolChoice(Some(ToolType::OneOf)),
_ => ToolChoice(Some(ToolType::Function(FunctionName { name: s }))),
"none" => ChatCompletionToolChoiceOption::NoTool,
"auto" => ChatCompletionToolChoiceOption::Auto,
"required" => ChatCompletionToolChoiceOption::Required,
_ => ChatCompletionToolChoiceOption::Function(FunctionName { name: s }),
},
ToolTypeDeserializer::ToolType(TypedChoice::Function { function }) => {
ToolChoice(Some(ToolType::Function(function)))
ToolTypeDeserializer::TypedChoice(TypedChoice::Function { function }) => {
ChatCompletionToolChoiceOption::Function(function)
}
}
}
Expand Down Expand Up @@ -1619,20 +1643,27 @@ mod tests {
fn tool_choice_formats() {
#[derive(Deserialize)]
struct TestRequest {
tool_choice: ToolChoice,
tool_choice: ChatCompletionToolChoiceOption,
}

let none = r#"{"tool_choice":"none"}"#;
let de_none: TestRequest = serde_json::from_str(none).unwrap();
assert_eq!(de_none.tool_choice, ToolChoice(Some(ToolType::NoTool)));
assert_eq!(de_none.tool_choice, ChatCompletionToolChoiceOption::NoTool);

let auto = r#"{"tool_choice":"auto"}"#;
let de_auto: TestRequest = serde_json::from_str(auto).unwrap();
assert_eq!(de_auto.tool_choice, ToolChoice(Some(ToolType::OneOf)));
assert_eq!(de_auto.tool_choice, ChatCompletionToolChoiceOption::Auto);

let ref_choice = ToolChoice(Some(ToolType::Function(FunctionName {
let auto = r#"{"tool_choice":"required"}"#;
let de_auto: TestRequest = serde_json::from_str(auto).unwrap();
assert_eq!(
de_auto.tool_choice,
ChatCompletionToolChoiceOption::Required
);

let ref_choice = ChatCompletionToolChoiceOption::Function(FunctionName {
name: "myfn".to_string(),
})));
});

let named = r#"{"tool_choice":"myfn"}"#;
let de_named: TestRequest = serde_json::from_str(named).unwrap();
Expand Down
9 changes: 4 additions & 5 deletions router/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ use crate::{
ChatRequest, Chunk, CompatGenerateRequest, Completion, CompletionComplete, CompletionFinal,
CompletionRequest, CompletionType, DeltaToolCall, Function, Prompt, Tool,
};
use crate::{FunctionDefinition, HubPreprocessorConfig, ToolCall, ToolChoice, ToolType};
use crate::{ChatCompletionToolChoiceOption, FunctionDefinition, HubPreprocessorConfig, ToolCall};
use crate::{ModelInfo, ModelsInfo};
use async_stream::__private::AsyncStream;
use axum::extract::Extension;
Expand Down Expand Up @@ -1560,12 +1560,11 @@ GrammarType,
Usage,
StreamOptions,
DeltaToolCall,
ToolType,
Tool,
ToolCall,
Function,
FunctionDefinition,
ToolChoice,
ChatCompletionToolChoiceOption,
ModelInfo,
)
),
Expand Down Expand Up @@ -2518,7 +2517,7 @@ pub(crate) fn prepare_chat_input(
infer: &Infer,
response_format: Option<GrammarType>,
tools: Option<Vec<Tool>>,
tool_choice: ToolChoice,
tool_choice: ChatCompletionToolChoiceOption,
tool_prompt: &str,
guideline: Option<String>,
messages: Vec<Message>,
Expand Down Expand Up @@ -2653,7 +2652,7 @@ mod tests {
&infer,
response_format,
tools,
ToolChoice(None),
ChatCompletionToolChoiceOption::Auto,
tool_prompt,
guideline,
messages,
Expand Down
Loading

0 comments on commit 37a1239

Please sign in to comment.