diff --git a/src/mcp/client/session.py b/src/mcp/client/session.py index 1853ce7c1..e8adc1af9 100644 --- a/src/mcp/client/session.py +++ b/src/mcp/client/session.py @@ -112,6 +112,7 @@ def __init__( read_stream: MemoryObjectReceiveStream[SessionMessage | Exception], write_stream: MemoryObjectSendStream[SessionMessage], read_timeout_seconds: timedelta | None = None, + progress_callback: ProgressFnT | None = None, sampling_callback: SamplingFnT | None = None, elicitation_callback: ElicitationFnT | None = None, list_roots_callback: ListRootsFnT | None = None, @@ -127,6 +128,7 @@ def __init__( read_timeout_seconds=read_timeout_seconds, ) self._client_info = client_info or DEFAULT_CLIENT_INFO + self._progress_callback = progress_callback self._sampling_callback = sampling_callback or _default_sampling_callback self._elicitation_callback = elicitation_callback or _default_elicitation_callback self._list_roots_callback = list_roots_callback or _default_list_roots_callback @@ -302,7 +304,7 @@ async def call_tool( ), types.CallToolResult, request_read_timeout_seconds=read_timeout_seconds, - progress_callback=progress_callback, + progress_callback=progress_callback or self._progress_callback, ) if not result.isError: