diff --git a/architecture/mcp.md b/architecture/mcp.md index 86f14bed1..626351966 100644 --- a/architecture/mcp.md +++ b/architecture/mcp.md @@ -34,6 +34,11 @@ Scoped to one `ToolConfig`. Provides the interface that `ModelFacade` uses: - **`process_completion_response`** — extracts tool calls from a completion, executes them in parallel via `MCPIOService`, returns `ChatMessage` list with results - **`refuse_completion_response`** — handles tool-call turn limits (prevents infinite tool loops) +Tool result messages may contain either text or ordered multimodal content blocks. MCP image results, and generic +base64 payloads with `image/*` MIME metadata or image data URI prefixes, are preserved as canonical `image_url` data +URI blocks and translated by provider adapters at the API boundary. Models need VLM-capable provider support to +interpret those image results semantically. + ### MCPRegistry Maps `tool_alias` → `ToolConfig`. Lazy `MCPFacade` construction mirrors `ModelRegistry`. Provides health checks for configured tools. diff --git a/packages/data-designer-engine/src/data_designer/engine/mcp/io.py b/packages/data-designer-engine/src/data_designer/engine/mcp/io.py index 60a46b257..807e4b890 100644 --- a/packages/data-designer-engine/src/data_designer/engine/mcp/io.py +++ b/packages/data-designer-engine/src/data_designer/engine/mcp/io.py @@ -32,8 +32,9 @@ import atexit import json import logging +import re import threading -from collections.abc import Coroutine, Iterable +from collections.abc import Callable, Coroutine, Iterable from typing import Any from mcp import ClientSession, StdioServerParameters @@ -42,10 +43,16 @@ from mcp.client.streamable_http import streamablehttp_client from data_designer.config.mcp import LocalStdioMCPProvider, MCPProvider, MCPProviderT +from data_designer.config.utils.image_helpers import ( + decode_base64_image, + detect_image_format, + extract_base64_from_data_uri, +) from data_designer.engine.mcp.errors import MCPToolError from data_designer.engine.mcp.registry import MCPToolDefinition, MCPToolResult logger = logging.getLogger(__name__) +_DATA_URI_MIME_TYPE_RE = re.compile(r"^data:(?P[^;]+);base64,") def _provider_cache_key(provider: MCPProviderT) -> str: @@ -289,7 +296,7 @@ async def _call_tool_async( session = await self._get_or_create_session(provider) result = await session.call_tool(tool_name, arguments) - content = _serialize_tool_result_content(result) + content = _coerce_tool_result_content(result) is_error = getattr(result, "isError", None) if is_error is None: is_error = getattr(result, "is_error", False) @@ -467,31 +474,195 @@ def _coerce_tool_definition(tool: Any, tool_definition_cls: type[MCPToolDefiniti return tool_definition_cls(name=name, description=description, input_schema=input_schema) -def _serialize_tool_result_content(result: Any) -> str: - """Serialize tool result content to a string.""" +def _coerce_tool_result_content(result: Any) -> str | list[dict[str, Any]]: + """Coerce MCP tool result content while preserving image blocks.""" content = getattr(result, "content", result) if content is None: return "" if isinstance(content, str): return content if isinstance(content, dict): + if _is_image_url_block(content): + return [_coerce_image_url_block(content)] + if _is_image_content(content) or _has_base64_image_payload(content): + return [_build_image_url_block(content)] + if _is_text_content(content): + return str(content.get("text", "")) return json.dumps(content) + if _is_image_content(content) or _has_base64_image_payload(content): + return [_build_image_url_block(content)] + if _is_text_content(content): + return str(_get_content_field(content, "text", default="")) if isinstance(content, list): - parts: list[str] = [] + blocks: list[dict[str, Any]] = [] + has_image = False for item in content: - if isinstance(item, str): - parts.append(item) - continue - if isinstance(item, dict): - if item.get("type") == "text": - parts.append(str(item.get("text", ""))) - else: - parts.append(json.dumps(item)) - continue - text_value = getattr(item, "text", None) - if text_value is not None: - parts.append(str(text_value)) - else: - parts.append(str(item)) - return "\n".join(parts) + block = _coerce_tool_result_content_item(item) + blocks.append(block) + has_image = has_image or block.get("type") == "image_url" + if has_image: + return blocks + return "\n".join(block.get("text", "") for block in blocks) return str(content) + + +def _coerce_tool_result_content_item(item: Any) -> dict[str, Any]: + """Coerce a single MCP content item to an internal ChatML-style block.""" + if isinstance(item, str): + return _build_text_block(item) + if _is_image_url_block(item): + return _coerce_image_url_block(item) + if _is_image_content(item) or _has_base64_image_payload(item): + return _build_image_url_block(item) + if _is_text_content(item): + return _build_text_block(_get_content_field(item, "text", default="")) + if isinstance(item, dict): + return _build_text_block(json.dumps(item)) + + text_value = getattr(item, "text", None) + if text_value is not None: + return _build_text_block(text_value) + return _build_text_block(item) + + +def _is_text_content(item: Any) -> bool: + return _get_content_field(item, "type") == "text" + + +def _is_image_content(item: Any) -> bool: + return _get_content_field(item, "type") == "image" + + +def _is_image_url_block(item: Any) -> bool: + return isinstance(item, dict) and item.get("type") == "image_url" + + +def _has_base64_image_payload(item: Any) -> bool: + data = _get_content_field(item, "data", "b64_json", "base64") + if not isinstance(data, str) or not data: + return False + + mime_type = _get_content_field(item, "mimeType", "mime_type", "media_type") + if isinstance(mime_type, str) and mime_type: + return _is_image_mime_type(mime_type) + + data_uri_mime_type = _extract_data_uri_mime_type(data) + return data_uri_mime_type is not None and _is_image_mime_type(data_uri_mime_type) + + +def _coerce_image_url_block(block: dict[str, Any]) -> dict[str, Any]: + image_url = block.get("image_url") + if isinstance(image_url, str): + image_url = {"url": image_url} + elif not isinstance(image_url, dict): + raise MCPToolError("MCP image_url block must contain an image_url dict or string.") + + url = image_url.get("url") + if not isinstance(url, str) or not url: + raise MCPToolError("MCP image_url block must contain a non-empty string URL.") + if url.startswith(("http://", "https://")): + return {"type": "image_url", "image_url": image_url} + if url.startswith("data:"): + _extract_mime_type_from_data_uri(url) + _coerce_base64_image_data(url) + return {"type": "image_url", "image_url": image_url} + + return _build_image_url_block({"base64": url}) + + +def _build_image_url_block(item: Any) -> dict[str, Any]: + data = _get_content_field(item, "data", "b64_json", "base64") + mime_type = _get_content_field(item, "mimeType", "mime_type", "media_type") + if not isinstance(data, str) or not data: + raise MCPToolError("MCP image content is missing base64 data.") + mime_type = _coerce_image_mime_type(data, mime_type) + base64_data = _coerce_base64_image_data(data) + + return { + "type": "image_url", + "image_url": {"url": f"data:{mime_type};base64,{base64_data}"}, + } + + +def _coerce_image_mime_type(data: str, mime_type: Any) -> str: + if isinstance(mime_type, str) and mime_type: + if not _is_image_mime_type(mime_type): + raise MCPToolError(f"MCP image content must use an image MIME type, got {mime_type!r}.") + return mime_type + + data_uri_mime_type = _extract_mime_type_from_data_uri(data) + if data_uri_mime_type is not None: + return data_uri_mime_type + + try: + return f"image/{detect_image_format(decode_base64_image(data)).value}" + except ValueError as exc: + raise MCPToolError("MCP image content is missing a MIME type.") from exc + + +def _coerce_base64_image_data(data: str) -> str: + try: + base64_data = extract_base64_from_data_uri(data) + decode_base64_image(base64_data) + return base64_data + except ValueError as exc: + raise MCPToolError("MCP image content has invalid base64 data.") from exc + + +def _extract_mime_type_from_data_uri(data: str) -> str | None: + mime_type = _extract_data_uri_mime_type(data) + if mime_type is None: + return None + if not _is_image_mime_type(mime_type): + raise MCPToolError(f"MCP image content data URI must use an image MIME type, got {mime_type!r}.") + return mime_type + + +def _extract_data_uri_mime_type(data: str) -> str | None: + match = _DATA_URI_MIME_TYPE_RE.match(data) + if match is None: + return None + return match.group("mime_type") + + +def _is_image_mime_type(mime_type: str) -> bool: + return mime_type.lower().startswith("image/") + + +def _get_content_field(item: Any, *names: str, default: Any = None) -> Any: + if isinstance(item, dict): + for name in names: + if name in item: + return item[name] + return default + + for name in names: + if hasattr(item, name): + return getattr(item, name) + + model_dump = getattr(item, "model_dump", None) + if callable(model_dump): + return _get_content_field_from_dump(model_dump, names, default) + + dict_dump = getattr(item, "dict", None) + if callable(dict_dump): + return _get_content_field_from_dump(dict_dump, names, default) + + return default + + +def _get_content_field_from_dump(dump_method: Callable[..., Any], names: tuple[str, ...], default: Any) -> Any: + for kwargs in ({"by_alias": True}, {}): + try: + dumped = dump_method(**kwargs) + except TypeError: + continue + if isinstance(dumped, dict): + for name in names: + if name in dumped: + return dumped[name] + return default + + +def _build_text_block(value: Any) -> dict[str, Any]: + return {"type": "text", "text": str(value)} diff --git a/packages/data-designer-engine/src/data_designer/engine/mcp/registry.py b/packages/data-designer-engine/src/data_designer/engine/mcp/registry.py index 7a8041fa2..aa7633d13 100644 --- a/packages/data-designer-engine/src/data_designer/engine/mcp/registry.py +++ b/packages/data-designer-engine/src/data_designer/engine/mcp/registry.py @@ -50,7 +50,7 @@ def to_openai_tool_schema(self) -> dict[str, Any]: class MCPToolResult: """Result from executing an MCP tool call.""" - content: str + content: str | list[dict[str, Any]] is_error: bool = False diff --git a/packages/data-designer-engine/src/data_designer/engine/models/utils.py b/packages/data-designer-engine/src/data_designer/engine/models/utils.py index f92d58da6..f7183e83d 100644 --- a/packages/data-designer-engine/src/data_designer/engine/models/utils.py +++ b/packages/data-designer-engine/src/data_designer/engine/models/utils.py @@ -78,7 +78,7 @@ def as_system(cls, content: str) -> ChatMessage: return cls(role="system", content=content) @classmethod - def as_tool(cls, content: str, tool_call_id: str) -> ChatMessage: + def as_tool(cls, content: str | list[dict[str, Any]], tool_call_id: str) -> ChatMessage: """Create a tool response message.""" return cls(role="tool", content=content, tool_call_id=tool_call_id) diff --git a/packages/data-designer-engine/tests/engine/mcp/test_mcp_facade.py b/packages/data-designer-engine/tests/engine/mcp/test_mcp_facade.py index 983559f9b..19bbf143f 100644 --- a/packages/data-designer-engine/tests/engine/mcp/test_mcp_facade.py +++ b/packages/data-designer-engine/tests/engine/mcp/test_mcp_facade.py @@ -122,6 +122,40 @@ def mock_call_tools( assert messages[1].tool_call_id == "call-1" +def test_process_completion_preserves_multimodal_tool_result_content( + monkeypatch: pytest.MonkeyPatch, + stub_mcp_facade: MCPFacade, + mock_completion_response_single_tool: ChatCompletionResponse, +) -> None: + """Tool result messages can carry multimodal content blocks unchanged.""" + + multimodal_result = [ + {"type": "text", "text": "Screenshot:"}, + {"type": "image_url", "image_url": {"url": "data:image/png;base64,iVBORw0KGgo="}}, + {"type": "text", "text": "Use the chart title."}, + ] + + def mock_list_tools(provider: Any, timeout_sec: float | None = None) -> tuple[MCPToolDefinition, ...]: + return (MCPToolDefinition(name="lookup", description="Lookup", input_schema={"type": "object"}),) + + def mock_call_tools( + calls: list[tuple[Any, str, dict[str, Any]]], + *, + timeout_sec: float | None = None, + ) -> list[MCPToolResult]: + return [MCPToolResult(content=multimodal_result)] + + monkeypatch.setattr(mcp_io, "list_tools", mock_list_tools) + monkeypatch.setattr(mcp_io, "call_tools", mock_call_tools) + + messages = stub_mcp_facade.process_completion_response(mock_completion_response_single_tool) + + assert len(messages) == 2 + assert messages[1].role == "tool" + assert messages[1].tool_call_id == "call-1" + assert messages[1].content == multimodal_result + + def test_process_completion_preserves_content( stub_mcp_facade: MCPFacade, mock_completion_response_no_tools: ChatCompletionResponse, diff --git a/packages/data-designer-engine/tests/engine/mcp/test_mcp_io.py b/packages/data-designer-engine/tests/engine/mcp/test_mcp_io.py index 21a357090..adea98441 100644 --- a/packages/data-designer-engine/tests/engine/mcp/test_mcp_io.py +++ b/packages/data-designer-engine/tests/engine/mcp/test_mcp_io.py @@ -6,6 +6,7 @@ from typing import Any, Iterator import pytest +from mcp.types import ImageContent, TextContent from data_designer.config.mcp import LocalStdioMCPProvider, MCPProvider from data_designer.engine.mcp import io as mcp_io @@ -101,67 +102,76 @@ def test_coerce_tool_definition_missing_name() -> None: # ============================================================================= -# Tool result serialization tests +# Tool result content coercion tests # ============================================================================= -def test_serialize_content_none() -> None: - """Test serializing None content.""" +def test_coerce_content_none() -> None: + """Test coercing None content.""" class FakeResult: content = None - assert mcp_io._serialize_tool_result_content(FakeResult()) == "" + assert mcp_io._coerce_tool_result_content(FakeResult()) == "" -def test_serialize_content_string() -> None: - """Test serializing string content.""" +def test_coerce_content_string() -> None: + """Test coercing string content.""" class FakeResult: content = "Hello, world!" - assert mcp_io._serialize_tool_result_content(FakeResult()) == "Hello, world!" + assert mcp_io._coerce_tool_result_content(FakeResult()) == "Hello, world!" -def test_serialize_content_dict() -> None: - """Test serializing dict content.""" +def test_coerce_content_dict() -> None: + """Test coercing structured dict content.""" class FakeResult: content = {"key": "value"} - assert mcp_io._serialize_tool_result_content(FakeResult()) == '{"key": "value"}' + assert mcp_io._coerce_tool_result_content(FakeResult()) == '{"key": "value"}' -def test_serialize_content_list_of_strings() -> None: - """Test serializing list of strings content.""" +def test_coerce_content_list_of_strings() -> None: + """Test coercing list of strings content.""" class FakeResult: content = ["line1", "line2", "line3"] - assert mcp_io._serialize_tool_result_content(FakeResult()) == "line1\nline2\nline3" + assert mcp_io._coerce_tool_result_content(FakeResult()) == "line1\nline2\nline3" -def test_serialize_content_list_of_text_items() -> None: - """Test serializing list of text items.""" +def test_coerce_content_list_of_text_items() -> None: + """Test coercing list of text items.""" class FakeResult: content = [{"type": "text", "text": "First"}, {"type": "text", "text": "Second"}] - assert mcp_io._serialize_tool_result_content(FakeResult()) == "First\nSecond" + assert mcp_io._coerce_tool_result_content(FakeResult()) == "First\nSecond" -def test_serialize_content_list_of_dicts() -> None: - """Test serializing list of non-text dicts.""" +def test_coerce_content_list_of_dicts() -> None: + """Test coercing list of non-text dicts.""" class FakeResult: content = [{"type": "data", "value": 1}] - result = mcp_io._serialize_tool_result_content(FakeResult()) + result = mcp_io._coerce_tool_result_content(FakeResult()) assert '{"type": "data", "value": 1}' in result -def test_serialize_content_list_with_objects() -> None: - """Test serializing list with objects that have text attribute.""" +def test_coerce_content_list_with_none_preserves_existing_string_fallback() -> None: + """Test coercing list fallback content preserves old str() behavior.""" + + class FakeResult: + content = [None] + + assert mcp_io._coerce_tool_result_content(FakeResult()) == "None" + + +def test_coerce_content_list_with_objects() -> None: + """Test coercing list with objects that have text attribute.""" class TextItem: text = "Object text" @@ -169,16 +179,242 @@ class TextItem: class FakeResult: content = [TextItem()] - assert mcp_io._serialize_tool_result_content(FakeResult()) == "Object text" + assert mcp_io._coerce_tool_result_content(FakeResult()) == "Object text" + + +def test_coerce_content_bare_text_object() -> None: + """Test coercing a bare MCP-like text object.""" + + class TextItem: + type = "text" + text = "Object text" + + assert mcp_io._coerce_tool_result_content(TextItem()) == "Object text" -def test_serialize_content_fallback_to_str() -> None: - """Test serializing content falls back to str().""" +def test_coerce_content_fallback_to_str() -> None: + """Test coercing content falls back to str().""" class FakeResult: content = 12345 - assert mcp_io._serialize_tool_result_content(FakeResult()) == "12345" + assert mcp_io._coerce_tool_result_content(FakeResult()) == "12345" + + +def test_coerce_content_image_dict_to_image_url_data_uri() -> None: + """MCP image content is preserved as an OpenAI-style image_url block.""" + + class FakeResult: + content = [{"type": "image", "data": "iVBORw0KGgo=", "mimeType": "image/png"}] + + assert mcp_io._coerce_tool_result_content(FakeResult()) == [ + {"type": "image_url", "image_url": {"url": "data:image/png;base64,iVBORw0KGgo="}} + ] + + +def test_coerce_content_b64_json_dict_to_image_url_data_uri() -> None: + """Explicit base64 image payloads are normalized to image_url data URIs.""" + + class FakeResult: + content = [{"b64_json": "iVBORw0KGgo=", "mime_type": "image/png"}] + + assert mcp_io._coerce_tool_result_content(FakeResult()) == [ + {"type": "image_url", "image_url": {"url": "data:image/png;base64,iVBORw0KGgo="}} + ] + + +def test_coerce_content_base64_payload_dict_with_media_type() -> None: + """Explicit media_type payloads do not need image format detection.""" + + class FakeResult: + content = [{"base64": "YWJjMTIz", "media_type": "image/webp"}] + + assert mcp_io._coerce_tool_result_content(FakeResult()) == [ + {"type": "image_url", "image_url": {"url": "data:image/webp;base64,YWJjMTIz"}} + ] + + +def test_coerce_content_non_image_resource_payload_falls_back_to_json() -> None: + """Non-image structured payloads are not misrouted into image blocks.""" + + class FakeResult: + content = {"type": "resource", "data": "eyJrZXkiOiAidmFsdWUifQ==", "mimeType": "application/json"} + + assert ( + mcp_io._coerce_tool_result_content(FakeResult()) + == '{"type": "resource", "data": "eyJrZXkiOiAidmFsdWUifQ==", "mimeType": "application/json"}' + ) + + +def test_coerce_content_non_image_base64_payload_falls_back_to_json_text() -> None: + """Generic base64 payloads require image MIME metadata to become image blocks.""" + + class FakeResult: + content = [{"base64": "JVBERi0xLjQ=", "media_type": "application/pdf"}] + + assert mcp_io._coerce_tool_result_content(FakeResult()) == ( + '{"base64": "JVBERi0xLjQ=", "media_type": "application/pdf"}' + ) + + +def test_coerce_content_image_data_uri_strips_existing_prefix() -> None: + """Data URI payloads are not double-prefixed.""" + + class FakeResult: + content = [{"type": "image", "data": "data:image/png;base64,iVBORw0KGgo="}] + + assert mcp_io._coerce_tool_result_content(FakeResult()) == [ + {"type": "image_url", "image_url": {"url": "data:image/png;base64,iVBORw0KGgo="}} + ] + + +def test_coerce_content_preserves_canonical_image_url_block() -> None: + """Canonical image_url blocks are already model-ready and should not be stringified.""" + image_block = {"type": "image_url", "image_url": {"url": "data:image/png;base64,iVBORw0KGgo="}} + + class FakeResult: + content = [{"type": "text", "text": "before"}, image_block] + + assert mcp_io._coerce_tool_result_content(FakeResult()) == [{"type": "text", "text": "before"}, image_block] + + +def test_coerce_content_normalizes_image_url_string_shorthand() -> None: + """Common shorthand image_url strings are normalized to canonical blocks.""" + + class FakeResult: + content = [{"type": "image_url", "image_url": "data:image/png;base64,iVBORw0KGgo="}] + + assert mcp_io._coerce_tool_result_content(FakeResult()) == [ + {"type": "image_url", "image_url": {"url": "data:image/png;base64,iVBORw0KGgo="}} + ] + + +def test_coerce_content_normalizes_image_url_raw_base64() -> None: + """Non-canonical image_url blocks with raw base64 are normalized to data URIs.""" + + class FakeResult: + content = [{"type": "image_url", "image_url": {"url": "iVBORw0KGgo="}}] + + assert mcp_io._coerce_tool_result_content(FakeResult()) == [ + {"type": "image_url", "image_url": {"url": "data:image/png;base64,iVBORw0KGgo="}} + ] + + +@pytest.mark.parametrize( + "image_url", + [ + pytest.param(None, id="missing"), + pytest.param(123, id="non-dict"), + pytest.param({}, id="missing-url"), + pytest.param({"url": 123}, id="non-string-url"), + ], +) +def test_coerce_content_rejects_malformed_image_url_blocks(image_url: object) -> None: + """MCP coercion enforces canonical image_url block shape.""" + + class FakeResult: + content = [{"type": "image_url", "image_url": image_url}] + + with pytest.raises(MCPToolError, match="image_url block"): + mcp_io._coerce_tool_result_content(FakeResult()) + + +def test_coerce_content_rejects_image_url_invalid_base64_data_uri() -> None: + """Canonical data URI image_url blocks still need valid base64 payloads.""" + + class FakeResult: + content = [{"type": "image_url", "image_url": {"url": "data:image/png;base64,not-base64!!!"}}] + + with pytest.raises(MCPToolError, match="invalid base64"): + mcp_io._coerce_tool_result_content(FakeResult()) + + +def test_coerce_content_bare_image_object() -> None: + """Test coercing a bare MCP-like image object.""" + + class ImageItem: + type = "image" + data = "iVBORw0KGgo=" + mimeType = "image/png" + + assert mcp_io._coerce_tool_result_content(ImageItem()) == [ + {"type": "image_url", "image_url": {"url": "data:image/png;base64,iVBORw0KGgo="}} + ] + + +def test_coerce_content_mixed_text_image_preserves_order() -> None: + """Mixed MCP text/image content returns an ordered block list.""" + + class ImageItem: + type = "image" + data = "YWJjMTIz" + mimeType = "image/jpeg" + + class TextItem: + type = "text" + text = "after" + + class FakeResult: + content = [ + {"type": "text", "text": "before"}, + ImageItem(), + TextItem(), + {"type": "image", "data": "ZGVmNDU2", "mime_type": "image/webp"}, + ] + + assert mcp_io._coerce_tool_result_content(FakeResult()) == [ + {"type": "text", "text": "before"}, + {"type": "image_url", "image_url": {"url": "data:image/jpeg;base64,YWJjMTIz"}}, + {"type": "text", "text": "after"}, + {"type": "image_url", "image_url": {"url": "data:image/webp;base64,ZGVmNDU2"}}, + ] + + +def test_coerce_content_real_mcp_text_and_image_objects() -> None: + """Real MCP content objects are preserved in model-visible order.""" + + class FakeResult: + content = [ + TextContent(type="text", text="before"), + ImageContent(type="image", data="YWJjMTIz", mimeType="image/png"), + TextContent(type="text", text="after"), + ] + + assert mcp_io._coerce_tool_result_content(FakeResult()) == [ + {"type": "text", "text": "before"}, + {"type": "image_url", "image_url": {"url": "data:image/png;base64,YWJjMTIz"}}, + {"type": "text", "text": "after"}, + ] + + +def test_coerce_content_image_without_mime_type_fails_clearly() -> None: + """Malformed image content fails before provider adaptation.""" + + class FakeResult: + content = [{"type": "image", "data": "abc123"}] + + with pytest.raises(MCPToolError, match="missing a MIME type"): + mcp_io._coerce_tool_result_content(FakeResult()) + + +def test_coerce_content_image_with_invalid_base64_fails_clearly() -> None: + """Explicit MCP image content must contain valid base64 data.""" + + class FakeResult: + content = [{"type": "image", "data": "not-base64!!!", "mimeType": "image/png"}] + + with pytest.raises(MCPToolError, match="invalid base64"): + mcp_io._coerce_tool_result_content(FakeResult()) + + +def test_coerce_content_image_with_non_image_mime_type_fails_clearly() -> None: + """Explicit MCP image content must not produce non-image image_url data URIs.""" + + class FakeResult: + content = [{"type": "image", "data": "eyJrZXkiOiAidmFsdWUifQ==", "mimeType": "application/json"}] + + with pytest.raises(MCPToolError, match="image MIME type"): + mcp_io._coerce_tool_result_content(FakeResult()) # ============================================================================= diff --git a/packages/data-designer-engine/tests/engine/models/clients/test_anthropic_translation.py b/packages/data-designer-engine/tests/engine/models/clients/test_anthropic_translation.py index 108d1cdb7..2aad03ced 100644 --- a/packages/data-designer-engine/tests/engine/models/clients/test_anthropic_translation.py +++ b/packages/data-designer-engine/tests/engine/models/clients/test_anthropic_translation.py @@ -496,6 +496,24 @@ def test_translate_tool_result_message_requires_tool_call_id(message: dict[str, ], id="mixed-blocks", ), + pytest.param( + [ + {"type": "text", "text": "Rendered chart:"}, + {"type": "image_url", "image_url": {"url": "data:image/png;base64,iVBORw0KGgo="}}, + ], + [ + {"type": "text", "text": "Rendered chart:"}, + { + "type": "image", + "source": { + "type": "base64", + "media_type": "image/png", + "data": "iVBORw0KGgo=", + }, + }, + ], + id="mixed-blocks-with-data-uri", + ), ], ) def test_translate_tool_result_content_normalizes_supported_inputs( diff --git a/packages/data-designer-engine/tests/engine/models/clients/test_openai_compatible.py b/packages/data-designer-engine/tests/engine/models/clients/test_openai_compatible.py index 185117978..3284d79b5 100644 --- a/packages/data-designer-engine/tests/engine/models/clients/test_openai_compatible.py +++ b/packages/data-designer-engine/tests/engine/models/clients/test_openai_compatible.py @@ -298,6 +298,25 @@ def test_completion_forwards_base64_image_url_dict_unchanged() -> None: assert content[0] == image_block +def test_completion_forwards_multimodal_tool_result_content_unchanged() -> None: + """OpenAI-compatible VLM backends receive canonical multimodal tool content.""" + sync_mock = make_mock_sync_client(_chat_response()) + client = _make_client(sync_client=sync_mock) + + content = [ + {"type": "text", "text": "Rendered page:"}, + {"type": "image_url", "image_url": {"url": "data:image/png;base64,iVBOR..."}}, + ] + request = ChatCompletionRequest( + model=MODEL, + messages=[{"role": "tool", "tool_call_id": "call-1", "content": content}], + ) + client.completion(request) + + payload = sync_mock.post.call_args.kwargs["json"] + assert payload["messages"][0]["content"] == content + + # --- Auth headers --- diff --git a/packages/data-designer-engine/tests/engine/models/test_facade.py b/packages/data-designer-engine/tests/engine/models/test_facade.py index 4587f0722..208010500 100644 --- a/packages/data-designer-engine/tests/engine/models/test_facade.py +++ b/packages/data-designer-engine/tests/engine/models/test_facade.py @@ -612,6 +612,71 @@ def _completion(self: Any, messages: list[ChatMessage], **kwargs: Any) -> ChatCo assert registry_calls == [("tools", "lookup", {"query": "foo"}, None)] +def test_generate_preserves_multimodal_mcp_tool_results_between_turns( + stub_model_configs: Any, + stub_model_client: MagicMock, + stub_model_provider_registry: Any, +) -> None: + tool_call = ToolCall(id="call-1", name="render_chart", arguments_json="{}") + responses = [ + _make_response(content=None, tool_calls=[tool_call]), + _make_response("final result"), + ] + multimodal_result = [ + {"type": "text", "text": "Rendered chart:"}, + {"type": "image_url", "image_url": {"url": "data:image/png;base64,iVBORw0KGgo="}}, + ] + captured_calls: list[tuple[list[ChatMessage], dict[str, Any]]] = [] + + def process_with_multimodal_tool_result(completion_response: ChatCompletionResponse) -> list[ChatMessage]: + if not completion_response.message.tool_calls: + return [ChatMessage.as_assistant(content=completion_response.message.content or "")] + return [ + ChatMessage.as_assistant( + content="", + tool_calls=[ + { + "id": "call-1", + "type": "function", + "function": {"name": "render_chart", "arguments": "{}"}, + } + ], + ), + ChatMessage.as_tool(content=multimodal_result, tool_call_id="call-1"), + ] + + facade = StubMCPFacade( + tool_schemas=[ + { + "type": "function", + "function": {"name": "render_chart", "description": "Render", "parameters": {"type": "object"}}, + } + ], + process_fn=process_with_multimodal_tool_result, + ) + registry = StubMCPRegistry(facade) + + def _completion(self: Any, messages: list[ChatMessage], **kwargs: Any) -> ChatCompletionResponse: + captured_calls.append((messages, kwargs)) + return responses.pop(0) + + model = ModelFacade( + model_config=stub_model_configs[0], + model_provider_registry=stub_model_provider_registry, + client=stub_model_client, + mcp_registry=registry, + ) + + with patch.object(ModelFacade, "completion", new=_completion): + result, _ = model.generate(prompt="question", parser=lambda x: x, tool_alias="tools") + + assert result == "final result" + assert len(captured_calls) == 2 + tool_messages = [message for message in captured_calls[1][0] if message.role == "tool"] + assert len(tool_messages) == 1 + assert tool_messages[0].content == multimodal_result + + def test_generate_with_tools_missing_registry( stub_model_configs: Any, stub_model_client: MagicMock, stub_model_provider_registry: Any ) -> None: diff --git a/packages/data-designer-engine/tests/engine/models/test_model_utils.py b/packages/data-designer-engine/tests/engine/models/test_model_utils.py index bc3765b29..c2f07c068 100644 --- a/packages/data-designer-engine/tests/engine/models/test_model_utils.py +++ b/packages/data-designer-engine/tests/engine/models/test_model_utils.py @@ -21,3 +21,15 @@ def test_prompt_to_messages() -> None: ChatMessage.as_system(stub_system_prompt), ChatMessage.as_user([mult_modal_context, {"type": "text", "text": "hello"}]), ] + + +def test_chat_message_as_tool_accepts_multimodal_content() -> None: + content = [ + {"type": "text", "text": "Rendered chart:"}, + {"type": "image_url", "image_url": {"url": "data:image/png;base64,iVBORw0KGgo="}}, + ] + + message = ChatMessage.as_tool(content=content, tool_call_id="call-1") + + assert message.content == content + assert message.to_dict()["content"] == content