Skip to content

Commit

Permalink
fix: simplify tool choice logic, improve tests, openapi and rust docs
Browse files Browse the repository at this point in the history
  • Loading branch information
drbh committed Oct 15, 2024
1 parent 970d8dc commit eefd725
Show file tree
Hide file tree
Showing 9 changed files with 226 additions and 79 deletions.
5 changes: 3 additions & 2 deletions docs/openapi.json
Original file line number Diff line number Diff line change
Expand Up @@ -863,7 +863,8 @@
}
}
}
]
],
"description": "<https://platform.openai.com/docs/guides/function-calling/configuring-function-calling-behavior-using-the-tool_choice-parameter>"
},
"ChatCompletionTopLogprob": {
"type": "object",
Expand Down Expand Up @@ -2184,4 +2185,4 @@
"description": "Hugging Face Text Generation Inference API"
}
]
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
{
"choices": [
{
"delta": {
"role": "assistant",
"tool_calls": {
"function": {
"arguments": "<|eot_id|>",
"name": null
},
"id": "",
"index": 0,
"type": "function"
}
},
"finish_reason": "stop",
"index": 0,
"logprobs": null
}
],
"created": 1729000499,
"id": "",
"model": "meta-llama/Llama-3.1-8B-Instruct",
"object": "chat.completion.chunk",
"system_fingerprint": "2.3.2-dev0-native",
"usage": null
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
{
"choices": [
{
"delta": {
"content": null,
"role": "assistant",
"tool_calls": {
"function": {
"arguments": "<|eot_id|>",
"name": null
},
"id": "",
"index": 0,
"type": "function"
}
},
"finish_reason": "stop",
"index": 0,
"logprobs": null
}
],
"created": 1728998230,
"id": "",
"model": "meta-llama/Llama-3.1-8B-Instruct",
"object": "chat.completion.chunk",
"system_fingerprint": "2.3.2-dev0-native",
"usage": null
}
103 changes: 102 additions & 1 deletion integration-tests/models/test_tools_llama.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import pytest
import requests
import json


@pytest.fixture(scope="module")
Expand Down Expand Up @@ -174,7 +176,7 @@ async def test_flash_llama_grammar_tools_choice(
"function": {
"description": None,
"name": "get_current_weather",
"arguments": {"format": "celsius", "location": "Brooklyn, NY"},
"arguments": {"format": "celsius", "location": "Brooklyn, New York"},
},
}
]
Expand Down Expand Up @@ -327,3 +329,102 @@ async def test_flash_llama_grammar_tools_sea_creatures_stream(
== "Once upon a time, in the ocean, there lived three sea creatures. There was a wise old octopus named Bob, a mischievous seagull named Sam, and a gentle sea turtle named Luna. They all lived together in a beautiful coral reef, surrounded by colorful fish and swaying sea fans"
)
assert last_response == response_snapshot


@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llama_grammar_tools_sea_creatures_stream_required(
flash_llama_grammar_tools, response_snapshot
):
responses = await flash_llama_grammar_tools.chat(
max_tokens=100,
seed=24,
tools=tools,
tool_choice="required",
messages=[
{
"role": "system",
"content": "You're a helpful assistant! Answer the users question best you can. If the question is not answerable by the tools, just generate a response.",
},
{
"role": "user",
"content": "Tell me a story about 3 sea creatures",
},
],
stream=True,
)

count = 0
tool_calls_generated = ""
last_response = None
async for response in responses:
count += 1
assert response.choices[0].delta.content is None
tool_calls_generated += response.choices[0].delta.tool_calls.function.arguments
last_response = response

assert count == 29
assert (
tool_calls_generated
== '{"function": {"_name": "get_current_weather", "format": "celsius", "location": "San Francisco, CA"}}<|eot_id|>'
)
assert last_response == response_snapshot


@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llama_grammar_tools_sea_creatures_stream_function_object(
flash_llama_grammar_tools, response_snapshot
):
# using `requests` to send the request until the client library supports tool_choice as a function object
responses = requests.post(
f"{flash_llama_grammar_tools.base_url}/v1/chat/completions",
headers=flash_llama_grammar_tools.headers,
json={
"model": "tgi",
"messages": [
{
"role": "system",
"content": "You're a helpful assistant! Answer the users question best you can. If the question is not answerable by the tools, just generate a response.",
},
{
"role": "user",
"content": "Tell me a story about 3 sea creatures",
},
],
"tools": tools,
"tool_choice": {
"type": "function",
"function": {"name": "get_current_weather"},
},
"seed": 24,
"max_tokens": 100,
"stream": True,
},
stream=True,
)
# iterate over the response in chunks
count = 0
tool_calls_generated = ""
last_response = None
for chunk in responses.iter_content(chunk_size=1024):
if chunk:
count += 1
# remove the "data: " prefix, trailing newline, and split the chunk into individual lines
lines = chunk.decode("utf-8").replace("data: ", "").rstrip("\n").split("\n")
for line in lines:
if line == "[DONE]":
break
response = json.loads(line)
tool_calls_generated += response["choices"][0]["delta"]["tool_calls"][
"function"
]["arguments"]
last_response = response

assert count == 30
print(tool_calls_generated)
assert (
tool_calls_generated
== '{"function": {"_name": "get_current_weather", "format": "celsius", "location": "Tokyo, JP"}}<|eot_id|>'
)
assert last_response == response_snapshot
58 changes: 28 additions & 30 deletions router/src/infer/tool_grammar.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,39 +27,37 @@ impl ToolGrammar {
return Ok((tools, None));
}

let mut tools = tools.clone();

// add the no_tool function to the tools as long as we are not required to use a specific tool
if tool_choice != ChatCompletionToolChoiceOption::Required {
let no_tool = Tool {
r#type: "function".to_string(),
function: FunctionDefinition {
name: "no_tool".to_string(),
description: Some(
"Open ended response with no specific tool selected".to_string(),
),
arguments: json!({
"type": "object",
"properties": {
"content": {
"type": "string",
"description": "The response content",
}
},
"required": ["content"]
}),
},
};
tools.push(no_tool);
}

// if tools are provided and no tool_choice we default to the OneOf
let tools_to_use = match tool_choice {
ChatCompletionToolChoiceOption::Function(function) => {
vec![Self::find_tool_by_name(&tools, &function.name)?]
}
ChatCompletionToolChoiceOption::Required => tools.clone(),
ChatCompletionToolChoiceOption::Auto => tools.clone(),
ChatCompletionToolChoiceOption::Required => tools,
ChatCompletionToolChoiceOption::Auto => {
// only add the no_tool function if the user has selected the auto option
tools
.iter()
.cloned()
.chain(std::iter::once(Tool {
r#type: "function".to_string(),
function: FunctionDefinition {
name: "no_tool".to_string(),
description: Some(
"Open ended response with no specific tool selected".to_string(),
),
arguments: json!({
"type": "object",
"properties": {
"content": {
"type": "string",
"description": "The response content",
}
},
"required": ["content"]
}),
},
}))
.collect::<Vec<_>>()
}
ChatCompletionToolChoiceOption::NoTool => return Ok((tools, None)),
};

Expand Down Expand Up @@ -121,6 +119,6 @@ impl ToolGrammar {
},
};

Ok((tools, Some(tool_schema)))
Ok((tools_to_use, Some(tool_schema)))
}
}
65 changes: 34 additions & 31 deletions router/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -903,21 +903,19 @@ impl ChatRequest {
Some(temperature) if temperature == 0.0 => (false, None),
other => (true, other),
};
// unwrap or default (use "auto" if tools are present, and "none" if not)
let choice = tool_choice.unwrap_or_else(|| {
if tools.is_some() {
ChatCompletionToolChoiceOption::Auto
} else {
ChatCompletionToolChoiceOption::NoTool
}
});
let (inputs, grammar, using_tools) = prepare_chat_input(
infer,
response_format,
tools.clone(),
// unwrap or default (use "auto" if tools are present, and "none" if not)
tool_choice.map_or_else(
|| {
if tools.is_some() {
ChatCompletionToolChoiceOption::Auto
} else {
ChatCompletionToolChoiceOption::NoTool
}
},
|t| t,
),
tools,
choice,
&tool_prompt,
guideline,
messages,
Expand Down Expand Up @@ -980,6 +978,7 @@ pub struct FunctionName {

#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, ToSchema, Default)]
#[serde(from = "ToolTypeDeserializer")]
/// <https://platform.openai.com/docs/guides/function-calling/configuring-function-calling-behavior-using-the-tool_choice-parameter>
pub enum ChatCompletionToolChoiceOption {
/// Means the model can pick between generating a message or calling one or more tools.
#[schema(rename = "auto")]
Expand All @@ -991,7 +990,7 @@ pub enum ChatCompletionToolChoiceOption {
/// Means the model must call one or more tools.
#[schema(rename = "required")]
Required,
/// Forces the model to call a specific tool.
/// Forces the model to call a specific tool. This structure aligns with the `OpenAI` API schema to force a specific tool.
#[schema(rename = "function")]
#[serde(alias = "function")]
Function(FunctionName),
Expand Down Expand Up @@ -1646,32 +1645,36 @@ mod tests {
tool_choice: ChatCompletionToolChoiceOption,
}

let none = r#"{"tool_choice":"none"}"#;
let de_none: TestRequest = serde_json::from_str(none).unwrap();
let de_none: TestRequest = serde_json::from_str(r#"{"tool_choice":"none"}"#).unwrap();
assert_eq!(de_none.tool_choice, ChatCompletionToolChoiceOption::NoTool);

let auto = r#"{"tool_choice":"auto"}"#;
let de_auto: TestRequest = serde_json::from_str(auto).unwrap();
let de_auto: TestRequest = serde_json::from_str(r#"{"tool_choice":"auto"}"#).unwrap();
assert_eq!(de_auto.tool_choice, ChatCompletionToolChoiceOption::Auto);

let auto = r#"{"tool_choice":"required"}"#;
let de_auto: TestRequest = serde_json::from_str(auto).unwrap();
let de_required: TestRequest =
serde_json::from_str(r#"{"tool_choice":"required"}"#).unwrap();
assert_eq!(
de_auto.tool_choice,
de_required.tool_choice,
ChatCompletionToolChoiceOption::Required
);

let ref_choice = ChatCompletionToolChoiceOption::Function(FunctionName {
name: "myfn".to_string(),
});

let named = r#"{"tool_choice":"myfn"}"#;
let de_named: TestRequest = serde_json::from_str(named).unwrap();
assert_eq!(de_named.tool_choice, ref_choice);

let openai_named = r#"{"tool_choice":{"type":"function","function":{"name":"myfn"}}}"#;
let de_openai_named: TestRequest = serde_json::from_str(openai_named).unwrap();
let de_named: TestRequest = serde_json::from_str(r#"{"tool_choice":"myfn"}"#).unwrap();
assert_eq!(
de_named.tool_choice,
ChatCompletionToolChoiceOption::Function(FunctionName {
name: "myfn".to_string(),
})
);

assert_eq!(de_openai_named.tool_choice, ref_choice);
let de_openai_named: TestRequest = serde_json::from_str(
r#"{"tool_choice":{"type":"function","function":{"name":"myfn"}}}"#,
)
.unwrap();
assert_eq!(
de_openai_named.tool_choice,
ChatCompletionToolChoiceOption::Function(FunctionName {
name: "myfn".to_string(),
})
);
}
}
2 changes: 1 addition & 1 deletion router/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2661,6 +2661,6 @@ mod tests {
assert!(result.is_ok());
let (inputs, _grammar, using_tools) = result.expect("Failed to prepare chat input");
assert_eq!(using_tools, true);
assert_eq!(inputs, "<s>[AVAILABLE_TOOLS] [{\"type\": \"function\", \"function\": {\"arguments\": {\"properties\":{\"format\":{\"description\":\"The temperature unit to use. Infer this from the users location.\",\"enum\":[\"celsius\",\"fahrenheit\"],\"type\":\"string\"},\"location\":{\"description\":\"The city and state, e.g. San Francisco, CA\",\"type\":\"string\"}},\"required\":[\"location\",\"format\"],\"type\":\"object\"}, \"description\": \"Get the current weather\", \"name\": \"get_current_weather\"}}, {\"type\": \"function\", \"function\": {\"arguments\": {\"properties\":{\"content\":{\"description\":\"The response content\",\"type\":\"string\"}},\"required\":[\"content\"],\"type\":\"object\"}, \"description\": \"Open ened response with no specific tool selected\", \"name\": \"no_tool\"}}][/AVAILABLE_TOOLS][INST] What is the weather like in New York?\n---\nGiven the functions available, please respond with a JSON for a function call with its proper arguments that best answers the given prompt. Respond in the format {name: function name, parameters: dictionary of argument name and its value}.Do not use variables.[/INST]".to_string());
assert_eq!(inputs, "<s>[AVAILABLE_TOOLS] [{\"type\": \"function\", \"function\": {\"arguments\": {\"properties\":{\"format\":{\"description\":\"The temperature unit to use. Infer this from the users location.\",\"enum\":[\"celsius\",\"fahrenheit\"],\"type\":\"string\"},\"location\":{\"description\":\"The city and state, e.g. San Francisco, CA\",\"type\":\"string\"}},\"required\":[\"location\",\"format\"],\"type\":\"object\"}, \"description\": \"Get the current weather\", \"name\": \"get_current_weather\"}}, {\"type\": \"function\", \"function\": {\"arguments\": {\"properties\":{\"content\":{\"description\":\"The response content\",\"type\":\"string\"}},\"required\":[\"content\"],\"type\":\"object\"}, \"description\": \"Open ended response with no specific tool selected\", \"name\": \"no_tool\"}}][/AVAILABLE_TOOLS][INST] What is the weather like in New York?\n---\nGiven the functions available, please respond with a JSON for a function call with its proper arguments that best answers the given prompt. Respond in the format {name: function name, parameters: dictionary of argument name and its value}.Do not use variables.[/INST]".to_string());
}
}
Loading

0 comments on commit eefd725

Please sign in to comment.