diff --git a/engine/function-macros/src/lib.rs b/engine/function-macros/src/lib.rs index a346f0867..7962d4ad8 100644 --- a/engine/function-macros/src/lib.rs +++ b/engine/function-macros/src/lib.rs @@ -6,7 +6,7 @@ use proc_macro::TokenStream; use proc_macro2::TokenStream as TokenStream2; -use quote::{format_ident, quote}; +use quote::{ToTokens, format_ident, quote}; use syn::{ItemFn, Meta, Token, parse::Parser, parse_macro_input, punctuated::Punctuated}; type AttrArgs = Punctuated; @@ -160,6 +160,73 @@ fn extract_success_type(return_type: &syn::Type) -> Option { Some(quote! { #success_ty }) } +fn type_contains_ident(ty: &syn::Type, name: &str) -> bool { + match ty { + syn::Type::Path(type_path) => type_path.path.segments.iter().any(|seg| { + if seg.ident == name { + return true; + } + if let syn::PathArguments::AngleBracketed(args) = &seg.arguments { + args.args.iter().any(|arg| { + if let syn::GenericArgument::Type(inner_ty) = arg { + type_contains_ident(inner_ty, name) + } else { + false + } + }) + } else { + false + } + }), + _ => false, + } +} + +/// Checks whether `ty` is exactly `Option>`, tolerating +/// fully-qualified paths such as `std::sync::Arc` or `::std::sync::Arc`. +fn is_option_arc_session(ty: &syn::Type) -> bool { + let type_path = match ty { + syn::Type::Path(tp) => tp, + _ => return false, + }; + + let option_seg = match type_path.path.segments.last() { + Some(seg) if seg.ident == "Option" => seg, + _ => return false, + }; + + let option_args = match &option_seg.arguments { + syn::PathArguments::AngleBracketed(args) if args.args.len() == 1 => args, + _ => return false, + }; + + let arc_ty = match option_args.args.first() { + Some(syn::GenericArgument::Type(syn::Type::Path(tp))) => tp, + _ => return false, + }; + + let arc_seg = match arc_ty.path.segments.last() { + Some(seg) if seg.ident == "Arc" => seg, + _ => return false, + }; + + let arc_args = match &arc_seg.arguments { + syn::PathArguments::AngleBracketed(args) if args.args.len() == 1 => args, + _ => return false, + }; + + let session_ty = match arc_args.args.first() { + Some(syn::GenericArgument::Type(syn::Type::Path(tp))) => tp, + _ => return false, + }; + + session_ty + .path + .segments + .last() + .map_or(false, |seg| seg.ident == "Session") +} + #[proc_macro_attribute] pub fn function(_attr: TokenStream, item: TokenStream) -> TokenStream { let func = parse_macro_input!(item as ItemFn); @@ -197,12 +264,14 @@ pub fn service(attr: TokenStream, item: TokenStream) -> TokenStream { let description = extract_optional(&metas, "description"); let method_ident = method.sig.ident.clone(); - let input_type = match method + let non_self_params: Vec<_> = method .sig .inputs .iter() - .find(|arg| !matches!(arg, syn::FnArg::Receiver(_))) - { + .filter(|arg| !matches!(arg, syn::FnArg::Receiver(_))) + .collect(); + + let input_type = match non_self_params.first() { Some(syn::FnArg::Typed(pat_type)) => { let ty = &*pat_type.ty; quote! { #ty } @@ -210,6 +279,24 @@ pub fn service(attr: TokenStream, item: TokenStream) -> TokenStream { _ => quote! { () }, }; + let has_session_param = + if let Some(syn::FnArg::Typed(pat_type)) = non_self_params.get(1) { + if is_option_arc_session(&pat_type.ty) { + true + } else if type_contains_ident(&pat_type.ty, "Session") { + let actual = pat_type.ty.to_token_stream().to_string(); + panic!( + "Session parameter on `{}` must be typed as \ + `Option>`, found: `{}`", + method_ident, actual + ); + } else { + false + } + } else { + false + }; + // Extract return type let return_type = match &method.sig.output { syn::ReturnType::Type(_, ty) => &**ty, @@ -224,9 +311,15 @@ pub fn service(attr: TokenStream, item: TokenStream) -> TokenStream { let handler_ident = format_ident!("{}_handler", method_ident); let description = quote!(Some(#description.into())); + let method_call = if has_session_param { + quote! { this.#method_ident(input, session).await } + } else { + quote! { this.#method_ident(input).await } + }; + let result_handling = if needs_serialization { quote! { - let result = this.#method_ident(input).await; + let result = #method_call; match result { FunctionResult::Success(value) => { match serde_json::to_value(&value) { @@ -251,9 +344,7 @@ pub fn service(attr: TokenStream, item: TokenStream) -> TokenStream { } } } else { - quote! { - this.#method_ident(input).await - } + method_call }; // Generate request_format schema @@ -284,8 +375,47 @@ pub fn service(attr: TokenStream, item: TokenStream) -> TokenStream { None => quote! { None }, }; - generated.push(quote! { - { + let handler_and_registration = if has_session_param { + quote! { + let this = self.clone(); + let #handler_ident = SessionHandler::new(move |input: Value, session: Option<::std::sync::Arc>| { + let this = this.clone(); + + async move { + let parsed: Result<#input_type, _> = serde_json::from_value(input); + let input = match parsed { + Ok(v) => v, + Err(err) => { + eprintln!( + "[warning] Failed to deserialize input for {}: {}", + #id, + err + ); + return FunctionResult::Failure(ErrorBody { + code: "deserialization_error".into(), + message: format!("Failed to deserialize input for {}: {}", #id, err.to_string()), + stacktrace: None, + }); + } + }; + + #result_handling + } + }); + + engine.register_function_handler_with_session( + RegisterFunctionRequest { + function_id: #id.into(), + description: #description, + request_format: #request_format_expr, + response_format: #response_format_expr, + metadata: None, + }, + #handler_ident, + ); + } + } else { + quote! { let this = self.clone(); let #handler_ident = Handler::new(move |input: Value| { let this = this.clone(); @@ -323,6 +453,12 @@ pub fn service(attr: TokenStream, item: TokenStream) -> TokenStream { #handler_ident, ); } + }; + + generated.push(quote! { + { + #handler_and_registration + } }); } Err(e) => panic!("failed to parse attributes: {}", e), diff --git a/engine/src/condition.rs b/engine/src/condition.rs index 7cac7b59f..8c5f27305 100644 --- a/engine/src/condition.rs +++ b/engine/src/condition.rs @@ -75,11 +75,17 @@ mod tests { _req: crate::engine::RegisterFunctionRequest, _handler: crate::engine::Handler, ) where - H: Fn(Value) -> F + Send + Sync + 'static, - F: std::future::Future< - Output = crate::function::FunctionResult, ErrorBody>, - > + Send - + 'static, + H: crate::engine::HandlerFn, + F: std::future::Future + Send + 'static, + { + } + fn register_function_handler_with_session( + &self, + _req: crate::engine::RegisterFunctionRequest, + _handler: crate::engine::SessionHandler, + ) where + H: crate::engine::SessionHandlerFn, + F: std::future::Future + Send + 'static, { } } diff --git a/engine/src/engine/mod.rs b/engine/src/engine/mod.rs index f0a9a4083..9e385305a 100644 --- a/engine/src/engine/mod.rs +++ b/engine/src/engine/mod.rs @@ -22,6 +22,7 @@ use uuid::Uuid; use crate::{ function::{Function, FunctionHandler, FunctionResult, FunctionsRegistry}, invocation::{InvocationHandler, http_function::HttpFunctionConfig}, + modules::worker::rbac_session::Session, modules::{ engine_fn::TRIGGER_WORKERS_AVAILABLE, http_functions::HttpFunctionsModule, @@ -136,14 +137,40 @@ pub struct RegisterFunctionRequest { pub metadata: Option, } +pub type HandlerOutput = FunctionResult, ErrorBody>; + +pub trait HandlerFn + Send + 'static>: + Fn(Value) -> F + Send + Sync + 'static +{ +} + +impl HandlerFn for H +where + H: Fn(Value) -> F + Send + Sync + 'static, + F: Future + Send + 'static, +{ +} + +pub trait SessionHandlerFn + Send + 'static>: + Fn(Value, Option>) -> F + Send + Sync + 'static +{ +} + +impl SessionHandlerFn for H +where + H: Fn(Value, Option>) -> F + Send + Sync + 'static, + F: Future + Send + 'static, +{ +} + pub struct Handler { - f: H, + pub f: H, } impl Handler where H: Fn(Value) -> F + Send + Sync + 'static, - F: Future, ErrorBody>> + Send + 'static, + F: Future + Send + 'static, { pub fn new(f: H) -> Self { Self { f } @@ -154,6 +181,20 @@ where } } +pub struct SessionHandler { + pub f: H, +} + +impl SessionHandler +where + H: Fn(Value, Option>) -> F + Send + Sync + 'static, + F: Future + Send + 'static, +{ + pub fn new(f: H) -> Self { + Self { f } + } +} + #[allow(async_fn_in_trait)] pub trait EngineTrait: Send + Sync { async fn call( @@ -172,8 +213,15 @@ pub trait EngineTrait: Send + Sync { request: RegisterFunctionRequest, handler: Handler, ) where - H: Fn(Value) -> F + Send + Sync + 'static, - F: Future, ErrorBody>> + Send + 'static; + H: HandlerFn, + F: Future + Send + 'static; + fn register_function_handler_with_session( + &self, + request: RegisterFunctionRequest, + handler: SessionHandler, + ) where + H: SessionHandlerFn, + F: Future + Send + 'static; } #[derive(Clone)] @@ -247,6 +295,8 @@ impl Engine { worker.add_invocation(invocation_id).await; } + let session = worker.session.clone(); + self.invocations .handle_invocation( invocation_id, @@ -256,6 +306,7 @@ impl Engine { function, traceparent, baggage, + session, ) .await } else { @@ -645,43 +696,45 @@ impl Engine { } if let Some(middleware_id) = &session.config.middleware_function_id { - let inv_id = (*invocation_id).unwrap_or_else(Uuid::new_v4); - let middleware_input = serde_json::json!({ - "function_id": function_id, - "payload": data, - "action": action, - "context": session.context, - }); - let engine = self.clone(); - let w = worker.clone(); - let middleware_id = middleware_id.clone(); - let function_id = function_id.clone(); - let traceparent = traceparent.clone(); - let baggage = baggage.clone(); - - tokio::spawn(async move { - let response = match engine.call(&middleware_id, middleware_input).await - { - Ok(result) => Message::InvocationResult { - invocation_id: inv_id, - function_id, - result, - error: None, - traceparent, - baggage, - }, - Err(err) => Message::InvocationResult { - invocation_id: inv_id, - function_id, - result: None, - error: Some(err), - traceparent, - baggage, - }, - }; - engine.send_msg(&w, response).await; - }); - return Ok(()); + if !function_id.starts_with("engine::") { + let inv_id = (*invocation_id).unwrap_or_else(Uuid::new_v4); + let middleware_input = serde_json::json!({ + "function_id": function_id, + "payload": data, + "action": action, + "context": session.context, + }); + let engine = self.clone(); + let w = worker.clone(); + let middleware_id = middleware_id.clone(); + let function_id = function_id.clone(); + let traceparent = traceparent.clone(); + let baggage = baggage.clone(); + + tokio::spawn(async move { + let response = + match engine.call(&middleware_id, middleware_input).await { + Ok(result) => Message::InvocationResult { + invocation_id: inv_id, + function_id, + result, + error: None, + traceparent, + baggage, + }, + Err(err) => Message::InvocationResult { + invocation_id: inv_id, + function_id, + result: None, + error: Some(err), + traceparent, + baggage, + }, + }; + engine.send_msg(&w, response).await; + }); + return Ok(()); + } } } @@ -1283,6 +1336,7 @@ impl EngineTrait for Engine { function, traceparent, baggage, + None, ) .await; @@ -1337,7 +1391,7 @@ impl EngineTrait for Engine { let handler_function_id = function_id.clone(); let function = Function { - handler: Arc::new(move |invocation_id, input| { + handler: Arc::new(move |invocation_id, input, _session| { let handler = handler_arc.clone(); let path = handler_function_id.clone(); Box::pin(async move { handler.handle_function(invocation_id, path, input).await }) @@ -1355,13 +1409,13 @@ impl EngineTrait for Engine { fn register_function_handler(&self, request: RegisterFunctionRequest, handler: Handler) where - H: Fn(Value) -> F + Send + Sync + 'static, - F: Future, ErrorBody>> + Send + 'static, + H: HandlerFn, + F: Future + Send + 'static, { let handler_arc: Arc = Arc::new(handler.f); let function = Function { - handler: Arc::new(move |_id, input| { + handler: Arc::new(move |_id, input, _session| { let handler = handler_arc.clone(); Box::pin(async move { handler(input).await }) }), @@ -1375,6 +1429,33 @@ impl EngineTrait for Engine { self.functions .register_function(request.function_id, function); } + + fn register_function_handler_with_session( + &self, + request: RegisterFunctionRequest, + handler: SessionHandler, + ) where + H: SessionHandlerFn, + F: Future + Send + 'static, + { + let handler_arc: Arc = Arc::new(handler.f); + + let function = Function { + handler: Arc::new(move |_id, input, session| { + let handler = handler_arc.clone(); + let session = session.clone(); + Box::pin(async move { handler(input, session).await }) + }), + _function_id: request.function_id.clone(), + _description: request.description, + request_format: request.request_format, + response_format: request.response_format, + metadata: request.metadata, + }; + + self.functions + .register_function(request.function_id, function); + } } #[cfg(test)] diff --git a/engine/src/function.rs b/engine/src/function.rs index 49ad11bc4..393bb8562 100644 --- a/engine/src/function.rs +++ b/engine/src/function.rs @@ -12,6 +12,7 @@ use futures::Future; use serde_json::Value; use uuid::Uuid; +use crate::modules::worker::rbac_session::Session; use crate::protocol::*; pub enum FunctionResult { @@ -21,7 +22,8 @@ pub enum FunctionResult { NoResult, } type HandlerFuture = Pin, ErrorBody>> + Send>>; -pub type HandlerFn = dyn Fn(Option, Value) -> HandlerFuture + Send + Sync; +pub type HandlerFn = + dyn Fn(Option, Value, Option>) -> HandlerFuture + Send + Sync; #[derive(Clone)] pub struct Function { @@ -38,8 +40,9 @@ impl Function { self, invocation_id: Option, data: Value, + session: Option>, ) -> FunctionResult, ErrorBody> { - (self.handler)(invocation_id, data.clone()).await + (self.handler)(invocation_id, data.clone(), session).await } } @@ -121,7 +124,7 @@ mod tests { /// Helper: create a dummy function with a simple handler fn make_function(id: &str) -> Function { Function { - handler: Arc::new(|_invocation_id, _input| { + handler: Arc::new(|_invocation_id, _input, _session| { Box::pin(async { FunctionResult::Success(Some(serde_json::json!({"ok": true}))) }) }), _function_id: id.to_string(), @@ -269,7 +272,7 @@ mod tests { fn registry_overwrite_existing_function() { let reg = FunctionsRegistry::new(); let func1 = Function { - handler: Arc::new(|_, _| Box::pin(async { FunctionResult::Success(None) })), + handler: Arc::new(|_, _, _| Box::pin(async { FunctionResult::Success(None) })), _function_id: "fn".to_string(), _description: Some("version 1".to_string()), request_format: None, @@ -277,7 +280,7 @@ mod tests { metadata: None, }; let func2 = Function { - handler: Arc::new(|_, _| Box::pin(async { FunctionResult::Success(None) })), + handler: Arc::new(|_, _, _| Box::pin(async { FunctionResult::Success(None) })), _function_id: "fn".to_string(), _description: Some("version 2".to_string()), request_format: None, @@ -293,7 +296,7 @@ mod tests { #[test] fn function_metadata_and_formats() { let func = Function { - handler: Arc::new(|_, _| Box::pin(async { FunctionResult::Success(None) })), + handler: Arc::new(|_, _, _| Box::pin(async { FunctionResult::Success(None) })), _function_id: "fn".to_string(), _description: None, request_format: Some(json!({"type": "object"})), @@ -313,7 +316,7 @@ mod tests { #[tokio::test] async fn call_handler_returns_success() { let func = make_function("test"); - let result = func.call_handler(None, json!({})).await; + let result = func.call_handler(None, json!({}), None).await; match result { FunctionResult::Success(Some(val)) => { assert_eq!(val, json!({"ok": true})); @@ -326,7 +329,7 @@ mod tests { async fn call_handler_with_invocation_id() { let invocation_id = Uuid::new_v4(); let func = Function { - handler: Arc::new(move |inv_id, _input| { + handler: Arc::new(move |inv_id, _input, _session| { Box::pin(async move { if inv_id.is_some() { FunctionResult::Success(Some(json!({"has_id": true}))) @@ -341,7 +344,9 @@ mod tests { response_format: None, metadata: None, }; - let result = func.call_handler(Some(invocation_id), json!({})).await; + let result = func + .call_handler(Some(invocation_id), json!({}), None) + .await; match result { FunctionResult::Success(Some(val)) => { assert_eq!(val["has_id"], true); @@ -353,7 +358,7 @@ mod tests { #[tokio::test] async fn call_handler_failure() { let func = Function { - handler: Arc::new(|_, _| { + handler: Arc::new(|_, _, _| { Box::pin(async { FunctionResult::Failure(ErrorBody { code: "test_error".to_string(), @@ -368,7 +373,7 @@ mod tests { response_format: None, metadata: None, }; - let result = func.call_handler(None, json!({})).await; + let result = func.call_handler(None, json!({}), None).await; match result { FunctionResult::Failure(e) => { assert_eq!(e.code, "test_error"); diff --git a/engine/src/invocation/mod.rs b/engine/src/invocation/mod.rs index 0249f8388..7297b7660 100644 --- a/engine/src/invocation/mod.rs +++ b/engine/src/invocation/mod.rs @@ -18,6 +18,7 @@ use crate::telemetry::SpanExt; use crate::{ function::{Function, FunctionResult}, modules::observability::metrics::get_engine_metrics, + modules::worker::rbac_session::Session, protocol::ErrorBody, }; @@ -80,6 +81,7 @@ impl InvocationHandler { function_handler: Function, traceparent: Option, baggage: Option, + session: Option>, ) -> Result, ErrorBody>, RecvError> { // Create span with dynamic name using the function_id // Using OTEL semantic conventions for FaaS (Function as a Service) @@ -121,7 +123,7 @@ impl InvocationHandler { let metrics = get_engine_metrics(); let result = function_handler - .call_handler(Some(invocation_id), body) + .call_handler(Some(invocation_id), body, session) .await; // Calculate duration diff --git a/engine/src/modules/bridge_client/mod.rs b/engine/src/modules/bridge_client/mod.rs index 79699e432..c86ba94a5 100644 --- a/engine/src/modules/bridge_client/mod.rs +++ b/engine/src/modules/bridge_client/mod.rs @@ -467,7 +467,7 @@ mod tests { .expect("bridge.invoke handler"); match invoke .clone() - .call_handler(None, json!({ "bad": true })) + .call_handler(None, json!({ "bad": true }), None) .await { FunctionResult::Failure(err) => assert_eq!(err.code, "deserialization_error"), @@ -481,6 +481,7 @@ mod tests { "data": { "hello": "world" }, "timeout_ms": 1 }), + None, ) .await { @@ -502,6 +503,7 @@ mod tests { "function_id": "remote.echo", "data": { "hello": "world" } }), + None, ) .await { @@ -513,7 +515,10 @@ mod tests { .functions .get("forward.echo") .expect("forward handler"); - match forward.call_handler(None, json!({ "value": 1 })).await { + match forward + .call_handler(None, json!({ "value": 1 }), None) + .await + { FunctionResult::Failure(err) => { assert_eq!(err.code, "bridge_error"); assert!(!err.message.is_empty()); @@ -568,7 +573,7 @@ mod tests { .get("bridge.invoke_async") .expect("bridge.invoke_async handler"); match invoke_async - .call_handler(None, json!({ "bad": true })) + .call_handler(None, json!({ "bad": true }), None) .await { FunctionResult::Failure(err) => assert_eq!(err.code, "deserialization_error"), diff --git a/engine/src/modules/engine_fn/mod.rs b/engine/src/modules/engine_fn/mod.rs index 84edbb69f..6f3407160 100644 --- a/engine/src/modules/engine_fn/mod.rs +++ b/engine/src/modules/engine_fn/mod.rs @@ -14,9 +14,10 @@ use serde::{Deserialize, Serialize}; use serde_json::Value; use crate::{ - engine::{Engine, EngineTrait, Handler, RegisterFunctionRequest}, + engine::{Engine, EngineTrait, Handler, RegisterFunctionRequest, SessionHandler}, function::FunctionResult, modules::module::Module, + modules::worker::rbac_session::Session, protocol::{ErrorBody, StreamChannelRef, WorkerMetrics}, trigger::{Trigger, TriggerRegistrator, TriggerType}, workers::WorkerTelemetryMeta, @@ -467,6 +468,7 @@ impl EngineFunctionsModule { pub async fn get_functions( &self, input: FunctionsListInput, + session: Option>, ) -> FunctionResult { let mut functions = self.list_functions(); @@ -474,6 +476,19 @@ impl EngineFunctionsModule { functions.retain(|f| !f.function_id.starts_with("engine::")); } + if let Some(session) = &session { + functions.retain(|f| { + let function = self.engine.functions.get(&f.function_id); + crate::modules::worker::rbac_config::is_function_allowed( + &f.function_id, + session.config.rbac.clone(), + &session.allowed_functions, + &session.forbidden_functions, + function.as_ref(), + ) + }); + } + FunctionResult::Success(FunctionsListResult { functions }) } @@ -1095,9 +1110,12 @@ mod tests { } let filtered = module - .get_functions(FunctionsListInput { - include_internal: None, - }) + .get_functions( + FunctionsListInput { + include_internal: None, + }, + None, + ) .await; match filtered { FunctionResult::Success(result) => { @@ -1108,9 +1126,12 @@ mod tests { } let all = module - .get_functions(FunctionsListInput { - include_internal: Some(true), - }) + .get_functions( + FunctionsListInput { + include_internal: Some(true), + }, + None, + ) .await; match all { FunctionResult::Success(result) => { diff --git a/engine/src/modules/queue/adapters/builtin/adapter.rs b/engine/src/modules/queue/adapters/builtin/adapter.rs index 3b7016fec..f96818d18 100644 --- a/engine/src/modules/queue/adapters/builtin/adapter.rs +++ b/engine/src/modules/queue/adapters/builtin/adapter.rs @@ -637,7 +637,7 @@ mod tests { fn register_test_function(engine: &Arc, function_id: &str, success: bool) { let function = Function { - handler: Arc::new(move |_invocation_id, _input| { + handler: Arc::new(move |_invocation_id, _input, _session| { Box::pin(async move { if success { FunctionResult::Success(Some(json!({ "ok": true }))) diff --git a/engine/src/modules/queue/queue.rs b/engine/src/modules/queue/queue.rs index 1a1d818b7..4d95b28db 100644 --- a/engine/src/modules/queue/queue.rs +++ b/engine/src/modules/queue/queue.rs @@ -1692,7 +1692,7 @@ mod tests { let counter = call_count.clone(); let function = crate::function::Function { - handler: Arc::new(move |_invocation_id, _input| { + handler: Arc::new(move |_invocation_id, _input, _session| { let counter = counter.clone(); Box::pin(async move { counter.fetch_add(1, Ordering::SeqCst); @@ -1747,7 +1747,7 @@ mod tests { let counter = call_count.clone(); let function = crate::function::Function { - handler: Arc::new(move |_invocation_id, _input| { + handler: Arc::new(move |_invocation_id, _input, _session| { let counter = counter.clone(); Box::pin(async move { counter.fetch_add(1, Ordering::SeqCst); @@ -1805,7 +1805,7 @@ mod tests { let order_ref = invocation_order.clone(); let function = crate::function::Function { - handler: Arc::new(move |_invocation_id, input| { + handler: Arc::new(move |_invocation_id, input, _session| { let order_ref = order_ref.clone(); Box::pin(async move { let txn_id = input @@ -1873,7 +1873,7 @@ mod tests { let ts_ref = timestamps.clone(); let function = crate::function::Function { - handler: Arc::new(move |_invocation_id, input| { + handler: Arc::new(move |_invocation_id, input, _session| { let ts_ref = ts_ref.clone(); Box::pin(async move { let task_id = input diff --git a/engine/src/modules/telemetry/mod.rs b/engine/src/modules/telemetry/mod.rs index c5e14b8c9..a05a45e46 100644 --- a/engine/src/modules/telemetry/mod.rs +++ b/engine/src/modules/telemetry/mod.rs @@ -855,8 +855,9 @@ mod tests { } fn register_test_function(engine: &Arc, function_id: &str) { - let handler: Arc = - Arc::new(|_invocation_id, _input| Box::pin(async { FunctionResult::NoResult })); + let handler: Arc = Arc::new(|_invocation_id, _input, _session| { + Box::pin(async { FunctionResult::NoResult }) + }); engine.functions.register_function( function_id.to_string(), Function { @@ -1588,7 +1589,7 @@ mod tests { fn test_collect_functions_and_triggers_filters_engine_prefix() { let engine = make_test_engine(); - let handler: Arc = Arc::new(|_inv_id, _input| { + let handler: Arc = Arc::new(|_inv_id, _input, _session| { Box::pin(async { crate::function::FunctionResult::NoResult }) }); engine.functions.register_function( diff --git a/engine/src/modules/worker/rbac_session.rs b/engine/src/modules/worker/rbac_session.rs index 83b618a73..638fe6c30 100644 --- a/engine/src/modules/worker/rbac_session.rs +++ b/engine/src/modules/worker/rbac_session.rs @@ -45,7 +45,7 @@ pub struct Session { pub function_registration_prefix: Option, } -#[derive(Deserialize)] +#[derive(Deserialize, Debug)] pub(crate) struct AuthResult { #[serde(default)] allowed_functions: Vec, diff --git a/engine/tests/common/queue_helpers.rs b/engine/tests/common/queue_helpers.rs index c718f541a..09a410683 100644 --- a/engine/tests/common/queue_helpers.rs +++ b/engine/tests/common/queue_helpers.rs @@ -72,7 +72,7 @@ pub fn register_counting_function( counter: Arc, ) { let function = Function { - handler: Arc::new(move |_invocation_id, _input| { + handler: Arc::new(move |_invocation_id, _input, _session| { let count = counter.clone(); Box::pin(async move { count.fetch_add(1, Ordering::SeqCst); @@ -99,7 +99,7 @@ pub fn register_order_recording_function( record: Arc>>, ) { let function = Function { - handler: Arc::new(move |_invocation_id, input| { + handler: Arc::new(move |_invocation_id, input, _session| { let rec = record.clone(); Box::pin(async move { let value = input.get(field_name).cloned().unwrap_or(Value::Null); @@ -127,7 +127,7 @@ pub fn register_slow_function( timestamps: Arc>>, ) { let function = Function { - handler: Arc::new(move |_invocation_id, _input| { + handler: Arc::new(move |_invocation_id, _input, _session| { let ts = timestamps.clone(); let d = delay; Box::pin(async move { @@ -159,7 +159,7 @@ pub fn register_group_order_recording_function( processing_delay: Duration, ) { let function = Function { - handler: Arc::new(move |_invocation_id, input| { + handler: Arc::new(move |_invocation_id, input, _session| { let rec = records.clone(); let delay = processing_delay; Box::pin(async move { @@ -193,7 +193,7 @@ pub fn register_failing_function( call_count: Arc, ) { let function = Function { - handler: Arc::new(move |_invocation_id, _input| { + handler: Arc::new(move |_invocation_id, _input, _session| { let count = call_count.clone(); Box::pin(async move { count.fetch_add(1, Ordering::SeqCst); @@ -225,7 +225,7 @@ pub fn register_failing_function_with_timestamps( timestamps: Arc>>, ) { let function = Function { - handler: Arc::new(move |_invocation_id, _input| { + handler: Arc::new(move |_invocation_id, _input, _session| { let count = call_count.clone(); let ts = timestamps.clone(); Box::pin(async move { @@ -257,7 +257,7 @@ pub fn register_payload_capturing_function( captured: Arc>>, ) { let function = Function { - handler: Arc::new(move |_invocation_id, input| { + handler: Arc::new(move |_invocation_id, input, _session| { let store = captured.clone(); Box::pin(async move { store.lock().await.push(input); @@ -285,7 +285,7 @@ pub fn register_condition_function( ) { let expected = expected_value.clone(); let function = Function { - handler: Arc::new(move |_invocation_id, input| { + handler: Arc::new(move |_invocation_id, input, _session| { let exp = expected.clone(); Box::pin(async move { let matches = input.get(field).map(|v| *v == exp).unwrap_or(false); @@ -311,7 +311,7 @@ pub fn register_panicking_function( success_count: Arc, ) { let function = Function { - handler: Arc::new(move |_invocation_id, input| { + handler: Arc::new(move |_invocation_id, input, _session| { let count = success_count.clone(); Box::pin(async move { let should_panic = input diff --git a/engine/tests/dlq_redrive_e2e.rs b/engine/tests/dlq_redrive_e2e.rs index 0efe9c4b6..20c2d6284 100644 --- a/engine/tests/dlq_redrive_e2e.rs +++ b/engine/tests/dlq_redrive_e2e.rs @@ -100,13 +100,13 @@ async fn invoke_redrive( .get("iii::queue::redrive") .expect("iii::queue::redrive should be registered"); function - .call_handler(None, json!({ "queue": queue_name })) + .call_handler(None, json!({ "queue": queue_name }), None) .await } fn register_failing_function(engine: &Arc, function_id: &str, call_count: Arc) { let function = Function { - handler: Arc::new(move |_invocation_id, _input| { + handler: Arc::new(move |_invocation_id, _input, _session| { let count = call_count.clone(); Box::pin(async move { count.fetch_add(1, Ordering::SeqCst); @@ -130,7 +130,7 @@ fn register_failing_function(engine: &Arc, function_id: &str, call_count fn register_counting_function(engine: &Arc, function_id: &str, counter: Arc) { let function = Function { - handler: Arc::new(move |_invocation_id, _input| { + handler: Arc::new(move |_invocation_id, _input, _session| { let count = counter.clone(); Box::pin(async move { count.fetch_add(1, Ordering::SeqCst); @@ -157,7 +157,7 @@ fn register_flaky_function( failures_before_success: u64, ) { let function = Function { - handler: Arc::new(move |_invocation_id, _input| { + handler: Arc::new(move |_invocation_id, _input, _session| { let fc = fail_count.clone(); let sc = success_count.clone(); let threshold = failures_before_success; diff --git a/engine/tests/queue_e2e_fanout.rs b/engine/tests/queue_e2e_fanout.rs index a61b172b6..2d99c05c1 100644 --- a/engine/tests/queue_e2e_fanout.rs +++ b/engine/tests/queue_e2e_fanout.rs @@ -20,7 +20,7 @@ fn register_capturing_function(engine: &Arc, function_id: &str) -> Arc>> = Arc::new(Mutex::new(Vec::new())); let cap = captured.clone(); let function = Function { - handler: Arc::new(move |_invocation_id, input| { + handler: Arc::new(move |_invocation_id, input, _session| { let rec = cap.clone(); Box::pin(async move { rec.lock().await.push(input); @@ -43,7 +43,7 @@ fn register_counting_fn(engine: &Arc, function_id: &str) -> Arc>> = Arc::new(Mutex::new(Vec::new())); let cap = captured.clone(); let function = Function { - handler: Arc::new(move |_invocation_id, input| { + handler: Arc::new(move |_invocation_id, input, _session| { let rec = cap.clone(); Box::pin(async move { rec.lock().await.push(input); @@ -1148,7 +1148,7 @@ fn register_rmq_counting_fn(engine: &Arc, function_id: &str) -> Arc { } }) + it('should only list allowed functions for valid-token worker', async () => { + const iiiClient = registerWorker(EW_URL, { + headers: { 'x-test-token': 'valid-token' }, + otel: { enabled: false }, + }) + + try { + await sleep(1000) + + const functions = await iiiClient.listFunctions() + const functionIds = functions.map((f) => f.function_id) + + expect(functionIds).toContain('test::ew::valid-token-echo') + expect(functionIds).toContain('test::ew::public::echo') + expect(functionIds).toContain('test::ew::meta-public') + + expect(functionIds).not.toContain('test::ew::private') + expect(functionIds).not.toContain('test::rbac-worker::auth') + } finally { + await iiiClient.shutdown() + } + }) + + it('should only list exposed functions for restricted-token worker', async () => { + const iiiClient = registerWorker(EW_URL, { + headers: { 'x-test-token': 'restricted-token' }, + otel: { enabled: false }, + }) + + try { + await sleep(1000) + + const functions = await iiiClient.listFunctions() + const functionIds = functions.map((f) => f.function_id) + + expect(functionIds).toContain('test::ew::public::echo') + expect(functionIds).toContain('test::ew::meta-public') + + expect(functionIds).not.toContain('test::ew::valid-token-echo') + expect(functionIds).not.toContain('test::ew::private') + expect(functionIds).not.toContain('test::rbac-worker::auth') + } finally { + await iiiClient.shutdown() + } + }) + it('should apply function_registration_prefix and strip on invocation', async () => { const iiiClient = registerWorker(EW_URL, { headers: { 'x-test-token': 'prefix-token' }, diff --git a/sdk/packages/python/iii/tests/test_rbac_workers.py b/sdk/packages/python/iii/tests/test_rbac_workers.py index 22efb6ea8..8427e3ac8 100644 --- a/sdk/packages/python/iii/tests/test_rbac_workers.py +++ b/sdk/packages/python/iii/tests/test_rbac_workers.py @@ -275,6 +275,48 @@ def test_should_deny_function_registration_via_hook(self, iii_server): finally: iii_client.shutdown() + def test_list_functions_only_returns_allowed_for_valid_token(self, iii_server): + iii_client = register_worker( + EW_URL, + InitOptions(otel={"enabled": False}, headers={"x-test-token": "valid-token"}), + ) + + try: + time.sleep(1.0) + + functions = iii_client.list_functions() + function_ids = [f.function_id for f in functions] + + assert "test::ew::valid-token-echo" in function_ids + assert "test::ew::public::echo" in function_ids + assert "test::ew::meta-public" in function_ids + + assert "test::ew::private" not in function_ids + assert "test::rbac-worker::auth" not in function_ids + finally: + iii_client.shutdown() + + def test_list_functions_only_returns_exposed_for_restricted_token(self, iii_server): + iii_client = register_worker( + EW_URL, + InitOptions(otel={"enabled": False}, headers={"x-test-token": "restricted-token"}), + ) + + try: + time.sleep(1.0) + + functions = iii_client.list_functions() + function_ids = [f.function_id for f in functions] + + assert "test::ew::public::echo" in function_ids + assert "test::ew::meta-public" in function_ids + + assert "test::ew::valid-token-echo" not in function_ids + assert "test::ew::private" not in function_ids + assert "test::rbac-worker::auth" not in function_ids + finally: + iii_client.shutdown() + def test_function_registration_prefix(self, iii_server): iii_client = register_worker( EW_URL, diff --git a/sdk/packages/rust/iii/tests/rbac_workers.rs b/sdk/packages/rust/iii/tests/rbac_workers.rs index 51758c3c4..eb05a6355 100644 --- a/sdk/packages/rust/iii/tests/rbac_workers.rs +++ b/sdk/packages/rust/iii/tests/rbac_workers.rs @@ -541,3 +541,105 @@ async fn should_apply_function_registration_prefix_and_strip_on_invocation() { iii_client.shutdown_async().await; } + +#[tokio::test(flavor = "current_thread")] +#[serial] +async fn should_only_list_allowed_functions_for_valid_token() { + ensure_functions_registered(); + common::settle().await; + tokio::time::sleep(Duration::from_millis(700)).await; + + let mut headers = HashMap::new(); + headers.insert("x-test-token".to_string(), "valid-token".to_string()); + + let iii_client = register_worker( + &ew_url(), + InitOptions { + headers: Some(headers), + ..Default::default() + }, + ); + + tokio::time::sleep(Duration::from_millis(1000)).await; + + let functions = iii_client + .list_functions() + .await + .expect("list_functions should succeed"); + let ids: Vec<&str> = functions.iter().map(|f| f.function_id.as_str()).collect(); + + assert!( + ids.contains(&"test::ew::valid-token-echo"), + "should contain allowed function" + ); + assert!( + ids.contains(&"test::ew::public::echo"), + "should contain exposed public function" + ); + assert!( + ids.contains(&"test::ew::meta-public"), + "should contain metadata-matched function" + ); + + assert!( + !ids.contains(&"test::ew::private"), + "should not contain private function" + ); + assert!( + !ids.contains(&"test::rbac-worker::auth"), + "should not contain auth function" + ); + + iii_client.shutdown_async().await; +} + +#[tokio::test(flavor = "current_thread")] +#[serial] +async fn should_only_list_exposed_functions_for_restricted_token() { + ensure_functions_registered(); + common::settle().await; + tokio::time::sleep(Duration::from_millis(700)).await; + + let mut headers = HashMap::new(); + headers.insert("x-test-token".to_string(), "restricted-token".to_string()); + + let iii_client = register_worker( + &ew_url(), + InitOptions { + headers: Some(headers), + ..Default::default() + }, + ); + + tokio::time::sleep(Duration::from_millis(1000)).await; + + let functions = iii_client + .list_functions() + .await + .expect("list_functions should succeed"); + let ids: Vec<&str> = functions.iter().map(|f| f.function_id.as_str()).collect(); + + assert!( + ids.contains(&"test::ew::public::echo"), + "should contain exposed public function" + ); + assert!( + ids.contains(&"test::ew::meta-public"), + "should contain metadata-matched function" + ); + + assert!( + !ids.contains(&"test::ew::valid-token-echo"), + "should not contain valid-token-only function" + ); + assert!( + !ids.contains(&"test::ew::private"), + "should not contain private function" + ); + assert!( + !ids.contains(&"test::rbac-worker::auth"), + "should not contain auth function" + ); + + iii_client.shutdown_async().await; +}