fix: read stderr in download (#1486)

#1186
This commit is contained in:
OlivierDehaene 2024-01-25 18:16:03 +01:00 committed by GitHub
parent 7e2a7433d3
commit 9c320e260b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 30 additions and 18 deletions

View File

@ -4,7 +4,7 @@ use nix::unistd::Pid;
use serde::Deserialize;
use std::env;
use std::ffi::OsString;
use std::io::{BufRead, BufReader, Lines, Read};
use std::io::{BufRead, BufReader, Lines};
use std::os::unix::process::{CommandExt, ExitStatusExt};
use std::path::Path;
use std::process::{Child, Command, ExitStatus, Stdio};
@ -489,6 +489,9 @@ fn shard_manager(
// Safetensors load fast
envs.push(("SAFETENSORS_FAST_GPU".into(), "1".into()));
// Disable progress bar
envs.push(("HF_HUB_DISABLE_PROGRESS_BARS".into(), "1".into()));
// Enable hf transfer for insane download speeds
let enable_hf_transfer = env::var("HF_HUB_ENABLE_HF_TRANSFER").unwrap_or("1".to_string());
envs.push((
@ -573,6 +576,13 @@ fn shard_manager(
thread::spawn(move || {
log_lines(shard_stdout_reader.lines());
});
// We read stderr in another thread as it seems that lines() can block in some cases
let (err_sender, err_receiver) = mpsc::channel();
thread::spawn(move || {
for line in shard_stderr_reader.lines().flatten() {
err_sender.send(line).unwrap_or(());
}
});
let mut ready = false;
let start_time = Instant::now();
@ -580,13 +590,6 @@ fn shard_manager(
loop {
// Process exited
if let Some(exit_status) = p.try_wait().unwrap() {
// We read stderr in another thread as it seems that lines() can block in some cases
let (err_sender, err_receiver) = mpsc::channel();
thread::spawn(move || {
for line in shard_stderr_reader.lines().flatten() {
err_sender.send(line).unwrap_or(());
}
});
let mut err = String::new();
while let Ok(line) = err_receiver.recv_timeout(Duration::from_millis(10)) {
err = err + "\n" + &line;
@ -782,6 +785,9 @@ fn download_convert_model(args: &Args, running: Arc<AtomicBool>) -> Result<(), L
// Copy current process env
let mut envs: Vec<(OsString, OsString)> = env::vars_os().collect();
// Disable progress bar
envs.push(("HF_HUB_DISABLE_PROGRESS_BARS".into(), "1".into()));
// If huggingface_hub_cache is set, pass it to the download process
// Useful when running inside a docker container
if let Some(ref huggingface_hub_cache) = args.huggingface_hub_cache {
@ -832,12 +838,20 @@ fn download_convert_model(args: &Args, running: Arc<AtomicBool>) -> Result<(), L
}
};
// Redirect STDOUT to the console
let download_stdout = download_process.stdout.take().unwrap();
let stdout = BufReader::new(download_stdout);
let download_stdout = BufReader::new(download_process.stdout.take().unwrap());
thread::spawn(move || {
log_lines(stdout.lines());
log_lines(download_stdout.lines());
});
let download_stderr = BufReader::new(download_process.stderr.take().unwrap());
// We read stderr in another thread as it seems that lines() can block in some cases
let (err_sender, err_receiver) = mpsc::channel();
thread::spawn(move || {
for line in download_stderr.lines().flatten() {
err_sender.send(line).unwrap_or(());
}
});
loop {
@ -848,12 +862,10 @@ fn download_convert_model(args: &Args, running: Arc<AtomicBool>) -> Result<(), L
}
let mut err = String::new();
download_process
.stderr
.take()
.unwrap()
.read_to_string(&mut err)
.unwrap();
while let Ok(line) = err_receiver.recv_timeout(Duration::from_millis(10)) {
err = err + "\n" + &line;
}
if let Some(signal) = status.signal() {
tracing::error!(
"Download process was signaled to shutdown with signal {signal}: {err}"