Skip to content

Commit 06c6569

Browse files
committed
🍻 simplify logic
1 parent 3db96b7 commit 06c6569

4 files changed

Lines changed: 41 additions & 58 deletions

File tree

arclet/letoderea/core.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from collections import defaultdict
77
from dataclasses import dataclass
88
from operator import attrgetter
9-
from itertools import chain, groupby
9+
from itertools import chain
1010
from types import AsyncGeneratorType
1111
from typing import Any, TypeVar, overload, cast
1212
from collections.abc import Callable, Coroutine
@@ -46,10 +46,10 @@ async def gather(self, context: Contexts):
4646
exc_pub = define(ExceptionEvent, name="internal/exception")
4747

4848

49-
async def publish_exc_event(event: ExceptionEvent):
49+
def publish_exc_event(event: ExceptionEvent):
5050
scopes = [sp for sp in _scopes.values() if sp.available]
5151
subs = [slot for sp in scopes for slot in sp.subscribers if slot.publisher_id != "$backend"]
52-
await dispatch(event, slots=subs)
52+
return add_task(dispatch(event, slots=subs))
5353

5454

5555
async def compute(event: Any, scope: str | Scope | None = None, slots: Iterable[SubscriberSlot] | None = None, inherit_ctx: Contexts | None = None):
@@ -96,7 +96,7 @@ async def dispatch(event: Any, scope: str | Scope | None = None, slots: Iterable
9696
return
9797
if isinstance(event, ExceptionEvent): # pragma: no cover
9898
return
99-
await publish_exc_event(ExceptionEvent(event, subs[_i], result))
99+
publish_exc_event(ExceptionEvent(event, subs[_i], result))
100100
if isinstance(result, AsyncGeneratorType): # pragma: no cover
101101
async for res in result:
102102
if result is None or result is STOP:
@@ -135,9 +135,9 @@ async def serial_exec_concurrent(subs: list[Subscriber], ctx: Contexts):
135135

136136
async def serial(event: Any, scope: str | Scope | None = None, slots: Iterable[SubscriberSlot] | None = None, inherit_ctx: Contexts | None = None):
137137
grouped, context_map = await compute(event, scope, slots, inherit_ctx)
138-
for (priority, pub_id) in grouped:
139-
contexts = context_map[pub_id]
140-
gene = serial_exec_concurrent(grouped[(priority, pub_id)], contexts)
138+
for key, subs in grouped.items():
139+
contexts = context_map[key[1]]
140+
gene = serial_exec_concurrent(subs, contexts)
141141
async for subscriber, result in gene:
142142
if result is None or result is STOP:
143143
continue
@@ -148,7 +148,7 @@ async def serial(event: Any, scope: str | Scope | None = None, slots: Iterable[S
148148
return result.args[0]
149149
if isinstance(event, ExceptionEvent):
150150
return
151-
await publish_exc_event(ExceptionEvent(event, subscriber, result))
151+
publish_exc_event(ExceptionEvent(event, subscriber, result))
152152
elif isinstance(result, AsyncGeneratorType): # pragma: no cover
153153
async for res in result:
154154
if res is None or res is STOP:
@@ -164,9 +164,9 @@ async def serial(event: Any, scope: str | Scope | None = None, slots: Iterable[S
164164

165165
async def broadcast(event: Any, scope: str | Scope | None = None, slots: Iterable[SubscriberSlot] | None = None, inherit_ctx: Contexts | None = None, concurrent: bool = False): # pragma: no cover
166166
grouped, context_map = await compute(event, scope, slots, inherit_ctx)
167-
for (priority, pub_id) in grouped.keys():
168-
contexts = context_map[pub_id]
169-
async for subscriber, result in (serial_exec_concurrent(grouped[(priority, pub_id)], contexts) if concurrent else serial_exec(grouped[(priority, pub_id)], contexts)):
167+
for key, subs in grouped.items():
168+
contexts = context_map[key[1]]
169+
async for subscriber, result in (serial_exec_concurrent(subs, contexts) if concurrent else serial_exec(subs, contexts)):
170170
if result is None or result is STOP:
171171
continue
172172
if result is BLOCK:
@@ -179,7 +179,7 @@ async def broadcast(event: Any, scope: str | Scope | None = None, slots: Iterabl
179179
continue
180180
if isinstance(event, ExceptionEvent):
181181
return
182-
await publish_exc_event(ExceptionEvent(event, subscriber, result))
182+
publish_exc_event(ExceptionEvent(event, subscriber, result))
183183
elif isinstance(result, AsyncGeneratorType):
184184
async for res in result:
185185
if res is None or res is STOP:
@@ -200,7 +200,7 @@ async def _post(event: Any, scope: str | Scope | None = None, inherit_ctx: Conte
200200
res = await serial(event, scope, inherit_ctx=inherit_ctx)
201201
if res is None:
202202
return
203-
if res.__class__ is Force:
203+
if res.__class__ is Force: # pragma: no cover
204204
res = res.value
205205
if validate and hasattr(event, "check_result"):
206206
return cast(Resultable, event).check_result(res.value if isinstance(res, Result) else res)

arclet/letoderea/scope.pyi

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ scope_ctx: ContextModel[Scope]
2121
global_propagators: list[Propagator]
2222

2323

24-
@dataclass
24+
@dataclass(slots=True, frozen=True)
2525
class SubscriberSlot:
2626
subscriber: Subscriber
2727
publisher_id: str

arclet/letoderea/subscriber.py

Lines changed: 24 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -78,16 +78,15 @@ async def __call__(self, context: Contexts):
7878
cache[self.target] = fut = asyncio.Future()
7979
try:
8080
res = await self.sub.handle(context.copy(), inner=True)
81-
except Exception as e: # pragma: no cover
82-
fut.set_exception(e)
83-
fut.cancel()
84-
raise
8581
except BaseException as e: # pragma: no cover
8682
fut.set_exception(e)
8783
fut.cancel()
88-
cache.pop(self.target, None)
84+
if isinstance(e, Exception):
85+
cache.pop(self.target, None)
8986
raise
9087
else:
88+
if isinstance(res, _ExitException):
89+
raise
9190
fut.set_result(res)
9291
return res
9392

@@ -115,6 +114,8 @@ class CompileParam:
115114
__slots__ = ("name", "annotation", "default", "providers", "depend", "record")
116115

117116
async def solve(self, context: Contexts | dict[str, Any]):
117+
if self.depend:
118+
return await self.depend(context) # type: ignore
118119
if self.name in context:
119120
return context[self.name]
120121
if self.record and (res := await self.record(context)) is not None: # type: ignore
@@ -219,6 +220,7 @@ def __init__(self, callable_target: Callable[..., R], *, priority: int = 16, pro
219220
self._propagates: list[Subscriber] = []
220221
self._propagator_cache: WeakSet[Propagator] = WeakSet()
221222
self._cursor = 0
223+
self._after_propagates = 0
222224
self._listen = _listen
223225

224226
if hasattr(callable_target, "__providers__"):
@@ -304,31 +306,21 @@ async def handle(self, context: Contexts, inner=False):
304306
context[STACK] = AsyncExitStack()
305307
try:
306308
if self._cursor:
307-
_res = await self._run_propagate(context, self._propagates[: self._cursor])
308-
if _res is STOP or _res is BLOCK:
309-
return _res
310-
arguments: Contexts = {} # type: ignore
309+
await self._run_propagate(context, self._propagates[: self._cursor])
310+
arguments = {} # type: ignore
311311
for param in self.params:
312-
if param.depend:
313-
dep_res = await param.depend(context)
314-
if dep_res is STOP or dep_res is BLOCK:
315-
return dep_res
316-
arguments[param.name] = dep_res
317-
else:
318-
arguments[param.name] = await param.solve(context)
312+
arguments[param.name] = await param.solve(context)
319313
if self.is_cm:
320314
stack: AsyncExitStack = context[STACK]
321315
result = await stack.enter_async_context(self._callable_target(**arguments))
322316
elif self.is_agen:
323317
result = self._callable_target(**arguments)
324318
else:
325319
result = await self._callable_target(**arguments)
326-
if self._propagates:
320+
if self._after_propagates:
327321
context[RESULT] = result
328322
propagate_result = await self._run_propagate(context, self._propagates[self._cursor :])
329323
result = result if propagate_result is None else propagate_result
330-
if result is STOP or result is BLOCK:
331-
return result
332324
except InnerHandlerException as e:
333325
if inner:
334326
raise
@@ -369,9 +361,7 @@ async def _run_propagate(self, context: Contexts, propagates: list[Subscriber]):
369361
else:
370362
raise
371363
else:
372-
if isinstance(result, ExitState):
373-
return result
374-
if isinstance(result, _ExitException): # pragma: no cover
364+
if isinstance(result, _ExitException):
375365
raise result
376366
if isinstance(result, dict):
377367
context.update(result)
@@ -384,7 +374,7 @@ async def _run_propagate(self, context: Contexts, propagates: list[Subscriber]):
384374
await self._run_propagate(context, [x[0] for x in pending.pop(key)])
385375
if pending:
386376
key, (slot, *_) = pending.popitem()
387-
raise ExceptionHandler.call(slot[1], slot[0].callable_target,context, inner=True)
377+
raise ExceptionHandler.call(slot[1], slot[0].callable_target, context, inner=True)
388378
return context.get(RESULT)
389379

390380
@overload
@@ -419,36 +409,27 @@ def propagate(self, func: TTarget[Any] | Propagator | None = None, *, prepend: b
419409
self._propagator_cache.add(func)
420410
return lambda: ([dispose() for dispose in disposes] or self._propagator_cache.discard(func))
421411

422-
def _dispose(x: Subscriber):
423-
self._propagates.remove(x)
424-
self._cursor -= 1
425-
426412
def wrapper(callable_target: TTarget[Any], /):
427413
if isinstance(callable_target, Subscriber):
428414
raise ValueError("Subscriber can't be propagated")
429415
_providers = [*(providers or []), *self.providers]
430416
if prepend:
431-
sub = Subscriber(
432-
callable_target,
433-
priority=priority,
434-
providers=_providers,
435-
dispose=_dispose,
436-
once=once,
437-
_listen=self._listen,
438-
)
417+
def _dispose(x: Subscriber):
418+
self._propagates.remove(x)
419+
self._cursor -= 1
420+
421+
sub = Subscriber(callable_target, priority=priority, providers=_providers, dispose=_dispose, once=once, _listen=self._listen)
439422
self._propagates.insert(self._cursor, sub)
440423
self._cursor += 1
441424
else:
425+
def _dispose(x: Subscriber):
426+
self._propagates.remove(x)
427+
self._after_propagates -= 1
428+
442429
_providers.append(ResultProvider())
443-
sub = Subscriber(
444-
callable_target,
445-
priority=priority,
446-
providers=_providers,
447-
dispose=lambda x: self._propagates.remove(x),
448-
once=once,
449-
_listen=self._listen,
450-
)
430+
sub = Subscriber(callable_target, priority=priority, providers=_providers, dispose=_dispose, once=once, _listen=self._listen)
451431
self._propagates.append(sub)
432+
self._after_propagates += 1
452433
return sub.dispose
453434

454435
if func:

tests/test_exception.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import asyncio
2+
13
import pytest
24

35
import arclet.letoderea as le
@@ -26,5 +28,5 @@ async def _(event: le.ExceptionEvent, origin, exc: Exception, subscriber):
2628
executed.append(1)
2729

2830
await le.publish(TestExcEvent("1"))
29-
31+
await asyncio.sleep(0)
3032
assert len(executed) == 3

0 commit comments

Comments
 (0)