1
1
"""Function inferrer class definition."""
2
2
3
+ from __future__ import annotations
4
+
5
+ import dataclasses
3
6
import inspect
4
- from collections .abc import Callable
5
- from typing import Any
7
+ import typing
8
+ from enum import EnumMeta
9
+ from typing import TYPE_CHECKING , Any , get_args , get_origin , get_type_hints
6
10
from warnings import warn
7
11
8
12
from docstring_parser import Docstring , parser
9
13
10
14
from openai_function_calling .function import Function
11
15
from openai_function_calling .helper_functions import python_type_to_json_schema_type
16
+ from openai_function_calling .json_schema_type import JsonSchemaType
12
17
from openai_function_calling .parameter import Parameter
13
18
19
+ if TYPE_CHECKING : # pragma: no cover
20
+ from collections .abc import Callable
21
+
14
22
15
23
class FunctionInferrer :
16
24
"""Class to help inferring a function definition from a reference."""
@@ -26,6 +34,7 @@ def infer_from_function_reference(function_reference: Callable) -> Function:
26
34
27
35
Return:
28
36
An instance of Function with inferred values.
37
+
29
38
"""
30
39
inferred_from_annotations : Function = FunctionInferrer ._infer_from_annotations (
31
40
function_reference
@@ -51,6 +60,7 @@ def _infer_from_docstring(function_reference: Callable) -> Function:
51
60
52
61
Returns:
53
62
The inferred Function instance.
63
+
54
64
"""
55
65
function_definition = Function (
56
66
name = function_reference .__name__ ,
@@ -91,29 +101,53 @@ def _infer_from_annotations(function_reference: Callable) -> Function:
91
101
92
102
Returns:
93
103
The inferred Function instance.
94
- """
95
- function_definition = Function (
96
- name = function_reference .__name__ ,
97
- description = "" ,
98
- parameters = [],
99
- )
100
-
101
- if hasattr (function_reference , "__annotations__" ):
102
- annotations : dict [str , Any ] = function_reference .__annotations__
103
104
104
- for key in annotations :
105
- if key == "return" :
106
- continue
107
-
108
- parameter_type : str = python_type_to_json_schema_type (
109
- annotations [key ].__name__ ,
105
+ """
106
+ annotations : dict [str , Any ] = get_type_hints (function_reference )
107
+ parameters : list [Parameter ] = []
108
+
109
+ for param_name , annotation_type in annotations .items ():
110
+ if param_name == "return" :
111
+ continue
112
+
113
+ origin = get_origin (annotation_type ) or annotation_type
114
+ args : tuple [Any , ...] = get_args (annotation_type )
115
+
116
+ if origin in [list , typing .List ]: # noqa: UP006
117
+ if not args :
118
+ raise ValueError (
119
+ f"Expected array parameter '{ param_name } ' to have an item type."
120
+ )
121
+ item_type = args [0 ]
122
+ parameter_type = JsonSchemaType .ARRAY .value
123
+ array_item_type = python_type_to_json_schema_type (
124
+ item_type .__name__ if hasattr (item_type , "__name__" ) else "Any"
125
+ )
126
+ elif origin in [dict , typing .Dict ]: # noqa: UP006
127
+ parameter_type = JsonSchemaType .OBJECT .value
128
+ array_item_type = None
129
+ else :
130
+ parameter_type = python_type_to_json_schema_type (
131
+ annotation_type .__name__
132
+ if hasattr (annotation_type , "__name__" )
133
+ else "Any"
110
134
)
111
135
112
- function_definition .parameters .append (
113
- Parameter (name = key , type = parameter_type )
136
+ array_item_type = None
137
+
138
+ parameters .append (
139
+ Parameter (
140
+ name = param_name ,
141
+ type = parameter_type ,
142
+ array_item_type = array_item_type ,
114
143
)
144
+ )
115
145
116
- return function_definition
146
+ return Function (
147
+ name = function_reference .__name__ ,
148
+ description = "" ,
149
+ parameters = parameters ,
150
+ )
117
151
118
152
@staticmethod
119
153
def _infer_from_inspection (function_reference : Callable ) -> Function :
@@ -124,6 +158,7 @@ def _infer_from_inspection(function_reference: Callable) -> Function:
124
158
125
159
Returns:
126
160
The inferred Function instance.
161
+
127
162
"""
128
163
function_definition = Function (
129
164
name = function_reference .__name__ ,
@@ -135,9 +170,31 @@ def _infer_from_inspection(function_reference: Callable) -> Function:
135
170
136
171
for name , parameter in inspected_parameters .items ():
137
172
parameter_type : str = python_type_to_json_schema_type (parameter .kind .name )
173
+ enum_values : list [str ] | None = None
174
+
175
+ if parameter_type == "null" :
176
+ if isinstance (parameter .annotation , EnumMeta ):
177
+ enum_values = list (
178
+ parameter .annotation ._value2member_map_ .keys () # noqa: SLF001
179
+ )
180
+ parameter_type = FunctionInferrer ._infer_list_item_type (enum_values )
181
+ elif dataclasses .is_dataclass (parameter .annotation ):
182
+ parameter_type = JsonSchemaType .OBJECT .value
138
183
139
184
function_definition .parameters .append (
140
- Parameter (name = name , type = parameter_type )
185
+ Parameter (name = name , type = parameter_type , enum = enum_values )
141
186
)
142
187
143
188
return function_definition
189
+
190
+ @staticmethod
191
+ def _infer_list_item_type (list_of_items : list [Any ]) -> str :
192
+ if len (list_of_items ) == 0 :
193
+ return JsonSchemaType .NULL .value
194
+
195
+ # Check if all items are the same type.
196
+ if len ({type (item ).__name__ for item in list_of_items }) == 1 :
197
+ item : Any = type (list_of_items [0 ]).__name__
198
+ return python_type_to_json_schema_type (item )
199
+
200
+ return JsonSchemaType .ANY .value
0 commit comments