Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 21 additions & 2 deletions dapr_agents/workflow/decorators/messaging.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,21 @@
import logging
import warnings
from copy import deepcopy
from typing import Any, Callable, Optional, get_type_hints

from dapr_agents.workflow.utils.core import is_valid_routable_model
from dapr_agents.workflow.utils.messaging import extract_message_models

logger = logging.getLogger(__name__)

_MESSAGE_ROUTER_DEPRECATION_MESSAGE = (
"@message_router (legacy version from dapr_agents.workflow.decorators.messaging) "
"is deprecated and will be removed in a future release. "
"Please migrate to the updated decorator in "
"`dapr_agents.workflow.decorators.routers`, which supports "
"Union types, forward references, and explicit Dapr workflow integration."
)
Comment on lines +11 to +17
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i vote we just remove since we are not v1.0 and beyond, we can make breaking changes and remove the message_router as long as we let folks know in the next release announcement :)



def message_router(
func: Optional[Callable[..., Any]] = None,
Expand All @@ -16,7 +26,8 @@ def message_router(
broadcast: bool = False,
) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
"""
Decorator for registering message handlers by inspecting type hints on the 'message' argument.
[DEPRECATED] Legacy decorator for registering message handlers by inspecting type hints
on the 'message' argument.

This decorator:
- Extracts the expected message model type from function annotations.
Expand All @@ -36,6 +47,12 @@ def message_router(
"""

def decorator(f: Callable[..., Any]) -> Callable[..., Any]:
warnings.warn(
_MESSAGE_ROUTER_DEPRECATION_MESSAGE,
DeprecationWarning,
stacklevel=2,
)

is_workflow = hasattr(f, "_is_workflow")
workflow_name = getattr(f, "_workflow_name", None)

Expand All @@ -56,7 +73,9 @@ def decorator(f: Callable[..., Any]) -> Callable[..., Any]:
)

logger.debug(
f"@message_router: '{f.__name__}' => models {[m.__name__ for m in message_models]}"
"@message_router (legacy): '%s' => models %s",
f.__name__,
[m.__name__ for m in message_models],
)

# Attach metadata for later registration
Expand Down
113 changes: 113 additions & 0 deletions dapr_agents/workflow/decorators/routers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
from __future__ import annotations

import inspect
import logging
from copy import deepcopy
from typing import (
Any,
Callable,
Optional,
get_type_hints,
)

from dapr_agents.workflow.utils.core import is_supported_model
from dapr_agents.workflow.utils.routers import extract_message_models

logger = logging.getLogger(__name__)


def message_router(
func: Optional[Callable[..., Any]] = None,
*,
pubsub: Optional[str] = None,
topic: Optional[str] = None,
dead_letter_topic: Optional[str] = None,
broadcast: bool = False,
) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
"""
Decorate a message handler with routing metadata.
The handler must accept a parameter named `message`. Its type hint defines the
expected payload model(s), e.g.:
@message_router(pubsub="pubsub", topic="orders")
def on_order(message: OrderCreated): ...
@message_router(pubsub="pubsub", topic="events")
def on_event(message: Union[Foo, Bar]): ...
Args:
func: (optional) bare-decorator form support.
pubsub: Name of the Dapr pub/sub component (required when used with args).
topic: Topic name to subscribe to (required when used with args).
dead_letter_topic: Optional dead-letter topic (defaults to f"{topic}_DEAD").
broadcast: Optional flag you can use downstream for fan-out semantics.
Comment on lines +20 to +44
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

some of these say required in the docs string but the code shows optional. Are func, pubsub, topic the ones that we require?

Returns:
The original function tagged with `_message_router_data`.
"""

def decorator(f: Callable[..., Any]) -> Callable[..., Any]:
# Validate required kwargs only when decorator is used with args
if pubsub is None or topic is None:
raise ValueError(
"`pubsub` and `topic` are required when using @message_router with arguments."
)

sig = inspect.signature(f)
if "message" not in sig.parameters:
raise ValueError(f"'{f.__name__}' must have a 'message' parameter.")

# Resolve forward refs under PEP 563 / future annotations
try:
hints = get_type_hints(f, globalns=f.__globals__)
except Exception:
logger.debug(
"Failed to fully resolve type hints for %s", f.__name__, exc_info=True
)
hints = getattr(f, "__annotations__", {}) or {}

raw_hint = hints.get("message")
if raw_hint is None:
raise TypeError(
f"'{f.__name__}' must type-hint the 'message' parameter "
"(e.g., 'message: MyModel' or 'message: Union[A, B]')"
)

models = extract_message_models(raw_hint)
if not models:
raise TypeError(
f"Unsupported or unresolved message type for '{f.__name__}': {raw_hint!r}"
)

# Optional early validation of supported schema kinds
for m in models:
if not is_supported_model(m):
raise TypeError(f"Unsupported model type in '{f.__name__}': {m!r}")

data = {
"pubsub": pubsub,
"topic": topic,
"dead_letter_topic": dead_letter_topic
or (f"{topic}_DEAD" if topic else None),
"is_broadcast": broadcast,
"message_schemas": models, # list[type]
"message_types": [m.__name__ for m in models], # list[str]
}

# Attach metadata; deepcopy for defensive isolation
setattr(f, "_is_message_handler", True)
setattr(f, "_message_router_data", deepcopy(data))

logger.debug(
"@message_router: '%s' => models %s (topic=%s, pubsub=%s, broadcast=%s)",
f.__name__,
[m.__name__ for m in models],
topic,
pubsub,
broadcast,
)
return f

# Support both @message_router(...) and bare @message_router usage
return decorator if func is None else decorator(func)
156 changes: 156 additions & 0 deletions dapr_agents/workflow/utils/registration.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
from __future__ import annotations

import asyncio
import inspect
import logging
from typing import Any, Callable, Iterable, List, Optional, Type

from dapr.clients import DaprClient
from dapr.clients.grpc._response import TopicEventResponse
from dapr.common.pubsub.subscription import SubscriptionMessage

from dapr_agents.workflow.utils.messaging import (
extract_cloudevent_data,
validate_message_model,
)

logger = logging.getLogger(__name__)


def register_message_handlers(
targets: Iterable[Any],
dapr_client: DaprClient,
*,
loop: Optional[asyncio.AbstractEventLoop] = None,
) -> List[Callable[[], None]]:
"""Discover and subscribe handlers decorated with `@message_router`.

Scans each target:
- If the target itself is a decorated function (has `_message_router_data`), it is registered.
- If the target is an object, all its attributes are scanned for decorated callables.

Subscriptions use Dapr's streaming API (`subscribe_with_handler`) which invokes your handler
on a background thread. This function returns a list of "closer" callables. Invoking a closer
will unsubscribe the corresponding handler.

Args:
targets: Functions and/or instances to inspect for `_message_router_data`.
dapr_client: Active Dapr client used to create subscriptions.
loop: Event loop to await async handlers. If omitted, uses the running loop
or falls back to `asyncio.get_event_loop()`.

Returns:
A list of callables. Each callable, when invoked, closes the associated subscription.
"""
# Resolve loop strategy once up front.
if loop is None:
try:
loop = asyncio.get_running_loop()
except RuntimeError:
loop = asyncio.get_event_loop()

closers: List[Callable[[], None]] = []

def _iter_handlers(obj: Any):
"""Yield (owner, fn) pairs for decorated handlers on `obj`.

If `obj` is itself a decorated function, yield (None, obj).
If `obj` is an instance, scan its attributes for decorated callables.
"""
meta = getattr(obj, "_message_router_data", None)
if callable(obj) and meta:
yield None, obj
return

for name in dir(obj):
fn = getattr(obj, name)
if callable(fn) and getattr(fn, "_message_router_data", None):
yield obj, fn

for target in targets:
for owner, handler in _iter_handlers(target):
meta = getattr(handler, "_message_router_data")
schemas: List[Type[Any]] = meta.get("message_schemas") or []

# Bind method to instance if needed (descriptor protocol).
bound = (
handler if owner is None else handler.__get__(owner, owner.__class__)
)

async def _invoke(
bound_handler: Callable[..., Any],
parsed: Any,
) -> TopicEventResponse:
"""Invoke the user handler (sync or async) and normalize the result."""
result = bound_handler(parsed)
if inspect.iscoroutine(result):
result = await result
if isinstance(result, TopicEventResponse):
return result
# Treat any truthy/None return as success unless user explicitly returns a response.
return TopicEventResponse("success")

def _make_handler(
bound_handler: Callable[..., Any],
) -> Callable[[SubscriptionMessage], TopicEventResponse]:
"""Create a Dapr-compatible handler for a single decorated function."""

def handler_fn(message: SubscriptionMessage) -> TopicEventResponse:
try:
# 1) Extract payload + CloudEvent metadata (bytes/str/dict are also supported by the extractor)
event_data, metadata = extract_cloudevent_data(message)

# 2) Validate against the first matching schema (or dict as fallback)
parsed = None
for model in schemas or [dict]:
try:
parsed = validate_message_model(model, event_data)
break
except Exception:
# Try the next schema; log at debug for signal without noise.
logger.debug(
"Schema %r did not match payload; trying next.",
model,
exc_info=True,
)
continue

if parsed is None:
# Permanent schema mismatch → drop (DLQ if configured by Dapr)
logger.warning(
"No matching schema for message on topic %r; dropping. Raw payload: %r",
meta["topic"],
event_data,
)
return TopicEventResponse("drop")

# 3) Attach CE metadata for downstream consumers
if isinstance(parsed, dict):
parsed["_message_metadata"] = metadata
else:
setattr(parsed, "_message_metadata", metadata)

# 4) Bridge worker thread → event loop
if loop and loop.is_running():
fut = asyncio.run_coroutine_threadsafe(
_invoke(bound_handler, parsed), loop
)
return fut.result()
return asyncio.run(_invoke(bound_handler, parsed))

except Exception:
# Transient failure (I/O, handler crash, etc.) → retry
logger.exception("Message handler error; requesting retry.")
return TopicEventResponse("retry")

return handler_fn

close_fn = dapr_client.subscribe_with_handler(
pubsub_name=meta["pubsub"],
topic=meta["topic"],
handler_fn=_make_handler(bound),
dead_letter_topic=meta.get("dead_letter_topic"),
)
closers.append(close_fn)

return closers
Loading