diff --git a/engine/function-macros/src/lib.rs b/engine/function-macros/src/lib.rs
index a346f0867..7962d4ad8 100644
--- a/engine/function-macros/src/lib.rs
+++ b/engine/function-macros/src/lib.rs
@@ -6,7 +6,7 @@
use proc_macro::TokenStream;
use proc_macro2::TokenStream as TokenStream2;
-use quote::{format_ident, quote};
+use quote::{ToTokens, format_ident, quote};
use syn::{ItemFn, Meta, Token, parse::Parser, parse_macro_input, punctuated::Punctuated};
type AttrArgs = Punctuated;
@@ -160,6 +160,73 @@ fn extract_success_type(return_type: &syn::Type) -> Option {
Some(quote! { #success_ty })
}
+fn type_contains_ident(ty: &syn::Type, name: &str) -> bool {
+ match ty {
+ syn::Type::Path(type_path) => type_path.path.segments.iter().any(|seg| {
+ if seg.ident == name {
+ return true;
+ }
+ if let syn::PathArguments::AngleBracketed(args) = &seg.arguments {
+ args.args.iter().any(|arg| {
+ if let syn::GenericArgument::Type(inner_ty) = arg {
+ type_contains_ident(inner_ty, name)
+ } else {
+ false
+ }
+ })
+ } else {
+ false
+ }
+ }),
+ _ => false,
+ }
+}
+
+/// Checks whether `ty` is exactly `Option>`, tolerating
+/// fully-qualified paths such as `std::sync::Arc` or `::std::sync::Arc`.
+fn is_option_arc_session(ty: &syn::Type) -> bool {
+ let type_path = match ty {
+ syn::Type::Path(tp) => tp,
+ _ => return false,
+ };
+
+ let option_seg = match type_path.path.segments.last() {
+ Some(seg) if seg.ident == "Option" => seg,
+ _ => return false,
+ };
+
+ let option_args = match &option_seg.arguments {
+ syn::PathArguments::AngleBracketed(args) if args.args.len() == 1 => args,
+ _ => return false,
+ };
+
+ let arc_ty = match option_args.args.first() {
+ Some(syn::GenericArgument::Type(syn::Type::Path(tp))) => tp,
+ _ => return false,
+ };
+
+ let arc_seg = match arc_ty.path.segments.last() {
+ Some(seg) if seg.ident == "Arc" => seg,
+ _ => return false,
+ };
+
+ let arc_args = match &arc_seg.arguments {
+ syn::PathArguments::AngleBracketed(args) if args.args.len() == 1 => args,
+ _ => return false,
+ };
+
+ let session_ty = match arc_args.args.first() {
+ Some(syn::GenericArgument::Type(syn::Type::Path(tp))) => tp,
+ _ => return false,
+ };
+
+ session_ty
+ .path
+ .segments
+ .last()
+ .map_or(false, |seg| seg.ident == "Session")
+}
+
#[proc_macro_attribute]
pub fn function(_attr: TokenStream, item: TokenStream) -> TokenStream {
let func = parse_macro_input!(item as ItemFn);
@@ -197,12 +264,14 @@ pub fn service(attr: TokenStream, item: TokenStream) -> TokenStream {
let description = extract_optional(&metas, "description");
let method_ident = method.sig.ident.clone();
- let input_type = match method
+ let non_self_params: Vec<_> = method
.sig
.inputs
.iter()
- .find(|arg| !matches!(arg, syn::FnArg::Receiver(_)))
- {
+ .filter(|arg| !matches!(arg, syn::FnArg::Receiver(_)))
+ .collect();
+
+ let input_type = match non_self_params.first() {
Some(syn::FnArg::Typed(pat_type)) => {
let ty = &*pat_type.ty;
quote! { #ty }
@@ -210,6 +279,24 @@ pub fn service(attr: TokenStream, item: TokenStream) -> TokenStream {
_ => quote! { () },
};
+ let has_session_param =
+ if let Some(syn::FnArg::Typed(pat_type)) = non_self_params.get(1) {
+ if is_option_arc_session(&pat_type.ty) {
+ true
+ } else if type_contains_ident(&pat_type.ty, "Session") {
+ let actual = pat_type.ty.to_token_stream().to_string();
+ panic!(
+ "Session parameter on `{}` must be typed as \
+ `Option>`, found: `{}`",
+ method_ident, actual
+ );
+ } else {
+ false
+ }
+ } else {
+ false
+ };
+
// Extract return type
let return_type = match &method.sig.output {
syn::ReturnType::Type(_, ty) => &**ty,
@@ -224,9 +311,15 @@ pub fn service(attr: TokenStream, item: TokenStream) -> TokenStream {
let handler_ident = format_ident!("{}_handler", method_ident);
let description = quote!(Some(#description.into()));
+ let method_call = if has_session_param {
+ quote! { this.#method_ident(input, session).await }
+ } else {
+ quote! { this.#method_ident(input).await }
+ };
+
let result_handling = if needs_serialization {
quote! {
- let result = this.#method_ident(input).await;
+ let result = #method_call;
match result {
FunctionResult::Success(value) => {
match serde_json::to_value(&value) {
@@ -251,9 +344,7 @@ pub fn service(attr: TokenStream, item: TokenStream) -> TokenStream {
}
}
} else {
- quote! {
- this.#method_ident(input).await
- }
+ method_call
};
// Generate request_format schema
@@ -284,8 +375,47 @@ pub fn service(attr: TokenStream, item: TokenStream) -> TokenStream {
None => quote! { None },
};
- generated.push(quote! {
- {
+ let handler_and_registration = if has_session_param {
+ quote! {
+ let this = self.clone();
+ let #handler_ident = SessionHandler::new(move |input: Value, session: Option<::std::sync::Arc>| {
+ let this = this.clone();
+
+ async move {
+ let parsed: Result<#input_type, _> = serde_json::from_value(input);
+ let input = match parsed {
+ Ok(v) => v,
+ Err(err) => {
+ eprintln!(
+ "[warning] Failed to deserialize input for {}: {}",
+ #id,
+ err
+ );
+ return FunctionResult::Failure(ErrorBody {
+ code: "deserialization_error".into(),
+ message: format!("Failed to deserialize input for {}: {}", #id, err.to_string()),
+ stacktrace: None,
+ });
+ }
+ };
+
+ #result_handling
+ }
+ });
+
+ engine.register_function_handler_with_session(
+ RegisterFunctionRequest {
+ function_id: #id.into(),
+ description: #description,
+ request_format: #request_format_expr,
+ response_format: #response_format_expr,
+ metadata: None,
+ },
+ #handler_ident,
+ );
+ }
+ } else {
+ quote! {
let this = self.clone();
let #handler_ident = Handler::new(move |input: Value| {
let this = this.clone();
@@ -323,6 +453,12 @@ pub fn service(attr: TokenStream, item: TokenStream) -> TokenStream {
#handler_ident,
);
}
+ };
+
+ generated.push(quote! {
+ {
+ #handler_and_registration
+ }
});
}
Err(e) => panic!("failed to parse attributes: {}", e),
diff --git a/engine/src/condition.rs b/engine/src/condition.rs
index 7cac7b59f..8c5f27305 100644
--- a/engine/src/condition.rs
+++ b/engine/src/condition.rs
@@ -75,11 +75,17 @@ mod tests {
_req: crate::engine::RegisterFunctionRequest,
_handler: crate::engine::Handler,
) where
- H: Fn(Value) -> F + Send + Sync + 'static,
- F: std::future::Future<
- Output = crate::function::FunctionResult