Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 30 additions & 1 deletion aisuite/utils/tools.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import dataclasses
import datetime
from typing import Callable, Dict, Any, Type, Optional
from pydantic import BaseModel, create_model, Field, ValidationError
import inspect
Expand Down Expand Up @@ -275,11 +277,38 @@ def execute_tool(self, tool_calls) -> tuple[list, list]:
{
"role": "tool",
"name": tool_name,
"content": json.dumps(result),
"content": json.dumps(result, cls=DefaultJSONEncoder),
"tool_call_id": tool_call_id,
}
)
except ValidationError as e:
raise ValueError(f"Error in tool '{tool_name}' parameters: {e}")

return results, messages


class DefaultJSONEncoder(json.JSONEncoder):
"""
Based on Flask https://flask.palletsprojects.com/en/stable/api/#flask.json.provider.DefaultJSONProvider
Provide JSON operations using Python’s built-in json library. Serializes the following additional data types:

dataclasses: by calling `asdict`
datetime.dateime: by calling `datetime.datetime.isoformat()`
Markup and other objects: by checking for a __html__ method to return a str.
"""

def default(self, obj: Any) -> Any:
# Handle dataclasses by converting them to dictionaries
if dataclasses.is_dataclass(obj):
return dataclasses.asdict(obj)

# Handle datetime objects by using isoformat
if isinstance(obj, datetime.datetime):
return obj.isoformat()

# Check for __html__ method which is used by Markupsafe a ton of other libraries and call it
if hasattr(obj, "__html__") and callable(getattr(obj, "__html__")):
return obj.__html__()

# Let the base class handle anything else
return super().default(obj)
57 changes: 57 additions & 0 deletions tests/utils/test_tool_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@
from typing import Dict
from aisuite.utils.tools import Tools # Import your ToolManager class
from enum import Enum
import dataclasses
from datetime import datetime
from markupsafe import Markup


# Define a sample tool function and Pydantic model for testing
Expand All @@ -21,6 +24,19 @@ class TemperatureParams(BaseModel):
unit: str = "Celsius"


@dataclasses.dataclass
class WeatherReport:
temperature: int
temperature_unit: TemperatureUnit = TemperatureUnit.CELSIUS
last_forecast_datetime: datetime = datetime(2024, 1, 1, 11, 5, 5)
forecast: Markup = Markup(
"""
<h1>Weather Report for Today</h1>
<section>It's looking really good.</section>
"""
)


def get_current_temperature(location: str, unit: str = "Celsius") -> Dict[str, str]:
"""Gets the current temperature for a specific location and unit."""
return {"location": location, "unit": unit, "temperature": "72"}
Expand All @@ -38,6 +54,18 @@ def get_current_temperature_v2(
return {"location": location, "unit": unit, "temperature": "72"}


class WeatherReportParams(BaseModel):
location: str


def get_weather_report_complex_type(location: str) -> WeatherReport:
"""Gets a detailed weather report for a location"""
return WeatherReport(
20,
TemperatureUnit.CELSIUS,
)


class TestToolManager(unittest.TestCase):
def setUp(self):
self.tool_manager = Tools()
Expand Down Expand Up @@ -195,6 +223,35 @@ def test_add_tool_with_enum(self):
tools == expected_tool_spec
), f"Expected {expected_tool_spec}, but got {tools}"

def test_execute_tool_with_complex_type(self):
"""Test executing a registered tool with valid parameters."""
self.tool_manager._add_tool(
get_weather_report_complex_type, WeatherReportParams
)
tool_call = {
"id": "call_1",
"function": {
"name": "get_weather_report_complex_type",
"arguments": {"location": "San Francisco"},
},
}
result, result_message = self.tool_manager.execute_tool(tool_call)

# Assuming result is returned as a list with a single dictionary
result_obj = result[0] if isinstance(result, list) else result
self.assertIsInstance(result_obj, WeatherReport)

# Check that the result matches expected output
self.assertEqual(
result_obj.forecast,
Markup(
"\n <h1>Weather Report for Today</h1>\n <section>It's looking really good.</section>\n "
),
)
self.assertEqual(
result_obj.last_forecast_datetime, datetime(2024, 1, 1, 11, 5, 5)
)


if __name__ == "__main__":
unittest.main()