diff --git a/launcher/src/main.rs b/launcher/src/main.rs index 51131f42..30abe88a 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -6,8 +6,7 @@ use std::io::{BufRead, BufReader, Read}; use std::path::Path; use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::mpsc::TryRecvError; -use std::sync::Arc; -use std::sync::{mpsc, Mutex}; +use std::sync::{mpsc, Arc}; use std::thread; use std::thread::sleep; use std::time::{Duration, Instant}; @@ -274,7 +273,7 @@ struct Args { #[derive(Debug)] enum ShardStatus { Ready, - Failed((usize, String)), + Failed((usize, Option)), } #[allow(clippy::too_many_arguments)] @@ -296,7 +295,7 @@ fn shard_manager( watermark_delta: Option, otlp_endpoint: Option, status_sender: mpsc::Sender, - shutdown: Arc>, + shutdown: Arc, _shutdown_sender: mpsc::Sender<()>, ) { // Get UDS path @@ -433,20 +432,20 @@ fn shard_manager( } } status_sender - .send(ShardStatus::Failed((rank, err.to_string()))) + .send(ShardStatus::Failed((rank, Some(err.to_string())))) .unwrap(); return; } }; // Redirect STDOUT to the console - let shard_stdout = p.stdout.take().unwrap(); + let shard_stdout_reader = BufReader::new(p.stdout.take().unwrap()); + let mut shard_stderr_reader = BufReader::new(p.stderr.take().unwrap()); thread::spawn(move || { // Enter shard-manager tracing span - let stdout = BufReader::new(shard_stdout); let _span = tracing::span!(tracing::Level::INFO, "shard-manager", rank = rank).entered(); - for line in stdout.lines() { + for line in shard_stdout_reader.lines() { // Parse loguru logs if let Ok(log) = serde_json::from_str::(&line.unwrap()) { log.trace(); @@ -460,8 +459,22 @@ fn shard_manager( loop { // Process exited if let Some(exit_status) = p.poll() { - let mut err = String::new(); - p.stderr.take().unwrap().read_to_string(&mut err).unwrap(); + // We read stderr in another thread as it seems that `read_to_string` can block + // indefinitely in some cases + let (err_sender, err_receiver) = mpsc::channel(); + thread::spawn(move || { + let mut err = String::new(); + shard_stderr_reader.read_to_string(&mut err).unwrap(); + err_sender.send(err).unwrap_or(()); + }); + + let err = err_receiver + .recv_timeout(Duration::from_millis(100)) + .map_err(|err| { + tracing::error!("Unable to read shard {rank} error from stderr"); + err + }) + .ok(); if let ExitStatus::Signaled(signal) = exit_status { tracing::error!("Shard process was signaled to shutdown with signal {signal}"); @@ -474,7 +487,7 @@ fn shard_manager( } // We received a shutdown signal - if *shutdown.lock().unwrap() { + if shutdown.load(Ordering::SeqCst) { p.kill().unwrap(); let _ = p.wait_timeout(Duration::from_secs(90)); tracing::info!("Shard {rank} terminated"); @@ -494,14 +507,11 @@ fn shard_manager( } } -fn shutdown_shards(shutdown: Arc>, shutdown_receiver: &mpsc::Receiver<()>) { +fn shutdown_shards(shutdown: Arc, shutdown_receiver: &mpsc::Receiver<()>) { tracing::info!("Shutting down shards"); // Update shutdown value to true // This will be picked up by the shard manager - { - let mut shutdown = shutdown.lock().unwrap(); - *shutdown = true; - } + shutdown.store(true, Ordering::SeqCst); // Wait for shards to shutdown // This will block till all shutdown_sender are dropped @@ -743,7 +753,7 @@ fn download_convert_model(args: &Args, running: Arc) -> Result<(), L fn spawn_shards( num_shard: usize, args: &Args, - shutdown: Arc>, + shutdown: Arc, shutdown_receiver: &mpsc::Receiver<()>, shutdown_sender: mpsc::Sender<()>, status_receiver: &mpsc::Receiver, @@ -819,7 +829,10 @@ fn spawn_shards( sleep(Duration::from_millis(100)); } Ok(ShardStatus::Failed((rank, err))) => { - tracing::error!("Shard {} failed to start:\n{}", rank, err); + tracing::error!("Shard {rank} failed to start"); + if let Some(err) = err { + tracing::error!("{err}"); + } shutdown_shards(shutdown, shutdown_receiver); return Err(LauncherError::ShardCannotStart); } @@ -835,7 +848,7 @@ fn spawn_shards( fn spawn_webserver( args: Args, - shutdown: Arc>, + shutdown: Arc, shutdown_receiver: &mpsc::Receiver<()>, ) -> Result { // All shard started @@ -1002,7 +1015,7 @@ fn main() -> Result<(), LauncherError> { download_convert_model(&args, running.clone())?; // Shared shutdown bool - let shutdown = Arc::new(Mutex::new(false)); + let shutdown = Arc::new(AtomicBool::new(false)); // Shared shutdown channel // When shutting down, the main thread will wait for all senders to be dropped let (shutdown_sender, shutdown_receiver) = mpsc::channel(); @@ -1034,7 +1047,10 @@ fn main() -> Result<(), LauncherError> { while running.load(Ordering::SeqCst) { if let Ok(ShardStatus::Failed((rank, err))) = status_receiver.try_recv() { - tracing::error!("Shard {rank} failed:\n{err}"); + tracing::error!("Shard {rank} failed to start"); + if let Some(err) = err { + tracing::error!("{err}"); + } exit_code = Err(LauncherError::ShardFailed); break; };