1313
1414use std:: collections:: VecDeque ;
1515use std:: pin:: Pin ;
16+ use std:: sync:: atomic:: AtomicBool ;
17+ use std:: sync:: atomic:: Ordering ;
1618use std:: sync:: Arc ;
1719
1820use futures:: pin_mut;
@@ -21,7 +23,6 @@ use futures::task::AtomicWaker;
2123use futures:: task:: Context ;
2224use futures:: task:: Poll ;
2325use log:: debug;
24- use log:: info;
2526use log:: warn;
2627use pin_project:: pin_project;
2728use tokio:: sync:: mpsc;
@@ -93,12 +94,16 @@ async fn run_tso(
9394 // more requests from the bounded channel. This waker is used to wake up the sending future
9495 // if the queue containing pending requests is no longer full.
9596 let sending_future_waker = Arc :: new ( AtomicWaker :: new ( ) ) ;
97+ // This flag indicates the sender stream could not acquire `pending_requests` lock in poll
98+ // and needs an explicit wake from the response path.
99+ let sender_waiting_on_lock = Arc :: new ( AtomicBool :: new ( false ) ) ;
96100
97101 let request_stream = TsoRequestStream {
98102 cluster_id,
99103 request_rx,
100104 pending_requests : pending_requests. clone ( ) ,
101105 self_waker : sending_future_waker. clone ( ) ,
106+ sender_waiting_on_lock : sender_waiting_on_lock. clone ( ) ,
102107 } ;
103108
104109 // let send_requests = rpc_sender.send_all(&mut request_stream);
@@ -112,15 +117,24 @@ async fn run_tso(
112117 allocate_timestamps ( & resp, & mut pending_requests) ?;
113118 was_full && pending_requests. len ( ) < MAX_PENDING_COUNT
114119 } ;
120+ let sender_blocked_by_lock = sender_waiting_on_lock. swap ( false , Ordering :: AcqRel ) ;
115121
116- // Only wake sender when a previously full queue gains capacity.
117- if should_wake_sender {
122+ // Wake sender when:
123+ // 1. a previously full queue gains capacity, or
124+ // 2. sender was blocked on `pending_requests` mutex contention.
125+ if should_wake_sender || sender_blocked_by_lock {
118126 sending_future_waker. wake ( ) ;
119127 }
120128 }
121- // TODO: distinguish between unexpected stream termination and expected end of test
122- info ! ( "TSO stream terminated" ) ;
123- Ok ( ( ) )
129+ let pending_count = pending_requests. lock ( ) . await . len ( ) ;
130+ if pending_count == 0 {
131+ Ok ( ( ) )
132+ } else {
133+ Err ( internal_err ! (
134+ "TSO stream terminated with {} pending requests" ,
135+ pending_count
136+ ) )
137+ }
124138}
125139
126140struct RequestGroup {
@@ -135,6 +149,7 @@ struct TsoRequestStream {
135149 request_rx : mpsc:: Receiver < oneshot:: Sender < Timestamp > > ,
136150 pending_requests : Arc < Mutex < VecDeque < RequestGroup > > > ,
137151 self_waker : Arc < AtomicWaker > ,
152+ sender_waiting_on_lock : Arc < AtomicBool > ,
138153}
139154
140155impl Stream for TsoRequestStream {
@@ -147,8 +162,18 @@ impl Stream for TsoRequestStream {
147162 pin_mut ! ( pending_requests) ;
148163 let mut pending_requests = if let Poll :: Ready ( pending_requests) = pending_requests. poll ( cx)
149164 {
165+ this. sender_waiting_on_lock . store ( false , Ordering :: Release ) ;
150166 pending_requests
151167 } else {
168+ // The lock future is dropped at the end of this poll, so record the lock-wait state
169+ // and rely on the response path to issue a wake after it releases the lock.
170+ this. sender_waiting_on_lock . store ( true , Ordering :: Release ) ;
171+ this. self_waker . register ( cx. waker ( ) ) ;
172+ // If response path consumed the wait flag before registration, its wake might be lost.
173+ // Trigger one local wake to guarantee another poll.
174+ if !this. sender_waiting_on_lock . load ( Ordering :: Acquire ) {
175+ cx. waker ( ) . wake_by_ref ( ) ;
176+ }
152177 return Poll :: Pending ;
153178 } ;
154179 let mut requests = Vec :: new ( ) ;
@@ -230,3 +255,214 @@ fn allocate_timestamps(
230255 } ;
231256 Ok ( ( ) )
232257}
258+
259+ #[ cfg( test) ]
260+ mod tests {
261+ use std:: sync:: atomic:: AtomicUsize ;
262+ use std:: sync:: Arc ;
263+
264+ use futures:: executor:: block_on;
265+ use futures:: task:: noop_waker_ref;
266+ use futures:: task:: waker;
267+ use futures:: task:: ArcWake ;
268+
269+ use super :: * ;
270+
271+ struct WakeCounter {
272+ wakes : AtomicUsize ,
273+ }
274+
275+ impl ArcWake for WakeCounter {
276+ fn wake_by_ref ( arc_self : & Arc < Self > ) {
277+ arc_self. wakes . fetch_add ( 1 , Ordering :: SeqCst ) ;
278+ }
279+ }
280+
281+ fn test_tso_request ( count : u32 ) -> TsoRequest {
282+ TsoRequest {
283+ header : Some ( RequestHeader {
284+ cluster_id : 1 ,
285+ sender_id : 0 ,
286+ } ) ,
287+ count,
288+ dc_location : String :: new ( ) ,
289+ }
290+ }
291+
292+ fn test_tso_response ( count : u32 , logical : i64 ) -> TsoResponse {
293+ TsoResponse {
294+ header : None ,
295+ count,
296+ timestamp : Some ( Timestamp {
297+ physical : 123 ,
298+ logical,
299+ suffix_bits : 0 ,
300+ } ) ,
301+ }
302+ }
303+
304+ type TestStreamContext = (
305+ TsoRequestStream ,
306+ mpsc:: Sender < TimestampRequest > ,
307+ Arc < Mutex < VecDeque < RequestGroup > > > ,
308+ Arc < AtomicWaker > ,
309+ Arc < AtomicBool > ,
310+ ) ;
311+
312+ fn new_test_stream ( ) -> TestStreamContext {
313+ let ( request_tx, request_rx) = mpsc:: channel ( MAX_BATCH_SIZE ) ;
314+ let pending_requests = Arc :: new ( Mutex :: new ( VecDeque :: new ( ) ) ) ;
315+ let self_waker = Arc :: new ( AtomicWaker :: new ( ) ) ;
316+ let sender_waiting_on_lock = Arc :: new ( AtomicBool :: new ( false ) ) ;
317+ let stream = TsoRequestStream {
318+ cluster_id : 1 ,
319+ request_rx,
320+ pending_requests : pending_requests. clone ( ) ,
321+ self_waker : self_waker. clone ( ) ,
322+ sender_waiting_on_lock : sender_waiting_on_lock. clone ( ) ,
323+ } ;
324+ (
325+ stream,
326+ request_tx,
327+ pending_requests,
328+ self_waker,
329+ sender_waiting_on_lock,
330+ )
331+ }
332+
333+ #[ test]
334+ fn allocate_timestamps_successfully_assigns_monotonic_timestamps ( ) {
335+ let ( tx1, rx1) = oneshot:: channel ( ) ;
336+ let ( tx2, rx2) = oneshot:: channel ( ) ;
337+ let ( tx3, rx3) = oneshot:: channel ( ) ;
338+ let mut pending_requests = VecDeque :: new ( ) ;
339+ pending_requests. push_back ( RequestGroup {
340+ tso_request : test_tso_request ( 3 ) ,
341+ requests : vec ! [ tx1, tx2, tx3] ,
342+ } ) ;
343+
344+ allocate_timestamps ( & test_tso_response ( 3 , 100 ) , & mut pending_requests) . unwrap ( ) ;
345+ assert ! ( pending_requests. is_empty( ) ) ;
346+
347+ let ts1 = block_on ( rx1) . unwrap ( ) ;
348+ let ts2 = block_on ( rx2) . unwrap ( ) ;
349+ let ts3 = block_on ( rx3) . unwrap ( ) ;
350+ assert_eq ! ( ts1. logical, 98 ) ;
351+ assert_eq ! ( ts2. logical, 99 ) ;
352+ assert_eq ! ( ts3. logical, 100 ) ;
353+ }
354+
355+ #[ test]
356+ fn allocate_timestamps_errors_without_timestamp ( ) {
357+ let ( tx, _rx) = oneshot:: channel ( ) ;
358+ let mut pending_requests = VecDeque :: new ( ) ;
359+ pending_requests. push_back ( RequestGroup {
360+ tso_request : test_tso_request ( 1 ) ,
361+ requests : vec ! [ tx] ,
362+ } ) ;
363+ let resp = TsoResponse {
364+ header : None ,
365+ count : 1 ,
366+ timestamp : None ,
367+ } ;
368+
369+ let err = allocate_timestamps ( & resp, & mut pending_requests) . unwrap_err ( ) ;
370+ assert ! ( format!( "{err:?}" ) . contains( "No timestamp in TsoResponse" ) ) ;
371+ }
372+
373+ #[ test]
374+ fn allocate_timestamps_errors_when_count_mismatches ( ) {
375+ let ( tx, _rx) = oneshot:: channel ( ) ;
376+ let mut pending_requests = VecDeque :: new ( ) ;
377+ pending_requests. push_back ( RequestGroup {
378+ tso_request : test_tso_request ( 2 ) ,
379+ requests : vec ! [ tx] ,
380+ } ) ;
381+
382+ let err =
383+ allocate_timestamps ( & test_tso_response ( 1 , 10 ) , & mut pending_requests) . unwrap_err ( ) ;
384+ assert ! ( format!( "{err:?}" ) . contains( "different number of timestamps" ) ) ;
385+ }
386+
387+ #[ test]
388+ fn allocate_timestamps_errors_on_extra_response ( ) {
389+ let mut pending_requests = VecDeque :: new ( ) ;
390+ let err =
391+ allocate_timestamps ( & test_tso_response ( 1 , 10 ) , & mut pending_requests) . unwrap_err ( ) ;
392+ assert ! ( format!( "{err:?}" ) . contains( "more TsoResponse than expected" ) ) ;
393+ }
394+
395+ #[ test]
396+ fn poll_next_emits_request_and_enqueues_request_group ( ) {
397+ let ( stream, request_tx, pending_requests, _self_waker, sender_waiting_on_lock) =
398+ new_test_stream ( ) ;
399+ let ( tx, _rx) = oneshot:: channel ( ) ;
400+ request_tx. try_send ( tx) . unwrap ( ) ;
401+
402+ let mut stream = Box :: pin ( stream) ;
403+ let mut cx = Context :: from_waker ( noop_waker_ref ( ) ) ;
404+ let polled = stream. as_mut ( ) . poll_next ( & mut cx) ;
405+ let req = match polled {
406+ Poll :: Ready ( Some ( req) ) => req,
407+ other => panic ! ( "expected Poll::Ready(Some(_)), got {:?}" , other) ,
408+ } ;
409+
410+ assert_eq ! ( req. count, 1 ) ;
411+ assert ! ( !sender_waiting_on_lock. load( Ordering :: SeqCst ) ) ;
412+ let queued = block_on ( async { pending_requests. lock ( ) . await . len ( ) } ) ;
413+ assert_eq ! ( queued, 1 ) ;
414+ }
415+
416+ #[ test]
417+ fn poll_next_registers_self_waker_when_pending_queue_is_full ( ) {
418+ let ( stream, _request_tx, pending_requests, self_waker, _sender_waiting_on_lock) =
419+ new_test_stream ( ) ;
420+ block_on ( async {
421+ let mut guard = pending_requests. lock ( ) . await ;
422+ for _ in 0 ..MAX_PENDING_COUNT {
423+ guard. push_back ( RequestGroup {
424+ tso_request : test_tso_request ( 0 ) ,
425+ requests : Vec :: new ( ) ,
426+ } ) ;
427+ }
428+ } ) ;
429+
430+ let wake_counter = Arc :: new ( WakeCounter {
431+ wakes : AtomicUsize :: new ( 0 ) ,
432+ } ) ;
433+ let test_waker = waker ( wake_counter. clone ( ) ) ;
434+ let mut cx = Context :: from_waker ( & test_waker) ;
435+ let mut stream = Box :: pin ( stream) ;
436+
437+ let polled = stream. as_mut ( ) . poll_next ( & mut cx) ;
438+ assert ! ( matches!( polled, Poll :: Pending ) ) ;
439+ assert_eq ! ( wake_counter. wakes. load( Ordering :: SeqCst ) , 0 ) ;
440+
441+ self_waker. wake ( ) ;
442+ assert_eq ! ( wake_counter. wakes. load( Ordering :: SeqCst ) , 1 ) ;
443+ }
444+
445+ #[ test]
446+ fn poll_next_marks_waiting_flag_when_lock_is_contended ( ) {
447+ let ( stream, _request_tx, pending_requests, self_waker, sender_waiting_on_lock) =
448+ new_test_stream ( ) ;
449+ let lock_guard = block_on ( pending_requests. lock ( ) ) ;
450+
451+ let wake_counter = Arc :: new ( WakeCounter {
452+ wakes : AtomicUsize :: new ( 0 ) ,
453+ } ) ;
454+ let test_waker = waker ( wake_counter. clone ( ) ) ;
455+ let mut cx = Context :: from_waker ( & test_waker) ;
456+ let mut stream = Box :: pin ( stream) ;
457+
458+ let polled = stream. as_mut ( ) . poll_next ( & mut cx) ;
459+ assert ! ( matches!( polled, Poll :: Pending ) ) ;
460+ assert ! ( sender_waiting_on_lock. load( Ordering :: SeqCst ) ) ;
461+
462+ drop ( lock_guard) ;
463+ if sender_waiting_on_lock. swap ( false , Ordering :: AcqRel ) {
464+ self_waker. wake ( ) ;
465+ }
466+ assert ! ( wake_counter. wakes. load( Ordering :: SeqCst ) >= 1 ) ;
467+ }
468+ }
0 commit comments