diff --git a/aisuite/utils/tools.py b/aisuite/utils/tools.py index 11b4bd57..ff6e5d18 100644 --- a/aisuite/utils/tools.py +++ b/aisuite/utils/tools.py @@ -1,3 +1,5 @@ +import dataclasses +import datetime from typing import Callable, Dict, Any, Type, Optional from pydantic import BaseModel, create_model, Field, ValidationError import inspect @@ -275,7 +277,7 @@ def execute_tool(self, tool_calls) -> tuple[list, list]: { "role": "tool", "name": tool_name, - "content": json.dumps(result), + "content": json.dumps(result, cls=DefaultJSONEncoder), "tool_call_id": tool_call_id, } ) @@ -283,3 +285,30 @@ def execute_tool(self, tool_calls) -> tuple[list, list]: raise ValueError(f"Error in tool '{tool_name}' parameters: {e}") return results, messages + + +class DefaultJSONEncoder(json.JSONEncoder): + """ + Based on Flask https://flask.palletsprojects.com/en/stable/api/#flask.json.provider.DefaultJSONProvider + Provide JSON operations using Python’s built-in json library. Serializes the following additional data types: + + dataclasses: by calling `asdict` + datetime.dateime: by calling `datetime.datetime.isoformat()` + Markup and other objects: by checking for a __html__ method to return a str. + """ + + def default(self, obj: Any) -> Any: + # Handle dataclasses by converting them to dictionaries + if dataclasses.is_dataclass(obj): + return dataclasses.asdict(obj) + + # Handle datetime objects by using isoformat + if isinstance(obj, datetime.datetime): + return obj.isoformat() + + # Check for __html__ method which is used by Markupsafe a ton of other libraries and call it + if hasattr(obj, "__html__") and callable(getattr(obj, "__html__")): + return obj.__html__() + + # Let the base class handle anything else + return super().default(obj) diff --git a/tests/utils/test_tool_manager.py b/tests/utils/test_tool_manager.py index c78e6999..abb3a975 100644 --- a/tests/utils/test_tool_manager.py +++ b/tests/utils/test_tool_manager.py @@ -3,6 +3,9 @@ from typing import Dict from aisuite.utils.tools import Tools # Import your ToolManager class from enum import Enum +import dataclasses +from datetime import datetime +from markupsafe import Markup # Define a sample tool function and Pydantic model for testing @@ -21,6 +24,19 @@ class TemperatureParams(BaseModel): unit: str = "Celsius" +@dataclasses.dataclass +class WeatherReport: + temperature: int + temperature_unit: TemperatureUnit = TemperatureUnit.CELSIUS + last_forecast_datetime: datetime = datetime(2024, 1, 1, 11, 5, 5) + forecast: Markup = Markup( + """ +

Weather Report for Today

+
It's looking really good.
+ """ + ) + + def get_current_temperature(location: str, unit: str = "Celsius") -> Dict[str, str]: """Gets the current temperature for a specific location and unit.""" return {"location": location, "unit": unit, "temperature": "72"} @@ -38,6 +54,18 @@ def get_current_temperature_v2( return {"location": location, "unit": unit, "temperature": "72"} +class WeatherReportParams(BaseModel): + location: str + + +def get_weather_report_complex_type(location: str) -> WeatherReport: + """Gets a detailed weather report for a location""" + return WeatherReport( + 20, + TemperatureUnit.CELSIUS, + ) + + class TestToolManager(unittest.TestCase): def setUp(self): self.tool_manager = Tools() @@ -195,6 +223,35 @@ def test_add_tool_with_enum(self): tools == expected_tool_spec ), f"Expected {expected_tool_spec}, but got {tools}" + def test_execute_tool_with_complex_type(self): + """Test executing a registered tool with valid parameters.""" + self.tool_manager._add_tool( + get_weather_report_complex_type, WeatherReportParams + ) + tool_call = { + "id": "call_1", + "function": { + "name": "get_weather_report_complex_type", + "arguments": {"location": "San Francisco"}, + }, + } + result, result_message = self.tool_manager.execute_tool(tool_call) + + # Assuming result is returned as a list with a single dictionary + result_obj = result[0] if isinstance(result, list) else result + self.assertIsInstance(result_obj, WeatherReport) + + # Check that the result matches expected output + self.assertEqual( + result_obj.forecast, + Markup( + "\n

Weather Report for Today

\n
It's looking really good.
\n " + ), + ) + self.assertEqual( + result_obj.last_forecast_datetime, datetime(2024, 1, 1, 11, 5, 5) + ) + if __name__ == "__main__": unittest.main()