Skip to content

Commit 06fb0d7

Browse files
authored
Fix handling of list parameters (#16)
* Check if parameter is enum * Check if parameter is enum * Check array type * Fix ruff errors * Fix ruff and black * Fix handling of primitive and typing list parameters * Update gh workflow to use ruff format * Remove unecessary include * Remove useless change * Fix ruff lint and format * Make compatible with python 3.9
1 parent 2569fb3 commit 06fb0d7

14 files changed

+544
-366
lines changed

.github/workflows/test-application.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,13 +24,13 @@ jobs:
2424
- name: Install dependencies
2525
run: |
2626
python3 -m pip install poetry
27-
poetry install --with dev
27+
poetry install
2828
- name: Check formatting with ruff
2929
run: |
3030
poetry run ruff check .
3131
- name: Check formatting with black
3232
run: |
33-
poetry run black --check .
33+
poetry run ruff format --check .
3434
- name: Test Coverage
3535
run: |
3636
poetry run coverage run --source openai_function_calling -m pytest

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,3 +7,4 @@ dist
77
.coverage
88
htmlcov
99
.mypy_cache
10+
.python-version

examples/weather_functions_infer.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
import json
88
from collections.abc import Callable
9+
from enum import Enum
910
from typing import Any
1011

1112
from openai import OpenAI
@@ -21,8 +22,15 @@
2122
from openai_function_calling.tool_helpers import ToolHelpers
2223

2324

25+
class TemperatureUnit(Enum):
26+
"""Temperature units."""
27+
28+
CELSIUS = "celsius"
29+
FAHRENHEIT = "fahrenheit"
30+
31+
2432
# Define our functions.
25-
def get_current_weather(location: str, unit: str) -> str:
33+
def get_current_weather(location: str, unit: TemperatureUnit) -> str:
2634
"""Get the current weather and return a summary."""
2735
return f"It is currently sunny in {location} and 75 degrees {unit}."
2836

openai_function_calling/function.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ def __init__(
4646
parameters: A list of parameters.
4747
required_parameters: A list of parameter names that are required to run the\
4848
function.
49+
4950
"""
5051
self.name: str = name
5152
self.description: str = description
@@ -80,6 +81,7 @@ def to_dict(self) -> FunctionDict:
8081
8182
Returns:
8283
A JSON schema representation of the function.
84+
8385
"""
8486
return self.to_json_schema()
8587

@@ -91,6 +93,7 @@ def to_json_schema(self) -> FunctionDict:
9193
9294
Returns:
9395
A JSON schema representation of the function.
96+
9497
"""
9598
self.validate()
9699

@@ -119,6 +122,7 @@ def merge(self, other_function: Function) -> None:
119122
120123
Args:
121124
other_function: The other function to merge into the current.
125+
122126
"""
123127
if not self.name:
124128
self.name = other_function.name

openai_function_calling/function_inferrer.py

Lines changed: 78 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,24 @@
11
"""Function inferrer class definition."""
22

3+
from __future__ import annotations
4+
5+
import dataclasses
36
import inspect
4-
from collections.abc import Callable
5-
from typing import Any
7+
import typing
8+
from enum import EnumMeta
9+
from typing import TYPE_CHECKING, Any, get_args, get_origin, get_type_hints
610
from warnings import warn
711

812
from docstring_parser import Docstring, parser
913

1014
from openai_function_calling.function import Function
1115
from openai_function_calling.helper_functions import python_type_to_json_schema_type
16+
from openai_function_calling.json_schema_type import JsonSchemaType
1217
from openai_function_calling.parameter import Parameter
1318

19+
if TYPE_CHECKING: # pragma: no cover
20+
from collections.abc import Callable
21+
1422

1523
class FunctionInferrer:
1624
"""Class to help inferring a function definition from a reference."""
@@ -26,6 +34,7 @@ def infer_from_function_reference(function_reference: Callable) -> Function:
2634
2735
Return:
2836
An instance of Function with inferred values.
37+
2938
"""
3039
inferred_from_annotations: Function = FunctionInferrer._infer_from_annotations(
3140
function_reference
@@ -51,6 +60,7 @@ def _infer_from_docstring(function_reference: Callable) -> Function:
5160
5261
Returns:
5362
The inferred Function instance.
63+
5464
"""
5565
function_definition = Function(
5666
name=function_reference.__name__,
@@ -91,29 +101,53 @@ def _infer_from_annotations(function_reference: Callable) -> Function:
91101
92102
Returns:
93103
The inferred Function instance.
94-
"""
95-
function_definition = Function(
96-
name=function_reference.__name__,
97-
description="",
98-
parameters=[],
99-
)
100-
101-
if hasattr(function_reference, "__annotations__"):
102-
annotations: dict[str, Any] = function_reference.__annotations__
103104
104-
for key in annotations:
105-
if key == "return":
106-
continue
107-
108-
parameter_type: str = python_type_to_json_schema_type(
109-
annotations[key].__name__,
105+
"""
106+
annotations: dict[str, Any] = get_type_hints(function_reference)
107+
parameters: list[Parameter] = []
108+
109+
for param_name, annotation_type in annotations.items():
110+
if param_name == "return":
111+
continue
112+
113+
origin = get_origin(annotation_type) or annotation_type
114+
args: tuple[Any, ...] = get_args(annotation_type)
115+
116+
if origin in [list, typing.List]: # noqa: UP006
117+
if not args:
118+
raise ValueError(
119+
f"Expected array parameter '{param_name}' to have an item type."
120+
)
121+
item_type = args[0]
122+
parameter_type = JsonSchemaType.ARRAY.value
123+
array_item_type = python_type_to_json_schema_type(
124+
item_type.__name__ if hasattr(item_type, "__name__") else "Any"
125+
)
126+
elif origin in [dict, typing.Dict]: # noqa: UP006
127+
parameter_type = JsonSchemaType.OBJECT.value
128+
array_item_type = None
129+
else:
130+
parameter_type = python_type_to_json_schema_type(
131+
annotation_type.__name__
132+
if hasattr(annotation_type, "__name__")
133+
else "Any"
110134
)
111135

112-
function_definition.parameters.append(
113-
Parameter(name=key, type=parameter_type)
136+
array_item_type = None
137+
138+
parameters.append(
139+
Parameter(
140+
name=param_name,
141+
type=parameter_type,
142+
array_item_type=array_item_type,
114143
)
144+
)
115145

116-
return function_definition
146+
return Function(
147+
name=function_reference.__name__,
148+
description="",
149+
parameters=parameters,
150+
)
117151

118152
@staticmethod
119153
def _infer_from_inspection(function_reference: Callable) -> Function:
@@ -124,6 +158,7 @@ def _infer_from_inspection(function_reference: Callable) -> Function:
124158
125159
Returns:
126160
The inferred Function instance.
161+
127162
"""
128163
function_definition = Function(
129164
name=function_reference.__name__,
@@ -135,9 +170,31 @@ def _infer_from_inspection(function_reference: Callable) -> Function:
135170

136171
for name, parameter in inspected_parameters.items():
137172
parameter_type: str = python_type_to_json_schema_type(parameter.kind.name)
173+
enum_values: list[str] | None = None
174+
175+
if parameter_type == "null":
176+
if isinstance(parameter.annotation, EnumMeta):
177+
enum_values = list(
178+
parameter.annotation._value2member_map_.keys() # noqa: SLF001
179+
)
180+
parameter_type = FunctionInferrer._infer_list_item_type(enum_values)
181+
elif dataclasses.is_dataclass(parameter.annotation):
182+
parameter_type = JsonSchemaType.OBJECT.value
138183

139184
function_definition.parameters.append(
140-
Parameter(name=name, type=parameter_type)
185+
Parameter(name=name, type=parameter_type, enum=enum_values)
141186
)
142187

143188
return function_definition
189+
190+
@staticmethod
191+
def _infer_list_item_type(list_of_items: list[Any]) -> str:
192+
if len(list_of_items) == 0:
193+
return JsonSchemaType.NULL.value
194+
195+
# Check if all items are the same type.
196+
if len({type(item).__name__ for item in list_of_items}) == 1:
197+
item: Any = type(list_of_items[0]).__name__
198+
return python_type_to_json_schema_type(item)
199+
200+
return JsonSchemaType.ANY.value

openai_function_calling/helper_functions.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ def python_type_to_json_schema_type(python_type: str | None) -> str:
1313
1414
Returns:
1515
A JSON schema type value.
16+
1617
"""
1718
json_schema_type: str = JsonSchemaType.NULL.value
1819

openai_function_calling/parameter.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ def __init__(
5252
is not set.
5353
ValueError: If the 'array_item_type' argument is set, but the 'type' is not\
5454
'array'.
55+
5556
"""
5657
self.name: str = name
5758
self.type: str = type
@@ -67,6 +68,7 @@ def validate(self) -> None:
6768
Raises:
6869
ValueError: If 'array_item_type' is not set, but 'type' is array.
6970
ValueError: If 'array_item_type' is set, but 'type' is not array.
71+
7072
"""
7173
if self.type == JsonSchemaType.ARRAY and self.array_item_type is None:
7274
raise ValueError(
@@ -86,6 +88,7 @@ def to_json_schema(self) -> ParameterDict:
8688
8789
Raises:
8890
ValueError: If there are validation errors. See the validate method.
91+
8992
"""
9093
self.validate()
9194

@@ -111,6 +114,7 @@ def merge(self, other_parameter: Parameter) -> None:
111114
112115
Args:
113116
other_parameter: The other parameter instance to merge into the current.
117+
114118
"""
115119
if not isinstance(other_parameter, Parameter):
116120
raise TypeError("Cannot merge non-parameter type into parameter.")
@@ -138,6 +142,7 @@ def __eq__(self, other: object) -> bool:
138142
139143
Returns:
140144
If the other object is equal to the current instance.
145+
141146
"""
142147
if not isinstance(other, Parameter):
143148
return False

openai_function_calling/tool_helpers.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ def from_functions(functions: list[Function]) -> list[ChatCompletionToolParam]:
2626
2727
Returns:
2828
A list of OpenAI chat completion tool parameters.
29+
2930
"""
3031
json_schemas: list[FunctionDict] = [f.to_json_schema() for f in functions]
3132
tool_params: list[ChatCompletionToolParam] = [
@@ -45,6 +46,7 @@ def infer_from_function_refs(
4546
4647
Returns:
4748
A list of OpenAI chat completion tool parameters.
49+
4850
"""
4951
functions: list[Function] = [
5052
FunctionInferrer.infer_from_function_reference(f) for f in function_refs
@@ -60,5 +62,6 @@ def json_schema_to_tool_param(json_schema: FunctionDict) -> ChatCompletionToolPa
6062
6163
Returns:
6264
An OpenAI chat completion tool parameter.
65+
6366
"""
6467
return {"type": "function", "function": cast(FunctionDefinition, json_schema)}

0 commit comments

Comments
 (0)