diff --git a/Cargo.lock b/Cargo.lock index 1d030249..867503f2 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -538,7 +538,7 @@ dependencies = [ "autocfg", "cfg-if", "crossbeam-utils", - "memoffset", + "memoffset 0.9.0", "scopeguard", ] @@ -1412,6 +1412,15 @@ version = "2.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2dffe52ecf27772e601905b7522cb4ef790d2cc203488bbd0e2fe85fcb74566d" +[[package]] +name = "memoffset" +version = "0.7.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5de893c32cde5f383baa4c04c5d6dbdd735cfd4a794b0debdb2bb1b421da5ff4" +dependencies = [ + "autocfg", +] + [[package]] name = "memoffset" version = "0.9.0" @@ -1632,6 +1641,8 @@ dependencies = [ "bitflags 1.3.2", "cfg-if", "libc", + "memoffset 0.7.1", + "pin-utils", "static_assertions", ] @@ -2739,16 +2750,6 @@ version = "0.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "73473c0e59e6d5812c5dfe2a064a6444949f089e20eec9a2e5506596494e4623" -[[package]] -name = "subprocess" -version = "0.2.9" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0c2e86926081dda636c546d8c5e641661049d7562a68f5488be4a1f7f66f6086" -dependencies = [ - "libc", - "winapi", -] - [[package]] name = "subtle" version = "2.5.0" @@ -2889,10 +2890,10 @@ dependencies = [ "clap", "ctrlc", "float_eq", + "nix", "reqwest", "serde", "serde_json", - "subprocess", "tracing", "tracing-subscriber", "vergen", diff --git a/launcher/Cargo.toml b/launcher/Cargo.toml index 6439e960..ae0694da 100644 --- a/launcher/Cargo.toml +++ b/launcher/Cargo.toml @@ -9,9 +9,9 @@ homepage.workspace = true [dependencies] clap = { version = "4.1.4", features = ["derive", "env"] } ctrlc = { version = "3.2.5", features = ["termination"] } +nix = "0.26.2" serde = { version = "1.0.152", features = ["derive"] } serde_json = "1.0.93" -subprocess = "0.2.9" tracing = "0.1.37" tracing-subscriber = { version = "0.3.16", features = ["json"] } diff --git a/launcher/src/main.rs b/launcher/src/main.rs index a612eb6d..8b34dfe3 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -1,9 +1,13 @@ use clap::{Parser, ValueEnum}; +use nix::sys::signal::{self, Signal}; +use nix::unistd::Pid; use serde::Deserialize; use std::env; use std::ffi::OsString; use std::io::{BufRead, BufReader, Read}; +use std::os::unix::process::{CommandExt, ExitStatusExt}; use std::path::Path; +use std::process::{Child, Command, Stdio}; use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::mpsc::TryRecvError; use std::sync::{mpsc, Arc}; @@ -11,7 +15,6 @@ use std::thread; use std::thread::sleep; use std::time::{Duration, Instant}; use std::{fs, io}; -use subprocess::{ExitStatus, Popen, PopenConfig, PopenError, Redirection}; mod env_runtime; @@ -72,6 +75,11 @@ struct Args { #[clap(long, env)] revision: Option, + /// The number of tokenizer workers used for payload validation and truncation inside the + /// router. + #[clap(default_value = "2", long, env)] + validation_workers: usize, + /// Whether to shard the model across multiple GPUs /// By default text-generation-inference will use all available GPUs to run /// the model. Setting it to `false` deactivates `num_shard`. @@ -306,11 +314,12 @@ fn shard_manager( let uds_string = format!("{uds_path}-{rank}"); let uds = Path::new(&uds_string); // Clean previous runs - fs::remove_file(uds).unwrap_or_default(); + if uds.exists() { + fs::remove_file(uds).unwrap(); + } // Process args let mut shard_argv = vec![ - "text-generation-server".to_string(), "serve".to_string(), model_id, "--uds-path".to_string(), @@ -415,26 +424,23 @@ fn shard_manager( // Start process tracing::info!("Starting shard {rank}"); - let mut p = match Popen::create( - &shard_argv, - PopenConfig { - stdout: Redirection::Pipe, - stderr: Redirection::Pipe, - // Needed for the shutdown procedure - setpgid: true, - // NCCL env vars - env: Some(env), - ..Default::default() - }, - ) { + let mut p = match Command::new("text-generation-server") + .args(shard_argv) + .envs(env) + .stdout(Stdio::piped()) + .stderr(Stdio::piped()) + .process_group(0) + .spawn() + { Ok(p) => p, Err(err) => { - if let PopenError::IoError(ref err) = err { - if err.kind() == io::ErrorKind::NotFound { - tracing::error!("text-generation-server not found in PATH"); - tracing::error!("Please install it with `make install-server`") - } + if err.kind() == io::ErrorKind::NotFound { + tracing::error!("text-generation-server not found in PATH"); + tracing::error!("Please install it with `make install-server`") + } else { + tracing::error!("{}", err); } + status_sender .send(ShardStatus::Failed((rank, Some(err.to_string())))) .unwrap(); @@ -462,7 +468,7 @@ fn shard_manager( let mut wait_time = Instant::now(); loop { // Process exited - if let Some(exit_status) = p.poll() { + if let Some(exit_status) = p.try_wait().unwrap() { // We read stderr in another thread as it seems that `read_to_string` can block // indefinitely in some cases let (err_sender, err_receiver) = mpsc::channel(); @@ -480,7 +486,7 @@ fn shard_manager( }) .ok(); - if let ExitStatus::Signaled(signal) = exit_status { + if let Some(signal) = exit_status.signal() { tracing::error!("Shard process was signaled to shutdown with signal {signal}"); } @@ -493,7 +499,7 @@ fn shard_manager( // We received a shutdown signal if shutdown.load(Ordering::SeqCst) { p.kill().unwrap(); - let _ = p.wait_timeout(Duration::from_secs(90)); + let _ = p.wait(); tracing::info!("Shard {rank} terminated"); return; } @@ -573,7 +579,10 @@ impl PythonLogMessage { } } -fn find_num_shards(sharded: Option, num_shard: Option) -> usize { +fn find_num_shards( + sharded: Option, + num_shard: Option, +) -> Result { // get the number of shards given `sharded` and `num_shard` let num_shard = match (sharded, num_shard) { (Some(true), None) => { @@ -582,14 +591,18 @@ fn find_num_shards(sharded: Option, num_shard: Option) -> usize { let n_devices = num_cuda_devices() .expect("--num-shard and CUDA_VISIBLE_DEVICES/NVIDIA_VISIBLE_DEVICES are not set"); if n_devices <= 1 { - panic!("`sharded` is true but only found {n_devices} CUDA devices"); + return Err(LauncherError::NotEnoughCUDADevices(format!( + "`sharded` is true but only found {n_devices} CUDA devices" + ))); } n_devices } (Some(true), Some(num_shard)) => { // we can't have only one shard while sharded if num_shard <= 1 { - panic!("`sharded` is true but `num_shard` <= 1"); + return Err(LauncherError::ArgumentValidation( + "`sharded` is true but `num_shard` <= 1".to_string(), + )); } num_shard } @@ -599,13 +612,17 @@ fn find_num_shards(sharded: Option, num_shard: Option) -> usize { (None, Some(num_shard)) => num_shard, }; if num_shard < 1 { - panic!("`num_shard` cannot be < 1"); + return Err(LauncherError::ArgumentValidation( + "`num_shard` cannot be < 1".to_string(), + )); } - num_shard + Ok(num_shard) } #[derive(Debug)] enum LauncherError { + ArgumentValidation(String), + NotEnoughCUDADevices(String), DownloadError, ShardCannotStart, ShardDisconnected, @@ -616,7 +633,6 @@ enum LauncherError { fn download_convert_model(args: &Args, running: Arc) -> Result<(), LauncherError> { let mut download_argv = vec![ - "text-generation-server".to_string(), "download-weights".to_string(), args.model_id.to_string(), "--extension".to_string(), @@ -664,25 +680,21 @@ fn download_convert_model(args: &Args, running: Arc) -> Result<(), L // Start process tracing::info!("Starting download process."); - let mut download_process = match Popen::create( - &download_argv, - PopenConfig { - stdout: Redirection::Pipe, - stderr: Redirection::Pipe, - // Needed for the shutdown procedure - setpgid: true, - env: Some(env), - ..Default::default() - }, - ) { + let mut download_process = match Command::new("text-generation-server") + .args(download_argv) + .envs(env) + .stdout(Stdio::piped()) + .stderr(Stdio::piped()) + .process_group(0) + .spawn() + { Ok(p) => p, Err(err) => { - if let PopenError::IoError(ref err) = err { - if err.kind() == io::ErrorKind::NotFound { - tracing::error!("text-generation-server not found in PATH"); - tracing::error!("Please install it with `make install-server`") - } + if err.kind() == io::ErrorKind::NotFound { + tracing::error!("text-generation-server not found in PATH"); + tracing::error!("Please install it with `make install-server`") } + return Err(LauncherError::DownloadError); } }; @@ -702,49 +714,33 @@ fn download_convert_model(args: &Args, running: Arc) -> Result<(), L }); loop { - if let Some(status) = download_process.poll() { - match status { - ExitStatus::Exited(exit_code) => { - if exit_code == 0 { - tracing::info!("Successfully downloaded weights."); - break; - } else { - let mut err = String::new(); - download_process - .stderr - .take() - .unwrap() - .read_to_string(&mut err) - .unwrap(); - tracing::error!("Download encountered an error: {err}"); - return Err(LauncherError::DownloadError); - } - } - ExitStatus::Signaled(signal) => { - let mut err = String::new(); - download_process - .stderr - .take() - .unwrap() - .read_to_string(&mut err) - .unwrap(); - tracing::error!( - "Download process was signaled to shutdown with signal {signal}: {err}" - ); - return Err(LauncherError::DownloadError); - } - e => { - tracing::error!("Download process exited with an unknown status.: {e:?}"); - return Err(LauncherError::DownloadError); - } + if let Some(status) = download_process.try_wait().unwrap() { + if status.success() { + tracing::info!("Successfully downloaded weights."); + break; } + + let mut err = String::new(); + download_process + .stderr + .take() + .unwrap() + .read_to_string(&mut err) + .unwrap(); + if let Some(signal) = status.signal() { + tracing::error!( + "Download process was signaled to shutdown with signal {signal}: {err}" + ); + } else { + tracing::error!("Download encountered an error: {err}"); + } + + return Err(LauncherError::DownloadError); } if !running.load(Ordering::SeqCst) { - download_process.terminate().unwrap(); + signal::kill(Pid::from_raw(download_process.id() as i32), Signal::SIGTERM).unwrap(); tracing::info!("Waiting for download process to gracefully shutdown"); - download_process - .wait_timeout(Duration::from_secs(90)) - .unwrap(); + download_process.wait().unwrap(); tracing::info!("Download process terminated"); return Ok(()); } @@ -854,12 +850,11 @@ fn spawn_webserver( args: Args, shutdown: Arc, shutdown_receiver: &mpsc::Receiver<()>, -) -> Result { +) -> Result { // All shard started // Start webserver tracing::info!("Starting Webserver"); let mut argv = vec![ - "text-generation-router".to_string(), "--max-concurrent-requests".to_string(), args.max_concurrent_requests.to_string(), "--max-best-of".to_string(), @@ -878,6 +873,8 @@ fn spawn_webserver( args.waiting_served_ratio.to_string(), "--max-waiting-tokens".to_string(), args.max_waiting_tokens.to_string(), + "--validation-workers".to_string(), + args.validation_workers.to_string(), "--hostname".to_string(), args.hostname.to_string(), "--port".to_string(), @@ -942,25 +939,20 @@ fn spawn_webserver( env.push(("HUGGING_FACE_HUB_TOKEN".into(), api_token.into())) }; - let mut webserver = match Popen::create( - &argv, - PopenConfig { - stdout: Redirection::Pipe, - stderr: Redirection::Pipe, - // Needed for the shutdown procedure - setpgid: true, - env: Some(env), - ..Default::default() - }, - ) { + let mut webserver = match Command::new("text-generation-router") + .args(argv) + .envs(env) + .stdout(Stdio::piped()) + .stderr(Stdio::piped()) + .process_group(0) + .spawn() + { Ok(p) => p, Err(err) => { tracing::error!("Failed to start webserver: {}", err); - if let PopenError::IoError(err) = err { - if err.kind() == io::ErrorKind::NotFound { - tracing::error!("text-generation-router not found in PATH"); - tracing::error!("Please install it with `make install-router`") - } + if err.kind() == io::ErrorKind::NotFound { + tracing::error!("text-generation-router not found in PATH"); + tracing::error!("Please install it with `make install-router`") } else { tracing::error!("{}", err); } @@ -1004,7 +996,37 @@ fn main() -> Result<(), LauncherError> { tracing::info!("{:?}", args); - let num_shard = find_num_shards(args.sharded, args.num_shard); + // Validate args + if args.max_input_length >= args.max_total_tokens { + return Err(LauncherError::ArgumentValidation( + "`max_input_length` must be < `max_total_tokens`".to_string(), + )); + } + if args.max_input_length as u32 > args.max_batch_prefill_tokens { + return Err(LauncherError::ArgumentValidation(format!( + "`max_batch_prefill_tokens` must be >= `max_input_length`. Given: {} and {}", + args.max_batch_prefill_tokens, args.max_input_length + ))); + } + if args.max_batch_prefill_tokens > args.max_batch_total_tokens { + return Err(LauncherError::ArgumentValidation(format!( + "`max_batch_prefill_tokens` must be <= `max_batch_total_tokens`. Given: {} and {}", + args.max_batch_prefill_tokens, args.max_batch_total_tokens + ))); + } + if args.max_total_tokens as u32 > args.max_batch_total_tokens { + return Err(LauncherError::ArgumentValidation(format!( + "`max_total_tokens` must be <= `max_batch_total_tokens`. Given: {} and {}", + args.max_total_tokens, args.max_batch_total_tokens + ))); + } + if args.validation_workers == 0 { + return Err(LauncherError::ArgumentValidation( + "`validation_workers` must be > 0".to_string(), + )); + } + + let num_shard = find_num_shards(args.sharded, args.num_shard)?; if num_shard > 1 { tracing::info!("Sharding model on {num_shard} processes"); } @@ -1065,7 +1087,7 @@ fn main() -> Result<(), LauncherError> { break; }; - match webserver.poll() { + match webserver.try_wait().unwrap() { Some(_) => { tracing::error!("Webserver Crashed"); shutdown_shards(shutdown, &shutdown_receiver); @@ -1078,9 +1100,9 @@ fn main() -> Result<(), LauncherError> { } // Graceful termination - webserver.terminate().unwrap(); + signal::kill(Pid::from_raw(webserver.id() as i32), Signal::SIGTERM).unwrap(); tracing::info!("Waiting for webserver to gracefully shutdown"); - webserver.wait_timeout(Duration::from_secs(90)).unwrap(); + webserver.wait().unwrap(); tracing::info!("Webserver terminated"); shutdown_shards(shutdown, &shutdown_receiver); diff --git a/router/client/src/client.rs b/router/client/src/client.rs index b5e0ccc0..b9607a5d 100644 --- a/router/client/src/client.rs +++ b/router/client/src/client.rs @@ -98,7 +98,7 @@ impl Client { /// Warmup on a max size batch /// /// Returns the maximum amount of tokens supported by the hardware - #[instrument(skip(self))] + #[instrument(skip_all)] pub async fn warmup( &mut self, max_input_length: u32, diff --git a/router/src/main.rs b/router/src/main.rs index e482d834..57ddd5ba 100644 --- a/router/src/main.rs +++ b/router/src/main.rs @@ -102,17 +102,24 @@ fn main() -> Result<(), RouterError> { } = args; // Validate args + if max_input_length >= max_total_tokens { + return Err(RouterError::ArgumentValidation( + "`max_input_length` must be < `max_total_tokens`".to_string(), + )); + } if max_input_length as u32 > max_batch_prefill_tokens { - panic!("{}", format!("`max_batch_prefill_tokens` must be >= `max_input_length`. Given: {max_batch_prefill_tokens} and {max_input_length}")); + return Err(RouterError::ArgumentValidation(format!("`max_batch_prefill_tokens` must be >= `max_input_length`. Given: {max_batch_prefill_tokens} and {max_input_length}"))); } if max_batch_prefill_tokens > max_batch_total_tokens { - panic!("{}", format!("`max_batch_prefill_tokens` must be <= `max_batch_total_tokens`. Given: {max_batch_prefill_tokens} and {max_batch_total_tokens}")); + return Err(RouterError::ArgumentValidation(format!("`max_batch_prefill_tokens` must be <= `max_batch_total_tokens`. Given: {max_batch_prefill_tokens} and {max_batch_total_tokens}"))); } if max_total_tokens as u32 > max_batch_total_tokens { - panic!("{}", format!("`max_total_tokens` must be <= `max_batch_total_tokens`. Given: {max_total_tokens} and {max_batch_total_tokens}")); + return Err(RouterError::ArgumentValidation(format!("`max_total_tokens` must be <= `max_batch_total_tokens`. Given: {max_total_tokens} and {max_batch_total_tokens}"))); } if validation_workers == 0 { - panic!("`validation_workers` must be > 0"); + return Err(RouterError::ArgumentValidation( + "`validation_workers` must be > 0".to_string(), + )); } // CORS allowed origins @@ -331,6 +338,8 @@ pub async fn get_model_info( #[derive(Debug, Error)] enum RouterError { + #[error("Argument validation error: {0}")] + ArgumentValidation(String), #[error("Unable to connect to the Python model shards: {0}")] Connection(ClientError), #[error("Unable to clear the Python model shards cache: {0}")] diff --git a/router/src/validation.rs b/router/src/validation.rs index 8843c6a8..be835bf0 100644 --- a/router/src/validation.rs +++ b/router/src/validation.rs @@ -1,5 +1,5 @@ -use crate::validation::ValidationError::{BestOfSampling, BestOfSeed, EmptyInput}; /// Payload validation logic +use crate::validation::ValidationError::{BestOfSampling, BestOfSeed, EmptyInput}; use crate::{GenerateParameters, GenerateRequest}; use rand::{thread_rng, Rng}; use text_generation_client::{NextTokenChooserParameters, StoppingCriteriaParameters}; @@ -30,10 +30,6 @@ impl Validation { max_input_length: usize, max_total_tokens: usize, ) -> Self { - if max_input_length >= max_total_tokens { - panic!("`max_input_length` must be < `max_total_tokens`"); - } - // If we have a fast tokenizer let sender = if let Some(tokenizer) = tokenizer { // Create channel