@@ -78,16 +78,15 @@ async def __call__(self, context: Contexts):
7878 cache [self .target ] = fut = asyncio .Future ()
7979 try :
8080 res = await self .sub .handle (context .copy (), inner = True )
81- except Exception as e : # pragma: no cover
82- fut .set_exception (e )
83- fut .cancel ()
84- raise
8581 except BaseException as e : # pragma: no cover
8682 fut .set_exception (e )
8783 fut .cancel ()
88- cache .pop (self .target , None )
84+ if isinstance (e , Exception ):
85+ cache .pop (self .target , None )
8986 raise
9087 else :
88+ if isinstance (res , _ExitException ):
89+ raise
9190 fut .set_result (res )
9291 return res
9392
@@ -115,6 +114,8 @@ class CompileParam:
115114 __slots__ = ("name" , "annotation" , "default" , "providers" , "depend" , "record" )
116115
117116 async def solve (self , context : Contexts | dict [str , Any ]):
117+ if self .depend :
118+ return await self .depend (context ) # type: ignore
118119 if self .name in context :
119120 return context [self .name ]
120121 if self .record and (res := await self .record (context )) is not None : # type: ignore
@@ -219,6 +220,7 @@ def __init__(self, callable_target: Callable[..., R], *, priority: int = 16, pro
219220 self ._propagates : list [Subscriber ] = []
220221 self ._propagator_cache : WeakSet [Propagator ] = WeakSet ()
221222 self ._cursor = 0
223+ self ._after_propagates = 0
222224 self ._listen = _listen
223225
224226 if hasattr (callable_target , "__providers__" ):
@@ -304,31 +306,21 @@ async def handle(self, context: Contexts, inner=False):
304306 context [STACK ] = AsyncExitStack ()
305307 try :
306308 if self ._cursor :
307- _res = await self ._run_propagate (context , self ._propagates [: self ._cursor ])
308- if _res is STOP or _res is BLOCK :
309- return _res
310- arguments : Contexts = {} # type: ignore
309+ await self ._run_propagate (context , self ._propagates [: self ._cursor ])
310+ arguments = {} # type: ignore
311311 for param in self .params :
312- if param .depend :
313- dep_res = await param .depend (context )
314- if dep_res is STOP or dep_res is BLOCK :
315- return dep_res
316- arguments [param .name ] = dep_res
317- else :
318- arguments [param .name ] = await param .solve (context )
312+ arguments [param .name ] = await param .solve (context )
319313 if self .is_cm :
320314 stack : AsyncExitStack = context [STACK ]
321315 result = await stack .enter_async_context (self ._callable_target (** arguments ))
322316 elif self .is_agen :
323317 result = self ._callable_target (** arguments )
324318 else :
325319 result = await self ._callable_target (** arguments )
326- if self ._propagates :
320+ if self ._after_propagates :
327321 context [RESULT ] = result
328322 propagate_result = await self ._run_propagate (context , self ._propagates [self ._cursor :])
329323 result = result if propagate_result is None else propagate_result
330- if result is STOP or result is BLOCK :
331- return result
332324 except InnerHandlerException as e :
333325 if inner :
334326 raise
@@ -369,9 +361,7 @@ async def _run_propagate(self, context: Contexts, propagates: list[Subscriber]):
369361 else :
370362 raise
371363 else :
372- if isinstance (result , ExitState ):
373- return result
374- if isinstance (result , _ExitException ): # pragma: no cover
364+ if isinstance (result , _ExitException ):
375365 raise result
376366 if isinstance (result , dict ):
377367 context .update (result )
@@ -384,7 +374,7 @@ async def _run_propagate(self, context: Contexts, propagates: list[Subscriber]):
384374 await self ._run_propagate (context , [x [0 ] for x in pending .pop (key )])
385375 if pending :
386376 key , (slot , * _ ) = pending .popitem ()
387- raise ExceptionHandler .call (slot [1 ], slot [0 ].callable_target ,context , inner = True )
377+ raise ExceptionHandler .call (slot [1 ], slot [0 ].callable_target , context , inner = True )
388378 return context .get (RESULT )
389379
390380 @overload
@@ -419,36 +409,27 @@ def propagate(self, func: TTarget[Any] | Propagator | None = None, *, prepend: b
419409 self ._propagator_cache .add (func )
420410 return lambda : ([dispose () for dispose in disposes ] or self ._propagator_cache .discard (func ))
421411
422- def _dispose (x : Subscriber ):
423- self ._propagates .remove (x )
424- self ._cursor -= 1
425-
426412 def wrapper (callable_target : TTarget [Any ], / ):
427413 if isinstance (callable_target , Subscriber ):
428414 raise ValueError ("Subscriber can't be propagated" )
429415 _providers = [* (providers or []), * self .providers ]
430416 if prepend :
431- sub = Subscriber (
432- callable_target ,
433- priority = priority ,
434- providers = _providers ,
435- dispose = _dispose ,
436- once = once ,
437- _listen = self ._listen ,
438- )
417+ def _dispose (x : Subscriber ):
418+ self ._propagates .remove (x )
419+ self ._cursor -= 1
420+
421+ sub = Subscriber (callable_target , priority = priority , providers = _providers , dispose = _dispose , once = once , _listen = self ._listen )
439422 self ._propagates .insert (self ._cursor , sub )
440423 self ._cursor += 1
441424 else :
425+ def _dispose (x : Subscriber ):
426+ self ._propagates .remove (x )
427+ self ._after_propagates -= 1
428+
442429 _providers .append (ResultProvider ())
443- sub = Subscriber (
444- callable_target ,
445- priority = priority ,
446- providers = _providers ,
447- dispose = lambda x : self ._propagates .remove (x ),
448- once = once ,
449- _listen = self ._listen ,
450- )
430+ sub = Subscriber (callable_target , priority = priority , providers = _providers , dispose = _dispose , once = once , _listen = self ._listen )
451431 self ._propagates .append (sub )
432+ self ._after_propagates += 1
452433 return sub .dispose
453434
454435 if func :
0 commit comments