Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions architecture/mcp.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,10 @@ Scoped to one `ToolConfig`. Provides the interface that `ModelFacade` uses:
- **`process_completion_response`** β€” extracts tool calls from a completion, executes them in parallel via `MCPIOService`, returns `ChatMessage` list with results
- **`refuse_completion_response`** β€” handles tool-call turn limits (prevents infinite tool loops)

Tool result messages may contain either text or ordered multimodal content blocks. MCP image results and explicit
base64 image payloads are preserved as canonical `image_url` data URI blocks and translated by provider adapters at
the API boundary. Models need VLM-capable provider support to interpret those image results semantically.

### MCPRegistry

Maps `tool_alias` β†’ `ToolConfig`. Lazy `MCPFacade` construction mirrors `ModelRegistry`. Provides health checks for configured tools.
Expand Down
185 changes: 165 additions & 20 deletions packages/data-designer-engine/src/data_designer/engine/mcp/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,9 @@
import atexit
import json
import logging
import re
import threading
from collections.abc import Coroutine, Iterable
from collections.abc import Callable, Coroutine, Iterable
from typing import Any

from mcp import ClientSession, StdioServerParameters
Expand All @@ -42,10 +43,16 @@
from mcp.client.streamable_http import streamablehttp_client

from data_designer.config.mcp import LocalStdioMCPProvider, MCPProvider, MCPProviderT
from data_designer.config.utils.image_helpers import (
decode_base64_image,
detect_image_format,
extract_base64_from_data_uri,
)
from data_designer.engine.mcp.errors import MCPToolError
from data_designer.engine.mcp.registry import MCPToolDefinition, MCPToolResult

logger = logging.getLogger(__name__)
_DATA_URI_MIME_TYPE_RE = re.compile(r"^data:(?P<mime_type>[^;]+);base64,")


def _provider_cache_key(provider: MCPProviderT) -> str:
Expand Down Expand Up @@ -289,7 +296,7 @@ async def _call_tool_async(
session = await self._get_or_create_session(provider)
result = await session.call_tool(tool_name, arguments)

content = _serialize_tool_result_content(result)
content = _coerce_tool_result_content(result)
is_error = getattr(result, "isError", None)
if is_error is None:
is_error = getattr(result, "is_error", False)
Expand Down Expand Up @@ -467,31 +474,169 @@ def _coerce_tool_definition(tool: Any, tool_definition_cls: type[MCPToolDefiniti
return tool_definition_cls(name=name, description=description, input_schema=input_schema)


def _serialize_tool_result_content(result: Any) -> str:
"""Serialize tool result content to a string."""
def _coerce_tool_result_content(result: Any) -> str | list[dict[str, Any]]:
"""Coerce MCP tool result content while preserving image blocks."""
content = getattr(result, "content", result)
if content is None:
return ""
if isinstance(content, str):
return content
if isinstance(content, dict):
if _is_image_url_block(content):
return [_coerce_image_url_block(content)]
if _is_image_content(content) or _has_base64_image_payload(content):
return [_build_image_url_block(content)]
if _is_text_content(content):
return str(content.get("text", ""))
return json.dumps(content)
if _is_image_content(content) or _has_base64_image_payload(content):
return [_build_image_url_block(content)]
if _is_text_content(content):
return str(_get_content_field(content, "text", default=""))
if isinstance(content, list):
parts: list[str] = []
blocks: list[dict[str, Any]] = []
has_image = False
for item in content:
if isinstance(item, str):
parts.append(item)
continue
if isinstance(item, dict):
if item.get("type") == "text":
parts.append(str(item.get("text", "")))
else:
parts.append(json.dumps(item))
continue
text_value = getattr(item, "text", None)
if text_value is not None:
parts.append(str(text_value))
else:
parts.append(str(item))
return "\n".join(parts)
block = _coerce_tool_result_content_item(item)
blocks.append(block)
has_image = has_image or block.get("type") == "image_url"
if has_image:
return blocks
return "\n".join(block.get("text", "") for block in blocks)
return str(content)


def _coerce_tool_result_content_item(item: Any) -> dict[str, Any]:
"""Coerce a single MCP content item to an internal ChatML-style block."""
if isinstance(item, str):
return _build_text_block(item)
if _is_image_url_block(item):
return _coerce_image_url_block(item)
if _is_image_content(item) or _has_base64_image_payload(item):
return _build_image_url_block(item)
if _is_text_content(item):
return _build_text_block(_get_content_field(item, "text", default=""))
if isinstance(item, dict):
return _build_text_block(json.dumps(item))

text_value = getattr(item, "text", None)
if text_value is not None:
return _build_text_block(text_value)
return _build_text_block(item)


def _is_text_content(item: Any) -> bool:
return _get_content_field(item, "type") == "text"


def _is_image_content(item: Any) -> bool:
return _get_content_field(item, "type") == "image"


def _is_image_url_block(item: Any) -> bool:
return isinstance(item, dict) and item.get("type") == "image_url"


def _has_base64_image_payload(item: Any) -> bool:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we tighten this image detection a bit? Right now _has_base64_image_payload() treats any item with data plus mimeType/mime_type/media_type as an image payload, and _coerce_image_mime_type() accepts any provided MIME string unchanged. That means a generic structured MCP/tool result such as {"type": "resource", "data": "...", "mimeType": "application/json"} or {"base64": "...", "media_type": "application/pdf"} can be turned into an image_url block with a non-image data URI.

Could we gate the generic base64/data detection on image/*, and reserve the clear MCPToolError path for explicit type == "image" content with a non-image MIME? Non-image structured payloads could then keep the existing JSON/text fallback instead of being sent to provider adapters as invalid image content.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Addressed in 362f1394. Generic data/base64 payload detection now only promotes content to an image block when the MIME metadata or data URI is image/*; non-image resource payloads such as application/json/application/pdf fall back to JSON/text. Explicit type == "image" content now validates provided MIME types and raises MCPToolError for non-image MIME values. Added regression coverage for both fallback cases and the explicit-image error path. Verified with the MCP coercion tests (56 passed) and the focused PR suite (243 passed).

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me, thanks for tightening this up. Small non-blocking note: this now means bare b64_json/base64 payloads without MIME metadata fall back to JSON/text instead of image auto-detection. That seems reasonable if MIME/data URI is the contract; if you want the shorthand to keep working, magic-byte detection could be added back for that case. Your call, not blocking from my side.

if _get_content_field(item, "b64_json", "base64") is not None:
return True
return (
_get_content_field(item, "data") is not None
and _get_content_field(item, "mimeType", "mime_type", "media_type") is not None
)


def _coerce_image_url_block(block: dict[str, Any]) -> dict[str, Any]:
image_url = block.get("image_url")
if not isinstance(image_url, dict):
return block

url = image_url.get("url")
if not isinstance(url, str) or url.startswith(("data:image/", "http://", "https://")):
return block

return _build_image_url_block({"base64": url})


def _build_image_url_block(item: Any) -> dict[str, Any]:
data = _get_content_field(item, "data", "b64_json", "base64")
mime_type = _get_content_field(item, "mimeType", "mime_type", "media_type")
if not isinstance(data, str) or not data:
raise MCPToolError("MCP image content is missing base64 data.")
mime_type = _coerce_image_mime_type(data, mime_type)
base64_data = _coerce_base64_image_data(data)

return {
"type": "image_url",
"image_url": {"url": f"data:{mime_type};base64,{base64_data}"},
}


def _coerce_image_mime_type(data: str, mime_type: Any) -> str:
if isinstance(mime_type, str) and mime_type:
return mime_type

data_uri_mime_type = _extract_mime_type_from_data_uri(data)
if data_uri_mime_type is not None:
return data_uri_mime_type

try:
return f"image/{detect_image_format(decode_base64_image(data)).value}"
except ValueError as exc:
raise MCPToolError("MCP image content is missing a MIME type.") from exc


def _coerce_base64_image_data(data: str) -> str:
try:
return extract_base64_from_data_uri(data)
except ValueError as exc:
raise MCPToolError("MCP image content has invalid base64 data.") from exc


def _extract_mime_type_from_data_uri(data: str) -> str | None:
match = _DATA_URI_MIME_TYPE_RE.match(data)
if match is None:
return None
mime_type = match.group("mime_type")
if not mime_type.startswith("image/"):
raise MCPToolError(f"MCP image content data URI must use an image MIME type, got {mime_type!r}.")
return mime_type


def _get_content_field(item: Any, *names: str, default: Any = None) -> Any:
if isinstance(item, dict):
for name in names:
if name in item:
return item[name]
return default

for name in names:
if hasattr(item, name):
return getattr(item, name)

model_dump = getattr(item, "model_dump", None)
if callable(model_dump):
return _get_content_field_from_dump(model_dump, names, default)

dict_dump = getattr(item, "dict", None)
if callable(dict_dump):
return _get_content_field_from_dump(dict_dump, names, default)

return default


def _get_content_field_from_dump(dump_method: Callable[..., Any], names: tuple[str, ...], default: Any) -> Any:
for kwargs in ({"by_alias": True}, {}):
try:
dumped = dump_method(**kwargs)
except TypeError:
continue
if isinstance(dumped, dict):
for name in names:
if name in dumped:
return dumped[name]
return default


def _build_text_block(value: Any) -> dict[str, Any]:
return {"type": "text", "text": str(value)}
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def to_openai_tool_schema(self) -> dict[str, Any]:
class MCPToolResult:
"""Result from executing an MCP tool call."""

content: str
content: str | list[dict[str, Any]]
is_error: bool = False


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def as_system(cls, content: str) -> ChatMessage:
return cls(role="system", content=content)

@classmethod
def as_tool(cls, content: str, tool_call_id: str) -> ChatMessage:
def as_tool(cls, content: str | list[dict[str, Any]], tool_call_id: str) -> ChatMessage:
"""Create a tool response message."""
return cls(role="tool", content=content, tool_call_id=tool_call_id)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,40 @@ def mock_call_tools(
assert messages[1].tool_call_id == "call-1"


def test_process_completion_preserves_multimodal_tool_result_content(
monkeypatch: pytest.MonkeyPatch,
stub_mcp_facade: MCPFacade,
mock_completion_response_single_tool: ChatCompletionResponse,
) -> None:
"""Tool result messages can carry multimodal content blocks unchanged."""

multimodal_result = [
{"type": "text", "text": "Screenshot:"},
{"type": "image_url", "image_url": {"url": "data:image/png;base64,iVBORw0KGgo="}},
{"type": "text", "text": "Use the chart title."},
]

def mock_list_tools(provider: Any, timeout_sec: float | None = None) -> tuple[MCPToolDefinition, ...]:
return (MCPToolDefinition(name="lookup", description="Lookup", input_schema={"type": "object"}),)

def mock_call_tools(
calls: list[tuple[Any, str, dict[str, Any]]],
*,
timeout_sec: float | None = None,
) -> list[MCPToolResult]:
return [MCPToolResult(content=multimodal_result)]

monkeypatch.setattr(mcp_io, "list_tools", mock_list_tools)
monkeypatch.setattr(mcp_io, "call_tools", mock_call_tools)

messages = stub_mcp_facade.process_completion_response(mock_completion_response_single_tool)

assert len(messages) == 2
assert messages[1].role == "tool"
assert messages[1].tool_call_id == "call-1"
assert messages[1].content == multimodal_result


def test_process_completion_preserves_content(
stub_mcp_facade: MCPFacade,
mock_completion_response_no_tools: ChatCompletionResponse,
Expand Down
Loading
Loading