diff --git a/crates/rmcp-macros/src/tool.rs b/crates/rmcp-macros/src/tool.rs index 8d1fb495..2ff22289 100644 --- a/crates/rmcp-macros/src/tool.rs +++ b/crates/rmcp-macros/src/tool.rs @@ -207,12 +207,22 @@ pub fn tool(attr: TokenStream, input: TokenStream) -> syn::Result { // 2. make return type: `std::pin::Pin + Send + '_>>` // 3. make body: { Box::pin(async move { #body }) } let new_output = syn::parse2::({ + let mut lt = quote! { 'static }; + if let Some(receiver) = fn_item.sig.receiver() { + if let Some((_, receiver_lt)) = receiver.reference.as_ref() { + if let Some(receiver_lt) = receiver_lt { + lt = quote! { #receiver_lt }; + } else { + lt = quote! { '_ }; + } + } + } match &fn_item.sig.output { syn::ReturnType::Default => { - quote! { -> std::pin::Pin + Send + '_>> } + quote! { -> std::pin::Pin + Send + #lt>> } } syn::ReturnType::Type(_, ty) => { - quote! { -> std::pin::Pin + Send + '_>> } + quote! { -> std::pin::Pin + Send + #lt>> } } } })?; diff --git a/crates/rmcp/Cargo.toml b/crates/rmcp/Cargo.toml index 999d4ee5..2b9126d1 100644 --- a/crates/rmcp/Cargo.toml +++ b/crates/rmcp/Cargo.toml @@ -71,7 +71,7 @@ chrono = { version = "0.4.38", default-features = false, features = ["serde", "c [features] default = ["base64", "macros", "server"] -client = [] +client = ["dep:tokio-stream"] server = ["transport-async-rw", "dep:schemars"] macros = ["dep:rmcp-macros", "dep:paste"] @@ -191,3 +191,8 @@ path = "tests/test_message_protocol.rs" name = "test_message_schema" required-features = ["server", "client", "schemars"] path = "tests/test_message_schema.rs" + +[[test]] +name = "test_progress_subscriber" +required-features = ["server", "client", "macros"] +path = "tests/test_progress_subscriber.rs" \ No newline at end of file diff --git a/crates/rmcp/src/handler/client.rs b/crates/rmcp/src/handler/client.rs index f6d4ef6e..37203a6c 100644 --- a/crates/rmcp/src/handler/client.rs +++ b/crates/rmcp/src/handler/client.rs @@ -1,3 +1,4 @@ +pub mod progress; use crate::{ error::Error as McpError, model::*, diff --git a/crates/rmcp/src/handler/client/progress.rs b/crates/rmcp/src/handler/client/progress.rs new file mode 100644 index 00000000..cb2ec314 --- /dev/null +++ b/crates/rmcp/src/handler/client/progress.rs @@ -0,0 +1,100 @@ +use std::{collections::HashMap, sync::Arc}; + +use futures::{Stream, StreamExt}; +use tokio::sync::RwLock; +use tokio_stream::wrappers::ReceiverStream; + +use crate::model::{ProgressNotificationParam, ProgressToken}; +type Dispatcher = + Arc>>>; + +/// A dispatcher for progress notifications. +#[derive(Debug, Clone, Default)] +pub struct ProgressDispatcher { + pub(crate) dispatcher: Dispatcher, +} + +impl ProgressDispatcher { + const CHANNEL_SIZE: usize = 16; + pub fn new() -> Self { + Self::default() + } + + /// Handle a progress notification by sending it to the appropriate subscriber + pub async fn handle_notification(&self, notification: ProgressNotificationParam) { + let token = ¬ification.progress_token; + if let Some(sender) = self.dispatcher.read().await.get(token).cloned() { + let send_result = sender.send(notification).await; + if let Err(e) = send_result { + tracing::warn!("Failed to send progress notification: {e}"); + } + } + } + + /// Subscribe to progress notifications for a specific token. + /// + /// If you drop the returned `ProgressSubscriber`, it will automatically unsubscribe from notifications for that token. + pub async fn subscribe(&self, progress_token: ProgressToken) -> ProgressSubscriber { + let (sender, receiver) = tokio::sync::mpsc::channel(Self::CHANNEL_SIZE); + self.dispatcher + .write() + .await + .insert(progress_token.clone(), sender); + let receiver = ReceiverStream::new(receiver); + ProgressSubscriber { + progress_token, + receiver, + dispacher: self.dispatcher.clone(), + } + } + + /// Unsubscribe from progress notifications for a specific token. + pub async fn unsubscribe(&self, token: &ProgressToken) { + self.dispatcher.write().await.remove(token); + } + + /// Clear all dispachter. + pub async fn clear(&self) { + let mut dispacher = self.dispatcher.write().await; + dispacher.clear(); + } +} + +pub struct ProgressSubscriber { + pub(crate) progress_token: ProgressToken, + pub(crate) receiver: ReceiverStream, + pub(crate) dispacher: Dispatcher, +} + +impl ProgressSubscriber { + pub fn progress_token(&self) -> &ProgressToken { + &self.progress_token + } +} + +impl Stream for ProgressSubscriber { + type Item = ProgressNotificationParam; + + fn poll_next( + mut self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + self.receiver.poll_next_unpin(cx) + } + + fn size_hint(&self) -> (usize, Option) { + self.receiver.size_hint() + } +} + +impl Drop for ProgressSubscriber { + fn drop(&mut self) { + let token = self.progress_token.clone(); + self.receiver.close(); + let dispatcher = self.dispacher.clone(); + tokio::spawn(async move { + let mut dispacher = dispatcher.write_owned().await; + dispacher.remove(&token); + }); + } +} diff --git a/crates/rmcp/src/handler/server/tool.rs b/crates/rmcp/src/handler/server/tool.rs index 9fbe5b94..5f33a742 100644 --- a/crates/rmcp/src/handler/server/tool.rs +++ b/crates/rmcp/src/handler/server/tool.rs @@ -99,11 +99,6 @@ pub trait FromToolCallContextPart: Sized { pub trait IntoCallToolResult { fn into_call_tool_result(self) -> Result; } -impl IntoCallToolResult for () { - fn into_call_tool_result(self) -> Result { - Ok(CallToolResult::success(vec![])) - } -} impl IntoCallToolResult for T { fn into_call_tool_result(self) -> Result { @@ -120,6 +115,15 @@ impl IntoCallToolResult for Result { } } +impl IntoCallToolResult for Result { + fn into_call_tool_result(self) -> Result { + match self { + Ok(value) => value.into_call_tool_result(), + Err(error) => Err(error), + } + } +} + pin_project_lite::pin_project! { #[project = IntoCallToolResultFutProj] pub enum IntoCallToolResultFut { diff --git a/crates/rmcp/src/model.rs b/crates/rmcp/src/model.rs index addc0bf0..657cdb29 100644 --- a/crates/rmcp/src/model.rs +++ b/crates/rmcp/src/model.rs @@ -242,7 +242,7 @@ pub type RequestId = NumberOrString; #[serde(transparent)] #[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] pub struct ProgressToken(pub NumberOrString); -#[derive(Debug, Clone)] +#[derive(Debug, Clone, Default)] #[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] pub struct Request { pub method: M, @@ -255,6 +255,16 @@ pub struct Request { pub extensions: Extensions, } +impl Request { + pub fn new(params: P) -> Self { + Self { + method: Default::default(), + params, + extensions: Extensions::default(), + } + } +} + impl GetExtensions for Request { fn extensions(&self) -> &Extensions { &self.extensions @@ -264,7 +274,7 @@ impl GetExtensions for Request { } } -#[derive(Debug, Clone)] +#[derive(Debug, Clone, Default)] #[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] pub struct RequestOptionalParam { pub method: M, @@ -277,7 +287,17 @@ pub struct RequestOptionalParam { pub extensions: Extensions, } -#[derive(Debug, Clone)] +impl RequestOptionalParam { + pub fn with_param(params: P) -> Self { + Self { + method: Default::default(), + params: Some(params), + extensions: Extensions::default(), + } + } +} + +#[derive(Debug, Clone, Default)] #[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] pub struct RequestNoParam { pub method: M, @@ -296,7 +316,7 @@ impl GetExtensions for RequestNoParam { &mut self.extensions } } -#[derive(Debug, Clone)] +#[derive(Debug, Clone, Default)] #[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] pub struct Notification { pub method: M, @@ -308,7 +328,17 @@ pub struct Notification { pub extensions: Extensions, } -#[derive(Debug, Clone)] +impl Notification { + pub fn new(params: P) -> Self { + Self { + method: Default::default(), + params, + extensions: Extensions::default(), + } + } +} + +#[derive(Debug, Clone, Default)] #[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] pub struct NotificationNoParam { pub method: M, diff --git a/crates/rmcp/src/model/content.rs b/crates/rmcp/src/model/content.rs index f21463c5..eab5f6ee 100644 --- a/crates/rmcp/src/model/content.rs +++ b/crates/rmcp/src/model/content.rs @@ -165,3 +165,9 @@ impl IntoContents for String { vec![Content::text(self)] } } + +impl IntoContents for () { + fn into_contents(self) -> Vec { + vec![] + } +} diff --git a/crates/rmcp/src/service.rs b/crates/rmcp/src/service.rs index 57d1089d..306c4c97 100644 --- a/crates/rmcp/src/service.rs +++ b/crates/rmcp/src/service.rs @@ -745,8 +745,9 @@ where let mut extensions = Extensions::new(); let mut meta = Meta::new(); // avoid clone - std::mem::swap(&mut extensions, request.extensions_mut()); + // swap meta firstly, otherwise progress token will be lost std::mem::swap(&mut meta, request.get_meta_mut()); + std::mem::swap(&mut extensions, request.extensions_mut()); let context = RequestContext { ct: context_ct, id: id.clone(), diff --git a/crates/rmcp/tests/test_progress_subscriber.rs b/crates/rmcp/tests/test_progress_subscriber.rs new file mode 100644 index 00000000..566667d7 --- /dev/null +++ b/crates/rmcp/tests/test_progress_subscriber.rs @@ -0,0 +1,126 @@ +use futures::StreamExt; +use rmcp::{ + ClientHandler, Peer, RoleServer, ServerHandler, ServiceExt, + handler::{client::progress::ProgressDispatcher, server::tool::ToolRouter}, + model::{CallToolRequestParam, ClientRequest, Meta, ProgressNotificationParam, Request}, + service::PeerRequestOptions, + tool, tool_handler, tool_router, +}; +use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; + +pub struct MyClient { + progress_handler: ProgressDispatcher, +} + +impl MyClient { + pub fn new() -> Self { + Self { + progress_handler: ProgressDispatcher::new(), + } + } +} + +impl Default for MyClient { + fn default() -> Self { + Self::new() + } +} + +impl ClientHandler for MyClient { + async fn on_progress( + &self, + params: rmcp::model::ProgressNotificationParam, + _context: rmcp::service::NotificationContext, + ) { + tracing::info!("Received progress notification: {:?}", params); + self.progress_handler.handle_notification(params).await; + } +} + +impl Default for MyServer { + fn default() -> Self { + Self::new() + } +} + +pub struct MyServer { + tool_router: ToolRouter, +} + +#[tool_router] +impl MyServer { + pub fn new() -> Self { + Self { + tool_router: Self::tool_router(), + } + } + #[tool] + pub async fn some_progress(meta: Meta, client: Peer) -> Result<(), rmcp::Error> { + let progress_token = meta + .get_progress_token() + .ok_or(rmcp::Error::invalid_params( + "Progress token is required for this tool", + None, + ))?; + for step in 0..10 { + let _ = client + .notify_progress(ProgressNotificationParam { + progress_token: progress_token.clone(), + progress: step, + total: Some(10), + message: Some("Some message".into()), + }) + .await; + tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; + } + Ok(()) + } +} + +#[tool_handler] +impl ServerHandler for MyServer {} + +#[tokio::test] +async fn test_progress_subscriber() -> anyhow::Result<()> { + let _ = tracing_subscriber::registry() + .with( + tracing_subscriber::EnvFilter::try_from_default_env() + .unwrap_or_else(|_| "debug".to_string().into()), + ) + .with(tracing_subscriber::fmt::layer()) + .try_init(); + let client = MyClient::new(); + + let server = MyServer::new(); + let (transport_server, transport_client) = tokio::io::duplex(4096); + tokio::spawn(async move { + let service = server.serve(transport_server).await?; + service.waiting().await?; + anyhow::Ok(()) + }); + let client_service = client.serve(transport_client).await?; + let handle = client_service + .send_cancellable_request( + ClientRequest::CallToolRequest(Request::new(CallToolRequestParam { + name: "some_progress".into(), + arguments: None, + })), + PeerRequestOptions::no_options(), + ) + .await?; + let mut progress_subscriber = client_service + .service() + .progress_handler + .subscribe(handle.progress_token.clone()) + .await; + tokio::spawn(async move { + while let Some(notification) = progress_subscriber.next().await { + tracing::info!("Progress notification: {:?}", notification); + } + }); + let _response = handle.await_response().await?; + + // Simulate some delay to allow the async task to complete + tokio::time::sleep(tokio::time::Duration::from_secs(1)).await; + Ok(()) +} diff --git a/crates/rmcp/tests/test_tool_routers.rs b/crates/rmcp/tests/test_tool_routers.rs index 846fe424..0a65cd47 100644 --- a/crates/rmcp/tests/test_tool_routers.rs +++ b/crates/rmcp/tests/test_tool_routers.rs @@ -43,10 +43,7 @@ impl TestHandler { } #[rmcp::tool] -async fn async_function( - _callee: &TestHandler, - Parameters(Request { fields }): Parameters, -) { +async fn async_function(Parameters(Request { fields }): Parameters) { drop(fields) }