Skip to content

Commit 678f8b3

Browse files
committed
Update condition.py,test_condition.py
1 parent 2af5dbb commit 678f8b3

File tree

2 files changed

+127
-125
lines changed

2 files changed

+127
-125
lines changed

distributed/condition.py

Lines changed: 121 additions & 121 deletions
Original file line numberDiff line numberDiff line change
@@ -3,46 +3,91 @@
33
import asyncio
44
import logging
55
import uuid
6-
from collections import defaultdict
6+
from collections import defaultdict, deque
77

88
from dask.utils import parse_timedelta
99

10-
from distributed.lock import Lock
1110
from distributed.utils import SyncMethodMixin, TimeoutError, log_errors, wait_for
1211
from distributed.worker import get_client
1312

1413
logger = logging.getLogger(__name__)
1514

1615

1716
class ConditionExtension:
18-
"""Scheduler extension for managing Condition variable notifications
19-
20-
Coordinates wait/notify between distributed clients.
21-
The lock itself is managed by LockExtension.
22-
"""
17+
"""Scheduler extension managing Condition lock and notifications"""
2318

2419
def __init__(self, scheduler):
2520
self.scheduler = scheduler
26-
# {condition_name: {waiter_id: asyncio.Event}}
27-
self._waiters = defaultdict(dict)
21+
# {condition_name: client_id} - who holds each lock
22+
self._lock_holders = {}
23+
# {condition_name: deque of (client_id, future)} - waiting to acquire
24+
self._acquire_waiters = defaultdict(deque)
25+
# {condition_name: {waiter_id: (client_id, Event)}} - waiting for notify
26+
self._notify_waiters = defaultdict(dict)
2827

2928
self.scheduler.handlers.update(
3029
{
30+
"condition_acquire": self.acquire,
31+
"condition_release": self.release,
3132
"condition_wait": self.wait,
3233
"condition_notify": self.notify,
3334
"condition_notify_all": self.notify_all,
3435
}
3536
)
3637

3738
@log_errors
38-
async def wait(self, name=None, id=None, timeout=None):
39-
"""Register waiter and block until notified
39+
async def acquire(self, name=None, client_id=None):
40+
"""Acquire lock - blocks until available"""
41+
if name not in self._lock_holders:
42+
# Lock is free
43+
self._lock_holders[name] = client_id
44+
return True
45+
46+
if self._lock_holders[name] == client_id:
47+
# Already hold it (shouldn't happen in normal use)
48+
return True
49+
50+
# Lock is held by someone else - wait our turn
51+
future = asyncio.Future()
52+
self._acquire_waiters[name].append((client_id, future))
53+
await future
54+
return True
55+
56+
@log_errors
57+
async def release(self, name=None, client_id=None):
58+
"""Release lock"""
59+
if name not in self._lock_holders:
60+
raise RuntimeError("Released too often")
61+
62+
if self._lock_holders[name] != client_id:
63+
raise RuntimeError("Cannot release lock held by another client")
4064

41-
Caller must have released the lock before calling this.
42-
"""
65+
del self._lock_holders[name]
66+
67+
# Wake next waiter if any
68+
waiters = self._acquire_waiters.get(name, deque())
69+
while waiters:
70+
next_client_id, future = waiters.popleft()
71+
if not future.done():
72+
self._lock_holders[name] = next_client_id
73+
future.set_result(True)
74+
break
75+
76+
@log_errors
77+
async def wait(self, name=None, waiter_id=None, client_id=None, timeout=None):
78+
"""Release lock, wait for notify, reacquire lock"""
79+
# Verify caller holds lock
80+
if self._lock_holders.get(name) != client_id:
81+
raise RuntimeError("wait() called without holding the lock")
82+
83+
# Release lock (waking next acquire waiter if any)
84+
await self.release(name=name, client_id=client_id)
85+
86+
# Register as notify waiter
4387
event = asyncio.Event()
44-
self._waiters[name][id] = event
88+
self._notify_waiters[name][waiter_id] = (client_id, event)
4589

90+
# Wait for notification
4691
future = event.wait()
4792
if timeout is not None:
4893
future = wait_for(future, timeout)
@@ -53,65 +98,52 @@ async def wait(self, name=None, id=None, timeout=None):
5398
except TimeoutError:
5499
result = False
55100
finally:
56-
self._waiters[name].pop(id, None)
57-
if not self._waiters[name]:
58-
del self._waiters[name]
101+
# Cleanup waiter
102+
self._notify_waiters[name].pop(waiter_id, None)
103+
if not self._notify_waiters[name]:
104+
del self._notify_waiters[name]
105+
106+
# Reacquire lock - blocks until available
107+
await self.acquire(name=name, client_id=client_id)
59108

60109
return result
61110

62111
@log_errors
63-
def notify(self, name=None, n=1):
112+
def notify(self, name=None, client_id=None, n=1):
64113
"""Wake up n waiters"""
65-
waiters = self._waiters.get(name, {})
114+
# Verify caller holds lock
115+
if self._lock_holders.get(name) != client_id:
116+
raise RuntimeError("notify() called without holding the lock")
117+
118+
waiters = self._notify_waiters.get(name, {})
66119
count = 0
67-
for event in list(waiters.values())[:n]:
120+
for _, (_, event) in list(waiters.items())[:n]:
68121
event.set()
69122
count += 1
70123
return count
71124

72125
@log_errors
73-
def notify_all(self, name=None):
126+
def notify_all(self, name=None, client_id=None):
74127
"""Wake up all waiters"""
75-
waiters = self._waiters.get(name, {})
76-
for event in waiters.values():
128+
# Verify caller holds lock
129+
if self._lock_holders.get(name) != client_id:
130+
raise RuntimeError("notify_all() called without holding the lock")
131+
132+
waiters = self._notify_waiters.get(name, {})
133+
for _, event in waiters.values():
77134
event.set()
78135
return len(waiters)
79136

80137

81138
class Condition(SyncMethodMixin):
82-
"""Distributed Condition Variable
83-
84-
Combines a Lock with wait/notify coordination across the cluster.
85-
86-
Parameters
87-
----------
88-
name : str, optional
89-
Name of the condition. Conditions with the same name share state.
90-
client : Client, optional
91-
Client for scheduler communication.
92-
93-
Examples
94-
--------
95-
Producer-consumer pattern:
96-
97-
>>> condition = Condition('data-ready')
98-
>>> # Consumer
99-
>>> async with condition:
100-
... while not data_available():
101-
... await condition.wait()
102-
... process_data()
103-
104-
>>> # Producer
105-
>>> async with condition:
106-
... produce_data()
107-
... condition.notify_all()
108-
"""
139+
"""Distributed Condition Variable"""
109140

110141
def __init__(self, name=None, client=None):
111142
self.name = name or f"condition-{uuid.uuid4().hex}"
112-
self.id = uuid.uuid4().hex
113-
self._lock = Lock(name=f"{self.name}-lock")
143+
self._waiter_id = uuid.uuid4().hex
144+
self._client_id = uuid.uuid4().hex
114145
self._client = client
146+
self._is_locked = False # Track local state
115147

116148
@property
117149
def client(self):
@@ -124,104 +156,72 @@ def client(self):
124156

125157
@property
126158
def loop(self):
127-
return self._lock.loop
159+
return self.client.loop
128160

129161
def _verify_running(self):
130162
if not self.client:
131-
raise RuntimeError(
132-
f"{type(self)} object not properly initialized. This can happen"
133-
" if the object is being deserialized outside of the context of"
134-
" a Client or Worker."
135-
)
163+
raise RuntimeError(f"{type(self)} object not properly initialized")
136164

137165
async def acquire(self):
138-
"""Acquire the underlying lock"""
139-
return await self._lock.acquire()
166+
"""Acquire lock"""
167+
self._verify_running()
168+
await self.client.scheduler.condition_acquire(
169+
name=self.name, client_id=self._client_id
170+
)
171+
self._is_locked = True
140172

141173
async def release(self):
142-
"""Release the underlying lock"""
143-
await self._lock.release()
174+
"""Release lock"""
175+
self._verify_running()
176+
await self.client.scheduler.condition_release(
177+
name=self.name, client_id=self._client_id
178+
)
179+
self._is_locked = False
144180

145181
async def wait(self, timeout=None):
146-
"""Wait until notified
147-
148-
Must be called while lock is held. Atomically releases lock,
149-
waits for notify(), then reacquires lock before returning.
150-
151-
Parameters
152-
----------
153-
timeout : number or string or timedelta, optional
154-
Maximum time to wait for notification.
155-
156-
Returns
157-
-------
158-
bool
159-
True if notified, False if timeout occurred
160-
161-
Raises
162-
------
163-
RuntimeError
164-
If called without holding the lock
165-
"""
166-
if not self._lock.locked():
182+
"""Wait for notification - atomically releases and reacquires lock"""
183+
if not self._is_locked:
167184
raise RuntimeError("wait() called without holding the lock")
168185

169186
self._verify_running()
170187
timeout = parse_timedelta(timeout)
171188

172-
# Atomically: release lock, wait for notify, reacquire lock
173-
await self._lock.release()
174-
try:
175-
result = await self.client.scheduler.condition_wait(
176-
name=self.name, id=self.id, timeout=timeout
177-
)
178-
finally:
179-
await self._lock.acquire()
180-
189+
# This handles release, wait, reacquire atomically on scheduler
190+
result = await self.client.scheduler.condition_wait(
191+
name=self.name,
192+
waiter_id=self._waiter_id,
193+
client_id=self._client_id,
194+
timeout=timeout,
195+
)
196+
# Lock is reacquired by the time this returns
181197
return result
182198

183199
def notify(self, n=1):
184-
"""Wake up one or more waiters
185-
186-
Must be called while holding the lock.
187-
188-
Parameters
189-
----------
190-
n : int, optional
191-
Number of waiters to wake. Default is 1.
192-
193-
Returns
194-
-------
195-
int
196-
Number of waiters actually notified
197-
"""
198-
if not self._lock.locked():
200+
"""Wake up n waiters"""
201+
if not self._is_locked:
199202
raise RuntimeError("notify() called without holding the lock")
200203
self._verify_running()
201204
return self.client.sync(
202-
self.client.scheduler.condition_notify, name=self.name, n=n
205+
self.client.scheduler.condition_notify,
206+
name=self.name,
207+
client_id=self._client_id,
208+
n=n,
203209
)
204210

205211
def notify_all(self):
206-
"""Wake up all waiters
207-
208-
Must be called while holding the lock.
209-
210-
Returns
211-
-------
212-
int
213-
Number of waiters notified
214-
"""
215-
if not self._lock.locked():
212+
"""Wake up all waiters"""
213+
if not self._is_locked:
216214
raise RuntimeError("notify_all() called without holding the lock")
217215
self._verify_running()
218216
return self.client.sync(
219-
self.client.scheduler.condition_notify_all, name=self.name
217+
self.client.scheduler.condition_notify_all,
218+
name=self.name,
219+
client_id=self._client_id,
220220
)
221221

222222
def locked(self):
223-
"""Return True if the lock is currently held"""
224-
return self._lock.locked()
223+
"""Return True if lock is held by this instance"""
224+
return self._is_locked
225225

226226
async def __aenter__(self):
227227
await self.acquire()

distributed/tests/test_condition.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -286,16 +286,18 @@ async def test_condition_cleanup(c, s, a, b):
286286
"""Test that condition state is cleaned up after use"""
287287
condition = Condition("cleanup-test")
288288

289-
# Check initial state - only check waiters since locks are managed by LockExtension
290-
assert "cleanup-test" not in s.extensions["conditions"]._waiters
289+
# Check initial state
290+
assert "cleanup-test" not in s.extensions["conditions"]._lock_holders
291+
assert "cleanup-test" not in s.extensions["conditions"]._notify_waiters
291292

292293
# Use condition
293294
async with condition:
294295
condition.notify()
295296

296-
# Waiter state should be cleaned up
297+
# State should be cleaned up
297298
await asyncio.sleep(0.1)
298-
assert "cleanup-test" not in s.extensions["conditions"]._waiters
299+
assert "cleanup-test" not in s.extensions["conditions"]._lock_holders
300+
assert "cleanup-test" not in s.extensions["conditions"]._notify_waiters
299301

300302

301303
@gen_cluster(client=True)

0 commit comments

Comments
 (0)