feat(launcher): add arg validation and drop subprocess (#595)

This commit is contained in:
OlivierDehaene 2023-07-13 14:22:37 +02:00 committed by GitHub
parent 3628559516
commit b7327205a6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 157 additions and 129 deletions

25
Cargo.lock generated
View File

@ -538,7 +538,7 @@ dependencies = [
"autocfg", "autocfg",
"cfg-if", "cfg-if",
"crossbeam-utils", "crossbeam-utils",
"memoffset", "memoffset 0.9.0",
"scopeguard", "scopeguard",
] ]
@ -1412,6 +1412,15 @@ version = "2.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2dffe52ecf27772e601905b7522cb4ef790d2cc203488bbd0e2fe85fcb74566d" checksum = "2dffe52ecf27772e601905b7522cb4ef790d2cc203488bbd0e2fe85fcb74566d"
[[package]]
name = "memoffset"
version = "0.7.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5de893c32cde5f383baa4c04c5d6dbdd735cfd4a794b0debdb2bb1b421da5ff4"
dependencies = [
"autocfg",
]
[[package]] [[package]]
name = "memoffset" name = "memoffset"
version = "0.9.0" version = "0.9.0"
@ -1632,6 +1641,8 @@ dependencies = [
"bitflags 1.3.2", "bitflags 1.3.2",
"cfg-if", "cfg-if",
"libc", "libc",
"memoffset 0.7.1",
"pin-utils",
"static_assertions", "static_assertions",
] ]
@ -2739,16 +2750,6 @@ version = "0.10.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "73473c0e59e6d5812c5dfe2a064a6444949f089e20eec9a2e5506596494e4623" checksum = "73473c0e59e6d5812c5dfe2a064a6444949f089e20eec9a2e5506596494e4623"
[[package]]
name = "subprocess"
version = "0.2.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0c2e86926081dda636c546d8c5e641661049d7562a68f5488be4a1f7f66f6086"
dependencies = [
"libc",
"winapi",
]
[[package]] [[package]]
name = "subtle" name = "subtle"
version = "2.5.0" version = "2.5.0"
@ -2889,10 +2890,10 @@ dependencies = [
"clap", "clap",
"ctrlc", "ctrlc",
"float_eq", "float_eq",
"nix",
"reqwest", "reqwest",
"serde", "serde",
"serde_json", "serde_json",
"subprocess",
"tracing", "tracing",
"tracing-subscriber", "tracing-subscriber",
"vergen", "vergen",

View File

@ -9,9 +9,9 @@ homepage.workspace = true
[dependencies] [dependencies]
clap = { version = "4.1.4", features = ["derive", "env"] } clap = { version = "4.1.4", features = ["derive", "env"] }
ctrlc = { version = "3.2.5", features = ["termination"] } ctrlc = { version = "3.2.5", features = ["termination"] }
nix = "0.26.2"
serde = { version = "1.0.152", features = ["derive"] } serde = { version = "1.0.152", features = ["derive"] }
serde_json = "1.0.93" serde_json = "1.0.93"
subprocess = "0.2.9"
tracing = "0.1.37" tracing = "0.1.37"
tracing-subscriber = { version = "0.3.16", features = ["json"] } tracing-subscriber = { version = "0.3.16", features = ["json"] }

View File

@ -1,9 +1,13 @@
use clap::{Parser, ValueEnum}; use clap::{Parser, ValueEnum};
use nix::sys::signal::{self, Signal};
use nix::unistd::Pid;
use serde::Deserialize; use serde::Deserialize;
use std::env; use std::env;
use std::ffi::OsString; use std::ffi::OsString;
use std::io::{BufRead, BufReader, Read}; use std::io::{BufRead, BufReader, Read};
use std::os::unix::process::{CommandExt, ExitStatusExt};
use std::path::Path; use std::path::Path;
use std::process::{Child, Command, Stdio};
use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::mpsc::TryRecvError; use std::sync::mpsc::TryRecvError;
use std::sync::{mpsc, Arc}; use std::sync::{mpsc, Arc};
@ -11,7 +15,6 @@ use std::thread;
use std::thread::sleep; use std::thread::sleep;
use std::time::{Duration, Instant}; use std::time::{Duration, Instant};
use std::{fs, io}; use std::{fs, io};
use subprocess::{ExitStatus, Popen, PopenConfig, PopenError, Redirection};
mod env_runtime; mod env_runtime;
@ -72,6 +75,11 @@ struct Args {
#[clap(long, env)] #[clap(long, env)]
revision: Option<String>, revision: Option<String>,
/// The number of tokenizer workers used for payload validation and truncation inside the
/// router.
#[clap(default_value = "2", long, env)]
validation_workers: usize,
/// Whether to shard the model across multiple GPUs /// Whether to shard the model across multiple GPUs
/// By default text-generation-inference will use all available GPUs to run /// By default text-generation-inference will use all available GPUs to run
/// the model. Setting it to `false` deactivates `num_shard`. /// the model. Setting it to `false` deactivates `num_shard`.
@ -306,11 +314,12 @@ fn shard_manager(
let uds_string = format!("{uds_path}-{rank}"); let uds_string = format!("{uds_path}-{rank}");
let uds = Path::new(&uds_string); let uds = Path::new(&uds_string);
// Clean previous runs // Clean previous runs
fs::remove_file(uds).unwrap_or_default(); if uds.exists() {
fs::remove_file(uds).unwrap();
}
// Process args // Process args
let mut shard_argv = vec![ let mut shard_argv = vec![
"text-generation-server".to_string(),
"serve".to_string(), "serve".to_string(),
model_id, model_id,
"--uds-path".to_string(), "--uds-path".to_string(),
@ -415,26 +424,23 @@ fn shard_manager(
// Start process // Start process
tracing::info!("Starting shard {rank}"); tracing::info!("Starting shard {rank}");
let mut p = match Popen::create( let mut p = match Command::new("text-generation-server")
&shard_argv, .args(shard_argv)
PopenConfig { .envs(env)
stdout: Redirection::Pipe, .stdout(Stdio::piped())
stderr: Redirection::Pipe, .stderr(Stdio::piped())
// Needed for the shutdown procedure .process_group(0)
setpgid: true, .spawn()
// NCCL env vars {
env: Some(env),
..Default::default()
},
) {
Ok(p) => p, Ok(p) => p,
Err(err) => { Err(err) => {
if let PopenError::IoError(ref err) = err {
if err.kind() == io::ErrorKind::NotFound { if err.kind() == io::ErrorKind::NotFound {
tracing::error!("text-generation-server not found in PATH"); tracing::error!("text-generation-server not found in PATH");
tracing::error!("Please install it with `make install-server`") tracing::error!("Please install it with `make install-server`")
} else {
tracing::error!("{}", err);
} }
}
status_sender status_sender
.send(ShardStatus::Failed((rank, Some(err.to_string())))) .send(ShardStatus::Failed((rank, Some(err.to_string()))))
.unwrap(); .unwrap();
@ -462,7 +468,7 @@ fn shard_manager(
let mut wait_time = Instant::now(); let mut wait_time = Instant::now();
loop { loop {
// Process exited // Process exited
if let Some(exit_status) = p.poll() { if let Some(exit_status) = p.try_wait().unwrap() {
// We read stderr in another thread as it seems that `read_to_string` can block // We read stderr in another thread as it seems that `read_to_string` can block
// indefinitely in some cases // indefinitely in some cases
let (err_sender, err_receiver) = mpsc::channel(); let (err_sender, err_receiver) = mpsc::channel();
@ -480,7 +486,7 @@ fn shard_manager(
}) })
.ok(); .ok();
if let ExitStatus::Signaled(signal) = exit_status { if let Some(signal) = exit_status.signal() {
tracing::error!("Shard process was signaled to shutdown with signal {signal}"); tracing::error!("Shard process was signaled to shutdown with signal {signal}");
} }
@ -493,7 +499,7 @@ fn shard_manager(
// We received a shutdown signal // We received a shutdown signal
if shutdown.load(Ordering::SeqCst) { if shutdown.load(Ordering::SeqCst) {
p.kill().unwrap(); p.kill().unwrap();
let _ = p.wait_timeout(Duration::from_secs(90)); let _ = p.wait();
tracing::info!("Shard {rank} terminated"); tracing::info!("Shard {rank} terminated");
return; return;
} }
@ -573,7 +579,10 @@ impl PythonLogMessage {
} }
} }
fn find_num_shards(sharded: Option<bool>, num_shard: Option<usize>) -> usize { fn find_num_shards(
sharded: Option<bool>,
num_shard: Option<usize>,
) -> Result<usize, LauncherError> {
// get the number of shards given `sharded` and `num_shard` // get the number of shards given `sharded` and `num_shard`
let num_shard = match (sharded, num_shard) { let num_shard = match (sharded, num_shard) {
(Some(true), None) => { (Some(true), None) => {
@ -582,14 +591,18 @@ fn find_num_shards(sharded: Option<bool>, num_shard: Option<usize>) -> usize {
let n_devices = num_cuda_devices() let n_devices = num_cuda_devices()
.expect("--num-shard and CUDA_VISIBLE_DEVICES/NVIDIA_VISIBLE_DEVICES are not set"); .expect("--num-shard and CUDA_VISIBLE_DEVICES/NVIDIA_VISIBLE_DEVICES are not set");
if n_devices <= 1 { if n_devices <= 1 {
panic!("`sharded` is true but only found {n_devices} CUDA devices"); return Err(LauncherError::NotEnoughCUDADevices(format!(
"`sharded` is true but only found {n_devices} CUDA devices"
)));
} }
n_devices n_devices
} }
(Some(true), Some(num_shard)) => { (Some(true), Some(num_shard)) => {
// we can't have only one shard while sharded // we can't have only one shard while sharded
if num_shard <= 1 { if num_shard <= 1 {
panic!("`sharded` is true but `num_shard` <= 1"); return Err(LauncherError::ArgumentValidation(
"`sharded` is true but `num_shard` <= 1".to_string(),
));
} }
num_shard num_shard
} }
@ -599,13 +612,17 @@ fn find_num_shards(sharded: Option<bool>, num_shard: Option<usize>) -> usize {
(None, Some(num_shard)) => num_shard, (None, Some(num_shard)) => num_shard,
}; };
if num_shard < 1 { if num_shard < 1 {
panic!("`num_shard` cannot be < 1"); return Err(LauncherError::ArgumentValidation(
"`num_shard` cannot be < 1".to_string(),
));
} }
num_shard Ok(num_shard)
} }
#[derive(Debug)] #[derive(Debug)]
enum LauncherError { enum LauncherError {
ArgumentValidation(String),
NotEnoughCUDADevices(String),
DownloadError, DownloadError,
ShardCannotStart, ShardCannotStart,
ShardDisconnected, ShardDisconnected,
@ -616,7 +633,6 @@ enum LauncherError {
fn download_convert_model(args: &Args, running: Arc<AtomicBool>) -> Result<(), LauncherError> { fn download_convert_model(args: &Args, running: Arc<AtomicBool>) -> Result<(), LauncherError> {
let mut download_argv = vec![ let mut download_argv = vec![
"text-generation-server".to_string(),
"download-weights".to_string(), "download-weights".to_string(),
args.model_id.to_string(), args.model_id.to_string(),
"--extension".to_string(), "--extension".to_string(),
@ -664,25 +680,21 @@ fn download_convert_model(args: &Args, running: Arc<AtomicBool>) -> Result<(), L
// Start process // Start process
tracing::info!("Starting download process."); tracing::info!("Starting download process.");
let mut download_process = match Popen::create( let mut download_process = match Command::new("text-generation-server")
&download_argv, .args(download_argv)
PopenConfig { .envs(env)
stdout: Redirection::Pipe, .stdout(Stdio::piped())
stderr: Redirection::Pipe, .stderr(Stdio::piped())
// Needed for the shutdown procedure .process_group(0)
setpgid: true, .spawn()
env: Some(env), {
..Default::default()
},
) {
Ok(p) => p, Ok(p) => p,
Err(err) => { Err(err) => {
if let PopenError::IoError(ref err) = err {
if err.kind() == io::ErrorKind::NotFound { if err.kind() == io::ErrorKind::NotFound {
tracing::error!("text-generation-server not found in PATH"); tracing::error!("text-generation-server not found in PATH");
tracing::error!("Please install it with `make install-server`") tracing::error!("Please install it with `make install-server`")
} }
}
return Err(LauncherError::DownloadError); return Err(LauncherError::DownloadError);
} }
}; };
@ -702,25 +714,12 @@ fn download_convert_model(args: &Args, running: Arc<AtomicBool>) -> Result<(), L
}); });
loop { loop {
if let Some(status) = download_process.poll() { if let Some(status) = download_process.try_wait().unwrap() {
match status { if status.success() {
ExitStatus::Exited(exit_code) => {
if exit_code == 0 {
tracing::info!("Successfully downloaded weights."); tracing::info!("Successfully downloaded weights.");
break; break;
} else { }
let mut err = String::new();
download_process
.stderr
.take()
.unwrap()
.read_to_string(&mut err)
.unwrap();
tracing::error!("Download encountered an error: {err}");
return Err(LauncherError::DownloadError);
}
}
ExitStatus::Signaled(signal) => {
let mut err = String::new(); let mut err = String::new();
download_process download_process
.stderr .stderr
@ -728,23 +727,20 @@ fn download_convert_model(args: &Args, running: Arc<AtomicBool>) -> Result<(), L
.unwrap() .unwrap()
.read_to_string(&mut err) .read_to_string(&mut err)
.unwrap(); .unwrap();
if let Some(signal) = status.signal() {
tracing::error!( tracing::error!(
"Download process was signaled to shutdown with signal {signal}: {err}" "Download process was signaled to shutdown with signal {signal}: {err}"
); );
} else {
tracing::error!("Download encountered an error: {err}");
}
return Err(LauncherError::DownloadError); return Err(LauncherError::DownloadError);
} }
e => {
tracing::error!("Download process exited with an unknown status.: {e:?}");
return Err(LauncherError::DownloadError);
}
}
}
if !running.load(Ordering::SeqCst) { if !running.load(Ordering::SeqCst) {
download_process.terminate().unwrap(); signal::kill(Pid::from_raw(download_process.id() as i32), Signal::SIGTERM).unwrap();
tracing::info!("Waiting for download process to gracefully shutdown"); tracing::info!("Waiting for download process to gracefully shutdown");
download_process download_process.wait().unwrap();
.wait_timeout(Duration::from_secs(90))
.unwrap();
tracing::info!("Download process terminated"); tracing::info!("Download process terminated");
return Ok(()); return Ok(());
} }
@ -854,12 +850,11 @@ fn spawn_webserver(
args: Args, args: Args,
shutdown: Arc<AtomicBool>, shutdown: Arc<AtomicBool>,
shutdown_receiver: &mpsc::Receiver<()>, shutdown_receiver: &mpsc::Receiver<()>,
) -> Result<Popen, LauncherError> { ) -> Result<Child, LauncherError> {
// All shard started // All shard started
// Start webserver // Start webserver
tracing::info!("Starting Webserver"); tracing::info!("Starting Webserver");
let mut argv = vec![ let mut argv = vec![
"text-generation-router".to_string(),
"--max-concurrent-requests".to_string(), "--max-concurrent-requests".to_string(),
args.max_concurrent_requests.to_string(), args.max_concurrent_requests.to_string(),
"--max-best-of".to_string(), "--max-best-of".to_string(),
@ -878,6 +873,8 @@ fn spawn_webserver(
args.waiting_served_ratio.to_string(), args.waiting_served_ratio.to_string(),
"--max-waiting-tokens".to_string(), "--max-waiting-tokens".to_string(),
args.max_waiting_tokens.to_string(), args.max_waiting_tokens.to_string(),
"--validation-workers".to_string(),
args.validation_workers.to_string(),
"--hostname".to_string(), "--hostname".to_string(),
args.hostname.to_string(), args.hostname.to_string(),
"--port".to_string(), "--port".to_string(),
@ -942,25 +939,20 @@ fn spawn_webserver(
env.push(("HUGGING_FACE_HUB_TOKEN".into(), api_token.into())) env.push(("HUGGING_FACE_HUB_TOKEN".into(), api_token.into()))
}; };
let mut webserver = match Popen::create( let mut webserver = match Command::new("text-generation-router")
&argv, .args(argv)
PopenConfig { .envs(env)
stdout: Redirection::Pipe, .stdout(Stdio::piped())
stderr: Redirection::Pipe, .stderr(Stdio::piped())
// Needed for the shutdown procedure .process_group(0)
setpgid: true, .spawn()
env: Some(env), {
..Default::default()
},
) {
Ok(p) => p, Ok(p) => p,
Err(err) => { Err(err) => {
tracing::error!("Failed to start webserver: {}", err); tracing::error!("Failed to start webserver: {}", err);
if let PopenError::IoError(err) = err {
if err.kind() == io::ErrorKind::NotFound { if err.kind() == io::ErrorKind::NotFound {
tracing::error!("text-generation-router not found in PATH"); tracing::error!("text-generation-router not found in PATH");
tracing::error!("Please install it with `make install-router`") tracing::error!("Please install it with `make install-router`")
}
} else { } else {
tracing::error!("{}", err); tracing::error!("{}", err);
} }
@ -1004,7 +996,37 @@ fn main() -> Result<(), LauncherError> {
tracing::info!("{:?}", args); tracing::info!("{:?}", args);
let num_shard = find_num_shards(args.sharded, args.num_shard); // Validate args
if args.max_input_length >= args.max_total_tokens {
return Err(LauncherError::ArgumentValidation(
"`max_input_length` must be < `max_total_tokens`".to_string(),
));
}
if args.max_input_length as u32 > args.max_batch_prefill_tokens {
return Err(LauncherError::ArgumentValidation(format!(
"`max_batch_prefill_tokens` must be >= `max_input_length`. Given: {} and {}",
args.max_batch_prefill_tokens, args.max_input_length
)));
}
if args.max_batch_prefill_tokens > args.max_batch_total_tokens {
return Err(LauncherError::ArgumentValidation(format!(
"`max_batch_prefill_tokens` must be <= `max_batch_total_tokens`. Given: {} and {}",
args.max_batch_prefill_tokens, args.max_batch_total_tokens
)));
}
if args.max_total_tokens as u32 > args.max_batch_total_tokens {
return Err(LauncherError::ArgumentValidation(format!(
"`max_total_tokens` must be <= `max_batch_total_tokens`. Given: {} and {}",
args.max_total_tokens, args.max_batch_total_tokens
)));
}
if args.validation_workers == 0 {
return Err(LauncherError::ArgumentValidation(
"`validation_workers` must be > 0".to_string(),
));
}
let num_shard = find_num_shards(args.sharded, args.num_shard)?;
if num_shard > 1 { if num_shard > 1 {
tracing::info!("Sharding model on {num_shard} processes"); tracing::info!("Sharding model on {num_shard} processes");
} }
@ -1065,7 +1087,7 @@ fn main() -> Result<(), LauncherError> {
break; break;
}; };
match webserver.poll() { match webserver.try_wait().unwrap() {
Some(_) => { Some(_) => {
tracing::error!("Webserver Crashed"); tracing::error!("Webserver Crashed");
shutdown_shards(shutdown, &shutdown_receiver); shutdown_shards(shutdown, &shutdown_receiver);
@ -1078,9 +1100,9 @@ fn main() -> Result<(), LauncherError> {
} }
// Graceful termination // Graceful termination
webserver.terminate().unwrap(); signal::kill(Pid::from_raw(webserver.id() as i32), Signal::SIGTERM).unwrap();
tracing::info!("Waiting for webserver to gracefully shutdown"); tracing::info!("Waiting for webserver to gracefully shutdown");
webserver.wait_timeout(Duration::from_secs(90)).unwrap(); webserver.wait().unwrap();
tracing::info!("Webserver terminated"); tracing::info!("Webserver terminated");
shutdown_shards(shutdown, &shutdown_receiver); shutdown_shards(shutdown, &shutdown_receiver);

View File

@ -98,7 +98,7 @@ impl Client {
/// Warmup on a max size batch /// Warmup on a max size batch
/// ///
/// Returns the maximum amount of tokens supported by the hardware /// Returns the maximum amount of tokens supported by the hardware
#[instrument(skip(self))] #[instrument(skip_all)]
pub async fn warmup( pub async fn warmup(
&mut self, &mut self,
max_input_length: u32, max_input_length: u32,

View File

@ -102,17 +102,24 @@ fn main() -> Result<(), RouterError> {
} = args; } = args;
// Validate args // Validate args
if max_input_length >= max_total_tokens {
return Err(RouterError::ArgumentValidation(
"`max_input_length` must be < `max_total_tokens`".to_string(),
));
}
if max_input_length as u32 > max_batch_prefill_tokens { if max_input_length as u32 > max_batch_prefill_tokens {
panic!("{}", format!("`max_batch_prefill_tokens` must be >= `max_input_length`. Given: {max_batch_prefill_tokens} and {max_input_length}")); return Err(RouterError::ArgumentValidation(format!("`max_batch_prefill_tokens` must be >= `max_input_length`. Given: {max_batch_prefill_tokens} and {max_input_length}")));
} }
if max_batch_prefill_tokens > max_batch_total_tokens { if max_batch_prefill_tokens > max_batch_total_tokens {
panic!("{}", format!("`max_batch_prefill_tokens` must be <= `max_batch_total_tokens`. Given: {max_batch_prefill_tokens} and {max_batch_total_tokens}")); return Err(RouterError::ArgumentValidation(format!("`max_batch_prefill_tokens` must be <= `max_batch_total_tokens`. Given: {max_batch_prefill_tokens} and {max_batch_total_tokens}")));
} }
if max_total_tokens as u32 > max_batch_total_tokens { if max_total_tokens as u32 > max_batch_total_tokens {
panic!("{}", format!("`max_total_tokens` must be <= `max_batch_total_tokens`. Given: {max_total_tokens} and {max_batch_total_tokens}")); return Err(RouterError::ArgumentValidation(format!("`max_total_tokens` must be <= `max_batch_total_tokens`. Given: {max_total_tokens} and {max_batch_total_tokens}")));
} }
if validation_workers == 0 { if validation_workers == 0 {
panic!("`validation_workers` must be > 0"); return Err(RouterError::ArgumentValidation(
"`validation_workers` must be > 0".to_string(),
));
} }
// CORS allowed origins // CORS allowed origins
@ -331,6 +338,8 @@ pub async fn get_model_info(
#[derive(Debug, Error)] #[derive(Debug, Error)]
enum RouterError { enum RouterError {
#[error("Argument validation error: {0}")]
ArgumentValidation(String),
#[error("Unable to connect to the Python model shards: {0}")] #[error("Unable to connect to the Python model shards: {0}")]
Connection(ClientError), Connection(ClientError),
#[error("Unable to clear the Python model shards cache: {0}")] #[error("Unable to clear the Python model shards cache: {0}")]

View File

@ -1,5 +1,5 @@
use crate::validation::ValidationError::{BestOfSampling, BestOfSeed, EmptyInput};
/// Payload validation logic /// Payload validation logic
use crate::validation::ValidationError::{BestOfSampling, BestOfSeed, EmptyInput};
use crate::{GenerateParameters, GenerateRequest}; use crate::{GenerateParameters, GenerateRequest};
use rand::{thread_rng, Rng}; use rand::{thread_rng, Rng};
use text_generation_client::{NextTokenChooserParameters, StoppingCriteriaParameters}; use text_generation_client::{NextTokenChooserParameters, StoppingCriteriaParameters};
@ -30,10 +30,6 @@ impl Validation {
max_input_length: usize, max_input_length: usize,
max_total_tokens: usize, max_total_tokens: usize,
) -> Self { ) -> Self {
if max_input_length >= max_total_tokens {
panic!("`max_input_length` must be < `max_total_tokens`");
}
// If we have a fast tokenizer // If we have a fast tokenizer
let sender = if let Some(tokenizer) = tokenizer { let sender = if let Some(tokenizer) = tokenizer {
// Create channel // Create channel