1414
1515
1616class ConditionExtension :
17- """Scheduler extension managing Condition lock and notifications"""
17+ """Scheduler extension managing Condition lock and notifications
18+
19+ State managed:
20+ - _locks: Which client holds which condition's lock
21+ - _acquire_queue: Clients waiting to acquire lock (FIFO)
22+ - _waiters: Clients in wait() (released lock, awaiting notify)
23+ - _client_conditions: Reverse index for cleanup on disconnect
24+ """
1825
1926 def __init__ (self , scheduler ):
2027 self .scheduler = scheduler
28+
2129 # {condition_name: client_id} - who holds each lock
22- self ._lock_holders = {}
23- # {condition_name: deque of (client_id, future)} - waiting to acquire
24- self ._acquire_waiters = defaultdict (deque )
25- # {condition_name: {waiter_id: (client_id, Event)}} - waiting for notify
26- self ._notify_waiters = defaultdict (dict )
30+ self ._locks = {}
31+
32+ # {condition_name: deque[(client_id, future)]} - waiting to acquire
33+ self ._acquire_queue = defaultdict (deque )
34+
35+ # {condition_name: {waiter_id: (client_id, event, reacquire_future)}}
36+ # - clients in wait(), will need to reacquire after notify
37+ self ._waiters = defaultdict (dict )
38+
39+ # {client_id: set(condition_names)} - for cleanup on disconnect
40+ self ._client_conditions = defaultdict (set )
2741
2842 self .scheduler .handlers .update (
2943 {
@@ -35,115 +49,227 @@ def __init__(self, scheduler):
3549 }
3650 )
3751
52+ # Register cleanup on client disconnect
53+ self .scheduler .extensions ["conditions" ] = self
54+
55+ def _track_client (self , name , client_id ):
56+ """Track that a client is using this condition"""
57+ self ._client_conditions [client_id ].add (name )
58+
59+ def _untrack_client (self , name , client_id ):
60+ """Stop tracking client for this condition"""
61+ if client_id in self ._client_conditions :
62+ self ._client_conditions [client_id ].discard (name )
63+ if not self ._client_conditions [client_id ]:
64+ del self ._client_conditions [client_id ]
65+
3866 @log_errors
3967 async def acquire (self , name = None , client_id = None ):
4068 """Acquire lock - blocks until available"""
41- if name not in self ._lock_holders :
69+ self ._track_client (name , client_id )
70+
71+ if name not in self ._locks :
4272 # Lock is free
43- self ._lock_holders [name ] = client_id
73+ self ._locks [name ] = client_id
4474 return True
4575
46- if self ._lock_holders [name ] == client_id :
47- # Already hold it (shouldn't happen in normal use )
76+ if self ._locks [name ] == client_id :
77+ # Re-entrant acquire (from same client )
4878 return True
4979
50- # Lock is held by someone else - wait our turn
80+ # Lock is held - queue up and wait
5181 future = asyncio .Future ()
52- self ._acquire_waiters [name ].append ((client_id , future ))
53- await future
54- return True
82+ self ._acquire_queue [name ].append ((client_id , future ))
83+
84+ try :
85+ await future
86+ return True
87+ except asyncio .CancelledError :
88+ # Remove from queue if cancelled
89+ queue = self ._acquire_queue .get (name , deque ())
90+ try :
91+ queue .remove ((client_id , future ))
92+ except ValueError :
93+ pass # Already removed
94+ raise
95+
96+ def _wake_next_acquirer (self , name ):
97+ """Wake the next client waiting to acquire this lock"""
98+ queue = self ._acquire_queue .get (name , deque ())
99+
100+ while queue :
101+ client_id , future = queue .popleft ()
102+ if not future .done ():
103+ self ._locks [name ] = client_id
104+ future .set_result (True )
105+ return True
106+
107+ # No waiters left
108+ if name in self ._acquire_queue :
109+ del self ._acquire_queue [name ]
110+ return False
55111
56112 @log_errors
57113 async def release (self , name = None , client_id = None ):
58114 """Release lock"""
59- if name not in self ._lock_holders :
115+ if name not in self ._locks :
60116 raise RuntimeError ("Released too often" )
61117
62- if self ._lock_holders [name ] != client_id :
118+ if self ._locks [name ] != client_id :
63119 raise RuntimeError ("Cannot release lock held by another client" )
64120
65- del self ._lock_holders [name ]
121+ del self ._locks [name ]
66122
67- # Wake next waiter if any
68- waiters = self ._acquire_waiters .get (name , deque ())
69- while waiters :
70- next_client_id , future = waiters .popleft ()
71- if not future .done ():
72- self ._lock_holders [name ] = next_client_id
73- future .set_result (True )
74- break
123+ # Wake next waiter trying to acquire
124+ if not self ._wake_next_acquirer (name ):
125+ # No acquire waiters - cleanup if no notify waiters either
126+ if name not in self ._waiters :
127+ self ._untrack_client (name , client_id )
75128
76129 @log_errors
77130 async def wait (self , name = None , waiter_id = None , client_id = None , timeout = None ):
78- """Release lock, wait for notify, reacquire lock"""
131+ """Release lock, wait for notify, reacquire lock
132+
133+ Critical: Register for notify BEFORE releasing lock to prevent lost wakeup
134+ """
79135 # Verify caller holds lock
80- if self ._lock_holders .get (name ) != client_id :
136+ if self ._locks .get (name ) != client_id :
81137 raise RuntimeError ("wait() called without holding the lock" )
82138
83- # Release lock (waking next acquire waiter if any)
84- await self .release (name = name , client_id = client_id )
139+ # 1. Register for notification FIRST (prevents lost wakeup)
140+ notify_event = asyncio .Event ()
141+ reacquire_future = asyncio .Future ()
142+ self ._waiters [name ][waiter_id ] = (client_id , notify_event , reacquire_future )
85143
86- # Register as notify waiter
87- event = asyncio .Event ()
88- self ._notify_waiters [name ][waiter_id ] = (client_id , event )
144+ # 2. Release lock (allows notifier to proceed)
145+ await self .release (name = name , client_id = client_id )
89146
90- # Wait for notification
91- future = event .wait ()
147+ # 3. Wait for notification
148+ wait_future = notify_event .wait ()
92149 if timeout is not None :
93- future = wait_for (future , timeout )
150+ wait_future = wait_for (wait_future , timeout )
94151
152+ notified = False
95153 try :
96- await future
97- result = True
154+ await wait_future
155+ notified = True
98156 except TimeoutError :
99- result = False
157+ notified = False
158+ except asyncio .CancelledError :
159+ # Cancelled - cleanup and don't reacquire
160+ self ._waiters [name ].pop (waiter_id , None )
161+ if not self ._waiters [name ]:
162+ del self ._waiters [name ]
163+ raise
100164 finally :
101- # Cleanup waiter
102- self ._notify_waiters [name ].pop (waiter_id , None )
103- if not self ._notify_waiters [name ]:
104- del self ._notify_waiters [name ]
165+ # Cleanup waiter registration
166+ self ._waiters [name ].pop (waiter_id , None )
167+ if not self ._waiters [name ]:
168+ del self ._waiters [name ]
105169
106- # Reacquire lock - blocks until available
107- await self .acquire (name = name , client_id = client_id )
170+ # 4. Reacquire lock before returning
171+ # This might block if other clients are waiting
172+ await self .acquire (name = name , client_id = client_id )
108173
109- return result
174+ return notified
110175
111176 @log_errors
112177 def notify (self , name = None , client_id = None , n = 1 ):
113178 """Wake up n waiters"""
114179 # Verify caller holds lock
115- if self ._lock_holders .get (name ) != client_id :
180+ if self ._locks .get (name ) != client_id :
116181 raise RuntimeError ("notify() called without holding the lock" )
117182
118- waiters = self ._notify_waiters .get (name , {})
183+ waiters = self ._waiters .get (name , {})
119184 count = 0
120- for _ , (_ , event ) in list (waiters .items ())[:n ]:
185+
186+ for waiter_id in list (waiters .keys ())[:n ]:
187+ _ , event , _ = waiters [waiter_id ]
121188 event .set ()
122189 count += 1
190+
123191 return count
124192
125193 @log_errors
126194 def notify_all (self , name = None , client_id = None ):
127195 """Wake up all waiters"""
128196 # Verify caller holds lock
129- if self ._lock_holders .get (name ) != client_id :
197+ if self ._locks .get (name ) != client_id :
130198 raise RuntimeError ("notify_all() called without holding the lock" )
131199
132- waiters = self ._notify_waiters .get (name , {})
133- for _ , event in waiters .values ():
200+ waiters = self ._waiters .get (name , {})
201+
202+ for _ , event , _ in waiters .values ():
134203 event .set ()
204+
135205 return len (waiters )
136206
207+ async def remove_client (self , client ):
208+ """Cleanup when client disconnects"""
209+ conditions = self ._client_conditions .pop (client , set ())
210+
211+ for name in conditions :
212+ # Release any locks held by this client
213+ if self ._locks .get (name ) == client :
214+ try :
215+ await self .release (name = name , client_id = client )
216+ except Exception as e :
217+ logger .warning (f"Error releasing lock for { name } : { e } " )
218+
219+ # Cancel acquire waiters from this client
220+ queue = self ._acquire_queue .get (name , deque ())
221+ to_remove = []
222+ for i , (cid , future ) in enumerate (queue ):
223+ if cid == client and not future .done ():
224+ future .cancel ()
225+ to_remove .append (i )
226+ for i in reversed (to_remove ):
227+ try :
228+ del queue [i ]
229+ except IndexError :
230+ pass
231+
232+ # Cancel notify waiters from this client
233+ waiters = self ._waiters .get (name , {})
234+ to_remove = []
235+ for waiter_id , (cid , event , reacq ) in waiters .items ():
236+ if cid == client :
237+ event .set () # Wake them up so they can cleanup
238+ if not reacq .done ():
239+ reacq .cancel ()
240+ to_remove .append (waiter_id )
241+ for wid in to_remove :
242+ waiters .pop (wid , None )
243+
137244
138245class Condition (SyncMethodMixin ):
139- """Distributed Condition Variable"""
246+ """Distributed Condition Variable
247+
248+ Provides wait/notify synchronization across distributed clients.
249+ Multiple Condition instances with the same name share state.
250+
251+ Examples
252+ --------
253+ >>> condition = Condition('data-ready')
254+ >>>
255+ >>> # Consumer
256+ >>> async with condition:
257+ ... while not data_available():
258+ ... await condition.wait()
259+ ... process_data()
260+ >>>
261+ >>> # Producer
262+ >>> async with condition:
263+ ... produce_data()
264+ ... condition.notify_all()
265+ """
140266
141267 def __init__ (self , name = None , client = None ):
142268 self .name = name or f"condition-{ uuid .uuid4 ().hex } "
143269 self ._waiter_id = uuid .uuid4 ().hex
144270 self ._client_id = uuid .uuid4 ().hex
145271 self ._client = client
146- self ._is_locked = False # Track local state
272+ self ._is_locked = False
147273
148274 @property
149275 def client (self ):
@@ -160,44 +286,64 @@ def loop(self):
160286
161287 def _verify_running (self ):
162288 if not self .client :
163- raise RuntimeError (f"{ type (self )} object not properly initialized" )
289+ raise RuntimeError (
290+ f"{ type (self )} object not properly initialized. "
291+ "This can happen if the object is being deserialized "
292+ "outside of the context of a Client or Worker."
293+ )
164294
165295 async def acquire (self ):
166- """Acquire lock"""
296+ """Acquire the lock"""
167297 self ._verify_running ()
168298 await self .client .scheduler .condition_acquire (
169299 name = self .name , client_id = self ._client_id
170300 )
171301 self ._is_locked = True
172302
173303 async def release (self ):
174- """Release lock"""
304+ """Release the lock"""
175305 self ._verify_running ()
176306 await self .client .scheduler .condition_release (
177307 name = self .name , client_id = self ._client_id
178308 )
179309 self ._is_locked = False
180310
181311 async def wait (self , timeout = None ):
182- """Wait for notification - atomically releases and reacquires lock"""
312+ """Wait for notification
313+
314+ Must be called while holding the lock. Atomically releases lock,
315+ waits for notify(), then reacquires lock before returning.
316+
317+ Parameters
318+ ----------
319+ timeout : float, optional
320+ Maximum time to wait in seconds
321+
322+ Returns
323+ -------
324+ bool
325+ True if notified, False if timeout
326+ """
183327 if not self ._is_locked :
184328 raise RuntimeError ("wait() called without holding the lock" )
185329
186330 self ._verify_running ()
187331 timeout = parse_timedelta (timeout )
188332
189- # This handles release, wait, reacquire atomically on scheduler
333+ # Scheduler handles atomic release/ wait/ reacquire
190334 result = await self .client .scheduler .condition_wait (
191335 name = self .name ,
192336 waiter_id = self ._waiter_id ,
193337 client_id = self ._client_id ,
194338 timeout = timeout ,
195339 )
196- # Lock is reacquired by the time this returns
340+
341+ # Lock is reacquired when this returns
342+ # _is_locked stays True
197343 return result
198344
199345 def notify (self , n = 1 ):
200- """Wake up n waiters"""
346+ """Wake up n waiters (default 1) """
201347 if not self ._is_locked :
202348 raise RuntimeError ("notify() called without holding the lock" )
203349 self ._verify_running ()
@@ -220,7 +366,7 @@ def notify_all(self):
220366 )
221367
222368 def locked (self ):
223- """Return True if lock is held by this instance """
369+ """Return True if this instance holds the lock """
224370 return self ._is_locked
225371
226372 async def __aenter__ (self ):
0 commit comments