Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 14 additions & 29 deletions docs/how-to/worker-rbac.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ modules:

## 3. Write a Middleware Function (Optional)

A middleware sits between the worker and the target function. Use it for validation, rate limiting, audit logging, or enriching the payload with auth context. The function receives a `MiddlewareFunctionInput`.
A middleware sits between the worker and the target function. Use it for validation, rate limiting, audit logging, or enriching the payload with auth context. The function receives a `MiddlewareFunctionInput` and returns the (possibly enriched) payload. The engine then invokes the target function with the returned payload, preserving the caller's RBAC session.

<Tabs>
<Tab title="Node / TypeScript">
Expand All @@ -172,16 +172,11 @@ iii.registerFunction(
async (input: MiddlewareFunctionInput) => {
console.log(`[audit] user=${input.context.user_id} invoking ${input.function_id}`)

const enrichedPayload = {
return {
...input.payload,
_caller_id: input.context.user_id,
_caller_role: input.context.role,
}

return iii.trigger({
function_id: input.function_id,
payload: enrichedPayload,
})
},
)
```
Expand All @@ -195,48 +190,38 @@ async def middleware_function(data: dict) -> dict:

print(f"[audit] user={mid.context['user_id']} invoking {mid.function_id}")

enriched_payload = {
return {
**mid.payload,
'_caller_id': mid.context['user_id'],
'_caller_role': mid.context['role'],
}

return await iii.trigger_async({
'function_id': mid.function_id,
'payload': enriched_payload,
})

iii.register_function({'id': 'my-project::middleware-function'}, middleware_function)
```
</Tab>
<Tab title="Rust">
```rust
use iii_sdk::{MiddlewareFunctionInput, RegisterFunction, TriggerRequest};
use iii_sdk::{MiddlewareFunctionInput, RegisterFunction};
use serde_json::json;

let iii_clone = iii.clone();
iii.register_function(RegisterFunction::new_async(
"my-project::middleware-function",
move |input: MiddlewareFunctionInput| {
let iii = iii_clone.clone();
async move {
let mut enriched = input.payload.as_object().cloned().unwrap_or_default();
enriched.insert("_caller_id".into(), json!(input.context.get("user_id")));
enriched.insert("_caller_role".into(), json!(input.context.get("role")));

iii.trigger(TriggerRequest {
function_id: input.function_id,
payload: json!(enriched),
action: None,
timeout_ms: None,
}).await
}
|input: MiddlewareFunctionInput| async move {
let mut enriched = input.payload.as_object().cloned().unwrap_or_default();
enriched.insert("_caller_id".into(), json!(input.context.get("user_id")));
enriched.insert("_caller_role".into(), json!(input.context.get("role")));

Ok::<_, iii_sdk::IIIError>(json!(enriched))
},
));
```
</Tab>
</Tabs>

<Warning>
**Do not call `trigger()` on the target function from inside middleware.** The middleware runs on the internal engine bridge worker, not on the RBAC worker. If the middleware calls `trigger()` itself, the target function executes under the bridge worker's session, which has no RBAC restrictions. This bypasses `expose_functions` filtering and means functions like `listFunctions` will return every registered function regardless of the caller's permissions. Always return the enriched payload from the middleware and let the engine invoke the target function with the original caller's session.
</Warning>

Add it to your config. Note that `middleware_function_id` sits at the `config` level, not inside `rbac`:

```yaml title="iii-config.yaml"
Expand Down
99 changes: 90 additions & 9 deletions engine/function-macros/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,28 @@ fn extract_success_type(return_type: &syn::Type) -> Option<TokenStream2> {
Some(quote! { #success_ty })
}

fn type_contains_ident(ty: &syn::Type, name: &str) -> bool {
match ty {
syn::Type::Path(type_path) => type_path.path.segments.iter().any(|seg| {
if seg.ident == name {
return true;
}
if let syn::PathArguments::AngleBracketed(args) = &seg.arguments {
args.args.iter().any(|arg| {
if let syn::GenericArgument::Type(inner_ty) = arg {
type_contains_ident(inner_ty, name)
} else {
false
}
})
} else {
false
}
}),
_ => false,
}
}

#[proc_macro_attribute]
pub fn function(_attr: TokenStream, item: TokenStream) -> TokenStream {
let func = parse_macro_input!(item as ItemFn);
Expand Down Expand Up @@ -197,19 +219,29 @@ pub fn service(attr: TokenStream, item: TokenStream) -> TokenStream {
let description = extract_optional(&metas, "description");
let method_ident = method.sig.ident.clone();

let input_type = match method
let non_self_params: Vec<_> = method
.sig
.inputs
.iter()
.find(|arg| !matches!(arg, syn::FnArg::Receiver(_)))
{
.filter(|arg| !matches!(arg, syn::FnArg::Receiver(_)))
.collect();

let input_type = match non_self_params.first() {
Some(syn::FnArg::Typed(pat_type)) => {
let ty = &*pat_type.ty;
quote! { #ty }
}
_ => quote! { () },
};

let has_session_param = non_self_params.get(1).map_or(false, |arg| {
if let syn::FnArg::Typed(pat_type) = arg {
type_contains_ident(&pat_type.ty, "Session")
} else {
false
}
});

// Extract return type
let return_type = match &method.sig.output {
syn::ReturnType::Type(_, ty) => &**ty,
Expand All @@ -224,9 +256,15 @@ pub fn service(attr: TokenStream, item: TokenStream) -> TokenStream {
let handler_ident = format_ident!("{}_handler", method_ident);
let description = quote!(Some(#description.into()));

let method_call = if has_session_param {
quote! { this.#method_ident(input, session).await }
} else {
quote! { this.#method_ident(input).await }
};

let result_handling = if needs_serialization {
quote! {
let result = this.#method_ident(input).await;
let result = #method_call;
match result {
FunctionResult::Success(value) => {
match serde_json::to_value(&value) {
Expand All @@ -251,9 +289,7 @@ pub fn service(attr: TokenStream, item: TokenStream) -> TokenStream {
}
}
} else {
quote! {
this.#method_ident(input).await
}
method_call
};

// Generate request_format schema
Expand Down Expand Up @@ -284,8 +320,47 @@ pub fn service(attr: TokenStream, item: TokenStream) -> TokenStream {
None => quote! { None },
};

generated.push(quote! {
{
let handler_and_registration = if has_session_param {
quote! {
let this = self.clone();
let #handler_ident = SessionHandler::new(move |input: Value, session: Option<::std::sync::Arc<Session>>| {
let this = this.clone();

async move {
let parsed: Result<#input_type, _> = serde_json::from_value(input);
let input = match parsed {
Ok(v) => v,
Err(err) => {
eprintln!(
"[warning] Failed to deserialize input for {}: {}",
#id,
err
);
return FunctionResult::Failure(ErrorBody {
code: "deserialization_error".into(),
message: format!("Failed to deserialize input for {}: {}", #id, err.to_string()),
stacktrace: None,
});
}
};

#result_handling
}
});

engine.register_function_handler_with_session(
RegisterFunctionRequest {
function_id: #id.into(),
description: #description,
request_format: #request_format_expr,
response_format: #response_format_expr,
metadata: None,
},
#handler_ident,
);
}
} else {
quote! {
let this = self.clone();
let #handler_ident = Handler::new(move |input: Value| {
let this = this.clone();
Expand Down Expand Up @@ -323,6 +398,12 @@ pub fn service(attr: TokenStream, item: TokenStream) -> TokenStream {
#handler_ident,
);
}
};

generated.push(quote! {
{
#handler_and_registration
}
});
}
Err(e) => panic!("failed to parse attributes: {}", e),
Expand Down
18 changes: 18 additions & 0 deletions engine/src/condition.rs
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,24 @@ mod tests {
+ 'static,
{
}
fn register_function_handler_with_session<H, F>(
&self,
_req: crate::engine::RegisterFunctionRequest,
_handler: crate::engine::SessionHandler<H>,
) where
H: Fn(
Value,
Option<std::sync::Arc<crate::modules::worker::rbac_session::Session>>,
) -> F
+ Send
+ Sync
+ 'static,
F: std::future::Future<
Output = crate::function::FunctionResult<Option<Value>, ErrorBody>,
> + Send
+ 'static,
{
}
}

#[tokio::test]
Expand Down
Loading
Loading