diff --git a/crates/rust-mcp-sdk/src/hyper_servers/routes/sse_routes.rs b/crates/rust-mcp-sdk/src/hyper_servers/routes/sse_routes.rs index b6e98f0..4e7cb3a 100644 --- a/crates/rust-mcp-sdk/src/hyper_servers/routes/sse_routes.rs +++ b/crates/rust-mcp-sdk/src/hyper_servers/routes/sse_routes.rs @@ -19,6 +19,7 @@ use axum::{ Extension, Router, }; use futures::stream::{self}; +use rust_mcp_schema::schema_utils::ClientMessage; use rust_mcp_transport::{error::TransportError, SessionId, SseTransport}; use std::{convert::Infallible, sync::Arc, time::Duration}; use tokio::{ @@ -78,7 +79,7 @@ pub async fn handle_sse( State(state): State>, ) -> TransportServerResult { let messages_endpoint = - SseTransport::message_endpoint(&state.sse_message_endpoint, &session_id); + SseTransport::::message_endpoint(&state.sse_message_endpoint, &session_id); // readable stream of string to be used in transport let (read_tx, read_rx) = duplex(DUPLEX_BUFFER_SIZE); diff --git a/crates/rust-mcp-sdk/src/mcp_runtimes/client_runtime.rs b/crates/rust-mcp-sdk/src/mcp_runtimes/client_runtime.rs index fb81612..9fab939 100644 --- a/crates/rust-mcp-sdk/src/mcp_runtimes/client_runtime.rs +++ b/crates/rust-mcp-sdk/src/mcp_runtimes/client_runtime.rs @@ -28,16 +28,10 @@ pub struct ClientRuntime { client_details: InitializeRequestParams, // Details about the connected server server_details: Arc>>, - message_sender: tokio::sync::RwLock>>, handlers: Mutex>>>, } impl ClientRuntime { - pub(crate) async fn set_message_sender(&self, sender: MessageDispatcher) { - let mut lock = self.message_sender.write().await; - *lock = Some(sender); - } - pub(crate) fn new( client_details: InitializeRequestParams, transport: impl Transport, @@ -48,7 +42,6 @@ impl ClientRuntime { handler, client_details, server_details: Arc::new(RwLock::new(None)), - message_sender: tokio::sync::RwLock::new(None), handlers: Mutex::new(vec![]), } } @@ -83,12 +76,14 @@ impl McpClient for ClientRuntime { where MessageDispatcher: McpDispatch, { - (&self.message_sender) as _ + (self.transport.message_sender().await) as _ } async fn start(self: Arc) -> SdkResult<()> { - let (mut stream, sender, error_io) = self.transport.start().await?; - self.set_message_sender(sender).await; + let mut stream = self.transport.start().await?; + + let mut error_io_stream = self.transport.error_stream().await.write().await; + let error_io_stream = error_io_stream.take(); let self_clone = Arc::clone(&self); let self_clone_err = Arc::clone(&self); @@ -96,7 +91,7 @@ impl McpClient for ClientRuntime { let err_task = tokio::spawn(async move { let self_ref = &*self_clone_err; - if let IoStream::Readable(error_input) = error_io { + if let Some(IoStream::Readable(error_input)) = error_io_stream { let mut reader = BufReader::new(error_input).lines(); loop { tokio::select! { @@ -126,6 +121,7 @@ impl McpClient for ClientRuntime { } } } + Ok::<(), McpSdkError>(()) }); diff --git a/crates/rust-mcp-sdk/src/mcp_runtimes/server_runtime.rs b/crates/rust-mcp-sdk/src/mcp_runtimes/server_runtime.rs index fff5acd..a1287b1 100644 --- a/crates/rust-mcp-sdk/src/mcp_runtimes/server_runtime.rs +++ b/crates/rust-mcp-sdk/src/mcp_runtimes/server_runtime.rs @@ -7,7 +7,6 @@ use async_trait::async_trait; use futures::StreamExt; use rust_mcp_transport::{IoStream, McpDispatch, MessageDispatcher, Transport}; use schema_utils::ClientMessage; -use std::pin::Pin; use std::sync::{Arc, RwLock}; use tokio::io::AsyncWriteExt; @@ -27,9 +26,6 @@ pub struct ServerRuntime { server_details: Arc, // Details about the connected client client_details: Arc>>, - - message_sender: tokio::sync::RwLock>>, - error_stream: tokio::sync::RwLock>>>, #[cfg(feature = "hyper-server")] session_id: Option, } @@ -70,24 +66,14 @@ impl McpServer for ServerRuntime { where MessageDispatcher: McpDispatch, { - (&self.message_sender) as _ + (self.transport.message_sender().await) as _ } /// Main runtime loop, processes incoming messages and handles requests async fn start(&self) -> SdkResult<()> { - // Start the transport layer to begin handling messages - // self.transport.start().await?; - // Open the transport stream - // let mut stream = self.transport.open(); - let (mut stream, sender, error_io) = self.transport.start().await?; - - self.set_message_sender(sender).await; - - if let IoStream::Writable(error_stream) = error_io { - self.set_error_stream(error_stream).await; - } + let mut stream = self.transport.start().await?; - let sender = self.sender().await.read().await; + let sender = self.transport.message_sender().await.read().await; let sender = sender .as_ref() .ok_or(schema_utils::SdkError::connection_closed())?; @@ -138,8 +124,8 @@ impl McpServer for ServerRuntime { } async fn stderr_message(&self, message: String) -> SdkResult<()> { - let mut lock = self.error_stream.write().await; - if let Some(stderr) = lock.as_mut() { + let mut lock = self.transport.error_stream().await.write().await; + if let Some(IoStream::Writable(stderr)) = lock.as_mut() { stderr.write_all(message.as_bytes()).await?; stderr.write_all(b"\n").await?; stderr.flush().await?; @@ -149,24 +135,11 @@ impl McpServer for ServerRuntime { } impl ServerRuntime { - pub(crate) async fn set_message_sender(&self, sender: MessageDispatcher) { - let mut lock = self.message_sender.write().await; - *lock = Some(sender); - } - #[cfg(feature = "hyper-server")] pub(crate) async fn session_id(&self) -> Option { self.session_id.to_owned() } - pub(crate) async fn set_error_stream( - &self, - error_stream: Pin>, - ) { - let mut lock = self.error_stream.write().await; - *lock = Some(error_stream); - } - #[cfg(feature = "hyper-server")] pub(crate) fn new_instance( server_details: Arc, @@ -179,8 +152,6 @@ impl ServerRuntime { client_details: Arc::new(RwLock::new(None)), transport: Box::new(transport), handler, - message_sender: tokio::sync::RwLock::new(None), - error_stream: tokio::sync::RwLock::new(None), session_id: Some(session_id), } } @@ -195,8 +166,6 @@ impl ServerRuntime { client_details: Arc::new(RwLock::new(None)), transport: Box::new(transport), handler, - message_sender: tokio::sync::RwLock::new(None), - error_stream: tokio::sync::RwLock::new(None), #[cfg(feature = "hyper-server")] session_id: None, } diff --git a/crates/rust-mcp-transport/src/client_sse.rs b/crates/rust-mcp-transport/src/client_sse.rs index 310c07d..0a8ef1c 100644 --- a/crates/rust-mcp-transport/src/client_sse.rs +++ b/crates/rust-mcp-transport/src/client_sse.rs @@ -52,7 +52,10 @@ impl Default for ClientSseTransportOptions { /// Client-side Server-Sent Events (SSE) transport implementation /// /// Manages SSE connections, HTTP POST requests, and message streaming for client-server communication. -pub struct ClientSseTransport { +pub struct ClientSseTransport +where + R: RpcMessage + Clone + Send + Sync + serde::de::DeserializeOwned + 'static, +{ /// Optional cancellation token source for shutting down the transport shutdown_source: tokio::sync::RwLock>, /// Flag indicating if the transport is shut down @@ -73,9 +76,14 @@ pub struct ClientSseTransport { custom_headers: Option, sse_task: tokio::sync::RwLock>>, post_task: tokio::sync::RwLock>>, + message_sender: tokio::sync::RwLock>>, + error_stream: tokio::sync::RwLock>, } -impl ClientSseTransport { +impl ClientSseTransport +where + R: RpcMessage + Clone + Send + Sync + serde::de::DeserializeOwned + 'static, +{ /// Creates a new ClientSseTransport instance /// /// Initializes the transport with the provided server URL and options. @@ -111,6 +119,8 @@ impl ClientSseTransport { custom_headers: headers, sse_task: tokio::sync::RwLock::new(None), post_task: tokio::sync::RwLock::new(None), + message_sender: tokio::sync::RwLock::new(None), + error_stream: tokio::sync::RwLock::new(None), }) } @@ -161,10 +171,23 @@ impl ClientSseTransport { } Ok(endpoint) } + + pub(crate) async fn set_message_sender(&self, sender: MessageDispatcher) { + let mut lock = self.message_sender.write().await; + *lock = Some(sender); + } + + pub(crate) async fn set_error_stream( + &self, + error_stream: Pin>, + ) { + let mut lock = self.error_stream.write().await; + *lock = Some(IoStream::Readable(error_stream)); + } } #[async_trait] -impl Transport for ClientSseTransport +impl Transport for ClientSseTransport where R: RpcMessage + Clone + Send + Sync + serde::de::DeserializeOwned + 'static, S: McpMessage + Clone + Send + Sync + serde::Serialize + 'static, @@ -176,13 +199,7 @@ where /// # Returns /// * `TransportResult<(Pin + Send>>, MessageDispatcher, IoStream)>` /// - The message stream, dispatcher, and error stream - async fn start( - &self, - ) -> TransportResult<( - Pin + Send>>, - MessageDispatcher, - IoStream, - )> + async fn start(&self) -> TransportResult + Send>>> where MessageDispatcher: McpDispatch, { @@ -290,7 +307,21 @@ where cancellation_token, ); - Ok((stream, sender, error_stream)) + self.set_message_sender(sender).await; + + if let IoStream::Readable(error_stream) = error_stream { + self.set_error_stream(error_stream).await; + } + + Ok(stream) + } + + async fn message_sender(&self) -> &tokio::sync::RwLock>> { + &self.message_sender as _ + } + + async fn error_stream(&self) -> &tokio::sync::RwLock> { + &self.error_stream as _ } /// Checks if the transport has been shut down diff --git a/crates/rust-mcp-transport/src/sse.rs b/crates/rust-mcp-transport/src/sse.rs index a8327a1..44c6bd3 100644 --- a/crates/rust-mcp-transport/src/sse.rs +++ b/crates/rust-mcp-transport/src/sse.rs @@ -2,6 +2,7 @@ use crate::schema::schema_utils::{McpMessage, RpcMessage}; use crate::schema::RequestId; use async_trait::async_trait; use futures::Stream; +use serde::de::DeserializeOwned; use std::collections::HashMap; use std::pin::Pin; use std::sync::Arc; @@ -15,15 +16,23 @@ use crate::transport::Transport; use crate::utils::{endpoint_with_session_id, CancellationTokenSource}; use crate::{IoStream, McpDispatch, SessionId, TransportOptions}; -pub struct SseTransport { +pub struct SseTransport +where + R: RpcMessage + Clone + Send + Sync + DeserializeOwned + 'static, +{ shutdown_source: tokio::sync::RwLock>, is_shut_down: Mutex, read_write_streams: Mutex>, options: Arc, + message_sender: tokio::sync::RwLock>>, + error_stream: tokio::sync::RwLock>, } /// Server-Sent Events (SSE) transport implementation -impl SseTransport { +impl SseTransport +where + R: RpcMessage + Clone + Send + Sync + DeserializeOwned + 'static, +{ /// Creates a new SseTransport instance /// /// Initializes the transport with provided read and write duplex streams and options. @@ -45,16 +54,31 @@ impl SseTransport { options, shutdown_source: tokio::sync::RwLock::new(None), is_shut_down: Mutex::new(false), + message_sender: tokio::sync::RwLock::new(None), + error_stream: tokio::sync::RwLock::new(None), }) } pub fn message_endpoint(endpoint: &str, session_id: &SessionId) -> String { endpoint_with_session_id(endpoint, session_id) } + + pub(crate) async fn set_message_sender(&self, sender: MessageDispatcher) { + let mut lock = self.message_sender.write().await; + *lock = Some(sender); + } + + pub(crate) async fn set_error_stream( + &self, + error_stream: Pin>, + ) { + let mut lock = self.error_stream.write().await; + *lock = Some(IoStream::Writable(error_stream)); + } } #[async_trait] -impl Transport for SseTransport +impl Transport for SseTransport where R: RpcMessage + Clone + Send + Sync + serde::de::DeserializeOwned + 'static, S: McpMessage + Clone + Send + Sync + serde::Serialize + 'static, @@ -69,13 +93,7 @@ where /// /// # Errors /// * Returns `TransportError` if streams are already taken or not initialized - async fn start( - &self, - ) -> TransportResult<( - Pin + Send>>, - MessageDispatcher, - IoStream, - )> + async fn start(&self) -> TransportResult + Send>>> where MessageDispatcher: McpDispatch, { @@ -103,7 +121,13 @@ where cancellation_token, ); - Ok((stream, sender, error_stream)) + self.set_message_sender(sender).await; + + if let IoStream::Writable(error_stream) = error_stream { + self.set_error_stream(error_stream).await; + } + + Ok(stream) } /// Checks if the transport has been shut down @@ -115,6 +139,14 @@ where *result } + async fn message_sender(&self) -> &tokio::sync::RwLock>> { + &self.message_sender as _ + } + + async fn error_stream(&self) -> &tokio::sync::RwLock> { + &self.error_stream as _ + } + /// Shuts down the transport, terminating tasks and signaling closure /// /// Cancels any running tasks and clears the cancellation source. diff --git a/crates/rust-mcp-transport/src/stdio.rs b/crates/rust-mcp-transport/src/stdio.rs index e9720f3..27aa802 100644 --- a/crates/rust-mcp-transport/src/stdio.rs +++ b/crates/rust-mcp-transport/src/stdio.rs @@ -2,6 +2,7 @@ use crate::schema::schema_utils::{McpMessage, RpcMessage}; use crate::schema::RequestId; use async_trait::async_trait; use futures::Stream; +use serde::de::DeserializeOwned; use std::collections::HashMap; use std::pin::Pin; use std::sync::Arc; @@ -22,16 +23,24 @@ use crate::{IoStream, McpDispatch, TransportOptions}; /// and server-side communication by optionally launching a subprocess or using the current /// process's stdio streams. The transport handles message streaming, dispatching, and shutdown /// operations, integrating with the MCP runtime ecosystem. -pub struct StdioTransport { +pub struct StdioTransport +where + R: RpcMessage + Clone + Send + Sync + DeserializeOwned + 'static, +{ command: Option, args: Option>, env: Option>, options: TransportOptions, shutdown_source: tokio::sync::RwLock>, is_shut_down: Mutex, + message_sender: tokio::sync::RwLock>>, + error_stream: tokio::sync::RwLock>, } -impl StdioTransport { +impl StdioTransport +where + R: RpcMessage + Clone + Send + Sync + DeserializeOwned + 'static, +{ /// Creates a new `StdioTransport` instance for MCP Server. /// /// This constructor configures the transport to use the current process's stdio streams, @@ -53,6 +62,8 @@ impl StdioTransport { options, shutdown_source: tokio::sync::RwLock::new(None), is_shut_down: Mutex::new(false), + message_sender: tokio::sync::RwLock::new(None), + error_stream: tokio::sync::RwLock::new(None), }) } @@ -84,6 +95,8 @@ impl StdioTransport { options, shutdown_source: tokio::sync::RwLock::new(None), is_shut_down: Mutex::new(false), + message_sender: tokio::sync::RwLock::new(None), + error_stream: tokio::sync::RwLock::new(None), }) } @@ -109,10 +122,23 @@ impl StdioTransport { (command, command_args) } } + + pub(crate) async fn set_message_sender(&self, sender: MessageDispatcher) { + let mut lock = self.message_sender.write().await; + *lock = Some(sender); + } + + pub(crate) async fn set_error_stream( + &self, + error_stream: Pin>, + ) { + let mut lock = self.error_stream.write().await; + *lock = Some(IoStream::Writable(error_stream)); + } } #[async_trait] -impl Transport for StdioTransport +impl Transport for StdioTransport where R: RpcMessage + Clone + Send + Sync + serde::de::DeserializeOwned + 'static, S: McpMessage + Clone + Send + Sync + serde::Serialize + 'static, @@ -130,13 +156,7 @@ where /// /// # Errors /// Returns a `TransportError` if the subprocess fails to spawn or stdio streams cannot be accessed. - async fn start( - &self, - ) -> TransportResult<( - Pin + Send>>, - MessageDispatcher, - IoStream, - )> + async fn start(&self) -> TransportResult + Send>>> where MessageDispatcher: McpDispatch, { @@ -200,7 +220,13 @@ where cancellation_token, ); - Ok((stream, sender, error_stream)) + self.set_message_sender(sender).await; + + if let IoStream::Writable(error_stream) = error_stream { + self.set_error_stream(error_stream).await; + } + + Ok(stream) } else { let pending_requests: Arc>>> = Arc::new(Mutex::new(HashMap::new())); @@ -213,7 +239,12 @@ where cancellation_token, ); - Ok((stream, sender, error_stream)) + self.set_message_sender(sender).await; + + if let IoStream::Writable(error_stream) = error_stream { + self.set_error_stream(error_stream).await; + } + Ok(stream) } } @@ -223,6 +254,14 @@ where *result } + async fn message_sender(&self) -> &tokio::sync::RwLock>> { + &self.message_sender as _ + } + + async fn error_stream(&self) -> &tokio::sync::RwLock> { + &self.error_stream as _ + } + // Shuts down the transport, terminating any subprocess and signaling closure. /// /// Sends a shutdown signal via the watch channel and kills the subprocess if present. diff --git a/crates/rust-mcp-transport/src/transport.rs b/crates/rust-mcp-transport/src/transport.rs index 4c013d1..fe1d61f 100644 --- a/crates/rust-mcp-transport/src/transport.rs +++ b/crates/rust-mcp-transport/src/transport.rs @@ -110,15 +110,11 @@ where R: McpMessage + Clone + Send + Sync + serde::de::DeserializeOwned + 'static, S: Clone + Send + Sync + serde::Serialize + 'static, { - async fn start( - &self, - ) -> TransportResult<( - Pin + Send>>, - MessageDispatcher, - IoStream, - )> + async fn start(&self) -> TransportResult + Send>>> where MessageDispatcher: McpDispatch; + async fn message_sender(&self) -> &tokio::sync::RwLock>>; + async fn error_stream(&self) -> &tokio::sync::RwLock>; async fn shut_down(&self) -> TransportResult<()>; async fn is_shut_down(&self) -> bool; }