Skip to content

Commit

Permalink
fix: adjust tool choice none logic, add test and small refactors
Browse files Browse the repository at this point in the history
  • Loading branch information
drbh committed Oct 18, 2024
1 parent 6837c5b commit c1fac74
Show file tree
Hide file tree
Showing 4 changed files with 81 additions and 42 deletions.
4 changes: 2 additions & 2 deletions docs/openapi.json
Original file line number Diff line number Diff line change
Expand Up @@ -923,7 +923,7 @@
"$ref": "#/components/schemas/GrammarType"
}
],
"default": "auto",
"default": "null",
"nullable": true
},
"seed": {
Expand Down Expand Up @@ -966,7 +966,7 @@
"$ref": "#/components/schemas/ToolChoice"
}
],
"default": "null",
"default": "auto",
"nullable": true
},
"tool_prompt": {
Expand Down
41 changes: 41 additions & 0 deletions integration-tests/models/test_tools_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,6 +371,47 @@ async def test_flash_llama_grammar_tools_sea_creatures_stream_required(
assert last_response == response_snapshot


@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llama_grammar_tools_sea_creatures_stream_none(
flash_llama_grammar_tools, response_snapshot
):
responses = await flash_llama_grammar_tools.chat(
max_tokens=100,
seed=24,
tools=tools,
tool_choice="none",
messages=[
{
"role": "system",
"content": "You're a helpful assistant! Answer the users question best you can. If the question is not answerable by the tools, just generate a response.",
},
{
"role": "user",
"content": "Tell me a story about 3 sea creatures",
},
],
stream=True,
)

count = 0
content_generated = ""
last_response = None
async for response in responses:
count += 1
content_generated += response.choices[0].delta.content
last_response = response
assert response.choices[0].delta.tool_calls is None

assert count == 100
print(content_generated)
assert (
content_generated
== "Once upon a time, in a vibrant ocean filled with coral reefs and schools of shimmering fish, lived three dear friends: Luna the sea turtle, Finley the friendly fish, and Crusty the wise crab.\n\nLuna was the oldest of the three. She had traveled the world, exploring hidden caves and shipwrecks, and collecting sparkling shells and shiny pebbles. Her shell was a beautiful mosaic of blues and greens, and her gentle eyes twinkled with the secrets of the deep"
)
assert last_response == response_snapshot


@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llama_grammar_tools_sea_creatures_stream_function_object(
Expand Down
40 changes: 18 additions & 22 deletions router/src/infer/tool_grammar.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,7 @@ impl ToolGrammar {
pub fn apply(
tools: Vec<Tool>,
tool_choice: ToolChoice,
) -> Result<(Vec<Tool>, Option<JsonSchemaTool>), InferError> {
// if no tools are provided, we return None and an empty vec
if tools.is_empty() {
return Ok((Vec::with_capacity(0), None));
}

) -> Result<Option<(Vec<Tool>, JsonSchemaTool)>, InferError> {
let tools_to_use = match tool_choice {
ToolChoice::Function(function) => {
vec![Self::find_tool_by_name(&tools, &function.name)?]
Expand Down Expand Up @@ -57,9 +52,14 @@ impl ToolGrammar {
}))
.collect::<Vec<_>>()
}
ToolChoice::NoTool => Vec::with_capacity(0),
ToolChoice::NoTool => vec![],
};

// if no tools are provided or if the user has selected the no_tool option, return None
if tools_to_use.is_empty() {
return Ok(None);
}

let functions: HashMap<String, serde_json::Value> = tools_to_use
.iter()
.map(|tool| {
Expand Down Expand Up @@ -106,22 +106,18 @@ impl ToolGrammar {
})
.collect();

let tool_schema = if tools_to_use.is_empty() {
None
} else {
Some(JsonSchemaTool {
functions_map: FunctionsMap { functions },
properties: Properties {
function: tools_to_use
.iter()
.map(|tool| FunctionRef {
ref_path: format!("#/$functions/{}", tool.function.name.clone()),
})
.collect(),
},
})
let tool_schema = JsonSchemaTool {
functions_map: FunctionsMap { functions },
properties: Properties {
function: tools_to_use
.iter()
.map(|tool| FunctionRef {
ref_path: format!("#/$functions/{}", tool.function.name.clone()),
})
.collect(),
},
};

Ok((tools_to_use, tool_schema))
Ok(Some((tools_to_use, tool_schema)))
}
}
38 changes: 20 additions & 18 deletions router/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -849,14 +849,14 @@ pub(crate) struct ChatRequest {

/// A specific tool to use. If not provided, the model will default to use any of the tools provided in the tools parameter.
#[serde(default)]
#[schema(nullable = true, default = "null", example = "null")]
pub tool_choice: Option<ToolChoice>,
#[schema(nullable = true, default = "auto", example = "auto")]
pub tool_choice: ToolChoice,

/// Response format constraints for the generation.
///
/// NOTE: A request can use `response_format` OR `tools` but not both.
#[serde(default)]
#[schema(nullable = true, default = "auto", example = "auto")]
#[schema(nullable = true, default = "null", example = "null")]
pub response_format: Option<GrammarType>,

/// A guideline to be used in the chat_template
Expand Down Expand Up @@ -903,8 +903,6 @@ impl ChatRequest {
Some(temperature) if temperature == 0.0 => (false, None),
other => (true, other),
};
// if no tool_choice is set, set default (Auto)
let tool_choice = tool_choice.unwrap_or_default();

if response_format.is_some() && tools.is_some() {
return Err(InferError::ToolError(
Expand All @@ -919,18 +917,22 @@ impl ChatRequest {
}
None => {
if let Some(tools) = tools {
let (updated_tools, tool_schema) = ToolGrammar::apply(tools, tool_choice)?;

let grammar = tool_schema
.as_ref()
.map(|t| GrammarType::Json(serde_json::json!(t)));

let inputs: String = infer.apply_chat_template(
guideline,
messages,
Some((updated_tools, tool_prompt)),
)?;
(inputs, grammar, tool_schema.is_some())
match ToolGrammar::apply(tools, tool_choice)? {
Some((updated_tools, tool_schema)) => {
let grammar = GrammarType::Json(serde_json::json!(tool_schema));
let inputs: String = infer.apply_chat_template(
guideline,
messages,
Some((updated_tools, tool_prompt)),
)?;
(inputs, Some(grammar), true)
}
None => {
// same as if no response_format or tools are set
let inputs = infer.apply_chat_template(guideline, messages, None)?;
(inputs, None, false)
}
}
} else {
// if no response_format or tools are set simply apply the chat template to generate inputs
let inputs = infer.apply_chat_template(guideline, messages, None)?;
Expand Down Expand Up @@ -1003,7 +1005,7 @@ pub enum ToolChoice {
#[default]
Auto,
/// Means the model will not call any tool and instead generates a message.
#[schema(rename = "none")]
#[serde(rename = "none")]
NoTool,
/// Means the model must call one or more tools.
Required,
Expand Down

0 comments on commit c1fac74

Please sign in to comment.