Skip to content

Commit a7d1afe

Browse files
committed
Fix tool parameters issue and add tests
1 parent 651f0f2 commit a7d1afe

File tree

2 files changed

+202
-0
lines changed

2 files changed

+202
-0
lines changed

src/neo4j_graphrag/tool.py

+30
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,15 @@ class ArrayParameter(ToolParameter):
111111
min_items: Optional[int] = None
112112
max_items: Optional[int] = None
113113

114+
@model_validator(mode="before")
115+
@classmethod
116+
def _preprocess_items(cls, values):
117+
# Convert items from dict to ToolParameter if needed
118+
items = values.get("items")
119+
if isinstance(items, dict):
120+
values["items"] = ToolParameter.from_dict(items)
121+
return values
122+
114123
def model_dump_tool(self) -> Dict[str, Any]:
115124
result = super().model_dump_tool()
116125
result["items"] = self.items.model_dump_tool()
@@ -129,6 +138,9 @@ def validate_items(self) -> "ArrayParameter":
129138
raise ValueError(
130139
f"Items must be a ToolParameter or dict, got {type(self.items)}"
131140
)
141+
elif type(self.items) is ToolParameter:
142+
# Promote base ToolParameter to correct subclass if possible
143+
self.items = ToolParameter.from_dict(self.items.model_dump())
132144
return self
133145

134146

@@ -140,6 +152,21 @@ class ObjectParameter(ToolParameter):
140152
required_properties: List[str] = Field(default_factory=list)
141153
additional_properties: bool = True
142154

155+
@model_validator(mode="before")
156+
@classmethod
157+
def _preprocess_properties(cls, values):
158+
# Convert properties from dicts to ToolParameter if needed
159+
props = values.get("properties")
160+
if isinstance(props, dict):
161+
new_props = {}
162+
for k, v in props.items():
163+
if isinstance(v, dict):
164+
new_props[k] = ToolParameter.from_dict(v)
165+
else:
166+
new_props[k] = v
167+
values["properties"] = new_props
168+
return values
169+
143170
def model_dump_tool(self) -> Dict[str, Any]:
144171
properties_dict: Dict[str, Any] = {}
145172
for name, param in self.properties.items():
@@ -167,6 +194,9 @@ def validate_properties(self) -> "ObjectParameter":
167194
raise ValueError(
168195
f"Property {name} must be a ToolParameter or dict, got {type(param)}"
169196
)
197+
elif type(param) is ToolParameter:
198+
# Promote base ToolParameter to correct subclass if possible
199+
validated_properties[name] = ToolParameter.from_dict(param.model_dump())
170200
else:
171201
validated_properties[name] = param
172202
self.properties = validated_properties

tests/unit/tool/test_tool.py

+172
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,172 @@
1+
import pytest
2+
from neo4j_graphrag.tool import (
3+
StringParameter,
4+
IntegerParameter,
5+
NumberParameter,
6+
BooleanParameter,
7+
ArrayParameter,
8+
ObjectParameter,
9+
Tool,
10+
ToolParameter,
11+
ParameterType,
12+
)
13+
14+
15+
def test_string_parameter():
16+
param = StringParameter(description="A string", required=True, enum=["a", "b"])
17+
assert param.description == "A string"
18+
assert param.required is True
19+
assert param.enum == ["a", "b"]
20+
d = param.model_dump_tool()
21+
assert d["type"] == ParameterType.STRING
22+
assert d["enum"] == ["a", "b"]
23+
24+
25+
def test_integer_parameter():
26+
param = IntegerParameter(description="An int", minimum=0, maximum=10)
27+
d = param.model_dump_tool()
28+
assert d["type"] == ParameterType.INTEGER
29+
assert d["minimum"] == 0
30+
assert d["maximum"] == 10
31+
32+
33+
def test_number_parameter():
34+
param = NumberParameter(description="A number", minimum=1.5, maximum=3.5)
35+
d = param.model_dump_tool()
36+
assert d["type"] == ParameterType.NUMBER
37+
assert d["minimum"] == 1.5
38+
assert d["maximum"] == 3.5
39+
40+
41+
def test_boolean_parameter():
42+
param = BooleanParameter(description="A bool")
43+
d = param.model_dump_tool()
44+
assert d["type"] == ParameterType.BOOLEAN
45+
assert d["description"] == "A bool"
46+
47+
48+
def test_array_parameter_and_validation():
49+
arr_param = ArrayParameter(
50+
description="An array",
51+
items=StringParameter(description="str"),
52+
min_items=1,
53+
max_items=5,
54+
)
55+
d = arr_param.model_dump_tool()
56+
assert d["type"] == ParameterType.ARRAY
57+
assert d["items"]["type"] == ParameterType.STRING
58+
assert d["minItems"] == 1
59+
assert d["maxItems"] == 5
60+
61+
# Test items as dict
62+
arr_param2 = ArrayParameter(
63+
description="Arr with dict",
64+
items={"type": "string", "description": "str"},
65+
)
66+
arr_param2 = arr_param2.validate_items()
67+
assert isinstance(arr_param2.items, StringParameter)
68+
69+
# Test error on invalid items
70+
with pytest.raises(ValueError):
71+
ArrayParameter(description="bad", items=123).validate_items()
72+
73+
74+
def test_object_parameter_and_validation():
75+
obj_param = ObjectParameter(
76+
description="Obj",
77+
properties={
78+
"foo": StringParameter(description="foo"),
79+
"bar": IntegerParameter(description="bar"),
80+
},
81+
required_properties=["foo"],
82+
additional_properties=False,
83+
)
84+
d = obj_param.model_dump_tool()
85+
assert d["type"] == ParameterType.OBJECT
86+
assert d["properties"]["foo"]["type"] == ParameterType.STRING
87+
assert d["required"] == ["foo"]
88+
assert d["additionalProperties"] is False
89+
90+
# Test properties as dicts
91+
obj_param2 = ObjectParameter(
92+
description="Obj2",
93+
properties={
94+
"foo": {"type": "string", "description": "foo"},
95+
},
96+
)
97+
obj_param2 = obj_param2.validate_properties()
98+
assert isinstance(obj_param2.properties["foo"], StringParameter)
99+
100+
# Test error on invalid property
101+
with pytest.raises(ValueError):
102+
ObjectParameter(
103+
description="bad", properties={"foo": 123}
104+
).validate_properties()
105+
106+
107+
def test_from_dict():
108+
d = {"type": ParameterType.STRING, "description": "desc"}
109+
param = ToolParameter.from_dict(d)
110+
assert isinstance(param, StringParameter)
111+
assert param.description == "desc"
112+
113+
obj_dict = {
114+
"type": "object",
115+
"description": "obj",
116+
"properties": {"foo": {"type": "string", "description": "foo"}},
117+
}
118+
obj_param = ToolParameter.from_dict(obj_dict)
119+
assert isinstance(obj_param, ObjectParameter)
120+
assert isinstance(obj_param.properties["foo"], StringParameter)
121+
122+
arr_dict = {
123+
"type": "array",
124+
"description": "arr",
125+
"items": {"type": "integer", "description": "int"},
126+
}
127+
arr_param = ToolParameter.from_dict(arr_dict)
128+
assert isinstance(arr_param, ArrayParameter)
129+
assert isinstance(arr_param.items, IntegerParameter)
130+
131+
# Test unknown type
132+
with pytest.raises(ValueError):
133+
ToolParameter.from_dict({"type": "unknown", "description": "bad"})
134+
135+
# Test missing type
136+
with pytest.raises(ValueError):
137+
ToolParameter.from_dict({"description": "no type"})
138+
139+
140+
def test_tool_class():
141+
def dummy_func(query, **kwargs):
142+
return kwargs
143+
144+
params = ObjectParameter(
145+
description="params",
146+
properties={"a": StringParameter(description="a")},
147+
)
148+
tool = Tool(
149+
name="mytool",
150+
description="desc",
151+
parameters=params,
152+
execute_func=dummy_func,
153+
)
154+
assert tool.get_name() == "mytool"
155+
assert tool.get_description() == "desc"
156+
assert tool.get_parameters()["type"] == ParameterType.OBJECT
157+
assert tool.execute("query", a="b") == {"a": "b"}
158+
159+
# Test parameters as dict
160+
params_dict = {
161+
"type": "object",
162+
"description": "params",
163+
"properties": {"a": {"type": "string", "description": "a"}},
164+
}
165+
tool2 = Tool(
166+
name="mytool2",
167+
description="desc2",
168+
parameters=params_dict,
169+
execute_func=dummy_func,
170+
)
171+
assert tool2.get_parameters()["type"] == ParameterType.OBJECT
172+
assert tool2.execute("query", a="b") == {"a": "b"}

0 commit comments

Comments
 (0)