@@ -849,8 +849,8 @@ pub(crate) struct ChatRequest {
849849
850850 /// A specific tool to use. If not provided, the model will default to use any of the tools provided in the tools parameter.
851851 #[ serde( default ) ]
852- #[ schema( nullable = true , example = "null" ) ]
853- pub tool_choice : ToolChoice ,
852+ #[ schema( nullable = true , default = "null" , example = "null" ) ]
853+ pub tool_choice : Option < ChatCompletionToolChoiceOption > ,
854854
855855 /// Response format constraints for the generation.
856856 ///
@@ -906,8 +906,18 @@ impl ChatRequest {
906906 let ( inputs, grammar, using_tools) = prepare_chat_input (
907907 infer,
908908 response_format,
909- tools,
910- tool_choice,
909+ tools. clone ( ) ,
910+ // unwrap or default (use "auto" if tools are present, and "none" if not)
911+ tool_choice. map_or_else (
912+ || {
913+ if tools. is_some ( ) {
914+ ChatCompletionToolChoiceOption :: Auto
915+ } else {
916+ ChatCompletionToolChoiceOption :: NoTool
917+ }
918+ } ,
919+ |t| t,
920+ ) ,
911921 & tool_prompt,
912922 guideline,
913923 messages,
@@ -956,22 +966,6 @@ pub fn default_tool_prompt() -> String {
956966 "\n Given the functions available, please respond with a JSON for a function call with its proper arguments that best answers the given prompt. Respond in the format {name: function name, parameters: dictionary of argument name and its value}.Do not use variables.\n " . to_string ( )
957967}
958968
959- #[ derive( Clone , Debug , Deserialize , PartialEq , Serialize , ToSchema ) ]
960- #[ schema( example = "auto" ) ]
961- /// Controls which (if any) tool is called by the model.
962- pub enum ToolType {
963- /// Means the model can pick between generating a message or calling one or more tools.
964- #[ schema( rename = "auto" ) ]
965- OneOf ,
966- /// Means the model will not call any tool and instead generates a message.
967- #[ schema( rename = "none" ) ]
968- NoTool ,
969- /// Forces the model to call a specific tool.
970- #[ schema( rename = "function" ) ]
971- #[ serde( alias = "function" ) ]
972- Function ( FunctionName ) ,
973- }
974-
975969#[ derive( Clone , Debug , Deserialize , PartialEq , Serialize ) ]
976970#[ serde( tag = "type" ) ]
977971pub enum TypedChoice {
@@ -984,29 +978,59 @@ pub struct FunctionName {
984978 pub name : String ,
985979}
986980
987- #[ derive( Debug , Clone , PartialEq , Serialize , Deserialize , Default , ToSchema ) ]
981+ #[ derive( Debug , Clone , PartialEq , Serialize , Deserialize , ToSchema , Default ) ]
988982#[ serde( from = "ToolTypeDeserializer" ) ]
989- pub struct ToolChoice ( pub Option < ToolType > ) ;
983+ pub enum ChatCompletionToolChoiceOption {
984+ /// Means the model can pick between generating a message or calling one or more tools.
985+ #[ schema( rename = "auto" ) ]
986+ Auto ,
987+ /// Means the model will not call any tool and instead generates a message.
988+ #[ schema( rename = "none" ) ]
989+ #[ default]
990+ NoTool ,
991+ /// Means the model must call one or more tools.
992+ #[ schema( rename = "required" ) ]
993+ Required ,
994+ /// Forces the model to call a specific tool.
995+ #[ schema( rename = "function" ) ]
996+ #[ serde( alias = "function" ) ]
997+ Function ( FunctionName ) ,
998+ }
990999
991- #[ derive( Deserialize ) ]
1000+ #[ derive( Deserialize , ToSchema ) ]
9921001#[ serde( untagged) ]
1002+ /// Controls which (if any) tool is called by the model.
1003+ /// - `none` means the model will not call any tool and instead generates a message.
1004+ /// - `auto` means the model can pick between generating a message or calling one or more tools.
1005+ /// - `required` means the model must call one or more tools.
1006+ /// - Specifying a particular tool via `{\"type\": \"function\", \"function\": {\"name\": \"my_function\"}}` forces the model to call that tool.
1007+ ///
1008+ /// `none` is the default when no tools are present. `auto` is the default if tools are present."
9931009enum ToolTypeDeserializer {
1010+ /// `none` means the model will not call any tool and instead generates a message.
9941011 Null ,
1012+
1013+ /// `auto` means the model can pick between generating a message or calling one or more tools.
1014+ #[ schema( example = "auto" ) ]
9951015 String ( String ) ,
996- ToolType ( TypedChoice ) ,
1016+
1017+ /// Specifying a particular tool forces the model to call that tool, with structured function details.
1018+ #[ schema( example = r#"{"type": "function", "function": {"name": "my_function"}}"# ) ]
1019+ TypedChoice ( TypedChoice ) ,
9971020}
9981021
999- impl From < ToolTypeDeserializer > for ToolChoice {
1022+ impl From < ToolTypeDeserializer > for ChatCompletionToolChoiceOption {
10001023 fn from ( value : ToolTypeDeserializer ) -> Self {
10011024 match value {
1002- ToolTypeDeserializer :: Null => ToolChoice ( None ) ,
1025+ ToolTypeDeserializer :: Null => ChatCompletionToolChoiceOption :: NoTool ,
10031026 ToolTypeDeserializer :: String ( s) => match s. as_str ( ) {
1004- "none" => ToolChoice ( Some ( ToolType :: NoTool ) ) ,
1005- "auto" => ToolChoice ( Some ( ToolType :: OneOf ) ) ,
1006- _ => ToolChoice ( Some ( ToolType :: Function ( FunctionName { name : s } ) ) ) ,
1027+ "none" => ChatCompletionToolChoiceOption :: NoTool ,
1028+ "auto" => ChatCompletionToolChoiceOption :: Auto ,
1029+ "required" => ChatCompletionToolChoiceOption :: Required ,
1030+ _ => ChatCompletionToolChoiceOption :: Function ( FunctionName { name : s } ) ,
10071031 } ,
1008- ToolTypeDeserializer :: ToolType ( TypedChoice :: Function { function } ) => {
1009- ToolChoice ( Some ( ToolType :: Function ( function) ) )
1032+ ToolTypeDeserializer :: TypedChoice ( TypedChoice :: Function { function } ) => {
1033+ ChatCompletionToolChoiceOption :: Function ( function)
10101034 }
10111035 }
10121036 }
@@ -1619,20 +1643,27 @@ mod tests {
16191643 fn tool_choice_formats ( ) {
16201644 #[ derive( Deserialize ) ]
16211645 struct TestRequest {
1622- tool_choice : ToolChoice ,
1646+ tool_choice : ChatCompletionToolChoiceOption ,
16231647 }
16241648
16251649 let none = r#"{"tool_choice":"none"}"# ;
16261650 let de_none: TestRequest = serde_json:: from_str ( none) . unwrap ( ) ;
1627- assert_eq ! ( de_none. tool_choice, ToolChoice ( Some ( ToolType :: NoTool ) ) ) ;
1651+ assert_eq ! ( de_none. tool_choice, ChatCompletionToolChoiceOption :: NoTool ) ;
16281652
16291653 let auto = r#"{"tool_choice":"auto"}"# ;
16301654 let de_auto: TestRequest = serde_json:: from_str ( auto) . unwrap ( ) ;
1631- assert_eq ! ( de_auto. tool_choice, ToolChoice ( Some ( ToolType :: OneOf ) ) ) ;
1655+ assert_eq ! ( de_auto. tool_choice, ChatCompletionToolChoiceOption :: Auto ) ;
16321656
1633- let ref_choice = ToolChoice ( Some ( ToolType :: Function ( FunctionName {
1657+ let auto = r#"{"tool_choice":"required"}"# ;
1658+ let de_auto: TestRequest = serde_json:: from_str ( auto) . unwrap ( ) ;
1659+ assert_eq ! (
1660+ de_auto. tool_choice,
1661+ ChatCompletionToolChoiceOption :: Required
1662+ ) ;
1663+
1664+ let ref_choice = ChatCompletionToolChoiceOption :: Function ( FunctionName {
16341665 name : "myfn" . to_string ( ) ,
1635- } ) ) ) ;
1666+ } ) ;
16361667
16371668 let named = r#"{"tool_choice":"myfn"}"# ;
16381669 let de_named: TestRequest = serde_json:: from_str ( named) . unwrap ( ) ;
0 commit comments