Skip to content

Commit d8d5d2e

Browse files
fix: fix mcp types + add py.typed (#2248)
* fix: fix mcp types + add py.typed * Star expression is from 3.11, use alternative * Update CI to use test:types * PR feedback * PR feedback --------- Co-authored-by: David S. Batista <[email protected]>
1 parent 594e4d9 commit d8d5d2e

File tree

6 files changed

+39
-50
lines changed

6 files changed

+39
-50
lines changed

.github/workflows/mcp.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ jobs:
6565

6666
- name: Lint
6767
if: matrix.python-version == '3.10' && runner.os == 'Linux'
68-
run: hatch run fmt-check && hatch run lint:typing
68+
run: hatch run fmt-check && hatch run test:types
6969

7070
- name: Generate docs
7171
if: matrix.python-version == '3.10' && runner.os == 'Linux'

integrations/mcp/pyproject.toml

Lines changed: 6 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -78,18 +78,13 @@ unit = 'pytest -m "not integration" {args:tests}'
7878
integration = 'pytest -m "integration" {args:tests}'
7979
all = 'pytest {args:tests}'
8080
cov-retry = 'all --cov=haystack_integrations --reruns 3 --reruns-delay 30 -x'
81-
types = "mypy --install-types --non-interactive --explicit-package-bases {args:src/ tests}"
81+
types = """mypy -p haystack_integrations.tools.mcp {args}"""
8282

83-
# TODO: remove lint environment once this integration is properly typed
84-
# test environment should be used instead
85-
# https://github.com/deepset-ai/haystack-core-integrations/issues/1771
86-
[tool.hatch.envs.lint]
87-
installer = "uv"
88-
detached = true
89-
dependencies = ["pip", "mypy>=1.0.0", "ruff>=0.0.243"]
90-
91-
[tool.hatch.envs.lint.scripts]
92-
typing = "mypy --install-types --non-interactive --explicit-package-bases {args:src/ tests}"
83+
[tool.mypy]
84+
install_types = true
85+
non_interactive = true
86+
check_untyped_defs = true
87+
disallow_incomplete_defs = true
9388

9489

9590
[tool.ruff]
@@ -167,19 +162,6 @@ omit = ["*/tests/*", "*/__init__.py"]
167162
show_missing = true
168163
exclude_lines = ["no cov", "if __name__ == .__main__.:", "if TYPE_CHECKING:"]
169164

170-
[[tool.mypy.overrides]]
171-
module = [
172-
"haystack.*",
173-
"haystack_integrations.*",
174-
"pytest.*",
175-
"pytest_asyncio",
176-
"anyio.*",
177-
"mcp.*",
178-
"mcp",
179-
"httpx",
180-
"exceptiongroup"
181-
]
182-
ignore_missing_imports = true
183165

184166
[tool.pytest.ini_options]
185167
addopts = "--strict-markers"

integrations/mcp/src/haystack_integrations/tools/mcp/mcp_tool.py

Lines changed: 26 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
from mcp.client.sse import sse_client
3030
from mcp.client.stdio import stdio_client
3131
from mcp.client.streamable_http import streamablehttp_client
32+
from mcp.shared.message import SessionMessage
3233

3334
logger = logging.getLogger(__name__)
3435

@@ -121,7 +122,7 @@ async def _coroutine_with_stop_event():
121122
# use it to control the coroutine.
122123
return future, stop_event_promise.result(timeout)
123124

124-
def shutdown(self, timeout: float = 2):
125+
def shutdown(self, timeout: float = 2) -> None:
125126
"""
126127
Shut down the background event loop and thread.
127128
@@ -208,14 +209,14 @@ class MCPClient(ABC):
208209
def __init__(self, max_retries: int = 3, base_delay: float = 1.0, max_delay: float = 30.0) -> None:
209210
self.session: ClientSession | None = None
210211
self.exit_stack: AsyncExitStack = AsyncExitStack()
211-
self.stdio: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception] | None = None
212-
self.write: MemoryObjectSendStream[types.JSONRPCMessage] | None = None
212+
self.stdio: MemoryObjectReceiveStream[SessionMessage | Exception] | None = None
213+
self.write: MemoryObjectSendStream[SessionMessage] | None = None
213214
self.max_retries = max_retries
214215
self.base_delay = base_delay
215216
self.max_delay = max_delay
216217

217218
@abstractmethod
218-
async def connect(self) -> list[Tool]:
219+
async def connect(self) -> list[types.Tool]:
219220
"""
220221
Connect to an MCP server.
221222
@@ -329,10 +330,16 @@ async def aclose(self) -> None:
329330
async def _initialize_session_with_transport(
330331
self,
331332
transport_tuple: tuple[
332-
MemoryObjectReceiveStream[types.JSONRPCMessage | Exception], MemoryObjectSendStream[types.JSONRPCMessage]
333+
MemoryObjectReceiveStream[SessionMessage | Exception],
334+
MemoryObjectSendStream[SessionMessage],
335+
]
336+
| tuple[
337+
MemoryObjectReceiveStream[SessionMessage | Exception],
338+
MemoryObjectSendStream[SessionMessage],
339+
Any,
333340
],
334341
connection_type: str,
335-
) -> list[Tool]:
342+
) -> list[types.Tool]:
336343
"""
337344
Common session initialization logic for all transports.
338345
@@ -390,11 +397,11 @@ def __init__(
390397
self.env: dict[str, str] | None = None
391398
if env:
392399
self.env = {
393-
key: value.resolve_value() if isinstance(value, Secret) else value for key, value in env.items()
400+
key: (value.resolve_value() if isinstance(value, Secret) else value) or "" for key, value in env.items()
394401
}
395402
logger.debug(f"PROCESS: Created StdioClient for command: {command} {' '.join(self.args or [])}")
396403

397-
async def connect(self) -> list[Tool]:
404+
async def connect(self) -> list[types.Tool]:
398405
"""
399406
Connect to an MCP server using stdio transport.
400407
@@ -433,7 +440,7 @@ def __init__(
433440
)
434441
self.timeout: int = server_info.timeout
435442

436-
async def connect(self) -> list[Tool]:
443+
async def connect(self) -> list[types.Tool]:
437444
"""
438445
Connect to an MCP server using SSE transport.
439446
@@ -481,7 +488,7 @@ def __init__(
481488
)
482489
self.timeout: int = server_info.timeout
483490

484-
async def connect(self) -> list[Tool]:
491+
async def connect(self) -> list[types.Tool]:
485492
"""
486493
Connect to an MCP server using streamable HTTP transport.
487494
@@ -526,7 +533,7 @@ def to_dict(self) -> dict[str, Any]:
526533
:returns: Dictionary representation of this server info
527534
"""
528535
# Store the fully qualified class name for deserialization
529-
result = {"type": generate_qualified_class_name(type(self))}
536+
result: dict[str, Any] = {"type": generate_qualified_class_name(type(self))}
530537

531538
# Add all fields from the dataclass
532539
for dataclass_field in fields(self):
@@ -629,7 +636,7 @@ def __post_init__(self):
629636
# from now on only use url for the lifetime of the SSEServerInfo instance, never base_url
630637
self.url = f"{self.base_url.rstrip('/')}/sse"
631638

632-
elif not is_valid_http_url(self.url):
639+
elif self.url and not is_valid_http_url(self.url):
633640
message = f"Invalid url: {self.url}"
634641
raise ValueError(message)
635642

@@ -834,7 +841,7 @@ def __init__(
834841
tool_dict = {t.name: t for t in tools}
835842
logger.debug(f"TOOL: Available tools: {list(tool_dict.keys())}")
836843

837-
tool_info = tool_dict.get(name)
844+
tool_info: types.Tool | None = tool_dict.get(name)
838845

839846
if not tool_info:
840847
available = list(tool_dict.keys())
@@ -846,7 +853,7 @@ def __init__(
846853
# Initialize the parent class
847854
super().__init__(
848855
name=name,
849-
description=description or tool_info.description,
856+
description=description or tool_info.description or "",
850857
parameters=tool_info.inputSchema,
851858
function=self._invoke_tool,
852859
)
@@ -971,7 +978,7 @@ def from_dict(cls, data: dict[str, Any]) -> "Tool":
971978
# First get the appropriate class by name
972979
server_info_class = import_class_by_name(server_info_dict["type"])
973980
# Then deserialize using that class's from_dict method
974-
server_info = server_info_class.from_dict(server_info_dict)
981+
server_info = cast(MCPServerInfo, server_info_class).from_dict(server_info_dict)
975982

976983
# Handle backward compatibility for timeout parameters
977984
connection_timeout = inner_data.get("connection_timeout", 30)
@@ -1027,7 +1034,7 @@ def __init__(self, client: "MCPClient", *, timeout: float | None = None):
10271034
self.executor = AsyncExecutor.get_instance()
10281035

10291036
# Where the tool list (or an exception) will be delivered.
1030-
self._tools_promise: Future[list[Tool]] = Future()
1037+
self._tools_promise: Future[list[types.Tool]] = Future()
10311038

10321039
# Kick off the worker coroutine in the background loop
10331040
self._worker_future, self._stop_event = self.executor.run_background(self._run, timeout=None)
@@ -1040,15 +1047,15 @@ def __init__(self, client: "MCPClient", *, timeout: float | None = None):
10401047
self.stop()
10411048
raise
10421049

1043-
def tools(self) -> list[Tool]:
1050+
def tools(self) -> list[types.Tool]:
10441051
"""Return the tool list already collected during startup."""
10451052

10461053
return self._tools_promise.result()
10471054

10481055
def stop(self) -> None:
10491056
"""Request the worker to shut down and block until done."""
10501057

1051-
def _set(ev: asyncio.Event):
1058+
def _set(ev: asyncio.Event) -> None:
10521059
if not ev.is_set():
10531060
ev.set()
10541061

@@ -1067,7 +1074,7 @@ def _set(ev: asyncio.Event):
10671074
logger.debug(f"Error during worker future result: {e}")
10681075
pass
10691076

1070-
async def _run(self, stop_event: asyncio.Event):
1077+
async def _run(self, stop_event: asyncio.Event) -> None:
10711078
"""Background coroutine living in AsyncExecutor's loop."""
10721079

10731080
try:

integrations/mcp/src/haystack_integrations/tools/mcp/mcp_toolset.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
# SPDX-License-Identifier: Apache-2.0
44

55
from collections.abc import Callable
6-
from typing import Any
6+
from typing import Any, cast
77
from urllib.parse import urlparse
88

99
import httpx
@@ -155,7 +155,7 @@ def create_invoke_tool(
155155
) -> Callable[..., Any]:
156156
"""Return a closure that keeps a strong reference to *owner_toolset* alive."""
157157

158-
def invoke_tool(**kwargs) -> Any:
158+
def invoke_tool(**kwargs: Any) -> Any:
159159
_ = owner_toolset # strong reference so GC can't collect the toolset too early
160160
return AsyncExecutor.get_instance().run(
161161
mcp_client.call_tool(tool_name, kwargs), timeout=tool_timeout
@@ -176,7 +176,7 @@ def invoke_tool(**kwargs) -> Any:
176176
# Use the helper function to create the invoke_tool function
177177
tool = Tool(
178178
name=tool_info.name,
179-
description=tool_info.description,
179+
description=tool_info.description or "",
180180
parameters=tool_info.inputSchema,
181181
function=create_invoke_tool(self, client, tool_info.name, self.invocation_timeout),
182182
)
@@ -273,7 +273,7 @@ def from_dict(cls, data: dict[str, Any]) -> "MCPToolset":
273273
# Reconstruct the server_info object
274274
server_info_dict = inner_data.get("server_info", {})
275275
server_info_class = import_class_by_name(server_info_dict["type"])
276-
server_info = server_info_class.from_dict(server_info_dict)
276+
server_info = cast(MCPServerInfo, server_info_class).from_dict(server_info_dict)
277277

278278
# Create a new MCPToolset instance
279279
return cls(

integrations/mcp/src/haystack_integrations/tools/mcp/py.typed

Whitespace-only changes.

integrations/mcp/tests/mcp_memory_transport.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from dataclasses import dataclass
22
from typing import Any
33

4-
from haystack.tools import Tool
4+
from mcp import types
55
from mcp.server import Server
66
from mcp.shared.memory import create_connected_server_and_client_session
77

@@ -17,7 +17,7 @@ def __init__(self, server: Server) -> None:
1717
super().__init__()
1818
self.server: Server = server
1919

20-
async def connect(self) -> list[Tool]:
20+
async def connect(self) -> list[types.Tool]:
2121
"""
2222
Connect to an MCP server using stdio transport.
2323

0 commit comments

Comments
 (0)