From f979ff19651b89f8f8bb9c4f5a7d6d0cc6e69a7b Mon Sep 17 00:00:00 2001 From: Linus Bierhoff Date: Thu, 10 Oct 2024 18:50:32 +0200 Subject: [PATCH 01/15] add OpenAI like tool_choice for named choice --- docs/openapi.json | 18 ++++++++++++++++++ router/src/lib.rs | 12 ++++++++++++ 2 files changed, 30 insertions(+) diff --git a/docs/openapi.json b/docs/openapi.json index 903f742629f..280754bcaec 100644 --- a/docs/openapi.json +++ b/docs/openapi.json @@ -2268,6 +2268,24 @@ "$ref": "#/components/schemas/FunctionName" } } + }, + { + "type": "object", + "required": [ + "type", + "function" + ], + "properties": { + "type": { + "type": "string", + "enum": [ + "function" + ] + }, + "function": { + "$ref": "#/components/schemas/FunctionName" + } + } } ], "description": "Controls which (if any) tool is called by the model.", diff --git a/router/src/lib.rs b/router/src/lib.rs index a5613f89237..52a8a64a9b0 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -1011,9 +1011,18 @@ pub enum ToolType { NoTool, /// Forces the model to call a specific tool. #[schema(rename = "function")] + #[serde(alias = "function")] Function(FunctionName), } + +#[derive(Clone, Debug, Deserialize, PartialEq, Serialize)] +#[serde(tag = "type")] +pub enum TypedChoice { + #[serde(rename = "function")] + Function{function: FunctionName}, +} + #[derive(Debug, Clone, PartialEq, Serialize, Deserialize, ToSchema)] pub struct FunctionName { pub name: String, @@ -1029,8 +1038,10 @@ enum ToolTypeDeserializer { Null, String(String), ToolType(ToolType), + TypedChoice(TypedChoice) //this is the OpenAI schema } + impl From for ToolChoice { fn from(value: ToolTypeDeserializer) -> Self { match value { @@ -1040,6 +1051,7 @@ impl From for ToolChoice { "auto" => ToolChoice(Some(ToolType::OneOf)), _ => ToolChoice(Some(ToolType::Function(FunctionName { name: s }))), }, + ToolTypeDeserializer::TypedChoice(TypedChoice::Function{function}) => ToolChoice(Some(ToolType::Function(function))), ToolTypeDeserializer::ToolType(tool_type) => ToolChoice(Some(tool_type)), } } From 151f950eea56c7449870024dc9f3c6f35f1f18ae Mon Sep 17 00:00:00 2001 From: Linus Bierhoff Date: Thu, 10 Oct 2024 19:41:51 +0200 Subject: [PATCH 02/15] add tests --- router/src/lib.rs | 32 ++++++++++++++++++++++++++++++++ 1 file changed, 32 insertions(+) diff --git a/router/src/lib.rs b/router/src/lib.rs index 52a8a64a9b0..c4189904d4d 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -1658,4 +1658,36 @@ mod tests { r#"{"role":"assistant","tool_calls":[{"id":"0","type":"function","function":{"description":null,"name":"myfn","arguments":{"format":"csv"}}}]}"# ); } + + #[test] + fn tool_choice_formats() { + + #[derive(Deserialize)] + struct TestRequest { + tool_choice: ToolChoice, + } + + let none = r#"{"tool_choice":"none"}"#; + let de_none: TestRequest = serde_json::from_str(none).unwrap(); + assert_eq!(de_none.tool_choice, ToolChoice(Some(ToolType::NoTool))); + + let auto = r#"{"tool_choice":"auto"}"#; + let de_auto: TestRequest = serde_json::from_str(auto).unwrap(); + assert_eq!(de_auto.tool_choice, ToolChoice(Some(ToolType::OneOf))); + + let ref_choice = ToolChoice(Some(ToolType::Function(FunctionName { name: "myfn".to_string() }))); + + let named = r#"{"tool_choice":"myfn"}"#; + let de_named: TestRequest = serde_json::from_str(named).unwrap(); + assert_eq!(de_named.tool_choice, ref_choice); + + let old_named = r#"{"tool_choice":{"function":{"name":"myfn"}}}"#; + let de_old_named: TestRequest = serde_json::from_str(old_named).unwrap(); + assert_eq!(de_old_named.tool_choice, ref_choice); + + let openai_named = r#"{"tool_choice":{"type":"function","function":{"name":"myfn"}}}"#; + let de_openai_named: TestRequest = serde_json::from_str(openai_named).unwrap(); + + assert_eq!(de_openai_named.tool_choice, ref_choice); + } } From 2c172a2da72e4a4383d19c1a60b375f4c03a5aa2 Mon Sep 17 00:00:00 2001 From: David Holtz Date: Mon, 14 Oct 2024 14:26:45 +0000 Subject: [PATCH 03/15] fix: run linter and bump api docs --- docs/openapi.json | 35 +++++++---------------------------- router/src/lib.rs | 15 ++++++++------- 2 files changed, 15 insertions(+), 35 deletions(-) diff --git a/docs/openapi.json b/docs/openapi.json index 280754bcaec..756afe1ef4a 100644 --- a/docs/openapi.json +++ b/docs/openapi.json @@ -2245,18 +2245,12 @@ "ToolType": { "oneOf": [ { - "type": "string", - "description": "Means the model can pick between generating a message or calling one or more tools.", - "enum": [ - "auto" - ] + "type": "object", + "default": null, + "nullable": true }, { - "type": "string", - "description": "Means the model will not call any tool and instead generates a message.", - "enum": [ - "none" - ] + "type": "string" }, { "type": "object", @@ -2271,25 +2265,10 @@ }, { "type": "object", - "required": [ - "type", - "function" - ], - "properties": { - "type": { - "type": "string", - "enum": [ - "function" - ] - }, - "function": { - "$ref": "#/components/schemas/FunctionName" - } - } + "default": null, + "nullable": true } - ], - "description": "Controls which (if any) tool is called by the model.", - "example": "auto" + ] }, "Url": { "type": "object", diff --git a/router/src/lib.rs b/router/src/lib.rs index c4189904d4d..6d342903c95 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -1015,12 +1015,11 @@ pub enum ToolType { Function(FunctionName), } - #[derive(Clone, Debug, Deserialize, PartialEq, Serialize)] #[serde(tag = "type")] pub enum TypedChoice { #[serde(rename = "function")] - Function{function: FunctionName}, + Function { function: FunctionName }, } #[derive(Debug, Clone, PartialEq, Serialize, Deserialize, ToSchema)] @@ -1038,10 +1037,9 @@ enum ToolTypeDeserializer { Null, String(String), ToolType(ToolType), - TypedChoice(TypedChoice) //this is the OpenAI schema + TypedChoice(TypedChoice), //this is the OpenAI schema } - impl From for ToolChoice { fn from(value: ToolTypeDeserializer) -> Self { match value { @@ -1051,7 +1049,9 @@ impl From for ToolChoice { "auto" => ToolChoice(Some(ToolType::OneOf)), _ => ToolChoice(Some(ToolType::Function(FunctionName { name: s }))), }, - ToolTypeDeserializer::TypedChoice(TypedChoice::Function{function}) => ToolChoice(Some(ToolType::Function(function))), + ToolTypeDeserializer::TypedChoice(TypedChoice::Function { function }) => { + ToolChoice(Some(ToolType::Function(function))) + } ToolTypeDeserializer::ToolType(tool_type) => ToolChoice(Some(tool_type)), } } @@ -1661,7 +1661,6 @@ mod tests { #[test] fn tool_choice_formats() { - #[derive(Deserialize)] struct TestRequest { tool_choice: ToolChoice, @@ -1675,7 +1674,9 @@ mod tests { let de_auto: TestRequest = serde_json::from_str(auto).unwrap(); assert_eq!(de_auto.tool_choice, ToolChoice(Some(ToolType::OneOf))); - let ref_choice = ToolChoice(Some(ToolType::Function(FunctionName { name: "myfn".to_string() }))); + let ref_choice = ToolChoice(Some(ToolType::Function(FunctionName { + name: "myfn".to_string(), + }))); let named = r#"{"tool_choice":"myfn"}"#; let de_named: TestRequest = serde_json::from_str(named).unwrap(); From 209f841767d0d39c503b27dfa3028234e89d07ef Mon Sep 17 00:00:00 2001 From: David Holtz Date: Mon, 14 Oct 2024 16:44:54 +0000 Subject: [PATCH 04/15] fix: consolidate changes and remove old tool type --- docs/openapi.json | 23 +++++++++++++---------- router/src/lib.rs | 10 ++-------- 2 files changed, 15 insertions(+), 18 deletions(-) diff --git a/docs/openapi.json b/docs/openapi.json index 756afe1ef4a..903f742629f 100644 --- a/docs/openapi.json +++ b/docs/openapi.json @@ -2245,12 +2245,18 @@ "ToolType": { "oneOf": [ { - "type": "object", - "default": null, - "nullable": true + "type": "string", + "description": "Means the model can pick between generating a message or calling one or more tools.", + "enum": [ + "auto" + ] }, { - "type": "string" + "type": "string", + "description": "Means the model will not call any tool and instead generates a message.", + "enum": [ + "none" + ] }, { "type": "object", @@ -2262,13 +2268,10 @@ "$ref": "#/components/schemas/FunctionName" } } - }, - { - "type": "object", - "default": null, - "nullable": true } - ] + ], + "description": "Controls which (if any) tool is called by the model.", + "example": "auto" }, "Url": { "type": "object", diff --git a/router/src/lib.rs b/router/src/lib.rs index 6d342903c95..6ecc6b39b35 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -1036,8 +1036,7 @@ pub struct ToolChoice(pub Option); enum ToolTypeDeserializer { Null, String(String), - ToolType(ToolType), - TypedChoice(TypedChoice), //this is the OpenAI schema + ToolType(TypedChoice), } impl From for ToolChoice { @@ -1049,10 +1048,9 @@ impl From for ToolChoice { "auto" => ToolChoice(Some(ToolType::OneOf)), _ => ToolChoice(Some(ToolType::Function(FunctionName { name: s }))), }, - ToolTypeDeserializer::TypedChoice(TypedChoice::Function { function }) => { + ToolTypeDeserializer::ToolType(TypedChoice::Function { function }) => { ToolChoice(Some(ToolType::Function(function))) } - ToolTypeDeserializer::ToolType(tool_type) => ToolChoice(Some(tool_type)), } } } @@ -1682,10 +1680,6 @@ mod tests { let de_named: TestRequest = serde_json::from_str(named).unwrap(); assert_eq!(de_named.tool_choice, ref_choice); - let old_named = r#"{"tool_choice":{"function":{"name":"myfn"}}}"#; - let de_old_named: TestRequest = serde_json::from_str(old_named).unwrap(); - assert_eq!(de_old_named.tool_choice, ref_choice); - let openai_named = r#"{"tool_choice":{"type":"function","function":{"name":"myfn"}}}"#; let de_openai_named: TestRequest = serde_json::from_str(openai_named).unwrap(); From f53c8059e9ea14059279334ba87a51221bfa3fff Mon Sep 17 00:00:00 2001 From: David Holtz Date: Mon, 14 Oct 2024 17:40:45 +0000 Subject: [PATCH 05/15] feat: improve, simplify and rename tool choice struct add required support and refactor --- docs/openapi.json | 78 ++++++++++++----------- router/src/infer/tool_grammar.rs | 57 +++++++++-------- router/src/lib.rs | 103 ++++++++++++++++++++----------- router/src/server.rs | 9 ++- 4 files changed, 139 insertions(+), 108 deletions(-) diff --git a/docs/openapi.json b/docs/openapi.json index 903f742629f..06c1f144156 100644 --- a/docs/openapi.json +++ b/docs/openapi.json @@ -921,6 +921,42 @@ } } }, + "ChatCompletionToolChoiceOption": { + "oneOf": [ + { + "type": "string", + "description": "Means the model can pick between generating a message or calling one or more tools.", + "enum": [ + "auto" + ] + }, + { + "type": "string", + "description": "Means the model will not call any tool and instead generates a message.", + "enum": [ + "none" + ] + }, + { + "type": "string", + "description": "Means the model must call one or more tools.", + "enum": [ + "required" + ] + }, + { + "type": "object", + "required": [ + "function" + ], + "properties": { + "function": { + "$ref": "#/components/schemas/FunctionName" + } + } + } + ] + }, "ChatCompletionTopLogprob": { "type": "object", "required": [ @@ -1055,9 +1091,10 @@ "tool_choice": { "allOf": [ { - "$ref": "#/components/schemas/ToolChoice" + "$ref": "#/components/schemas/ChatCompletionToolChoiceOption" } ], + "default": "null", "nullable": true }, "tool_prompt": { @@ -2234,45 +2271,6 @@ } } }, - "ToolChoice": { - "allOf": [ - { - "$ref": "#/components/schemas/ToolType" - } - ], - "nullable": true - }, - "ToolType": { - "oneOf": [ - { - "type": "string", - "description": "Means the model can pick between generating a message or calling one or more tools.", - "enum": [ - "auto" - ] - }, - { - "type": "string", - "description": "Means the model will not call any tool and instead generates a message.", - "enum": [ - "none" - ] - }, - { - "type": "object", - "required": [ - "function" - ], - "properties": { - "function": { - "$ref": "#/components/schemas/FunctionName" - } - } - } - ], - "description": "Controls which (if any) tool is called by the model.", - "example": "auto" - }, "Url": { "type": "object", "required": [ diff --git a/router/src/infer/tool_grammar.rs b/router/src/infer/tool_grammar.rs index f86205fb532..b9070812a76 100644 --- a/router/src/infer/tool_grammar.rs +++ b/router/src/infer/tool_grammar.rs @@ -1,7 +1,7 @@ use crate::infer::InferError; use crate::{ - FunctionDefinition, FunctionRef, FunctionsMap, JsonSchemaTool, Properties, Tool, ToolChoice, - ToolType, + ChatCompletionToolChoiceOption, FunctionDefinition, FunctionRef, FunctionsMap, JsonSchemaTool, + Properties, Tool, }; use serde_json::{json, Map, Value}; use std::collections::HashMap; @@ -20,44 +20,47 @@ impl ToolGrammar { pub fn apply( tools: Vec, - tool_choice: ToolChoice, + tool_choice: ChatCompletionToolChoiceOption, ) -> Result<(Vec, Option), InferError> { // if no tools are provided, we return None if tools.is_empty() { return Ok((tools, None)); } - let tool_choice = tool_choice.0.unwrap_or(ToolType::OneOf); - let mut tools = tools.clone(); - // add the no_tool function to the tools - let no_tool = Tool { - r#type: "function".to_string(), - function: FunctionDefinition { - name: "no_tool".to_string(), - description: Some("Open ened response with no specific tool selected".to_string()), - arguments: json!({ - "type": "object", - "properties": { - "content": { - "type": "string", - "description": "The response content", - } - }, - "required": ["content"] - }), - }, - }; - tools.push(no_tool); + // add the no_tool function to the tools as long as we are not required to use a specific tool + if tool_choice != ChatCompletionToolChoiceOption::Required { + let no_tool = Tool { + r#type: "function".to_string(), + function: FunctionDefinition { + name: "no_tool".to_string(), + description: Some( + "Open ended response with no specific tool selected".to_string(), + ), + arguments: json!({ + "type": "object", + "properties": { + "content": { + "type": "string", + "description": "The response content", + } + }, + "required": ["content"] + }), + }, + }; + tools.push(no_tool); + } // if tools are provided and no tool_choice we default to the OneOf let tools_to_use = match tool_choice { - ToolType::Function(function) => { + ChatCompletionToolChoiceOption::Function(function) => { vec![Self::find_tool_by_name(&tools, &function.name)?] } - ToolType::OneOf => tools.clone(), - ToolType::NoTool => return Ok((tools, None)), + ChatCompletionToolChoiceOption::Required => tools.clone(), + ChatCompletionToolChoiceOption::Auto => tools.clone(), + ChatCompletionToolChoiceOption::NoTool => return Ok((tools, None)), }; let functions: HashMap = tools_to_use diff --git a/router/src/lib.rs b/router/src/lib.rs index 6ecc6b39b35..15d202a891d 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -892,8 +892,8 @@ pub(crate) struct ChatRequest { /// A specific tool to use. If not provided, the model will default to use any of the tools provided in the tools parameter. #[serde(default)] - #[schema(nullable = true, example = "null")] - pub tool_choice: ToolChoice, + #[schema(nullable = true, default = "null", example = "null")] + pub tool_choice: Option, /// Response format constraints for the generation. /// @@ -949,8 +949,18 @@ impl ChatRequest { let (inputs, grammar, using_tools) = prepare_chat_input( infer, response_format, - tools, - tool_choice, + tools.clone(), + // unwrap or default (use "auto" if tools are present, and "none" if not) + tool_choice.map_or_else( + || { + if tools.is_some() { + ChatCompletionToolChoiceOption::Auto + } else { + ChatCompletionToolChoiceOption::NoTool + } + }, + |t| t, + ), &tool_prompt, guideline, messages, @@ -999,22 +1009,6 @@ pub fn default_tool_prompt() -> String { "\nGiven the functions available, please respond with a JSON for a function call with its proper arguments that best answers the given prompt. Respond in the format {name: function name, parameters: dictionary of argument name and its value}.Do not use variables.\n".to_string() } -#[derive(Clone, Debug, Deserialize, PartialEq, Serialize, ToSchema)] -#[schema(example = "auto")] -/// Controls which (if any) tool is called by the model. -pub enum ToolType { - /// Means the model can pick between generating a message or calling one or more tools. - #[schema(rename = "auto")] - OneOf, - /// Means the model will not call any tool and instead generates a message. - #[schema(rename = "none")] - NoTool, - /// Forces the model to call a specific tool. - #[schema(rename = "function")] - #[serde(alias = "function")] - Function(FunctionName), -} - #[derive(Clone, Debug, Deserialize, PartialEq, Serialize)] #[serde(tag = "type")] pub enum TypedChoice { @@ -1027,29 +1021,59 @@ pub struct FunctionName { pub name: String, } -#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Default, ToSchema)] +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, ToSchema, Default)] #[serde(from = "ToolTypeDeserializer")] -pub struct ToolChoice(pub Option); +pub enum ChatCompletionToolChoiceOption { + /// Means the model can pick between generating a message or calling one or more tools. + #[schema(rename = "auto")] + Auto, + /// Means the model will not call any tool and instead generates a message. + #[schema(rename = "none")] + #[default] + NoTool, + /// Means the model must call one or more tools. + #[schema(rename = "required")] + Required, + /// Forces the model to call a specific tool. + #[schema(rename = "function")] + #[serde(alias = "function")] + Function(FunctionName), +} -#[derive(Deserialize)] +#[derive(Deserialize, ToSchema)] #[serde(untagged)] +/// Controls which (if any) tool is called by the model. +/// - `none` means the model will not call any tool and instead generates a message. +/// - `auto` means the model can pick between generating a message or calling one or more tools. +/// - `required` means the model must call one or more tools. +/// - Specifying a particular tool via `{\"type\": \"function\", \"function\": {\"name\": \"my_function\"}}` forces the model to call that tool. +/// +/// `none` is the default when no tools are present. `auto` is the default if tools are present." enum ToolTypeDeserializer { + /// `none` means the model will not call any tool and instead generates a message. Null, + + /// `auto` means the model can pick between generating a message or calling one or more tools. + #[schema(example = "auto")] String(String), - ToolType(TypedChoice), + + /// Specifying a particular tool forces the model to call that tool, with structured function details. + #[schema(example = r#"{"type": "function", "function": {"name": "my_function"}}"#)] + TypedChoice(TypedChoice), } -impl From for ToolChoice { +impl From for ChatCompletionToolChoiceOption { fn from(value: ToolTypeDeserializer) -> Self { match value { - ToolTypeDeserializer::Null => ToolChoice(None), + ToolTypeDeserializer::Null => ChatCompletionToolChoiceOption::NoTool, ToolTypeDeserializer::String(s) => match s.as_str() { - "none" => ToolChoice(Some(ToolType::NoTool)), - "auto" => ToolChoice(Some(ToolType::OneOf)), - _ => ToolChoice(Some(ToolType::Function(FunctionName { name: s }))), + "none" => ChatCompletionToolChoiceOption::NoTool, + "auto" => ChatCompletionToolChoiceOption::Auto, + "required" => ChatCompletionToolChoiceOption::Required, + _ => ChatCompletionToolChoiceOption::Function(FunctionName { name: s }), }, - ToolTypeDeserializer::ToolType(TypedChoice::Function { function }) => { - ToolChoice(Some(ToolType::Function(function))) + ToolTypeDeserializer::TypedChoice(TypedChoice::Function { function }) => { + ChatCompletionToolChoiceOption::Function(function) } } } @@ -1661,20 +1685,27 @@ mod tests { fn tool_choice_formats() { #[derive(Deserialize)] struct TestRequest { - tool_choice: ToolChoice, + tool_choice: ChatCompletionToolChoiceOption, } let none = r#"{"tool_choice":"none"}"#; let de_none: TestRequest = serde_json::from_str(none).unwrap(); - assert_eq!(de_none.tool_choice, ToolChoice(Some(ToolType::NoTool))); + assert_eq!(de_none.tool_choice, ChatCompletionToolChoiceOption::NoTool); let auto = r#"{"tool_choice":"auto"}"#; let de_auto: TestRequest = serde_json::from_str(auto).unwrap(); - assert_eq!(de_auto.tool_choice, ToolChoice(Some(ToolType::OneOf))); + assert_eq!(de_auto.tool_choice, ChatCompletionToolChoiceOption::Auto); - let ref_choice = ToolChoice(Some(ToolType::Function(FunctionName { + let auto = r#"{"tool_choice":"required"}"#; + let de_auto: TestRequest = serde_json::from_str(auto).unwrap(); + assert_eq!( + de_auto.tool_choice, + ChatCompletionToolChoiceOption::Required + ); + + let ref_choice = ChatCompletionToolChoiceOption::Function(FunctionName { name: "myfn".to_string(), - }))); + }); let named = r#"{"tool_choice":"myfn"}"#; let de_named: TestRequest = serde_json::from_str(named).unwrap(); diff --git a/router/src/server.rs b/router/src/server.rs index 863607b185c..26a43f0a829 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -28,7 +28,7 @@ use crate::{ ChatRequest, Chunk, CompatGenerateRequest, Completion, CompletionComplete, CompletionFinal, CompletionRequest, CompletionType, DeltaToolCall, Function, Prompt, Tool, }; -use crate::{FunctionDefinition, HubPreprocessorConfig, ToolCall, ToolChoice, ToolType}; +use crate::{ChatCompletionToolChoiceOption, FunctionDefinition, HubPreprocessorConfig, ToolCall}; use crate::{ModelInfo, ModelsInfo}; use async_stream::__private::AsyncStream; use axum::extract::Extension; @@ -1551,12 +1551,11 @@ GrammarType, Usage, StreamOptions, DeltaToolCall, -ToolType, Tool, ToolCall, Function, FunctionDefinition, -ToolChoice, +ChatCompletionToolChoiceOption, ModelInfo, ) ), @@ -2522,7 +2521,7 @@ pub(crate) fn prepare_chat_input( infer: &Infer, response_format: Option, tools: Option>, - tool_choice: ToolChoice, + tool_choice: ChatCompletionToolChoiceOption, tool_prompt: &str, guideline: Option, messages: Vec, @@ -2660,7 +2659,7 @@ mod tests { &infer, response_format, tools, - ToolChoice(None), + ChatCompletionToolChoiceOption::Auto, tool_prompt, guideline, messages, From b2db1075e4381c2d97f9146823d2417a41b87c2f Mon Sep 17 00:00:00 2001 From: David Holtz Date: Tue, 15 Oct 2024 14:01:02 +0000 Subject: [PATCH 06/15] fix: simplify tool choice logic, improve tests, openapi and rust docs --- docs/openapi.json | 3 +- ..._sea_creatures_stream_function_object.json | 27 +++++ ...r_tools_sea_creatures_stream_required.json | 28 +++++ integration-tests/models/test_tools_llama.py | 103 +++++++++++++++++- router/src/infer/tool_grammar.rs | 58 +++++----- router/src/lib.rs | 65 +++++------ router/src/server.rs | 2 +- 7 files changed, 222 insertions(+), 64 deletions(-) create mode 100644 integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_sea_creatures_stream_function_object.json create mode 100644 integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_sea_creatures_stream_required.json diff --git a/docs/openapi.json b/docs/openapi.json index 06c1f144156..167bb3fb270 100644 --- a/docs/openapi.json +++ b/docs/openapi.json @@ -955,7 +955,8 @@ } } } - ] + ], + "description": "" }, "ChatCompletionTopLogprob": { "type": "object", diff --git a/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_sea_creatures_stream_function_object.json b/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_sea_creatures_stream_function_object.json new file mode 100644 index 00000000000..cf3f1fcc926 --- /dev/null +++ b/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_sea_creatures_stream_function_object.json @@ -0,0 +1,27 @@ +{ + "choices": [ + { + "delta": { + "role": "assistant", + "tool_calls": { + "function": { + "arguments": "<|eot_id|>", + "name": null + }, + "id": "", + "index": 0, + "type": "function" + } + }, + "finish_reason": "stop", + "index": 0, + "logprobs": null + } + ], + "created": 1729000499, + "id": "", + "model": "meta-llama/Llama-3.1-8B-Instruct", + "object": "chat.completion.chunk", + "system_fingerprint": "2.3.2-dev0-native", + "usage": null +} diff --git a/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_sea_creatures_stream_required.json b/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_sea_creatures_stream_required.json new file mode 100644 index 00000000000..fea26690b5b --- /dev/null +++ b/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_sea_creatures_stream_required.json @@ -0,0 +1,28 @@ +{ + "choices": [ + { + "delta": { + "content": null, + "role": "assistant", + "tool_calls": { + "function": { + "arguments": "<|eot_id|>", + "name": null + }, + "id": "", + "index": 0, + "type": "function" + } + }, + "finish_reason": "stop", + "index": 0, + "logprobs": null + } + ], + "created": 1728998230, + "id": "", + "model": "meta-llama/Llama-3.1-8B-Instruct", + "object": "chat.completion.chunk", + "system_fingerprint": "2.3.2-dev0-native", + "usage": null +} diff --git a/integration-tests/models/test_tools_llama.py b/integration-tests/models/test_tools_llama.py index 98e75bb4942..ce9eb4eb164 100644 --- a/integration-tests/models/test_tools_llama.py +++ b/integration-tests/models/test_tools_llama.py @@ -1,4 +1,6 @@ import pytest +import requests +import json @pytest.fixture(scope="module") @@ -174,7 +176,7 @@ async def test_flash_llama_grammar_tools_choice( "function": { "description": None, "name": "get_current_weather", - "arguments": {"format": "celsius", "location": "Brooklyn, NY"}, + "arguments": {"format": "celsius", "location": "Brooklyn, New York"}, }, } ] @@ -327,3 +329,102 @@ async def test_flash_llama_grammar_tools_sea_creatures_stream( == "Once upon a time, in the ocean, there lived three sea creatures. There was a wise old octopus named Bob, a mischievous seagull named Sam, and a gentle sea turtle named Luna. They all lived together in a beautiful coral reef, surrounded by colorful fish and swaying sea fans" ) assert last_response == response_snapshot + + +@pytest.mark.asyncio +@pytest.mark.private +async def test_flash_llama_grammar_tools_sea_creatures_stream_required( + flash_llama_grammar_tools, response_snapshot +): + responses = await flash_llama_grammar_tools.chat( + max_tokens=100, + seed=24, + tools=tools, + tool_choice="required", + messages=[ + { + "role": "system", + "content": "You're a helpful assistant! Answer the users question best you can. If the question is not answerable by the tools, just generate a response.", + }, + { + "role": "user", + "content": "Tell me a story about 3 sea creatures", + }, + ], + stream=True, + ) + + count = 0 + tool_calls_generated = "" + last_response = None + async for response in responses: + count += 1 + assert response.choices[0].delta.content is None + tool_calls_generated += response.choices[0].delta.tool_calls.function.arguments + last_response = response + + assert count == 29 + assert ( + tool_calls_generated + == '{"function": {"_name": "get_current_weather", "format": "celsius", "location": "San Francisco, CA"}}<|eot_id|>' + ) + assert last_response == response_snapshot + + +@pytest.mark.asyncio +@pytest.mark.private +async def test_flash_llama_grammar_tools_sea_creatures_stream_function_object( + flash_llama_grammar_tools, response_snapshot +): + # using `requests` to send the request until the client library supports tool_choice as a function object + responses = requests.post( + f"{flash_llama_grammar_tools.base_url}/v1/chat/completions", + headers=flash_llama_grammar_tools.headers, + json={ + "model": "tgi", + "messages": [ + { + "role": "system", + "content": "You're a helpful assistant! Answer the users question best you can. If the question is not answerable by the tools, just generate a response.", + }, + { + "role": "user", + "content": "Tell me a story about 3 sea creatures", + }, + ], + "tools": tools, + "tool_choice": { + "type": "function", + "function": {"name": "get_current_weather"}, + }, + "seed": 24, + "max_tokens": 100, + "stream": True, + }, + stream=True, + ) + # iterate over the response in chunks + count = 0 + tool_calls_generated = "" + last_response = None + for chunk in responses.iter_content(chunk_size=1024): + if chunk: + count += 1 + # remove the "data: " prefix, trailing newline, and split the chunk into individual lines + lines = chunk.decode("utf-8").replace("data: ", "").rstrip("\n").split("\n") + for line in lines: + if line == "[DONE]": + break + response = json.loads(line) + tool_calls_generated += response["choices"][0]["delta"]["tool_calls"][ + "function" + ]["arguments"] + last_response = response + + assert count == 30 + print(tool_calls_generated) + assert ( + tool_calls_generated + == '{"function": {"_name": "get_current_weather", "format": "celsius", "location": "Tokyo, JP"}}<|eot_id|>' + ) + assert last_response == response_snapshot diff --git a/router/src/infer/tool_grammar.rs b/router/src/infer/tool_grammar.rs index b9070812a76..772485590cc 100644 --- a/router/src/infer/tool_grammar.rs +++ b/router/src/infer/tool_grammar.rs @@ -27,39 +27,37 @@ impl ToolGrammar { return Ok((tools, None)); } - let mut tools = tools.clone(); - - // add the no_tool function to the tools as long as we are not required to use a specific tool - if tool_choice != ChatCompletionToolChoiceOption::Required { - let no_tool = Tool { - r#type: "function".to_string(), - function: FunctionDefinition { - name: "no_tool".to_string(), - description: Some( - "Open ended response with no specific tool selected".to_string(), - ), - arguments: json!({ - "type": "object", - "properties": { - "content": { - "type": "string", - "description": "The response content", - } - }, - "required": ["content"] - }), - }, - }; - tools.push(no_tool); - } - - // if tools are provided and no tool_choice we default to the OneOf let tools_to_use = match tool_choice { ChatCompletionToolChoiceOption::Function(function) => { vec![Self::find_tool_by_name(&tools, &function.name)?] } - ChatCompletionToolChoiceOption::Required => tools.clone(), - ChatCompletionToolChoiceOption::Auto => tools.clone(), + ChatCompletionToolChoiceOption::Required => tools, + ChatCompletionToolChoiceOption::Auto => { + // only add the no_tool function if the user has selected the auto option + tools + .iter() + .cloned() + .chain(std::iter::once(Tool { + r#type: "function".to_string(), + function: FunctionDefinition { + name: "no_tool".to_string(), + description: Some( + "Open ended response with no specific tool selected".to_string(), + ), + arguments: json!({ + "type": "object", + "properties": { + "content": { + "type": "string", + "description": "The response content", + } + }, + "required": ["content"] + }), + }, + })) + .collect::>() + } ChatCompletionToolChoiceOption::NoTool => return Ok((tools, None)), }; @@ -121,6 +119,6 @@ impl ToolGrammar { }, }; - Ok((tools, Some(tool_schema))) + Ok((tools_to_use, Some(tool_schema))) } } diff --git a/router/src/lib.rs b/router/src/lib.rs index 15d202a891d..df70f827f66 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -946,21 +946,19 @@ impl ChatRequest { Some(temperature) if temperature == 0.0 => (false, None), other => (true, other), }; + // unwrap or default (use "auto" if tools are present, and "none" if not) + let choice = tool_choice.unwrap_or_else(|| { + if tools.is_some() { + ChatCompletionToolChoiceOption::Auto + } else { + ChatCompletionToolChoiceOption::NoTool + } + }); let (inputs, grammar, using_tools) = prepare_chat_input( infer, response_format, - tools.clone(), - // unwrap or default (use "auto" if tools are present, and "none" if not) - tool_choice.map_or_else( - || { - if tools.is_some() { - ChatCompletionToolChoiceOption::Auto - } else { - ChatCompletionToolChoiceOption::NoTool - } - }, - |t| t, - ), + tools, + choice, &tool_prompt, guideline, messages, @@ -1023,6 +1021,7 @@ pub struct FunctionName { #[derive(Debug, Clone, PartialEq, Serialize, Deserialize, ToSchema, Default)] #[serde(from = "ToolTypeDeserializer")] +/// pub enum ChatCompletionToolChoiceOption { /// Means the model can pick between generating a message or calling one or more tools. #[schema(rename = "auto")] @@ -1034,7 +1033,7 @@ pub enum ChatCompletionToolChoiceOption { /// Means the model must call one or more tools. #[schema(rename = "required")] Required, - /// Forces the model to call a specific tool. + /// Forces the model to call a specific tool. This structure aligns with the `OpenAI` API schema to force a specific tool. #[schema(rename = "function")] #[serde(alias = "function")] Function(FunctionName), @@ -1688,32 +1687,36 @@ mod tests { tool_choice: ChatCompletionToolChoiceOption, } - let none = r#"{"tool_choice":"none"}"#; - let de_none: TestRequest = serde_json::from_str(none).unwrap(); + let de_none: TestRequest = serde_json::from_str(r#"{"tool_choice":"none"}"#).unwrap(); assert_eq!(de_none.tool_choice, ChatCompletionToolChoiceOption::NoTool); - let auto = r#"{"tool_choice":"auto"}"#; - let de_auto: TestRequest = serde_json::from_str(auto).unwrap(); + let de_auto: TestRequest = serde_json::from_str(r#"{"tool_choice":"auto"}"#).unwrap(); assert_eq!(de_auto.tool_choice, ChatCompletionToolChoiceOption::Auto); - let auto = r#"{"tool_choice":"required"}"#; - let de_auto: TestRequest = serde_json::from_str(auto).unwrap(); + let de_required: TestRequest = + serde_json::from_str(r#"{"tool_choice":"required"}"#).unwrap(); assert_eq!( - de_auto.tool_choice, + de_required.tool_choice, ChatCompletionToolChoiceOption::Required ); - let ref_choice = ChatCompletionToolChoiceOption::Function(FunctionName { - name: "myfn".to_string(), - }); - - let named = r#"{"tool_choice":"myfn"}"#; - let de_named: TestRequest = serde_json::from_str(named).unwrap(); - assert_eq!(de_named.tool_choice, ref_choice); - - let openai_named = r#"{"tool_choice":{"type":"function","function":{"name":"myfn"}}}"#; - let de_openai_named: TestRequest = serde_json::from_str(openai_named).unwrap(); + let de_named: TestRequest = serde_json::from_str(r#"{"tool_choice":"myfn"}"#).unwrap(); + assert_eq!( + de_named.tool_choice, + ChatCompletionToolChoiceOption::Function(FunctionName { + name: "myfn".to_string(), + }) + ); - assert_eq!(de_openai_named.tool_choice, ref_choice); + let de_openai_named: TestRequest = serde_json::from_str( + r#"{"tool_choice":{"type":"function","function":{"name":"myfn"}}}"#, + ) + .unwrap(); + assert_eq!( + de_openai_named.tool_choice, + ChatCompletionToolChoiceOption::Function(FunctionName { + name: "myfn".to_string(), + }) + ); } } diff --git a/router/src/server.rs b/router/src/server.rs index 26a43f0a829..c46351e5032 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -2668,6 +2668,6 @@ mod tests { assert!(result.is_ok()); let (inputs, _grammar, using_tools) = result.expect("Failed to prepare chat input"); assert_eq!(using_tools, true); - assert_eq!(inputs, "[AVAILABLE_TOOLS] [{\"type\": \"function\", \"function\": {\"arguments\": {\"properties\":{\"format\":{\"description\":\"The temperature unit to use. Infer this from the users location.\",\"enum\":[\"celsius\",\"fahrenheit\"],\"type\":\"string\"},\"location\":{\"description\":\"The city and state, e.g. San Francisco, CA\",\"type\":\"string\"}},\"required\":[\"location\",\"format\"],\"type\":\"object\"}, \"description\": \"Get the current weather\", \"name\": \"get_current_weather\"}}, {\"type\": \"function\", \"function\": {\"arguments\": {\"properties\":{\"content\":{\"description\":\"The response content\",\"type\":\"string\"}},\"required\":[\"content\"],\"type\":\"object\"}, \"description\": \"Open ened response with no specific tool selected\", \"name\": \"no_tool\"}}][/AVAILABLE_TOOLS][INST] What is the weather like in New York?\n---\nGiven the functions available, please respond with a JSON for a function call with its proper arguments that best answers the given prompt. Respond in the format {name: function name, parameters: dictionary of argument name and its value}.Do not use variables.[/INST]".to_string()); + assert_eq!(inputs, "[AVAILABLE_TOOLS] [{\"type\": \"function\", \"function\": {\"arguments\": {\"properties\":{\"format\":{\"description\":\"The temperature unit to use. Infer this from the users location.\",\"enum\":[\"celsius\",\"fahrenheit\"],\"type\":\"string\"},\"location\":{\"description\":\"The city and state, e.g. San Francisco, CA\",\"type\":\"string\"}},\"required\":[\"location\",\"format\"],\"type\":\"object\"}, \"description\": \"Get the current weather\", \"name\": \"get_current_weather\"}}, {\"type\": \"function\", \"function\": {\"arguments\": {\"properties\":{\"content\":{\"description\":\"The response content\",\"type\":\"string\"}},\"required\":[\"content\"],\"type\":\"object\"}, \"description\": \"Open ended response with no specific tool selected\", \"name\": \"no_tool\"}}][/AVAILABLE_TOOLS][INST] What is the weather like in New York?\n---\nGiven the functions available, please respond with a JSON for a function call with its proper arguments that best answers the given prompt. Respond in the format {name: function name, parameters: dictionary of argument name and its value}.Do not use variables.[/INST]".to_string()); } } From daa1c6280a122249cccc7baac3b7b1bbeb62ebcf Mon Sep 17 00:00:00 2001 From: David Holtz Date: Tue, 15 Oct 2024 15:00:24 +0000 Subject: [PATCH 07/15] fix: refactor away prepare_chat_input and improve tool grammar apply control flow --- router/src/infer/tool_grammar.rs | 28 +++--- router/src/lib.rs | 152 ++++++++++++++++++++++++++--- router/src/server.rs | 160 +------------------------------ 3 files changed, 158 insertions(+), 182 deletions(-) diff --git a/router/src/infer/tool_grammar.rs b/router/src/infer/tool_grammar.rs index 772485590cc..982535535aa 100644 --- a/router/src/infer/tool_grammar.rs +++ b/router/src/infer/tool_grammar.rs @@ -58,7 +58,7 @@ impl ToolGrammar { })) .collect::>() } - ChatCompletionToolChoiceOption::NoTool => return Ok((tools, None)), + ChatCompletionToolChoiceOption::NoTool => Vec::with_capacity(0), }; let functions: HashMap = tools_to_use @@ -107,18 +107,22 @@ impl ToolGrammar { }) .collect(); - let tool_schema = JsonSchemaTool { - functions_map: FunctionsMap { functions }, - properties: Properties { - function: tools_to_use - .iter() - .map(|tool| FunctionRef { - ref_path: format!("#/$functions/{}", tool.function.name.clone()), - }) - .collect(), - }, + let tool_schema = if tools_to_use.is_empty() { + None + } else { + Some(JsonSchemaTool { + functions_map: FunctionsMap { functions }, + properties: Properties { + function: tools_to_use + .iter() + .map(|tool| FunctionRef { + ref_path: format!("#/$functions/{}", tool.function.name.clone()), + }) + .collect(), + }, + }) }; - Ok((tools_to_use, Some(tool_schema))) + Ok((tools_to_use, tool_schema)) } } diff --git a/router/src/lib.rs b/router/src/lib.rs index df70f827f66..f76e440b6e9 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -12,8 +12,8 @@ mod sagemaker; pub mod usage_stats; mod vertex; +use crate::infer::tool_grammar::ToolGrammar; use crate::infer::{Infer, InferError}; -use crate::server::prepare_chat_input; use pyo3::prelude::*; use pyo3::types::IntoPyDict; use serde::{Deserialize, Serialize}; @@ -947,22 +947,46 @@ impl ChatRequest { other => (true, other), }; // unwrap or default (use "auto" if tools are present, and "none" if not) - let choice = tool_choice.unwrap_or_else(|| { + let tool_choice = tool_choice.unwrap_or_else(|| { if tools.is_some() { ChatCompletionToolChoiceOption::Auto } else { ChatCompletionToolChoiceOption::NoTool } }); - let (inputs, grammar, using_tools) = prepare_chat_input( - infer, - response_format, - tools, - choice, - &tool_prompt, - guideline, - messages, - )?; + + if response_format.is_some() && tools.is_some() { + return Err(InferError::ToolError( + "Grammar and tools are mutually exclusive".into(), + )); + } + + let (inputs, grammar, using_tools) = match response_format { + Some(format) => { + let inputs = infer.apply_chat_template(guideline, messages, None)?; + (inputs, Some(format), false) + } + None => { + if let Some(tools) = tools { + let (updated_tools, tool_schema) = ToolGrammar::apply(tools, tool_choice)?; + + let grammar = tool_schema + .as_ref() + .map(|t| GrammarType::Json(serde_json::json!(t))); + + let inputs: String = infer.apply_chat_template( + guideline, + messages, + Some((updated_tools, tool_prompt)), + )?; + (inputs, grammar, tool_schema.is_some()) + } else { + // if no response_format or tools are set simply apply the chat template to generate inputs + let inputs = infer.apply_chat_template(guideline, messages, None)?; + (inputs, None, false) + } + } + }; Ok(( GenerateRequest { @@ -1239,6 +1263,7 @@ pub(crate) enum OutputMessage { } #[derive(Clone, Debug, Deserialize, ToSchema)] +#[cfg_attr(test, derive(PartialEq))] pub(crate) struct GenerateRequest { #[schema(example = "My name is Olivier and I")] pub inputs: String, @@ -1719,4 +1744,109 @@ mod tests { }) ); } + + #[test] + fn test_try_into_generate_with_tools_and_template() { + use crate::infer::Backend; + use crate::infer::InferStreamResponse; + use crate::validation::ValidGenerateRequest; + use crate::ChatTemplateVersions; + use crate::HubTokenizerConfig; + use crate::TokenizerConfigToken; + use crate::Tool; + + use core::future::Future; + use std::pin::Pin; + use tokio_stream::wrappers::UnboundedReceiverStream; + + // Mock Backend to avoid network requests. This is never used since we only test the conversion. It is mocked to satisfy the Backend trait. + struct MockBackend; + impl Backend for MockBackend { + fn schedule( + &self, + _: ValidGenerateRequest, + ) -> Result>, InferError> + { + unimplemented!("Never called in this test"); + } + fn health<'a, 't>(&'a self, _: bool) -> Pin + Send + 't>> + where + 'a: 't, + Self: 't, + { + unimplemented!("Never called in this test"); + } + } + + let mut tokenizer_config = HubTokenizerConfig::default(); + tokenizer_config.bos_token = Some(TokenizerConfigToken::String("".to_string())); + tokenizer_config.eos_token = Some(TokenizerConfigToken::String("".to_string())); + tokenizer_config.chat_template = Some( + ChatTemplateVersions::Single("{%- if messages[0][\"role\"] == \"system\" %}\n {%- set system_message = messages[0][\"content\"] %}\n {%- set loop_messages = messages[1:] %}\n{%- else %}\n {%- set loop_messages = messages %}\n{%- endif %}\n{%- if not tools is defined %}\n {%- set tools = none %}\n{%- endif %}\n{%- set user_messages = loop_messages | selectattr(\"role\", \"equalto\", \"user\") | list %}\n\n{#- This block checks for alternating user/assistant messages, skipping tool calling messages #}\n{%- set ns = namespace() %}\n{%- set ns.index = 0 %}\n{%- for message in loop_messages %}\n {%- if not (message.role == \"tool\" or message.role == \"tool_results\" or (message.tool_calls is defined and message.tool_calls is not none)) %}\n {%- if (message[\"role\"] == \"user\") != (ns.index % 2 == 0) %}\n {{- raise_exception(\"After the optional system message, conversation roles must alternate user/assistant/user/assistant/...\") }}\n {%- endif %}\n {%- set ns.index = ns.index + 1 %}\n {%- endif %}\n{%- endfor %}\n\n{{- bos_token }}\n{%- for message in loop_messages %}\n {%- if message[\"role\"] == \"user\" %}\n {%- if tools is not none and (message == user_messages[-1]) %}\n {{- \"[AVAILABLE_TOOLS] [\" }}\n {%- for tool in tools %}\n {%- set tool = tool.function %}\n {{- '{\"type\": \"function\", \"function\": {' }}\n {%- for key, val in tool.items() if key != \"return\" %}\n {%- if val is string %}\n {{- '\"' + key + '\": \"' + val + '\"' }}\n {%- else %}\n {{- '\"' + key + '\": ' + val|tojson }}\n {%- endif %}\n {%- if not loop.last %}\n {{- \", \" }}\n {%- endif %}\n {%- endfor %}\n {{- \"}}\" }}\n {%- if not loop.last %}\n {{- \", \" }}\n {%- else %}\n {{- \"]\" }}\n {%- endif %}\n {%- endfor %}\n {{- \"[/AVAILABLE_TOOLS]\" }}\n {%- endif %}\n {%- if loop.last and system_message is defined %}\n {{- \"[INST] \" + system_message + \"\\n\\n\" + message[\"content\"] + \"[/INST]\" }}\n {%- else %}\n {{- \"[INST] \" + message[\"content\"] + \"[/INST]\" }}\n {%- endif %}\n {%- elif message.tool_calls is defined and message.tool_calls is not none %}\n {{- \"[TOOL_CALLS] [\" }}\n {%- for tool_call in message.tool_calls %}\n {%- set out = tool_call.function|tojson %}\n {{- out[:-1] }}\n {%- if not tool_call.id is defined or tool_call.id|length != 9 %}\n {{- raise_exception(\"Tool call IDs should be alphanumeric strings with length 9!\") }}\n {%- endif %}\n {{- ', \"id\": \"' + tool_call.id + '\"}' }}\n {%- if not loop.last %}\n {{- \", \" }}\n {%- else %}\n {{- \"]\" + eos_token }}\n {%- endif %}\n {%- endfor %}\n {%- elif message[\"role\"] == \"assistant\" %}\n {{- \" \" + message[\"content\"]|trim + eos_token}}\n {%- elif message[\"role\"] == \"tool_results\" or message[\"role\"] == \"tool\" %}\n {%- if message.content is defined and message.content.content is defined %}\n {%- set content = message.content.content %}\n {%- else %}\n {%- set content = message.content %}\n {%- endif %}\n {{- '[TOOL_RESULTS] {\"content\": ' + content|string + \", \" }}\n {%- if not message.tool_call_id is defined or message.tool_call_id|length != 9 %}\n {{- raise_exception(\"Tool call IDs should be alphanumeric strings with length 9!\") }}\n {%- endif %}\n {{- '\"call_id\": \"' + message.tool_call_id + '\"}[/TOOL_RESULTS]' }}\n {%- else %}\n {{- raise_exception(\"Only user and assistant roles are supported, with the exception of an initial optional system message!\") }}\n {%- endif %}\n{%- endfor %}\n".to_string()) + ); + + let infer = Infer::new( + MockBackend {}, // never used; just to satisfy Infer::new signature + Validation::new(1, None, None, None, 1, 1, 1, 1, 1, false), + 1, + tokenizer_config, + HubProcessorConfig::default(), + ); + + let tools = Some(vec![Tool { + r#type: "function".to_string(), + function: FunctionDefinition { + name: "get_current_weather".to_string(), + description: Some("Get the current weather".to_string()), + arguments: json!({ + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state, e.g. San Francisco, CA" + }, + "format": { + "type": "string", + "enum": ["celsius", "fahrenheit"], + "description": "The temperature unit to use. Infer this from the users location." + } + }, + "required": ["location", "format"] + }), + }, + }]); + + let request = ChatRequest { + model: None, + max_tokens: None, + logit_bias: None, + logprobs: None, + n: None, + messages: vec![Message { + name: None, + role: "user".to_string(), + content: MessageContent::SingleText( + "What is the weather like in New York?".to_string(), + ), + }], + seed: None, + stop: None, + stream: false, + tools: tools, + tool_choice: None, + tool_prompt: Some("Given the functions available, please respond with a JSON for a function call with its proper arguments that best answers the given prompt. Respond in the format {name: function name, parameters: dictionary of argument name and its value}.Do not use variables.".to_string()), + temperature: None, + response_format: None, + guideline: None, + presence_penalty: None, + frequency_penalty: None, + top_p: None, + top_logprobs: None, + stream_options: None, + }; + + let (generate, using_tools) = request.try_into_generate(&infer).unwrap(); + assert_eq!(using_tools, true); + assert_eq!(generate.inputs, "[AVAILABLE_TOOLS] [{\"type\": \"function\", \"function\": {\"arguments\": {\"properties\":{\"format\":{\"description\":\"The temperature unit to use. Infer this from the users location.\",\"enum\":[\"celsius\",\"fahrenheit\"],\"type\":\"string\"},\"location\":{\"description\":\"The city and state, e.g. San Francisco, CA\",\"type\":\"string\"}},\"required\":[\"location\",\"format\"],\"type\":\"object\"}, \"description\": \"Get the current weather\", \"name\": \"get_current_weather\"}}, {\"type\": \"function\", \"function\": {\"arguments\": {\"properties\":{\"content\":{\"description\":\"The response content\",\"type\":\"string\"}},\"required\":[\"content\"],\"type\":\"object\"}, \"description\": \"Open ended response with no specific tool selected\", \"name\": \"no_tool\"}}][/AVAILABLE_TOOLS][INST] What is the weather like in New York?\n---\nGiven the functions available, please respond with a JSON for a function call with its proper arguments that best answers the given prompt. Respond in the format {name: function name, parameters: dictionary of argument name and its value}.Do not use variables.[/INST]".to_string()); + } } diff --git a/router/src/server.rs b/router/src/server.rs index c46351e5032..0f970391d05 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -1,6 +1,5 @@ /// HTTP Server logic use crate::config::Config; -use crate::infer::tool_grammar::ToolGrammar; use crate::infer::{Backend, Infer, InferError, InferResponse, InferStreamResponse}; #[cfg(feature = "kserve")] use crate::kserve::{ @@ -2513,161 +2512,4 @@ impl From for Event { pub enum WebServerError { #[error("Axum error: {0}")] Axum(#[from] axum::BoxError), -} - -type PreparedInput = (String, Option, bool); - -pub(crate) fn prepare_chat_input( - infer: &Infer, - response_format: Option, - tools: Option>, - tool_choice: ChatCompletionToolChoiceOption, - tool_prompt: &str, - guideline: Option, - messages: Vec, -) -> Result { - if response_format.is_some() && tools.is_some() { - return Err(InferError::ToolError( - "Grammar and tools are mutually exclusive".into(), - )); - } - - // when response_format is set, tools are not included when applying the chat template to generate inputs - if let Some(format) = response_format { - let inputs = infer.apply_chat_template(guideline, messages, None)?; - return Ok((inputs, Some(format), false)); - } - - // when no response_format is set and tools are included, apply the chat template with the tools - // to generate inputs - if let Some(tools) = tools { - let (updated_tools, tool_schema) = ToolGrammar::apply(tools, tool_choice)?; - - let grammar = tool_schema - .as_ref() - .map(|t| GrammarType::Json(serde_json::json!(t))); - - let inputs: String = infer.apply_chat_template( - guideline, - messages, - Some((updated_tools, tool_prompt.into())), - )?; - return Ok((inputs, grammar, tool_schema.is_some())); - } - - // if no response_format or tools are set simply apply the chat template to generate inputs - let inputs = infer.apply_chat_template(guideline, messages, None)?; - Ok((inputs, None, false)) -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::ChatTemplateVersions; - use crate::HubTokenizerConfig; - use crate::TokenizerConfigToken; - use crate::Tool; - - use crate::tests::get_tokenizer; - use serde_json::json; - - #[tokio::test] - async fn test_prepare_chat_input() { - // Mock Backend to avoid network requests - struct MockBackend; - - impl Backend for MockBackend { - fn schedule( - &self, - _request: crate::validation::ValidGenerateRequest, - ) -> Result< - tokio_stream::wrappers::UnboundedReceiverStream< - Result, - >, - InferError, - > { - unimplemented!("Never called in this test"); - } - fn health<'a, 'async_trait>( - &'a self, - _current_health: bool, - ) -> core::pin::Pin< - Box + core::marker::Send + 'async_trait>, - > - where - 'a: 'async_trait, - Self: 'async_trait, - { - unimplemented!("Never called in this test"); - } - } - - let backend = MockBackend {}; - - let mut tokenizer_config = HubTokenizerConfig::default(); - - // mock tokenizer config values - tokenizer_config.bos_token = Some(TokenizerConfigToken::String("".to_string())); - tokenizer_config.eos_token = Some(TokenizerConfigToken::String("".to_string())); - tokenizer_config.chat_template = Some( - ChatTemplateVersions::Single("{%- if messages[0][\"role\"] == \"system\" %}\n {%- set system_message = messages[0][\"content\"] %}\n {%- set loop_messages = messages[1:] %}\n{%- else %}\n {%- set loop_messages = messages %}\n{%- endif %}\n{%- if not tools is defined %}\n {%- set tools = none %}\n{%- endif %}\n{%- set user_messages = loop_messages | selectattr(\"role\", \"equalto\", \"user\") | list %}\n\n{#- This block checks for alternating user/assistant messages, skipping tool calling messages #}\n{%- set ns = namespace() %}\n{%- set ns.index = 0 %}\n{%- for message in loop_messages %}\n {%- if not (message.role == \"tool\" or message.role == \"tool_results\" or (message.tool_calls is defined and message.tool_calls is not none)) %}\n {%- if (message[\"role\"] == \"user\") != (ns.index % 2 == 0) %}\n {{- raise_exception(\"After the optional system message, conversation roles must alternate user/assistant/user/assistant/...\") }}\n {%- endif %}\n {%- set ns.index = ns.index + 1 %}\n {%- endif %}\n{%- endfor %}\n\n{{- bos_token }}\n{%- for message in loop_messages %}\n {%- if message[\"role\"] == \"user\" %}\n {%- if tools is not none and (message == user_messages[-1]) %}\n {{- \"[AVAILABLE_TOOLS] [\" }}\n {%- for tool in tools %}\n {%- set tool = tool.function %}\n {{- '{\"type\": \"function\", \"function\": {' }}\n {%- for key, val in tool.items() if key != \"return\" %}\n {%- if val is string %}\n {{- '\"' + key + '\": \"' + val + '\"' }}\n {%- else %}\n {{- '\"' + key + '\": ' + val|tojson }}\n {%- endif %}\n {%- if not loop.last %}\n {{- \", \" }}\n {%- endif %}\n {%- endfor %}\n {{- \"}}\" }}\n {%- if not loop.last %}\n {{- \", \" }}\n {%- else %}\n {{- \"]\" }}\n {%- endif %}\n {%- endfor %}\n {{- \"[/AVAILABLE_TOOLS]\" }}\n {%- endif %}\n {%- if loop.last and system_message is defined %}\n {{- \"[INST] \" + system_message + \"\\n\\n\" + message[\"content\"] + \"[/INST]\" }}\n {%- else %}\n {{- \"[INST] \" + message[\"content\"] + \"[/INST]\" }}\n {%- endif %}\n {%- elif message.tool_calls is defined and message.tool_calls is not none %}\n {{- \"[TOOL_CALLS] [\" }}\n {%- for tool_call in message.tool_calls %}\n {%- set out = tool_call.function|tojson %}\n {{- out[:-1] }}\n {%- if not tool_call.id is defined or tool_call.id|length != 9 %}\n {{- raise_exception(\"Tool call IDs should be alphanumeric strings with length 9!\") }}\n {%- endif %}\n {{- ', \"id\": \"' + tool_call.id + '\"}' }}\n {%- if not loop.last %}\n {{- \", \" }}\n {%- else %}\n {{- \"]\" + eos_token }}\n {%- endif %}\n {%- endfor %}\n {%- elif message[\"role\"] == \"assistant\" %}\n {{- \" \" + message[\"content\"]|trim + eos_token}}\n {%- elif message[\"role\"] == \"tool_results\" or message[\"role\"] == \"tool\" %}\n {%- if message.content is defined and message.content.content is defined %}\n {%- set content = message.content.content %}\n {%- else %}\n {%- set content = message.content %}\n {%- endif %}\n {{- '[TOOL_RESULTS] {\"content\": ' + content|string + \", \" }}\n {%- if not message.tool_call_id is defined or message.tool_call_id|length != 9 %}\n {{- raise_exception(\"Tool call IDs should be alphanumeric strings with length 9!\") }}\n {%- endif %}\n {{- '\"call_id\": \"' + message.tool_call_id + '\"}[/TOOL_RESULTS]' }}\n {%- else %}\n {{- raise_exception(\"Only user and assistant roles are supported, with the exception of an initial optional system message!\") }}\n {%- endif %}\n{%- endfor %}\n".to_string()) - ); - - let tokenizer = get_tokenizer(); - - let infer = Infer::new( - backend, - Validation::new(1, tokenizer, None, None, 1, 1, 1, 1, 1, false), - 1, - tokenizer_config, - HubProcessorConfig::default(), - ); - let response_format = None; - let tools = Some(vec![Tool { - r#type: "function".to_string(), - function: FunctionDefinition { - name: "get_current_weather".to_string(), - description: Some("Get the current weather".to_string()), - arguments: json!({ - "type": "object", - "properties": { - "location": { - "type": "string", - "description": "The city and state, e.g. San Francisco, CA" - }, - "format": { - "type": "string", - "enum": ["celsius", "fahrenheit"], - "description": "The temperature unit to use. Infer this from the users location." - } - }, - "required": ["location", "format"] - }), - }, - }]); - let tool_prompt = "Given the functions available, please respond with a JSON for a function call with its proper arguments that best answers the given prompt. Respond in the format {name: function name, parameters: dictionary of argument name and its value}.Do not use variables."; - let guideline = None; - let messages = vec![Message { - name: None, - role: "user".to_string(), - content: MessageContent::SingleText( - "What is the weather like in New York?".to_string(), - ), - }]; - - let result = prepare_chat_input( - &infer, - response_format, - tools, - ChatCompletionToolChoiceOption::Auto, - tool_prompt, - guideline, - messages, - ); - - assert!(result.is_ok()); - let (inputs, _grammar, using_tools) = result.expect("Failed to prepare chat input"); - assert_eq!(using_tools, true); - assert_eq!(inputs, "[AVAILABLE_TOOLS] [{\"type\": \"function\", \"function\": {\"arguments\": {\"properties\":{\"format\":{\"description\":\"The temperature unit to use. Infer this from the users location.\",\"enum\":[\"celsius\",\"fahrenheit\"],\"type\":\"string\"},\"location\":{\"description\":\"The city and state, e.g. San Francisco, CA\",\"type\":\"string\"}},\"required\":[\"location\",\"format\"],\"type\":\"object\"}, \"description\": \"Get the current weather\", \"name\": \"get_current_weather\"}}, {\"type\": \"function\", \"function\": {\"arguments\": {\"properties\":{\"content\":{\"description\":\"The response content\",\"type\":\"string\"}},\"required\":[\"content\"],\"type\":\"object\"}, \"description\": \"Open ended response with no specific tool selected\", \"name\": \"no_tool\"}}][/AVAILABLE_TOOLS][INST] What is the weather like in New York?\n---\nGiven the functions available, please respond with a JSON for a function call with its proper arguments that best answers the given prompt. Respond in the format {name: function name, parameters: dictionary of argument name and its value}.Do not use variables.[/INST]".to_string()); - } -} +} \ No newline at end of file From dd759e79145788f34900bbc9bf60dd15c3000745 Mon Sep 17 00:00:00 2001 From: David Holtz Date: Tue, 15 Oct 2024 17:13:16 +0000 Subject: [PATCH 08/15] feat: update docs and add tool choice configuration section --- docs/source/basic_tutorials/using_guidance.md | 60 ++++++++++++++++++- 1 file changed, 58 insertions(+), 2 deletions(-) diff --git a/docs/source/basic_tutorials/using_guidance.md b/docs/source/basic_tutorials/using_guidance.md index dfa3f0e49b1..2d55c9528c1 100644 --- a/docs/source/basic_tutorials/using_guidance.md +++ b/docs/source/basic_tutorials/using_guidance.md @@ -315,8 +315,6 @@ print(chat.choices[0].message.tool_calls) TGI exposes an OpenAI-compatible API, which means you can use OpenAI's client libraries to interact with TGI's Messages API and Tool functions. -However there are some minor differences in the API, for example `tool_choice="auto"` will ALWAYS choose the tool for you. This is different from OpenAI's API where `tool_choice="auto"` will choose a tool if the model thinks it's necessary. - ```python from openai import OpenAI @@ -362,3 +360,61 @@ print(called) # }, # } ``` + +### Tool Choice Configuration + +When configuring how the model interacts with tools during a chat completion, there are several options for determining if or how a tool should be called. These options are controlled by the `tool_choice` parameter, which specifies the behavior of the model in relation to tool usage. The following modes are supported: + +1. **`auto`**: + + - The model decides whether to call a tool or generate a response message based on the user's input. + - If tools are provided, this is the default mode. + - Example usage: + ```python + tool_choice="auto" + ``` + +2. **`none`**: + + - The model will never call any tools and will only generate a response message. + - If no tools are provided, this is the default mode. + - Example usage: + ```python + tool_choice="none" + ``` + +3. **`required`**: + + - The model must call one or more tools and will not generate a response message on its own. + - Example usage: + ```python + tool_choice="required" + ``` + +4. **Specific Tool Call by Function Name**: + - You can force the model to call a specific tool either by specifying the tool function directly or by using an object definition. + - Two ways to do this: + 1. Provide the function name as a string: + ```python + tool_choice="get_current_weather" + ``` + 2. Use the function object format: + ```python + tool_choice={ + "type": "function", + "function": { + "name": "get_current_weather" + } + } + ``` + +These options allow flexibility when integrating tools with the chat completions endpoint. You can configure the model to either rely on tools automatically or force it to follow a predefined behavior, based on the needs of the task at hand. + +--- + +| **Tool Choice Option** | **Description** | **When to Use** | +| ------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------- | -------------------------------------------------------------------------------------- | +| `auto` | The model decides whether to call a tool or generate a message. This is the default if tools are provided. | Use when you want the model to decide when a tool is necessary. | +| `none` | The model generates a message without calling any tools. This is the default if no tools are provided. | Use when you do not want the model to call any tools. | +| `required` | The model must call one or more tools and will not generate a message on its own. | Use when a tool call is mandatory, and you do not want a regular message generated. | +| Specific Tool Call (`name` or object) | Force the model to call a specific tool either by specifying its name (`tool_choice="get_current_weather"`) or using an object. | Use when you want to restrict the model to calling a particular tool for the response. | From b5bf5b32ad48c08cef5eea1f515fab303d597fe7 Mon Sep 17 00:00:00 2001 From: David Holtz Date: Wed, 16 Oct 2024 13:49:46 +0000 Subject: [PATCH 09/15] fix: simplify naming, tool choice default and improve test --- docs/openapi.json | 78 +++++++++---------- ..._sea_creatures_stream_function_object.json | 2 +- ...r_tools_sea_creatures_stream_required.json | 2 +- integration-tests/models/test_tools_llama.py | 7 +- router/src/infer/tool_grammar.rs | 17 ++-- router/src/lib.rs | 52 +++++-------- router/src/server.rs | 4 +- 7 files changed, 74 insertions(+), 88 deletions(-) diff --git a/docs/openapi.json b/docs/openapi.json index 167bb3fb270..ba53f7ee1f9 100644 --- a/docs/openapi.json +++ b/docs/openapi.json @@ -921,43 +921,6 @@ } } }, - "ChatCompletionToolChoiceOption": { - "oneOf": [ - { - "type": "string", - "description": "Means the model can pick between generating a message or calling one or more tools.", - "enum": [ - "auto" - ] - }, - { - "type": "string", - "description": "Means the model will not call any tool and instead generates a message.", - "enum": [ - "none" - ] - }, - { - "type": "string", - "description": "Means the model must call one or more tools.", - "enum": [ - "required" - ] - }, - { - "type": "object", - "required": [ - "function" - ], - "properties": { - "function": { - "$ref": "#/components/schemas/FunctionName" - } - } - } - ], - "description": "" - }, "ChatCompletionTopLogprob": { "type": "object", "required": [ @@ -1052,7 +1015,7 @@ "$ref": "#/components/schemas/GrammarType" } ], - "default": "null", + "default": "auto", "nullable": true }, "seed": { @@ -1092,7 +1055,7 @@ "tool_choice": { "allOf": [ { - "$ref": "#/components/schemas/ChatCompletionToolChoiceOption" + "$ref": "#/components/schemas/ToolChoice" } ], "default": "null", @@ -2272,6 +2235,43 @@ } } }, + "ToolChoice": { + "oneOf": [ + { + "type": "string", + "description": "Means the model can pick between generating a message or calling one or more tools.", + "enum": [ + "auto" + ] + }, + { + "type": "string", + "description": "Means the model will not call any tool and instead generates a message.", + "enum": [ + "none" + ] + }, + { + "type": "string", + "description": "Means the model must call one or more tools.", + "enum": [ + "required" + ] + }, + { + "type": "object", + "required": [ + "function" + ], + "properties": { + "function": { + "$ref": "#/components/schemas/FunctionName" + } + } + } + ], + "description": "" + }, "Url": { "type": "object", "required": [ diff --git a/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_sea_creatures_stream_function_object.json b/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_sea_creatures_stream_function_object.json index cf3f1fcc926..e64dd49d9df 100644 --- a/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_sea_creatures_stream_function_object.json +++ b/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_sea_creatures_stream_function_object.json @@ -18,7 +18,7 @@ "logprobs": null } ], - "created": 1729000499, + "created": 1729084854, "id": "", "model": "meta-llama/Llama-3.1-8B-Instruct", "object": "chat.completion.chunk", diff --git a/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_sea_creatures_stream_required.json b/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_sea_creatures_stream_required.json index fea26690b5b..d8d538d6d6f 100644 --- a/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_sea_creatures_stream_required.json +++ b/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_sea_creatures_stream_required.json @@ -19,7 +19,7 @@ "logprobs": null } ], - "created": 1728998230, + "created": 1729084850, "id": "", "model": "meta-llama/Llama-3.1-8B-Instruct", "object": "chat.completion.chunk", diff --git a/integration-tests/models/test_tools_llama.py b/integration-tests/models/test_tools_llama.py index ce9eb4eb164..9fa993bd63a 100644 --- a/integration-tests/models/test_tools_llama.py +++ b/integration-tests/models/test_tools_llama.py @@ -395,7 +395,7 @@ async def test_flash_llama_grammar_tools_sea_creatures_stream_function_object( "tools": tools, "tool_choice": { "type": "function", - "function": {"name": "get_current_weather"}, + "function": {"name": "get_n_day_weather_forecast"}, }, "seed": 24, "max_tokens": 100, @@ -421,10 +421,9 @@ async def test_flash_llama_grammar_tools_sea_creatures_stream_function_object( ]["arguments"] last_response = response - assert count == 30 - print(tool_calls_generated) + assert count == 39 assert ( tool_calls_generated - == '{"function": {"_name": "get_current_weather", "format": "celsius", "location": "Tokyo, JP"}}<|eot_id|>' + == '{"function": {"_name": "get_n_day_weather_forecast", "format": "celsius", "location": "San Francisco, CA", "num_days":3}}<|eot_id|>' ) assert last_response == response_snapshot diff --git a/router/src/infer/tool_grammar.rs b/router/src/infer/tool_grammar.rs index 982535535aa..9c5ce2d83e1 100644 --- a/router/src/infer/tool_grammar.rs +++ b/router/src/infer/tool_grammar.rs @@ -1,7 +1,6 @@ use crate::infer::InferError; use crate::{ - ChatCompletionToolChoiceOption, FunctionDefinition, FunctionRef, FunctionsMap, JsonSchemaTool, - Properties, Tool, + FunctionDefinition, FunctionRef, FunctionsMap, JsonSchemaTool, Properties, Tool, ToolChoice, }; use serde_json::{json, Map, Value}; use std::collections::HashMap; @@ -20,19 +19,19 @@ impl ToolGrammar { pub fn apply( tools: Vec, - tool_choice: ChatCompletionToolChoiceOption, + tool_choice: ToolChoice, ) -> Result<(Vec, Option), InferError> { - // if no tools are provided, we return None + // if no tools are provided, we return None and an empty vec if tools.is_empty() { - return Ok((tools, None)); + return Ok((Vec::with_capacity(0), None)); } let tools_to_use = match tool_choice { - ChatCompletionToolChoiceOption::Function(function) => { + ToolChoice::Function(function) => { vec![Self::find_tool_by_name(&tools, &function.name)?] } - ChatCompletionToolChoiceOption::Required => tools, - ChatCompletionToolChoiceOption::Auto => { + ToolChoice::Required => tools, + ToolChoice::Auto => { // only add the no_tool function if the user has selected the auto option tools .iter() @@ -58,7 +57,7 @@ impl ToolGrammar { })) .collect::>() } - ChatCompletionToolChoiceOption::NoTool => Vec::with_capacity(0), + ToolChoice::NoTool => Vec::with_capacity(0), }; let functions: HashMap = tools_to_use diff --git a/router/src/lib.rs b/router/src/lib.rs index f76e440b6e9..59b300dd4e1 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -893,13 +893,13 @@ pub(crate) struct ChatRequest { /// A specific tool to use. If not provided, the model will default to use any of the tools provided in the tools parameter. #[serde(default)] #[schema(nullable = true, default = "null", example = "null")] - pub tool_choice: Option, + pub tool_choice: Option, /// Response format constraints for the generation. /// /// NOTE: A request can use `response_format` OR `tools` but not both. #[serde(default)] - #[schema(nullable = true, default = "null", example = "null")] + #[schema(nullable = true, default = "auto", example = "auto")] pub response_format: Option, /// A guideline to be used in the chat_template @@ -946,14 +946,8 @@ impl ChatRequest { Some(temperature) if temperature == 0.0 => (false, None), other => (true, other), }; - // unwrap or default (use "auto" if tools are present, and "none" if not) - let tool_choice = tool_choice.unwrap_or_else(|| { - if tools.is_some() { - ChatCompletionToolChoiceOption::Auto - } else { - ChatCompletionToolChoiceOption::NoTool - } - }); + // if no tool_choice is set, set default (Auto) + let tool_choice = tool_choice.unwrap_or_default(); if response_format.is_some() && tools.is_some() { return Err(InferError::ToolError( @@ -1045,21 +1039,18 @@ pub struct FunctionName { #[derive(Debug, Clone, PartialEq, Serialize, Deserialize, ToSchema, Default)] #[serde(from = "ToolTypeDeserializer")] +#[serde(rename_all = "snake_case")] /// -pub enum ChatCompletionToolChoiceOption { +pub enum ToolChoice { /// Means the model can pick between generating a message or calling one or more tools. - #[schema(rename = "auto")] + #[default] Auto, /// Means the model will not call any tool and instead generates a message. #[schema(rename = "none")] - #[default] NoTool, /// Means the model must call one or more tools. - #[schema(rename = "required")] Required, /// Forces the model to call a specific tool. This structure aligns with the `OpenAI` API schema to force a specific tool. - #[schema(rename = "function")] - #[serde(alias = "function")] Function(FunctionName), } @@ -1085,18 +1076,18 @@ enum ToolTypeDeserializer { TypedChoice(TypedChoice), } -impl From for ChatCompletionToolChoiceOption { +impl From for ToolChoice { fn from(value: ToolTypeDeserializer) -> Self { match value { - ToolTypeDeserializer::Null => ChatCompletionToolChoiceOption::NoTool, + ToolTypeDeserializer::Null => ToolChoice::NoTool, ToolTypeDeserializer::String(s) => match s.as_str() { - "none" => ChatCompletionToolChoiceOption::NoTool, - "auto" => ChatCompletionToolChoiceOption::Auto, - "required" => ChatCompletionToolChoiceOption::Required, - _ => ChatCompletionToolChoiceOption::Function(FunctionName { name: s }), + "none" => ToolChoice::NoTool, + "auto" => ToolChoice::Auto, + "required" => ToolChoice::Required, + _ => ToolChoice::Function(FunctionName { name: s }), }, ToolTypeDeserializer::TypedChoice(TypedChoice::Function { function }) => { - ChatCompletionToolChoiceOption::Function(function) + ToolChoice::Function(function) } } } @@ -1709,26 +1700,23 @@ mod tests { fn tool_choice_formats() { #[derive(Deserialize)] struct TestRequest { - tool_choice: ChatCompletionToolChoiceOption, + tool_choice: ToolChoice, } let de_none: TestRequest = serde_json::from_str(r#"{"tool_choice":"none"}"#).unwrap(); - assert_eq!(de_none.tool_choice, ChatCompletionToolChoiceOption::NoTool); + assert_eq!(de_none.tool_choice, ToolChoice::NoTool); let de_auto: TestRequest = serde_json::from_str(r#"{"tool_choice":"auto"}"#).unwrap(); - assert_eq!(de_auto.tool_choice, ChatCompletionToolChoiceOption::Auto); + assert_eq!(de_auto.tool_choice, ToolChoice::Auto); let de_required: TestRequest = serde_json::from_str(r#"{"tool_choice":"required"}"#).unwrap(); - assert_eq!( - de_required.tool_choice, - ChatCompletionToolChoiceOption::Required - ); + assert_eq!(de_required.tool_choice, ToolChoice::Required); let de_named: TestRequest = serde_json::from_str(r#"{"tool_choice":"myfn"}"#).unwrap(); assert_eq!( de_named.tool_choice, - ChatCompletionToolChoiceOption::Function(FunctionName { + ToolChoice::Function(FunctionName { name: "myfn".to_string(), }) ); @@ -1739,7 +1727,7 @@ mod tests { .unwrap(); assert_eq!( de_openai_named.tool_choice, - ChatCompletionToolChoiceOption::Function(FunctionName { + ToolChoice::Function(FunctionName { name: "myfn".to_string(), }) ); diff --git a/router/src/server.rs b/router/src/server.rs index 0f970391d05..911c77d8681 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -27,7 +27,7 @@ use crate::{ ChatRequest, Chunk, CompatGenerateRequest, Completion, CompletionComplete, CompletionFinal, CompletionRequest, CompletionType, DeltaToolCall, Function, Prompt, Tool, }; -use crate::{ChatCompletionToolChoiceOption, FunctionDefinition, HubPreprocessorConfig, ToolCall}; +use crate::{FunctionDefinition, HubPreprocessorConfig, ToolCall, ToolChoice}; use crate::{ModelInfo, ModelsInfo}; use async_stream::__private::AsyncStream; use axum::extract::Extension; @@ -1554,7 +1554,7 @@ Tool, ToolCall, Function, FunctionDefinition, -ChatCompletionToolChoiceOption, +ToolChoice, ModelInfo, ) ), From 407531708e492abc74ae35bd871437174631ae01 Mon Sep 17 00:00:00 2001 From: David Holtz Date: Fri, 18 Oct 2024 15:36:40 +0000 Subject: [PATCH 10/15] fix: adjust tool choice none logic, add test and small refactors --- docs/openapi.json | 4 +- integration-tests/models/test_tools_llama.py | 41 ++++++++++++++++++++ router/src/infer/tool_grammar.rs | 40 +++++++++---------- router/src/lib.rs | 38 +++++++++--------- 4 files changed, 81 insertions(+), 42 deletions(-) diff --git a/docs/openapi.json b/docs/openapi.json index ba53f7ee1f9..08aab6c5647 100644 --- a/docs/openapi.json +++ b/docs/openapi.json @@ -1015,7 +1015,7 @@ "$ref": "#/components/schemas/GrammarType" } ], - "default": "auto", + "default": "null", "nullable": true }, "seed": { @@ -1058,7 +1058,7 @@ "$ref": "#/components/schemas/ToolChoice" } ], - "default": "null", + "default": "auto", "nullable": true }, "tool_prompt": { diff --git a/integration-tests/models/test_tools_llama.py b/integration-tests/models/test_tools_llama.py index 9fa993bd63a..b5821945b58 100644 --- a/integration-tests/models/test_tools_llama.py +++ b/integration-tests/models/test_tools_llama.py @@ -371,6 +371,47 @@ async def test_flash_llama_grammar_tools_sea_creatures_stream_required( assert last_response == response_snapshot +@pytest.mark.asyncio +@pytest.mark.private +async def test_flash_llama_grammar_tools_sea_creatures_stream_none( + flash_llama_grammar_tools, response_snapshot +): + responses = await flash_llama_grammar_tools.chat( + max_tokens=100, + seed=24, + tools=tools, + tool_choice="none", + messages=[ + { + "role": "system", + "content": "You're a helpful assistant! Answer the users question best you can. If the question is not answerable by the tools, just generate a response.", + }, + { + "role": "user", + "content": "Tell me a story about 3 sea creatures", + }, + ], + stream=True, + ) + + count = 0 + content_generated = "" + last_response = None + async for response in responses: + count += 1 + content_generated += response.choices[0].delta.content + last_response = response + assert response.choices[0].delta.tool_calls is None + + assert count == 100 + print(content_generated) + assert ( + content_generated + == "Once upon a time, in a vibrant ocean filled with coral reefs and schools of shimmering fish, lived three dear friends: Luna the sea turtle, Finley the friendly fish, and Crusty the wise crab.\n\nLuna was the oldest of the three. She had traveled the world, exploring hidden caves and shipwrecks, and collecting sparkling shells and shiny pebbles. Her shell was a beautiful mosaic of blues and greens, and her gentle eyes twinkled with the secrets of the deep" + ) + assert last_response == response_snapshot + + @pytest.mark.asyncio @pytest.mark.private async def test_flash_llama_grammar_tools_sea_creatures_stream_function_object( diff --git a/router/src/infer/tool_grammar.rs b/router/src/infer/tool_grammar.rs index 9c5ce2d83e1..7770cd9d708 100644 --- a/router/src/infer/tool_grammar.rs +++ b/router/src/infer/tool_grammar.rs @@ -20,12 +20,7 @@ impl ToolGrammar { pub fn apply( tools: Vec, tool_choice: ToolChoice, - ) -> Result<(Vec, Option), InferError> { - // if no tools are provided, we return None and an empty vec - if tools.is_empty() { - return Ok((Vec::with_capacity(0), None)); - } - + ) -> Result, JsonSchemaTool)>, InferError> { let tools_to_use = match tool_choice { ToolChoice::Function(function) => { vec![Self::find_tool_by_name(&tools, &function.name)?] @@ -57,9 +52,14 @@ impl ToolGrammar { })) .collect::>() } - ToolChoice::NoTool => Vec::with_capacity(0), + ToolChoice::NoTool => vec![], }; + // if no tools are provided or if the user has selected the no_tool option, return None + if tools_to_use.is_empty() { + return Ok(None); + } + let functions: HashMap = tools_to_use .iter() .map(|tool| { @@ -106,22 +106,18 @@ impl ToolGrammar { }) .collect(); - let tool_schema = if tools_to_use.is_empty() { - None - } else { - Some(JsonSchemaTool { - functions_map: FunctionsMap { functions }, - properties: Properties { - function: tools_to_use - .iter() - .map(|tool| FunctionRef { - ref_path: format!("#/$functions/{}", tool.function.name.clone()), - }) - .collect(), - }, - }) + let tool_schema = JsonSchemaTool { + functions_map: FunctionsMap { functions }, + properties: Properties { + function: tools_to_use + .iter() + .map(|tool| FunctionRef { + ref_path: format!("#/$functions/{}", tool.function.name.clone()), + }) + .collect(), + }, }; - Ok((tools_to_use, tool_schema)) + Ok(Some((tools_to_use, tool_schema))) } } diff --git a/router/src/lib.rs b/router/src/lib.rs index 59b300dd4e1..009b57cffc8 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -892,14 +892,14 @@ pub(crate) struct ChatRequest { /// A specific tool to use. If not provided, the model will default to use any of the tools provided in the tools parameter. #[serde(default)] - #[schema(nullable = true, default = "null", example = "null")] - pub tool_choice: Option, + #[schema(nullable = true, default = "auto", example = "auto")] + pub tool_choice: ToolChoice, /// Response format constraints for the generation. /// /// NOTE: A request can use `response_format` OR `tools` but not both. #[serde(default)] - #[schema(nullable = true, default = "auto", example = "auto")] + #[schema(nullable = true, default = "null", example = "null")] pub response_format: Option, /// A guideline to be used in the chat_template @@ -946,8 +946,6 @@ impl ChatRequest { Some(temperature) if temperature == 0.0 => (false, None), other => (true, other), }; - // if no tool_choice is set, set default (Auto) - let tool_choice = tool_choice.unwrap_or_default(); if response_format.is_some() && tools.is_some() { return Err(InferError::ToolError( @@ -962,18 +960,22 @@ impl ChatRequest { } None => { if let Some(tools) = tools { - let (updated_tools, tool_schema) = ToolGrammar::apply(tools, tool_choice)?; - - let grammar = tool_schema - .as_ref() - .map(|t| GrammarType::Json(serde_json::json!(t))); - - let inputs: String = infer.apply_chat_template( - guideline, - messages, - Some((updated_tools, tool_prompt)), - )?; - (inputs, grammar, tool_schema.is_some()) + match ToolGrammar::apply(tools, tool_choice)? { + Some((updated_tools, tool_schema)) => { + let grammar = GrammarType::Json(serde_json::json!(tool_schema)); + let inputs: String = infer.apply_chat_template( + guideline, + messages, + Some((updated_tools, tool_prompt)), + )?; + (inputs, Some(grammar), true) + } + None => { + // same as if no response_format or tools are set + let inputs = infer.apply_chat_template(guideline, messages, None)?; + (inputs, None, false) + } + } } else { // if no response_format or tools are set simply apply the chat template to generate inputs let inputs = infer.apply_chat_template(guideline, messages, None)?; @@ -1046,7 +1048,7 @@ pub enum ToolChoice { #[default] Auto, /// Means the model will not call any tool and instead generates a message. - #[schema(rename = "none")] + #[serde(rename = "none")] NoTool, /// Means the model must call one or more tools. Required, From 1ce1cf28624a87116fbcf8bfb9568becd3a1e09c Mon Sep 17 00:00:00 2001 From: David Holtz Date: Fri, 18 Oct 2024 15:39:05 +0000 Subject: [PATCH 11/15] fix: add missing snapshot file --- ...ammar_tools_sea_creatures_stream_none.json | 20 +++++++++++++++++++ 1 file changed, 20 insertions(+) create mode 100644 integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_sea_creatures_stream_none.json diff --git a/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_sea_creatures_stream_none.json b/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_sea_creatures_stream_none.json new file mode 100644 index 00000000000..2ccab4a9dcb --- /dev/null +++ b/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_sea_creatures_stream_none.json @@ -0,0 +1,20 @@ +{ + "choices": [ + { + "delta": { + "content": " deep", + "role": "assistant", + "tool_calls": null + }, + "finish_reason": "length", + "index": 0, + "logprobs": null + } + ], + "created": 1729262528, + "id": "", + "model": "meta-llama/Llama-3.1-8B-Instruct", + "object": "chat.completion.chunk", + "system_fingerprint": "2.3.2-dev0-native", + "usage": null +} From 905d50397110fc53cebedbdb1da94c7062e1b43a Mon Sep 17 00:00:00 2001 From: David Holtz Date: Fri, 18 Oct 2024 16:11:58 +0000 Subject: [PATCH 12/15] fix: adjust tool choice type in test --- router/src/lib.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/router/src/lib.rs b/router/src/lib.rs index 009b57cffc8..d8762206834 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -1823,7 +1823,7 @@ mod tests { stop: None, stream: false, tools: tools, - tool_choice: None, + tool_choice: ToolChoice::Auto, tool_prompt: Some("Given the functions available, please respond with a JSON for a function call with its proper arguments that best answers the given prompt. Respond in the format {name: function name, parameters: dictionary of argument name and its value}.Do not use variables.".to_string()), temperature: None, response_format: None, From c1eab6cbb3516921deae53a013e902a51b8cb194 Mon Sep 17 00:00:00 2001 From: David Holtz Date: Sun, 20 Oct 2024 21:57:47 +0000 Subject: [PATCH 13/15] fix: adjust default when json tool choice is --- router/src/lib.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/router/src/lib.rs b/router/src/lib.rs index d8762206834..18d1b1b3de0 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -1066,7 +1066,7 @@ pub enum ToolChoice { /// /// `none` is the default when no tools are present. `auto` is the default if tools are present." enum ToolTypeDeserializer { - /// `none` means the model will not call any tool and instead generates a message. + /// None means `null` was passed in the JSON, and the default choice is applied based on the presence of tools. Null, /// `auto` means the model can pick between generating a message or calling one or more tools. @@ -1081,7 +1081,7 @@ enum ToolTypeDeserializer { impl From for ToolChoice { fn from(value: ToolTypeDeserializer) -> Self { match value { - ToolTypeDeserializer::Null => ToolChoice::NoTool, + ToolTypeDeserializer::Null => ToolChoice::Auto, ToolTypeDeserializer::String(s) => match s.as_str() { "none" => ToolChoice::NoTool, "auto" => ToolChoice::Auto, From daf7d979d0a2193995be85efd7ace9c64cc7046d Mon Sep 17 00:00:00 2001 From: drbh Date: Tue, 29 Oct 2024 11:49:42 -0400 Subject: [PATCH 14/15] fix: remove trailing space lint after rebase --- router/src/server.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/router/src/server.rs b/router/src/server.rs index 911c77d8681..599395e7841 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -2512,4 +2512,4 @@ impl From for Event { pub enum WebServerError { #[error("Axum error: {0}")] Axum(#[from] axum::BoxError), -} \ No newline at end of file +} From f9f34a5e2006ddec13c66cb2b27e392c260e970c Mon Sep 17 00:00:00 2001 From: drbh Date: Tue, 29 Oct 2024 13:10:22 -0400 Subject: [PATCH 15/15] fix: remove mostly mocked unit test --- router/src/lib.rs | 105 ---------------------------------------------- 1 file changed, 105 deletions(-) diff --git a/router/src/lib.rs b/router/src/lib.rs index 18d1b1b3de0..473d33e0bf6 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -1734,109 +1734,4 @@ mod tests { }) ); } - - #[test] - fn test_try_into_generate_with_tools_and_template() { - use crate::infer::Backend; - use crate::infer::InferStreamResponse; - use crate::validation::ValidGenerateRequest; - use crate::ChatTemplateVersions; - use crate::HubTokenizerConfig; - use crate::TokenizerConfigToken; - use crate::Tool; - - use core::future::Future; - use std::pin::Pin; - use tokio_stream::wrappers::UnboundedReceiverStream; - - // Mock Backend to avoid network requests. This is never used since we only test the conversion. It is mocked to satisfy the Backend trait. - struct MockBackend; - impl Backend for MockBackend { - fn schedule( - &self, - _: ValidGenerateRequest, - ) -> Result>, InferError> - { - unimplemented!("Never called in this test"); - } - fn health<'a, 't>(&'a self, _: bool) -> Pin + Send + 't>> - where - 'a: 't, - Self: 't, - { - unimplemented!("Never called in this test"); - } - } - - let mut tokenizer_config = HubTokenizerConfig::default(); - tokenizer_config.bos_token = Some(TokenizerConfigToken::String("".to_string())); - tokenizer_config.eos_token = Some(TokenizerConfigToken::String("".to_string())); - tokenizer_config.chat_template = Some( - ChatTemplateVersions::Single("{%- if messages[0][\"role\"] == \"system\" %}\n {%- set system_message = messages[0][\"content\"] %}\n {%- set loop_messages = messages[1:] %}\n{%- else %}\n {%- set loop_messages = messages %}\n{%- endif %}\n{%- if not tools is defined %}\n {%- set tools = none %}\n{%- endif %}\n{%- set user_messages = loop_messages | selectattr(\"role\", \"equalto\", \"user\") | list %}\n\n{#- This block checks for alternating user/assistant messages, skipping tool calling messages #}\n{%- set ns = namespace() %}\n{%- set ns.index = 0 %}\n{%- for message in loop_messages %}\n {%- if not (message.role == \"tool\" or message.role == \"tool_results\" or (message.tool_calls is defined and message.tool_calls is not none)) %}\n {%- if (message[\"role\"] == \"user\") != (ns.index % 2 == 0) %}\n {{- raise_exception(\"After the optional system message, conversation roles must alternate user/assistant/user/assistant/...\") }}\n {%- endif %}\n {%- set ns.index = ns.index + 1 %}\n {%- endif %}\n{%- endfor %}\n\n{{- bos_token }}\n{%- for message in loop_messages %}\n {%- if message[\"role\"] == \"user\" %}\n {%- if tools is not none and (message == user_messages[-1]) %}\n {{- \"[AVAILABLE_TOOLS] [\" }}\n {%- for tool in tools %}\n {%- set tool = tool.function %}\n {{- '{\"type\": \"function\", \"function\": {' }}\n {%- for key, val in tool.items() if key != \"return\" %}\n {%- if val is string %}\n {{- '\"' + key + '\": \"' + val + '\"' }}\n {%- else %}\n {{- '\"' + key + '\": ' + val|tojson }}\n {%- endif %}\n {%- if not loop.last %}\n {{- \", \" }}\n {%- endif %}\n {%- endfor %}\n {{- \"}}\" }}\n {%- if not loop.last %}\n {{- \", \" }}\n {%- else %}\n {{- \"]\" }}\n {%- endif %}\n {%- endfor %}\n {{- \"[/AVAILABLE_TOOLS]\" }}\n {%- endif %}\n {%- if loop.last and system_message is defined %}\n {{- \"[INST] \" + system_message + \"\\n\\n\" + message[\"content\"] + \"[/INST]\" }}\n {%- else %}\n {{- \"[INST] \" + message[\"content\"] + \"[/INST]\" }}\n {%- endif %}\n {%- elif message.tool_calls is defined and message.tool_calls is not none %}\n {{- \"[TOOL_CALLS] [\" }}\n {%- for tool_call in message.tool_calls %}\n {%- set out = tool_call.function|tojson %}\n {{- out[:-1] }}\n {%- if not tool_call.id is defined or tool_call.id|length != 9 %}\n {{- raise_exception(\"Tool call IDs should be alphanumeric strings with length 9!\") }}\n {%- endif %}\n {{- ', \"id\": \"' + tool_call.id + '\"}' }}\n {%- if not loop.last %}\n {{- \", \" }}\n {%- else %}\n {{- \"]\" + eos_token }}\n {%- endif %}\n {%- endfor %}\n {%- elif message[\"role\"] == \"assistant\" %}\n {{- \" \" + message[\"content\"]|trim + eos_token}}\n {%- elif message[\"role\"] == \"tool_results\" or message[\"role\"] == \"tool\" %}\n {%- if message.content is defined and message.content.content is defined %}\n {%- set content = message.content.content %}\n {%- else %}\n {%- set content = message.content %}\n {%- endif %}\n {{- '[TOOL_RESULTS] {\"content\": ' + content|string + \", \" }}\n {%- if not message.tool_call_id is defined or message.tool_call_id|length != 9 %}\n {{- raise_exception(\"Tool call IDs should be alphanumeric strings with length 9!\") }}\n {%- endif %}\n {{- '\"call_id\": \"' + message.tool_call_id + '\"}[/TOOL_RESULTS]' }}\n {%- else %}\n {{- raise_exception(\"Only user and assistant roles are supported, with the exception of an initial optional system message!\") }}\n {%- endif %}\n{%- endfor %}\n".to_string()) - ); - - let infer = Infer::new( - MockBackend {}, // never used; just to satisfy Infer::new signature - Validation::new(1, None, None, None, 1, 1, 1, 1, 1, false), - 1, - tokenizer_config, - HubProcessorConfig::default(), - ); - - let tools = Some(vec![Tool { - r#type: "function".to_string(), - function: FunctionDefinition { - name: "get_current_weather".to_string(), - description: Some("Get the current weather".to_string()), - arguments: json!({ - "type": "object", - "properties": { - "location": { - "type": "string", - "description": "The city and state, e.g. San Francisco, CA" - }, - "format": { - "type": "string", - "enum": ["celsius", "fahrenheit"], - "description": "The temperature unit to use. Infer this from the users location." - } - }, - "required": ["location", "format"] - }), - }, - }]); - - let request = ChatRequest { - model: None, - max_tokens: None, - logit_bias: None, - logprobs: None, - n: None, - messages: vec![Message { - name: None, - role: "user".to_string(), - content: MessageContent::SingleText( - "What is the weather like in New York?".to_string(), - ), - }], - seed: None, - stop: None, - stream: false, - tools: tools, - tool_choice: ToolChoice::Auto, - tool_prompt: Some("Given the functions available, please respond with a JSON for a function call with its proper arguments that best answers the given prompt. Respond in the format {name: function name, parameters: dictionary of argument name and its value}.Do not use variables.".to_string()), - temperature: None, - response_format: None, - guideline: None, - presence_penalty: None, - frequency_penalty: None, - top_p: None, - top_logprobs: None, - stream_options: None, - }; - - let (generate, using_tools) = request.try_into_generate(&infer).unwrap(); - assert_eq!(using_tools, true); - assert_eq!(generate.inputs, "[AVAILABLE_TOOLS] [{\"type\": \"function\", \"function\": {\"arguments\": {\"properties\":{\"format\":{\"description\":\"The temperature unit to use. Infer this from the users location.\",\"enum\":[\"celsius\",\"fahrenheit\"],\"type\":\"string\"},\"location\":{\"description\":\"The city and state, e.g. San Francisco, CA\",\"type\":\"string\"}},\"required\":[\"location\",\"format\"],\"type\":\"object\"}, \"description\": \"Get the current weather\", \"name\": \"get_current_weather\"}}, {\"type\": \"function\", \"function\": {\"arguments\": {\"properties\":{\"content\":{\"description\":\"The response content\",\"type\":\"string\"}},\"required\":[\"content\"],\"type\":\"object\"}, \"description\": \"Open ended response with no specific tool selected\", \"name\": \"no_tool\"}}][/AVAILABLE_TOOLS][INST] What is the weather like in New York?\n---\nGiven the functions available, please respond with a JSON for a function call with its proper arguments that best answers the given prompt. Respond in the format {name: function name, parameters: dictionary of argument name and its value}.Do not use variables.[/INST]".to_string()); - } }