feat(launcher): add arg validation and drop subprocess (#595)
This commit is contained in:
parent
3628559516
commit
b7327205a6
|
@ -538,7 +538,7 @@ dependencies = [
|
|||
"autocfg",
|
||||
"cfg-if",
|
||||
"crossbeam-utils",
|
||||
"memoffset",
|
||||
"memoffset 0.9.0",
|
||||
"scopeguard",
|
||||
]
|
||||
|
||||
|
@ -1412,6 +1412,15 @@ version = "2.5.0"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "2dffe52ecf27772e601905b7522cb4ef790d2cc203488bbd0e2fe85fcb74566d"
|
||||
|
||||
[[package]]
|
||||
name = "memoffset"
|
||||
version = "0.7.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "5de893c32cde5f383baa4c04c5d6dbdd735cfd4a794b0debdb2bb1b421da5ff4"
|
||||
dependencies = [
|
||||
"autocfg",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "memoffset"
|
||||
version = "0.9.0"
|
||||
|
@ -1632,6 +1641,8 @@ dependencies = [
|
|||
"bitflags 1.3.2",
|
||||
"cfg-if",
|
||||
"libc",
|
||||
"memoffset 0.7.1",
|
||||
"pin-utils",
|
||||
"static_assertions",
|
||||
]
|
||||
|
||||
|
@ -2739,16 +2750,6 @@ version = "0.10.0"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "73473c0e59e6d5812c5dfe2a064a6444949f089e20eec9a2e5506596494e4623"
|
||||
|
||||
[[package]]
|
||||
name = "subprocess"
|
||||
version = "0.2.9"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "0c2e86926081dda636c546d8c5e641661049d7562a68f5488be4a1f7f66f6086"
|
||||
dependencies = [
|
||||
"libc",
|
||||
"winapi",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "subtle"
|
||||
version = "2.5.0"
|
||||
|
@ -2889,10 +2890,10 @@ dependencies = [
|
|||
"clap",
|
||||
"ctrlc",
|
||||
"float_eq",
|
||||
"nix",
|
||||
"reqwest",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"subprocess",
|
||||
"tracing",
|
||||
"tracing-subscriber",
|
||||
"vergen",
|
||||
|
|
|
@ -9,9 +9,9 @@ homepage.workspace = true
|
|||
[dependencies]
|
||||
clap = { version = "4.1.4", features = ["derive", "env"] }
|
||||
ctrlc = { version = "3.2.5", features = ["termination"] }
|
||||
nix = "0.26.2"
|
||||
serde = { version = "1.0.152", features = ["derive"] }
|
||||
serde_json = "1.0.93"
|
||||
subprocess = "0.2.9"
|
||||
tracing = "0.1.37"
|
||||
tracing-subscriber = { version = "0.3.16", features = ["json"] }
|
||||
|
||||
|
|
|
@ -1,9 +1,13 @@
|
|||
use clap::{Parser, ValueEnum};
|
||||
use nix::sys::signal::{self, Signal};
|
||||
use nix::unistd::Pid;
|
||||
use serde::Deserialize;
|
||||
use std::env;
|
||||
use std::ffi::OsString;
|
||||
use std::io::{BufRead, BufReader, Read};
|
||||
use std::os::unix::process::{CommandExt, ExitStatusExt};
|
||||
use std::path::Path;
|
||||
use std::process::{Child, Command, Stdio};
|
||||
use std::sync::atomic::{AtomicBool, Ordering};
|
||||
use std::sync::mpsc::TryRecvError;
|
||||
use std::sync::{mpsc, Arc};
|
||||
|
@ -11,7 +15,6 @@ use std::thread;
|
|||
use std::thread::sleep;
|
||||
use std::time::{Duration, Instant};
|
||||
use std::{fs, io};
|
||||
use subprocess::{ExitStatus, Popen, PopenConfig, PopenError, Redirection};
|
||||
|
||||
mod env_runtime;
|
||||
|
||||
|
@ -72,6 +75,11 @@ struct Args {
|
|||
#[clap(long, env)]
|
||||
revision: Option<String>,
|
||||
|
||||
/// The number of tokenizer workers used for payload validation and truncation inside the
|
||||
/// router.
|
||||
#[clap(default_value = "2", long, env)]
|
||||
validation_workers: usize,
|
||||
|
||||
/// Whether to shard the model across multiple GPUs
|
||||
/// By default text-generation-inference will use all available GPUs to run
|
||||
/// the model. Setting it to `false` deactivates `num_shard`.
|
||||
|
@ -306,11 +314,12 @@ fn shard_manager(
|
|||
let uds_string = format!("{uds_path}-{rank}");
|
||||
let uds = Path::new(&uds_string);
|
||||
// Clean previous runs
|
||||
fs::remove_file(uds).unwrap_or_default();
|
||||
if uds.exists() {
|
||||
fs::remove_file(uds).unwrap();
|
||||
}
|
||||
|
||||
// Process args
|
||||
let mut shard_argv = vec![
|
||||
"text-generation-server".to_string(),
|
||||
"serve".to_string(),
|
||||
model_id,
|
||||
"--uds-path".to_string(),
|
||||
|
@ -415,26 +424,23 @@ fn shard_manager(
|
|||
|
||||
// Start process
|
||||
tracing::info!("Starting shard {rank}");
|
||||
let mut p = match Popen::create(
|
||||
&shard_argv,
|
||||
PopenConfig {
|
||||
stdout: Redirection::Pipe,
|
||||
stderr: Redirection::Pipe,
|
||||
// Needed for the shutdown procedure
|
||||
setpgid: true,
|
||||
// NCCL env vars
|
||||
env: Some(env),
|
||||
..Default::default()
|
||||
},
|
||||
) {
|
||||
let mut p = match Command::new("text-generation-server")
|
||||
.args(shard_argv)
|
||||
.envs(env)
|
||||
.stdout(Stdio::piped())
|
||||
.stderr(Stdio::piped())
|
||||
.process_group(0)
|
||||
.spawn()
|
||||
{
|
||||
Ok(p) => p,
|
||||
Err(err) => {
|
||||
if let PopenError::IoError(ref err) = err {
|
||||
if err.kind() == io::ErrorKind::NotFound {
|
||||
tracing::error!("text-generation-server not found in PATH");
|
||||
tracing::error!("Please install it with `make install-server`")
|
||||
} else {
|
||||
tracing::error!("{}", err);
|
||||
}
|
||||
}
|
||||
|
||||
status_sender
|
||||
.send(ShardStatus::Failed((rank, Some(err.to_string()))))
|
||||
.unwrap();
|
||||
|
@ -462,7 +468,7 @@ fn shard_manager(
|
|||
let mut wait_time = Instant::now();
|
||||
loop {
|
||||
// Process exited
|
||||
if let Some(exit_status) = p.poll() {
|
||||
if let Some(exit_status) = p.try_wait().unwrap() {
|
||||
// We read stderr in another thread as it seems that `read_to_string` can block
|
||||
// indefinitely in some cases
|
||||
let (err_sender, err_receiver) = mpsc::channel();
|
||||
|
@ -480,7 +486,7 @@ fn shard_manager(
|
|||
})
|
||||
.ok();
|
||||
|
||||
if let ExitStatus::Signaled(signal) = exit_status {
|
||||
if let Some(signal) = exit_status.signal() {
|
||||
tracing::error!("Shard process was signaled to shutdown with signal {signal}");
|
||||
}
|
||||
|
||||
|
@ -493,7 +499,7 @@ fn shard_manager(
|
|||
// We received a shutdown signal
|
||||
if shutdown.load(Ordering::SeqCst) {
|
||||
p.kill().unwrap();
|
||||
let _ = p.wait_timeout(Duration::from_secs(90));
|
||||
let _ = p.wait();
|
||||
tracing::info!("Shard {rank} terminated");
|
||||
return;
|
||||
}
|
||||
|
@ -573,7 +579,10 @@ impl PythonLogMessage {
|
|||
}
|
||||
}
|
||||
|
||||
fn find_num_shards(sharded: Option<bool>, num_shard: Option<usize>) -> usize {
|
||||
fn find_num_shards(
|
||||
sharded: Option<bool>,
|
||||
num_shard: Option<usize>,
|
||||
) -> Result<usize, LauncherError> {
|
||||
// get the number of shards given `sharded` and `num_shard`
|
||||
let num_shard = match (sharded, num_shard) {
|
||||
(Some(true), None) => {
|
||||
|
@ -582,14 +591,18 @@ fn find_num_shards(sharded: Option<bool>, num_shard: Option<usize>) -> usize {
|
|||
let n_devices = num_cuda_devices()
|
||||
.expect("--num-shard and CUDA_VISIBLE_DEVICES/NVIDIA_VISIBLE_DEVICES are not set");
|
||||
if n_devices <= 1 {
|
||||
panic!("`sharded` is true but only found {n_devices} CUDA devices");
|
||||
return Err(LauncherError::NotEnoughCUDADevices(format!(
|
||||
"`sharded` is true but only found {n_devices} CUDA devices"
|
||||
)));
|
||||
}
|
||||
n_devices
|
||||
}
|
||||
(Some(true), Some(num_shard)) => {
|
||||
// we can't have only one shard while sharded
|
||||
if num_shard <= 1 {
|
||||
panic!("`sharded` is true but `num_shard` <= 1");
|
||||
return Err(LauncherError::ArgumentValidation(
|
||||
"`sharded` is true but `num_shard` <= 1".to_string(),
|
||||
));
|
||||
}
|
||||
num_shard
|
||||
}
|
||||
|
@ -599,13 +612,17 @@ fn find_num_shards(sharded: Option<bool>, num_shard: Option<usize>) -> usize {
|
|||
(None, Some(num_shard)) => num_shard,
|
||||
};
|
||||
if num_shard < 1 {
|
||||
panic!("`num_shard` cannot be < 1");
|
||||
return Err(LauncherError::ArgumentValidation(
|
||||
"`num_shard` cannot be < 1".to_string(),
|
||||
));
|
||||
}
|
||||
num_shard
|
||||
Ok(num_shard)
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
enum LauncherError {
|
||||
ArgumentValidation(String),
|
||||
NotEnoughCUDADevices(String),
|
||||
DownloadError,
|
||||
ShardCannotStart,
|
||||
ShardDisconnected,
|
||||
|
@ -616,7 +633,6 @@ enum LauncherError {
|
|||
|
||||
fn download_convert_model(args: &Args, running: Arc<AtomicBool>) -> Result<(), LauncherError> {
|
||||
let mut download_argv = vec![
|
||||
"text-generation-server".to_string(),
|
||||
"download-weights".to_string(),
|
||||
args.model_id.to_string(),
|
||||
"--extension".to_string(),
|
||||
|
@ -664,25 +680,21 @@ fn download_convert_model(args: &Args, running: Arc<AtomicBool>) -> Result<(), L
|
|||
|
||||
// Start process
|
||||
tracing::info!("Starting download process.");
|
||||
let mut download_process = match Popen::create(
|
||||
&download_argv,
|
||||
PopenConfig {
|
||||
stdout: Redirection::Pipe,
|
||||
stderr: Redirection::Pipe,
|
||||
// Needed for the shutdown procedure
|
||||
setpgid: true,
|
||||
env: Some(env),
|
||||
..Default::default()
|
||||
},
|
||||
) {
|
||||
let mut download_process = match Command::new("text-generation-server")
|
||||
.args(download_argv)
|
||||
.envs(env)
|
||||
.stdout(Stdio::piped())
|
||||
.stderr(Stdio::piped())
|
||||
.process_group(0)
|
||||
.spawn()
|
||||
{
|
||||
Ok(p) => p,
|
||||
Err(err) => {
|
||||
if let PopenError::IoError(ref err) = err {
|
||||
if err.kind() == io::ErrorKind::NotFound {
|
||||
tracing::error!("text-generation-server not found in PATH");
|
||||
tracing::error!("Please install it with `make install-server`")
|
||||
}
|
||||
}
|
||||
|
||||
return Err(LauncherError::DownloadError);
|
||||
}
|
||||
};
|
||||
|
@ -702,25 +714,12 @@ fn download_convert_model(args: &Args, running: Arc<AtomicBool>) -> Result<(), L
|
|||
});
|
||||
|
||||
loop {
|
||||
if let Some(status) = download_process.poll() {
|
||||
match status {
|
||||
ExitStatus::Exited(exit_code) => {
|
||||
if exit_code == 0 {
|
||||
if let Some(status) = download_process.try_wait().unwrap() {
|
||||
if status.success() {
|
||||
tracing::info!("Successfully downloaded weights.");
|
||||
break;
|
||||
} else {
|
||||
let mut err = String::new();
|
||||
download_process
|
||||
.stderr
|
||||
.take()
|
||||
.unwrap()
|
||||
.read_to_string(&mut err)
|
||||
.unwrap();
|
||||
tracing::error!("Download encountered an error: {err}");
|
||||
return Err(LauncherError::DownloadError);
|
||||
}
|
||||
}
|
||||
ExitStatus::Signaled(signal) => {
|
||||
}
|
||||
|
||||
let mut err = String::new();
|
||||
download_process
|
||||
.stderr
|
||||
|
@ -728,23 +727,20 @@ fn download_convert_model(args: &Args, running: Arc<AtomicBool>) -> Result<(), L
|
|||
.unwrap()
|
||||
.read_to_string(&mut err)
|
||||
.unwrap();
|
||||
if let Some(signal) = status.signal() {
|
||||
tracing::error!(
|
||||
"Download process was signaled to shutdown with signal {signal}: {err}"
|
||||
);
|
||||
} else {
|
||||
tracing::error!("Download encountered an error: {err}");
|
||||
}
|
||||
|
||||
return Err(LauncherError::DownloadError);
|
||||
}
|
||||
e => {
|
||||
tracing::error!("Download process exited with an unknown status.: {e:?}");
|
||||
return Err(LauncherError::DownloadError);
|
||||
}
|
||||
}
|
||||
}
|
||||
if !running.load(Ordering::SeqCst) {
|
||||
download_process.terminate().unwrap();
|
||||
signal::kill(Pid::from_raw(download_process.id() as i32), Signal::SIGTERM).unwrap();
|
||||
tracing::info!("Waiting for download process to gracefully shutdown");
|
||||
download_process
|
||||
.wait_timeout(Duration::from_secs(90))
|
||||
.unwrap();
|
||||
download_process.wait().unwrap();
|
||||
tracing::info!("Download process terminated");
|
||||
return Ok(());
|
||||
}
|
||||
|
@ -854,12 +850,11 @@ fn spawn_webserver(
|
|||
args: Args,
|
||||
shutdown: Arc<AtomicBool>,
|
||||
shutdown_receiver: &mpsc::Receiver<()>,
|
||||
) -> Result<Popen, LauncherError> {
|
||||
) -> Result<Child, LauncherError> {
|
||||
// All shard started
|
||||
// Start webserver
|
||||
tracing::info!("Starting Webserver");
|
||||
let mut argv = vec![
|
||||
"text-generation-router".to_string(),
|
||||
"--max-concurrent-requests".to_string(),
|
||||
args.max_concurrent_requests.to_string(),
|
||||
"--max-best-of".to_string(),
|
||||
|
@ -878,6 +873,8 @@ fn spawn_webserver(
|
|||
args.waiting_served_ratio.to_string(),
|
||||
"--max-waiting-tokens".to_string(),
|
||||
args.max_waiting_tokens.to_string(),
|
||||
"--validation-workers".to_string(),
|
||||
args.validation_workers.to_string(),
|
||||
"--hostname".to_string(),
|
||||
args.hostname.to_string(),
|
||||
"--port".to_string(),
|
||||
|
@ -942,25 +939,20 @@ fn spawn_webserver(
|
|||
env.push(("HUGGING_FACE_HUB_TOKEN".into(), api_token.into()))
|
||||
};
|
||||
|
||||
let mut webserver = match Popen::create(
|
||||
&argv,
|
||||
PopenConfig {
|
||||
stdout: Redirection::Pipe,
|
||||
stderr: Redirection::Pipe,
|
||||
// Needed for the shutdown procedure
|
||||
setpgid: true,
|
||||
env: Some(env),
|
||||
..Default::default()
|
||||
},
|
||||
) {
|
||||
let mut webserver = match Command::new("text-generation-router")
|
||||
.args(argv)
|
||||
.envs(env)
|
||||
.stdout(Stdio::piped())
|
||||
.stderr(Stdio::piped())
|
||||
.process_group(0)
|
||||
.spawn()
|
||||
{
|
||||
Ok(p) => p,
|
||||
Err(err) => {
|
||||
tracing::error!("Failed to start webserver: {}", err);
|
||||
if let PopenError::IoError(err) = err {
|
||||
if err.kind() == io::ErrorKind::NotFound {
|
||||
tracing::error!("text-generation-router not found in PATH");
|
||||
tracing::error!("Please install it with `make install-router`")
|
||||
}
|
||||
} else {
|
||||
tracing::error!("{}", err);
|
||||
}
|
||||
|
@ -1004,7 +996,37 @@ fn main() -> Result<(), LauncherError> {
|
|||
|
||||
tracing::info!("{:?}", args);
|
||||
|
||||
let num_shard = find_num_shards(args.sharded, args.num_shard);
|
||||
// Validate args
|
||||
if args.max_input_length >= args.max_total_tokens {
|
||||
return Err(LauncherError::ArgumentValidation(
|
||||
"`max_input_length` must be < `max_total_tokens`".to_string(),
|
||||
));
|
||||
}
|
||||
if args.max_input_length as u32 > args.max_batch_prefill_tokens {
|
||||
return Err(LauncherError::ArgumentValidation(format!(
|
||||
"`max_batch_prefill_tokens` must be >= `max_input_length`. Given: {} and {}",
|
||||
args.max_batch_prefill_tokens, args.max_input_length
|
||||
)));
|
||||
}
|
||||
if args.max_batch_prefill_tokens > args.max_batch_total_tokens {
|
||||
return Err(LauncherError::ArgumentValidation(format!(
|
||||
"`max_batch_prefill_tokens` must be <= `max_batch_total_tokens`. Given: {} and {}",
|
||||
args.max_batch_prefill_tokens, args.max_batch_total_tokens
|
||||
)));
|
||||
}
|
||||
if args.max_total_tokens as u32 > args.max_batch_total_tokens {
|
||||
return Err(LauncherError::ArgumentValidation(format!(
|
||||
"`max_total_tokens` must be <= `max_batch_total_tokens`. Given: {} and {}",
|
||||
args.max_total_tokens, args.max_batch_total_tokens
|
||||
)));
|
||||
}
|
||||
if args.validation_workers == 0 {
|
||||
return Err(LauncherError::ArgumentValidation(
|
||||
"`validation_workers` must be > 0".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
let num_shard = find_num_shards(args.sharded, args.num_shard)?;
|
||||
if num_shard > 1 {
|
||||
tracing::info!("Sharding model on {num_shard} processes");
|
||||
}
|
||||
|
@ -1065,7 +1087,7 @@ fn main() -> Result<(), LauncherError> {
|
|||
break;
|
||||
};
|
||||
|
||||
match webserver.poll() {
|
||||
match webserver.try_wait().unwrap() {
|
||||
Some(_) => {
|
||||
tracing::error!("Webserver Crashed");
|
||||
shutdown_shards(shutdown, &shutdown_receiver);
|
||||
|
@ -1078,9 +1100,9 @@ fn main() -> Result<(), LauncherError> {
|
|||
}
|
||||
|
||||
// Graceful termination
|
||||
webserver.terminate().unwrap();
|
||||
signal::kill(Pid::from_raw(webserver.id() as i32), Signal::SIGTERM).unwrap();
|
||||
tracing::info!("Waiting for webserver to gracefully shutdown");
|
||||
webserver.wait_timeout(Duration::from_secs(90)).unwrap();
|
||||
webserver.wait().unwrap();
|
||||
tracing::info!("Webserver terminated");
|
||||
shutdown_shards(shutdown, &shutdown_receiver);
|
||||
|
||||
|
|
|
@ -98,7 +98,7 @@ impl Client {
|
|||
/// Warmup on a max size batch
|
||||
///
|
||||
/// Returns the maximum amount of tokens supported by the hardware
|
||||
#[instrument(skip(self))]
|
||||
#[instrument(skip_all)]
|
||||
pub async fn warmup(
|
||||
&mut self,
|
||||
max_input_length: u32,
|
||||
|
|
|
@ -102,17 +102,24 @@ fn main() -> Result<(), RouterError> {
|
|||
} = args;
|
||||
|
||||
// Validate args
|
||||
if max_input_length >= max_total_tokens {
|
||||
return Err(RouterError::ArgumentValidation(
|
||||
"`max_input_length` must be < `max_total_tokens`".to_string(),
|
||||
));
|
||||
}
|
||||
if max_input_length as u32 > max_batch_prefill_tokens {
|
||||
panic!("{}", format!("`max_batch_prefill_tokens` must be >= `max_input_length`. Given: {max_batch_prefill_tokens} and {max_input_length}"));
|
||||
return Err(RouterError::ArgumentValidation(format!("`max_batch_prefill_tokens` must be >= `max_input_length`. Given: {max_batch_prefill_tokens} and {max_input_length}")));
|
||||
}
|
||||
if max_batch_prefill_tokens > max_batch_total_tokens {
|
||||
panic!("{}", format!("`max_batch_prefill_tokens` must be <= `max_batch_total_tokens`. Given: {max_batch_prefill_tokens} and {max_batch_total_tokens}"));
|
||||
return Err(RouterError::ArgumentValidation(format!("`max_batch_prefill_tokens` must be <= `max_batch_total_tokens`. Given: {max_batch_prefill_tokens} and {max_batch_total_tokens}")));
|
||||
}
|
||||
if max_total_tokens as u32 > max_batch_total_tokens {
|
||||
panic!("{}", format!("`max_total_tokens` must be <= `max_batch_total_tokens`. Given: {max_total_tokens} and {max_batch_total_tokens}"));
|
||||
return Err(RouterError::ArgumentValidation(format!("`max_total_tokens` must be <= `max_batch_total_tokens`. Given: {max_total_tokens} and {max_batch_total_tokens}")));
|
||||
}
|
||||
if validation_workers == 0 {
|
||||
panic!("`validation_workers` must be > 0");
|
||||
return Err(RouterError::ArgumentValidation(
|
||||
"`validation_workers` must be > 0".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
// CORS allowed origins
|
||||
|
@ -331,6 +338,8 @@ pub async fn get_model_info(
|
|||
|
||||
#[derive(Debug, Error)]
|
||||
enum RouterError {
|
||||
#[error("Argument validation error: {0}")]
|
||||
ArgumentValidation(String),
|
||||
#[error("Unable to connect to the Python model shards: {0}")]
|
||||
Connection(ClientError),
|
||||
#[error("Unable to clear the Python model shards cache: {0}")]
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
use crate::validation::ValidationError::{BestOfSampling, BestOfSeed, EmptyInput};
|
||||
/// Payload validation logic
|
||||
use crate::validation::ValidationError::{BestOfSampling, BestOfSeed, EmptyInput};
|
||||
use crate::{GenerateParameters, GenerateRequest};
|
||||
use rand::{thread_rng, Rng};
|
||||
use text_generation_client::{NextTokenChooserParameters, StoppingCriteriaParameters};
|
||||
|
@ -30,10 +30,6 @@ impl Validation {
|
|||
max_input_length: usize,
|
||||
max_total_tokens: usize,
|
||||
) -> Self {
|
||||
if max_input_length >= max_total_tokens {
|
||||
panic!("`max_input_length` must be < `max_total_tokens`");
|
||||
}
|
||||
|
||||
// If we have a fast tokenizer
|
||||
let sender = if let Some(tokenizer) = tokenizer {
|
||||
// Create channel
|
||||
|
|
Loading…
Reference in New Issue