29
29
from mcp .client .sse import sse_client
30
30
from mcp .client .stdio import stdio_client
31
31
from mcp .client .streamable_http import streamablehttp_client
32
+ from mcp .shared .message import SessionMessage
32
33
33
34
logger = logging .getLogger (__name__ )
34
35
@@ -121,7 +122,7 @@ async def _coroutine_with_stop_event():
121
122
# use it to control the coroutine.
122
123
return future , stop_event_promise .result (timeout )
123
124
124
- def shutdown (self , timeout : float = 2 ):
125
+ def shutdown (self , timeout : float = 2 ) -> None :
125
126
"""
126
127
Shut down the background event loop and thread.
127
128
@@ -208,14 +209,14 @@ class MCPClient(ABC):
208
209
def __init__ (self , max_retries : int = 3 , base_delay : float = 1.0 , max_delay : float = 30.0 ) -> None :
209
210
self .session : ClientSession | None = None
210
211
self .exit_stack : AsyncExitStack = AsyncExitStack ()
211
- self .stdio : MemoryObjectReceiveStream [types . JSONRPCMessage | Exception ] | None = None
212
- self .write : MemoryObjectSendStream [types . JSONRPCMessage ] | None = None
212
+ self .stdio : MemoryObjectReceiveStream [SessionMessage | Exception ] | None = None
213
+ self .write : MemoryObjectSendStream [SessionMessage ] | None = None
213
214
self .max_retries = max_retries
214
215
self .base_delay = base_delay
215
216
self .max_delay = max_delay
216
217
217
218
@abstractmethod
218
- async def connect (self ) -> list [Tool ]:
219
+ async def connect (self ) -> list [types . Tool ]:
219
220
"""
220
221
Connect to an MCP server.
221
222
@@ -329,10 +330,16 @@ async def aclose(self) -> None:
329
330
async def _initialize_session_with_transport (
330
331
self ,
331
332
transport_tuple : tuple [
332
- MemoryObjectReceiveStream [types .JSONRPCMessage | Exception ], MemoryObjectSendStream [types .JSONRPCMessage ]
333
+ MemoryObjectReceiveStream [SessionMessage | Exception ],
334
+ MemoryObjectSendStream [SessionMessage ],
335
+ ]
336
+ | tuple [
337
+ MemoryObjectReceiveStream [SessionMessage | Exception ],
338
+ MemoryObjectSendStream [SessionMessage ],
339
+ Any ,
333
340
],
334
341
connection_type : str ,
335
- ) -> list [Tool ]:
342
+ ) -> list [types . Tool ]:
336
343
"""
337
344
Common session initialization logic for all transports.
338
345
@@ -390,11 +397,11 @@ def __init__(
390
397
self .env : dict [str , str ] | None = None
391
398
if env :
392
399
self .env = {
393
- key : value .resolve_value () if isinstance (value , Secret ) else value for key , value in env .items ()
400
+ key : ( value .resolve_value () if isinstance (value , Secret ) else value ) or "" for key , value in env .items ()
394
401
}
395
402
logger .debug (f"PROCESS: Created StdioClient for command: { command } { ' ' .join (self .args or [])} " )
396
403
397
- async def connect (self ) -> list [Tool ]:
404
+ async def connect (self ) -> list [types . Tool ]:
398
405
"""
399
406
Connect to an MCP server using stdio transport.
400
407
@@ -433,7 +440,7 @@ def __init__(
433
440
)
434
441
self .timeout : int = server_info .timeout
435
442
436
- async def connect (self ) -> list [Tool ]:
443
+ async def connect (self ) -> list [types . Tool ]:
437
444
"""
438
445
Connect to an MCP server using SSE transport.
439
446
@@ -481,7 +488,7 @@ def __init__(
481
488
)
482
489
self .timeout : int = server_info .timeout
483
490
484
- async def connect (self ) -> list [Tool ]:
491
+ async def connect (self ) -> list [types . Tool ]:
485
492
"""
486
493
Connect to an MCP server using streamable HTTP transport.
487
494
@@ -526,7 +533,7 @@ def to_dict(self) -> dict[str, Any]:
526
533
:returns: Dictionary representation of this server info
527
534
"""
528
535
# Store the fully qualified class name for deserialization
529
- result = {"type" : generate_qualified_class_name (type (self ))}
536
+ result : dict [ str , Any ] = {"type" : generate_qualified_class_name (type (self ))}
530
537
531
538
# Add all fields from the dataclass
532
539
for dataclass_field in fields (self ):
@@ -629,7 +636,7 @@ def __post_init__(self):
629
636
# from now on only use url for the lifetime of the SSEServerInfo instance, never base_url
630
637
self .url = f"{ self .base_url .rstrip ('/' )} /sse"
631
638
632
- elif not is_valid_http_url (self .url ):
639
+ elif self . url and not is_valid_http_url (self .url ):
633
640
message = f"Invalid url: { self .url } "
634
641
raise ValueError (message )
635
642
@@ -834,7 +841,7 @@ def __init__(
834
841
tool_dict = {t .name : t for t in tools }
835
842
logger .debug (f"TOOL: Available tools: { list (tool_dict .keys ())} " )
836
843
837
- tool_info = tool_dict .get (name )
844
+ tool_info : types . Tool | None = tool_dict .get (name )
838
845
839
846
if not tool_info :
840
847
available = list (tool_dict .keys ())
@@ -846,7 +853,7 @@ def __init__(
846
853
# Initialize the parent class
847
854
super ().__init__ (
848
855
name = name ,
849
- description = description or tool_info .description ,
856
+ description = description or tool_info .description or "" ,
850
857
parameters = tool_info .inputSchema ,
851
858
function = self ._invoke_tool ,
852
859
)
@@ -971,7 +978,7 @@ def from_dict(cls, data: dict[str, Any]) -> "Tool":
971
978
# First get the appropriate class by name
972
979
server_info_class = import_class_by_name (server_info_dict ["type" ])
973
980
# Then deserialize using that class's from_dict method
974
- server_info = server_info_class .from_dict (server_info_dict )
981
+ server_info = cast ( MCPServerInfo , server_info_class ) .from_dict (server_info_dict )
975
982
976
983
# Handle backward compatibility for timeout parameters
977
984
connection_timeout = inner_data .get ("connection_timeout" , 30 )
@@ -1027,7 +1034,7 @@ def __init__(self, client: "MCPClient", *, timeout: float | None = None):
1027
1034
self .executor = AsyncExecutor .get_instance ()
1028
1035
1029
1036
# Where the tool list (or an exception) will be delivered.
1030
- self ._tools_promise : Future [list [Tool ]] = Future ()
1037
+ self ._tools_promise : Future [list [types . Tool ]] = Future ()
1031
1038
1032
1039
# Kick off the worker coroutine in the background loop
1033
1040
self ._worker_future , self ._stop_event = self .executor .run_background (self ._run , timeout = None )
@@ -1040,15 +1047,15 @@ def __init__(self, client: "MCPClient", *, timeout: float | None = None):
1040
1047
self .stop ()
1041
1048
raise
1042
1049
1043
- def tools (self ) -> list [Tool ]:
1050
+ def tools (self ) -> list [types . Tool ]:
1044
1051
"""Return the tool list already collected during startup."""
1045
1052
1046
1053
return self ._tools_promise .result ()
1047
1054
1048
1055
def stop (self ) -> None :
1049
1056
"""Request the worker to shut down and block until done."""
1050
1057
1051
- def _set (ev : asyncio .Event ):
1058
+ def _set (ev : asyncio .Event ) -> None :
1052
1059
if not ev .is_set ():
1053
1060
ev .set ()
1054
1061
@@ -1067,7 +1074,7 @@ def _set(ev: asyncio.Event):
1067
1074
logger .debug (f"Error during worker future result: { e } " )
1068
1075
pass
1069
1076
1070
- async def _run (self , stop_event : asyncio .Event ):
1077
+ async def _run (self , stop_event : asyncio .Event ) -> None :
1071
1078
"""Background coroutine living in AsyncExecutor's loop."""
1072
1079
1073
1080
try :
0 commit comments