From d37a6fb775268cd9a6863537b4019e734c162ca3 Mon Sep 17 00:00:00 2001 From: snowmead Date: Fri, 11 Jul 2025 21:25:56 -0400 Subject: [PATCH 1/3] feat: Add prompt support with typed argument handling MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Implement #[prompt], #[prompt_router], and #[prompt_handler] macros - Add automatic JSON schema generation from Rust types for arguments - Support flexible async handler signatures with automatic adaptation - Create PromptRouter for efficient prompt dispatch - Include comprehensive tests and example implementation This enables MCP servers to provide reusable prompt templates that LLMs can discover and invoke with strongly-typed parameters, similar to the existing tool system but optimized for prompt use cases. 🤖 Generated with Claude Code Co-Authored-By: Claude --- crates/rmcp-macros/src/common.rs | 41 +++ crates/rmcp-macros/src/lib.rs | 103 ++++++ crates/rmcp-macros/src/prompt.rs | 204 +++++++++++ crates/rmcp-macros/src/prompt_handler.rs | 142 +++++++ crates/rmcp-macros/src/prompt_router.rs | 97 +++++ crates/rmcp-macros/src/tool.rs | 39 +- crates/rmcp/Cargo.toml | 2 +- crates/rmcp/src/handler/server.rs | 1 + crates/rmcp/src/handler/server/prompt.rs | 346 ++++++++++++++++++ crates/rmcp/src/handler/server/router.rs | 44 ++- .../rmcp/src/handler/server/router/prompt.rs | 212 +++++++++++ crates/rmcp/src/service.rs | 8 +- crates/rmcp/tests/test_prompts.rs | 210 +++++++++++ examples/servers/Cargo.toml | 4 + examples/servers/README.md | 13 + examples/servers/src/prompt_stdio.rs | 145 ++++++++ 16 files changed, 1569 insertions(+), 42 deletions(-) create mode 100644 crates/rmcp-macros/src/common.rs create mode 100644 crates/rmcp-macros/src/prompt.rs create mode 100644 crates/rmcp-macros/src/prompt_handler.rs create mode 100644 crates/rmcp-macros/src/prompt_router.rs create mode 100644 crates/rmcp/src/handler/server/prompt.rs create mode 100644 crates/rmcp/tests/test_prompts.rs create mode 100644 examples/servers/src/prompt_stdio.rs diff --git a/crates/rmcp-macros/src/common.rs b/crates/rmcp-macros/src/common.rs new file mode 100644 index 00000000..d009471f --- /dev/null +++ b/crates/rmcp-macros/src/common.rs @@ -0,0 +1,41 @@ +//! Common utilities shared between different macro implementations + +use quote::quote; +use syn::{Attribute, Expr}; + +/// Parse a None expression +pub fn none_expr() -> syn::Result { + syn::parse2::(quote! { None }) +} + +/// Extract documentation from doc attributes +pub fn extract_doc_line(existing_docs: Option, attr: &Attribute) -> Option { + if !attr.path().is_ident("doc") { + return None; + } + + let syn::Meta::NameValue(name_value) = &attr.meta else { + return None; + }; + + let syn::Expr::Lit(expr_lit) = &name_value.value else { + return None; + }; + + let syn::Lit::Str(lit_str) = &expr_lit.lit else { + return None; + }; + + let content = lit_str.value().trim().to_string(); + match (existing_docs, content) { + (Some(mut existing_docs), content) if !content.is_empty() => { + existing_docs.push('\n'); + existing_docs.push_str(&content); + Some(existing_docs) + } + (Some(existing_docs), _) => Some(existing_docs), + (None, content) if !content.is_empty() => Some(content), + _ => None, + } +} + diff --git a/crates/rmcp-macros/src/lib.rs b/crates/rmcp-macros/src/lib.rs index 57e175bf..efc002a1 100644 --- a/crates/rmcp-macros/src/lib.rs +++ b/crates/rmcp-macros/src/lib.rs @@ -1,6 +1,10 @@ #[allow(unused_imports)] use proc_macro::TokenStream; +mod common; +mod prompt; +mod prompt_handler; +mod prompt_router; mod tool; mod tool_handler; mod tool_router; @@ -160,3 +164,102 @@ pub fn tool_handler(attr: TokenStream, input: TokenStream) -> TokenStream { .unwrap_or_else(|err| err.to_compile_error()) .into() } + +/// # prompt +/// +/// This macro is used to mark a function as a prompt handler. +/// +/// This will generate a function that returns the attribute of this prompt, with type `rmcp::model::Prompt`. +/// +/// ## Usage +/// +/// | field | type | usage | +/// | :- | :- | :- | +/// | `name` | `String` | The name of the prompt. If not provided, it defaults to the function name. | +/// | `description` | `String` | A description of the prompt. The document of this function will be used if not provided. | +/// | `arguments` | `Expr` | Arguments that can be passed to the prompt. If not provided, it will use arguments from `Arguments` or `PromptArguments` parameter type. | +/// +/// ## Example +/// +/// ```rust,ignore +/// #[prompt(name = "code_review", description = "Reviews code for best practices")] +/// pub async fn code_review_prompt(&self, Arguments(args): Arguments) -> Result> { +/// // Generate prompt messages based on arguments +/// } +/// ``` +#[proc_macro_attribute] +pub fn prompt(attr: TokenStream, input: TokenStream) -> TokenStream { + prompt::prompt(attr.into(), input.into()) + .unwrap_or_else(|err| err.to_compile_error()) + .into() +} + +/// # prompt_router +/// +/// This macro generates a prompt router based on functions marked with `#[rmcp::prompt]` in an implementation block. +/// +/// It creates a function that returns a `PromptRouter` instance. +/// +/// ## Usage +/// +/// | field | type | usage | +/// | :- | :- | :- | +/// | `router` | `Ident` | The name of the router function to be generated. Defaults to `prompt_router`. | +/// | `vis` | `Visibility` | The visibility of the generated router function. Defaults to empty. | +/// +/// ## Example +/// +/// ```rust,ignore +/// #[prompt_router] +/// impl MyPromptHandler { +/// #[prompt] +/// pub async fn greeting_prompt(&self) -> Result, Error> { +/// // Generate greeting prompt +/// } +/// +/// pub fn new() -> Self { +/// Self { +/// // the default name of prompt router will be `prompt_router` +/// prompt_router: Self::prompt_router(), +/// } +/// } +/// } +/// ``` +#[proc_macro_attribute] +pub fn prompt_router(attr: TokenStream, input: TokenStream) -> TokenStream { + prompt_router::prompt_router(attr.into(), input.into()) + .unwrap_or_else(|err| err.to_compile_error()) + .into() +} + +/// # prompt_handler +/// +/// This macro generates handler methods for `get_prompt` and `list_prompts` in the implementation block, using an existing `PromptRouter` instance. +/// +/// ## Usage +/// +/// | field | type | usage | +/// | :- | :- | :- | +/// | `router` | `Expr` | The expression to access the `PromptRouter` instance. Defaults to `self.prompt_router`. | +/// +/// ## Example +/// ```rust,ignore +/// #[prompt_handler] +/// impl ServerHandler for MyPromptHandler { +/// // ...implement other handler methods +/// } +/// ``` +/// +/// or using a custom router expression: +/// ```rust,ignore +/// #[prompt_handler(router = self.get_prompt_router())] +/// impl ServerHandler for MyPromptHandler { +/// // ...implement other handler methods +/// } +/// ``` +#[proc_macro_attribute] +pub fn prompt_handler(attr: TokenStream, input: TokenStream) -> TokenStream { + prompt_handler::prompt_handler(attr.into(), input.into()) + .unwrap_or_else(|err| err.to_compile_error()) + .into() +} diff --git a/crates/rmcp-macros/src/prompt.rs b/crates/rmcp-macros/src/prompt.rs new file mode 100644 index 00000000..c1a6b8c9 --- /dev/null +++ b/crates/rmcp-macros/src/prompt.rs @@ -0,0 +1,204 @@ +use darling::{FromMeta, ast::NestedMeta}; +use proc_macro2::TokenStream; +use quote::{format_ident, quote}; +use syn::{Expr, Ident, ImplItemFn, ReturnType}; + +use crate::common::{extract_doc_line, none_expr}; + +#[derive(FromMeta, Default, Debug)] +#[darling(default)] +pub struct PromptAttribute { + /// The name of the prompt + pub name: Option, + /// Optional description of what the prompt does + pub description: Option, + /// Arguments that can be passed to customize the prompt + pub arguments: Option, +} + +pub struct ResolvedPromptAttribute { + pub name: String, + pub description: Option, + pub arguments: Expr, +} + +impl ResolvedPromptAttribute { + pub fn into_fn(self, fn_ident: Ident) -> syn::Result { + let Self { + name, + description, + arguments, + } = self; + let description = if let Some(description) = description { + quote! { Some(#description.into()) } + } else { + quote! { None } + }; + let tokens = quote! { + pub fn #fn_ident() -> rmcp::model::Prompt { + rmcp::model::Prompt { + name: #name.into(), + description: #description, + arguments: #arguments, + } + } + }; + syn::parse2::(tokens) + } +} + +pub fn prompt(attr: TokenStream, input: TokenStream) -> syn::Result { + let attribute = if attr.is_empty() { + Default::default() + } else { + let attr_args = NestedMeta::parse_meta_list(attr)?; + PromptAttribute::from_list(&attr_args)? + }; + let mut fn_item = syn::parse2::(input.clone())?; + let fn_ident = &fn_item.sig.ident; + + let prompt_attr_fn_ident = format_ident!("{}_prompt_attr", fn_ident); + + // Try to find prompt arguments from function parameters + let arguments_expr = if let Some(arguments) = attribute.arguments { + arguments + } else { + // Look for a type named Arguments or PromptArguments in the function signature + let args_ty = fn_item.sig.inputs.iter().find_map(|input| { + if let syn::FnArg::Typed(pat_type) = input { + if let syn::Type::Path(type_path) = &*pat_type.ty { + if let Some(last_segment) = type_path.path.segments.last() { + if last_segment.ident == "Arguments" + || last_segment.ident == "PromptArguments" + { + // Extract the inner type from Arguments + if let syn::PathArguments::AngleBracketed(args) = + &last_segment.arguments + { + if let Some(syn::GenericArgument::Type(inner_ty)) = + args.args.first() + { + return Some(inner_ty.clone()); + } + } + } + } + } + } + None + }); + + if let Some(args_ty) = args_ty { + // Generate arguments from the type's schema with caching + syn::parse2::(quote! { + rmcp::handler::server::prompt::cached_arguments_from_schema::<#args_ty>() + })? + } else { + // No arguments + none_expr()? + } + }; + + let name = attribute.name.unwrap_or_else(|| fn_ident.to_string()); + let description = attribute + .description + .or_else(|| fn_item.attrs.iter().fold(None, extract_doc_line)); + let arguments = arguments_expr; + + let resolved_prompt_attr = ResolvedPromptAttribute { + name: name.clone(), + description: description.clone(), + arguments: arguments.clone(), + }; + let prompt_attr_fn = resolved_prompt_attr.into_fn(prompt_attr_fn_ident.clone())?; + + // Modify the input function for async support + if fn_item.sig.asyncness.is_some() { + // 1. remove asyncness from sig + // 2. make return type: `std::pin::Pin + Send + '_>>` + // 3. make body: { Box::pin(async move { #body }) } + let new_output = syn::parse2::({ + let mut lt = quote! { 'static }; + if let Some(receiver) = fn_item.sig.receiver() { + if let Some((_, receiver_lt)) = receiver.reference.as_ref() { + if let Some(receiver_lt) = receiver_lt { + lt = quote! { #receiver_lt }; + } else { + lt = quote! { '_ }; + } + } + } + match &fn_item.sig.output { + syn::ReturnType::Default => { + quote! { -> std::pin::Pin + Send + #lt>> } + } + syn::ReturnType::Type(_, ty) => { + quote! { -> std::pin::Pin + Send + #lt>> } + } + } + })?; + let prev_block = &fn_item.block; + let new_block = syn::parse2::(quote! { + { Box::pin(async move #prev_block ) } + })?; + fn_item.sig.asyncness = None; + fn_item.sig.output = new_output; + fn_item.block = new_block; + } + + Ok(quote! { + #prompt_attr_fn + #fn_item + }) +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn test_prompt_macro() -> syn::Result<()> { + let attr = quote! { + name = "example-prompt", + description = "An example prompt" + }; + let input = quote! { + async fn example_prompt(&self, Arguments(args): Arguments) -> Result { + Ok("Example prompt response".to_string()) + } + }; + let result = prompt(attr, input)?; + + // Verify the output contains both the attribute function and the modified function + let result_str = result.to_string(); + assert!(result_str.contains("example_prompt_prompt_attr")); + assert!( + result_str.contains("rmcp") + && result_str.contains("model") + && result_str.contains("Prompt") + ); + + Ok(()) + } + + #[test] + fn test_doc_comment_description() -> syn::Result<()> { + let attr = quote! {}; // No explicit description + let input = quote! { + /// This is a test prompt description + /// with multiple lines + fn test_prompt(&self) -> Result { + Ok("Test".to_string()) + } + }; + let result = prompt(attr, input)?; + + // The output should contain the description from doc comments + let result_str = result.to_string(); + assert!(result_str.contains("This is a test prompt description")); + assert!(result_str.contains("with multiple lines")); + + Ok(()) + } +} + diff --git a/crates/rmcp-macros/src/prompt_handler.rs b/crates/rmcp-macros/src/prompt_handler.rs new file mode 100644 index 00000000..9d8fcd70 --- /dev/null +++ b/crates/rmcp-macros/src/prompt_handler.rs @@ -0,0 +1,142 @@ +use darling::FromMeta; +use proc_macro2::TokenStream; +use quote::quote; +use syn::{Expr, ImplItem, ItemImpl, parse_quote}; + +#[derive(FromMeta, Debug, Default)] +#[darling(default)] +pub struct PromptHandlerAttribute { + pub router: Option, +} + +pub fn prompt_handler(attr: TokenStream, input: TokenStream) -> syn::Result { + let attribute = if attr.is_empty() { + Default::default() + } else { + let attr_args = darling::ast::NestedMeta::parse_meta_list(attr)?; + PromptHandlerAttribute::from_list(&attr_args)? + }; + + let mut impl_block = syn::parse2::(input)?; + + let router_expr = attribute + .router + .unwrap_or_else(|| syn::parse2(quote! { self.prompt_router }).unwrap()); + + // Add get_prompt implementation + let get_prompt_impl: ImplItem = parse_quote! { + async fn get_prompt( + &self, + request: GetPromptRequestParam, + context: RequestContext, + ) -> Result { + let prompt_context = rmcp::handler::server::prompt::PromptContext::new( + self, + request.name, + request.arguments, + context, + ); + #router_expr.get_prompt(prompt_context).await + } + }; + + // Add list_prompts implementation + let list_prompts_impl: ImplItem = parse_quote! { + async fn list_prompts( + &self, + _request: Option, + _context: RequestContext, + ) -> Result { + let prompts = #router_expr.list_all(); + Ok(ListPromptsResult { + prompts, + next_cursor: None, + }) + } + }; + + // Check if methods already exist and replace them if they do + let mut has_get_prompt = false; + let mut has_list_prompts = false; + + for item in &mut impl_block.items { + if let ImplItem::Fn(fn_item) = item { + match fn_item.sig.ident.to_string().as_str() { + "get_prompt" => { + *item = get_prompt_impl.clone(); + has_get_prompt = true; + } + "list_prompts" => { + *item = list_prompts_impl.clone(); + has_list_prompts = true; + } + _ => {} + } + } + } + + // Add methods if they don't exist + if !has_get_prompt { + impl_block.items.push(get_prompt_impl); + } + if !has_list_prompts { + impl_block.items.push(list_prompts_impl); + } + + Ok(quote! { + #impl_block + }) +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn test_prompt_handler_macro() -> syn::Result<()> { + let input = quote! { + impl ServerHandler for MyPromptHandler { + // Other handler methods... + } + }; + + let result = prompt_handler(TokenStream::new(), input)?; + let result_str = result.to_string(); + + // Check that the required methods were generated + assert!(result_str.contains("async fn get_prompt")); + assert!(result_str.contains("PromptContext") && result_str.contains("new")); + assert!(result_str.contains("async fn list_prompts")); + assert!(result_str.contains("ListPromptsResult")); + + Ok(()) + } + + #[test] + fn test_prompt_handler_with_custom_router() -> syn::Result<()> { + let attr = quote! { router = self.get_prompt_router() }; + let input = quote! { + impl ServerHandler for MyPromptHandler { + // Other handler methods... + } + }; + + let result = prompt_handler(attr, input)?; + let result_str = result.to_string(); + + // Check that the custom router expression is used + assert!( + result_str.contains("self") + && result_str.contains("get_prompt_router") + && result_str.contains("get_prompt") + ); + assert!( + result_str.contains("self") + && result_str.contains("get_prompt_router") + && result_str.contains("list_all") + ); + + Ok(()) + } +} + diff --git a/crates/rmcp-macros/src/prompt_router.rs b/crates/rmcp-macros/src/prompt_router.rs new file mode 100644 index 00000000..791a9912 --- /dev/null +++ b/crates/rmcp-macros/src/prompt_router.rs @@ -0,0 +1,97 @@ +use darling::FromMeta; +use proc_macro2::TokenStream; +use quote::{format_ident, quote}; +use syn::{ImplItem, ItemImpl, Visibility, parse_quote}; + +#[derive(FromMeta, Debug, Default)] +#[darling(default)] +pub struct PromptRouterAttribute { + pub router: Option, + pub vis: Option, +} + +pub fn prompt_router(attr: TokenStream, input: TokenStream) -> syn::Result { + let attribute = if attr.is_empty() { + Default::default() + } else { + let attr_args = darling::ast::NestedMeta::parse_meta_list(attr)?; + PromptRouterAttribute::from_list(&attr_args)? + }; + + let mut impl_block = syn::parse2::(input)?; + let self_ty = &impl_block.self_ty; + + let router_fn_ident = attribute + .router + .map(|s| format_ident!("{}", s)) + .unwrap_or_else(|| format_ident!("prompt_router")); + let vis = attribute.vis.unwrap_or(Visibility::Inherited); + + let mut prompt_route_fn_calls = Vec::new(); + + for item in &mut impl_block.items { + if let ImplItem::Fn(fn_item) = item { + let has_prompt_attr = fn_item.attrs.iter().any(|attr| { + attr.path() + .segments + .last() + .map(|seg| seg.ident == "prompt") + .unwrap_or(false) + }); + + if has_prompt_attr { + let fn_ident = &fn_item.sig.ident; + let attr_fn_ident = format_ident!("{}_prompt_attr", fn_ident); + prompt_route_fn_calls + .push(quote! { .with_route((Self::#attr_fn_ident(), Self::#fn_ident)) }); + } + } + } + + let router_fn: ImplItem = parse_quote! { + #vis fn #router_fn_ident() -> rmcp::handler::server::router::prompt::PromptRouter<#self_ty> { + rmcp::handler::server::router::prompt::PromptRouter::new() + #(#prompt_route_fn_calls)* + } + }; + + impl_block.items.push(router_fn); + + Ok(quote! { + #impl_block + }) +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn test_prompt_router_macro() -> syn::Result<()> { + let input = quote! { + impl MyPromptHandler { + #[prompt] + async fn greeting_prompt(&self) -> Result, Error> { + Ok(vec![]) + } + + #[prompt] + async fn code_review_prompt(&self, args: Arguments) -> Result, Error> { + Ok(vec![]) + } + } + }; + + let result = prompt_router(TokenStream::new(), input)?; + let result_str = result.to_string(); + + // Check that the prompt_router function was generated + assert!(result_str.contains("fn prompt_router")); + assert!(result_str.contains("PromptRouter :: new")); + assert!(result_str.contains("greeting_prompt_prompt_attr")); + assert!(result_str.contains("code_review_prompt_prompt_attr")); + + Ok(()) + } +} + diff --git a/crates/rmcp-macros/src/tool.rs b/crates/rmcp-macros/src/tool.rs index 2ff22289..f083ce31 100644 --- a/crates/rmcp-macros/src/tool.rs +++ b/crates/rmcp-macros/src/tool.rs @@ -2,6 +2,8 @@ use darling::{FromMeta, ast::NestedMeta}; use proc_macro2::TokenStream; use quote::{ToTokens, format_ident, quote}; use syn::{Expr, Ident, ImplItemFn, ReturnType}; + +use crate::common::{extract_doc_line, none_expr}; #[derive(FromMeta, Default, Debug)] #[darling(default)] pub struct ToolAttribute { @@ -85,41 +87,6 @@ pub struct ToolAnnotationsAttribute { pub open_world_hint: Option, } -fn none_expr() -> Expr { - syn::parse2::(quote! { None }).unwrap() -} - -// extract doc line from attribute -fn extract_doc_line(existing_docs: Option, attr: &syn::Attribute) -> Option { - if !attr.path().is_ident("doc") { - return None; - } - - let syn::Meta::NameValue(name_value) = &attr.meta else { - return None; - }; - - let syn::Expr::Lit(expr_lit) = &name_value.value else { - return None; - }; - - let syn::Lit::Str(lit_str) = &expr_lit.lit else { - return None; - }; - - let content = lit_str.value().trim().to_string(); - match (existing_docs, content) { - (Some(mut existing_docs), content) if !content.is_empty() => { - existing_docs.push('\n'); - existing_docs.push_str(&content); - Some(existing_docs) - } - (Some(existing_docs), _) => Some(existing_docs), - (None, content) if !content.is_empty() => Some(content), - _ => None, - } -} - pub fn tool(attr: TokenStream, input: TokenStream) -> syn::Result { let attribute = if attr.is_empty() { Default::default() @@ -190,7 +157,7 @@ pub fn tool(attr: TokenStream, input: TokenStream) -> syn::Result { }; syn::parse2::(token_stream)? } else { - none_expr() + none_expr()? }; let resolved_tool_attr = ResolvedToolAttribute { name: attribute.name.unwrap_or_else(|| fn_ident.to_string()), diff --git a/crates/rmcp/Cargo.toml b/crates/rmcp/Cargo.toml index 8e9b48e9..a5687d5a 100644 --- a/crates/rmcp/Cargo.toml +++ b/crates/rmcp/Cargo.toml @@ -62,7 +62,7 @@ http-body = { version = "1", optional = true } http-body-util = { version = "0.1", optional = true } bytes = { version = "1", optional = true } # macro -rmcp-macros = { version = "0.2.1", workspace = true, optional = true } +rmcp-macros = { workspace = true, optional = true } [target.'cfg(not(all(target_family = "wasm", target_os = "unknown")))'.dependencies] chrono = { version = "0.4.38", features = ["serde"] } diff --git a/crates/rmcp/src/handler/server.rs b/crates/rmcp/src/handler/server.rs index 13bb69f5..b182461e 100644 --- a/crates/rmcp/src/handler/server.rs +++ b/crates/rmcp/src/handler/server.rs @@ -4,6 +4,7 @@ use crate::{ service::{NotificationContext, RequestContext, RoleServer, Service, ServiceRole}, }; +pub mod prompt; mod resource; pub mod router; pub mod tool; diff --git a/crates/rmcp/src/handler/server/prompt.rs b/crates/rmcp/src/handler/server/prompt.rs new file mode 100644 index 00000000..181c6537 --- /dev/null +++ b/crates/rmcp/src/handler/server/prompt.rs @@ -0,0 +1,346 @@ +//! Prompt handling infrastructure for MCP servers +//! +//! This module provides the core types and traits for implementing prompt handlers +//! in MCP servers. Prompts allow servers to provide reusable templates for LLM +//! interactions with customizable arguments. + +use std::{any::TypeId, collections::HashMap, future::Future, marker::PhantomData, pin::Pin}; + +use futures::future::BoxFuture; +use schemars::{JsonSchema, schema_for}; +use serde::de::DeserializeOwned; + +use crate::{ + RoleServer, + model::{GetPromptResult, PromptArgument, PromptMessage}, + service::RequestContext, +}; + +/// Context for prompt retrieval operations +pub struct PromptContext<'a, S> { + pub server: &'a S, + pub name: String, + pub arguments: Option>, + pub context: RequestContext, +} + +impl<'a, S> PromptContext<'a, S> { + pub fn new( + server: &'a S, + name: String, + arguments: Option>, + context: RequestContext, + ) -> Self { + Self { + server, + name, + arguments, + context, + } + } + + /// Invoke a prompt handler with parsed arguments + pub async fn invoke(self, handler: H) -> Result + where + H: GetPromptHandler, + S: 'a, + { + handler.handle(self).await + } +} + +/// Trait for handling prompt retrieval +pub trait GetPromptHandler { + fn handle<'a>( + self, + context: PromptContext<'a, S>, + ) -> BoxFuture<'a, Result> + where + S: 'a; +} + +/// Type alias for dynamic prompt handlers +pub type DynGetPromptHandler = dyn for<'a> Fn(PromptContext<'a, S>) -> BoxFuture<'a, Result> + + Send + + Sync; + +/// Adapter type for async methods that return Vec +pub struct AsyncMethodAdapter(PhantomData); + +/// Adapter type for async methods with arguments that return Vec +pub struct AsyncMethodWithArgsAdapter(PhantomData); + +/// Wrapper for parsing prompt arguments +pub struct Arguments(pub T); + +/// Type alias for prompt arguments - matches tool's Parameters pattern +pub type PromptArguments = Arguments; + +impl Arguments { + pub fn into_inner(self) -> T { + self.0 + } +} + +impl JsonSchema for Arguments { + fn schema_name() -> String { + T::schema_name() + } + + fn json_schema(generator: &mut schemars::r#gen::SchemaGenerator) -> schemars::schema::Schema { + T::json_schema(generator) + } +} + +/// Convert a JSON schema into prompt arguments +pub fn arguments_from_schema() -> Option> { + let schema = schema_for!(T); + let schema_value = serde_json::to_value(schema).ok()?; + + // Extract properties from the schema + let properties = schema_value.get("properties")?.as_object()?; + + let required = schema_value + .get("required") + .and_then(|r| r.as_array()) + .map(|arr| { + arr.iter() + .filter_map(|v| v.as_str()) + .collect::>() + }) + .unwrap_or_default(); + + let mut arguments = Vec::new(); + for (name, prop_schema) in properties { + let description = prop_schema + .get("description") + .and_then(|d| d.as_str()) + .map(String::from); + + arguments.push(PromptArgument { + name: name.clone(), + description, + required: Some(required.contains(name.as_str())), + }); + } + + if arguments.is_empty() { + None + } else { + Some(arguments) + } +} + +/// Call [`arguments_from_schema`] with a cache +pub fn cached_arguments_from_schema() -> Option> +{ + thread_local! { + static CACHE_FOR_TYPE: std::sync::RwLock>>> = Default::default(); + }; + CACHE_FOR_TYPE.with(|cache| { + // Try to read from cache first + if let Ok(cache_read) = cache.read() { + if let Some(x) = cache_read.get(&TypeId::of::()) { + return x.clone(); + } + } + + // Compute the value + let args = arguments_from_schema::(); + + // Try to update cache, but don't fail if we can't + if let Ok(mut cache_write) = cache.write() { + cache_write.insert(TypeId::of::(), args.clone()); + } + + args + }) +} + +// Implement GetPromptHandler for async functions returning GetPromptResult +impl GetPromptHandler for F +where + S: Sync, + F: FnOnce(&S, RequestContext) -> Fut + Send + 'static, + Fut: Future> + Send + 'static, +{ + fn handle<'a>( + self, + context: PromptContext<'a, S>, + ) -> BoxFuture<'a, Result> + where + S: 'a, + { + Box::pin(async move { (self)(context.server, context.context).await }) + } +} + +// Implement GetPromptHandler for async functions with parsed arguments +impl GetPromptHandler> for F +where + S: Sync, + F: FnOnce(&S, Arguments, RequestContext) -> Fut + Send + 'static, + Fut: Future> + Send + 'static, + T: DeserializeOwned + 'static, +{ + fn handle<'a>( + self, + context: PromptContext<'a, S>, + ) -> BoxFuture<'a, Result> + where + S: 'a, + { + Box::pin(async move { + // Parse arguments if provided + let args = if let Some(args_map) = context.arguments { + let args_value = serde_json::Value::Object(args_map); + serde_json::from_value::(args_value).map_err(|e| { + crate::Error::invalid_params(format!("Failed to parse arguments: {}", e), None) + })? + } else { + // Try to deserialize from empty object for optional fields + serde_json::from_value::(serde_json::json!({})).map_err(|e| { + crate::Error::invalid_params(format!("Missing required arguments: {}", e), None) + })? + }; + + (self)(context.server, Arguments(args), context.context).await + }) + } +} + +// Implement GetPromptHandler for async methods that return Pin> +impl GetPromptHandler>> for F +where + S: Sync + 'static, + F: for<'a> FnOnce( + &'a S, + RequestContext, + ) -> Pin< + Box, crate::Error>> + Send + 'a>, + > + Send + + 'static, +{ + fn handle<'a>( + self, + context: PromptContext<'a, S>, + ) -> BoxFuture<'a, Result> + where + S: 'a, + { + Box::pin(async move { + let messages = (self)(context.server, context.context).await?; + Ok(GetPromptResult { + description: None, + messages, + }) + }) + } +} + +// Implement GetPromptHandler for async methods with arguments that return Pin> +impl GetPromptHandler, Vec)>> + for F +where + S: Sync + 'static, + T: DeserializeOwned + 'static, + F: for<'a> FnOnce( + &'a S, + Arguments, + RequestContext, + ) -> Pin< + Box, crate::Error>> + Send + 'a>, + > + Send + + 'static, +{ + fn handle<'a>( + self, + context: PromptContext<'a, S>, + ) -> BoxFuture<'a, Result> + where + S: 'a, + { + Box::pin(async move { + // Parse arguments if provided + let args = if let Some(args_map) = context.arguments { + let args_value = serde_json::Value::Object(args_map); + serde_json::from_value::(args_value).map_err(|e| { + crate::Error::invalid_params(format!("Failed to parse arguments: {}", e), None) + })? + } else { + // Try to deserialize from empty object for optional fields + serde_json::from_value::(serde_json::json!({})).map_err(|e| { + crate::Error::invalid_params(format!("Missing required arguments: {}", e), None) + })? + }; + + let messages = (self)(context.server, Arguments(args), context.context).await?; + Ok(GetPromptResult { + description: None, + messages, + }) + }) + } +} + +// Implement GetPromptHandler for async functions returning Vec +impl GetPromptHandler)> for F +where + S: Sync, + F: FnOnce(&S, RequestContext) -> Fut + Send + 'static, + Fut: Future, crate::Error>> + Send + 'static, +{ + fn handle<'a>( + self, + context: PromptContext<'a, S>, + ) -> BoxFuture<'a, Result> + where + S: 'a, + { + Box::pin(async move { + let messages = (self)(context.server, context.context).await?; + Ok(GetPromptResult { + description: None, + messages, + }) + }) + } +} + +// Implement GetPromptHandler for async functions with parsed arguments returning Vec +impl GetPromptHandler, Vec)> for F +where + S: Sync, + F: FnOnce(&S, Arguments, RequestContext) -> Fut + Send + 'static, + Fut: Future, crate::Error>> + Send + 'static, + T: DeserializeOwned + 'static, +{ + fn handle<'a>( + self, + context: PromptContext<'a, S>, + ) -> BoxFuture<'a, Result> + where + S: 'a, + { + Box::pin(async move { + // Parse arguments if provided + let args = if let Some(args_map) = context.arguments { + let args_value = serde_json::Value::Object(args_map); + serde_json::from_value::(args_value).map_err(|e| { + crate::Error::invalid_params(format!("Failed to parse arguments: {}", e), None) + })? + } else { + // Try to deserialize from empty object for optional fields + serde_json::from_value::(serde_json::json!({})).map_err(|e| { + crate::Error::invalid_params(format!("Missing required arguments: {}", e), None) + })? + }; + + let messages = (self)(context.server, Arguments(args), context.context).await?; + Ok(GetPromptResult { + description: None, + messages, + }) + }) + } +} diff --git a/crates/rmcp/src/handler/server/router.rs b/crates/rmcp/src/handler/server/router.rs index 7401b39f..23b15dde 100644 --- a/crates/rmcp/src/handler/server/router.rs +++ b/crates/rmcp/src/handler/server/router.rs @@ -1,18 +1,21 @@ use std::sync::Arc; +use prompt::{IntoPromptRoute, PromptRoute}; use tool::{IntoToolRoute, ToolRoute}; use super::ServerHandler; use crate::{ RoleServer, Service, - model::{ClientRequest, ListToolsResult, ServerResult}, + model::{ClientRequest, ListPromptsResult, ListToolsResult, ServerResult}, service::NotificationContext, }; +pub mod prompt; pub mod tool; pub struct Router { pub tool_router: tool::ToolRouter, + pub prompt_router: prompt::PromptRouter, pub service: Arc, } @@ -23,6 +26,7 @@ where pub fn new(service: S) -> Self { Self { tool_router: tool::ToolRouter::new(), + prompt_router: prompt::PromptRouter::new(), service: Arc::new(service), } } @@ -41,6 +45,21 @@ where } self } + + pub fn with_prompt(mut self, route: R) -> Self + where + R: IntoPromptRoute, + { + self.prompt_router.add_route(route.into_prompt_route()); + self + } + + pub fn with_prompts(mut self, routes: impl IntoIterator>) -> Self { + for route in routes { + self.prompt_router.add_route(route); + } + self + } } impl Service for Router @@ -86,6 +105,29 @@ where next_cursor: None, })) } + ClientRequest::GetPromptRequest(request) => { + if self.prompt_router.has_route(request.params.name.as_ref()) { + let prompt_context = crate::handler::server::prompt::PromptContext::new( + self.service.as_ref(), + request.params.name, + request.params.arguments, + context, + ); + let result = self.prompt_router.get_prompt(prompt_context).await?; + Ok(ServerResult::GetPromptResult(result)) + } else { + self.service + .handle_request(ClientRequest::GetPromptRequest(request), context) + .await + } + } + ClientRequest::ListPromptsRequest(_) => { + let prompts = self.prompt_router.list_all(); + Ok(ServerResult::ListPromptsResult(ListPromptsResult { + prompts, + next_cursor: None, + })) + } rest => self.service.handle_request(rest, context).await, } } diff --git a/crates/rmcp/src/handler/server/router/prompt.rs b/crates/rmcp/src/handler/server/router/prompt.rs index e69de29b..c05f15ec 100644 --- a/crates/rmcp/src/handler/server/router/prompt.rs +++ b/crates/rmcp/src/handler/server/router/prompt.rs @@ -0,0 +1,212 @@ +use std::{borrow::Cow, sync::Arc}; + +use futures::{FutureExt, future::BoxFuture}; + +use crate::{ + handler::server::prompt::{DynGetPromptHandler, GetPromptHandler, PromptContext}, + model::{GetPromptResult, Prompt}, +}; + +pub struct PromptRoute { + #[allow(clippy::type_complexity)] + pub get: Arc>, + pub attr: crate::model::Prompt, +} + +impl std::fmt::Debug for PromptRoute { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("PromptRoute") + .field("name", &self.attr.name) + .field("description", &self.attr.description) + .field("arguments", &self.attr.arguments) + .finish() + } +} + +impl Clone for PromptRoute { + fn clone(&self) -> Self { + Self { + get: self.get.clone(), + attr: self.attr.clone(), + } + } +} + +impl PromptRoute { + pub fn new(attr: impl Into, handler: H) -> Self + where + H: GetPromptHandler + Send + Sync + Clone + 'static, + { + Self { + get: Arc::new(move |context: PromptContext| { + let handler = handler.clone(); + context.invoke(handler).boxed() + }), + attr: attr.into(), + } + } + + pub fn new_dyn(attr: impl Into, handler: H) -> Self + where + H: for<'a> Fn(PromptContext<'a, S>) -> BoxFuture<'a, Result> + + Send + + Sync + + 'static, + { + Self { + get: Arc::new(handler), + attr: attr.into(), + } + } + + pub fn name(&self) -> &str { + &self.attr.name + } +} + +pub trait IntoPromptRoute { + fn into_prompt_route(self) -> PromptRoute; +} + +impl IntoPromptRoute for (P, H) +where + S: Send + Sync + 'static, + A: 'static, + H: GetPromptHandler + Send + Sync + Clone + 'static, + P: Into, +{ + fn into_prompt_route(self) -> PromptRoute { + PromptRoute::new(self.0.into(), self.1) + } +} + +impl IntoPromptRoute for PromptRoute +where + S: Send + Sync + 'static, +{ + fn into_prompt_route(self) -> PromptRoute { + self + } +} + +/// Adapter for functions generated by the #[prompt] macro +pub struct PromptAttrGenerateFunctionAdapter; + +impl IntoPromptRoute for F +where + S: Send + Sync + 'static, + F: Fn() -> PromptRoute, +{ + fn into_prompt_route(self) -> PromptRoute { + (self)() + } +} + +#[derive(Debug)] +pub struct PromptRouter { + #[allow(clippy::type_complexity)] + pub map: std::collections::HashMap, PromptRoute>, +} + +impl Default for PromptRouter { + fn default() -> Self { + Self { + map: std::collections::HashMap::new(), + } + } +} + +impl Clone for PromptRouter { + fn clone(&self) -> Self { + Self { + map: self.map.clone(), + } + } +} + +impl IntoIterator for PromptRouter { + type Item = PromptRoute; + type IntoIter = std::collections::hash_map::IntoValues, PromptRoute>; + + fn into_iter(self) -> Self::IntoIter { + self.map.into_values() + } +} + +impl PromptRouter +where + S: Send + Sync + 'static, +{ + pub fn new() -> Self { + Self { + map: std::collections::HashMap::new(), + } + } + + pub fn with_route(mut self, route: R) -> Self + where + R: IntoPromptRoute, + { + self.add_route(route.into_prompt_route()); + self + } + + pub fn add_route(&mut self, item: PromptRoute) { + self.map.insert(item.attr.name.clone().into(), item); + } + + pub fn merge(&mut self, other: PromptRouter) { + for item in other.map.into_values() { + self.add_route(item); + } + } + + pub fn remove_route(&mut self, name: &str) { + self.map.remove(name); + } + + pub fn has_route(&self, name: &str) -> bool { + self.map.contains_key(name) + } + + pub async fn get_prompt( + &self, + context: PromptContext<'_, S>, + ) -> Result { + let item = self.map.get(context.name.as_str()).ok_or_else(|| { + crate::Error::invalid_params( + format!("prompt '{}' not found", context.name), + Some(serde_json::json!({ + "available_prompts": self.list_all().iter().map(|p| &p.name).collect::>() + })), + ) + })?; + (item.get)(context).await + } + + pub fn list_all(&self) -> Vec { + self.map.values().map(|item| item.attr.clone()).collect() + } +} + +impl std::ops::Add> for PromptRouter +where + S: Send + Sync + 'static, +{ + type Output = Self; + + fn add(mut self, other: PromptRouter) -> Self::Output { + self.merge(other); + self + } +} + +impl std::ops::AddAssign> for PromptRouter +where + S: Send + Sync + 'static, +{ + fn add_assign(&mut self, other: PromptRouter) { + self.merge(other); + } +} + diff --git a/crates/rmcp/src/service.rs b/crates/rmcp/src/service.rs index e1e72c9d..c70c2d6f 100644 --- a/crates/rmcp/src/service.rs +++ b/crates/rmcp/src/service.rs @@ -162,12 +162,12 @@ pub trait DynService: Send + Sync { &self, request: R::PeerReq, context: RequestContext, - ) -> BoxFuture>; + ) -> BoxFuture<'_, Result>; fn handle_notification( &self, notification: R::PeerNot, context: NotificationContext, - ) -> BoxFuture>; + ) -> BoxFuture<'_, Result<(), McpError>>; fn get_info(&self) -> R::Info; } @@ -176,14 +176,14 @@ impl> DynService for S { &self, request: R::PeerReq, context: RequestContext, - ) -> BoxFuture> { + ) -> BoxFuture<'_, Result> { Box::pin(self.handle_request(request, context)) } fn handle_notification( &self, notification: R::PeerNot, context: NotificationContext, - ) -> BoxFuture> { + ) -> BoxFuture<'_, Result<(), McpError>> { Box::pin(self.handle_notification(notification, context)) } fn get_info(&self) -> R::Info { diff --git a/crates/rmcp/tests/test_prompts.rs b/crates/rmcp/tests/test_prompts.rs new file mode 100644 index 00000000..ff6a1374 --- /dev/null +++ b/crates/rmcp/tests/test_prompts.rs @@ -0,0 +1,210 @@ +use rmcp::{ + handler::server::{ServerHandler, prompt::Arguments, router::Router}, + model::{GetPromptResult, PromptMessage, PromptMessageRole}, + service::{RequestContext, RoleServer}, +}; +use schemars::JsonSchema; +use serde::{Deserialize, Serialize}; + +/// Test prompt arguments for code review +#[derive(Debug, Serialize, Deserialize, JsonSchema)] +struct CodeReviewArgs { + /// The file path to review + file_path: String, + /// Focus areas for the review + #[serde(skip_serializing_if = "Option::is_none")] + focus_areas: Option>, +} + +/// Test prompt arguments for debugging +#[derive(Debug, Serialize, Deserialize, JsonSchema)] +struct DebugAssistantArgs { + /// The error message to debug + error_message: String, + /// The programming language + language: String, + /// Additional context + #[serde(skip_serializing_if = "Option::is_none")] + context: Option, +} + +struct TestPromptServer; + +impl ServerHandler for TestPromptServer {} + +/// A simple code review prompt +#[rmcp::prompt( + name = "code_review", + description = "Reviews code for best practices and potential issues" +)] +async fn code_review_prompt( + _server: &TestPromptServer, + Arguments(args): Arguments, + _ctx: RequestContext, +) -> Result, rmcp::Error> { + let mut messages = vec![PromptMessage::new_text( + PromptMessageRole::User, + format!("Please review the code in file: {}", args.file_path), + )]; + + if let Some(focus_areas) = args.focus_areas { + messages.push(PromptMessage::new_text( + PromptMessageRole::User, + format!("Focus on these areas: {}", focus_areas.join(", ")), + )); + } + + messages.push(PromptMessage::new_text( + PromptMessageRole::Assistant, + "I'll help you review this code. Let me analyze it for best practices, potential bugs, and improvement opportunities.", + )); + + Ok(messages) +} + +/// A debugging assistant prompt +#[rmcp::prompt(name = "debug_assistant")] +async fn debug_assistant_prompt( + _server: &TestPromptServer, + Arguments(args): Arguments, + _ctx: RequestContext, +) -> Result { + let mut messages = vec![PromptMessage::new_text( + PromptMessageRole::User, + format!( + "I'm getting this error in my {} code: {}", + args.language, args.error_message + ), + )]; + + if let Some(context) = args.context { + messages.push(PromptMessage::new_text( + PromptMessageRole::User, + format!("Additional context: {}", context), + )); + } + + messages.push(PromptMessage::new_text( + PromptMessageRole::Assistant, + format!( + "I'll help you debug this {} error. Let me analyze the error message and provide solutions.", + args.language + ), + )); + + Ok(GetPromptResult { + description: Some("Helps debug programming errors with detailed analysis".to_string()), + messages, + }) +} + +/// A simple greeting prompt without arguments +#[rmcp::prompt] +async fn greeting_prompt( + _server: &TestPromptServer, + _ctx: RequestContext, +) -> Result, rmcp::Error> { + Ok(vec![ + PromptMessage::new_text( + PromptMessageRole::User, + "Hello! I'd like to start a conversation.", + ), + PromptMessage::new_text( + PromptMessageRole::Assistant, + "Hello! I'm here to help. What would you like to discuss today?", + ), + ]) +} + +#[tokio::test] +async fn test_prompt_macro_basic() { + // Test that the prompt attribute functions are generated + let greeting = greeting_prompt_prompt_attr(); + assert_eq!(greeting.name, "greeting_prompt"); + assert_eq!( + greeting.description.as_deref(), + Some("A simple greeting prompt without arguments") + ); + assert!(greeting.arguments.is_none()); + + let code_review = code_review_prompt_prompt_attr(); + assert_eq!(code_review.name, "code_review"); + assert_eq!( + code_review.description.as_deref(), + Some("Reviews code for best practices and potential issues") + ); + assert!(code_review.arguments.is_some()); + + let debug_assistant = debug_assistant_prompt_prompt_attr(); + assert_eq!(debug_assistant.name, "debug_assistant"); + assert!(debug_assistant.arguments.is_some()); +} + +#[tokio::test] +async fn test_prompt_router() { + // Create prompt routes manually + let greeting_route = rmcp::handler::server::router::prompt::PromptRoute::new( + greeting_prompt_prompt_attr(), + greeting_prompt, + ); + let code_review_route = rmcp::handler::server::router::prompt::PromptRoute::new( + code_review_prompt_prompt_attr(), + code_review_prompt, + ); + let debug_assistant_route = rmcp::handler::server::router::prompt::PromptRoute::new( + debug_assistant_prompt_prompt_attr(), + debug_assistant_prompt, + ); + + let server = Router::new(TestPromptServer) + .with_prompt(greeting_route) + .with_prompt(code_review_route) + .with_prompt(debug_assistant_route); + + // Test list prompts + let prompts = server.prompt_router.list_all(); + assert_eq!(prompts.len(), 3); + + let prompt_names: Vec<_> = prompts.iter().map(|p| p.name.as_str()).collect(); + assert!(prompt_names.contains(&"greeting_prompt")); + assert!(prompt_names.contains(&"code_review")); + assert!(prompt_names.contains(&"debug_assistant")); +} + +#[tokio::test] +async fn test_prompt_arguments_schema() { + let code_review = code_review_prompt_prompt_attr(); + let args = code_review.arguments.unwrap(); + + // Should have two arguments: file_path (required) and focus_areas (optional) + assert_eq!(args.len(), 2); + + let file_path_arg = args.iter().find(|a| a.name == "file_path").unwrap(); + assert_eq!(file_path_arg.required, Some(true)); + assert_eq!( + file_path_arg.description.as_deref(), + Some("The file path to review") + ); + + let focus_areas_arg = args.iter().find(|a| a.name == "focus_areas").unwrap(); + assert_eq!(focus_areas_arg.required, Some(false)); + assert_eq!( + focus_areas_arg.description.as_deref(), + Some("Focus areas for the review") + ); +} + +#[tokio::test] +async fn test_prompt_route_creation() { + // Test that prompt routes can be created + let route = rmcp::handler::server::router::prompt::PromptRoute::new( + code_review_prompt_prompt_attr(), + code_review_prompt, + ); + + assert_eq!(route.name(), "code_review"); +} + +// Additional integration tests would require a full server setup +// These tests demonstrate the basic functionality of the prompt system + diff --git a/examples/servers/Cargo.toml b/examples/servers/Cargo.toml index b7d43c5d..a9b62bb7 100644 --- a/examples/servers/Cargo.toml +++ b/examples/servers/Cargo.toml @@ -75,6 +75,10 @@ path = "src/complex_auth_sse.rs" name = "servers_simple_auth_sse" path = "src/simple_auth_sse.rs" +[[example]] +name = "servers_prompt_stdio" +path = "src/prompt_stdio.rs" + [[example]] name = "counter_hyper_streamable_http" path = "src/counter_hyper_streamable_http.rs" diff --git a/examples/servers/README.md b/examples/servers/README.md index c80b130a..9b8f31a9 100644 --- a/examples/servers/README.md +++ b/examples/servers/README.md @@ -72,6 +72,16 @@ A simplified OAuth example showing basic token-based authentication. - Simplified authentication flow - Good starting point for adding authentication to MCP servers +### Prompt Standard I/O Server (`prompt_stdio.rs`) + +A server demonstrating the prompt framework capabilities. + +- Shows how to implement prompts in MCP servers +- Provides code review and debugging prompts +- Demonstrates prompt argument handling with JSON schema +- Uses standard I/O transport +- Good example of prompt implementation patterns + ## How to Run Each example can be run using Cargo: @@ -97,6 +107,9 @@ cargo run --example servers_complex_auth_sse # Run the simple OAuth SSE server cargo run --example servers_simple_auth_sse + +# Run the prompt standard I/O server +cargo run --example servers_prompt_stdio ``` ## Testing with MCP Inspector diff --git a/examples/servers/src/prompt_stdio.rs b/examples/servers/src/prompt_stdio.rs new file mode 100644 index 00000000..eab5fcae --- /dev/null +++ b/examples/servers/src/prompt_stdio.rs @@ -0,0 +1,145 @@ +use anyhow::Result; +use rmcp::{ + Error as McpError, RoleServer, ServerHandler, ServiceExt, + handler::server::prompt::arguments_from_schema, model::*, schemars, service::RequestContext, + transport::stdio, +}; +use serde::{Deserialize, Serialize}; +use tracing_subscriber::EnvFilter; + +#[derive(Debug, Serialize, Deserialize, schemars::JsonSchema)] +struct CodeReviewArgs { + /// The file path to review + file_path: String, + /// Language for syntax highlighting + #[serde(default = "default_language")] + language: String, +} + +fn default_language() -> String { + "rust".to_string() +} + +#[derive(Clone, Debug, Default)] +struct PromptExampleServer; + +impl ServerHandler for PromptExampleServer { + fn get_info(&self) -> ServerInfo { + ServerInfo { + server_info: Implementation { + name: "Prompt Example Server".to_string(), + version: "1.0.0".to_string(), + }, + instructions: Some( + concat!( + "This server demonstrates the prompt framework capabilities. ", + "It provides code review and debugging prompts." + ) + .to_string(), + ), + capabilities: ServerCapabilities::builder().enable_prompts().build(), + ..Default::default() + } + } + + async fn list_prompts( + &self, + _request: Option, + _: RequestContext, + ) -> Result { + Ok(ListPromptsResult { + next_cursor: None, + prompts: vec![ + Prompt { + name: "code_review".to_string(), + description: Some( + "Reviews code for best practices and potential issues".to_string(), + ), + arguments: arguments_from_schema::(), + }, + Prompt { + name: "debug_helper".to_string(), + description: Some("Interactive debugging assistant".to_string()), + arguments: None, + }, + ], + }) + } + + async fn get_prompt( + &self, + GetPromptRequestParam { name, arguments }: GetPromptRequestParam, + _: RequestContext, + ) -> Result { + match name.as_str() { + "code_review" => { + // Parse arguments + let args = if let Some(args_map) = arguments { + serde_json::from_value::(serde_json::Value::Object(args_map)) + .map_err(|e| { + McpError::invalid_params(format!("Invalid arguments: {}", e), None) + })? + } else { + return Err(McpError::invalid_params("Missing required arguments", None)); + }; + + Ok(GetPromptResult { + description: None, + messages: vec![ + PromptMessage::new_text( + PromptMessageRole::User, + format!( + "Please review the {} code in file: {}", + args.language, args.file_path + ), + ), + PromptMessage::new_text( + PromptMessageRole::Assistant, + "I'll analyze this code for best practices, potential bugs, and improvements.", + ), + ], + }) + } + "debug_helper" => Ok(GetPromptResult { + description: Some("Interactive debugging assistant".to_string()), + messages: vec![ + PromptMessage::new_text( + PromptMessageRole::Assistant, + "You are a helpful debugging assistant. Ask the user about their error and help them solve it.", + ), + PromptMessage::new_text( + PromptMessageRole::User, + "I need help debugging an issue in my code.", + ), + PromptMessage::new_text( + PromptMessageRole::Assistant, + "I'd be happy to help you debug your code! Please tell me:\n1. What error or issue are you experiencing?\n2. What programming language are you using?\n3. What were you trying to accomplish?", + ), + ], + }), + _ => Err(McpError::invalid_params( + format!("Unknown prompt: {}", name), + Some(serde_json::json!({ + "available_prompts": ["code_review", "debug_helper"] + })), + )), + } + } +} + +#[tokio::main] +async fn main() -> Result<()> { + // Initialize the tracing subscriber + tracing_subscriber::fmt() + .with_env_filter(EnvFilter::from_default_env().add_directive(tracing::Level::INFO.into())) + .with_writer(std::io::stderr) + .init(); + + tracing::info!("Starting Prompt Example MCP server"); + + // Create and serve the prompt server + let service = PromptExampleServer.serve(stdio()).await?; + + service.waiting().await?; + Ok(()) +} From aab9e781f527bc0bb35018b6bb460567d1dc297c Mon Sep 17 00:00:00 2001 From: snowmead Date: Fri, 11 Jul 2025 21:37:10 -0400 Subject: [PATCH 2/3] fmt --- crates/rmcp-macros/src/common.rs | 1 - crates/rmcp-macros/src/prompt.rs | 1 - crates/rmcp-macros/src/prompt_handler.rs | 1 - crates/rmcp-macros/src/prompt_router.rs | 1 - crates/rmcp/src/handler/server/router/prompt.rs | 1 - crates/rmcp/tests/test_prompts.rs | 1 - 6 files changed, 6 deletions(-) diff --git a/crates/rmcp-macros/src/common.rs b/crates/rmcp-macros/src/common.rs index d009471f..96fa946b 100644 --- a/crates/rmcp-macros/src/common.rs +++ b/crates/rmcp-macros/src/common.rs @@ -38,4 +38,3 @@ pub fn extract_doc_line(existing_docs: Option, attr: &Attribute) -> Opti _ => None, } } - diff --git a/crates/rmcp-macros/src/prompt.rs b/crates/rmcp-macros/src/prompt.rs index c1a6b8c9..79ac03f8 100644 --- a/crates/rmcp-macros/src/prompt.rs +++ b/crates/rmcp-macros/src/prompt.rs @@ -201,4 +201,3 @@ mod test { Ok(()) } } - diff --git a/crates/rmcp-macros/src/prompt_handler.rs b/crates/rmcp-macros/src/prompt_handler.rs index 9d8fcd70..87dd85ac 100644 --- a/crates/rmcp-macros/src/prompt_handler.rs +++ b/crates/rmcp-macros/src/prompt_handler.rs @@ -139,4 +139,3 @@ mod test { Ok(()) } } - diff --git a/crates/rmcp-macros/src/prompt_router.rs b/crates/rmcp-macros/src/prompt_router.rs index 791a9912..206bcc1e 100644 --- a/crates/rmcp-macros/src/prompt_router.rs +++ b/crates/rmcp-macros/src/prompt_router.rs @@ -94,4 +94,3 @@ mod test { Ok(()) } } - diff --git a/crates/rmcp/src/handler/server/router/prompt.rs b/crates/rmcp/src/handler/server/router/prompt.rs index c05f15ec..e7c552dd 100644 --- a/crates/rmcp/src/handler/server/router/prompt.rs +++ b/crates/rmcp/src/handler/server/router/prompt.rs @@ -209,4 +209,3 @@ where self.merge(other); } } - diff --git a/crates/rmcp/tests/test_prompts.rs b/crates/rmcp/tests/test_prompts.rs index ff6a1374..b8fa9985 100644 --- a/crates/rmcp/tests/test_prompts.rs +++ b/crates/rmcp/tests/test_prompts.rs @@ -207,4 +207,3 @@ async fn test_prompt_route_creation() { // Additional integration tests would require a full server setup // These tests demonstrate the basic functionality of the prompt system - From 5795debf5713f1cce324045efb0da8f14dd7e195 Mon Sep 17 00:00:00 2001 From: snowmead Date: Mon, 14 Jul 2025 15:27:27 -0400 Subject: [PATCH 3/3] refactor: Unify parameter handling between tools and prompts MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Replace Arguments with Parameters for consistent API - Create shared common module for tool/prompt utilities - Modernize async handling with futures::future::BoxFuture - Move cached_schema_for_type to common module for reuse - Update error types from rmcp::Error to rmcp::ErrorData - Add comprehensive trait implementations for parameter extraction 🤖 Generated with Claude Code Co-Authored-By: Claude --- crates/rmcp-macros/src/common.rs | 27 +- crates/rmcp-macros/src/lib.rs | 8 +- crates/rmcp-macros/src/prompt.rs | 44 +- crates/rmcp-macros/src/prompt_handler.rs | 4 +- crates/rmcp-macros/src/prompt_router.rs | 21 +- crates/rmcp-macros/src/tool.rs | 24 +- crates/rmcp/src/handler/server.rs | 1 + crates/rmcp/src/handler/server/common.rs | 158 +++++ crates/rmcp/src/handler/server/prompt.rs | 583 ++++++++++-------- .../rmcp/src/handler/server/router/prompt.rs | 12 +- crates/rmcp/src/handler/server/tool.rs | 177 ++---- crates/rmcp/tests/test_prompt_handler.rs | 143 +++++ .../tests/test_prompt_macro_annotations.rs | 291 +++++++++ crates/rmcp/tests/test_prompt_macros.rs | 383 ++++++++++++ crates/rmcp/tests/test_prompt_routers.rs | 105 ++++ crates/rmcp/tests/test_prompts.rs | 209 ------- examples/servers/Cargo.toml | 7 +- examples/servers/src/common/counter.rs | 209 +++++-- examples/servers/src/prompt_stdio.rs | 145 ----- 19 files changed, 1689 insertions(+), 862 deletions(-) create mode 100644 crates/rmcp/src/handler/server/common.rs create mode 100644 crates/rmcp/tests/test_prompt_handler.rs create mode 100644 crates/rmcp/tests/test_prompt_macro_annotations.rs create mode 100644 crates/rmcp/tests/test_prompt_macros.rs create mode 100644 crates/rmcp/tests/test_prompt_routers.rs delete mode 100644 crates/rmcp/tests/test_prompts.rs delete mode 100644 examples/servers/src/prompt_stdio.rs diff --git a/crates/rmcp-macros/src/common.rs b/crates/rmcp-macros/src/common.rs index 96fa946b..39d44f1e 100644 --- a/crates/rmcp-macros/src/common.rs +++ b/crates/rmcp-macros/src/common.rs @@ -1,7 +1,7 @@ //! Common utilities shared between different macro implementations use quote::quote; -use syn::{Attribute, Expr}; +use syn::{Attribute, Expr, FnArg, ImplItemFn, Signature, Type}; /// Parse a None expression pub fn none_expr() -> syn::Result { @@ -38,3 +38,28 @@ pub fn extract_doc_line(existing_docs: Option, attr: &Attribute) -> Opti _ => None, } } + +/// Find Parameters type in function signature +/// Returns the full Parameters type if found +pub fn find_parameters_type_in_sig(sig: &Signature) -> Option> { + sig.inputs.iter().find_map(|input| { + if let FnArg::Typed(pat_type) = input { + if let Type::Path(type_path) = &*pat_type.ty { + if type_path + .path + .segments + .last() + .is_some_and(|type_name| type_name.ident == "Parameters") + { + return Some(pat_type.ty.clone()); + } + } + } + None + }) +} + +/// Find Parameters type in ImplItemFn +pub fn find_parameters_type_impl(fn_item: &ImplItemFn) -> Option> { + find_parameters_type_in_sig(&fn_item.sig) +} diff --git a/crates/rmcp-macros/src/lib.rs b/crates/rmcp-macros/src/lib.rs index efc002a1..d827fe74 100644 --- a/crates/rmcp-macros/src/lib.rs +++ b/crates/rmcp-macros/src/lib.rs @@ -177,13 +177,13 @@ pub fn tool_handler(attr: TokenStream, input: TokenStream) -> TokenStream { /// | :- | :- | :- | /// | `name` | `String` | The name of the prompt. If not provided, it defaults to the function name. | /// | `description` | `String` | A description of the prompt. The document of this function will be used if not provided. | -/// | `arguments` | `Expr` | Arguments that can be passed to the prompt. If not provided, it will use arguments from `Arguments` or `PromptArguments` parameter type. | +/// | `arguments` | `Expr` | An expression that evaluates to `Option>` defining the prompt's arguments. If not provided, it will automatically generate arguments from the `Parameters` type found in the function signature. | /// /// ## Example /// /// ```rust,ignore /// #[prompt(name = "code_review", description = "Reviews code for best practices")] -/// pub async fn code_review_prompt(&self, Arguments(args): Arguments) -> Result> { +/// pub async fn code_review_prompt(&self, Parameters(args): Parameters) -> Result> { /// // Generate prompt messages based on arguments /// } /// ``` @@ -213,8 +213,8 @@ pub fn prompt(attr: TokenStream, input: TokenStream) -> TokenStream { /// #[prompt_router] /// impl MyPromptHandler { /// #[prompt] -/// pub async fn greeting_prompt(&self) -> Result, Error> { -/// // Generate greeting prompt +/// pub async fn greeting_prompt(&self, Parameters(args): Parameters) -> Result, Error> { +/// // Generate greeting prompt using args /// } /// /// pub fn new() -> Self { diff --git a/crates/rmcp-macros/src/prompt.rs b/crates/rmcp-macros/src/prompt.rs index 79ac03f8..68f853ab 100644 --- a/crates/rmcp-macros/src/prompt.rs +++ b/crates/rmcp-macros/src/prompt.rs @@ -12,7 +12,7 @@ pub struct PromptAttribute { pub name: Option, /// Optional description of what the prompt does pub description: Option, - /// Arguments that can be passed to customize the prompt + /// Arguments that can be passed to the prompt pub arguments: Option, } @@ -59,39 +59,17 @@ pub fn prompt(attr: TokenStream, input: TokenStream) -> syn::Result let prompt_attr_fn_ident = format_ident!("{}_prompt_attr", fn_ident); - // Try to find prompt arguments from function parameters + // Try to find prompt parameters from function parameters let arguments_expr = if let Some(arguments) = attribute.arguments { arguments } else { - // Look for a type named Arguments or PromptArguments in the function signature - let args_ty = fn_item.sig.inputs.iter().find_map(|input| { - if let syn::FnArg::Typed(pat_type) = input { - if let syn::Type::Path(type_path) = &*pat_type.ty { - if let Some(last_segment) = type_path.path.segments.last() { - if last_segment.ident == "Arguments" - || last_segment.ident == "PromptArguments" - { - // Extract the inner type from Arguments - if let syn::PathArguments::AngleBracketed(args) = - &last_segment.arguments - { - if let Some(syn::GenericArgument::Type(inner_ty)) = - args.args.first() - { - return Some(inner_ty.clone()); - } - } - } - } - } - } - None - }); + // Look for a type named Parameters in the function signature + let params_ty = crate::common::find_parameters_type_impl(&fn_item); - if let Some(args_ty) = args_ty { + if let Some(params_ty) = params_ty { // Generate arguments from the type's schema with caching syn::parse2::(quote! { - rmcp::handler::server::prompt::cached_arguments_from_schema::<#args_ty>() + rmcp::handler::server::prompt::cached_arguments_from_schema::<#params_ty>() })? } else { // No arguments @@ -112,10 +90,10 @@ pub fn prompt(attr: TokenStream, input: TokenStream) -> syn::Result }; let prompt_attr_fn = resolved_prompt_attr.into_fn(prompt_attr_fn_ident.clone())?; - // Modify the input function for async support + // Modify the input function for async support (same as tool macro) if fn_item.sig.asyncness.is_some() { // 1. remove asyncness from sig - // 2. make return type: `std::pin::Pin + Send + '_>>` + // 2. make return type: `futures::future::BoxFuture<'_, #ReturnType>` // 3. make body: { Box::pin(async move { #body }) } let new_output = syn::parse2::({ let mut lt = quote! { 'static }; @@ -130,10 +108,10 @@ pub fn prompt(attr: TokenStream, input: TokenStream) -> syn::Result } match &fn_item.sig.output { syn::ReturnType::Default => { - quote! { -> std::pin::Pin + Send + #lt>> } + quote! { -> futures::future::BoxFuture<#lt, ()> } } syn::ReturnType::Type(_, ty) => { - quote! { -> std::pin::Pin + Send + #lt>> } + quote! { -> futures::future::BoxFuture<#lt, #ty> } } } })?; @@ -163,7 +141,7 @@ mod test { description = "An example prompt" }; let input = quote! { - async fn example_prompt(&self, Arguments(args): Arguments) -> Result { + async fn example_prompt(&self, Parameters(args): Parameters) -> Result { Ok("Example prompt response".to_string()) } }; diff --git a/crates/rmcp-macros/src/prompt_handler.rs b/crates/rmcp-macros/src/prompt_handler.rs index 87dd85ac..f6f46d6f 100644 --- a/crates/rmcp-macros/src/prompt_handler.rs +++ b/crates/rmcp-macros/src/prompt_handler.rs @@ -29,7 +29,7 @@ pub fn prompt_handler(attr: TokenStream, input: TokenStream) -> syn::Result, - ) -> Result { + ) -> Result { let prompt_context = rmcp::handler::server::prompt::PromptContext::new( self, request.name, @@ -46,7 +46,7 @@ pub fn prompt_handler(attr: TokenStream, input: TokenStream) -> syn::Result, _context: RequestContext, - ) -> Result { + ) -> Result { let prompts = #router_expr.list_all(); Ok(ListPromptsResult { prompts, diff --git a/crates/rmcp-macros/src/prompt_router.rs b/crates/rmcp-macros/src/prompt_router.rs index 206bcc1e..6abd96c9 100644 --- a/crates/rmcp-macros/src/prompt_router.rs +++ b/crates/rmcp-macros/src/prompt_router.rs @@ -42,8 +42,23 @@ pub fn prompt_router(attr: TokenStream, input: TokenStream) -> syn::Result) -> Result, Error> { + async fn code_review_prompt(&self, Parameters(args): Parameters) -> Result, Error> { Ok(vec![]) } } diff --git a/crates/rmcp-macros/src/tool.rs b/crates/rmcp-macros/src/tool.rs index f083ce31..102a71c5 100644 --- a/crates/rmcp-macros/src/tool.rs +++ b/crates/rmcp-macros/src/tool.rs @@ -102,30 +102,16 @@ pub fn tool(attr: TokenStream, input: TokenStream) -> syn::Result { input_schema } else { // try to find some parameters wrapper in the function - let params_ty = fn_item.sig.inputs.iter().find_map(|input| { - if let syn::FnArg::Typed(pat_type) = input { - if let syn::Type::Path(type_path) = &*pat_type.ty { - if type_path - .path - .segments - .last() - .is_some_and(|type_name| type_name.ident == "Parameters") - { - return Some(pat_type.ty.clone()); - } - } - } - None - }); + let params_ty = crate::common::find_parameters_type_impl(&fn_item); if let Some(params_ty) = params_ty { // if found, use the Parameters schema syn::parse2::(quote! { - rmcp::handler::server::tool::cached_schema_for_type::<#params_ty>() + rmcp::handler::server::common::cached_schema_for_type::<#params_ty>() })? } else { // if not found, use the default EmptyObject schema syn::parse2::(quote! { - rmcp::handler::server::tool::cached_schema_for_type::() + rmcp::handler::server::common::cached_schema_for_type::() })? } }; @@ -186,10 +172,10 @@ pub fn tool(attr: TokenStream, input: TokenStream) -> syn::Result { } match &fn_item.sig.output { syn::ReturnType::Default => { - quote! { -> std::pin::Pin + Send + #lt>> } + quote! { -> futures::future::BoxFuture<#lt, ()> } } syn::ReturnType::Type(_, ty) => { - quote! { -> std::pin::Pin + Send + #lt>> } + quote! { -> futures::future::BoxFuture<#lt, #ty> } } } })?; diff --git a/crates/rmcp/src/handler/server.rs b/crates/rmcp/src/handler/server.rs index b182461e..06929a8f 100644 --- a/crates/rmcp/src/handler/server.rs +++ b/crates/rmcp/src/handler/server.rs @@ -4,6 +4,7 @@ use crate::{ service::{NotificationContext, RequestContext, RoleServer, Service, ServiceRole}, }; +pub mod common; pub mod prompt; mod resource; pub mod router; diff --git a/crates/rmcp/src/handler/server/common.rs b/crates/rmcp/src/handler/server/common.rs new file mode 100644 index 00000000..b5b0189f --- /dev/null +++ b/crates/rmcp/src/handler/server/common.rs @@ -0,0 +1,158 @@ +//! Common utilities shared between tool and prompt handlers + +use std::{any::TypeId, collections::HashMap, sync::Arc}; + +use schemars::JsonSchema; + +use crate::{ + RoleServer, model::JsonObject, schemars::generate::SchemaSettings, service::RequestContext, +}; + +/// A shortcut for generating a JSON schema for a type. +pub fn schema_for_type() -> JsonObject { + // explicitly to align json schema version to official specifications. + // https://github.com/modelcontextprotocol/modelcontextprotocol/blob/main/schema/2025-03-26/schema.json + // TODO: update to 2020-12 waiting for the mcp spec update + let mut settings = SchemaSettings::draft07(); + settings.transforms = vec![Box::new(schemars::transform::AddNullable::default())]; + let generator = settings.into_generator(); + let schema = generator.into_root_schema_for::(); + let object = serde_json::to_value(schema).expect("failed to serialize schema"); + match object { + serde_json::Value::Object(object) => object, + _ => panic!("unexpected schema value"), + } +} + +/// Call [`schema_for_type`] with a cache +pub fn cached_schema_for_type() -> Arc { + thread_local! { + static CACHE_FOR_TYPE: std::sync::RwLock>> = Default::default(); + }; + CACHE_FOR_TYPE.with(|cache| { + if let Some(x) = cache + .read() + .expect("schema cache lock poisoned") + .get(&TypeId::of::()) + { + x.clone() + } else { + let schema = schema_for_type::(); + let schema = Arc::new(schema); + cache + .write() + .expect("schema cache lock poisoned") + .insert(TypeId::of::(), schema.clone()); + schema + } + }) +} + +/// Trait for extracting parts from a context, unifying tool and prompt extraction +pub trait FromContextPart: Sized { + fn from_context_part(context: &mut C) -> Result; +} + +/// Parameter Extractor +#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] +#[serde(transparent)] +pub struct Parameters

(pub P); + +impl JsonSchema for Parameters

{ + fn schema_name() -> std::borrow::Cow<'static, str> { + P::schema_name() + } + + fn json_schema(generator: &mut schemars::SchemaGenerator) -> schemars::Schema { + P::json_schema(generator) + } +} + +/// Common extractors that can be used by both tool and prompt handlers +impl FromContextPart for RequestContext +where + C: AsRequestContext, +{ + fn from_context_part(context: &mut C) -> Result { + Ok(context.as_request_context().clone()) + } +} + +impl FromContextPart for tokio_util::sync::CancellationToken +where + C: AsRequestContext, +{ + fn from_context_part(context: &mut C) -> Result { + Ok(context.as_request_context().ct.clone()) + } +} + +impl FromContextPart for crate::model::Extensions +where + C: AsRequestContext, +{ + fn from_context_part(context: &mut C) -> Result { + Ok(context.as_request_context().extensions.clone()) + } +} + +pub struct Extension(pub T); + +impl FromContextPart for Extension +where + C: AsRequestContext, + T: Send + Sync + 'static + Clone, +{ + fn from_context_part(context: &mut C) -> Result { + let extension = context + .as_request_context() + .extensions + .get::() + .cloned() + .ok_or_else(|| { + crate::ErrorData::invalid_params( + format!("missing extension {}", std::any::type_name::()), + None, + ) + })?; + Ok(Extension(extension)) + } +} + +impl FromContextPart for crate::Peer +where + C: AsRequestContext, +{ + fn from_context_part(context: &mut C) -> Result { + Ok(context.as_request_context().peer.clone()) + } +} + +impl FromContextPart for crate::model::Meta +where + C: AsRequestContext, +{ + fn from_context_part(context: &mut C) -> Result { + let request_context = context.as_request_context_mut(); + let mut meta = crate::model::Meta::default(); + std::mem::swap(&mut meta, &mut request_context.meta); + Ok(meta) + } +} + +pub struct RequestId(pub crate::model::RequestId); + +impl FromContextPart for RequestId +where + C: AsRequestContext, +{ + fn from_context_part(context: &mut C) -> Result { + Ok(RequestId(context.as_request_context().id.clone())) + } +} + +/// Trait for types that can provide access to RequestContext +pub trait AsRequestContext { + fn as_request_context(&self) -> &RequestContext; + fn as_request_context_mut(&mut self) -> &mut RequestContext; +} diff --git a/crates/rmcp/src/handler/server/prompt.rs b/crates/rmcp/src/handler/server/prompt.rs index 181c6537..692e4dd8 100644 --- a/crates/rmcp/src/handler/server/prompt.rs +++ b/crates/rmcp/src/handler/server/prompt.rs @@ -4,15 +4,16 @@ //! in MCP servers. Prompts allow servers to provide reusable templates for LLM //! interactions with customizable arguments. -use std::{any::TypeId, collections::HashMap, future::Future, marker::PhantomData, pin::Pin}; +use std::{future::Future, marker::PhantomData}; -use futures::future::BoxFuture; -use schemars::{JsonSchema, schema_for}; +use futures::future::{BoxFuture, FutureExt}; use serde::de::DeserializeOwned; +use super::common::AsRequestContext; +pub use super::common::{Extension, Parameters, RequestId}; use crate::{ RoleServer, - model::{GetPromptResult, PromptArgument, PromptMessage}, + model::{GetPromptResult, PromptMessage}, service::RequestContext, }; @@ -38,309 +39,391 @@ impl<'a, S> PromptContext<'a, S> { context, } } +} + +impl AsRequestContext for PromptContext<'_, S> { + fn as_request_context(&self) -> &RequestContext { + &self.context + } - /// Invoke a prompt handler with parsed arguments - pub async fn invoke(self, handler: H) -> Result - where - H: GetPromptHandler, - S: 'a, - { - handler.handle(self).await + fn as_request_context_mut(&mut self) -> &mut RequestContext { + &mut self.context } } /// Trait for handling prompt retrieval pub trait GetPromptHandler { - fn handle<'a>( + fn handle( self, - context: PromptContext<'a, S>, - ) -> BoxFuture<'a, Result> - where - S: 'a; + context: PromptContext<'_, S>, + ) -> BoxFuture<'_, Result>; } /// Type alias for dynamic prompt handlers -pub type DynGetPromptHandler = dyn for<'a> Fn(PromptContext<'a, S>) -> BoxFuture<'a, Result> +pub type DynGetPromptHandler = dyn for<'a> Fn(PromptContext<'a, S>) -> BoxFuture<'a, Result> + Send + Sync; /// Adapter type for async methods that return Vec pub struct AsyncMethodAdapter(PhantomData); -/// Adapter type for async methods with arguments that return Vec +/// Adapter type for async methods with parameters that return Vec pub struct AsyncMethodWithArgsAdapter(PhantomData); -/// Wrapper for parsing prompt arguments -pub struct Arguments(pub T); +/// Adapter types for macro-generated implementations +#[allow(clippy::type_complexity)] +pub struct AsyncPromptAdapter(PhantomData fn(Fut) -> R>); +pub struct SyncPromptAdapter(PhantomData R>); +pub struct AsyncPromptMethodAdapter(PhantomData R>); +pub struct SyncPromptMethodAdapter(PhantomData R>); -/// Type alias for prompt arguments - matches tool's Parameters pattern -pub type PromptArguments = Arguments; - -impl Arguments { - pub fn into_inner(self) -> T { - self.0 - } +/// Trait for types that can be converted into GetPromptResult +pub trait IntoGetPromptResult { + fn into_get_prompt_result(self) -> Result; } -impl JsonSchema for Arguments { - fn schema_name() -> String { - T::schema_name() +impl IntoGetPromptResult for GetPromptResult { + fn into_get_prompt_result(self) -> Result { + Ok(self) } +} - fn json_schema(generator: &mut schemars::r#gen::SchemaGenerator) -> schemars::schema::Schema { - T::json_schema(generator) +impl IntoGetPromptResult for Vec { + fn into_get_prompt_result(self) -> Result { + Ok(GetPromptResult { + description: None, + messages: self, + }) } } -/// Convert a JSON schema into prompt arguments -pub fn arguments_from_schema() -> Option> { - let schema = schema_for!(T); - let schema_value = serde_json::to_value(schema).ok()?; - - // Extract properties from the schema - let properties = schema_value.get("properties")?.as_object()?; - - let required = schema_value - .get("required") - .and_then(|r| r.as_array()) - .map(|arr| { - arr.iter() - .filter_map(|v| v.as_str()) - .collect::>() - }) - .unwrap_or_default(); - - let mut arguments = Vec::new(); - for (name, prop_schema) in properties { - let description = prop_schema - .get("description") - .and_then(|d| d.as_str()) - .map(String::from); - - arguments.push(PromptArgument { - name: name.clone(), - description, - required: Some(required.contains(name.as_str())), - }); +impl IntoGetPromptResult for Result { + fn into_get_prompt_result(self) -> Result { + self.and_then(|v| v.into_get_prompt_result()) } +} - if arguments.is_empty() { - None - } else { - Some(arguments) +// Future wrapper that automatically handles IntoGetPromptResult conversion +pin_project_lite::pin_project! { + #[project = IntoGetPromptResultFutProj] + pub enum IntoGetPromptResultFut { + Pending { + #[pin] + fut: F, + _marker: PhantomData, + }, + Ready { + #[pin] + result: futures::future::Ready>, + } } } -/// Call [`arguments_from_schema`] with a cache -pub fn cached_arguments_from_schema() -> Option> +impl Future for IntoGetPromptResultFut +where + F: Future, + R: IntoGetPromptResult, { - thread_local! { - static CACHE_FOR_TYPE: std::sync::RwLock>>> = Default::default(); - }; - CACHE_FOR_TYPE.with(|cache| { - // Try to read from cache first - if let Ok(cache_read) = cache.read() { - if let Some(x) = cache_read.get(&TypeId::of::()) { - return x.clone(); - } + type Output = Result; + + fn poll( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll { + match self.project() { + IntoGetPromptResultFutProj::Pending { fut, _marker } => fut + .poll(cx) + .map(IntoGetPromptResult::into_get_prompt_result), + IntoGetPromptResultFutProj::Ready { result } => result.poll(cx), } + } +} - // Compute the value - let args = arguments_from_schema::(); +/// Keep the original trait for backward compatibility +pub trait FromPromptContextPart: Sized { + fn from_prompt_context_part(context: &mut PromptContext) -> Result; +} - // Try to update cache, but don't fail if we can't - if let Ok(mut cache_write) = cache.write() { - cache_write.insert(TypeId::of::(), args.clone()); - } +// Implement for common extractors that use AsRequestContext +impl FromPromptContextPart for RequestContext { + fn from_prompt_context_part(context: &mut PromptContext) -> Result { + Ok(context.context.clone()) + } +} - args - }) +impl FromPromptContextPart for tokio_util::sync::CancellationToken { + fn from_prompt_context_part(context: &mut PromptContext) -> Result { + Ok(context.context.ct.clone()) + } } -// Implement GetPromptHandler for async functions returning GetPromptResult -impl GetPromptHandler for F -where - S: Sync, - F: FnOnce(&S, RequestContext) -> Fut + Send + 'static, - Fut: Future> + Send + 'static, -{ - fn handle<'a>( - self, - context: PromptContext<'a, S>, - ) -> BoxFuture<'a, Result> - where - S: 'a, - { - Box::pin(async move { (self)(context.server, context.context).await }) +impl FromPromptContextPart for crate::model::Extensions { + fn from_prompt_context_part(context: &mut PromptContext) -> Result { + Ok(context.context.extensions.clone()) } } -// Implement GetPromptHandler for async functions with parsed arguments -impl GetPromptHandler> for F +impl FromPromptContextPart for Extension where - S: Sync, - F: FnOnce(&S, Arguments, RequestContext) -> Fut + Send + 'static, - Fut: Future> + Send + 'static, - T: DeserializeOwned + 'static, + T: Send + Sync + 'static + Clone, { - fn handle<'a>( - self, - context: PromptContext<'a, S>, - ) -> BoxFuture<'a, Result> - where - S: 'a, - { - Box::pin(async move { - // Parse arguments if provided - let args = if let Some(args_map) = context.arguments { - let args_value = serde_json::Value::Object(args_map); - serde_json::from_value::(args_value).map_err(|e| { - crate::Error::invalid_params(format!("Failed to parse arguments: {}", e), None) - })? - } else { - // Try to deserialize from empty object for optional fields - serde_json::from_value::(serde_json::json!({})).map_err(|e| { - crate::Error::invalid_params(format!("Missing required arguments: {}", e), None) - })? - }; - - (self)(context.server, Arguments(args), context.context).await - }) + fn from_prompt_context_part(context: &mut PromptContext) -> Result { + let extension = context + .context + .extensions + .get::() + .cloned() + .ok_or_else(|| { + crate::ErrorData::invalid_params( + format!("missing extension {}", std::any::type_name::()), + None, + ) + })?; + Ok(Extension(extension)) } } -// Implement GetPromptHandler for async methods that return Pin> -impl GetPromptHandler>> for F -where - S: Sync + 'static, - F: for<'a> FnOnce( - &'a S, - RequestContext, - ) -> Pin< - Box, crate::Error>> + Send + 'a>, - > + Send - + 'static, -{ - fn handle<'a>( - self, - context: PromptContext<'a, S>, - ) -> BoxFuture<'a, Result> - where - S: 'a, - { - Box::pin(async move { - let messages = (self)(context.server, context.context).await?; - Ok(GetPromptResult { - description: None, - messages, - }) - }) +impl FromPromptContextPart for crate::Peer { + fn from_prompt_context_part(context: &mut PromptContext) -> Result { + Ok(context.context.peer.clone()) } } -// Implement GetPromptHandler for async methods with arguments that return Pin> -impl GetPromptHandler, Vec)>> - for F -where - S: Sync + 'static, - T: DeserializeOwned + 'static, - F: for<'a> FnOnce( - &'a S, - Arguments, - RequestContext, - ) -> Pin< - Box, crate::Error>> + Send + 'a>, - > + Send - + 'static, -{ - fn handle<'a>( - self, - context: PromptContext<'a, S>, - ) -> BoxFuture<'a, Result> - where - S: 'a, - { - Box::pin(async move { - // Parse arguments if provided - let args = if let Some(args_map) = context.arguments { - let args_value = serde_json::Value::Object(args_map); - serde_json::from_value::(args_value).map_err(|e| { - crate::Error::invalid_params(format!("Failed to parse arguments: {}", e), None) - })? - } else { - // Try to deserialize from empty object for optional fields - serde_json::from_value::(serde_json::json!({})).map_err(|e| { - crate::Error::invalid_params(format!("Missing required arguments: {}", e), None) - })? - }; - - let messages = (self)(context.server, Arguments(args), context.context).await?; - Ok(GetPromptResult { - description: None, - messages, - }) - }) +impl FromPromptContextPart for crate::model::Meta { + fn from_prompt_context_part(context: &mut PromptContext) -> Result { + let mut meta = crate::model::Meta::default(); + std::mem::swap(&mut meta, &mut context.context.meta); + Ok(meta) } } -// Implement GetPromptHandler for async functions returning Vec -impl GetPromptHandler)> for F -where - S: Sync, - F: FnOnce(&S, RequestContext) -> Fut + Send + 'static, - Fut: Future, crate::Error>> + Send + 'static, -{ - fn handle<'a>( - self, - context: PromptContext<'a, S>, - ) -> BoxFuture<'a, Result> - where - S: 'a, - { - Box::pin(async move { - let messages = (self)(context.server, context.context).await?; - Ok(GetPromptResult { - description: None, - messages, - }) - }) +impl FromPromptContextPart for RequestId { + fn from_prompt_context_part(context: &mut PromptContext) -> Result { + Ok(RequestId(context.context.id.clone())) + } +} + +// Prompt-specific extractor for prompt name +pub struct PromptName(pub String); + +impl FromPromptContextPart for PromptName { + fn from_prompt_context_part(context: &mut PromptContext) -> Result { + Ok(Self(context.name.clone())) } } -// Implement GetPromptHandler for async functions with parsed arguments returning Vec -impl GetPromptHandler, Vec)> for F +// Special implementation for Parameters that handles prompt arguments +impl FromPromptContextPart for Parameters

where - S: Sync, - F: FnOnce(&S, Arguments, RequestContext) -> Fut + Send + 'static, - Fut: Future, crate::Error>> + Send + 'static, - T: DeserializeOwned + 'static, + P: DeserializeOwned, { - fn handle<'a>( - self, - context: PromptContext<'a, S>, - ) -> BoxFuture<'a, Result> - where - S: 'a, - { - Box::pin(async move { - // Parse arguments if provided - let args = if let Some(args_map) = context.arguments { - let args_value = serde_json::Value::Object(args_map); - serde_json::from_value::(args_value).map_err(|e| { - crate::Error::invalid_params(format!("Failed to parse arguments: {}", e), None) - })? - } else { - // Try to deserialize from empty object for optional fields - serde_json::from_value::(serde_json::json!({})).map_err(|e| { - crate::Error::invalid_params(format!("Missing required arguments: {}", e), None) - })? - }; - - let messages = (self)(context.server, Arguments(args), context.context).await?; - Ok(GetPromptResult { - description: None, - messages, + fn from_prompt_context_part(context: &mut PromptContext) -> Result { + let params = if let Some(args_map) = context.arguments.take() { + let args_value = serde_json::Value::Object(args_map); + serde_json::from_value::

(args_value).map_err(|e| { + crate::ErrorData::invalid_params(format!("Failed to parse parameters: {}", e), None) + })? + } else { + // Try to deserialize from empty object for optional fields + serde_json::from_value::

(serde_json::json!({})).map_err(|e| { + crate::ErrorData::invalid_params( + format!("Missing required parameters: {}", e), + None, + ) + })? + }; + Ok(Parameters(params)) + } +} + +// Macro to generate GetPromptHandler implementations for various parameter combinations +macro_rules! impl_prompt_handler_for { + ($($T: ident)*) => { + impl_prompt_handler_for!([] [$($T)*]); + }; + // finished + ([$($Tn: ident)*] []) => { + impl_prompt_handler_for!(@impl $($Tn)*); + }; + ([$($Tn: ident)*] [$Tn_1: ident $($Rest: ident)*]) => { + impl_prompt_handler_for!(@impl $($Tn)*); + impl_prompt_handler_for!([$($Tn)* $Tn_1] [$($Rest)*]); + }; + (@impl $($Tn: ident)*) => { + // Implementation for async methods (transformed by #[prompt] macro) + impl<$($Tn,)* S, F, R> GetPromptHandler for F + where + $( + $Tn: FromPromptContextPart + Send, + )* + F: FnOnce(&S, $($Tn,)*) -> BoxFuture<'_, R> + Send, + R: IntoGetPromptResult + Send + 'static, + S: Send + Sync + 'static, + { + #[allow(unused_variables, non_snake_case, unused_mut)] + fn handle( + self, + mut context: PromptContext<'_, S>, + ) -> BoxFuture<'_, Result> + { + $( + let result = $Tn::from_prompt_context_part(&mut context); + let $Tn = match result { + Ok(value) => value, + Err(e) => return std::future::ready(Err(e)).boxed(), + }; + )* + let service = context.server; + let fut = self(service, $($Tn,)*); + async move { + let result = fut.await; + result.into_get_prompt_result() + }.boxed() + } + } + + + // Implementation for sync methods + impl<$($Tn,)* S, F, R> GetPromptHandler> for F + where + $( + $Tn: FromPromptContextPart + Send, + )* + F: FnOnce(&S, $($Tn,)*) -> R + Send, + R: IntoGetPromptResult + Send, + S: Send + Sync, + { + #[allow(unused_variables, non_snake_case, unused_mut)] + fn handle( + self, + mut context: PromptContext<'_, S>, + ) -> BoxFuture<'_, Result> + { + $( + let result = $Tn::from_prompt_context_part(&mut context); + let $Tn = match result { + Ok(value) => value, + Err(e) => return std::future::ready(Err(e)).boxed(), + }; + )* + let service = context.server; + let result = self(service, $($Tn,)*); + std::future::ready(result.into_get_prompt_result()).boxed() + } + } + + + // AsyncPromptAdapter - for standalone functions returning GetPromptResult + impl<$($Tn,)* S, F, Fut, R> GetPromptHandler> for F + where + $( + $Tn: FromPromptContextPart + Send + 'static, + )* + F: FnOnce($($Tn,)*) -> Fut + Send + 'static, + Fut: Future> + Send + 'static, + R: IntoGetPromptResult + Send + 'static, + S: Send + Sync + 'static, + { + #[allow(unused_variables, non_snake_case, unused_mut)] + fn handle( + self, + mut context: PromptContext<'_, S>, + ) -> BoxFuture<'_, Result> + { + // Extract all parameters before moving into the async block + $( + let result = $Tn::from_prompt_context_part(&mut context); + let $Tn = match result { + Ok(value) => value, + Err(e) => return std::future::ready(Err(e)).boxed(), + }; + )* + + // Since we're dealing with standalone functions that don't take &S, + // we can return a 'static future + Box::pin(async move { + let result = self($($Tn,)*).await?; + result.into_get_prompt_result() + }) + } + } + + + // SyncPromptAdapter - for standalone sync functions returning Result + impl<$($Tn,)* S, F, R> GetPromptHandler> for F + where + $( + $Tn: FromPromptContextPart + Send + 'static, + )* + F: FnOnce($($Tn,)*) -> Result + Send + 'static, + R: IntoGetPromptResult + Send + 'static, + S: Send + Sync, + { + #[allow(unused_variables, non_snake_case, unused_mut)] + fn handle( + self, + mut context: PromptContext<'_, S>, + ) -> BoxFuture<'_, Result> + { + $( + let result = $Tn::from_prompt_context_part(&mut context); + let $Tn = match result { + Ok(value) => value, + Err(e) => return std::future::ready(Err(e)).boxed(), + }; + )* + let result = self($($Tn,)*); + std::future::ready(result.and_then(|r| r.into_get_prompt_result())).boxed() + } + } + + }; +} + +// Invoke the macro to generate implementations for up to 16 parameters +impl_prompt_handler_for!(T0 T1 T2 T3 T4 T5 T6 T7 T8 T9 T10 T11 T12 T13 T14 T15); + +/// Extract prompt arguments from a type's JSON schema +/// This function analyzes the schema of a type and extracts the properties +/// as PromptArgument entries with name, description, and required status +pub fn cached_arguments_from_schema() +-> Option> { + let schema = super::common::cached_schema_for_type::(); + let schema_value = serde_json::Value::Object((*schema).clone()); + + let properties = schema_value.get("properties").and_then(|p| p.as_object()); + + if let Some(props) = properties { + let required = schema_value + .get("required") + .and_then(|r| r.as_array()) + .map(|arr| { + arr.iter() + .filter_map(|v| v.as_str()) + .collect::>() }) - }) + .unwrap_or_default(); + + let mut arguments = Vec::new(); + for (name, prop_schema) in props { + let description = prop_schema + .get("description") + .and_then(|d| d.as_str()) + .map(|s| s.to_string()); + + arguments.push(crate::model::PromptArgument { + name: name.clone(), + description, + required: Some(required.contains(name.as_str())), + }); + } + + if arguments.is_empty() { + None + } else { + Some(arguments) + } + } else { + None } } diff --git a/crates/rmcp/src/handler/server/router/prompt.rs b/crates/rmcp/src/handler/server/router/prompt.rs index e7c552dd..67491c5c 100644 --- a/crates/rmcp/src/handler/server/router/prompt.rs +++ b/crates/rmcp/src/handler/server/router/prompt.rs @@ -1,6 +1,6 @@ use std::{borrow::Cow, sync::Arc}; -use futures::{FutureExt, future::BoxFuture}; +use futures::future::BoxFuture; use crate::{ handler::server::prompt::{DynGetPromptHandler, GetPromptHandler, PromptContext}, @@ -40,7 +40,7 @@ impl PromptRoute { Self { get: Arc::new(move |context: PromptContext| { let handler = handler.clone(); - context.invoke(handler).boxed() + handler.handle(context) }), attr: attr.into(), } @@ -48,7 +48,9 @@ impl PromptRoute { pub fn new_dyn(attr: impl Into, handler: H) -> Self where - H: for<'a> Fn(PromptContext<'a, S>) -> BoxFuture<'a, Result> + H: for<'a> Fn( + PromptContext<'a, S>, + ) -> BoxFuture<'a, Result> + Send + Sync + 'static, @@ -172,9 +174,9 @@ where pub async fn get_prompt( &self, context: PromptContext<'_, S>, - ) -> Result { + ) -> Result { let item = self.map.get(context.name.as_str()).ok_or_else(|| { - crate::Error::invalid_params( + crate::ErrorData::invalid_params( format!("prompt '{}' not found", context.name), Some(serde_json::json!({ "available_prompts": self.list_all().iter().map(|p| &p.name).collect::>() diff --git a/crates/rmcp/src/handler/server/tool.rs b/crates/rmcp/src/handler/server/tool.rs index cea0e9cc..9fd697ff 100644 --- a/crates/rmcp/src/handler/server/tool.rs +++ b/crates/rmcp/src/handler/server/tool.rs @@ -1,58 +1,22 @@ use std::{ - any::TypeId, borrow::Cow, collections::HashMap, future::Ready, marker::PhantomData, sync::Arc, + borrow::Cow, + future::{Future, Ready}, + marker::PhantomData, }; use futures::future::{BoxFuture, FutureExt}; -use schemars::{JsonSchema, transform::AddNullable}; -use serde::{Deserialize, Serialize, de::DeserializeOwned}; -use tokio_util::sync::CancellationToken; +use serde::de::DeserializeOwned; +use super::common::{AsRequestContext, FromContextPart}; +pub use super::common::{ + Extension, Parameters, RequestId, cached_schema_for_type, schema_for_type, +}; pub use super::router::tool::{ToolRoute, ToolRouter}; use crate::{ RoleServer, model::{CallToolRequestParam, CallToolResult, IntoContents, JsonObject}, - schemars::generate::SchemaSettings, service::RequestContext, }; -/// A shortcut for generating a JSON schema for a type. -pub fn schema_for_type() -> JsonObject { - // explicitly to align json schema version to official specifications. - // https://github.com/modelcontextprotocol/modelcontextprotocol/blob/main/schema/2025-03-26/schema.json - // TODO: update to 2020-12 waiting for the mcp spec update - let mut settings = SchemaSettings::draft07(); - settings.transforms = vec![Box::new(AddNullable::default())]; - let generator = settings.into_generator(); - let schema = generator.into_root_schema_for::(); - let object = serde_json::to_value(schema).expect("failed to serialize schema"); - match object { - serde_json::Value::Object(object) => object, - _ => panic!("unexpected schema value"), - } -} - -/// Call [`schema_for_type`] with a cache -pub fn cached_schema_for_type() -> Arc { - thread_local! { - static CACHE_FOR_TYPE: std::sync::RwLock>> = Default::default(); - }; - CACHE_FOR_TYPE.with(|cache| { - if let Some(x) = cache - .read() - .expect("schema cache lock poisoned") - .get(&TypeId::of::()) - { - x.clone() - } else { - let schema = schema_for_type::(); - let schema = Arc::new(schema); - cache - .write() - .expect("schema cache lock poisoned") - .insert(TypeId::of::(), schema.clone()); - schema - } - }) -} /// Deserialize a JSON object into a type pub fn parse_json_object(input: JsonObject) -> Result { @@ -91,6 +55,17 @@ impl<'s, S> ToolCallContext<'s, S> { } } +impl AsRequestContext for ToolCallContext<'_, S> { + fn as_request_context(&self) -> &RequestContext { + &self.request_context + } + + fn as_request_context_mut(&mut self) -> &mut RequestContext { + &mut self.request_context + } +} + +// Keep the original trait for backward compatibility pub trait FromToolCallContextPart: Sized { fn from_tool_call_context_part( context: &mut ToolCallContext, @@ -177,47 +152,21 @@ pub type DynCallToolHandler = dyn for<'s> Fn(ToolCallContext<'s, S>) -> BoxFu + Send + Sync; -/// Parameter Extractor -/// -#[derive(Debug, Clone, Serialize, Deserialize)] -#[serde(transparent)] -pub struct Parameters

(pub P); - -impl JsonSchema for Parameters

{ - fn schema_name() -> Cow<'static, str> { - P::schema_name() - } - - fn json_schema(generator: &mut schemars::SchemaGenerator) -> schemars::Schema { - P::json_schema(generator) - } -} - -impl FromToolCallContextPart for CancellationToken { - fn from_tool_call_context_part( - context: &mut ToolCallContext, - ) -> Result { - Ok(context.request_context.ct.clone()) - } -} - +// Tool-specific extractor for tool name pub struct ToolName(pub Cow<'static, str>); -impl FromToolCallContextPart for ToolName { - fn from_tool_call_context_part( - context: &mut ToolCallContext, - ) -> Result { +impl FromContextPart> for ToolName { + fn from_context_part(context: &mut ToolCallContext) -> Result { Ok(Self(context.name.clone())) } } -impl FromToolCallContextPart for Parameters

+// Special implementation for Parameters that handles tool arguments +impl FromContextPart> for Parameters

where P: DeserializeOwned, { - fn from_tool_call_context_part( - context: &mut ToolCallContext, - ) -> Result { + fn from_context_part(context: &mut ToolCallContext) -> Result { let arguments = context.arguments.take().unwrap_or_default(); let value: P = serde_json::from_value(serde_json::Value::Object(arguments)).map_err(|e| { @@ -230,81 +179,31 @@ where } } -impl FromToolCallContextPart for JsonObject { - fn from_tool_call_context_part( - context: &mut ToolCallContext, - ) -> Result { - let object = context.arguments.take().unwrap_or_default(); - Ok(object) - } -} - -impl FromToolCallContextPart for crate::model::Extensions { - fn from_tool_call_context_part( - context: &mut ToolCallContext, - ) -> Result { - let extensions = context.request_context.extensions.clone(); - Ok(extensions) - } -} - -pub struct Extension(pub T); - -impl FromToolCallContextPart for Extension +// Also implement the old trait directly for Parameters to support macro-generated code +impl FromToolCallContextPart for Parameters

where - T: Send + Sync + 'static + Clone, + P: DeserializeOwned, { fn from_tool_call_context_part( context: &mut ToolCallContext, ) -> Result { - let extension = context - .request_context - .extensions - .get::() - .cloned() - .ok_or_else(|| { + let arguments = context.arguments.take().unwrap_or_default(); + let value: P = + serde_json::from_value(serde_json::Value::Object(arguments)).map_err(|e| { crate::ErrorData::invalid_params( - format!("missing extension {}", std::any::type_name::()), + format!("failed to deserialize parameters: {error}", error = e), None, ) })?; - Ok(Extension(extension)) - } -} - -impl FromToolCallContextPart for crate::Peer { - fn from_tool_call_context_part( - context: &mut ToolCallContext, - ) -> Result { - let peer = context.request_context.peer.clone(); - Ok(peer) - } -} - -impl FromToolCallContextPart for crate::model::Meta { - fn from_tool_call_context_part( - context: &mut ToolCallContext, - ) -> Result { - let mut meta = crate::model::Meta::default(); - std::mem::swap(&mut meta, &mut context.request_context.meta); - Ok(meta) - } -} - -pub struct RequestId(pub crate::model::RequestId); -impl FromToolCallContextPart for RequestId { - fn from_tool_call_context_part( - context: &mut ToolCallContext, - ) -> Result { - Ok(RequestId(context.request_context.id.clone())) + Ok(Parameters(value)) } } -impl FromToolCallContextPart for RequestContext { - fn from_tool_call_context_part( - context: &mut ToolCallContext, - ) -> Result { - Ok(context.request_context.clone()) +// Special implementation for JsonObject that takes tool arguments +impl FromContextPart> for JsonObject { + fn from_context_part(context: &mut ToolCallContext) -> Result { + let object = context.arguments.take().unwrap_or_default(); + Ok(object) } } diff --git a/crates/rmcp/tests/test_prompt_handler.rs b/crates/rmcp/tests/test_prompt_handler.rs new file mode 100644 index 00000000..1eef0a1f --- /dev/null +++ b/crates/rmcp/tests/test_prompt_handler.rs @@ -0,0 +1,143 @@ +//cargo test --test test_prompt_handler --features "client server" +// Tests for verifying that the #[prompt_handler] macro correctly generates +// the ServerHandler trait implementation methods. +#![allow(dead_code)] + +use rmcp::{ + RoleServer, ServerHandler, + handler::server::router::prompt::PromptRouter, + model::{GetPromptRequestParam, GetPromptResult, ListPromptsResult, PaginatedRequestParam}, + prompt_handler, + service::RequestContext, +}; + +#[derive(Debug, Clone)] +pub struct TestPromptServer { + prompt_router: PromptRouter, +} + +impl TestPromptServer { + pub fn new() -> Self { + Self { + prompt_router: PromptRouter::new(), + } + } +} + +#[prompt_handler] +impl ServerHandler for TestPromptServer {} + +#[derive(Debug, Clone)] +pub struct CustomRouterServer { + custom_router: PromptRouter, +} + +impl CustomRouterServer { + pub fn new() -> Self { + Self { + custom_router: PromptRouter::new(), + } + } + + pub fn get_custom_router(&self) -> &PromptRouter { + &self.custom_router + } +} + +#[prompt_handler(router = self.custom_router)] +impl ServerHandler for CustomRouterServer {} + +#[derive(Debug, Clone)] +pub struct GenericPromptServer { + prompt_router: PromptRouter, + _marker: std::marker::PhantomData, +} + +impl GenericPromptServer { + pub fn new() -> Self { + Self { + prompt_router: PromptRouter::new(), + _marker: std::marker::PhantomData, + } + } +} + +#[prompt_handler] +impl ServerHandler for GenericPromptServer {} + +#[test] +fn test_prompt_handler_basic() { + let server = TestPromptServer::new(); + + // Test that the server implements ServerHandler + fn assert_server_handler(_: &T) {} + assert_server_handler(&server); + + // Test that the prompt router is accessible + assert_eq!(server.prompt_router.list_all().len(), 0); +} + +#[test] +fn test_prompt_handler_custom_router() { + let server = CustomRouterServer::new(); + + // Test that the server implements ServerHandler + fn assert_server_handler(_: &T) {} + assert_server_handler(&server); + + // Test that the custom router is used + assert_eq!(server.custom_router.list_all().len(), 0); +} + +#[test] +fn test_prompt_handler_with_generics() { + let server = GenericPromptServer::::new(); + + // Test that generic server implements ServerHandler + fn assert_server_handler(_: &T) {} + assert_server_handler(&server); + + // Test with a different generic type + let server2 = GenericPromptServer::::new(); + assert_server_handler(&server2); +} + +#[test] +fn test_prompt_handler_trait_implementation() { + // This test verifies that the prompt_handler macro generates proper ServerHandler implementation + // The actual method signatures are tested through the ServerHandler trait bound + fn compile_time_check() {} + + compile_time_check::(); + compile_time_check::(); + compile_time_check::>(); +} + +// Test that the macro works with different server configurations +mod nested { + use super::*; + + #[derive(Debug, Clone)] + pub struct NestedServer { + prompt_router: PromptRouter, + } + + impl NestedServer { + pub fn new() -> Self { + Self { + prompt_router: PromptRouter::new(), + } + } + } + + #[prompt_handler] + impl ServerHandler for NestedServer {} + + #[test] + fn test_nested_prompt_handler() { + let server = NestedServer::new(); + // Verify it implements ServerHandler + fn assert_server_handler(_: &T) {} + assert_server_handler(&server); + } +} diff --git a/crates/rmcp/tests/test_prompt_macro_annotations.rs b/crates/rmcp/tests/test_prompt_macro_annotations.rs new file mode 100644 index 00000000..752070b5 --- /dev/null +++ b/crates/rmcp/tests/test_prompt_macro_annotations.rs @@ -0,0 +1,291 @@ +//cargo test --test test_prompt_macro_annotations --features "client server" +#![allow(dead_code)] + +use rmcp::{ + ServerHandler, + handler::server::prompt::Parameters, + model::{GetPromptResult, Prompt, PromptMessage, PromptMessageRole}, + prompt, +}; +use schemars::JsonSchema; +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Clone)] +struct TestServer; + +impl ServerHandler for TestServer {} + +#[derive(Serialize, Deserialize, JsonSchema)] +struct TestArgs { + /// The input text to process + input: String, + /// Optional configuration + #[serde(skip_serializing_if = "Option::is_none")] + config: Option, +} + +#[derive(Serialize, Deserialize, JsonSchema)] +struct ComplexArgs { + /// Required field + required_field: String, + /// Optional string field + #[schemars(description = "An optional string parameter")] + optional_string: Option, + /// Optional number field + optional_number: Option, + /// Array field + items: Vec, +} + +// Test basic prompt attribute generation +#[prompt] +async fn basic_prompt(_server: &TestServer) -> Vec { + vec![PromptMessage::new_text( + PromptMessageRole::Assistant, + "Basic response", + )] +} + +// Test prompt with custom name +#[prompt(name = "custom_name")] +async fn named_prompt(_server: &TestServer) -> Vec { + vec![PromptMessage::new_text( + PromptMessageRole::Assistant, + "Named response", + )] +} + +// Test prompt with custom description +#[prompt(description = "This is a custom description")] +async fn described_prompt(_server: &TestServer) -> Vec { + vec![PromptMessage::new_text( + PromptMessageRole::Assistant, + "Described response", + )] +} + +// Test prompt with both name and description +#[prompt(name = "full_custom", description = "Fully customized prompt")] +async fn fully_custom_prompt(_server: &TestServer) -> Vec { + vec![PromptMessage::new_text( + PromptMessageRole::Assistant, + "Fully custom response", + )] +} + +// Test prompt with doc comments +/// This is a doc comment description +/// that spans multiple lines +#[prompt] +async fn doc_comment_prompt(_server: &TestServer) -> Vec { + vec![PromptMessage::new_text( + PromptMessageRole::Assistant, + "Doc comment response", + )] +} + +// Test prompt with doc comments and explicit description (explicit wins) +/// This is a doc comment +#[prompt(description = "This overrides the doc comment")] +async fn override_doc_prompt(_server: &TestServer) -> Vec { + vec![PromptMessage::new_text( + PromptMessageRole::Assistant, + "Override response", + )] +} + +// Test prompt with arguments +#[prompt] +async fn args_prompt(_server: &TestServer, _args: Parameters) -> Vec { + vec![PromptMessage::new_text( + PromptMessageRole::Assistant, + "Args response", + )] +} + +// Test prompt with complex arguments +#[prompt] +async fn complex_args_prompt( + _server: &TestServer, + _args: Parameters, +) -> GetPromptResult { + GetPromptResult { + description: Some("Complex args result".to_string()), + messages: vec![PromptMessage::new_text( + PromptMessageRole::Assistant, + "Complex response", + )], + } +} + +// Test sync prompt +#[prompt] +fn sync_prompt(_server: &TestServer) -> Vec { + vec![PromptMessage::new_text( + PromptMessageRole::Assistant, + "Sync response", + )] +} + +#[test] +fn test_basic_prompt_attr() { + let attr = basic_prompt_prompt_attr(); + assert_eq!(attr.name, "basic_prompt"); + assert_eq!(attr.description, None); + assert!(attr.arguments.is_none()); +} + +#[test] +fn test_named_prompt_attr() { + let attr = named_prompt_prompt_attr(); + assert_eq!(attr.name, "custom_name"); + assert_eq!(attr.description, None); + assert!(attr.arguments.is_none()); +} + +#[test] +fn test_described_prompt_attr() { + let attr = described_prompt_prompt_attr(); + assert_eq!(attr.name, "described_prompt"); + assert_eq!( + attr.description.as_deref(), + Some("This is a custom description") + ); + assert!(attr.arguments.is_none()); +} + +#[test] +fn test_fully_custom_prompt_attr() { + let attr = fully_custom_prompt_prompt_attr(); + assert_eq!(attr.name, "full_custom"); + assert_eq!(attr.description.as_deref(), Some("Fully customized prompt")); + assert!(attr.arguments.is_none()); +} + +#[test] +fn test_doc_comment_prompt_attr() { + let attr = doc_comment_prompt_prompt_attr(); + assert_eq!(attr.name, "doc_comment_prompt"); + assert!(attr.description.is_some()); + let desc = attr.description.unwrap(); + assert!(desc.contains("This is a doc comment description")); + assert!(desc.contains("that spans multiple lines")); +} + +#[test] +fn test_override_doc_prompt_attr() { + let attr = override_doc_prompt_prompt_attr(); + assert_eq!(attr.name, "override_doc_prompt"); + assert_eq!( + attr.description.as_deref(), + Some("This overrides the doc comment") + ); +} + +#[test] +fn test_args_prompt_attr() { + let attr = args_prompt_prompt_attr(); + assert_eq!(attr.name, "args_prompt"); + + let args = attr.arguments.as_ref().unwrap(); + assert_eq!(args.len(), 2); + + // Check input field + let input_arg = args.iter().find(|a| a.name == "input").unwrap(); + assert_eq!(input_arg.required, Some(true)); + assert_eq!( + input_arg.description.as_deref(), + Some("The input text to process") + ); + + // Check config field + let config_arg = args.iter().find(|a| a.name == "config").unwrap(); + assert_eq!(config_arg.required, Some(false)); + assert_eq!( + config_arg.description.as_deref(), + Some("Optional configuration") + ); +} + +#[test] +fn test_complex_args_prompt_attr() { + let attr = complex_args_prompt_prompt_attr(); + assert_eq!(attr.name, "complex_args_prompt"); + + let args = attr.arguments.as_ref().unwrap(); + assert_eq!(args.len(), 4); + + // Check required_field + let required_arg = args.iter().find(|a| a.name == "required_field").unwrap(); + assert_eq!(required_arg.required, Some(true)); + assert_eq!(required_arg.description.as_deref(), Some("Required field")); + + // Check optional_string + let optional_string_arg = args.iter().find(|a| a.name == "optional_string").unwrap(); + assert_eq!(optional_string_arg.required, Some(false)); + assert_eq!( + optional_string_arg.description.as_deref(), + Some("An optional string parameter") + ); + + // Check optional_number + let optional_number_arg = args.iter().find(|a| a.name == "optional_number").unwrap(); + assert_eq!(optional_number_arg.required, Some(false)); + assert_eq!( + optional_number_arg.description.as_deref(), + Some("Optional number field") + ); + + // Check items + let items_arg = args.iter().find(|a| a.name == "items").unwrap(); + assert_eq!(items_arg.required, Some(true)); + assert_eq!(items_arg.description.as_deref(), Some("Array field")); +} + +#[test] +fn test_sync_prompt_attr() { + let attr = sync_prompt_prompt_attr(); + assert_eq!(attr.name, "sync_prompt"); + assert!(attr.arguments.is_none()); +} + +#[test] +fn test_prompt_attr_function_type() { + // Test that the generated function returns the correct type + fn assert_prompt_attr_fn(_: impl Fn() -> Prompt) {} + + assert_prompt_attr_fn(basic_prompt_prompt_attr); + assert_prompt_attr_fn(named_prompt_prompt_attr); + assert_prompt_attr_fn(described_prompt_prompt_attr); + assert_prompt_attr_fn(fully_custom_prompt_prompt_attr); + assert_prompt_attr_fn(doc_comment_prompt_prompt_attr); + assert_prompt_attr_fn(override_doc_prompt_prompt_attr); + assert_prompt_attr_fn(args_prompt_prompt_attr); + assert_prompt_attr_fn(complex_args_prompt_prompt_attr); + assert_prompt_attr_fn(sync_prompt_prompt_attr); +} + +// Test generic prompts +#[derive(Debug, Clone)] +struct GenericServer { + _marker: std::marker::PhantomData, +} + +impl ServerHandler for GenericServer {} + +#[prompt] +async fn generic_prompt( + _server: &GenericServer, +) -> Vec { + vec![PromptMessage::new_text( + PromptMessageRole::Assistant, + "Generic response", + )] +} + +#[test] +fn test_generic_prompt_attr() { + let attr = generic_prompt_prompt_attr(); + assert_eq!(attr.name, "generic_prompt"); + assert!(attr.arguments.is_none()); +} diff --git a/crates/rmcp/tests/test_prompt_macros.rs b/crates/rmcp/tests/test_prompt_macros.rs new file mode 100644 index 00000000..60725faa --- /dev/null +++ b/crates/rmcp/tests/test_prompt_macros.rs @@ -0,0 +1,383 @@ +//cargo test --test test_prompt_macros --features "client server" +#![allow(dead_code)] +use std::sync::Arc; + +use rmcp::{ + ClientHandler, RoleServer, ServerHandler, ServiceExt, + handler::server::{prompt::Parameters, router::prompt::PromptRouter}, + model::{ + ClientInfo, GetPromptRequestParam, GetPromptResult, ListPromptsResult, + PaginatedRequestParam, PromptMessage, PromptMessageRole, + }, + prompt, prompt_handler, prompt_router, + service::RequestContext, +}; +use schemars::JsonSchema; +use serde::{Deserialize, Serialize}; + +#[derive(Serialize, Deserialize, JsonSchema)] +pub struct CodeReviewRequest { + pub file_path: String, + pub language: String, +} + +#[prompt_handler(router = self.prompt_router)] +impl ServerHandler for Server {} + +#[derive(Debug, Clone)] +#[allow(dead_code)] +pub struct Server { + prompt_router: PromptRouter, +} + +impl Default for Server { + fn default() -> Self { + Self::new() + } +} + +#[prompt_router] +impl Server { + pub fn new() -> Self { + Self { + prompt_router: Self::prompt_router(), + } + } + + /// This prompt is used to review code for best practices. + #[prompt( + name = "code-review", + description = "Review code for best practices and issues." + )] + pub async fn code_review(&self, params: Parameters) -> Vec { + vec![ + PromptMessage::new_text( + PromptMessageRole::User, + format!( + "Please review the {} code in: {}", + params.0.language, params.0.file_path + ), + ), + PromptMessage::new_text( + PromptMessageRole::Assistant, + "I'll review this code for best practices and potential issues.".to_string(), + ), + ] + } + + #[prompt] + async fn empty_param(&self) -> Vec { + vec![PromptMessage::new_text( + PromptMessageRole::Assistant, + "This is a prompt with no parameters.".to_string(), + )] + } +} + +// define generic service trait +pub trait DataService: Send + Sync + 'static { + fn get_context(&self) -> String; +} + +// mock service for test +#[derive(Clone)] +struct MockDataService; +impl DataService for MockDataService { + fn get_context(&self) -> String { + "mock context data".to_string() + } +} + +// define generic server +#[derive(Debug, Clone)] +pub struct GenericServer { + data_service: Arc, + prompt_router: PromptRouter, +} + +#[prompt_router] +impl GenericServer { + pub fn new(data_service: DS) -> Self { + Self { + data_service: Arc::new(data_service), + prompt_router: Self::prompt_router(), + } + } + + #[prompt(description = "Get contextual help from the service")] + async fn get_help(&self) -> GetPromptResult { + let context = self.data_service.get_context(); + GetPromptResult { + description: Some("Contextual help based on service data".to_string()), + messages: vec![ + PromptMessage::new_text( + PromptMessageRole::User, + "I need help with the current context.".to_string(), + ), + PromptMessage::new_text( + PromptMessageRole::Assistant, + format!( + "Based on the context '{}', here's how I can help...", + context + ), + ), + ], + } + } +} + +#[prompt_handler] +impl ServerHandler for GenericServer {} + +#[tokio::test] +async fn test_prompt_macros() { + let server = Server::new(); + let _attr = Server::code_review_prompt_attr(); + let _code_review_prompt_attr_fn = Server::code_review_prompt_attr; + let _code_review_fn = Server::code_review; + let result = server + .code_review(Parameters(CodeReviewRequest { + file_path: "/src/main.rs".into(), + language: "rust".into(), + })) + .await; + assert_eq!(result.len(), 2); + assert_eq!(result[0].role, PromptMessageRole::User); + assert_eq!(result[1].role, PromptMessageRole::Assistant); +} + +#[tokio::test] +async fn test_prompt_macros_with_empty_param() { + let _attr = Server::empty_param_prompt_attr(); + println!("{_attr:?}"); + assert!( + _attr.arguments.is_none(), + "Empty param prompt should have no arguments" + ); +} + +#[tokio::test] +async fn test_prompt_macros_with_generics() { + let mock_service = MockDataService; + let server = GenericServer::new(mock_service); + let _attr = GenericServer::::get_help_prompt_attr(); + let _get_help_call_fn = GenericServer::::get_help; + let _get_help_fn = GenericServer::::get_help; + let result = server.get_help().await; + assert!(result.description.is_some()); + assert_eq!(result.messages.len(), 2); + match &result.messages[1].content { + rmcp::model::PromptMessageContent::Text { text } => { + assert!(text.contains("mock context data")); + } + _ => panic!("Expected text content"), + } +} + +#[tokio::test] +async fn test_prompt_macros_with_optional_param() { + let _attr = Server::code_review_prompt_attr(); + let arguments = _attr.arguments.as_ref().unwrap(); + + // Check that we have the expected number of arguments + assert_eq!(arguments.len(), 2); + + // Verify file_path is required + let file_path_arg = arguments.iter().find(|a| a.name == "file_path").unwrap(); + assert_eq!(file_path_arg.required, Some(true)); + + // Verify language is required + let language_arg = arguments.iter().find(|a| a.name == "language").unwrap(); + assert_eq!(language_arg.required, Some(true)); +} + +impl CodeReviewRequest {} + +// Struct defined for testing optional field schema generation +#[derive(Debug, Deserialize, Serialize, JsonSchema)] +pub struct OptionalFieldTestSchema { + #[schemars(description = "An optional description field")] + pub description: Option, +} + +// Struct defined for testing optional i64 field schema generation and null handling +#[derive(Debug, Deserialize, Serialize, JsonSchema)] +pub struct OptionalI64TestSchema { + #[schemars(description = "An optional i64 field")] + pub count: Option, + pub mandatory_field: String, // Added to ensure non-empty object schema +} + +// Dummy struct to host the test prompt method +#[derive(Debug, Clone)] +pub struct OptionalSchemaTester { + prompt_router: PromptRouter, +} + +impl Default for OptionalSchemaTester { + fn default() -> Self { + Self::new() + } +} + +impl OptionalSchemaTester { + pub fn new() -> Self { + Self { + prompt_router: Self::prompt_router(), + } + } +} + +#[prompt_router] +impl OptionalSchemaTester { + // Dummy prompt function using the test schema as an aggregated parameter + #[prompt(description = "A prompt to test optional schema generation")] + async fn test_optional(&self, _req: Parameters) -> Vec { + vec![PromptMessage::new_text( + PromptMessageRole::Assistant, + "Testing optional fields".to_string(), + )] + } + + // Prompt function to test optional i64 handling + #[prompt(description = "A prompt to test optional i64 schema generation")] + async fn test_optional_i64( + &self, + Parameters(req): Parameters, + ) -> GetPromptResult { + let message = match req.count { + Some(c) => format!("Received count: {}", c), + None => "Received null count".to_string(), + }; + + GetPromptResult { + description: Some("Test result for optional i64".to_string()), + messages: vec![PromptMessage::new_text( + PromptMessageRole::Assistant, + message, + )], + } + } +} + +#[prompt_handler] +// Implement ServerHandler to route prompt calls for OptionalSchemaTester +impl ServerHandler for OptionalSchemaTester {} + +#[test] +fn test_optional_field_schema_generation_via_macro() { + // tests https://github.com/modelcontextprotocol/rust-sdk/issues/135 + + // Get the attributes generated by the #[prompt] macro helper + let prompt_attr = OptionalSchemaTester::test_optional_prompt_attr(); + + // Print the actual generated schema for debugging + println!( + "Actual arguments generated by macro: {:#?}", + prompt_attr.arguments + ); + + // Verify the schema generated for the aggregated OptionalFieldTestSchema + let arguments = prompt_attr.arguments.expect("Should have arguments"); + + // Check that we have an argument for the optional description field + let description_arg = arguments + .iter() + .find(|arg| arg.name == "description") + .expect("Should have description argument"); + + // Assert that optional fields are marked as not required + assert_eq!( + description_arg.required, + Some(false), + "Optional fields should be marked as not required" + ); + + // Check the description is correct + assert_eq!( + description_arg.description.as_deref(), + Some("An optional description field") + ); +} + +// Define a dummy client handler +#[derive(Debug, Clone, Default)] +struct DummyClientHandler {} + +impl ClientHandler for DummyClientHandler { + fn get_info(&self) -> ClientInfo { + ClientInfo::default() + } +} + +#[tokio::test] +async fn test_optional_i64_field_with_null_input() -> anyhow::Result<()> { + let (server_transport, client_transport) = tokio::io::duplex(4096); + + // Server setup + let server = OptionalSchemaTester::new(); + let server_handle = tokio::spawn(async move { + server.serve(server_transport).await?.waiting().await?; + anyhow::Ok(()) + }); + + // Create a simple client handler that just forwards prompt calls + let client_handler = DummyClientHandler::default(); + let client = client_handler.serve(client_transport).await?; + + // Test null case + let result = client + .get_prompt(GetPromptRequestParam { + name: "test_optional_i64".into(), + arguments: Some( + serde_json::json!({ + "count": null, + "mandatory_field": "test_null" + }) + .as_object() + .unwrap() + .clone(), + ), + }) + .await?; + + let result_text = match &result.messages.first().unwrap().content { + rmcp::model::PromptMessageContent::Text { text } => text.as_str(), + _ => panic!("Expected text content"), + }; + + assert_eq!( + result_text, "Received null count", + "Null case should return expected message" + ); + + // Test Some case + let some_result = client + .get_prompt(GetPromptRequestParam { + name: "test_optional_i64".into(), + arguments: Some( + serde_json::json!({ + "count": 42, + "mandatory_field": "test_some" + }) + .as_object() + .unwrap() + .clone(), + ), + }) + .await?; + + let some_result_text = match &some_result.messages.first().unwrap().content { + rmcp::model::PromptMessageContent::Text { text } => text.as_str(), + _ => panic!("Expected text content"), + }; + + assert_eq!( + some_result_text, "Received count: 42", + "Some case should return expected message" + ); + + client.cancel().await?; + server_handle.await??; + Ok(()) +} diff --git a/crates/rmcp/tests/test_prompt_routers.rs b/crates/rmcp/tests/test_prompt_routers.rs new file mode 100644 index 00000000..64bcdb58 --- /dev/null +++ b/crates/rmcp/tests/test_prompt_routers.rs @@ -0,0 +1,105 @@ +use std::collections::HashMap; + +use futures::future::BoxFuture; +use rmcp::{ + ServerHandler, + handler::server::prompt::Parameters, + model::{GetPromptResult, PromptMessage, PromptMessageRole}, +}; + +#[derive(Debug, Default)] +pub struct TestHandler { + pub _marker: std::marker::PhantomData, +} + +impl ServerHandler for TestHandler {} + +#[derive(Debug, schemars::JsonSchema, serde::Deserialize, serde::Serialize)] +pub struct Request { + pub fields: HashMap, +} + +#[derive(Debug, schemars::JsonSchema, serde::Deserialize, serde::Serialize)] +pub struct Sum { + pub a: i32, + pub b: i32, +} + +#[rmcp::prompt_router(router = "test_router")] +impl TestHandler { + #[rmcp::prompt] + async fn async_method( + &self, + Parameters(Request { fields }): Parameters, + ) -> Vec { + drop(fields); + vec![PromptMessage::new_text( + PromptMessageRole::Assistant, + "Async method response", + )] + } + + #[rmcp::prompt] + fn sync_method( + &self, + Parameters(Request { fields }): Parameters, + ) -> Vec { + drop(fields); + vec![PromptMessage::new_text( + PromptMessageRole::Assistant, + "Sync method response", + )] + } +} + +#[rmcp::prompt] +async fn async_function(Parameters(Request { fields }): Parameters) -> Vec { + drop(fields); + vec![PromptMessage::new_text( + PromptMessageRole::Assistant, + "Async function response", + )] +} + +#[rmcp::prompt] +fn async_function2(_callee: &TestHandler) -> BoxFuture<'_, GetPromptResult> { + Box::pin(async move { + GetPromptResult { + description: Some("Async function 2".to_string()), + messages: vec![PromptMessage::new_text( + PromptMessageRole::Assistant, + "Async function 2 response", + )], + } + }) +} + +#[test] +fn test_prompt_router() { + let test_prompt_router = TestHandler::<()>::test_router() + .with_route(rmcp::handler::server::router::prompt::PromptRoute::new_dyn( + async_function_prompt_attr(), + |mut context| { + Box::pin(async move { + use rmcp::handler::server::prompt::{ + FromPromptContextPart, IntoGetPromptResult, + }; + let params = Parameters::::from_prompt_context_part(&mut context)?; + let result = async_function(params).await; + result.into_get_prompt_result() + }) + }, + )) + .with_route(rmcp::handler::server::router::prompt::PromptRoute::new_dyn( + async_function2_prompt_attr(), + |context| { + Box::pin(async move { + use rmcp::handler::server::prompt::IntoGetPromptResult; + let result = async_function2(context.server).await; + result.into_get_prompt_result() + }) + }, + )); + let prompts = test_prompt_router.list_all(); + assert_eq!(prompts.len(), 4); +} diff --git a/crates/rmcp/tests/test_prompts.rs b/crates/rmcp/tests/test_prompts.rs deleted file mode 100644 index b8fa9985..00000000 --- a/crates/rmcp/tests/test_prompts.rs +++ /dev/null @@ -1,209 +0,0 @@ -use rmcp::{ - handler::server::{ServerHandler, prompt::Arguments, router::Router}, - model::{GetPromptResult, PromptMessage, PromptMessageRole}, - service::{RequestContext, RoleServer}, -}; -use schemars::JsonSchema; -use serde::{Deserialize, Serialize}; - -/// Test prompt arguments for code review -#[derive(Debug, Serialize, Deserialize, JsonSchema)] -struct CodeReviewArgs { - /// The file path to review - file_path: String, - /// Focus areas for the review - #[serde(skip_serializing_if = "Option::is_none")] - focus_areas: Option>, -} - -/// Test prompt arguments for debugging -#[derive(Debug, Serialize, Deserialize, JsonSchema)] -struct DebugAssistantArgs { - /// The error message to debug - error_message: String, - /// The programming language - language: String, - /// Additional context - #[serde(skip_serializing_if = "Option::is_none")] - context: Option, -} - -struct TestPromptServer; - -impl ServerHandler for TestPromptServer {} - -/// A simple code review prompt -#[rmcp::prompt( - name = "code_review", - description = "Reviews code for best practices and potential issues" -)] -async fn code_review_prompt( - _server: &TestPromptServer, - Arguments(args): Arguments, - _ctx: RequestContext, -) -> Result, rmcp::Error> { - let mut messages = vec![PromptMessage::new_text( - PromptMessageRole::User, - format!("Please review the code in file: {}", args.file_path), - )]; - - if let Some(focus_areas) = args.focus_areas { - messages.push(PromptMessage::new_text( - PromptMessageRole::User, - format!("Focus on these areas: {}", focus_areas.join(", ")), - )); - } - - messages.push(PromptMessage::new_text( - PromptMessageRole::Assistant, - "I'll help you review this code. Let me analyze it for best practices, potential bugs, and improvement opportunities.", - )); - - Ok(messages) -} - -/// A debugging assistant prompt -#[rmcp::prompt(name = "debug_assistant")] -async fn debug_assistant_prompt( - _server: &TestPromptServer, - Arguments(args): Arguments, - _ctx: RequestContext, -) -> Result { - let mut messages = vec![PromptMessage::new_text( - PromptMessageRole::User, - format!( - "I'm getting this error in my {} code: {}", - args.language, args.error_message - ), - )]; - - if let Some(context) = args.context { - messages.push(PromptMessage::new_text( - PromptMessageRole::User, - format!("Additional context: {}", context), - )); - } - - messages.push(PromptMessage::new_text( - PromptMessageRole::Assistant, - format!( - "I'll help you debug this {} error. Let me analyze the error message and provide solutions.", - args.language - ), - )); - - Ok(GetPromptResult { - description: Some("Helps debug programming errors with detailed analysis".to_string()), - messages, - }) -} - -/// A simple greeting prompt without arguments -#[rmcp::prompt] -async fn greeting_prompt( - _server: &TestPromptServer, - _ctx: RequestContext, -) -> Result, rmcp::Error> { - Ok(vec![ - PromptMessage::new_text( - PromptMessageRole::User, - "Hello! I'd like to start a conversation.", - ), - PromptMessage::new_text( - PromptMessageRole::Assistant, - "Hello! I'm here to help. What would you like to discuss today?", - ), - ]) -} - -#[tokio::test] -async fn test_prompt_macro_basic() { - // Test that the prompt attribute functions are generated - let greeting = greeting_prompt_prompt_attr(); - assert_eq!(greeting.name, "greeting_prompt"); - assert_eq!( - greeting.description.as_deref(), - Some("A simple greeting prompt without arguments") - ); - assert!(greeting.arguments.is_none()); - - let code_review = code_review_prompt_prompt_attr(); - assert_eq!(code_review.name, "code_review"); - assert_eq!( - code_review.description.as_deref(), - Some("Reviews code for best practices and potential issues") - ); - assert!(code_review.arguments.is_some()); - - let debug_assistant = debug_assistant_prompt_prompt_attr(); - assert_eq!(debug_assistant.name, "debug_assistant"); - assert!(debug_assistant.arguments.is_some()); -} - -#[tokio::test] -async fn test_prompt_router() { - // Create prompt routes manually - let greeting_route = rmcp::handler::server::router::prompt::PromptRoute::new( - greeting_prompt_prompt_attr(), - greeting_prompt, - ); - let code_review_route = rmcp::handler::server::router::prompt::PromptRoute::new( - code_review_prompt_prompt_attr(), - code_review_prompt, - ); - let debug_assistant_route = rmcp::handler::server::router::prompt::PromptRoute::new( - debug_assistant_prompt_prompt_attr(), - debug_assistant_prompt, - ); - - let server = Router::new(TestPromptServer) - .with_prompt(greeting_route) - .with_prompt(code_review_route) - .with_prompt(debug_assistant_route); - - // Test list prompts - let prompts = server.prompt_router.list_all(); - assert_eq!(prompts.len(), 3); - - let prompt_names: Vec<_> = prompts.iter().map(|p| p.name.as_str()).collect(); - assert!(prompt_names.contains(&"greeting_prompt")); - assert!(prompt_names.contains(&"code_review")); - assert!(prompt_names.contains(&"debug_assistant")); -} - -#[tokio::test] -async fn test_prompt_arguments_schema() { - let code_review = code_review_prompt_prompt_attr(); - let args = code_review.arguments.unwrap(); - - // Should have two arguments: file_path (required) and focus_areas (optional) - assert_eq!(args.len(), 2); - - let file_path_arg = args.iter().find(|a| a.name == "file_path").unwrap(); - assert_eq!(file_path_arg.required, Some(true)); - assert_eq!( - file_path_arg.description.as_deref(), - Some("The file path to review") - ); - - let focus_areas_arg = args.iter().find(|a| a.name == "focus_areas").unwrap(); - assert_eq!(focus_areas_arg.required, Some(false)); - assert_eq!( - focus_areas_arg.description.as_deref(), - Some("Focus areas for the review") - ); -} - -#[tokio::test] -async fn test_prompt_route_creation() { - // Test that prompt routes can be created - let route = rmcp::handler::server::router::prompt::PromptRoute::new( - code_review_prompt_prompt_attr(), - code_review_prompt, - ); - - assert_eq!(route.name(), "code_review"); -} - -// Additional integration tests would require a full server setup -// These tests demonstrate the basic functionality of the prompt system diff --git a/examples/servers/Cargo.toml b/examples/servers/Cargo.toml index a9b62bb7..25ba96de 100644 --- a/examples/servers/Cargo.toml +++ b/examples/servers/Cargo.toml @@ -33,7 +33,7 @@ tracing-subscriber = { version = "0.3", features = [ futures = "0.3" rand = { version = "0.9", features = ["std"] } axum = { version = "0.8", features = ["macros"] } -schemars = { version = "1.0", optional = true } +schemars = { version = "1.0" } reqwest = { version = "0.12", features = ["json"] } chrono = "0.4" uuid = { version = "1.6", features = ["v4", "serde"] } @@ -42,6 +42,7 @@ askama = { version = "0.14" } tower-http = { version = "0.6", features = ["cors"] } hyper = { version = "1" } hyper-util = { version = "0", features = ["server"] } +tokio-util = { version = "0.7" } [dev-dependencies] tokio-stream = { version = "0.1" } @@ -86,3 +87,7 @@ path = "src/counter_hyper_streamable_http.rs" [[example]] name = "servers_sampling_stdio" path = "src/sampling_stdio.rs" + +[[example]] +name = "servers_prompt_with_extractors" +path = "src/prompt_with_extractors.rs" diff --git a/examples/servers/src/common/counter.rs b/examples/servers/src/common/counter.rs index 4acacb8b..3bf2c534 100644 --- a/examples/servers/src/common/counter.rs +++ b/examples/servers/src/common/counter.rs @@ -3,9 +3,13 @@ use std::sync::Arc; use rmcp::{ ErrorData as McpError, RoleServer, ServerHandler, - handler::server::{router::tool::ToolRouter, tool::Parameters}, + handler::server::{ + prompt::Parameters as PromptParameters, + router::{prompt::PromptRouter, tool::ToolRouter}, + tool::Parameters as ToolParameters, + }, model::*, - schemars, + prompt, prompt_handler, prompt_router, schemars, service::RequestContext, tool, tool_handler, tool_router, }; @@ -18,10 +22,26 @@ pub struct StructRequest { pub b: i32, } +#[derive(Debug, serde::Serialize, serde::Deserialize, schemars::JsonSchema)] +pub struct ExamplePromptArgs { + /// A message to put in the prompt + pub message: String, +} + +#[derive(Debug, serde::Serialize, serde::Deserialize, schemars::JsonSchema)] +pub struct CounterAnalysisArgs { + /// The target value you're trying to reach + pub goal: i32, + /// Preferred strategy: 'fast' or 'careful' + #[serde(skip_serializing_if = "Option::is_none")] + pub strategy: Option, +} + #[derive(Clone)] pub struct Counter { counter: Arc>, tool_router: ToolRouter, + prompt_router: PromptRouter, } #[tool_router] @@ -31,6 +51,7 @@ impl Counter { Self { counter: Arc::new(Mutex::new(0)), tool_router: Self::tool_router(), + prompt_router: Self::prompt_router(), } } @@ -70,7 +91,10 @@ impl Counter { } #[tool(description = "Repeat what you say")] - fn echo(&self, Parameters(object): Parameters) -> Result { + fn echo( + &self, + ToolParameters(object): ToolParameters, + ) -> Result { Ok(CallToolResult::success(vec![Content::text( serde_json::Value::Object(object).to_string(), )])) @@ -79,14 +103,70 @@ impl Counter { #[tool(description = "Calculate the sum of two numbers")] fn sum( &self, - Parameters(StructRequest { a, b }): Parameters, + ToolParameters(StructRequest { a, b }): ToolParameters, ) -> Result { Ok(CallToolResult::success(vec![Content::text( (a + b).to_string(), )])) } } + +#[prompt_router] +impl Counter { + /// This is an example prompt that takes one required argument, message + #[prompt(name = "example_prompt")] + async fn example_prompt( + &self, + PromptParameters(args): PromptParameters, + _ctx: RequestContext, + ) -> Result, McpError> { + let prompt = format!( + "This is an example prompt with your message here: '{}'", + args.message + ); + Ok(vec![PromptMessage { + role: PromptMessageRole::User, + content: PromptMessageContent::text(prompt), + }]) + } + + /// Analyze the current counter value and suggest next steps + #[prompt(name = "counter_analysis")] + async fn counter_analysis( + &self, + PromptParameters(args): PromptParameters, + _ctx: RequestContext, + ) -> Result { + let strategy = args.strategy.unwrap_or_else(|| "careful".to_string()); + let current_value = *self.counter.lock().await; + let difference = args.goal - current_value; + + let messages = vec![ + PromptMessage::new_text( + PromptMessageRole::Assistant, + "I'll analyze the counter situation and suggest the best approach.", + ), + PromptMessage::new_text( + PromptMessageRole::User, + format!( + "Current counter value: {}\nGoal value: {}\nDifference: {}\nStrategy preference: {}\n\nPlease analyze the situation and suggest the best approach to reach the goal.", + current_value, args.goal, difference, strategy + ), + ), + ]; + + Ok(GetPromptResult { + description: Some(format!( + "Counter analysis for reaching {} from {}", + args.goal, current_value + )), + messages, + }) + } +} + #[tool_handler] +#[prompt_handler] impl ServerHandler for Counter { fn get_info(&self) -> ServerInfo { ServerInfo { @@ -97,7 +177,7 @@ impl ServerHandler for Counter { .enable_tools() .build(), server_info: Implementation::from_build_env(), - instructions: Some("This server provides a counter tool that can increment and decrement values. The counter starts at 0 and can be modified using the 'increment' and 'decrement' tools. Use 'get_value' to check the current count.".to_string()), + instructions: Some("This server provides counter tools and prompts. Tools: increment, decrement, get_value, say_hello, echo, sum. Prompts: example_prompt (takes a message), counter_analysis (analyzes counter state with a goal).".to_string()), } } @@ -142,52 +222,6 @@ impl ServerHandler for Counter { } } - async fn list_prompts( - &self, - _request: Option, - _: RequestContext, - ) -> Result { - Ok(ListPromptsResult { - next_cursor: None, - prompts: vec![Prompt::new( - "example_prompt", - Some("This is an example prompt that takes one required argument, message"), - Some(vec![PromptArgument { - name: "message".to_string(), - description: Some("A message to put in the prompt".to_string()), - required: Some(true), - }]), - )], - }) - } - - async fn get_prompt( - &self, - GetPromptRequestParam { name, arguments }: GetPromptRequestParam, - _: RequestContext, - ) -> Result { - match name.as_str() { - "example_prompt" => { - let message = arguments - .and_then(|json| json.get("message")?.as_str().map(|s| s.to_string())) - .ok_or_else(|| { - McpError::invalid_params("No message provided to example_prompt", None) - })?; - - let prompt = - format!("This is an example prompt with your message here: '{message}'"); - Ok(GetPromptResult { - description: None, - messages: vec![PromptMessage { - role: PromptMessageRole::User, - content: PromptMessageContent::text(prompt), - }], - }) - } - _ => Err(McpError::invalid_params("prompt not found", None)), - } - } - async fn list_resource_templates( &self, _request: Option, @@ -212,3 +246,76 @@ impl ServerHandler for Counter { Ok(self.get_info()) } } + +#[cfg(test)] +mod tests { + use super::*; + + #[tokio::test] + async fn test_prompt_attributes_generated() { + // Verify that the prompt macros generate the expected attributes + let example_attr = Counter::example_prompt_prompt_attr(); + assert_eq!(example_attr.name, "example_prompt"); + assert!(example_attr.description.is_some()); + assert!(example_attr.arguments.is_some()); + + let args = example_attr.arguments.unwrap(); + assert_eq!(args.len(), 1); + assert_eq!(args[0].name, "message"); + assert_eq!(args[0].required, Some(true)); + + let analysis_attr = Counter::counter_analysis_prompt_attr(); + assert_eq!(analysis_attr.name, "counter_analysis"); + assert!(analysis_attr.description.is_some()); + assert!(analysis_attr.arguments.is_some()); + + let args = analysis_attr.arguments.unwrap(); + assert_eq!(args.len(), 2); + assert_eq!(args[0].name, "goal"); + assert_eq!(args[0].required, Some(true)); + assert_eq!(args[1].name, "strategy"); + assert_eq!(args[1].required, Some(false)); + } + + #[tokio::test] + async fn test_prompt_router_has_routes() { + let router = Counter::prompt_router(); + assert!(router.has_route("example_prompt")); + assert!(router.has_route("counter_analysis")); + + let prompts = router.list_all(); + assert_eq!(prompts.len(), 2); + } + + #[tokio::test] + async fn test_example_prompt_execution() { + let counter = Counter::new(); + let context = rmcp::handler::server::prompt::PromptContext::new( + &counter, + "example_prompt".to_string(), + Some({ + let mut map = serde_json::Map::new(); + map.insert( + "message".to_string(), + serde_json::Value::String("Test message".to_string()), + ); + map + }), + RequestContext { + meta: Default::default(), + ct: tokio_util::sync::CancellationToken::new(), + id: rmcp::model::NumberOrString::String("test-1".to_string()), + peer: Default::default(), + extensions: Default::default(), + }, + ); + + let router = Counter::prompt_router(); + let result = router.get_prompt(context).await; + assert!(result.is_ok()); + + let prompt_result = result.unwrap(); + assert_eq!(prompt_result.messages.len(), 1); + assert_eq!(prompt_result.messages[0].role, PromptMessageRole::User); + } +} diff --git a/examples/servers/src/prompt_stdio.rs b/examples/servers/src/prompt_stdio.rs deleted file mode 100644 index eab5fcae..00000000 --- a/examples/servers/src/prompt_stdio.rs +++ /dev/null @@ -1,145 +0,0 @@ -use anyhow::Result; -use rmcp::{ - Error as McpError, RoleServer, ServerHandler, ServiceExt, - handler::server::prompt::arguments_from_schema, model::*, schemars, service::RequestContext, - transport::stdio, -}; -use serde::{Deserialize, Serialize}; -use tracing_subscriber::EnvFilter; - -#[derive(Debug, Serialize, Deserialize, schemars::JsonSchema)] -struct CodeReviewArgs { - /// The file path to review - file_path: String, - /// Language for syntax highlighting - #[serde(default = "default_language")] - language: String, -} - -fn default_language() -> String { - "rust".to_string() -} - -#[derive(Clone, Debug, Default)] -struct PromptExampleServer; - -impl ServerHandler for PromptExampleServer { - fn get_info(&self) -> ServerInfo { - ServerInfo { - server_info: Implementation { - name: "Prompt Example Server".to_string(), - version: "1.0.0".to_string(), - }, - instructions: Some( - concat!( - "This server demonstrates the prompt framework capabilities. ", - "It provides code review and debugging prompts." - ) - .to_string(), - ), - capabilities: ServerCapabilities::builder().enable_prompts().build(), - ..Default::default() - } - } - - async fn list_prompts( - &self, - _request: Option, - _: RequestContext, - ) -> Result { - Ok(ListPromptsResult { - next_cursor: None, - prompts: vec![ - Prompt { - name: "code_review".to_string(), - description: Some( - "Reviews code for best practices and potential issues".to_string(), - ), - arguments: arguments_from_schema::(), - }, - Prompt { - name: "debug_helper".to_string(), - description: Some("Interactive debugging assistant".to_string()), - arguments: None, - }, - ], - }) - } - - async fn get_prompt( - &self, - GetPromptRequestParam { name, arguments }: GetPromptRequestParam, - _: RequestContext, - ) -> Result { - match name.as_str() { - "code_review" => { - // Parse arguments - let args = if let Some(args_map) = arguments { - serde_json::from_value::(serde_json::Value::Object(args_map)) - .map_err(|e| { - McpError::invalid_params(format!("Invalid arguments: {}", e), None) - })? - } else { - return Err(McpError::invalid_params("Missing required arguments", None)); - }; - - Ok(GetPromptResult { - description: None, - messages: vec![ - PromptMessage::new_text( - PromptMessageRole::User, - format!( - "Please review the {} code in file: {}", - args.language, args.file_path - ), - ), - PromptMessage::new_text( - PromptMessageRole::Assistant, - "I'll analyze this code for best practices, potential bugs, and improvements.", - ), - ], - }) - } - "debug_helper" => Ok(GetPromptResult { - description: Some("Interactive debugging assistant".to_string()), - messages: vec![ - PromptMessage::new_text( - PromptMessageRole::Assistant, - "You are a helpful debugging assistant. Ask the user about their error and help them solve it.", - ), - PromptMessage::new_text( - PromptMessageRole::User, - "I need help debugging an issue in my code.", - ), - PromptMessage::new_text( - PromptMessageRole::Assistant, - "I'd be happy to help you debug your code! Please tell me:\n1. What error or issue are you experiencing?\n2. What programming language are you using?\n3. What were you trying to accomplish?", - ), - ], - }), - _ => Err(McpError::invalid_params( - format!("Unknown prompt: {}", name), - Some(serde_json::json!({ - "available_prompts": ["code_review", "debug_helper"] - })), - )), - } - } -} - -#[tokio::main] -async fn main() -> Result<()> { - // Initialize the tracing subscriber - tracing_subscriber::fmt() - .with_env_filter(EnvFilter::from_default_env().add_directive(tracing::Level::INFO.into())) - .with_writer(std::io::stderr) - .init(); - - tracing::info!("Starting Prompt Example MCP server"); - - // Create and serve the prompt server - let service = PromptExampleServer.serve(stdio()).await?; - - service.waiting().await?; - Ok(()) -}