diff --git a/aisuite/utils/tools.py b/aisuite/utils/tools.py index 11b4bd57..ff6e5d18 100644 --- a/aisuite/utils/tools.py +++ b/aisuite/utils/tools.py @@ -1,3 +1,5 @@ +import dataclasses +import datetime from typing import Callable, Dict, Any, Type, Optional from pydantic import BaseModel, create_model, Field, ValidationError import inspect @@ -275,7 +277,7 @@ def execute_tool(self, tool_calls) -> tuple[list, list]: { "role": "tool", "name": tool_name, - "content": json.dumps(result), + "content": json.dumps(result, cls=DefaultJSONEncoder), "tool_call_id": tool_call_id, } ) @@ -283,3 +285,30 @@ def execute_tool(self, tool_calls) -> tuple[list, list]: raise ValueError(f"Error in tool '{tool_name}' parameters: {e}") return results, messages + + +class DefaultJSONEncoder(json.JSONEncoder): + """ + Based on Flask https://flask.palletsprojects.com/en/stable/api/#flask.json.provider.DefaultJSONProvider + Provide JSON operations using Python’s built-in json library. Serializes the following additional data types: + + dataclasses: by calling `asdict` + datetime.dateime: by calling `datetime.datetime.isoformat()` + Markup and other objects: by checking for a __html__ method to return a str. + """ + + def default(self, obj: Any) -> Any: + # Handle dataclasses by converting them to dictionaries + if dataclasses.is_dataclass(obj): + return dataclasses.asdict(obj) + + # Handle datetime objects by using isoformat + if isinstance(obj, datetime.datetime): + return obj.isoformat() + + # Check for __html__ method which is used by Markupsafe a ton of other libraries and call it + if hasattr(obj, "__html__") and callable(getattr(obj, "__html__")): + return obj.__html__() + + # Let the base class handle anything else + return super().default(obj) diff --git a/tests/utils/test_tool_manager.py b/tests/utils/test_tool_manager.py index c78e6999..abb3a975 100644 --- a/tests/utils/test_tool_manager.py +++ b/tests/utils/test_tool_manager.py @@ -3,6 +3,9 @@ from typing import Dict from aisuite.utils.tools import Tools # Import your ToolManager class from enum import Enum +import dataclasses +from datetime import datetime +from markupsafe import Markup # Define a sample tool function and Pydantic model for testing @@ -21,6 +24,19 @@ class TemperatureParams(BaseModel): unit: str = "Celsius" +@dataclasses.dataclass +class WeatherReport: + temperature: int + temperature_unit: TemperatureUnit = TemperatureUnit.CELSIUS + last_forecast_datetime: datetime = datetime(2024, 1, 1, 11, 5, 5) + forecast: Markup = Markup( + """ +