|
8 | 8 | from collections import defaultdict |
9 | 9 | from contextlib import AsyncExitStack, asynccontextmanager |
10 | 10 | from dataclasses import dataclass |
11 | | -from typing import Annotated, Any, Generic, TypeVar, cast, final, overload |
| 11 | +from typing import Annotated, Any, Generic, TypeVar, final, overload |
12 | 12 | from collections.abc import Callable, Generator, AsyncGenerator, Awaitable |
13 | 13 | from typing_extensions import Self |
14 | 14 | from typing import get_args, get_origin |
|
41 | 41 |
|
42 | 42 | R = TypeVar("R") |
43 | 43 | T = TypeVar("T") |
44 | | -RESULT: CtxItem[Any] = cast(CtxItem, "$result") |
45 | | -STACK: CtxItem[AsyncExitStack] = cast(CtxItem, "$exit_stack") |
46 | | -SUBSCRIBER: CtxItem[Subscriber] = cast(CtxItem, "$subscriber") |
| 44 | +RESULT: CtxItem[Any] = CtxItem.make("$result") |
| 45 | +STACK: CtxItem[AsyncExitStack] = CtxItem.make("$exit_stack") |
| 46 | +SUBSCRIBER: CtxItem[Subscriber] = CtxItem.make("$subscriber") |
47 | 47 | current_subscriber: ContextVar[Subscriber] = ContextVar("_current_subscriber") |
48 | 48 |
|
49 | 49 |
|
@@ -339,7 +339,8 @@ async def handle(self, context: Contexts, inner=False): |
339 | 339 | current_subscriber.reset(token) # type: ignore |
340 | 340 | if not inner: |
341 | 341 | if "$exit_stack" in context: # pragma: no cover |
342 | | - await context[STACK].__aexit__(*sys.exc_info()) |
| 342 | + typ, e, tb = sys.exc_info() |
| 343 | + await context[STACK].__aexit__(typ, e, tb) |
343 | 344 | context.clear() |
344 | 345 | if self.once: |
345 | 346 | self.dispose() |
@@ -480,8 +481,16 @@ def defer(func: Callable[..., Any], ctx: Contexts | TTarget | None = None): |
480 | 481 | return sub.propagate(func, once=True) |
481 | 482 |
|
482 | 483 |
|
483 | | -def get_params(ctx: Contexts): |
484 | | - sub = ctx[SUBSCRIBER] |
| 484 | +def get_params(ctx: Contexts | Subscriber | None = None): # pragma: no cover |
| 485 | + if isinstance(ctx, dict): |
| 486 | + sub = ctx[SUBSCRIBER] |
| 487 | + elif isinstance(ctx, Subscriber): |
| 488 | + sub = ctx |
| 489 | + else: |
| 490 | + try: |
| 491 | + sub = current_subscriber.get() |
| 492 | + except LookupError: |
| 493 | + raise TypeError(f"Unsupported type {type(ctx)}") from None |
485 | 494 | return sub.params |
486 | 495 |
|
487 | 496 |
|
|
0 commit comments