Skip to content

Commit 2b53d71

Browse files
fix(launcher): fix issue where launcher does not properly report shard failures (#522)
1 parent ecf6dc3 commit 2b53d71

File tree

1 file changed

+37
-21
lines changed

1 file changed

+37
-21
lines changed

launcher/src/main.rs

Lines changed: 37 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,7 @@ use std::io::{BufRead, BufReader, Read};
66
use std::path::Path;
77
use std::sync::atomic::{AtomicBool, Ordering};
88
use std::sync::mpsc::TryRecvError;
9-
use std::sync::Arc;
10-
use std::sync::{mpsc, Mutex};
9+
use std::sync::{mpsc, Arc};
1110
use std::thread;
1211
use std::thread::sleep;
1312
use std::time::{Duration, Instant};
@@ -274,7 +273,7 @@ struct Args {
274273
#[derive(Debug)]
275274
enum ShardStatus {
276275
Ready,
277-
Failed((usize, String)),
276+
Failed((usize, Option<String>)),
278277
}
279278

280279
#[allow(clippy::too_many_arguments)]
@@ -296,7 +295,7 @@ fn shard_manager(
296295
watermark_delta: Option<f32>,
297296
otlp_endpoint: Option<String>,
298297
status_sender: mpsc::Sender<ShardStatus>,
299-
shutdown: Arc<Mutex<bool>>,
298+
shutdown: Arc<AtomicBool>,
300299
_shutdown_sender: mpsc::Sender<()>,
301300
) {
302301
// Get UDS path
@@ -433,20 +432,20 @@ fn shard_manager(
433432
}
434433
}
435434
status_sender
436-
.send(ShardStatus::Failed((rank, err.to_string())))
435+
.send(ShardStatus::Failed((rank, Some(err.to_string()))))
437436
.unwrap();
438437
return;
439438
}
440439
};
441440

442441
// Redirect STDOUT to the console
443-
let shard_stdout = p.stdout.take().unwrap();
442+
let shard_stdout_reader = BufReader::new(p.stdout.take().unwrap());
443+
let mut shard_stderr_reader = BufReader::new(p.stderr.take().unwrap());
444444

445445
thread::spawn(move || {
446446
// Enter shard-manager tracing span
447-
let stdout = BufReader::new(shard_stdout);
448447
let _span = tracing::span!(tracing::Level::INFO, "shard-manager", rank = rank).entered();
449-
for line in stdout.lines() {
448+
for line in shard_stdout_reader.lines() {
450449
// Parse loguru logs
451450
if let Ok(log) = serde_json::from_str::<PythonLogMessage>(&line.unwrap()) {
452451
log.trace();
@@ -460,8 +459,22 @@ fn shard_manager(
460459
loop {
461460
// Process exited
462461
if let Some(exit_status) = p.poll() {
463-
let mut err = String::new();
464-
p.stderr.take().unwrap().read_to_string(&mut err).unwrap();
462+
// We read stderr in another thread as it seems that `read_to_string` can block
463+
// indefinitely in some cases
464+
let (err_sender, err_receiver) = mpsc::channel();
465+
thread::spawn(move || {
466+
let mut err = String::new();
467+
shard_stderr_reader.read_to_string(&mut err).unwrap();
468+
err_sender.send(err).unwrap_or(());
469+
});
470+
471+
let err = err_receiver
472+
.recv_timeout(Duration::from_millis(100))
473+
.map_err(|err| {
474+
tracing::error!("Unable to read shard {rank} error from stderr");
475+
err
476+
})
477+
.ok();
465478

466479
if let ExitStatus::Signaled(signal) = exit_status {
467480
tracing::error!("Shard process was signaled to shutdown with signal {signal}");
@@ -474,7 +487,7 @@ fn shard_manager(
474487
}
475488

476489
// We received a shutdown signal
477-
if *shutdown.lock().unwrap() {
490+
if shutdown.load(Ordering::SeqCst) {
478491
p.kill().unwrap();
479492
let _ = p.wait_timeout(Duration::from_secs(90));
480493
tracing::info!("Shard {rank} terminated");
@@ -494,14 +507,11 @@ fn shard_manager(
494507
}
495508
}
496509

497-
fn shutdown_shards(shutdown: Arc<Mutex<bool>>, shutdown_receiver: &mpsc::Receiver<()>) {
510+
fn shutdown_shards(shutdown: Arc<AtomicBool>, shutdown_receiver: &mpsc::Receiver<()>) {
498511
tracing::info!("Shutting down shards");
499512
// Update shutdown value to true
500513
// This will be picked up by the shard manager
501-
{
502-
let mut shutdown = shutdown.lock().unwrap();
503-
*shutdown = true;
504-
}
514+
shutdown.store(true, Ordering::SeqCst);
505515

506516
// Wait for shards to shutdown
507517
// This will block till all shutdown_sender are dropped
@@ -743,7 +753,7 @@ fn download_convert_model(args: &Args, running: Arc<AtomicBool>) -> Result<(), L
743753
fn spawn_shards(
744754
num_shard: usize,
745755
args: &Args,
746-
shutdown: Arc<Mutex<bool>>,
756+
shutdown: Arc<AtomicBool>,
747757
shutdown_receiver: &mpsc::Receiver<()>,
748758
shutdown_sender: mpsc::Sender<()>,
749759
status_receiver: &mpsc::Receiver<ShardStatus>,
@@ -819,7 +829,10 @@ fn spawn_shards(
819829
sleep(Duration::from_millis(100));
820830
}
821831
Ok(ShardStatus::Failed((rank, err))) => {
822-
tracing::error!("Shard {} failed to start:\n{}", rank, err);
832+
tracing::error!("Shard {rank} failed to start");
833+
if let Some(err) = err {
834+
tracing::error!("{err}");
835+
}
823836
shutdown_shards(shutdown, shutdown_receiver);
824837
return Err(LauncherError::ShardCannotStart);
825838
}
@@ -835,7 +848,7 @@ fn spawn_shards(
835848

836849
fn spawn_webserver(
837850
args: Args,
838-
shutdown: Arc<Mutex<bool>>,
851+
shutdown: Arc<AtomicBool>,
839852
shutdown_receiver: &mpsc::Receiver<()>,
840853
) -> Result<Popen, LauncherError> {
841854
// All shard started
@@ -1002,7 +1015,7 @@ fn main() -> Result<(), LauncherError> {
10021015
download_convert_model(&args, running.clone())?;
10031016

10041017
// Shared shutdown bool
1005-
let shutdown = Arc::new(Mutex::new(false));
1018+
let shutdown = Arc::new(AtomicBool::new(false));
10061019
// Shared shutdown channel
10071020
// When shutting down, the main thread will wait for all senders to be dropped
10081021
let (shutdown_sender, shutdown_receiver) = mpsc::channel();
@@ -1034,7 +1047,10 @@ fn main() -> Result<(), LauncherError> {
10341047

10351048
while running.load(Ordering::SeqCst) {
10361049
if let Ok(ShardStatus::Failed((rank, err))) = status_receiver.try_recv() {
1037-
tracing::error!("Shard {rank} failed:\n{err}");
1050+
tracing::error!("Shard {rank} failed to start");
1051+
if let Some(err) = err {
1052+
tracing::error!("{err}");
1053+
}
10381054
exit_code = Err(LauncherError::ShardFailed);
10391055
break;
10401056
};

0 commit comments

Comments
 (0)