Skip to content

chore: refactor transport trait and runtime architecture #74

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jul 13, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion crates/rust-mcp-sdk/src/hyper_servers/routes/sse_routes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ use axum::{
Extension, Router,
};
use futures::stream::{self};
use rust_mcp_schema::schema_utils::ClientMessage;
use rust_mcp_transport::{error::TransportError, SessionId, SseTransport};
use std::{convert::Infallible, sync::Arc, time::Duration};
use tokio::{
Expand Down Expand Up @@ -78,7 +79,7 @@ pub async fn handle_sse(
State(state): State<Arc<AppState>>,
) -> TransportServerResult<impl IntoResponse> {
let messages_endpoint =
SseTransport::message_endpoint(&state.sse_message_endpoint, &session_id);
SseTransport::<ClientMessage>::message_endpoint(&state.sse_message_endpoint, &session_id);

// readable stream of string to be used in transport
let (read_tx, read_rx) = duplex(DUPLEX_BUFFER_SIZE);
Expand Down
18 changes: 7 additions & 11 deletions crates/rust-mcp-sdk/src/mcp_runtimes/client_runtime.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,16 +28,10 @@ pub struct ClientRuntime {
client_details: InitializeRequestParams,
// Details about the connected server
server_details: Arc<RwLock<Option<InitializeResult>>>,
message_sender: tokio::sync::RwLock<Option<MessageDispatcher<ServerMessage>>>,
handlers: Mutex<Vec<tokio::task::JoinHandle<Result<(), McpSdkError>>>>,
}

impl ClientRuntime {
pub(crate) async fn set_message_sender(&self, sender: MessageDispatcher<ServerMessage>) {
let mut lock = self.message_sender.write().await;
*lock = Some(sender);
}

pub(crate) fn new(
client_details: InitializeRequestParams,
transport: impl Transport<ServerMessage, MessageFromClient>,
Expand All @@ -48,7 +42,6 @@ impl ClientRuntime {
handler,
client_details,
server_details: Arc::new(RwLock::new(None)),
message_sender: tokio::sync::RwLock::new(None),
handlers: Mutex::new(vec![]),
}
}
Expand Down Expand Up @@ -83,20 +76,22 @@ impl McpClient for ClientRuntime {
where
MessageDispatcher<ServerMessage>: McpDispatch<ServerMessage, MessageFromClient>,
{
(&self.message_sender) as _
(self.transport.message_sender().await) as _
}

async fn start(self: Arc<Self>) -> SdkResult<()> {
let (mut stream, sender, error_io) = self.transport.start().await?;
self.set_message_sender(sender).await;
let mut stream = self.transport.start().await?;

let mut error_io_stream = self.transport.error_stream().await.write().await;
let error_io_stream = error_io_stream.take();

let self_clone = Arc::clone(&self);
let self_clone_err = Arc::clone(&self);

let err_task = tokio::spawn(async move {
let self_ref = &*self_clone_err;

if let IoStream::Readable(error_input) = error_io {
if let Some(IoStream::Readable(error_input)) = error_io_stream {
let mut reader = BufReader::new(error_input).lines();
loop {
tokio::select! {
Expand Down Expand Up @@ -126,6 +121,7 @@ impl McpClient for ClientRuntime {
}
}
}

Ok::<(), McpSdkError>(())
});

Expand Down
41 changes: 5 additions & 36 deletions crates/rust-mcp-sdk/src/mcp_runtimes/server_runtime.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ use async_trait::async_trait;
use futures::StreamExt;
use rust_mcp_transport::{IoStream, McpDispatch, MessageDispatcher, Transport};
use schema_utils::ClientMessage;
use std::pin::Pin;
use std::sync::{Arc, RwLock};
use tokio::io::AsyncWriteExt;

Expand All @@ -27,9 +26,6 @@ pub struct ServerRuntime {
server_details: Arc<InitializeResult>,
// Details about the connected client
client_details: Arc<RwLock<Option<InitializeRequestParams>>>,

message_sender: tokio::sync::RwLock<Option<MessageDispatcher<ClientMessage>>>,
error_stream: tokio::sync::RwLock<Option<Pin<Box<dyn tokio::io::AsyncWrite + Send + Sync>>>>,
#[cfg(feature = "hyper-server")]
session_id: Option<SessionId>,
}
Expand Down Expand Up @@ -70,24 +66,14 @@ impl McpServer for ServerRuntime {
where
MessageDispatcher<ClientMessage>: McpDispatch<ClientMessage, MessageFromServer>,
{
(&self.message_sender) as _
(self.transport.message_sender().await) as _
}

/// Main runtime loop, processes incoming messages and handles requests
async fn start(&self) -> SdkResult<()> {
// Start the transport layer to begin handling messages
// self.transport.start().await?;
// Open the transport stream
// let mut stream = self.transport.open();
let (mut stream, sender, error_io) = self.transport.start().await?;

self.set_message_sender(sender).await;

if let IoStream::Writable(error_stream) = error_io {
self.set_error_stream(error_stream).await;
}
let mut stream = self.transport.start().await?;

let sender = self.sender().await.read().await;
let sender = self.transport.message_sender().await.read().await;
let sender = sender
.as_ref()
.ok_or(schema_utils::SdkError::connection_closed())?;
Expand Down Expand Up @@ -138,8 +124,8 @@ impl McpServer for ServerRuntime {
}

async fn stderr_message(&self, message: String) -> SdkResult<()> {
let mut lock = self.error_stream.write().await;
if let Some(stderr) = lock.as_mut() {
let mut lock = self.transport.error_stream().await.write().await;
if let Some(IoStream::Writable(stderr)) = lock.as_mut() {
stderr.write_all(message.as_bytes()).await?;
stderr.write_all(b"\n").await?;
stderr.flush().await?;
Expand All @@ -149,24 +135,11 @@ impl McpServer for ServerRuntime {
}

impl ServerRuntime {
pub(crate) async fn set_message_sender(&self, sender: MessageDispatcher<ClientMessage>) {
let mut lock = self.message_sender.write().await;
*lock = Some(sender);
}

#[cfg(feature = "hyper-server")]
pub(crate) async fn session_id(&self) -> Option<SessionId> {
self.session_id.to_owned()
}

pub(crate) async fn set_error_stream(
&self,
error_stream: Pin<Box<dyn tokio::io::AsyncWrite + Send + Sync>>,
) {
let mut lock = self.error_stream.write().await;
*lock = Some(error_stream);
}

#[cfg(feature = "hyper-server")]
pub(crate) fn new_instance(
server_details: Arc<InitializeResult>,
Expand All @@ -179,8 +152,6 @@ impl ServerRuntime {
client_details: Arc::new(RwLock::new(None)),
transport: Box::new(transport),
handler,
message_sender: tokio::sync::RwLock::new(None),
error_stream: tokio::sync::RwLock::new(None),
session_id: Some(session_id),
}
}
Expand All @@ -195,8 +166,6 @@ impl ServerRuntime {
client_details: Arc::new(RwLock::new(None)),
transport: Box::new(transport),
handler,
message_sender: tokio::sync::RwLock::new(None),
error_stream: tokio::sync::RwLock::new(None),
#[cfg(feature = "hyper-server")]
session_id: None,
}
Expand Down
53 changes: 42 additions & 11 deletions crates/rust-mcp-transport/src/client_sse.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,10 @@ impl Default for ClientSseTransportOptions {
/// Client-side Server-Sent Events (SSE) transport implementation
///
/// Manages SSE connections, HTTP POST requests, and message streaming for client-server communication.
pub struct ClientSseTransport {
pub struct ClientSseTransport<R>
where
R: RpcMessage + Clone + Send + Sync + serde::de::DeserializeOwned + 'static,
{
/// Optional cancellation token source for shutting down the transport
shutdown_source: tokio::sync::RwLock<Option<CancellationTokenSource>>,
/// Flag indicating if the transport is shut down
Expand All @@ -73,9 +76,14 @@ pub struct ClientSseTransport {
custom_headers: Option<HeaderMap>,
sse_task: tokio::sync::RwLock<Option<tokio::task::JoinHandle<()>>>,
post_task: tokio::sync::RwLock<Option<tokio::task::JoinHandle<()>>>,
message_sender: tokio::sync::RwLock<Option<MessageDispatcher<R>>>,
error_stream: tokio::sync::RwLock<Option<IoStream>>,
}

impl ClientSseTransport {
impl<R> ClientSseTransport<R>
where
R: RpcMessage + Clone + Send + Sync + serde::de::DeserializeOwned + 'static,
{
/// Creates a new ClientSseTransport instance
///
/// Initializes the transport with the provided server URL and options.
Expand Down Expand Up @@ -111,6 +119,8 @@ impl ClientSseTransport {
custom_headers: headers,
sse_task: tokio::sync::RwLock::new(None),
post_task: tokio::sync::RwLock::new(None),
message_sender: tokio::sync::RwLock::new(None),
error_stream: tokio::sync::RwLock::new(None),
})
}

Expand Down Expand Up @@ -161,10 +171,23 @@ impl ClientSseTransport {
}
Ok(endpoint)
}

pub(crate) async fn set_message_sender(&self, sender: MessageDispatcher<R>) {
let mut lock = self.message_sender.write().await;
*lock = Some(sender);
}

pub(crate) async fn set_error_stream(
&self,
error_stream: Pin<Box<dyn tokio::io::AsyncRead + Send + Sync>>,
) {
let mut lock = self.error_stream.write().await;
*lock = Some(IoStream::Readable(error_stream));
}
}

#[async_trait]
impl<R, S> Transport<R, S> for ClientSseTransport
impl<R, S> Transport<R, S> for ClientSseTransport<R>
where
R: RpcMessage + Clone + Send + Sync + serde::de::DeserializeOwned + 'static,
S: McpMessage + Clone + Send + Sync + serde::Serialize + 'static,
Expand All @@ -176,13 +199,7 @@ where
/// # Returns
/// * `TransportResult<(Pin<Box<dyn Stream<Item = R> + Send>>, MessageDispatcher<R>, IoStream)>`
/// - The message stream, dispatcher, and error stream
async fn start(
&self,
) -> TransportResult<(
Pin<Box<dyn Stream<Item = R> + Send>>,
MessageDispatcher<R>,
IoStream,
)>
async fn start(&self) -> TransportResult<Pin<Box<dyn Stream<Item = R> + Send>>>
where
MessageDispatcher<R>: McpDispatch<R, S>,
{
Expand Down Expand Up @@ -290,7 +307,21 @@ where
cancellation_token,
);

Ok((stream, sender, error_stream))
self.set_message_sender(sender).await;

if let IoStream::Readable(error_stream) = error_stream {
self.set_error_stream(error_stream).await;
}

Ok(stream)
}

async fn message_sender(&self) -> &tokio::sync::RwLock<Option<MessageDispatcher<R>>> {
&self.message_sender as _
}

async fn error_stream(&self) -> &tokio::sync::RwLock<Option<IoStream>> {
&self.error_stream as _
}

/// Checks if the transport has been shut down
Expand Down
54 changes: 43 additions & 11 deletions crates/rust-mcp-transport/src/sse.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use crate::schema::schema_utils::{McpMessage, RpcMessage};
use crate::schema::RequestId;
use async_trait::async_trait;
use futures::Stream;
use serde::de::DeserializeOwned;
use std::collections::HashMap;
use std::pin::Pin;
use std::sync::Arc;
Expand All @@ -15,15 +16,23 @@ use crate::transport::Transport;
use crate::utils::{endpoint_with_session_id, CancellationTokenSource};
use crate::{IoStream, McpDispatch, SessionId, TransportOptions};

pub struct SseTransport {
pub struct SseTransport<R>
where
R: RpcMessage + Clone + Send + Sync + DeserializeOwned + 'static,
{
shutdown_source: tokio::sync::RwLock<Option<CancellationTokenSource>>,
is_shut_down: Mutex<bool>,
read_write_streams: Mutex<Option<(DuplexStream, DuplexStream)>>,
options: Arc<TransportOptions>,
message_sender: tokio::sync::RwLock<Option<MessageDispatcher<R>>>,
error_stream: tokio::sync::RwLock<Option<IoStream>>,
}

/// Server-Sent Events (SSE) transport implementation
impl SseTransport {
impl<R> SseTransport<R>
where
R: RpcMessage + Clone + Send + Sync + DeserializeOwned + 'static,
{
/// Creates a new SseTransport instance
///
/// Initializes the transport with provided read and write duplex streams and options.
Expand All @@ -45,16 +54,31 @@ impl SseTransport {
options,
shutdown_source: tokio::sync::RwLock::new(None),
is_shut_down: Mutex::new(false),
message_sender: tokio::sync::RwLock::new(None),
error_stream: tokio::sync::RwLock::new(None),
})
}

pub fn message_endpoint(endpoint: &str, session_id: &SessionId) -> String {
endpoint_with_session_id(endpoint, session_id)
}

pub(crate) async fn set_message_sender(&self, sender: MessageDispatcher<R>) {
let mut lock = self.message_sender.write().await;
*lock = Some(sender);
}

pub(crate) async fn set_error_stream(
&self,
error_stream: Pin<Box<dyn tokio::io::AsyncWrite + Send + Sync>>,
) {
let mut lock = self.error_stream.write().await;
*lock = Some(IoStream::Writable(error_stream));
}
}

#[async_trait]
impl<R, S> Transport<R, S> for SseTransport
impl<R, S> Transport<R, S> for SseTransport<R>
where
R: RpcMessage + Clone + Send + Sync + serde::de::DeserializeOwned + 'static,
S: McpMessage + Clone + Send + Sync + serde::Serialize + 'static,
Expand All @@ -69,13 +93,7 @@ where
///
/// # Errors
/// * Returns `TransportError` if streams are already taken or not initialized
async fn start(
&self,
) -> TransportResult<(
Pin<Box<dyn Stream<Item = R> + Send>>,
MessageDispatcher<R>,
IoStream,
)>
async fn start(&self) -> TransportResult<Pin<Box<dyn Stream<Item = R> + Send>>>
where
MessageDispatcher<R>: McpDispatch<R, S>,
{
Expand Down Expand Up @@ -103,7 +121,13 @@ where
cancellation_token,
);

Ok((stream, sender, error_stream))
self.set_message_sender(sender).await;

if let IoStream::Writable(error_stream) = error_stream {
self.set_error_stream(error_stream).await;
}

Ok(stream)
}

/// Checks if the transport has been shut down
Expand All @@ -115,6 +139,14 @@ where
*result
}

async fn message_sender(&self) -> &tokio::sync::RwLock<Option<MessageDispatcher<R>>> {
&self.message_sender as _
}

async fn error_stream(&self) -> &tokio::sync::RwLock<Option<IoStream>> {
&self.error_stream as _
}

/// Shuts down the transport, terminating tasks and signaling closure
///
/// Cancels any running tasks and clears the cancellation source.
Expand Down
Loading