Skip to content

Commit 470135a

Browse files
committed
Update condition.py,test_condition.py
1 parent 678f8b3 commit 470135a

File tree

2 files changed

+207
-80
lines changed

2 files changed

+207
-80
lines changed

distributed/condition.py

Lines changed: 207 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -14,16 +14,30 @@
1414

1515

1616
class ConditionExtension:
17-
"""Scheduler extension managing Condition lock and notifications"""
17+
"""Scheduler extension managing Condition lock and notifications
18+
19+
State managed:
20+
- _locks: Which client holds which condition's lock
21+
- _acquire_queue: Clients waiting to acquire lock (FIFO)
22+
- _waiters: Clients in wait() (released lock, awaiting notify)
23+
- _client_conditions: Reverse index for cleanup on disconnect
24+
"""
1825

1926
def __init__(self, scheduler):
2027
self.scheduler = scheduler
28+
2129
# {condition_name: client_id} - who holds each lock
22-
self._lock_holders = {}
23-
# {condition_name: deque of (client_id, future)} - waiting to acquire
24-
self._acquire_waiters = defaultdict(deque)
25-
# {condition_name: {waiter_id: (client_id, Event)}} - waiting for notify
26-
self._notify_waiters = defaultdict(dict)
30+
self._locks = {}
31+
32+
# {condition_name: deque[(client_id, future)]} - waiting to acquire
33+
self._acquire_queue = defaultdict(deque)
34+
35+
# {condition_name: {waiter_id: (client_id, event, reacquire_future)}}
36+
# - clients in wait(), will need to reacquire after notify
37+
self._waiters = defaultdict(dict)
38+
39+
# {client_id: set(condition_names)} - for cleanup on disconnect
40+
self._client_conditions = defaultdict(set)
2741

2842
self.scheduler.handlers.update(
2943
{
@@ -35,115 +49,227 @@ def __init__(self, scheduler):
3549
}
3650
)
3751

52+
# Register cleanup on client disconnect
53+
self.scheduler.extensions["conditions"] = self
54+
55+
def _track_client(self, name, client_id):
56+
"""Track that a client is using this condition"""
57+
self._client_conditions[client_id].add(name)
58+
59+
def _untrack_client(self, name, client_id):
60+
"""Stop tracking client for this condition"""
61+
if client_id in self._client_conditions:
62+
self._client_conditions[client_id].discard(name)
63+
if not self._client_conditions[client_id]:
64+
del self._client_conditions[client_id]
65+
3866
@log_errors
3967
async def acquire(self, name=None, client_id=None):
4068
"""Acquire lock - blocks until available"""
41-
if name not in self._lock_holders:
69+
self._track_client(name, client_id)
70+
71+
if name not in self._locks:
4272
# Lock is free
43-
self._lock_holders[name] = client_id
73+
self._locks[name] = client_id
4474
return True
4575

46-
if self._lock_holders[name] == client_id:
47-
# Already hold it (shouldn't happen in normal use)
76+
if self._locks[name] == client_id:
77+
# Re-entrant acquire (from same client)
4878
return True
4979

50-
# Lock is held by someone else - wait our turn
80+
# Lock is held - queue up and wait
5181
future = asyncio.Future()
52-
self._acquire_waiters[name].append((client_id, future))
53-
await future
54-
return True
82+
self._acquire_queue[name].append((client_id, future))
83+
84+
try:
85+
await future
86+
return True
87+
except asyncio.CancelledError:
88+
# Remove from queue if cancelled
89+
queue = self._acquire_queue.get(name, deque())
90+
try:
91+
queue.remove((client_id, future))
92+
except ValueError:
93+
pass # Already removed
94+
raise
95+
96+
def _wake_next_acquirer(self, name):
97+
"""Wake the next client waiting to acquire this lock"""
98+
queue = self._acquire_queue.get(name, deque())
99+
100+
while queue:
101+
client_id, future = queue.popleft()
102+
if not future.done():
103+
self._locks[name] = client_id
104+
future.set_result(True)
105+
return True
106+
107+
# No waiters left
108+
if name in self._acquire_queue:
109+
del self._acquire_queue[name]
110+
return False
55111

56112
@log_errors
57113
async def release(self, name=None, client_id=None):
58114
"""Release lock"""
59-
if name not in self._lock_holders:
115+
if name not in self._locks:
60116
raise RuntimeError("Released too often")
61117

62-
if self._lock_holders[name] != client_id:
118+
if self._locks[name] != client_id:
63119
raise RuntimeError("Cannot release lock held by another client")
64120

65-
del self._lock_holders[name]
121+
del self._locks[name]
66122

67-
# Wake next waiter if any
68-
waiters = self._acquire_waiters.get(name, deque())
69-
while waiters:
70-
next_client_id, future = waiters.popleft()
71-
if not future.done():
72-
self._lock_holders[name] = next_client_id
73-
future.set_result(True)
74-
break
123+
# Wake next waiter trying to acquire
124+
if not self._wake_next_acquirer(name):
125+
# No acquire waiters - cleanup if no notify waiters either
126+
if name not in self._waiters:
127+
self._untrack_client(name, client_id)
75128

76129
@log_errors
77130
async def wait(self, name=None, waiter_id=None, client_id=None, timeout=None):
78-
"""Release lock, wait for notify, reacquire lock"""
131+
"""Release lock, wait for notify, reacquire lock
132+
133+
Critical: Register for notify BEFORE releasing lock to prevent lost wakeup
134+
"""
79135
# Verify caller holds lock
80-
if self._lock_holders.get(name) != client_id:
136+
if self._locks.get(name) != client_id:
81137
raise RuntimeError("wait() called without holding the lock")
82138

83-
# Release lock (waking next acquire waiter if any)
84-
await self.release(name=name, client_id=client_id)
139+
# 1. Register for notification FIRST (prevents lost wakeup)
140+
notify_event = asyncio.Event()
141+
reacquire_future = asyncio.Future()
142+
self._waiters[name][waiter_id] = (client_id, notify_event, reacquire_future)
85143

86-
# Register as notify waiter
87-
event = asyncio.Event()
88-
self._notify_waiters[name][waiter_id] = (client_id, event)
144+
# 2. Release lock (allows notifier to proceed)
145+
await self.release(name=name, client_id=client_id)
89146

90-
# Wait for notification
91-
future = event.wait()
147+
# 3. Wait for notification
148+
wait_future = notify_event.wait()
92149
if timeout is not None:
93-
future = wait_for(future, timeout)
150+
wait_future = wait_for(wait_future, timeout)
94151

152+
notified = False
95153
try:
96-
await future
97-
result = True
154+
await wait_future
155+
notified = True
98156
except TimeoutError:
99-
result = False
157+
notified = False
158+
except asyncio.CancelledError:
159+
# Cancelled - cleanup and don't reacquire
160+
self._waiters[name].pop(waiter_id, None)
161+
if not self._waiters[name]:
162+
del self._waiters[name]
163+
raise
100164
finally:
101-
# Cleanup waiter
102-
self._notify_waiters[name].pop(waiter_id, None)
103-
if not self._notify_waiters[name]:
104-
del self._notify_waiters[name]
165+
# Cleanup waiter registration
166+
self._waiters[name].pop(waiter_id, None)
167+
if not self._waiters[name]:
168+
del self._waiters[name]
105169

106-
# Reacquire lock - blocks until available
107-
await self.acquire(name=name, client_id=client_id)
170+
# 4. Reacquire lock before returning
171+
# This might block if other clients are waiting
172+
await self.acquire(name=name, client_id=client_id)
108173

109-
return result
174+
return notified
110175

111176
@log_errors
112177
def notify(self, name=None, client_id=None, n=1):
113178
"""Wake up n waiters"""
114179
# Verify caller holds lock
115-
if self._lock_holders.get(name) != client_id:
180+
if self._locks.get(name) != client_id:
116181
raise RuntimeError("notify() called without holding the lock")
117182

118-
waiters = self._notify_waiters.get(name, {})
183+
waiters = self._waiters.get(name, {})
119184
count = 0
120-
for _, (_, event) in list(waiters.items())[:n]:
185+
186+
for waiter_id in list(waiters.keys())[:n]:
187+
_, event, _ = waiters[waiter_id]
121188
event.set()
122189
count += 1
190+
123191
return count
124192

125193
@log_errors
126194
def notify_all(self, name=None, client_id=None):
127195
"""Wake up all waiters"""
128196
# Verify caller holds lock
129-
if self._lock_holders.get(name) != client_id:
197+
if self._locks.get(name) != client_id:
130198
raise RuntimeError("notify_all() called without holding the lock")
131199

132-
waiters = self._notify_waiters.get(name, {})
133-
for _, event in waiters.values():
200+
waiters = self._waiters.get(name, {})
201+
202+
for _, event, _ in waiters.values():
134203
event.set()
204+
135205
return len(waiters)
136206

207+
async def remove_client(self, client):
208+
"""Cleanup when client disconnects"""
209+
conditions = self._client_conditions.pop(client, set())
210+
211+
for name in conditions:
212+
# Release any locks held by this client
213+
if self._locks.get(name) == client:
214+
try:
215+
await self.release(name=name, client_id=client)
216+
except Exception as e:
217+
logger.warning(f"Error releasing lock for {name}: {e}")
218+
219+
# Cancel acquire waiters from this client
220+
queue = self._acquire_queue.get(name, deque())
221+
to_remove = []
222+
for i, (cid, future) in enumerate(queue):
223+
if cid == client and not future.done():
224+
future.cancel()
225+
to_remove.append(i)
226+
for i in reversed(to_remove):
227+
try:
228+
del queue[i]
229+
except IndexError:
230+
pass
231+
232+
# Cancel notify waiters from this client
233+
waiters = self._waiters.get(name, {})
234+
to_remove = []
235+
for waiter_id, (cid, event, reacq) in waiters.items():
236+
if cid == client:
237+
event.set() # Wake them up so they can cleanup
238+
if not reacq.done():
239+
reacq.cancel()
240+
to_remove.append(waiter_id)
241+
for wid in to_remove:
242+
waiters.pop(wid, None)
243+
137244

138245
class Condition(SyncMethodMixin):
139-
"""Distributed Condition Variable"""
246+
"""Distributed Condition Variable
247+
248+
Provides wait/notify synchronization across distributed clients.
249+
Multiple Condition instances with the same name share state.
250+
251+
Examples
252+
--------
253+
>>> condition = Condition('data-ready')
254+
>>>
255+
>>> # Consumer
256+
>>> async with condition:
257+
... while not data_available():
258+
... await condition.wait()
259+
... process_data()
260+
>>>
261+
>>> # Producer
262+
>>> async with condition:
263+
... produce_data()
264+
... condition.notify_all()
265+
"""
140266

141267
def __init__(self, name=None, client=None):
142268
self.name = name or f"condition-{uuid.uuid4().hex}"
143269
self._waiter_id = uuid.uuid4().hex
144270
self._client_id = uuid.uuid4().hex
145271
self._client = client
146-
self._is_locked = False # Track local state
272+
self._is_locked = False
147273

148274
@property
149275
def client(self):
@@ -160,44 +286,64 @@ def loop(self):
160286

161287
def _verify_running(self):
162288
if not self.client:
163-
raise RuntimeError(f"{type(self)} object not properly initialized")
289+
raise RuntimeError(
290+
f"{type(self)} object not properly initialized. "
291+
"This can happen if the object is being deserialized "
292+
"outside of the context of a Client or Worker."
293+
)
164294

165295
async def acquire(self):
166-
"""Acquire lock"""
296+
"""Acquire the lock"""
167297
self._verify_running()
168298
await self.client.scheduler.condition_acquire(
169299
name=self.name, client_id=self._client_id
170300
)
171301
self._is_locked = True
172302

173303
async def release(self):
174-
"""Release lock"""
304+
"""Release the lock"""
175305
self._verify_running()
176306
await self.client.scheduler.condition_release(
177307
name=self.name, client_id=self._client_id
178308
)
179309
self._is_locked = False
180310

181311
async def wait(self, timeout=None):
182-
"""Wait for notification - atomically releases and reacquires lock"""
312+
"""Wait for notification
313+
314+
Must be called while holding the lock. Atomically releases lock,
315+
waits for notify(), then reacquires lock before returning.
316+
317+
Parameters
318+
----------
319+
timeout : float, optional
320+
Maximum time to wait in seconds
321+
322+
Returns
323+
-------
324+
bool
325+
True if notified, False if timeout
326+
"""
183327
if not self._is_locked:
184328
raise RuntimeError("wait() called without holding the lock")
185329

186330
self._verify_running()
187331
timeout = parse_timedelta(timeout)
188332

189-
# This handles release, wait, reacquire atomically on scheduler
333+
# Scheduler handles atomic release/wait/reacquire
190334
result = await self.client.scheduler.condition_wait(
191335
name=self.name,
192336
waiter_id=self._waiter_id,
193337
client_id=self._client_id,
194338
timeout=timeout,
195339
)
196-
# Lock is reacquired by the time this returns
340+
341+
# Lock is reacquired when this returns
342+
# _is_locked stays True
197343
return result
198344

199345
def notify(self, n=1):
200-
"""Wake up n waiters"""
346+
"""Wake up n waiters (default 1)"""
201347
if not self._is_locked:
202348
raise RuntimeError("notify() called without holding the lock")
203349
self._verify_running()
@@ -220,7 +366,7 @@ def notify_all(self):
220366
)
221367

222368
def locked(self):
223-
"""Return True if lock is held by this instance"""
369+
"""Return True if this instance holds the lock"""
224370
return self._is_locked
225371

226372
async def __aenter__(self):

0 commit comments

Comments
 (0)