33import asyncio
44import logging
55import uuid
6- from collections import defaultdict
6+ from collections import defaultdict , deque
77
88from dask .utils import parse_timedelta
99
10- from distributed .lock import Lock
1110from distributed .utils import SyncMethodMixin , TimeoutError , log_errors , wait_for
1211from distributed .worker import get_client
1312
1413logger = logging .getLogger (__name__ )
1514
1615
1716class ConditionExtension :
18- """Scheduler extension for managing Condition variable notifications
19-
20- Coordinates wait/notify between distributed clients.
21- The lock itself is managed by LockExtension.
22- """
17+ """Scheduler extension managing Condition lock and notifications"""
2318
2419 def __init__ (self , scheduler ):
2520 self .scheduler = scheduler
26- # {condition_name: {waiter_id: asyncio.Event}}
27- self ._waiters = defaultdict (dict )
21+ # {condition_name: client_id} - who holds each lock
22+ self ._lock_holders = {}
23+ # {condition_name: deque of (client_id, future)} - waiting to acquire
24+ self ._acquire_waiters = defaultdict (deque )
25+ # {condition_name: {waiter_id: (client_id, Event)}} - waiting for notify
26+ self ._notify_waiters = defaultdict (dict )
2827
2928 self .scheduler .handlers .update (
3029 {
30+ "condition_acquire" : self .acquire ,
31+ "condition_release" : self .release ,
3132 "condition_wait" : self .wait ,
3233 "condition_notify" : self .notify ,
3334 "condition_notify_all" : self .notify_all ,
3435 }
3536 )
3637
3738 @log_errors
38- async def wait (self , name = None , id = None , timeout = None ):
39- """Register waiter and block until notified
39+ async def acquire (self , name = None , client_id = None ):
40+ """Acquire lock - blocks until available"""
41+ if name not in self ._lock_holders :
42+ # Lock is free
43+ self ._lock_holders [name ] = client_id
44+ return True
45+
46+ if self ._lock_holders [name ] == client_id :
47+ # Already hold it (shouldn't happen in normal use)
48+ return True
49+
50+ # Lock is held by someone else - wait our turn
51+ future = asyncio .Future ()
52+ self ._acquire_waiters [name ].append ((client_id , future ))
53+ await future
54+ return True
55+
56+ @log_errors
57+ async def release (self , name = None , client_id = None ):
58+ """Release lock"""
59+ if name not in self ._lock_holders :
60+ raise RuntimeError ("Released too often" )
61+
62+ if self ._lock_holders [name ] != client_id :
63+ raise RuntimeError ("Cannot release lock held by another client" )
4064
41- Caller must have released the lock before calling this.
42- """
65+ del self ._lock_holders [name ]
66+
67+ # Wake next waiter if any
68+ waiters = self ._acquire_waiters .get (name , deque ())
69+ while waiters :
70+ next_client_id , future = waiters .popleft ()
71+ if not future .done ():
72+ self ._lock_holders [name ] = next_client_id
73+ future .set_result (True )
74+ break
75+
76+ @log_errors
77+ async def wait (self , name = None , waiter_id = None , client_id = None , timeout = None ):
78+ """Release lock, wait for notify, reacquire lock"""
79+ # Verify caller holds lock
80+ if self ._lock_holders .get (name ) != client_id :
81+ raise RuntimeError ("wait() called without holding the lock" )
82+
83+ # Release lock (waking next acquire waiter if any)
84+ await self .release (name = name , client_id = client_id )
85+
86+ # Register as notify waiter
4387 event = asyncio .Event ()
44- self ._waiters [name ][id ] = event
88+ self ._notify_waiters [name ][waiter_id ] = ( client_id , event )
4589
90+ # Wait for notification
4691 future = event .wait ()
4792 if timeout is not None :
4893 future = wait_for (future , timeout )
@@ -53,65 +98,52 @@ async def wait(self, name=None, id=None, timeout=None):
5398 except TimeoutError :
5499 result = False
55100 finally :
56- self ._waiters [name ].pop (id , None )
57- if not self ._waiters [name ]:
58- del self ._waiters [name ]
101+ # Cleanup waiter
102+ self ._notify_waiters [name ].pop (waiter_id , None )
103+ if not self ._notify_waiters [name ]:
104+ del self ._notify_waiters [name ]
105+
106+ # Reacquire lock - blocks until available
107+ await self .acquire (name = name , client_id = client_id )
59108
60109 return result
61110
62111 @log_errors
63- def notify (self , name = None , n = 1 ):
112+ def notify (self , name = None , client_id = None , n = 1 ):
64113 """Wake up n waiters"""
65- waiters = self ._waiters .get (name , {})
114+ # Verify caller holds lock
115+ if self ._lock_holders .get (name ) != client_id :
116+ raise RuntimeError ("notify() called without holding the lock" )
117+
118+ waiters = self ._notify_waiters .get (name , {})
66119 count = 0
67- for event in list (waiters .values ())[:n ]:
120+ for _ , ( _ , event ) in list (waiters .items ())[:n ]:
68121 event .set ()
69122 count += 1
70123 return count
71124
72125 @log_errors
73- def notify_all (self , name = None ):
126+ def notify_all (self , name = None , client_id = None ):
74127 """Wake up all waiters"""
75- waiters = self ._waiters .get (name , {})
76- for event in waiters .values ():
128+ # Verify caller holds lock
129+ if self ._lock_holders .get (name ) != client_id :
130+ raise RuntimeError ("notify_all() called without holding the lock" )
131+
132+ waiters = self ._notify_waiters .get (name , {})
133+ for _ , event in waiters .values ():
77134 event .set ()
78135 return len (waiters )
79136
80137
81138class Condition (SyncMethodMixin ):
82- """Distributed Condition Variable
83-
84- Combines a Lock with wait/notify coordination across the cluster.
85-
86- Parameters
87- ----------
88- name : str, optional
89- Name of the condition. Conditions with the same name share state.
90- client : Client, optional
91- Client for scheduler communication.
92-
93- Examples
94- --------
95- Producer-consumer pattern:
96-
97- >>> condition = Condition('data-ready')
98- >>> # Consumer
99- >>> async with condition:
100- ... while not data_available():
101- ... await condition.wait()
102- ... process_data()
103-
104- >>> # Producer
105- >>> async with condition:
106- ... produce_data()
107- ... condition.notify_all()
108- """
139+ """Distributed Condition Variable"""
109140
110141 def __init__ (self , name = None , client = None ):
111142 self .name = name or f"condition-{ uuid .uuid4 ().hex } "
112- self .id = uuid .uuid4 ().hex
113- self ._lock = Lock ( name = f" { self . name } -lock" )
143+ self ._waiter_id = uuid .uuid4 ().hex
144+ self ._client_id = uuid . uuid4 (). hex
114145 self ._client = client
146+ self ._is_locked = False # Track local state
115147
116148 @property
117149 def client (self ):
@@ -124,104 +156,72 @@ def client(self):
124156
125157 @property
126158 def loop (self ):
127- return self ._lock .loop
159+ return self .client .loop
128160
129161 def _verify_running (self ):
130162 if not self .client :
131- raise RuntimeError (
132- f"{ type (self )} object not properly initialized. This can happen"
133- " if the object is being deserialized outside of the context of"
134- " a Client or Worker."
135- )
163+ raise RuntimeError (f"{ type (self )} object not properly initialized" )
136164
137165 async def acquire (self ):
138- """Acquire the underlying lock"""
139- return await self ._lock .acquire ()
166+ """Acquire lock"""
167+ self ._verify_running ()
168+ await self .client .scheduler .condition_acquire (
169+ name = self .name , client_id = self ._client_id
170+ )
171+ self ._is_locked = True
140172
141173 async def release (self ):
142- """Release the underlying lock"""
143- await self ._lock .release ()
174+ """Release lock"""
175+ self ._verify_running ()
176+ await self .client .scheduler .condition_release (
177+ name = self .name , client_id = self ._client_id
178+ )
179+ self ._is_locked = False
144180
145181 async def wait (self , timeout = None ):
146- """Wait until notified
147-
148- Must be called while lock is held. Atomically releases lock,
149- waits for notify(), then reacquires lock before returning.
150-
151- Parameters
152- ----------
153- timeout : number or string or timedelta, optional
154- Maximum time to wait for notification.
155-
156- Returns
157- -------
158- bool
159- True if notified, False if timeout occurred
160-
161- Raises
162- ------
163- RuntimeError
164- If called without holding the lock
165- """
166- if not self ._lock .locked ():
182+ """Wait for notification - atomically releases and reacquires lock"""
183+ if not self ._is_locked :
167184 raise RuntimeError ("wait() called without holding the lock" )
168185
169186 self ._verify_running ()
170187 timeout = parse_timedelta (timeout )
171188
172- # Atomically: release lock, wait for notify, reacquire lock
173- await self ._lock .release ()
174- try :
175- result = await self .client .scheduler .condition_wait (
176- name = self .name , id = self .id , timeout = timeout
177- )
178- finally :
179- await self ._lock .acquire ()
180-
189+ # This handles release, wait, reacquire atomically on scheduler
190+ result = await self .client .scheduler .condition_wait (
191+ name = self .name ,
192+ waiter_id = self ._waiter_id ,
193+ client_id = self ._client_id ,
194+ timeout = timeout ,
195+ )
196+ # Lock is reacquired by the time this returns
181197 return result
182198
183199 def notify (self , n = 1 ):
184- """Wake up one or more waiters
185-
186- Must be called while holding the lock.
187-
188- Parameters
189- ----------
190- n : int, optional
191- Number of waiters to wake. Default is 1.
192-
193- Returns
194- -------
195- int
196- Number of waiters actually notified
197- """
198- if not self ._lock .locked ():
200+ """Wake up n waiters"""
201+ if not self ._is_locked :
199202 raise RuntimeError ("notify() called without holding the lock" )
200203 self ._verify_running ()
201204 return self .client .sync (
202- self .client .scheduler .condition_notify , name = self .name , n = n
205+ self .client .scheduler .condition_notify ,
206+ name = self .name ,
207+ client_id = self ._client_id ,
208+ n = n ,
203209 )
204210
205211 def notify_all (self ):
206- """Wake up all waiters
207-
208- Must be called while holding the lock.
209-
210- Returns
211- -------
212- int
213- Number of waiters notified
214- """
215- if not self ._lock .locked ():
212+ """Wake up all waiters"""
213+ if not self ._is_locked :
216214 raise RuntimeError ("notify_all() called without holding the lock" )
217215 self ._verify_running ()
218216 return self .client .sync (
219- self .client .scheduler .condition_notify_all , name = self .name
217+ self .client .scheduler .condition_notify_all ,
218+ name = self .name ,
219+ client_id = self ._client_id ,
220220 )
221221
222222 def locked (self ):
223- """Return True if the lock is currently held"""
224- return self ._lock . locked ()
223+ """Return True if lock is held by this instance """
224+ return self ._is_locked
225225
226226 async def __aenter__ (self ):
227227 await self .acquire ()
0 commit comments