Skip to content

Commit 288e137

Browse files
committed
🐛 fix generator result handle
1 parent e864ba3 commit 288e137

5 files changed

Lines changed: 44 additions & 32 deletions

File tree

arclet/letoderea/core.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -108,8 +108,8 @@ async def dispatch(event: Any, scope: str | Scope | None = None, slots: Iterable
108108
continue
109109
if res is BLOCK:
110110
return
111-
if isinstance(result, _ExitException): # type: ignore
112-
if result.args[1]:
111+
if isinstance(res, _ExitException):
112+
if res.args[1]:
113113
return
114114
continue
115115

@@ -146,15 +146,15 @@ async def serial(event: Any, scope: str | Scope | None = None, slots: Iterable[t
146146
for t in tasks: t.cancel()
147147
if res is BLOCK:
148148
return
149-
if isinstance(result, _ExitException): # type: ignore
150-
return result.args[0]
149+
if isinstance(res, _ExitException):
150+
return res.args[0]
151151
if res.__class__ is Force:
152-
return res.value # type: ignore
152+
return res.value
153153
else:
154154
return res
155155
elif result.__class__ is Force: # pragma: no cover
156156
for t in tasks: t.cancel()
157-
return result.value # type: ignore
157+
return result.value
158158
else:
159159
for t in tasks: t.cancel()
160160
return result
@@ -192,17 +192,17 @@ async def broadcast(event: Any, scope: str | Scope | None = None, slots: Iterabl
192192
continue
193193
if res is BLOCK:
194194
return
195-
if isinstance(result, _ExitException):
196-
yield result.args[0]
197-
if result.args[1]:
195+
if isinstance(res, _ExitException):
196+
yield res.args[0]
197+
if res.args[1]:
198198
return
199199
continue
200200
if res.__class__ is Force:
201-
yield res.value # type: ignore
201+
yield res.value
202202
else:
203203
yield res
204204
elif result.__class__ is Force:
205-
yield result.value # type: ignore
205+
yield result.value
206206
else:
207207
yield result
208208

@@ -267,7 +267,7 @@ async def _loop_fetch(publisher: Publisher):
267267

268268
def publish(event: Any, scope: str | Scope | None = None, inherit_ctx: Contexts | None = None) -> asyncio.Task[None]:
269269
"""发布事件,并行处理所有响应"""
270-
return add_task(dispatch(event, scope, inherit_ctx=inherit_ctx)) # type: ignore
270+
return add_task(dispatch(event, scope, inherit_ctx=inherit_ctx))
271271

272272

273273
@overload

arclet/letoderea/exceptions.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,8 @@
66
import sys
77
import traceback
88
from enum import Enum
9-
from types import CodeType, TracebackType
9+
from types import CodeType, TracebackType, FunctionType
1010
from typing import Any, Final, cast
11-
from collections.abc import Callable
1211

1312
from .typing import Contexts
1413

@@ -55,7 +54,7 @@ def __init__(self, exc_type: type[BaseException], exc_value: BaseException, exc_
5554

5655

5756
@functools.lru_cache(maxsize=None)
58-
def get_caller_info(func: Callable, name: str | None = None) -> tuple[str, int, int, str, int, int]: # pragma: no cover
57+
def get_caller_info(func: FunctionType, name: str | None = None) -> tuple[str, int, int, str, int, int]: # pragma: no cover
5958
"""Get the caller information of a function or method.
6059
6160
Returns:
@@ -66,7 +65,7 @@ def get_caller_info(func: Callable, name: str | None = None) -> tuple[str, int,
6665
- param_lineno: The line number where the parameter is defined.
6766
- param_offset: The character offset in the line where the parameter is defined.
6867
"""
69-
code: CodeType = func.__code__ # type: ignore
68+
code: CodeType = func.__code__
7069
lines = inspect.getsourcelines(func)
7170
lineno = code.co_firstlineno
7271
line = ""
@@ -116,7 +115,7 @@ def print_trace(te: Trace): # pragma: no cover
116115
print(line, file=sys.stderr, end="")
117116

118117
@staticmethod
119-
def call(e: Exception, callable_target: Callable, contexts: Contexts, inner: bool = False):
118+
def call(e: Exception, callable_target: FunctionType, contexts: Contexts, inner: bool = False):
120119
if isinstance(e, UnresolvedRequirement) and not isinstance(e, SyntaxError):
121120
name, anno, _, pds = e.args
122121
param = f"{name}: {inspect.formatannotation(anno)}" if anno is not None else name

arclet/letoderea/publisher.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
_publisher_cache: dict[type, list[str]] = {}
1919
_publisher_cache_ignore = set()
2020

21+
2122
async def _supplier(event: Any, context: Contexts):
2223
if isinstance(event, dict):
2324
return context.update(event)
@@ -95,7 +96,7 @@ def check_subscriber(self, sub: Subscriber) -> bool:
9596
return True
9697

9798

98-
def filter_publisher(target: type[T]) -> Publisher[T] | None:
99+
def filter_publisher(target: type[T1]) -> Publisher[T1] | None:
99100
if (label := getattr(target, "__publisher__", f"$event:{target.__module__}.{target.__name__}")) in _publishers:
100101
return _publishers[label]
101102
return next((pub for pub in _publishers.values() if pub.target == target), None)
@@ -112,19 +113,19 @@ def get_publishers(event: Any) -> dict[str, Publisher]:
112113

113114

114115
def define(
115-
target: type[T],
116-
supplier: Callable[[T, Contexts], Awaitable[Contexts | None]] | None = None,
116+
target: type[T1],
117+
supplier: Callable[[T1, Contexts], Awaitable[Contexts | None]] | None = None,
117118
name: str | None = None,
118-
) -> Publisher[T]:
119+
) -> Publisher[T1]:
119120
if name and name in _publishers:
120121
return _publishers[name]
121122
if (_id := getattr(target, "__publisher__", f"$event:{target.__module__}.{target.__name__}")) in _publishers:
122123
return _publishers[_id]
123124
return Publisher(target, name, supplier)
124125

125126

126-
def gather(func: Callable[[T, Contexts], Awaitable[Contexts | None]]):
127-
target: type[T] = next(iter(get_type_hints(func).values())) # type: ignore
127+
def gather(func: Callable[[Any, Contexts], Awaitable[Contexts | None]]):
128+
target: type[T1] = next(iter(get_type_hints(func).values())) # type: ignore
128129
pub = filter_publisher(target) or define(target)
129130
pub.supplier = func
130131
return func

arclet/letoderea/subscriber.py

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from collections import defaultdict
99
from contextlib import AsyncExitStack, asynccontextmanager
1010
from dataclasses import dataclass
11-
from typing import Annotated, Any, Generic, TypeVar, cast, final, overload
11+
from typing import Annotated, Any, Generic, TypeVar, final, overload
1212
from collections.abc import Callable, Generator, AsyncGenerator, Awaitable
1313
from typing_extensions import Self
1414
from typing import get_args, get_origin
@@ -41,9 +41,9 @@
4141

4242
R = TypeVar("R")
4343
T = TypeVar("T")
44-
RESULT: CtxItem[Any] = cast(CtxItem, "$result")
45-
STACK: CtxItem[AsyncExitStack] = cast(CtxItem, "$exit_stack")
46-
SUBSCRIBER: CtxItem[Subscriber] = cast(CtxItem, "$subscriber")
44+
RESULT: CtxItem[Any] = CtxItem.make("$result")
45+
STACK: CtxItem[AsyncExitStack] = CtxItem.make("$exit_stack")
46+
SUBSCRIBER: CtxItem[Subscriber] = CtxItem.make("$subscriber")
4747
current_subscriber: ContextVar[Subscriber] = ContextVar("_current_subscriber")
4848

4949

@@ -339,7 +339,8 @@ async def handle(self, context: Contexts, inner=False):
339339
current_subscriber.reset(token) # type: ignore
340340
if not inner:
341341
if "$exit_stack" in context: # pragma: no cover
342-
await context[STACK].__aexit__(*sys.exc_info())
342+
typ, e, tb = sys.exc_info()
343+
await context[STACK].__aexit__(typ, e, tb)
343344
context.clear()
344345
if self.once:
345346
self.dispose()
@@ -480,8 +481,16 @@ def defer(func: Callable[..., Any], ctx: Contexts | TTarget | None = None):
480481
return sub.propagate(func, once=True)
481482

482483

483-
def get_params(ctx: Contexts):
484-
sub = ctx[SUBSCRIBER]
484+
def get_params(ctx: Contexts | Subscriber | None = None): # pragma: no cover
485+
if isinstance(ctx, dict):
486+
sub = ctx[SUBSCRIBER]
487+
elif isinstance(ctx, Subscriber):
488+
sub = ctx
489+
else:
490+
try:
491+
sub = current_subscriber.get()
492+
except LookupError:
493+
raise TypeError(f"Unsupported type {type(ctx)}") from None
485494
return sub.params
486495

487496

arclet/letoderea/typing.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,10 @@
1414
T1 = TypeVar("T1")
1515

1616

17-
class CtxItem(Generic[T]): ...
17+
class CtxItem(Generic[T]):
18+
@classmethod
19+
def make(cls, name: str) -> CtxItem[T]:
20+
return cast(CtxItem[T], cast(object, name))
1821

1922

2023
class Contexts(dict[str, Any]):
@@ -50,7 +53,7 @@ def get(self, __key: str | CtxItem[T1], __default: Any = ...) -> Any: ... # typ
5053
...
5154

5255

53-
EVENT: CtxItem[Any] = cast(CtxItem, "$event")
56+
EVENT: CtxItem[Any] = CtxItem.make("$event")
5457

5558

5659
async def generate_contexts(

0 commit comments

Comments
 (0)