diff --git a/docs/openapi.json b/docs/openapi.json index d1b60f4d410..a8269181a66 100644 --- a/docs/openapi.json +++ b/docs/openapi.json @@ -829,6 +829,42 @@ } } }, + "ChatCompletionToolChoiceOption": { + "oneOf": [ + { + "type": "string", + "description": "Means the model can pick between generating a message or calling one or more tools.", + "enum": [ + "auto" + ] + }, + { + "type": "string", + "description": "Means the model will not call any tool and instead generates a message.", + "enum": [ + "none" + ] + }, + { + "type": "string", + "description": "Means the model must call one or more tools.", + "enum": [ + "required" + ] + }, + { + "type": "object", + "required": [ + "function" + ], + "properties": { + "function": { + "$ref": "#/components/schemas/FunctionName" + } + } + } + ] + }, "ChatCompletionTopLogprob": { "type": "object", "required": [ @@ -963,9 +999,10 @@ "tool_choice": { "allOf": [ { - "$ref": "#/components/schemas/ToolChoice" + "$ref": "#/components/schemas/ChatCompletionToolChoiceOption" } ], + "default": "null", "nullable": true }, "tool_prompt": { @@ -2103,45 +2140,6 @@ } } }, - "ToolChoice": { - "allOf": [ - { - "$ref": "#/components/schemas/ToolType" - } - ], - "nullable": true - }, - "ToolType": { - "oneOf": [ - { - "type": "string", - "description": "Means the model can pick between generating a message or calling one or more tools.", - "enum": [ - "auto" - ] - }, - { - "type": "string", - "description": "Means the model will not call any tool and instead generates a message.", - "enum": [ - "none" - ] - }, - { - "type": "object", - "required": [ - "function" - ], - "properties": { - "function": { - "$ref": "#/components/schemas/FunctionName" - } - } - } - ], - "description": "Controls which (if any) tool is called by the model.", - "example": "auto" - }, "Url": { "type": "object", "required": [ diff --git a/router/src/infer/tool_grammar.rs b/router/src/infer/tool_grammar.rs index f86205fb532..b9070812a76 100644 --- a/router/src/infer/tool_grammar.rs +++ b/router/src/infer/tool_grammar.rs @@ -1,7 +1,7 @@ use crate::infer::InferError; use crate::{ - FunctionDefinition, FunctionRef, FunctionsMap, JsonSchemaTool, Properties, Tool, ToolChoice, - ToolType, + ChatCompletionToolChoiceOption, FunctionDefinition, FunctionRef, FunctionsMap, JsonSchemaTool, + Properties, Tool, }; use serde_json::{json, Map, Value}; use std::collections::HashMap; @@ -20,44 +20,47 @@ impl ToolGrammar { pub fn apply( tools: Vec, - tool_choice: ToolChoice, + tool_choice: ChatCompletionToolChoiceOption, ) -> Result<(Vec, Option), InferError> { // if no tools are provided, we return None if tools.is_empty() { return Ok((tools, None)); } - let tool_choice = tool_choice.0.unwrap_or(ToolType::OneOf); - let mut tools = tools.clone(); - // add the no_tool function to the tools - let no_tool = Tool { - r#type: "function".to_string(), - function: FunctionDefinition { - name: "no_tool".to_string(), - description: Some("Open ened response with no specific tool selected".to_string()), - arguments: json!({ - "type": "object", - "properties": { - "content": { - "type": "string", - "description": "The response content", - } - }, - "required": ["content"] - }), - }, - }; - tools.push(no_tool); + // add the no_tool function to the tools as long as we are not required to use a specific tool + if tool_choice != ChatCompletionToolChoiceOption::Required { + let no_tool = Tool { + r#type: "function".to_string(), + function: FunctionDefinition { + name: "no_tool".to_string(), + description: Some( + "Open ended response with no specific tool selected".to_string(), + ), + arguments: json!({ + "type": "object", + "properties": { + "content": { + "type": "string", + "description": "The response content", + } + }, + "required": ["content"] + }), + }, + }; + tools.push(no_tool); + } // if tools are provided and no tool_choice we default to the OneOf let tools_to_use = match tool_choice { - ToolType::Function(function) => { + ChatCompletionToolChoiceOption::Function(function) => { vec![Self::find_tool_by_name(&tools, &function.name)?] } - ToolType::OneOf => tools.clone(), - ToolType::NoTool => return Ok((tools, None)), + ChatCompletionToolChoiceOption::Required => tools.clone(), + ChatCompletionToolChoiceOption::Auto => tools.clone(), + ChatCompletionToolChoiceOption::NoTool => return Ok((tools, None)), }; let functions: HashMap = tools_to_use diff --git a/router/src/lib.rs b/router/src/lib.rs index 1f57df904b2..635e44d296d 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -849,8 +849,8 @@ pub(crate) struct ChatRequest { /// A specific tool to use. If not provided, the model will default to use any of the tools provided in the tools parameter. #[serde(default)] - #[schema(nullable = true, example = "null")] - pub tool_choice: ToolChoice, + #[schema(nullable = true, default = "null", example = "null")] + pub tool_choice: Option, /// Response format constraints for the generation. /// @@ -906,8 +906,18 @@ impl ChatRequest { let (inputs, grammar, using_tools) = prepare_chat_input( infer, response_format, - tools, - tool_choice, + tools.clone(), + // unwrap or default (use "auto" if tools are present, and "none" if not) + tool_choice.map_or_else( + || { + if tools.is_some() { + ChatCompletionToolChoiceOption::Auto + } else { + ChatCompletionToolChoiceOption::NoTool + } + }, + |t| t, + ), &tool_prompt, guideline, messages, @@ -956,22 +966,6 @@ pub fn default_tool_prompt() -> String { "\nGiven the functions available, please respond with a JSON for a function call with its proper arguments that best answers the given prompt. Respond in the format {name: function name, parameters: dictionary of argument name and its value}.Do not use variables.\n".to_string() } -#[derive(Clone, Debug, Deserialize, PartialEq, Serialize, ToSchema)] -#[schema(example = "auto")] -/// Controls which (if any) tool is called by the model. -pub enum ToolType { - /// Means the model can pick between generating a message or calling one or more tools. - #[schema(rename = "auto")] - OneOf, - /// Means the model will not call any tool and instead generates a message. - #[schema(rename = "none")] - NoTool, - /// Forces the model to call a specific tool. - #[schema(rename = "function")] - #[serde(alias = "function")] - Function(FunctionName), -} - #[derive(Clone, Debug, Deserialize, PartialEq, Serialize)] #[serde(tag = "type")] pub enum TypedChoice { @@ -984,29 +978,59 @@ pub struct FunctionName { pub name: String, } -#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Default, ToSchema)] +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, ToSchema, Default)] #[serde(from = "ToolTypeDeserializer")] -pub struct ToolChoice(pub Option); +pub enum ChatCompletionToolChoiceOption { + /// Means the model can pick between generating a message or calling one or more tools. + #[schema(rename = "auto")] + Auto, + /// Means the model will not call any tool and instead generates a message. + #[schema(rename = "none")] + #[default] + NoTool, + /// Means the model must call one or more tools. + #[schema(rename = "required")] + Required, + /// Forces the model to call a specific tool. + #[schema(rename = "function")] + #[serde(alias = "function")] + Function(FunctionName), +} -#[derive(Deserialize)] +#[derive(Deserialize, ToSchema)] #[serde(untagged)] +/// Controls which (if any) tool is called by the model. +/// - `none` means the model will not call any tool and instead generates a message. +/// - `auto` means the model can pick between generating a message or calling one or more tools. +/// - `required` means the model must call one or more tools. +/// - Specifying a particular tool via `{\"type\": \"function\", \"function\": {\"name\": \"my_function\"}}` forces the model to call that tool. +/// +/// `none` is the default when no tools are present. `auto` is the default if tools are present." enum ToolTypeDeserializer { + /// `none` means the model will not call any tool and instead generates a message. Null, + + /// `auto` means the model can pick between generating a message or calling one or more tools. + #[schema(example = "auto")] String(String), - ToolType(TypedChoice), + + /// Specifying a particular tool forces the model to call that tool, with structured function details. + #[schema(example = r#"{"type": "function", "function": {"name": "my_function"}}"#)] + TypedChoice(TypedChoice), } -impl From for ToolChoice { +impl From for ChatCompletionToolChoiceOption { fn from(value: ToolTypeDeserializer) -> Self { match value { - ToolTypeDeserializer::Null => ToolChoice(None), + ToolTypeDeserializer::Null => ChatCompletionToolChoiceOption::NoTool, ToolTypeDeserializer::String(s) => match s.as_str() { - "none" => ToolChoice(Some(ToolType::NoTool)), - "auto" => ToolChoice(Some(ToolType::OneOf)), - _ => ToolChoice(Some(ToolType::Function(FunctionName { name: s }))), + "none" => ChatCompletionToolChoiceOption::NoTool, + "auto" => ChatCompletionToolChoiceOption::Auto, + "required" => ChatCompletionToolChoiceOption::Required, + _ => ChatCompletionToolChoiceOption::Function(FunctionName { name: s }), }, - ToolTypeDeserializer::ToolType(TypedChoice::Function { function }) => { - ToolChoice(Some(ToolType::Function(function))) + ToolTypeDeserializer::TypedChoice(TypedChoice::Function { function }) => { + ChatCompletionToolChoiceOption::Function(function) } } } @@ -1619,20 +1643,27 @@ mod tests { fn tool_choice_formats() { #[derive(Deserialize)] struct TestRequest { - tool_choice: ToolChoice, + tool_choice: ChatCompletionToolChoiceOption, } let none = r#"{"tool_choice":"none"}"#; let de_none: TestRequest = serde_json::from_str(none).unwrap(); - assert_eq!(de_none.tool_choice, ToolChoice(Some(ToolType::NoTool))); + assert_eq!(de_none.tool_choice, ChatCompletionToolChoiceOption::NoTool); let auto = r#"{"tool_choice":"auto"}"#; let de_auto: TestRequest = serde_json::from_str(auto).unwrap(); - assert_eq!(de_auto.tool_choice, ToolChoice(Some(ToolType::OneOf))); + assert_eq!(de_auto.tool_choice, ChatCompletionToolChoiceOption::Auto); - let ref_choice = ToolChoice(Some(ToolType::Function(FunctionName { + let auto = r#"{"tool_choice":"required"}"#; + let de_auto: TestRequest = serde_json::from_str(auto).unwrap(); + assert_eq!( + de_auto.tool_choice, + ChatCompletionToolChoiceOption::Required + ); + + let ref_choice = ChatCompletionToolChoiceOption::Function(FunctionName { name: "myfn".to_string(), - }))); + }); let named = r#"{"tool_choice":"myfn"}"#; let de_named: TestRequest = serde_json::from_str(named).unwrap(); diff --git a/router/src/server.rs b/router/src/server.rs index 5e6e696037e..4b571663338 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -23,7 +23,7 @@ use crate::{ ChatRequest, Chunk, CompatGenerateRequest, Completion, CompletionComplete, CompletionFinal, CompletionRequest, CompletionType, DeltaToolCall, Function, Prompt, Tool, }; -use crate::{FunctionDefinition, HubPreprocessorConfig, ToolCall, ToolChoice, ToolType}; +use crate::{ChatCompletionToolChoiceOption, FunctionDefinition, HubPreprocessorConfig, ToolCall}; use crate::{ModelInfo, ModelsInfo}; use async_stream::__private::AsyncStream; use axum::extract::Extension; @@ -1560,12 +1560,11 @@ GrammarType, Usage, StreamOptions, DeltaToolCall, -ToolType, Tool, ToolCall, Function, FunctionDefinition, -ToolChoice, +ChatCompletionToolChoiceOption, ModelInfo, ) ), @@ -2518,7 +2517,7 @@ pub(crate) fn prepare_chat_input( infer: &Infer, response_format: Option, tools: Option>, - tool_choice: ToolChoice, + tool_choice: ChatCompletionToolChoiceOption, tool_prompt: &str, guideline: Option, messages: Vec, @@ -2653,7 +2652,7 @@ mod tests { &infer, response_format, tools, - ToolChoice(None), + ChatCompletionToolChoiceOption::Auto, tool_prompt, guideline, messages, diff --git a/router/src/vertex.rs b/router/src/vertex.rs index 0c1467fe373..a54603bd6fe 100644 --- a/router/src/vertex.rs +++ b/router/src/vertex.rs @@ -1,8 +1,8 @@ use crate::infer::Infer; use crate::server::{generate_internal, ComputeType}; use crate::{ - ChatRequest, ErrorResponse, GenerateParameters, GenerateRequest, GrammarType, Message, - StreamOptions, Tool, ToolChoice, + ChatCompletionToolChoiceOption, ChatRequest, ErrorResponse, GenerateParameters, + GenerateRequest, GrammarType, Message, StreamOptions, Tool, }; use axum::extract::Extension; use axum::http::{HeaderMap, StatusCode}; @@ -124,7 +124,7 @@ pub(crate) struct VertexParameters { /// A specific tool to use. If not provided, the model will default to use any of the tools provided in the tools parameter. #[serde(default)] #[schema(nullable = true, example = "null")] - pub tool_choice: ToolChoice, + pub tool_choice: ChatCompletionToolChoiceOption, /// Response format constraints for the generation. /// @@ -162,7 +162,7 @@ impl From for ChatRequest { stream_options: val.parameters.stream_options, stream: val.parameters.stream, temperature: val.parameters.temperature, - tool_choice: val.parameters.tool_choice, + tool_choice: Some(val.parameters.tool_choice), tool_prompt: val.parameters.tool_prompt, tools: val.parameters.tools, top_logprobs: val.parameters.top_logprobs,