From 79a67aa4f953605e0ca2bfbc16742bc8ae335cea Mon Sep 17 00:00:00 2001 From: David Holtz Date: Tue, 15 Oct 2024 15:00:24 +0000 Subject: [PATCH] fix: refactor away prepare_chat_input and improve tool grammar apply control flow --- router/src/infer/tool_grammar.rs | 28 +++--- router/src/lib.rs | 152 +++++++++++++++++++++++++++--- router/src/server.rs | 155 ------------------------------- 3 files changed, 157 insertions(+), 178 deletions(-) diff --git a/router/src/infer/tool_grammar.rs b/router/src/infer/tool_grammar.rs index 772485590cc..982535535aa 100644 --- a/router/src/infer/tool_grammar.rs +++ b/router/src/infer/tool_grammar.rs @@ -58,7 +58,7 @@ impl ToolGrammar { })) .collect::>() } - ChatCompletionToolChoiceOption::NoTool => return Ok((tools, None)), + ChatCompletionToolChoiceOption::NoTool => Vec::with_capacity(0), }; let functions: HashMap = tools_to_use @@ -107,18 +107,22 @@ impl ToolGrammar { }) .collect(); - let tool_schema = JsonSchemaTool { - functions_map: FunctionsMap { functions }, - properties: Properties { - function: tools_to_use - .iter() - .map(|tool| FunctionRef { - ref_path: format!("#/$functions/{}", tool.function.name.clone()), - }) - .collect(), - }, + let tool_schema = if tools_to_use.is_empty() { + None + } else { + Some(JsonSchemaTool { + functions_map: FunctionsMap { functions }, + properties: Properties { + function: tools_to_use + .iter() + .map(|tool| FunctionRef { + ref_path: format!("#/$functions/{}", tool.function.name.clone()), + }) + .collect(), + }, + }) }; - Ok((tools_to_use, Some(tool_schema))) + Ok((tools_to_use, tool_schema)) } } diff --git a/router/src/lib.rs b/router/src/lib.rs index 6a54f568144..572bbd4ea62 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -11,8 +11,8 @@ pub mod logging; pub mod usage_stats; mod vertex; +use crate::infer::tool_grammar::ToolGrammar; use crate::infer::{Infer, InferError}; -use crate::server::prepare_chat_input; use serde::{Deserialize, Serialize}; use tracing::warn; use utoipa::ToSchema; @@ -904,22 +904,46 @@ impl ChatRequest { other => (true, other), }; // unwrap or default (use "auto" if tools are present, and "none" if not) - let choice = tool_choice.unwrap_or_else(|| { + let tool_choice = tool_choice.unwrap_or_else(|| { if tools.is_some() { ChatCompletionToolChoiceOption::Auto } else { ChatCompletionToolChoiceOption::NoTool } }); - let (inputs, grammar, using_tools) = prepare_chat_input( - infer, - response_format, - tools, - choice, - &tool_prompt, - guideline, - messages, - )?; + + if response_format.is_some() && tools.is_some() { + return Err(InferError::ToolError( + "Grammar and tools are mutually exclusive".into(), + )); + } + + let (inputs, grammar, using_tools) = match response_format { + Some(format) => { + let inputs = infer.apply_chat_template(guideline, messages, None)?; + (inputs, Some(format), false) + } + None => { + if let Some(tools) = tools { + let (updated_tools, tool_schema) = ToolGrammar::apply(tools, tool_choice)?; + + let grammar = tool_schema + .as_ref() + .map(|t| GrammarType::Json(serde_json::json!(t))); + + let inputs: String = infer.apply_chat_template( + guideline, + messages, + Some((updated_tools, tool_prompt)), + )?; + (inputs, grammar, tool_schema.is_some()) + } else { + // if no response_format or tools are set simply apply the chat template to generate inputs + let inputs = infer.apply_chat_template(guideline, messages, None)?; + (inputs, None, false) + } + } + }; Ok(( GenerateRequest { @@ -1196,6 +1220,7 @@ pub(crate) enum OutputMessage { } #[derive(Clone, Debug, Deserialize, ToSchema)] +#[cfg_attr(test, derive(PartialEq))] pub(crate) struct GenerateRequest { #[schema(example = "My name is Olivier and I")] pub inputs: String, @@ -1677,4 +1702,109 @@ mod tests { }) ); } + + #[test] + fn test_try_into_generate_with_tools_and_template() { + use crate::infer::Backend; + use crate::infer::InferStreamResponse; + use crate::validation::ValidGenerateRequest; + use crate::ChatTemplateVersions; + use crate::HubTokenizerConfig; + use crate::TokenizerConfigToken; + use crate::Tool; + + use core::future::Future; + use std::pin::Pin; + use tokio_stream::wrappers::UnboundedReceiverStream; + + // Mock Backend to avoid network requests. This is never used since we only test the conversion. It is mocked to satisfy the Backend trait. + struct MockBackend; + impl Backend for MockBackend { + fn schedule( + &self, + _: ValidGenerateRequest, + ) -> Result>, InferError> + { + unimplemented!("Never called in this test"); + } + fn health<'a, 't>(&'a self, _: bool) -> Pin + Send + 't>> + where + 'a: 't, + Self: 't, + { + unimplemented!("Never called in this test"); + } + } + + let mut tokenizer_config = HubTokenizerConfig::default(); + tokenizer_config.bos_token = Some(TokenizerConfigToken::String("".to_string())); + tokenizer_config.eos_token = Some(TokenizerConfigToken::String("".to_string())); + tokenizer_config.chat_template = Some( + ChatTemplateVersions::Single("{%- if messages[0][\"role\"] == \"system\" %}\n {%- set system_message = messages[0][\"content\"] %}\n {%- set loop_messages = messages[1:] %}\n{%- else %}\n {%- set loop_messages = messages %}\n{%- endif %}\n{%- if not tools is defined %}\n {%- set tools = none %}\n{%- endif %}\n{%- set user_messages = loop_messages | selectattr(\"role\", \"equalto\", \"user\") | list %}\n\n{#- This block checks for alternating user/assistant messages, skipping tool calling messages #}\n{%- set ns = namespace() %}\n{%- set ns.index = 0 %}\n{%- for message in loop_messages %}\n {%- if not (message.role == \"tool\" or message.role == \"tool_results\" or (message.tool_calls is defined and message.tool_calls is not none)) %}\n {%- if (message[\"role\"] == \"user\") != (ns.index % 2 == 0) %}\n {{- raise_exception(\"After the optional system message, conversation roles must alternate user/assistant/user/assistant/...\") }}\n {%- endif %}\n {%- set ns.index = ns.index + 1 %}\n {%- endif %}\n{%- endfor %}\n\n{{- bos_token }}\n{%- for message in loop_messages %}\n {%- if message[\"role\"] == \"user\" %}\n {%- if tools is not none and (message == user_messages[-1]) %}\n {{- \"[AVAILABLE_TOOLS] [\" }}\n {%- for tool in tools %}\n {%- set tool = tool.function %}\n {{- '{\"type\": \"function\", \"function\": {' }}\n {%- for key, val in tool.items() if key != \"return\" %}\n {%- if val is string %}\n {{- '\"' + key + '\": \"' + val + '\"' }}\n {%- else %}\n {{- '\"' + key + '\": ' + val|tojson }}\n {%- endif %}\n {%- if not loop.last %}\n {{- \", \" }}\n {%- endif %}\n {%- endfor %}\n {{- \"}}\" }}\n {%- if not loop.last %}\n {{- \", \" }}\n {%- else %}\n {{- \"]\" }}\n {%- endif %}\n {%- endfor %}\n {{- \"[/AVAILABLE_TOOLS]\" }}\n {%- endif %}\n {%- if loop.last and system_message is defined %}\n {{- \"[INST] \" + system_message + \"\\n\\n\" + message[\"content\"] + \"[/INST]\" }}\n {%- else %}\n {{- \"[INST] \" + message[\"content\"] + \"[/INST]\" }}\n {%- endif %}\n {%- elif message.tool_calls is defined and message.tool_calls is not none %}\n {{- \"[TOOL_CALLS] [\" }}\n {%- for tool_call in message.tool_calls %}\n {%- set out = tool_call.function|tojson %}\n {{- out[:-1] }}\n {%- if not tool_call.id is defined or tool_call.id|length != 9 %}\n {{- raise_exception(\"Tool call IDs should be alphanumeric strings with length 9!\") }}\n {%- endif %}\n {{- ', \"id\": \"' + tool_call.id + '\"}' }}\n {%- if not loop.last %}\n {{- \", \" }}\n {%- else %}\n {{- \"]\" + eos_token }}\n {%- endif %}\n {%- endfor %}\n {%- elif message[\"role\"] == \"assistant\" %}\n {{- \" \" + message[\"content\"]|trim + eos_token}}\n {%- elif message[\"role\"] == \"tool_results\" or message[\"role\"] == \"tool\" %}\n {%- if message.content is defined and message.content.content is defined %}\n {%- set content = message.content.content %}\n {%- else %}\n {%- set content = message.content %}\n {%- endif %}\n {{- '[TOOL_RESULTS] {\"content\": ' + content|string + \", \" }}\n {%- if not message.tool_call_id is defined or message.tool_call_id|length != 9 %}\n {{- raise_exception(\"Tool call IDs should be alphanumeric strings with length 9!\") }}\n {%- endif %}\n {{- '\"call_id\": \"' + message.tool_call_id + '\"}[/TOOL_RESULTS]' }}\n {%- else %}\n {{- raise_exception(\"Only user and assistant roles are supported, with the exception of an initial optional system message!\") }}\n {%- endif %}\n{%- endfor %}\n".to_string()) + ); + + let infer = Infer::new( + MockBackend {}, // never used; just to satisfy Infer::new signature + Validation::new(1, None, None, None, 1, 1, 1, 1, 1, false), + 1, + tokenizer_config, + HubProcessorConfig::default(), + ); + + let tools = Some(vec![Tool { + r#type: "function".to_string(), + function: FunctionDefinition { + name: "get_current_weather".to_string(), + description: Some("Get the current weather".to_string()), + arguments: json!({ + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state, e.g. San Francisco, CA" + }, + "format": { + "type": "string", + "enum": ["celsius", "fahrenheit"], + "description": "The temperature unit to use. Infer this from the users location." + } + }, + "required": ["location", "format"] + }), + }, + }]); + + let request = ChatRequest { + model: None, + max_tokens: None, + logit_bias: None, + logprobs: None, + n: None, + messages: vec![Message { + name: None, + role: "user".to_string(), + content: MessageContent::SingleText( + "What is the weather like in New York?".to_string(), + ), + }], + seed: None, + stop: None, + stream: false, + tools: tools, + tool_choice: None, + tool_prompt: Some("Given the functions available, please respond with a JSON for a function call with its proper arguments that best answers the given prompt. Respond in the format {name: function name, parameters: dictionary of argument name and its value}.Do not use variables.".to_string()), + temperature: None, + response_format: None, + guideline: None, + presence_penalty: None, + frequency_penalty: None, + top_p: None, + top_logprobs: None, + stream_options: None, + }; + + let (generate, using_tools) = request.try_into_generate(&infer).unwrap(); + assert_eq!(using_tools, true); + assert_eq!(generate.inputs, "[AVAILABLE_TOOLS] [{\"type\": \"function\", \"function\": {\"arguments\": {\"properties\":{\"format\":{\"description\":\"The temperature unit to use. Infer this from the users location.\",\"enum\":[\"celsius\",\"fahrenheit\"],\"type\":\"string\"},\"location\":{\"description\":\"The city and state, e.g. San Francisco, CA\",\"type\":\"string\"}},\"required\":[\"location\",\"format\"],\"type\":\"object\"}, \"description\": \"Get the current weather\", \"name\": \"get_current_weather\"}}, {\"type\": \"function\", \"function\": {\"arguments\": {\"properties\":{\"content\":{\"description\":\"The response content\",\"type\":\"string\"}},\"required\":[\"content\"],\"type\":\"object\"}, \"description\": \"Open ended response with no specific tool selected\", \"name\": \"no_tool\"}}][/AVAILABLE_TOOLS][INST] What is the weather like in New York?\n---\nGiven the functions available, please respond with a JSON for a function call with its proper arguments that best answers the given prompt. Respond in the format {name: function name, parameters: dictionary of argument name and its value}.Do not use variables.[/INST]".to_string()); + } } diff --git a/router/src/server.rs b/router/src/server.rs index 1d60f8aeeff..b2ab8adfe82 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -1,6 +1,5 @@ /// HTTP Server logic use crate::config::Config; -use crate::infer::tool_grammar::ToolGrammar; use crate::infer::{Backend, Infer, InferError, InferResponse, InferStreamResponse}; #[cfg(feature = "kserve")] use crate::kserve::{ @@ -2510,157 +2509,3 @@ pub enum WebServerError { #[error("Axum error: {0}")] Axum(#[from] axum::BoxError), } - -type PreparedInput = (String, Option, bool); - -pub(crate) fn prepare_chat_input( - infer: &Infer, - response_format: Option, - tools: Option>, - tool_choice: ChatCompletionToolChoiceOption, - tool_prompt: &str, - guideline: Option, - messages: Vec, -) -> Result { - if response_format.is_some() && tools.is_some() { - return Err(InferError::ToolError( - "Grammar and tools are mutually exclusive".into(), - )); - } - - // when response_format is set, tools are not included when applying the chat template to generate inputs - if let Some(format) = response_format { - let inputs = infer.apply_chat_template(guideline, messages, None)?; - return Ok((inputs, Some(format), false)); - } - - // when no response_format is set and tools are included, apply the chat template with the tools - // to generate inputs - if let Some(tools) = tools { - let (updated_tools, tool_schema) = ToolGrammar::apply(tools, tool_choice)?; - - let grammar = tool_schema - .as_ref() - .map(|t| GrammarType::Json(serde_json::json!(t))); - - let inputs: String = infer.apply_chat_template( - guideline, - messages, - Some((updated_tools, tool_prompt.into())), - )?; - return Ok((inputs, grammar, tool_schema.is_some())); - } - - // if no response_format or tools are set simply apply the chat template to generate inputs - let inputs = infer.apply_chat_template(guideline, messages, None)?; - Ok((inputs, None, false)) -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::ChatTemplateVersions; - use crate::HubTokenizerConfig; - use crate::TokenizerConfigToken; - use crate::Tool; - - use serde_json::json; - - #[test] - fn test_prepare_chat_input() { - // Mock Backend to avoid network requests - struct MockBackend; - - impl Backend for MockBackend { - fn schedule( - &self, - _request: crate::validation::ValidGenerateRequest, - ) -> Result< - tokio_stream::wrappers::UnboundedReceiverStream< - Result, - >, - InferError, - > { - unimplemented!("Never called in this test"); - } - fn health<'a, 'async_trait>( - &'a self, - _current_health: bool, - ) -> core::pin::Pin< - Box + core::marker::Send + 'async_trait>, - > - where - 'a: 'async_trait, - Self: 'async_trait, - { - unimplemented!("Never called in this test"); - } - } - - let backend = MockBackend {}; - - let mut tokenizer_config = HubTokenizerConfig::default(); - - // mock tokenizer config values - tokenizer_config.bos_token = Some(TokenizerConfigToken::String("".to_string())); - tokenizer_config.eos_token = Some(TokenizerConfigToken::String("".to_string())); - tokenizer_config.chat_template = Some( - ChatTemplateVersions::Single("{%- if messages[0][\"role\"] == \"system\" %}\n {%- set system_message = messages[0][\"content\"] %}\n {%- set loop_messages = messages[1:] %}\n{%- else %}\n {%- set loop_messages = messages %}\n{%- endif %}\n{%- if not tools is defined %}\n {%- set tools = none %}\n{%- endif %}\n{%- set user_messages = loop_messages | selectattr(\"role\", \"equalto\", \"user\") | list %}\n\n{#- This block checks for alternating user/assistant messages, skipping tool calling messages #}\n{%- set ns = namespace() %}\n{%- set ns.index = 0 %}\n{%- for message in loop_messages %}\n {%- if not (message.role == \"tool\" or message.role == \"tool_results\" or (message.tool_calls is defined and message.tool_calls is not none)) %}\n {%- if (message[\"role\"] == \"user\") != (ns.index % 2 == 0) %}\n {{- raise_exception(\"After the optional system message, conversation roles must alternate user/assistant/user/assistant/...\") }}\n {%- endif %}\n {%- set ns.index = ns.index + 1 %}\n {%- endif %}\n{%- endfor %}\n\n{{- bos_token }}\n{%- for message in loop_messages %}\n {%- if message[\"role\"] == \"user\" %}\n {%- if tools is not none and (message == user_messages[-1]) %}\n {{- \"[AVAILABLE_TOOLS] [\" }}\n {%- for tool in tools %}\n {%- set tool = tool.function %}\n {{- '{\"type\": \"function\", \"function\": {' }}\n {%- for key, val in tool.items() if key != \"return\" %}\n {%- if val is string %}\n {{- '\"' + key + '\": \"' + val + '\"' }}\n {%- else %}\n {{- '\"' + key + '\": ' + val|tojson }}\n {%- endif %}\n {%- if not loop.last %}\n {{- \", \" }}\n {%- endif %}\n {%- endfor %}\n {{- \"}}\" }}\n {%- if not loop.last %}\n {{- \", \" }}\n {%- else %}\n {{- \"]\" }}\n {%- endif %}\n {%- endfor %}\n {{- \"[/AVAILABLE_TOOLS]\" }}\n {%- endif %}\n {%- if loop.last and system_message is defined %}\n {{- \"[INST] \" + system_message + \"\\n\\n\" + message[\"content\"] + \"[/INST]\" }}\n {%- else %}\n {{- \"[INST] \" + message[\"content\"] + \"[/INST]\" }}\n {%- endif %}\n {%- elif message.tool_calls is defined and message.tool_calls is not none %}\n {{- \"[TOOL_CALLS] [\" }}\n {%- for tool_call in message.tool_calls %}\n {%- set out = tool_call.function|tojson %}\n {{- out[:-1] }}\n {%- if not tool_call.id is defined or tool_call.id|length != 9 %}\n {{- raise_exception(\"Tool call IDs should be alphanumeric strings with length 9!\") }}\n {%- endif %}\n {{- ', \"id\": \"' + tool_call.id + '\"}' }}\n {%- if not loop.last %}\n {{- \", \" }}\n {%- else %}\n {{- \"]\" + eos_token }}\n {%- endif %}\n {%- endfor %}\n {%- elif message[\"role\"] == \"assistant\" %}\n {{- \" \" + message[\"content\"]|trim + eos_token}}\n {%- elif message[\"role\"] == \"tool_results\" or message[\"role\"] == \"tool\" %}\n {%- if message.content is defined and message.content.content is defined %}\n {%- set content = message.content.content %}\n {%- else %}\n {%- set content = message.content %}\n {%- endif %}\n {{- '[TOOL_RESULTS] {\"content\": ' + content|string + \", \" }}\n {%- if not message.tool_call_id is defined or message.tool_call_id|length != 9 %}\n {{- raise_exception(\"Tool call IDs should be alphanumeric strings with length 9!\") }}\n {%- endif %}\n {{- '\"call_id\": \"' + message.tool_call_id + '\"}[/TOOL_RESULTS]' }}\n {%- else %}\n {{- raise_exception(\"Only user and assistant roles are supported, with the exception of an initial optional system message!\") }}\n {%- endif %}\n{%- endfor %}\n".to_string()) - ); - - let infer = Infer::new( - backend, - Validation::new(1, None, None, None, 1, 1, 1, 1, 1, false), - 1, - tokenizer_config, - HubProcessorConfig::default(), - ); - let response_format = None; - let tools = Some(vec![Tool { - r#type: "function".to_string(), - function: FunctionDefinition { - name: "get_current_weather".to_string(), - description: Some("Get the current weather".to_string()), - arguments: json!({ - "type": "object", - "properties": { - "location": { - "type": "string", - "description": "The city and state, e.g. San Francisco, CA" - }, - "format": { - "type": "string", - "enum": ["celsius", "fahrenheit"], - "description": "The temperature unit to use. Infer this from the users location." - } - }, - "required": ["location", "format"] - }), - }, - }]); - let tool_prompt = "Given the functions available, please respond with a JSON for a function call with its proper arguments that best answers the given prompt. Respond in the format {name: function name, parameters: dictionary of argument name and its value}.Do not use variables."; - let guideline = None; - let messages = vec![Message { - name: None, - role: "user".to_string(), - content: MessageContent::SingleText( - "What is the weather like in New York?".to_string(), - ), - }]; - - let result = prepare_chat_input( - &infer, - response_format, - tools, - ChatCompletionToolChoiceOption::Auto, - tool_prompt, - guideline, - messages, - ); - - assert!(result.is_ok()); - let (inputs, _grammar, using_tools) = result.expect("Failed to prepare chat input"); - assert_eq!(using_tools, true); - assert_eq!(inputs, "[AVAILABLE_TOOLS] [{\"type\": \"function\", \"function\": {\"arguments\": {\"properties\":{\"format\":{\"description\":\"The temperature unit to use. Infer this from the users location.\",\"enum\":[\"celsius\",\"fahrenheit\"],\"type\":\"string\"},\"location\":{\"description\":\"The city and state, e.g. San Francisco, CA\",\"type\":\"string\"}},\"required\":[\"location\",\"format\"],\"type\":\"object\"}, \"description\": \"Get the current weather\", \"name\": \"get_current_weather\"}}, {\"type\": \"function\", \"function\": {\"arguments\": {\"properties\":{\"content\":{\"description\":\"The response content\",\"type\":\"string\"}},\"required\":[\"content\"],\"type\":\"object\"}, \"description\": \"Open ended response with no specific tool selected\", \"name\": \"no_tool\"}}][/AVAILABLE_TOOLS][INST] What is the weather like in New York?\n---\nGiven the functions available, please respond with a JSON for a function call with its proper arguments that best answers the given prompt. Respond in the format {name: function name, parameters: dictionary of argument name and its value}.Do not use variables.[/INST]".to_string()); - } -}