fix(launcher): fix issue where launcher does not properly report shard failures (#522)

This commit is contained in:
OlivierDehaene 2023-06-30 23:09:20 +02:00 committed by GitHub
parent ecf6dc3a5a
commit 2b53d71991
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 37 additions and 21 deletions

View File

@ -6,8 +6,7 @@ use std::io::{BufRead, BufReader, Read};
use std::path::Path; use std::path::Path;
use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::mpsc::TryRecvError; use std::sync::mpsc::TryRecvError;
use std::sync::Arc; use std::sync::{mpsc, Arc};
use std::sync::{mpsc, Mutex};
use std::thread; use std::thread;
use std::thread::sleep; use std::thread::sleep;
use std::time::{Duration, Instant}; use std::time::{Duration, Instant};
@ -274,7 +273,7 @@ struct Args {
#[derive(Debug)] #[derive(Debug)]
enum ShardStatus { enum ShardStatus {
Ready, Ready,
Failed((usize, String)), Failed((usize, Option<String>)),
} }
#[allow(clippy::too_many_arguments)] #[allow(clippy::too_many_arguments)]
@ -296,7 +295,7 @@ fn shard_manager(
watermark_delta: Option<f32>, watermark_delta: Option<f32>,
otlp_endpoint: Option<String>, otlp_endpoint: Option<String>,
status_sender: mpsc::Sender<ShardStatus>, status_sender: mpsc::Sender<ShardStatus>,
shutdown: Arc<Mutex<bool>>, shutdown: Arc<AtomicBool>,
_shutdown_sender: mpsc::Sender<()>, _shutdown_sender: mpsc::Sender<()>,
) { ) {
// Get UDS path // Get UDS path
@ -433,20 +432,20 @@ fn shard_manager(
} }
} }
status_sender status_sender
.send(ShardStatus::Failed((rank, err.to_string()))) .send(ShardStatus::Failed((rank, Some(err.to_string()))))
.unwrap(); .unwrap();
return; return;
} }
}; };
// Redirect STDOUT to the console // Redirect STDOUT to the console
let shard_stdout = p.stdout.take().unwrap(); let shard_stdout_reader = BufReader::new(p.stdout.take().unwrap());
let mut shard_stderr_reader = BufReader::new(p.stderr.take().unwrap());
thread::spawn(move || { thread::spawn(move || {
// Enter shard-manager tracing span // Enter shard-manager tracing span
let stdout = BufReader::new(shard_stdout);
let _span = tracing::span!(tracing::Level::INFO, "shard-manager", rank = rank).entered(); let _span = tracing::span!(tracing::Level::INFO, "shard-manager", rank = rank).entered();
for line in stdout.lines() { for line in shard_stdout_reader.lines() {
// Parse loguru logs // Parse loguru logs
if let Ok(log) = serde_json::from_str::<PythonLogMessage>(&line.unwrap()) { if let Ok(log) = serde_json::from_str::<PythonLogMessage>(&line.unwrap()) {
log.trace(); log.trace();
@ -460,8 +459,22 @@ fn shard_manager(
loop { loop {
// Process exited // Process exited
if let Some(exit_status) = p.poll() { if let Some(exit_status) = p.poll() {
// We read stderr in another thread as it seems that `read_to_string` can block
// indefinitely in some cases
let (err_sender, err_receiver) = mpsc::channel();
thread::spawn(move || {
let mut err = String::new(); let mut err = String::new();
p.stderr.take().unwrap().read_to_string(&mut err).unwrap(); shard_stderr_reader.read_to_string(&mut err).unwrap();
err_sender.send(err).unwrap_or(());
});
let err = err_receiver
.recv_timeout(Duration::from_millis(100))
.map_err(|err| {
tracing::error!("Unable to read shard {rank} error from stderr");
err
})
.ok();
if let ExitStatus::Signaled(signal) = exit_status { if let ExitStatus::Signaled(signal) = exit_status {
tracing::error!("Shard process was signaled to shutdown with signal {signal}"); tracing::error!("Shard process was signaled to shutdown with signal {signal}");
@ -474,7 +487,7 @@ fn shard_manager(
} }
// We received a shutdown signal // We received a shutdown signal
if *shutdown.lock().unwrap() { if shutdown.load(Ordering::SeqCst) {
p.kill().unwrap(); p.kill().unwrap();
let _ = p.wait_timeout(Duration::from_secs(90)); let _ = p.wait_timeout(Duration::from_secs(90));
tracing::info!("Shard {rank} terminated"); tracing::info!("Shard {rank} terminated");
@ -494,14 +507,11 @@ fn shard_manager(
} }
} }
fn shutdown_shards(shutdown: Arc<Mutex<bool>>, shutdown_receiver: &mpsc::Receiver<()>) { fn shutdown_shards(shutdown: Arc<AtomicBool>, shutdown_receiver: &mpsc::Receiver<()>) {
tracing::info!("Shutting down shards"); tracing::info!("Shutting down shards");
// Update shutdown value to true // Update shutdown value to true
// This will be picked up by the shard manager // This will be picked up by the shard manager
{ shutdown.store(true, Ordering::SeqCst);
let mut shutdown = shutdown.lock().unwrap();
*shutdown = true;
}
// Wait for shards to shutdown // Wait for shards to shutdown
// This will block till all shutdown_sender are dropped // This will block till all shutdown_sender are dropped
@ -743,7 +753,7 @@ fn download_convert_model(args: &Args, running: Arc<AtomicBool>) -> Result<(), L
fn spawn_shards( fn spawn_shards(
num_shard: usize, num_shard: usize,
args: &Args, args: &Args,
shutdown: Arc<Mutex<bool>>, shutdown: Arc<AtomicBool>,
shutdown_receiver: &mpsc::Receiver<()>, shutdown_receiver: &mpsc::Receiver<()>,
shutdown_sender: mpsc::Sender<()>, shutdown_sender: mpsc::Sender<()>,
status_receiver: &mpsc::Receiver<ShardStatus>, status_receiver: &mpsc::Receiver<ShardStatus>,
@ -819,7 +829,10 @@ fn spawn_shards(
sleep(Duration::from_millis(100)); sleep(Duration::from_millis(100));
} }
Ok(ShardStatus::Failed((rank, err))) => { Ok(ShardStatus::Failed((rank, err))) => {
tracing::error!("Shard {} failed to start:\n{}", rank, err); tracing::error!("Shard {rank} failed to start");
if let Some(err) = err {
tracing::error!("{err}");
}
shutdown_shards(shutdown, shutdown_receiver); shutdown_shards(shutdown, shutdown_receiver);
return Err(LauncherError::ShardCannotStart); return Err(LauncherError::ShardCannotStart);
} }
@ -835,7 +848,7 @@ fn spawn_shards(
fn spawn_webserver( fn spawn_webserver(
args: Args, args: Args,
shutdown: Arc<Mutex<bool>>, shutdown: Arc<AtomicBool>,
shutdown_receiver: &mpsc::Receiver<()>, shutdown_receiver: &mpsc::Receiver<()>,
) -> Result<Popen, LauncherError> { ) -> Result<Popen, LauncherError> {
// All shard started // All shard started
@ -1002,7 +1015,7 @@ fn main() -> Result<(), LauncherError> {
download_convert_model(&args, running.clone())?; download_convert_model(&args, running.clone())?;
// Shared shutdown bool // Shared shutdown bool
let shutdown = Arc::new(Mutex::new(false)); let shutdown = Arc::new(AtomicBool::new(false));
// Shared shutdown channel // Shared shutdown channel
// When shutting down, the main thread will wait for all senders to be dropped // When shutting down, the main thread will wait for all senders to be dropped
let (shutdown_sender, shutdown_receiver) = mpsc::channel(); let (shutdown_sender, shutdown_receiver) = mpsc::channel();
@ -1034,7 +1047,10 @@ fn main() -> Result<(), LauncherError> {
while running.load(Ordering::SeqCst) { while running.load(Ordering::SeqCst) {
if let Ok(ShardStatus::Failed((rank, err))) = status_receiver.try_recv() { if let Ok(ShardStatus::Failed((rank, err))) = status_receiver.try_recv() {
tracing::error!("Shard {rank} failed:\n{err}"); tracing::error!("Shard {rank} failed to start");
if let Some(err) = err {
tracing::error!("{err}");
}
exit_code = Err(LauncherError::ShardFailed); exit_code = Err(LauncherError::ShardFailed);
break; break;
}; };