Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
156 changes: 146 additions & 10 deletions engine/function-macros/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

use proc_macro::TokenStream;
use proc_macro2::TokenStream as TokenStream2;
use quote::{format_ident, quote};
use quote::{ToTokens, format_ident, quote};
use syn::{ItemFn, Meta, Token, parse::Parser, parse_macro_input, punctuated::Punctuated};

type AttrArgs = Punctuated<Meta, Token![,]>;
Expand Down Expand Up @@ -160,6 +160,73 @@ fn extract_success_type(return_type: &syn::Type) -> Option<TokenStream2> {
Some(quote! { #success_ty })
}

fn type_contains_ident(ty: &syn::Type, name: &str) -> bool {
match ty {
syn::Type::Path(type_path) => type_path.path.segments.iter().any(|seg| {
if seg.ident == name {
return true;
}
if let syn::PathArguments::AngleBracketed(args) = &seg.arguments {
args.args.iter().any(|arg| {
if let syn::GenericArgument::Type(inner_ty) = arg {
type_contains_ident(inner_ty, name)
} else {
false
}
})
} else {
false
}
}),
_ => false,
}
}

/// Checks whether `ty` is exactly `Option<Arc<Session>>`, tolerating
/// fully-qualified paths such as `std::sync::Arc` or `::std::sync::Arc`.
fn is_option_arc_session(ty: &syn::Type) -> bool {
let type_path = match ty {
syn::Type::Path(tp) => tp,
_ => return false,
};

let option_seg = match type_path.path.segments.last() {
Some(seg) if seg.ident == "Option" => seg,
_ => return false,
};

let option_args = match &option_seg.arguments {
syn::PathArguments::AngleBracketed(args) if args.args.len() == 1 => args,
_ => return false,
};

let arc_ty = match option_args.args.first() {
Some(syn::GenericArgument::Type(syn::Type::Path(tp))) => tp,
_ => return false,
};

let arc_seg = match arc_ty.path.segments.last() {
Some(seg) if seg.ident == "Arc" => seg,
_ => return false,
};

let arc_args = match &arc_seg.arguments {
syn::PathArguments::AngleBracketed(args) if args.args.len() == 1 => args,
_ => return false,
};

let session_ty = match arc_args.args.first() {
Some(syn::GenericArgument::Type(syn::Type::Path(tp))) => tp,
_ => return false,
};

session_ty
.path
.segments
.last()
.map_or(false, |seg| seg.ident == "Session")
}

#[proc_macro_attribute]
pub fn function(_attr: TokenStream, item: TokenStream) -> TokenStream {
let func = parse_macro_input!(item as ItemFn);
Expand Down Expand Up @@ -197,19 +264,39 @@ pub fn service(attr: TokenStream, item: TokenStream) -> TokenStream {
let description = extract_optional(&metas, "description");
let method_ident = method.sig.ident.clone();

let input_type = match method
let non_self_params: Vec<_> = method
.sig
.inputs
.iter()
.find(|arg| !matches!(arg, syn::FnArg::Receiver(_)))
{
.filter(|arg| !matches!(arg, syn::FnArg::Receiver(_)))
.collect();

let input_type = match non_self_params.first() {
Some(syn::FnArg::Typed(pat_type)) => {
let ty = &*pat_type.ty;
quote! { #ty }
}
_ => quote! { () },
};

let has_session_param =
if let Some(syn::FnArg::Typed(pat_type)) = non_self_params.get(1) {
if is_option_arc_session(&pat_type.ty) {
true
} else if type_contains_ident(&pat_type.ty, "Session") {
let actual = pat_type.ty.to_token_stream().to_string();
panic!(
"Session parameter on `{}` must be typed as \
`Option<Arc<Session>>`, found: `{}`",
method_ident, actual
);
} else {
false
}
} else {
false
};

// Extract return type
let return_type = match &method.sig.output {
syn::ReturnType::Type(_, ty) => &**ty,
Expand All @@ -224,9 +311,15 @@ pub fn service(attr: TokenStream, item: TokenStream) -> TokenStream {
let handler_ident = format_ident!("{}_handler", method_ident);
let description = quote!(Some(#description.into()));

let method_call = if has_session_param {
quote! { this.#method_ident(input, session).await }
} else {
quote! { this.#method_ident(input).await }
};

let result_handling = if needs_serialization {
quote! {
let result = this.#method_ident(input).await;
let result = #method_call;
match result {
FunctionResult::Success(value) => {
match serde_json::to_value(&value) {
Expand All @@ -251,9 +344,7 @@ pub fn service(attr: TokenStream, item: TokenStream) -> TokenStream {
}
}
} else {
quote! {
this.#method_ident(input).await
}
method_call
};

// Generate request_format schema
Expand Down Expand Up @@ -284,8 +375,47 @@ pub fn service(attr: TokenStream, item: TokenStream) -> TokenStream {
None => quote! { None },
};

generated.push(quote! {
{
let handler_and_registration = if has_session_param {
quote! {
let this = self.clone();
let #handler_ident = SessionHandler::new(move |input: Value, session: Option<::std::sync::Arc<Session>>| {
let this = this.clone();

async move {
let parsed: Result<#input_type, _> = serde_json::from_value(input);
let input = match parsed {
Ok(v) => v,
Err(err) => {
eprintln!(
"[warning] Failed to deserialize input for {}: {}",
#id,
err
);
return FunctionResult::Failure(ErrorBody {
code: "deserialization_error".into(),
message: format!("Failed to deserialize input for {}: {}", #id, err.to_string()),
stacktrace: None,
});
}
};

#result_handling
}
});

engine.register_function_handler_with_session(
RegisterFunctionRequest {
function_id: #id.into(),
description: #description,
request_format: #request_format_expr,
response_format: #response_format_expr,
metadata: None,
},
#handler_ident,
);
}
} else {
quote! {
let this = self.clone();
let #handler_ident = Handler::new(move |input: Value| {
let this = this.clone();
Expand Down Expand Up @@ -323,6 +453,12 @@ pub fn service(attr: TokenStream, item: TokenStream) -> TokenStream {
#handler_ident,
);
}
};

generated.push(quote! {
{
#handler_and_registration
}
});
}
Err(e) => panic!("failed to parse attributes: {}", e),
Expand Down
16 changes: 11 additions & 5 deletions engine/src/condition.rs
Original file line number Diff line number Diff line change
Expand Up @@ -75,11 +75,17 @@ mod tests {
_req: crate::engine::RegisterFunctionRequest,
_handler: crate::engine::Handler<H>,
) where
H: Fn(Value) -> F + Send + Sync + 'static,
F: std::future::Future<
Output = crate::function::FunctionResult<Option<Value>, ErrorBody>,
> + Send
+ 'static,
H: crate::engine::HandlerFn<F>,
F: std::future::Future<Output = crate::engine::HandlerOutput> + Send + 'static,
{
}
fn register_function_handler_with_session<H, F>(
&self,
_req: crate::engine::RegisterFunctionRequest,
_handler: crate::engine::SessionHandler<H>,
) where
H: crate::engine::SessionHandlerFn<F>,
F: std::future::Future<Output = crate::engine::HandlerOutput> + Send + 'static,
{
}
}
Expand Down
Loading
Loading