Skip to content

feat: Add prompt support with typed argument handling #314

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 65 additions & 0 deletions crates/rmcp-macros/src/common.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
//! Common utilities shared between different macro implementations

use quote::quote;
use syn::{Attribute, Expr, FnArg, ImplItemFn, Signature, Type};

/// Parse a None expression
pub fn none_expr() -> syn::Result<Expr> {
syn::parse2::<Expr>(quote! { None })
}

/// Extract documentation from doc attributes
pub fn extract_doc_line(existing_docs: Option<String>, attr: &Attribute) -> Option<String> {
if !attr.path().is_ident("doc") {
return None;
}

let syn::Meta::NameValue(name_value) = &attr.meta else {
return None;
};

let syn::Expr::Lit(expr_lit) = &name_value.value else {
return None;
};

let syn::Lit::Str(lit_str) = &expr_lit.lit else {
return None;
};

let content = lit_str.value().trim().to_string();
match (existing_docs, content) {
(Some(mut existing_docs), content) if !content.is_empty() => {
existing_docs.push('\n');
existing_docs.push_str(&content);
Some(existing_docs)
}
(Some(existing_docs), _) => Some(existing_docs),
(None, content) if !content.is_empty() => Some(content),
_ => None,
}
}

/// Find Parameters<T> type in function signature
/// Returns the full Parameters<T> type if found
pub fn find_parameters_type_in_sig(sig: &Signature) -> Option<Box<Type>> {
sig.inputs.iter().find_map(|input| {
if let FnArg::Typed(pat_type) = input {
if let Type::Path(type_path) = &*pat_type.ty {
if type_path
.path
.segments
.last()
.is_some_and(|type_name| type_name.ident == "Parameters")
{
return Some(pat_type.ty.clone());
}
}
}
None
})
}

/// Find Parameters<T> type in ImplItemFn
pub fn find_parameters_type_impl(fn_item: &ImplItemFn) -> Option<Box<Type>> {
find_parameters_type_in_sig(&fn_item.sig)
}
103 changes: 103 additions & 0 deletions crates/rmcp-macros/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
#[allow(unused_imports)]
use proc_macro::TokenStream;

mod common;
mod prompt;
mod prompt_handler;
mod prompt_router;
mod tool;
mod tool_handler;
mod tool_router;
Expand Down Expand Up @@ -160,3 +164,102 @@ pub fn tool_handler(attr: TokenStream, input: TokenStream) -> TokenStream {
.unwrap_or_else(|err| err.to_compile_error())
.into()
}

/// # prompt
///
/// This macro is used to mark a function as a prompt handler.
///
/// This will generate a function that returns the attribute of this prompt, with type `rmcp::model::Prompt`.
///
/// ## Usage
///
/// | field | type | usage |
/// | :- | :- | :- |
/// | `name` | `String` | The name of the prompt. If not provided, it defaults to the function name. |
/// | `description` | `String` | A description of the prompt. The document of this function will be used if not provided. |
/// | `arguments` | `Expr` | An expression that evaluates to `Option<Vec<PromptArgument>>` defining the prompt's arguments. If not provided, it will automatically generate arguments from the `Parameters<T>` type found in the function signature. |
///
/// ## Example
///
/// ```rust,ignore
/// #[prompt(name = "code_review", description = "Reviews code for best practices")]
/// pub async fn code_review_prompt(&self, Parameters(args): Parameters<CodeReviewArgs>) -> Result<Vec<PromptMessage>> {
/// // Generate prompt messages based on arguments
/// }
/// ```
#[proc_macro_attribute]
pub fn prompt(attr: TokenStream, input: TokenStream) -> TokenStream {
prompt::prompt(attr.into(), input.into())
.unwrap_or_else(|err| err.to_compile_error())
.into()
}

/// # prompt_router
///
/// This macro generates a prompt router based on functions marked with `#[rmcp::prompt]` in an implementation block.
///
/// It creates a function that returns a `PromptRouter` instance.
///
/// ## Usage
///
/// | field | type | usage |
/// | :- | :- | :- |
/// | `router` | `Ident` | The name of the router function to be generated. Defaults to `prompt_router`. |
/// | `vis` | `Visibility` | The visibility of the generated router function. Defaults to empty. |
///
/// ## Example
///
/// ```rust,ignore
/// #[prompt_router]
/// impl MyPromptHandler {
/// #[prompt]
/// pub async fn greeting_prompt(&self, Parameters(args): Parameters<GreetingArgs>) -> Result<Vec<PromptMessage>, Error> {
/// // Generate greeting prompt using args
/// }
///
/// pub fn new() -> Self {
/// Self {
/// // the default name of prompt router will be `prompt_router`
/// prompt_router: Self::prompt_router(),
/// }
/// }
/// }
/// ```
#[proc_macro_attribute]
pub fn prompt_router(attr: TokenStream, input: TokenStream) -> TokenStream {
prompt_router::prompt_router(attr.into(), input.into())
.unwrap_or_else(|err| err.to_compile_error())
.into()
}

/// # prompt_handler
///
/// This macro generates handler methods for `get_prompt` and `list_prompts` in the implementation block, using an existing `PromptRouter` instance.
///
/// ## Usage
///
/// | field | type | usage |
/// | :- | :- | :- |
/// | `router` | `Expr` | The expression to access the `PromptRouter` instance. Defaults to `self.prompt_router`. |
///
/// ## Example
/// ```rust,ignore
/// #[prompt_handler]
/// impl ServerHandler for MyPromptHandler {
/// // ...implement other handler methods
/// }
/// ```
///
/// or using a custom router expression:
/// ```rust,ignore
/// #[prompt_handler(router = self.get_prompt_router())]
/// impl ServerHandler for MyPromptHandler {
/// // ...implement other handler methods
/// }
/// ```
#[proc_macro_attribute]
pub fn prompt_handler(attr: TokenStream, input: TokenStream) -> TokenStream {
prompt_handler::prompt_handler(attr.into(), input.into())
.unwrap_or_else(|err| err.to_compile_error())
.into()
}
181 changes: 181 additions & 0 deletions crates/rmcp-macros/src/prompt.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,181 @@
use darling::{FromMeta, ast::NestedMeta};
use proc_macro2::TokenStream;
use quote::{format_ident, quote};
use syn::{Expr, Ident, ImplItemFn, ReturnType};

use crate::common::{extract_doc_line, none_expr};

#[derive(FromMeta, Default, Debug)]
#[darling(default)]
pub struct PromptAttribute {
/// The name of the prompt
pub name: Option<String>,
/// Optional description of what the prompt does
pub description: Option<String>,
/// Arguments that can be passed to the prompt
pub arguments: Option<Expr>,
}

pub struct ResolvedPromptAttribute {
pub name: String,
pub description: Option<String>,
pub arguments: Expr,
}

impl ResolvedPromptAttribute {
pub fn into_fn(self, fn_ident: Ident) -> syn::Result<ImplItemFn> {
let Self {
name,
description,
arguments,
} = self;
let description = if let Some(description) = description {
quote! { Some(#description.into()) }
} else {
quote! { None }
};
let tokens = quote! {
pub fn #fn_ident() -> rmcp::model::Prompt {
rmcp::model::Prompt {
name: #name.into(),
description: #description,
arguments: #arguments,
}
}
};
syn::parse2::<ImplItemFn>(tokens)
}
}

pub fn prompt(attr: TokenStream, input: TokenStream) -> syn::Result<TokenStream> {
let attribute = if attr.is_empty() {
Default::default()
} else {
let attr_args = NestedMeta::parse_meta_list(attr)?;
PromptAttribute::from_list(&attr_args)?
};
let mut fn_item = syn::parse2::<ImplItemFn>(input.clone())?;
let fn_ident = &fn_item.sig.ident;

let prompt_attr_fn_ident = format_ident!("{}_prompt_attr", fn_ident);

// Try to find prompt parameters from function parameters
let arguments_expr = if let Some(arguments) = attribute.arguments {
arguments
} else {
// Look for a type named Parameters in the function signature
let params_ty = crate::common::find_parameters_type_impl(&fn_item);

if let Some(params_ty) = params_ty {
// Generate arguments from the type's schema with caching
syn::parse2::<Expr>(quote! {
rmcp::handler::server::prompt::cached_arguments_from_schema::<#params_ty>()
})?
} else {
// No arguments
none_expr()?
}
};

let name = attribute.name.unwrap_or_else(|| fn_ident.to_string());
let description = attribute
.description
.or_else(|| fn_item.attrs.iter().fold(None, extract_doc_line));
let arguments = arguments_expr;

let resolved_prompt_attr = ResolvedPromptAttribute {
name: name.clone(),
description: description.clone(),
arguments: arguments.clone(),
};
let prompt_attr_fn = resolved_prompt_attr.into_fn(prompt_attr_fn_ident.clone())?;

// Modify the input function for async support (same as tool macro)
if fn_item.sig.asyncness.is_some() {
// 1. remove asyncness from sig
// 2. make return type: `futures::future::BoxFuture<'_, #ReturnType>`
// 3. make body: { Box::pin(async move { #body }) }
let new_output = syn::parse2::<ReturnType>({
let mut lt = quote! { 'static };
if let Some(receiver) = fn_item.sig.receiver() {
if let Some((_, receiver_lt)) = receiver.reference.as_ref() {
if let Some(receiver_lt) = receiver_lt {
lt = quote! { #receiver_lt };
} else {
lt = quote! { '_ };
}
}
}
match &fn_item.sig.output {
syn::ReturnType::Default => {
quote! { -> futures::future::BoxFuture<#lt, ()> }
}
syn::ReturnType::Type(_, ty) => {
quote! { -> futures::future::BoxFuture<#lt, #ty> }
}
}
})?;
let prev_block = &fn_item.block;
let new_block = syn::parse2::<syn::Block>(quote! {
{ Box::pin(async move #prev_block ) }
})?;
fn_item.sig.asyncness = None;
fn_item.sig.output = new_output;
fn_item.block = new_block;
}

Ok(quote! {
#prompt_attr_fn
#fn_item
})
}

#[cfg(test)]
mod test {
use super::*;

#[test]
fn test_prompt_macro() -> syn::Result<()> {
let attr = quote! {
name = "example-prompt",
description = "An example prompt"
};
let input = quote! {
async fn example_prompt(&self, Parameters(args): Parameters<ExampleArgs>) -> Result<String> {
Ok("Example prompt response".to_string())
}
};
let result = prompt(attr, input)?;

// Verify the output contains both the attribute function and the modified function
let result_str = result.to_string();
assert!(result_str.contains("example_prompt_prompt_attr"));
assert!(
result_str.contains("rmcp")
&& result_str.contains("model")
&& result_str.contains("Prompt")
);

Ok(())
}

#[test]
fn test_doc_comment_description() -> syn::Result<()> {
let attr = quote! {}; // No explicit description
let input = quote! {
/// This is a test prompt description
/// with multiple lines
fn test_prompt(&self) -> Result<String> {
Ok("Test".to_string())
}
};
let result = prompt(attr, input)?;

// The output should contain the description from doc comments
let result_str = result.to_string();
assert!(result_str.contains("This is a test prompt description"));
assert!(result_str.contains("with multiple lines"));

Ok(())
}
}
Loading