@@ -6,8 +6,7 @@ use std::io::{BufRead, BufReader, Read};
66use std:: path:: Path ;
77use std:: sync:: atomic:: { AtomicBool , Ordering } ;
88use std:: sync:: mpsc:: TryRecvError ;
9- use std:: sync:: Arc ;
10- use std:: sync:: { mpsc, Mutex } ;
9+ use std:: sync:: { mpsc, Arc } ;
1110use std:: thread;
1211use std:: thread:: sleep;
1312use std:: time:: { Duration , Instant } ;
@@ -274,7 +273,7 @@ struct Args {
274273#[ derive( Debug ) ]
275274enum ShardStatus {
276275 Ready ,
277- Failed ( ( usize , String ) ) ,
276+ Failed ( ( usize , Option < String > ) ) ,
278277}
279278
280279#[ allow( clippy:: too_many_arguments) ]
@@ -296,7 +295,7 @@ fn shard_manager(
296295 watermark_delta : Option < f32 > ,
297296 otlp_endpoint : Option < String > ,
298297 status_sender : mpsc:: Sender < ShardStatus > ,
299- shutdown : Arc < Mutex < bool > > ,
298+ shutdown : Arc < AtomicBool > ,
300299 _shutdown_sender : mpsc:: Sender < ( ) > ,
301300) {
302301 // Get UDS path
@@ -433,20 +432,20 @@ fn shard_manager(
433432 }
434433 }
435434 status_sender
436- . send ( ShardStatus :: Failed ( ( rank, err. to_string ( ) ) ) )
435+ . send ( ShardStatus :: Failed ( ( rank, Some ( err. to_string ( ) ) ) ) )
437436 . unwrap ( ) ;
438437 return ;
439438 }
440439 } ;
441440
442441 // Redirect STDOUT to the console
443- let shard_stdout = p. stdout . take ( ) . unwrap ( ) ;
442+ let shard_stdout_reader = BufReader :: new ( p. stdout . take ( ) . unwrap ( ) ) ;
443+ let mut shard_stderr_reader = BufReader :: new ( p. stderr . take ( ) . unwrap ( ) ) ;
444444
445445 thread:: spawn ( move || {
446446 // Enter shard-manager tracing span
447- let stdout = BufReader :: new ( shard_stdout) ;
448447 let _span = tracing:: span!( tracing:: Level :: INFO , "shard-manager" , rank = rank) . entered ( ) ;
449- for line in stdout . lines ( ) {
448+ for line in shard_stdout_reader . lines ( ) {
450449 // Parse loguru logs
451450 if let Ok ( log) = serde_json:: from_str :: < PythonLogMessage > ( & line. unwrap ( ) ) {
452451 log. trace ( ) ;
@@ -460,8 +459,22 @@ fn shard_manager(
460459 loop {
461460 // Process exited
462461 if let Some ( exit_status) = p. poll ( ) {
463- let mut err = String :: new ( ) ;
464- p. stderr . take ( ) . unwrap ( ) . read_to_string ( & mut err) . unwrap ( ) ;
462+ // We read stderr in another thread as it seems that `read_to_string` can block
463+ // indefinitely in some cases
464+ let ( err_sender, err_receiver) = mpsc:: channel ( ) ;
465+ thread:: spawn ( move || {
466+ let mut err = String :: new ( ) ;
467+ shard_stderr_reader. read_to_string ( & mut err) . unwrap ( ) ;
468+ err_sender. send ( err) . unwrap_or ( ( ) ) ;
469+ } ) ;
470+
471+ let err = err_receiver
472+ . recv_timeout ( Duration :: from_millis ( 100 ) )
473+ . map_err ( |err| {
474+ tracing:: error!( "Unable to read shard {rank} error from stderr" ) ;
475+ err
476+ } )
477+ . ok ( ) ;
465478
466479 if let ExitStatus :: Signaled ( signal) = exit_status {
467480 tracing:: error!( "Shard process was signaled to shutdown with signal {signal}" ) ;
@@ -474,7 +487,7 @@ fn shard_manager(
474487 }
475488
476489 // We received a shutdown signal
477- if * shutdown. lock ( ) . unwrap ( ) {
490+ if shutdown. load ( Ordering :: SeqCst ) {
478491 p. kill ( ) . unwrap ( ) ;
479492 let _ = p. wait_timeout ( Duration :: from_secs ( 90 ) ) ;
480493 tracing:: info!( "Shard {rank} terminated" ) ;
@@ -494,14 +507,11 @@ fn shard_manager(
494507 }
495508}
496509
497- fn shutdown_shards ( shutdown : Arc < Mutex < bool > > , shutdown_receiver : & mpsc:: Receiver < ( ) > ) {
510+ fn shutdown_shards ( shutdown : Arc < AtomicBool > , shutdown_receiver : & mpsc:: Receiver < ( ) > ) {
498511 tracing:: info!( "Shutting down shards" ) ;
499512 // Update shutdown value to true
500513 // This will be picked up by the shard manager
501- {
502- let mut shutdown = shutdown. lock ( ) . unwrap ( ) ;
503- * shutdown = true ;
504- }
514+ shutdown. store ( true , Ordering :: SeqCst ) ;
505515
506516 // Wait for shards to shutdown
507517 // This will block till all shutdown_sender are dropped
@@ -743,7 +753,7 @@ fn download_convert_model(args: &Args, running: Arc<AtomicBool>) -> Result<(), L
743753fn spawn_shards (
744754 num_shard : usize ,
745755 args : & Args ,
746- shutdown : Arc < Mutex < bool > > ,
756+ shutdown : Arc < AtomicBool > ,
747757 shutdown_receiver : & mpsc:: Receiver < ( ) > ,
748758 shutdown_sender : mpsc:: Sender < ( ) > ,
749759 status_receiver : & mpsc:: Receiver < ShardStatus > ,
@@ -819,7 +829,10 @@ fn spawn_shards(
819829 sleep ( Duration :: from_millis ( 100 ) ) ;
820830 }
821831 Ok ( ShardStatus :: Failed ( ( rank, err) ) ) => {
822- tracing:: error!( "Shard {} failed to start:\n {}" , rank, err) ;
832+ tracing:: error!( "Shard {rank} failed to start" ) ;
833+ if let Some ( err) = err {
834+ tracing:: error!( "{err}" ) ;
835+ }
823836 shutdown_shards ( shutdown, shutdown_receiver) ;
824837 return Err ( LauncherError :: ShardCannotStart ) ;
825838 }
@@ -835,7 +848,7 @@ fn spawn_shards(
835848
836849fn spawn_webserver (
837850 args : Args ,
838- shutdown : Arc < Mutex < bool > > ,
851+ shutdown : Arc < AtomicBool > ,
839852 shutdown_receiver : & mpsc:: Receiver < ( ) > ,
840853) -> Result < Popen , LauncherError > {
841854 // All shard started
@@ -1002,7 +1015,7 @@ fn main() -> Result<(), LauncherError> {
10021015 download_convert_model ( & args, running. clone ( ) ) ?;
10031016
10041017 // Shared shutdown bool
1005- let shutdown = Arc :: new ( Mutex :: new ( false ) ) ;
1018+ let shutdown = Arc :: new ( AtomicBool :: new ( false ) ) ;
10061019 // Shared shutdown channel
10071020 // When shutting down, the main thread will wait for all senders to be dropped
10081021 let ( shutdown_sender, shutdown_receiver) = mpsc:: channel ( ) ;
@@ -1034,7 +1047,10 @@ fn main() -> Result<(), LauncherError> {
10341047
10351048 while running. load ( Ordering :: SeqCst ) {
10361049 if let Ok ( ShardStatus :: Failed ( ( rank, err) ) ) = status_receiver. try_recv ( ) {
1037- tracing:: error!( "Shard {rank} failed:\n {err}" ) ;
1050+ tracing:: error!( "Shard {rank} failed to start" ) ;
1051+ if let Some ( err) = err {
1052+ tracing:: error!( "{err}" ) ;
1053+ }
10381054 exit_code = Err ( LauncherError :: ShardFailed ) ;
10391055 break ;
10401056 } ;
0 commit comments