From 4dbdad79b946f7da89fa728c4951fd6aa667b018 Mon Sep 17 00:00:00 2001 From: Wouter Doppenberg Date: Fri, 15 Aug 2025 10:43:35 +0200 Subject: [PATCH 1/6] Added generics for state & forwarded props; poetry -> uv; linting Changelog: changed --- python-sdk/.gitignore | 1 + python-sdk/.pre-commit-config.yaml | 41 +++++++++++++++++++++++++++++ python-sdk/ag_ui/__init__.py | 0 python-sdk/ag_ui/core/__init__.py | 2 -- python-sdk/ag_ui/core/events.py | 18 ++++++------- python-sdk/ag_ui/core/types.py | 20 +++++++------- python-sdk/ag_ui/encoder/encoder.py | 2 +- python-sdk/pyproject.toml | 33 ++++++++++++++++------- python-sdk/tests/test_events.py | 5 ++-- 9 files changed, 87 insertions(+), 35 deletions(-) create mode 100644 python-sdk/.pre-commit-config.yaml create mode 100644 python-sdk/ag_ui/__init__.py diff --git a/python-sdk/.gitignore b/python-sdk/.gitignore index 5d9e5de22..69af64f9c 100644 --- a/python-sdk/.gitignore +++ b/python-sdk/.gitignore @@ -63,3 +63,4 @@ venv.bak/ # Project specific .DS_Store +/uv.lock diff --git a/python-sdk/.pre-commit-config.yaml b/python-sdk/.pre-commit-config.yaml new file mode 100644 index 000000000..5b1742428 --- /dev/null +++ b/python-sdk/.pre-commit-config.yaml @@ -0,0 +1,41 @@ +repos: + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: v5.0.0 + hooks: + - id: trailing-whitespace + exclude_types: [ jupyter ] + - id: end-of-file-fixer + exclude_types: [ jupyter ] + - id: check-docstring-first + - id: debug-statements + - id: check-ast + - repo: https://github.com/charliermarsh/ruff-pre-commit + rev: v0.11.8 + hooks: + - id: ruff + args: [ + --fix + ] + - id: ruff-format + - repo: https://github.com/pre-commit/mirrors-mypy + rev: v1.15.0 + hooks: + - id: mypy + args: [ + --python-version=3.12, + --disallow-untyped-calls, + --disallow-untyped-defs, + --disallow-incomplete-defs, + --check-untyped-defs, + --no-implicit-optional, + --warn-redundant-casts, + --ignore-missing-imports, + ] + additional_dependencies: + - "types-pytz" + exclude_types: [ jupyter ] + exclude: "tests" + - repo: https://github.com/kynan/nbstripout + rev: 0.8.1 + hooks: + - id: nbstripout diff --git a/python-sdk/ag_ui/__init__.py b/python-sdk/ag_ui/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/python-sdk/ag_ui/core/__init__.py b/python-sdk/ag_ui/core/__init__.py index 7e909ad5b..59410aec0 100644 --- a/python-sdk/ag_ui/core/__init__.py +++ b/python-sdk/ag_ui/core/__init__.py @@ -46,7 +46,6 @@ Context, Tool, RunAgentInput, - State ) __all__ = [ @@ -92,5 +91,4 @@ "Context", "Tool", "RunAgentInput", - "State" ] diff --git a/python-sdk/ag_ui/core/events.py b/python-sdk/ag_ui/core/events.py index 16dfdccca..a42c2888d 100644 --- a/python-sdk/ag_ui/core/events.py +++ b/python-sdk/ag_ui/core/events.py @@ -3,11 +3,11 @@ """ from enum import Enum -from typing import Annotated, Any, List, Literal, Optional, Union +from typing import Annotated, List, Literal, Optional, Union, Generic from pydantic import Field -from .types import ConfiguredBaseModel, Message, State +from .types import ConfiguredBaseModel, Message, AgentStateT, JSONValue class EventType(str, Enum): @@ -46,7 +46,7 @@ class BaseEvent(ConfiguredBaseModel): """ type: EventType timestamp: Optional[int] = None - raw_event: Optional[Any] = None + raw_event: Optional[JSONValue] = None class TextMessageStartEvent(BaseEvent): @@ -161,12 +161,12 @@ class ThinkingEndEvent(BaseEvent): """ type: Literal[EventType.THINKING_END] = EventType.THINKING_END # pyright: ignore[reportIncompatibleVariableOverride] -class StateSnapshotEvent(BaseEvent): +class StateSnapshotEvent(BaseEvent, Generic[AgentStateT]): """ Event containing a snapshot of the state. """ type: Literal[EventType.STATE_SNAPSHOT] = EventType.STATE_SNAPSHOT # pyright: ignore[reportIncompatibleVariableOverride] - snapshot: State + snapshot: AgentStateT class StateDeltaEvent(BaseEvent): @@ -174,7 +174,7 @@ class StateDeltaEvent(BaseEvent): Event containing a delta of the state. """ type: Literal[EventType.STATE_DELTA] = EventType.STATE_DELTA # pyright: ignore[reportIncompatibleVariableOverride] - delta: List[Any] # JSON Patch (RFC 6902) + delta: JSONValue # JSON Patch (RFC 6902) class MessagesSnapshotEvent(BaseEvent): @@ -190,7 +190,7 @@ class RawEvent(BaseEvent): Event containing a raw event. """ type: Literal[EventType.RAW] = EventType.RAW # pyright: ignore[reportIncompatibleVariableOverride] - event: Any + event: JSONValue source: Optional[str] = None @@ -200,7 +200,7 @@ class CustomEvent(BaseEvent): """ type: Literal[EventType.CUSTOM] = EventType.CUSTOM # pyright: ignore[reportIncompatibleVariableOverride] name: str - value: Any + value: JSONValue class RunStartedEvent(BaseEvent): @@ -219,7 +219,7 @@ class RunFinishedEvent(BaseEvent): type: Literal[EventType.RUN_FINISHED] = EventType.RUN_FINISHED # pyright: ignore[reportIncompatibleVariableOverride] thread_id: str run_id: str - result: Optional[Any] = None + result: JSONValue = None class RunErrorEvent(BaseEvent): diff --git a/python-sdk/ag_ui/core/types.py b/python-sdk/ag_ui/core/types.py index 47b7ae182..efa1c03df 100644 --- a/python-sdk/ag_ui/core/types.py +++ b/python-sdk/ag_ui/core/types.py @@ -2,11 +2,16 @@ This module contains the types for the Agent User Interaction Protocol Python SDK. """ -from typing import Annotated, Any, List, Literal, Optional, Union +from typing import Annotated, Any, List, Literal, Optional, Union, Generic +from typing_extensions import TypeVar from pydantic import BaseModel, ConfigDict, Field from pydantic.alias_generators import to_camel +JSONValue = Union[str, int, float, bool, None, dict[str, Any], list[Any]] +AgentStateT = TypeVar('AgentStateT', default=JSONValue, contravariant=True) +FwdPropsT = TypeVar('FwdPropsT', default=JSONValue, contravariant=True) + class ConfiguredBaseModel(BaseModel): """ @@ -51,7 +56,6 @@ class DeveloperMessage(BaseMessage): A developer message. """ role: Literal["developer"] = "developer" # pyright: ignore[reportIncompatibleVariableOverride] - content: str class SystemMessage(BaseMessage): @@ -59,7 +63,6 @@ class SystemMessage(BaseMessage): A system message. """ role: Literal["system"] = "system" # pyright: ignore[reportIncompatibleVariableOverride] - content: str class AssistantMessage(BaseMessage): @@ -75,7 +78,6 @@ class UserMessage(BaseMessage): A user message. """ role: Literal["user"] = "user" # pyright: ignore[reportIncompatibleVariableOverride] - content: str class ToolMessage(ConfiguredBaseModel): @@ -114,18 +116,14 @@ class Tool(ConfiguredBaseModel): parameters: Any # JSON Schema for the tool parameters -class RunAgentInput(ConfiguredBaseModel): +class RunAgentInput(ConfiguredBaseModel, Generic[AgentStateT, FwdPropsT]): """ Input for running an agent. """ thread_id: str run_id: str - state: Any + state: AgentStateT messages: List[Message] tools: List[Tool] context: List[Context] - forwarded_props: Any - - -# State can be any type -State = Any + forwarded_props: FwdPropsT diff --git a/python-sdk/ag_ui/encoder/encoder.py b/python-sdk/ag_ui/encoder/encoder.py index f840e3bb8..2cfe88392 100644 --- a/python-sdk/ag_ui/encoder/encoder.py +++ b/python-sdk/ag_ui/encoder/encoder.py @@ -10,7 +10,7 @@ class EventEncoder: """ Encodes Agent User Interaction events. """ - def __init__(self, accept: str = None): + def __init__(self, accept: str | None = None): pass def get_content_type(self) -> str: diff --git a/python-sdk/pyproject.toml b/python-sdk/pyproject.toml index 42c02bf45..15b8b9649 100644 --- a/python-sdk/pyproject.toml +++ b/python-sdk/pyproject.toml @@ -1,15 +1,30 @@ -[tool.poetry] +[project] name = "ag-ui-protocol" -version = "0.1.8" +version = "0.1.9" description = "" -authors = ["Markus Ecker "] +authors = [ + { name = "Markus Ecker", email = "markus.ecker@gmail.com" }, +] readme = "README.md" -packages = [{include = "ag_ui", from = "."}] -[tool.poetry.dependencies] -python = "^3.9" -pydantic = "^2.11.2" +requires-python = ">=3.9,<4.0" +dependencies = [ + "pydantic>=2.11.2,<3.0.0", +] +packages = [ + { include = "ag_ui", from = "ag_ui" } +] + +[tool.hatch.build.targets.wheel] +packages = ["ag_ui"] [build-system] -requires = ["poetry-core"] -build-backend = "poetry.core.masonry.api" +requires = ["hatchling"] +build-backend = "hatchling.build" + +[dependency-groups] +dev = [ + "mypy>=1.17.1", + "pyright>=1.1.403", + "ruff>=0.12.9", +] diff --git a/python-sdk/tests/test_events.py b/python-sdk/tests/test_events.py index c73a2537c..e413275c7 100644 --- a/python-sdk/tests/test_events.py +++ b/python-sdk/tests/test_events.py @@ -1,9 +1,8 @@ import unittest -import json from datetime import datetime -from pydantic import ValidationError, TypeAdapter +from pydantic import TypeAdapter -from ag_ui.core.types import Message, UserMessage, AssistantMessage, FunctionCall, ToolCall +from ag_ui.core.types import UserMessage, AssistantMessage, FunctionCall, ToolCall from ag_ui.core.events import ( EventType, BaseEvent, From df5f0d34aafd142c38b1ad3fca3d118330d4b0ff Mon Sep 17 00:00:00 2001 From: Wouter Doppenberg Date: Fri, 15 Aug 2025 10:45:39 +0200 Subject: [PATCH 2/6] pre-commit fixes --- python-sdk/.pre-commit-config.yaml | 19 +- python-sdk/ag_ui/core/__init__.py | 2 +- python-sdk/ag_ui/core/events.py | 50 +++++- python-sdk/ag_ui/core/types.py | 20 ++- python-sdk/ag_ui/encoder/encoder.py | 2 + python-sdk/tests/test_encoder.py | 100 ++++++----- python-sdk/tests/test_events.py | 263 ++++++++++++---------------- python-sdk/tests/test_types.py | 227 ++++++++++-------------- 8 files changed, 331 insertions(+), 352 deletions(-) diff --git a/python-sdk/.pre-commit-config.yaml b/python-sdk/.pre-commit-config.yaml index 5b1742428..f52033e00 100644 --- a/python-sdk/.pre-commit-config.yaml +++ b/python-sdk/.pre-commit-config.yaml @@ -1,14 +1,4 @@ repos: - - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v5.0.0 - hooks: - - id: trailing-whitespace - exclude_types: [ jupyter ] - - id: end-of-file-fixer - exclude_types: [ jupyter ] - - id: check-docstring-first - - id: debug-statements - - id: check-ast - repo: https://github.com/charliermarsh/ruff-pre-commit rev: v0.11.8 hooks: @@ -16,11 +6,14 @@ repos: args: [ --fix ] + files: ^python-sdk/ - id: ruff-format + files: ^python-sdk/ - repo: https://github.com/pre-commit/mirrors-mypy rev: v1.15.0 hooks: - id: mypy + files: ^python-sdk/ args: [ --python-version=3.12, --disallow-untyped-calls, @@ -34,8 +27,4 @@ repos: additional_dependencies: - "types-pytz" exclude_types: [ jupyter ] - exclude: "tests" - - repo: https://github.com/kynan/nbstripout - rev: 0.8.1 - hooks: - - id: nbstripout + exclude: "tests" \ No newline at end of file diff --git a/python-sdk/ag_ui/core/__init__.py b/python-sdk/ag_ui/core/__init__.py index 59410aec0..a545ee13d 100644 --- a/python-sdk/ag_ui/core/__init__.py +++ b/python-sdk/ag_ui/core/__init__.py @@ -29,7 +29,7 @@ RunErrorEvent, StepStartedEvent, StepFinishedEvent, - Event + Event, ) from ag_ui.core.types import ( diff --git a/python-sdk/ag_ui/core/events.py b/python-sdk/ag_ui/core/events.py index a42c2888d..b0b2bab40 100644 --- a/python-sdk/ag_ui/core/events.py +++ b/python-sdk/ag_ui/core/events.py @@ -14,6 +14,7 @@ class EventType(str, Enum): """ The type of event. """ + TEXT_MESSAGE_START = "TEXT_MESSAGE_START" TEXT_MESSAGE_CONTENT = "TEXT_MESSAGE_CONTENT" TEXT_MESSAGE_END = "TEXT_MESSAGE_END" @@ -44,6 +45,7 @@ class BaseEvent(ConfiguredBaseModel): """ Base event for all events in the Agent User Interaction Protocol. """ + type: EventType timestamp: Optional[int] = None raw_event: Optional[JSONValue] = None @@ -53,6 +55,7 @@ class TextMessageStartEvent(BaseEvent): """ Event indicating the start of a text message. """ + type: Literal[EventType.TEXT_MESSAGE_START] = EventType.TEXT_MESSAGE_START # pyright: ignore[reportIncompatibleVariableOverride] message_id: str role: Literal["assistant"] = "assistant" @@ -62,6 +65,7 @@ class TextMessageContentEvent(BaseEvent): """ Event containing a piece of text message content. """ + type: Literal[EventType.TEXT_MESSAGE_CONTENT] = EventType.TEXT_MESSAGE_CONTENT # pyright: ignore[reportIncompatibleVariableOverride] message_id: str delta: str = Field(min_length=1) @@ -71,41 +75,58 @@ class TextMessageEndEvent(BaseEvent): """ Event indicating the end of a text message. """ + type: Literal[EventType.TEXT_MESSAGE_END] = EventType.TEXT_MESSAGE_END # pyright: ignore[reportIncompatibleVariableOverride] message_id: str + class TextMessageChunkEvent(BaseEvent): """ Event containing a chunk of text message content. """ + type: Literal[EventType.TEXT_MESSAGE_CHUNK] = EventType.TEXT_MESSAGE_CHUNK # pyright: ignore[reportIncompatibleVariableOverride] message_id: Optional[str] = None role: Optional[Literal["assistant"]] = None delta: Optional[str] = None + class ThinkingTextMessageStartEvent(BaseEvent): """ Event indicating the start of a thinking text message. """ - type: Literal[EventType.THINKING_TEXT_MESSAGE_START] = EventType.THINKING_TEXT_MESSAGE_START # pyright: ignore[reportIncompatibleVariableOverride] + + type: Literal[EventType.THINKING_TEXT_MESSAGE_START] = ( + EventType.THINKING_TEXT_MESSAGE_START + ) # pyright: ignore[reportIncompatibleVariableOverride] + class ThinkingTextMessageContentEvent(BaseEvent): """ Event indicating a piece of a thinking text message. """ - type: Literal[EventType.THINKING_TEXT_MESSAGE_CONTENT] = EventType.THINKING_TEXT_MESSAGE_CONTENT # pyright: ignore[reportIncompatibleVariableOverride] + + type: Literal[EventType.THINKING_TEXT_MESSAGE_CONTENT] = ( + EventType.THINKING_TEXT_MESSAGE_CONTENT + ) # pyright: ignore[reportIncompatibleVariableOverride] delta: str = Field(min_length=1) + class ThinkingTextMessageEndEvent(BaseEvent): """ Event indicating the end of a thinking text message. """ - type: Literal[EventType.THINKING_TEXT_MESSAGE_END] = EventType.THINKING_TEXT_MESSAGE_END # pyright: ignore[reportIncompatibleVariableOverride] + + type: Literal[EventType.THINKING_TEXT_MESSAGE_END] = ( + EventType.THINKING_TEXT_MESSAGE_END + ) # pyright: ignore[reportIncompatibleVariableOverride] + class ToolCallStartEvent(BaseEvent): """ Event indicating the start of a tool call. """ + type: Literal[EventType.TOOL_CALL_START] = EventType.TOOL_CALL_START # pyright: ignore[reportIncompatibleVariableOverride] tool_call_id: str tool_call_name: str @@ -116,6 +137,7 @@ class ToolCallArgsEvent(BaseEvent): """ Event containing tool call arguments. """ + type: Literal[EventType.TOOL_CALL_ARGS] = EventType.TOOL_CALL_ARGS # pyright: ignore[reportIncompatibleVariableOverride] tool_call_id: str delta: str @@ -125,46 +147,57 @@ class ToolCallEndEvent(BaseEvent): """ Event indicating the end of a tool call. """ + type: Literal[EventType.TOOL_CALL_END] = EventType.TOOL_CALL_END # pyright: ignore[reportIncompatibleVariableOverride] tool_call_id: str + class ToolCallChunkEvent(BaseEvent): """ Event containing a chunk of tool call content. """ + type: Literal[EventType.TOOL_CALL_CHUNK] = EventType.TOOL_CALL_CHUNK # pyright: ignore[reportIncompatibleVariableOverride] tool_call_id: Optional[str] = None tool_call_name: Optional[str] = None parent_message_id: Optional[str] = None delta: Optional[str] = None + class ToolCallResultEvent(BaseEvent): """ Event containing the result of a tool call. """ + message_id: str type: Literal[EventType.TOOL_CALL_RESULT] = EventType.TOOL_CALL_RESULT # pyright: ignore[reportIncompatibleVariableOverride] tool_call_id: str content: str role: Optional[Literal["tool"]] = None + class ThinkingStartEvent(BaseEvent): """ Event indicating the start of a thinking step event. """ + type: Literal[EventType.THINKING_START] = EventType.THINKING_START # pyright: ignore[reportIncompatibleVariableOverride] title: Optional[str] = None + class ThinkingEndEvent(BaseEvent): """ Event indicating the end of a thinking step event. """ + type: Literal[EventType.THINKING_END] = EventType.THINKING_END # pyright: ignore[reportIncompatibleVariableOverride] + class StateSnapshotEvent(BaseEvent, Generic[AgentStateT]): """ Event containing a snapshot of the state. """ + type: Literal[EventType.STATE_SNAPSHOT] = EventType.STATE_SNAPSHOT # pyright: ignore[reportIncompatibleVariableOverride] snapshot: AgentStateT @@ -173,6 +206,7 @@ class StateDeltaEvent(BaseEvent): """ Event containing a delta of the state. """ + type: Literal[EventType.STATE_DELTA] = EventType.STATE_DELTA # pyright: ignore[reportIncompatibleVariableOverride] delta: JSONValue # JSON Patch (RFC 6902) @@ -181,6 +215,7 @@ class MessagesSnapshotEvent(BaseEvent): """ Event containing a snapshot of the messages. """ + type: Literal[EventType.MESSAGES_SNAPSHOT] = EventType.MESSAGES_SNAPSHOT # pyright: ignore[reportIncompatibleVariableOverride] messages: List[Message] @@ -189,6 +224,7 @@ class RawEvent(BaseEvent): """ Event containing a raw event. """ + type: Literal[EventType.RAW] = EventType.RAW # pyright: ignore[reportIncompatibleVariableOverride] event: JSONValue source: Optional[str] = None @@ -198,6 +234,7 @@ class CustomEvent(BaseEvent): """ Event containing a custom event. """ + type: Literal[EventType.CUSTOM] = EventType.CUSTOM # pyright: ignore[reportIncompatibleVariableOverride] name: str value: JSONValue @@ -207,6 +244,7 @@ class RunStartedEvent(BaseEvent): """ Event indicating that a run has started. """ + type: Literal[EventType.RUN_STARTED] = EventType.RUN_STARTED # pyright: ignore[reportIncompatibleVariableOverride] thread_id: str run_id: str @@ -216,6 +254,7 @@ class RunFinishedEvent(BaseEvent): """ Event indicating that a run has finished. """ + type: Literal[EventType.RUN_FINISHED] = EventType.RUN_FINISHED # pyright: ignore[reportIncompatibleVariableOverride] thread_id: str run_id: str @@ -226,6 +265,7 @@ class RunErrorEvent(BaseEvent): """ Event indicating that a run has encountered an error. """ + type: Literal[EventType.RUN_ERROR] = EventType.RUN_ERROR # pyright: ignore[reportIncompatibleVariableOverride] message: str code: Optional[str] = None @@ -235,6 +275,7 @@ class StepStartedEvent(BaseEvent): """ Event indicating that a step has started. """ + type: Literal[EventType.STEP_STARTED] = EventType.STEP_STARTED # pyright: ignore[reportIncompatibleVariableOverride] step_name: str @@ -243,6 +284,7 @@ class StepFinishedEvent(BaseEvent): """ Event indicating that a step has finished. """ + type: Literal[EventType.STEP_FINISHED] = EventType.STEP_FINISHED # pyright: ignore[reportIncompatibleVariableOverride] step_name: str @@ -269,5 +311,5 @@ class StepFinishedEvent(BaseEvent): StepStartedEvent, StepFinishedEvent, ], - Field(discriminator="type") + Field(discriminator="type"), ] diff --git a/python-sdk/ag_ui/core/types.py b/python-sdk/ag_ui/core/types.py index efa1c03df..824f0bd70 100644 --- a/python-sdk/ag_ui/core/types.py +++ b/python-sdk/ag_ui/core/types.py @@ -9,14 +9,15 @@ from pydantic.alias_generators import to_camel JSONValue = Union[str, int, float, bool, None, dict[str, Any], list[Any]] -AgentStateT = TypeVar('AgentStateT', default=JSONValue, contravariant=True) -FwdPropsT = TypeVar('FwdPropsT', default=JSONValue, contravariant=True) +AgentStateT = TypeVar("AgentStateT", default=JSONValue, contravariant=True) +FwdPropsT = TypeVar("FwdPropsT", default=JSONValue, contravariant=True) class ConfiguredBaseModel(BaseModel): """ A configurable base model. """ + model_config = ConfigDict( extra="forbid", alias_generator=to_camel, @@ -28,6 +29,7 @@ class FunctionCall(ConfiguredBaseModel): """ Name and arguments of a function call. """ + name: str arguments: str @@ -36,6 +38,7 @@ class ToolCall(ConfiguredBaseModel): """ A tool call, modelled after OpenAI tool calls. """ + id: str type: Literal["function"] = "function" # pyright: ignore[reportIncompatibleVariableOverride] function: FunctionCall @@ -45,6 +48,7 @@ class BaseMessage(ConfiguredBaseModel): """ A base message, modelled after OpenAI messages. """ + id: str role: str content: Optional[str] = None @@ -55,6 +59,7 @@ class DeveloperMessage(BaseMessage): """ A developer message. """ + role: Literal["developer"] = "developer" # pyright: ignore[reportIncompatibleVariableOverride] @@ -62,6 +67,7 @@ class SystemMessage(BaseMessage): """ A system message. """ + role: Literal["system"] = "system" # pyright: ignore[reportIncompatibleVariableOverride] @@ -69,6 +75,7 @@ class AssistantMessage(BaseMessage): """ An assistant message. """ + role: Literal["assistant"] = "assistant" # pyright: ignore[reportIncompatibleVariableOverride] tool_calls: Optional[List[ToolCall]] = None @@ -77,13 +84,15 @@ class UserMessage(BaseMessage): """ A user message. """ - role: Literal["user"] = "user" # pyright: ignore[reportIncompatibleVariableOverride] + + role: Literal["user"] = "user" # pyright: ignore[reportIncompatibleVariableOverride] class ToolMessage(ConfiguredBaseModel): """ A tool result message. """ + id: str role: Literal["tool"] = "tool" content: str @@ -93,7 +102,7 @@ class ToolMessage(ConfiguredBaseModel): Message = Annotated[ Union[DeveloperMessage, SystemMessage, AssistantMessage, UserMessage, ToolMessage], - Field(discriminator="role") + Field(discriminator="role"), ] Role = Literal["developer", "system", "assistant", "user", "tool"] @@ -103,6 +112,7 @@ class Context(ConfiguredBaseModel): """ Additional context for the agent. """ + description: str value: str @@ -111,6 +121,7 @@ class Tool(ConfiguredBaseModel): """ A tool definition. """ + name: str description: str parameters: Any # JSON Schema for the tool parameters @@ -120,6 +131,7 @@ class RunAgentInput(ConfiguredBaseModel, Generic[AgentStateT, FwdPropsT]): """ Input for running an agent. """ + thread_id: str run_id: str state: AgentStateT diff --git a/python-sdk/ag_ui/encoder/encoder.py b/python-sdk/ag_ui/encoder/encoder.py index 2cfe88392..a30957568 100644 --- a/python-sdk/ag_ui/encoder/encoder.py +++ b/python-sdk/ag_ui/encoder/encoder.py @@ -6,10 +6,12 @@ AGUI_MEDIA_TYPE = "application/vnd.ag-ui.event+proto" + class EventEncoder: """ Encodes Agent User Interaction events. """ + def __init__(self, accept: str | None = None): pass diff --git a/python-sdk/tests/test_encoder.py b/python-sdk/tests/test_encoder.py index 2d466c5a4..4c2766888 100644 --- a/python-sdk/tests/test_encoder.py +++ b/python-sdk/tests/test_encoder.py @@ -3,7 +3,12 @@ from datetime import datetime from ag_ui.encoder.encoder import EventEncoder, AGUI_MEDIA_TYPE -from ag_ui.core.events import BaseEvent, EventType, TextMessageContentEvent, ToolCallStartEvent +from ag_ui.core.events import ( + BaseEvent, + EventType, + TextMessageContentEvent, + ToolCallStartEvent, +) class TestEventEncoder(unittest.TestCase): @@ -23,15 +28,17 @@ def test_encode_method(self): # Create a test event timestamp = int(datetime.now().timestamp() * 1000) event = BaseEvent(type=EventType.RAW, timestamp=timestamp) - + # Create encoder and encode event encoder = EventEncoder() encoded = encoder.encode(event) - + # The encode method calls encode_sse, so the result should be in SSE format - expected = f"data: {event.model_dump_json(by_alias=True, exclude_none=True)}\n\n" + expected = ( + f"data: {event.model_dump_json(by_alias=True, exclude_none=True)}\n\n" + ) self.assertEqual(encoded, expected) - + # Verify that camelCase is used in the encoded output self.assertIn('"type":', encoded) self.assertIn('"timestamp":', encoded) @@ -43,29 +50,29 @@ def test_encode_sse_method(self): """Test the encode_sse method""" # Create a test event with specific data event = TextMessageContentEvent( - message_id="msg_123", - delta="Hello, world!", - timestamp=1648214400000 + message_id="msg_123", delta="Hello, world!", timestamp=1648214400000 ) - + # Create encoder and encode event to SSE encoder = EventEncoder() encoded_sse = encoder._encode_sse(event) - + # Verify the format is correct for SSE (data: [json]\n\n) self.assertTrue(encoded_sse.startswith("data: ")) self.assertTrue(encoded_sse.endswith("\n\n")) - + # Extract and verify the JSON content json_content = encoded_sse[6:-2] # Remove "data: " prefix and "\n\n" suffix decoded = json.loads(json_content) - + # Check that all fields were properly encoded self.assertEqual(decoded["type"], "TEXT_MESSAGE_CONTENT") - self.assertEqual(decoded["messageId"], "msg_123") # Check snake_case converted to camelCase + self.assertEqual( + decoded["messageId"], "msg_123" + ) # Check snake_case converted to camelCase self.assertEqual(decoded["delta"], "Hello, world!") self.assertEqual(decoded["timestamp"], 1648214400000) - + # Verify that snake_case has been converted to camelCase self.assertIn("messageId", decoded) # camelCase key exists self.assertNotIn("message_id", decoded) # snake_case key doesn't exist @@ -74,77 +81,79 @@ def test_encode_with_different_event_types(self): """Test encoding different types of events""" # Create encoder encoder = EventEncoder() - + # Test with a basic BaseEvent base_event = BaseEvent(type=EventType.RAW, timestamp=1648214400000) encoded_base = encoder.encode(base_event) self.assertIn('"type":"RAW"', encoded_base) - + # Test with a more complex event content_event = TextMessageContentEvent( message_id="msg_456", delta="Testing different events", - timestamp=1648214400000 + timestamp=1648214400000, ) encoded_content = encoder.encode(content_event) - + # Verify correct encoding and camelCase conversion self.assertIn('"type":"TEXT_MESSAGE_CONTENT"', encoded_content) - self.assertIn('"messageId":"msg_456"', encoded_content) # Check snake_case converted to camelCase + self.assertIn( + '"messageId":"msg_456"', encoded_content + ) # Check snake_case converted to camelCase self.assertIn('"delta":"Testing different events"', encoded_content) - + # Extract JSON and verify camelCase conversion json_content = encoded_content.split("data: ")[1].rstrip("\n\n") decoded = json.loads(json_content) - + # Verify messageId is camelCase (not message_id) self.assertIn("messageId", decoded) self.assertNotIn("message_id", decoded) - + def test_null_value_exclusion(self): """Test that fields with None values are excluded from the JSON output""" # Create an event with some fields set to None event = BaseEvent( type=EventType.RAW, timestamp=1648214400000, - raw_event=None # Explicitly set to None + raw_event=None, # Explicitly set to None ) - + # Create encoder and encode event encoder = EventEncoder() encoded = encoder.encode(event) - + # Extract JSON json_content = encoded.split("data: ")[1].rstrip("\n\n") decoded = json.loads(json_content) - + # Verify fields that are present self.assertIn("type", decoded) self.assertIn("timestamp", decoded) - + # Verify null fields are excluded self.assertNotIn("rawEvent", decoded) - + # Test with another event that has optional fields # Create event with some optional fields set to None event_with_optional = ToolCallStartEvent( tool_call_id="call_123", tool_call_name="test_tool", parent_message_id=None, # Optional field explicitly set to None - timestamp=1648214400000 + timestamp=1648214400000, ) - + encoded_optional = encoder.encode(event_with_optional) json_content_optional = encoded_optional.split("data: ")[1].rstrip("\n\n") decoded_optional = json.loads(json_content_optional) - + # Required fields should be present self.assertIn("toolCallId", decoded_optional) self.assertIn("toolCallName", decoded_optional) - + # Optional field with None value should be excluded self.assertNotIn("parentMessageId", decoded_optional) - + def test_round_trip_serialization(self): """Test that events can be serialized to JSON with camelCase and deserialized back correctly""" # Create a complex event with multiple fields @@ -152,12 +161,12 @@ def test_round_trip_serialization(self): tool_call_id="call_abc123", tool_call_name="search_tool", parent_message_id="msg_parent_456", - timestamp=1648214400000 + timestamp=1648214400000, ) - + # Serialize to JSON with camelCase fields json_str = original_event.model_dump_json(by_alias=True) - + # Verify JSON uses camelCase json_data = json.loads(json_str) self.assertIn("toolCallId", json_data) @@ -166,19 +175,20 @@ def test_round_trip_serialization(self): self.assertNotIn("tool_call_id", json_data) self.assertNotIn("tool_call_name", json_data) self.assertNotIn("parent_message_id", json_data) - + # Deserialize back to an event deserialized_event = ToolCallStartEvent.model_validate_json(json_str) - + # Verify the deserialized event is equivalent to the original self.assertEqual(deserialized_event.type, original_event.type) self.assertEqual(deserialized_event.tool_call_id, original_event.tool_call_id) - self.assertEqual(deserialized_event.tool_call_name, original_event.tool_call_name) - self.assertEqual(deserialized_event.parent_message_id, original_event.parent_message_id) - self.assertEqual(deserialized_event.timestamp, original_event.timestamp) - - # Verify complete equality using model_dump self.assertEqual( - original_event.model_dump(), - deserialized_event.model_dump() + deserialized_event.tool_call_name, original_event.tool_call_name ) + self.assertEqual( + deserialized_event.parent_message_id, original_event.parent_message_id + ) + self.assertEqual(deserialized_event.timestamp, original_event.timestamp) + + # Verify complete equality using model_dump + self.assertEqual(original_event.model_dump(), deserialized_event.model_dump()) diff --git a/python-sdk/tests/test_events.py b/python-sdk/tests/test_events.py index e413275c7..1745f61f9 100644 --- a/python-sdk/tests/test_events.py +++ b/python-sdk/tests/test_events.py @@ -22,7 +22,7 @@ RunErrorEvent, StepStartedEvent, StepFinishedEvent, - Event + Event, ) @@ -47,13 +47,10 @@ def test_base_event_creation(self): def test_text_message_start(self): """Test creating and serializing a TextMessageStartEvent event""" - event = TextMessageStartEvent( - message_id="msg_123", - timestamp=1648214400000 - ) + event = TextMessageStartEvent(message_id="msg_123", timestamp=1648214400000) self.assertEqual(event.message_id, "msg_123") self.assertEqual(event.role, "assistant") - + # Test serialization serialized = event.model_dump(by_alias=True) self.assertEqual(serialized["type"], "TEXT_MESSAGE_START") @@ -63,13 +60,11 @@ def test_text_message_start(self): def test_text_message_content(self): """Test creating and serializing a TextMessageContentEvent event""" event = TextMessageContentEvent( - message_id="msg_123", - delta="Hello, world!", - timestamp=1648214400000 + message_id="msg_123", delta="Hello, world!", timestamp=1648214400000 ) self.assertEqual(event.message_id, "msg_123") self.assertEqual(event.delta, "Hello, world!") - + # Test serialization serialized = event.model_dump(by_alias=True) self.assertEqual(serialized["type"], "TEXT_MESSAGE_CONTENT") @@ -78,12 +73,9 @@ def test_text_message_content(self): def test_text_message_end(self): """Test creating and serializing a TextMessageEndEvent event""" - event = TextMessageEndEvent( - message_id="msg_123", - timestamp=1648214400000 - ) + event = TextMessageEndEvent(message_id="msg_123", timestamp=1648214400000) self.assertEqual(event.message_id, "msg_123") - + # Test serialization serialized = event.model_dump(by_alias=True) self.assertEqual(serialized["type"], "TEXT_MESSAGE_END") @@ -95,12 +87,12 @@ def test_tool_call_start(self): tool_call_id="call_123", tool_call_name="get_weather", parent_message_id="msg_456", - timestamp=1648214400000 + timestamp=1648214400000, ) self.assertEqual(event.tool_call_id, "call_123") self.assertEqual(event.tool_call_name, "get_weather") self.assertEqual(event.parent_message_id, "msg_456") - + # Test serialization serialized = event.model_dump(by_alias=True) self.assertEqual(serialized["type"], "TOOL_CALL_START") @@ -113,11 +105,11 @@ def test_tool_call_args(self): event = ToolCallArgsEvent( tool_call_id="call_123", delta='{"location": "New York"}', - timestamp=1648214400000 + timestamp=1648214400000, ) self.assertEqual(event.tool_call_id, "call_123") self.assertEqual(event.delta, '{"location": "New York"}') - + # Test serialization serialized = event.model_dump(by_alias=True) self.assertEqual(serialized["type"], "TOOL_CALL_ARGS") @@ -126,12 +118,9 @@ def test_tool_call_args(self): def test_tool_call_end(self): """Test creating and serializing a ToolCallEndEvent event""" - event = ToolCallEndEvent( - tool_call_id="call_123", - timestamp=1648214400000 - ) + event = ToolCallEndEvent(tool_call_id="call_123", timestamp=1648214400000) self.assertEqual(event.tool_call_id, "call_123") - + # Test serialization serialized = event.model_dump(by_alias=True) self.assertEqual(serialized["type"], "TOOL_CALL_END") @@ -140,12 +129,9 @@ def test_tool_call_end(self): def test_state_snapshot(self): """Test creating and serializing a StateSnapshotEvent event""" state = {"conversation_state": "active", "user_info": {"name": "John"}} - event = StateSnapshotEvent( - snapshot=state, - timestamp=1648214400000 - ) + event = StateSnapshotEvent(snapshot=state, timestamp=1648214400000) self.assertEqual(event.snapshot, state) - + # Test serialization serialized = event.model_dump(by_alias=True) self.assertEqual(serialized["type"], "STATE_SNAPSHOT") @@ -157,14 +143,11 @@ def test_state_delta(self): # JSON Patch format delta = [ {"op": "replace", "path": "/conversation_state", "value": "paused"}, - {"op": "add", "path": "/user_info/age", "value": 30} + {"op": "add", "path": "/user_info/age", "value": 30}, ] - event = StateDeltaEvent( - delta=delta, - timestamp=1648214400000 - ) + event = StateDeltaEvent(delta=delta, timestamp=1648214400000) self.assertEqual(event.delta, delta) - + # Test serialization serialized = event.model_dump(by_alias=True) self.assertEqual(serialized["type"], "STATE_DELTA") @@ -176,42 +159,40 @@ def test_messages_snapshot(self): """Test creating and serializing a MessagesSnapshotEvent event""" messages = [ UserMessage(id="user_1", content="Hello"), - AssistantMessage(id="asst_1", content="Hi there", tool_calls=[ - ToolCall( - id="call_1", - function=FunctionCall( - name="get_weather", - arguments='{"location": "New York"}' + AssistantMessage( + id="asst_1", + content="Hi there", + tool_calls=[ + ToolCall( + id="call_1", + function=FunctionCall( + name="get_weather", arguments='{"location": "New York"}' + ), ) - ) - ]) + ], + ), ] - event = MessagesSnapshotEvent( - messages=messages, - timestamp=1648214400000 - ) + event = MessagesSnapshotEvent(messages=messages, timestamp=1648214400000) self.assertEqual(len(event.messages), 2) self.assertEqual(event.messages[0].id, "user_1") self.assertEqual(event.messages[1].tool_calls[0].function.name, "get_weather") - + # Test serialization serialized = event.model_dump(by_alias=True) self.assertEqual(serialized["type"], "MESSAGES_SNAPSHOT") self.assertEqual(len(serialized["messages"]), 2) self.assertEqual(serialized["messages"][0]["role"], "user") - self.assertEqual(serialized["messages"][1]["toolCalls"][0]["function"]["name"], "get_weather") + self.assertEqual( + serialized["messages"][1]["toolCalls"][0]["function"]["name"], "get_weather" + ) def test_raw_event(self): """Test creating and serializing a RawEvent""" raw_data = {"origin": "server", "data": {"key": "value"}} - event = RawEvent( - event=raw_data, - source="api", - timestamp=1648214400000 - ) + event = RawEvent(event=raw_data, source="api", timestamp=1648214400000) self.assertEqual(event.event, raw_data) self.assertEqual(event.source, "api") - + # Test serialization serialized = event.model_dump(by_alias=True) self.assertEqual(serialized["type"], "RAW") @@ -223,11 +204,11 @@ def test_custom_event(self): event = CustomEvent( name="user_action", value={"action": "click", "element": "button"}, - timestamp=1648214400000 + timestamp=1648214400000, ) self.assertEqual(event.name, "user_action") self.assertEqual(event.value["action"], "click") - + # Test serialization serialized = event.model_dump(by_alias=True) self.assertEqual(serialized["type"], "CUSTOM") @@ -237,13 +218,11 @@ def test_custom_event(self): def test_run_started(self): """Test creating and serializing a RunStartedEvent event""" event = RunStartedEvent( - thread_id="thread_123", - run_id="run_456", - timestamp=1648214400000 + thread_id="thread_123", run_id="run_456", timestamp=1648214400000 ) self.assertEqual(event.thread_id, "thread_123") self.assertEqual(event.run_id, "run_456") - + # Test serialization serialized = event.model_dump(by_alias=True) self.assertEqual(serialized["type"], "RUN_STARTED") @@ -253,13 +232,11 @@ def test_run_started(self): def test_run_finished(self): """Test creating and serializing a RunFinishedEvent event""" event = RunFinishedEvent( - thread_id="thread_123", - run_id="run_456", - timestamp=1648214400000 + thread_id="thread_123", run_id="run_456", timestamp=1648214400000 ) self.assertEqual(event.thread_id, "thread_123") self.assertEqual(event.run_id, "run_456") - + # Test serialization serialized = event.model_dump(by_alias=True) self.assertEqual(serialized["type"], "RUN_FINISHED") @@ -271,11 +248,11 @@ def test_run_error(self): event = RunErrorEvent( message="An error occurred during execution", code="ERROR_001", - timestamp=1648214400000 + timestamp=1648214400000, ) self.assertEqual(event.message, "An error occurred during execution") self.assertEqual(event.code, "ERROR_001") - + # Test serialization serialized = event.model_dump(by_alias=True) self.assertEqual(serialized["type"], "RUN_ERROR") @@ -284,12 +261,9 @@ def test_run_error(self): def test_step_started(self): """Test creating and serializing a StepStartedEvent event""" - event = StepStartedEvent( - step_name="process_data", - timestamp=1648214400000 - ) + event = StepStartedEvent(step_name="process_data", timestamp=1648214400000) self.assertEqual(event.step_name, "process_data") - + # Test serialization serialized = event.model_dump(by_alias=True) self.assertEqual(serialized["type"], "STEP_STARTED") @@ -297,12 +271,9 @@ def test_step_started(self): def test_step_finished(self): """Test creating and serializing a StepFinishedEvent event""" - event = StepFinishedEvent( - step_name="process_data", - timestamp=1648214400000 - ) + event = StepFinishedEvent(step_name="process_data", timestamp=1648214400000) self.assertEqual(event.step_name, "process_data") - + # Test serialization serialized = event.model_dump(by_alias=True) self.assertEqual(serialized["type"], "STEP_FINISHED") @@ -311,48 +282,48 @@ def test_step_finished(self): def test_event_union_deserialization(self): """Test the Event union type correctly deserializes different event types""" event_adapter = TypeAdapter(Event) - + # Test different event types event_data = [ { "type": "TEXT_MESSAGE_START", "messageId": "msg_start", "role": "assistant", - "timestamp": 1648214400000 + "timestamp": 1648214400000, }, { "type": "TEXT_MESSAGE_CONTENT", "messageId": "msg_content", "delta": "Hello!", - "timestamp": 1648214400000 + "timestamp": 1648214400000, }, { "type": "TOOL_CALL_START", "toolCallId": "call_start", "toolCallName": "get_info", - "timestamp": 1648214400000 + "timestamp": 1648214400000, }, { "type": "STATE_SNAPSHOT", "snapshot": {"status": "active"}, - "timestamp": 1648214400000 + "timestamp": 1648214400000, }, { "type": "RUN_ERROR", "message": "Error occurred", "code": "ERR_001", - "timestamp": 1648214400000 - } + "timestamp": 1648214400000, + }, ] - + expected_types = [ TextMessageStartEvent, TextMessageContentEvent, ToolCallStartEvent, StateSnapshotEvent, - RunErrorEvent + RunErrorEvent, ] - + for data, expected_type in zip(event_data, expected_types): event = event_adapter.validate_python(data) self.assertIsInstance(event, expected_type) @@ -365,7 +336,7 @@ def test_validation_constraints(self): with self.assertRaises(ValueError): TextMessageContentEvent( message_id="msg_123", - delta="" # Empty delta, should fail + delta="", # Empty delta, should fail ) def test_serialization_round_trip(self): @@ -375,57 +346,54 @@ def test_serialization_round_trip(self): TextMessageStartEvent( message_id="msg_123", ), - TextMessageContentEvent( - message_id="msg_123", - delta="Hello, world!" - ), - ToolCallStartEvent( - tool_call_id="call_123", - tool_call_name="get_weather" - ), - StateSnapshotEvent( - snapshot={"status": "active"} - ), - MessagesSnapshotEvent( - messages=[ - UserMessage(id="user_1", content="Hello") - ] - ), - RunStartedEvent( - thread_id="thread_123", - run_id="run_456" - ) + TextMessageContentEvent(message_id="msg_123", delta="Hello, world!"), + ToolCallStartEvent(tool_call_id="call_123", tool_call_name="get_weather"), + StateSnapshotEvent(snapshot={"status": "active"}), + MessagesSnapshotEvent(messages=[UserMessage(id="user_1", content="Hello")]), + RunStartedEvent(thread_id="thread_123", run_id="run_456"), ] - + event_adapter = TypeAdapter(Event) - + # Test round trip for each event for original_event in events: # Serialize to JSON json_str = original_event.model_dump_json(by_alias=True) - + # Deserialize back to object deserialized_event = event_adapter.validate_json(json_str) - + # Verify the types match self.assertIsInstance(deserialized_event, type(original_event)) self.assertEqual(deserialized_event.type, original_event.type) - + # Verify event-specific fields if isinstance(original_event, TextMessageStartEvent): - self.assertEqual(deserialized_event.message_id, original_event.message_id) + self.assertEqual( + deserialized_event.message_id, original_event.message_id + ) self.assertEqual(deserialized_event.role, original_event.role) elif isinstance(original_event, TextMessageContentEvent): - self.assertEqual(deserialized_event.message_id, original_event.message_id) + self.assertEqual( + deserialized_event.message_id, original_event.message_id + ) self.assertEqual(deserialized_event.delta, original_event.delta) elif isinstance(original_event, ToolCallStartEvent): - self.assertEqual(deserialized_event.tool_call_id, original_event.tool_call_id) - self.assertEqual(deserialized_event.tool_call_name, original_event.tool_call_name) + self.assertEqual( + deserialized_event.tool_call_id, original_event.tool_call_id + ) + self.assertEqual( + deserialized_event.tool_call_name, original_event.tool_call_name + ) elif isinstance(original_event, StateSnapshotEvent): self.assertEqual(deserialized_event.snapshot, original_event.snapshot) elif isinstance(original_event, MessagesSnapshotEvent): - self.assertEqual(len(deserialized_event.messages), len(original_event.messages)) - self.assertEqual(deserialized_event.messages[0].id, original_event.messages[0].id) + self.assertEqual( + len(deserialized_event.messages), len(original_event.messages) + ) + self.assertEqual( + deserialized_event.messages[0].id, original_event.messages[0].id + ) elif isinstance(original_event, RunStartedEvent): self.assertEqual(deserialized_event.thread_id, original_event.thread_id) self.assertEqual(deserialized_event.run_id, original_event.run_id) @@ -434,16 +402,16 @@ def test_raw_event_with_null_source(self): """Test RawEvent with null source""" event = RawEvent( event={"data": "test"}, - source=None # Explicit None + source=None, # Explicit None ) self.assertIsNone(event.source) - + # Test serialization serialized = event.model_dump(by_alias=True) self.assertEqual(serialized["type"], "RAW") self.assertEqual(serialized["event"]["data"], "test") self.assertIsNone(serialized["source"]) - + # Test round-trip event_adapter = TypeAdapter(Event) json_str = event.model_dump_json(by_alias=True) @@ -460,44 +428,39 @@ def test_complex_nested_event_structures(self): "preferences": { "theme": "dark", "notifications": True, - "filters": ["news", "social", "tech"] - } + "filters": ["news", "social", "tech"], + }, }, "stats": { "messages": 42, - "interactions": { - "clicks": 18, - "searches": 7 - } - } + "interactions": {"clicks": 18, "searches": 7}, + }, }, "active_tools": ["search", "calculator", "weather"], - "settings": { - "language": "en", - "timezone": "UTC-5" - } + "settings": {"language": "en", "timezone": "UTC-5"}, } - - event = StateSnapshotEvent( - snapshot=complex_state, - timestamp=1648214400000 - ) - + + event = StateSnapshotEvent(snapshot=complex_state, timestamp=1648214400000) + # Verify complex state structure self.assertEqual(event.snapshot["session"]["user"]["id"], "user_123") - self.assertEqual(event.snapshot["session"]["user"]["preferences"]["theme"], "dark") - self.assertEqual(event.snapshot["session"]["stats"]["interactions"]["searches"], 7) + self.assertEqual( + event.snapshot["session"]["user"]["preferences"]["theme"], "dark" + ) + self.assertEqual( + event.snapshot["session"]["stats"]["interactions"]["searches"], 7 + ) self.assertEqual(event.snapshot["active_tools"][1], "calculator") - + # Test serialization and deserialization event_adapter = TypeAdapter(Event) json_str = event.model_dump_json(by_alias=True) deserialized = event_adapter.validate_json(json_str) - + # Verify structure is preserved self.assertEqual( deserialized.snapshot["session"]["user"]["preferences"]["filters"], - ["news", "social", "tech"] + ["news", "social", "tech"], ) self.assertEqual(deserialized.snapshot["settings"]["timezone"], "UTC-5") @@ -505,21 +468,19 @@ def test_event_with_unicode_and_special_chars(self): """Test events with Unicode and special characters""" # Text with Unicode and special characters text = "Hello 你好 こんにちは 안녕하세요 👋 🌍 \n\t\"'\\/<>{}[]" - + event = TextMessageContentEvent( - message_id="msg_unicode", - delta=text, - timestamp=1648214400000 + message_id="msg_unicode", delta=text, timestamp=1648214400000 ) - + # Verify text is stored correctly self.assertEqual(event.delta, text) - + # Test serialization and deserialization event_adapter = TypeAdapter(Event) json_str = event.model_dump_json(by_alias=True) deserialized = event_adapter.validate_json(json_str) - + # Verify Unicode and special characters are preserved self.assertEqual(deserialized.delta, text) diff --git a/python-sdk/tests/test_types.py b/python-sdk/tests/test_types.py index e534aa5ab..6ad77a7e1 100644 --- a/python-sdk/tests/test_types.py +++ b/python-sdk/tests/test_types.py @@ -11,7 +11,7 @@ UserMessage, ToolMessage, Message, - RunAgentInput + RunAgentInput, ) @@ -26,10 +26,7 @@ def test_function_call_creation(self): def test_message_serialization(self): """Test serialization of a basic message""" - user_msg = UserMessage( - id="msg_123", - content="Hello, world!" - ) + user_msg = UserMessage(id="msg_123", content="Hello, world!") serialized = user_msg.model_dump(by_alias=True) self.assertEqual(serialized["id"], "msg_123") self.assertEqual(serialized["role"], "user") @@ -38,8 +35,7 @@ def test_message_serialization(self): def test_tool_call_serialization(self): """Test camel case serialization for ConfiguredBaseModel subclasses""" tool_call = ToolCall( - id="call_123", - function=FunctionCall(name="test_function", arguments="{}") + id="call_123", function=FunctionCall(name="test_function", arguments="{}") ) serialized = tool_call.model_dump(by_alias=True) # Should convert function to camelCase @@ -48,9 +44,7 @@ def test_tool_call_serialization(self): def test_tool_message_camel_case(self): """Test camel case serialization for ToolMessage""" tool_msg = ToolMessage( - id="tool_123", - content="Tool result", - tool_call_id="call_456" + id="tool_123", content="Tool result", tool_call_id="call_456" ) serialized = tool_msg.model_dump(by_alias=True) self.assertIn("toolCallId", serialized) @@ -63,7 +57,7 @@ def test_parse_camel_case_json_tool_message(self): "id": "tool_789", "role": "tool", "content": "Result from tool", - "toolCallId": "call_123" # camelCase field name + "toolCallId": "call_123", # camelCase field name } # Parse the JSON data into a ToolMessage instance @@ -81,10 +75,7 @@ def test_parse_camel_case_json_function_call(self): json_data = { "id": "call_abc", "type": "function", - "function": { - "name": "get_weather", - "arguments": '{"location":"New York"}' - } + "function": {"name": "get_weather", "arguments": '{"location":"New York"}'}, } # Parse JSON into a ToolCall instance @@ -98,20 +89,14 @@ def test_parse_camel_case_json_function_call(self): def test_developer_message(self): """Test creating and serializing a developer message""" - msg = DeveloperMessage( - id="dev_123", - content="Developer note" - ) + msg = DeveloperMessage(id="dev_123", content="Developer note") serialized = msg.model_dump(by_alias=True) self.assertEqual(serialized["role"], "developer") self.assertEqual(serialized["content"], "Developer note") def test_system_message(self): """Test creating and serializing a system message""" - msg = SystemMessage( - id="sys_123", - content="System instruction" - ) + msg = SystemMessage(id="sys_123", content="System instruction") serialized = msg.model_dump(by_alias=True) self.assertEqual(serialized["role"], "system") self.assertEqual(serialized["content"], "System instruction") @@ -120,12 +105,10 @@ def test_assistant_message(self): """Test creating and serializing an assistant message with tool calls""" tool_call = ToolCall( id="call_456", - function=FunctionCall(name="get_data", arguments='{"param": "value"}') + function=FunctionCall(name="get_data", arguments='{"param": "value"}'), ) msg = AssistantMessage( - id="asst_123", - content="Assistant response", - tool_calls=[tool_call] + id="asst_123", content="Assistant response", tool_calls=[tool_call] ) serialized = msg.model_dump(by_alias=True) self.assertEqual(serialized["role"], "assistant") @@ -135,10 +118,7 @@ def test_assistant_message(self): def test_user_message(self): """Test creating and serializing a user message""" - msg = UserMessage( - id="user_123", - content="User query" - ) + msg = UserMessage(id="user_123", content="User query") serialized = msg.model_dump(by_alias=True) self.assertEqual(serialized["role"], "user") self.assertEqual(serialized["content"], "User query") @@ -155,11 +135,11 @@ def test_message_union_deserialization(self): {"id": "asst_789", "role": "assistant", "content": "Assistant response"}, {"id": "user_101", "role": "user", "content": "User query"}, { - "id": "tool_202", - "role": "tool", - "content": "Tool result", - "toolCallId": "call_303" - } + "id": "tool_202", + "role": "tool", + "content": "Tool result", + "toolCallId": "call_303", + }, ] expected_types = [ @@ -167,7 +147,7 @@ def test_message_union_deserialization(self): SystemMessage, AssistantMessage, UserMessage, - ToolMessage + ToolMessage, ] for data, expected_type in zip(message_data, expected_types): @@ -192,10 +172,10 @@ def test_message_union_with_tool_calls(self): "type": "function", "function": { "name": "search_data", - "arguments": '{"query": "python"}' - } + "arguments": '{"query": "python"}', + }, } - ] + ], } msg = message_adapter.validate_python(data) @@ -215,19 +195,19 @@ def test_run_agent_input_deserialization(self): { "id": "sys_001", "role": "system", - "content": "You are a helpful assistant." + "content": "You are a helpful assistant.", }, # User message { "id": "user_001", "role": "user", - "content": "Can you help me analyze this data?" + "content": "Can you help me analyze this data?", }, # Developer message { "id": "dev_001", "role": "developer", - "content": "The assistant should provide a detailed analysis." + "content": "The assistant should provide a detailed analysis.", }, # Assistant message with tool calls { @@ -240,24 +220,24 @@ def test_run_agent_input_deserialization(self): "type": "function", "function": { "name": "analyze_data", - "arguments": '{"dataset": "sales_2023", "metrics": ["mean", "median"]}' # pylint: disable=line-too-long - } + "arguments": '{"dataset": "sales_2023", "metrics": ["mean", "median"]}', # pylint: disable=line-too-long + }, } - ] + ], }, # Tool message responding to tool call { "id": "tool_001", "role": "tool", "content": '{"mean": 42.5, "median": 38.0}', - "toolCallId": "call_001" + "toolCallId": "call_001", }, # Another user message { "id": "user_002", "role": "user", - "content": "Can you explain these results?" - } + "content": "Can you explain these results?", + }, ], "tools": [ { @@ -267,10 +247,10 @@ def test_run_agent_input_deserialization(self): "type": "object", "properties": { "dataset": {"type": "string"}, - "metrics": {"type": "array", "items": {"type": "string"}} + "metrics": {"type": "array", "items": {"type": "string"}}, }, - "required": ["dataset"] - } + "required": ["dataset"], + }, }, { "name": "fetch_data", @@ -279,26 +259,23 @@ def test_run_agent_input_deserialization(self): "type": "object", "properties": { "source": {"type": "string"}, - "query": {"type": "string"} + "query": {"type": "string"}, }, - "required": ["source", "query"] - } - } + "required": ["source", "query"], + }, + }, ], "context": [ { "description": "User preferences", - "value": '{"theme": "dark", "language": "English"}' + "value": '{"theme": "dark", "language": "English"}', }, - { - "description": "Environment", - "value": "production" - } + {"description": "Environment", "value": "production"}, ], "forwardedProps": { "api_version": "v1", - "custom_settings": {"max_tokens": 500} - } + "custom_settings": {"max_tokens": 500}, + }, } # Deserialize using TypeAdapter @@ -319,8 +296,12 @@ def test_run_agent_input_deserialization(self): self.assertIsInstance(run_agent_input.messages[5], UserMessage) # Verify specific message content - self.assertEqual(run_agent_input.messages[0].content, "You are a helpful assistant.") - self.assertEqual(run_agent_input.messages[1].content, "Can you help me analyze this data?") + self.assertEqual( + run_agent_input.messages[0].content, "You are a helpful assistant." + ) + self.assertEqual( + run_agent_input.messages[1].content, "Can you help me analyze this data?" + ) # Verify assistant message with tool call assistant_msg = run_agent_input.messages[3] @@ -344,7 +325,9 @@ def test_run_agent_input_deserialization(self): # Verify forwarded props self.assertEqual(run_agent_input.forwarded_props["api_version"], "v1") - self.assertEqual(run_agent_input.forwarded_props["custom_settings"]["max_tokens"], 500) + self.assertEqual( + run_agent_input.forwarded_props["custom_settings"]["max_tokens"], 500 + ) def test_validation_errors(self): """Test validation errors for various message types""" @@ -354,7 +337,7 @@ def test_validation_errors(self): invalid_role_data = { "id": "msg_123", "role": "invalid_role", # Invalid role - "content": "Hello" + "content": "Hello", } with self.assertRaises(ValidationError): message_adapter.validate_python(invalid_role_data) @@ -363,7 +346,7 @@ def test_validation_errors(self): missing_id_data = { # Missing "id" field "role": "user", - "content": "Hello" + "content": "Hello", } with self.assertRaises(ValidationError): UserMessage.model_validate(missing_id_data) @@ -373,7 +356,7 @@ def test_validation_errors(self): "id": "msg_456", "role": "user", "content": "Hello", - "extra_field": "This shouldn't be here" # Extra field + "extra_field": "This shouldn't be here", # Extra field } with self.assertRaises(ValidationError): UserMessage.model_validate(extra_field_data) @@ -396,9 +379,9 @@ def test_empty_collections(self): "runId": "run_empty", "state": {}, "messages": [], # Empty messages - "tools": [], # Empty tools - "context": [], # Empty context - "forwardedProps": {} + "tools": [], # Empty tools + "context": [], # Empty context + "forwardedProps": {}, } # Deserialize and verify @@ -423,26 +406,26 @@ def test_multiple_tool_calls(self): "type": "function", "function": { "name": "get_weather", - "arguments": '{"location": "New York"}' - } + "arguments": '{"location": "New York"}', + }, }, { "id": "call_2", "type": "function", "function": { "name": "search_database", - "arguments": '{"query": "recent sales"}' - } + "arguments": '{"query": "recent sales"}', + }, }, { "id": "call_3", "type": "function", "function": { "name": "calculate", - "arguments": '{"operation": "sum", "values": [1, 2, 3, 4, 5]}' - } - } - ] + "arguments": '{"operation": "sum", "values": [1, 2, 3, 4, 5]}', + }, + }, + ], } # Deserialize and verify @@ -471,13 +454,9 @@ def test_serialization_round_trip(self): { "id": "sys_rt", "role": "system", - "content": "You are a helpful assistant." - }, - { - "id": "user_rt", - "role": "user", - "content": "Help me with my task." + "content": "You are a helpful assistant.", }, + {"id": "user_rt", "role": "user", "content": "Help me with my task."}, { "id": "asst_rt", "role": "assistant", @@ -486,33 +465,20 @@ def test_serialization_round_trip(self): { "id": "call_rt", "type": "function", - "function": { - "name": "get_task_info", - "arguments": "{}" - } + "function": {"name": "get_task_info", "arguments": "{}"}, } - ] - } + ], + }, ], "tools": [ { "name": "get_task_info", "description": "Get task information", - "parameters": { - "type": "object", - "properties": {} - } - } - ], - "context": [ - { - "description": "Session", - "value": "123456" + "parameters": {"type": "object", "properties": {}}, } ], - "forwardedProps": { - "timestamp": 1648214400 - } + "context": [{"description": "Session", "value": "123456"}], + "forwardedProps": {"timestamp": 1648214400}, } # Deserialize @@ -538,7 +504,7 @@ def test_serialization_round_trip(self): self.assertEqual(len(deserialized_obj.messages[2].tool_calls), 1) self.assertEqual( deserialized_obj.messages[2].tool_calls[0].function.name, - original_obj.messages[2].tool_calls[0].function.name + original_obj.messages[2].tool_calls[0].function.name, ) def test_content_edge_cases(self): @@ -548,7 +514,7 @@ def test_content_edge_cases(self): empty_content_data = { "id": "msg_empty", "role": "user", - "content": "" # Empty string + "content": "", # Empty string } empty_msg = UserMessage.model_validate(empty_content_data) self.assertEqual(empty_msg.content, "") @@ -562,12 +528,9 @@ def test_content_edge_cases(self): { "id": "call_null", "type": "function", - "function": { - "name": "get_data", - "arguments": "{}" - } + "function": {"name": "get_data", "arguments": "{}"}, } - ] + ], } null_msg = AssistantMessage.model_validate(null_content_data) self.assertIsNone(null_msg.content) @@ -577,17 +540,19 @@ def test_content_edge_cases(self): large_content_data = { "id": "msg_large", "role": "user", - "content": large_content + "content": large_content, } large_msg = UserMessage.model_validate(large_content_data) self.assertEqual(len(large_msg.content), 10000) # Test content with special characters - special_chars = "Special chars: 你好 こんにちは 안녕하세요 👋 🌍 \n\t\"'\\/<>{}[]" + special_chars = ( + "Special chars: 你好 こんにちは 안녕하세요 👋 🌍 \n\t\"'\\/<>{}[]" + ) special_chars_data = { "id": "msg_special", "role": "user", - "content": special_chars + "content": special_chars, } special_msg = UserMessage.model_validate(special_chars_data) self.assertEqual(special_msg.content, special_chars) @@ -599,7 +564,7 @@ def test_name_field_handling(self): "id": "user_named", "role": "user", "content": "Hello", - "name": "John" + "name": "John", } user_msg = UserMessage.model_validate(user_with_name_data) self.assertEqual(user_msg.name, "John") @@ -609,7 +574,7 @@ def test_name_field_handling(self): "id": "asst_named", "role": "assistant", "content": "Hello", - "name": "AI Assistant" + "name": "AI Assistant", } assistant_msg = AssistantMessage.model_validate(assistant_with_name_data) self.assertEqual(assistant_msg.name, "AI Assistant") @@ -633,7 +598,7 @@ def test_state_variations(self): "messages": [], "tools": [], "context": [], - "forwardedProps": {} + "forwardedProps": {}, } scalar_input = RunAgentInput.model_validate(scalar_state_data) self.assertEqual(scalar_input.state, "ACTIVE") @@ -647,19 +612,13 @@ def test_state_variations(self): "preferences": { "theme": "dark", "notifications": True, - "filters": ["important", "urgent"] - } + "filters": ["important", "urgent"], + }, }, - "metrics": { - "requests": 42, - "tokens": { - "input": 1024, - "output": 2048 - } - } + "metrics": {"requests": 42, "tokens": {"input": 1024, "output": 2048}}, }, "timestamp": 1648214400, - "version": "1.0.0" + "version": "1.0.0", } complex_state_data = { @@ -669,15 +628,19 @@ def test_state_variations(self): "messages": [], "tools": [], "context": [], - "forwardedProps": {} + "forwardedProps": {}, } complex_input = RunAgentInput.model_validate(complex_state_data) # Verify nested state structure is preserved self.assertEqual(complex_input.state["session"]["id"], "sess_123") self.assertEqual(complex_input.state["session"]["user"]["id"], "user_456") - self.assertEqual(complex_input.state["session"]["user"]["preferences"]["theme"], "dark") - self.assertEqual(complex_input.state["session"]["metrics"]["tokens"]["output"], 2048) + self.assertEqual( + complex_input.state["session"]["user"]["preferences"]["theme"], "dark" + ) + self.assertEqual( + complex_input.state["session"]["metrics"]["tokens"]["output"], 2048 + ) self.assertEqual(complex_input.state["version"], "1.0.0") # Verify serialization round-trip works with complex state @@ -685,7 +648,7 @@ def test_state_variations(self): deserialized = RunAgentInput.model_validate(serialized) self.assertEqual( deserialized.state["session"]["user"]["preferences"]["filters"], - ["important", "urgent"] + ["important", "urgent"], ) From 42f0d19bba55b3bcff286a7b232e090fbf392416 Mon Sep 17 00:00:00 2001 From: Wouter Doppenberg Date: Fri, 15 Aug 2025 10:47:24 +0200 Subject: [PATCH 3/6] CI schema error fix --- .github/workflows/test.yml | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index ddfd0bfd9..c1bbdcbee 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -2,9 +2,13 @@ name: test on: push: - branches: main + branches: [ + "main" + ] pull_request: - branches: main + branches: [ + "main" + ] jobs: python: From f147fad877d94a93568a86e7d194ca0a483ab497 Mon Sep 17 00:00:00 2001 From: Wouter Doppenberg Date: Fri, 15 Aug 2025 10:50:02 +0200 Subject: [PATCH 4/6] Python 3.9 compatibility changes; pre-commit updates --- python-sdk/.pre-commit-config.yaml | 6 +++--- python-sdk/ag_ui/encoder/encoder.py | 4 +++- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/python-sdk/.pre-commit-config.yaml b/python-sdk/.pre-commit-config.yaml index f52033e00..a0e329a04 100644 --- a/python-sdk/.pre-commit-config.yaml +++ b/python-sdk/.pre-commit-config.yaml @@ -1,6 +1,6 @@ repos: - repo: https://github.com/charliermarsh/ruff-pre-commit - rev: v0.11.8 + rev: v0.12.9 hooks: - id: ruff args: [ @@ -10,12 +10,12 @@ repos: - id: ruff-format files: ^python-sdk/ - repo: https://github.com/pre-commit/mirrors-mypy - rev: v1.15.0 + rev: v1.17.1 hooks: - id: mypy files: ^python-sdk/ args: [ - --python-version=3.12, + --python-version=3.9, --disallow-untyped-calls, --disallow-untyped-defs, --disallow-incomplete-defs, diff --git a/python-sdk/ag_ui/encoder/encoder.py b/python-sdk/ag_ui/encoder/encoder.py index a30957568..227a552d6 100644 --- a/python-sdk/ag_ui/encoder/encoder.py +++ b/python-sdk/ag_ui/encoder/encoder.py @@ -2,6 +2,8 @@ This module contains the EventEncoder class """ +from typing import Union + from ag_ui.core.events import BaseEvent AGUI_MEDIA_TYPE = "application/vnd.ag-ui.event+proto" @@ -12,7 +14,7 @@ class EventEncoder: Encodes Agent User Interaction events. """ - def __init__(self, accept: str | None = None): + def __init__(self, accept: Union[str, None] = None): pass def get_content_type(self) -> str: From 2a240a15febc15ef93e1fdab65d1d2b1f88a5c9d Mon Sep 17 00:00:00 2001 From: Wouter Doppenberg Date: Fri, 15 Aug 2025 10:55:29 +0200 Subject: [PATCH 5/6] AgentStateT -> StateT --- python-sdk/ag_ui/core/events.py | 6 +++--- python-sdk/ag_ui/core/types.py | 6 +++--- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/python-sdk/ag_ui/core/events.py b/python-sdk/ag_ui/core/events.py index b0b2bab40..256df6778 100644 --- a/python-sdk/ag_ui/core/events.py +++ b/python-sdk/ag_ui/core/events.py @@ -7,7 +7,7 @@ from pydantic import Field -from .types import ConfiguredBaseModel, Message, AgentStateT, JSONValue +from .types import ConfiguredBaseModel, Message, StateT, JSONValue class EventType(str, Enum): @@ -193,13 +193,13 @@ class ThinkingEndEvent(BaseEvent): type: Literal[EventType.THINKING_END] = EventType.THINKING_END # pyright: ignore[reportIncompatibleVariableOverride] -class StateSnapshotEvent(BaseEvent, Generic[AgentStateT]): +class StateSnapshotEvent(BaseEvent, Generic[StateT]): """ Event containing a snapshot of the state. """ type: Literal[EventType.STATE_SNAPSHOT] = EventType.STATE_SNAPSHOT # pyright: ignore[reportIncompatibleVariableOverride] - snapshot: AgentStateT + snapshot: StateT class StateDeltaEvent(BaseEvent): diff --git a/python-sdk/ag_ui/core/types.py b/python-sdk/ag_ui/core/types.py index 824f0bd70..6aaa13729 100644 --- a/python-sdk/ag_ui/core/types.py +++ b/python-sdk/ag_ui/core/types.py @@ -9,7 +9,7 @@ from pydantic.alias_generators import to_camel JSONValue = Union[str, int, float, bool, None, dict[str, Any], list[Any]] -AgentStateT = TypeVar("AgentStateT", default=JSONValue, contravariant=True) +StateT = TypeVar("StateT", default=JSONValue, contravariant=True) FwdPropsT = TypeVar("FwdPropsT", default=JSONValue, contravariant=True) @@ -127,14 +127,14 @@ class Tool(ConfiguredBaseModel): parameters: Any # JSON Schema for the tool parameters -class RunAgentInput(ConfiguredBaseModel, Generic[AgentStateT, FwdPropsT]): +class RunAgentInput(ConfiguredBaseModel, Generic[StateT, FwdPropsT]): """ Input for running an agent. """ thread_id: str run_id: str - state: AgentStateT + state: StateT messages: List[Message] tools: List[Tool] context: List[Context] From 0d04958260b2dbc6c161d9b0aa2fafe446827ae7 Mon Sep 17 00:00:00 2001 From: Wouter Doppenberg Date: Fri, 15 Aug 2025 11:05:57 +0200 Subject: [PATCH 6/6] `ag-ui-protocol` -> `ag-ui` --- python-sdk/pyproject.toml | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/python-sdk/pyproject.toml b/python-sdk/pyproject.toml index 15b8b9649..dbc7fe56c 100644 --- a/python-sdk/pyproject.toml +++ b/python-sdk/pyproject.toml @@ -1,5 +1,5 @@ [project] -name = "ag-ui-protocol" +name = "ag-ui" version = "0.1.9" description = "" authors = [ @@ -10,13 +10,6 @@ requires-python = ">=3.9,<4.0" dependencies = [ "pydantic>=2.11.2,<3.0.0", ] -packages = [ - { include = "ag_ui", from = "ag_ui" } -] - -[tool.hatch.build.targets.wheel] -packages = ["ag_ui"] - [build-system] requires = ["hatchling"]