feat(launcher): add arg validation and drop subprocess (#595)
This commit is contained in:
parent
3628559516
commit
b7327205a6
|
@ -538,7 +538,7 @@ dependencies = [
|
||||||
"autocfg",
|
"autocfg",
|
||||||
"cfg-if",
|
"cfg-if",
|
||||||
"crossbeam-utils",
|
"crossbeam-utils",
|
||||||
"memoffset",
|
"memoffset 0.9.0",
|
||||||
"scopeguard",
|
"scopeguard",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
@ -1412,6 +1412,15 @@ version = "2.5.0"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "2dffe52ecf27772e601905b7522cb4ef790d2cc203488bbd0e2fe85fcb74566d"
|
checksum = "2dffe52ecf27772e601905b7522cb4ef790d2cc203488bbd0e2fe85fcb74566d"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "memoffset"
|
||||||
|
version = "0.7.1"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "5de893c32cde5f383baa4c04c5d6dbdd735cfd4a794b0debdb2bb1b421da5ff4"
|
||||||
|
dependencies = [
|
||||||
|
"autocfg",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "memoffset"
|
name = "memoffset"
|
||||||
version = "0.9.0"
|
version = "0.9.0"
|
||||||
|
@ -1632,6 +1641,8 @@ dependencies = [
|
||||||
"bitflags 1.3.2",
|
"bitflags 1.3.2",
|
||||||
"cfg-if",
|
"cfg-if",
|
||||||
"libc",
|
"libc",
|
||||||
|
"memoffset 0.7.1",
|
||||||
|
"pin-utils",
|
||||||
"static_assertions",
|
"static_assertions",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
@ -2739,16 +2750,6 @@ version = "0.10.0"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "73473c0e59e6d5812c5dfe2a064a6444949f089e20eec9a2e5506596494e4623"
|
checksum = "73473c0e59e6d5812c5dfe2a064a6444949f089e20eec9a2e5506596494e4623"
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "subprocess"
|
|
||||||
version = "0.2.9"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "0c2e86926081dda636c546d8c5e641661049d7562a68f5488be4a1f7f66f6086"
|
|
||||||
dependencies = [
|
|
||||||
"libc",
|
|
||||||
"winapi",
|
|
||||||
]
|
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "subtle"
|
name = "subtle"
|
||||||
version = "2.5.0"
|
version = "2.5.0"
|
||||||
|
@ -2889,10 +2890,10 @@ dependencies = [
|
||||||
"clap",
|
"clap",
|
||||||
"ctrlc",
|
"ctrlc",
|
||||||
"float_eq",
|
"float_eq",
|
||||||
|
"nix",
|
||||||
"reqwest",
|
"reqwest",
|
||||||
"serde",
|
"serde",
|
||||||
"serde_json",
|
"serde_json",
|
||||||
"subprocess",
|
|
||||||
"tracing",
|
"tracing",
|
||||||
"tracing-subscriber",
|
"tracing-subscriber",
|
||||||
"vergen",
|
"vergen",
|
||||||
|
|
|
@ -9,9 +9,9 @@ homepage.workspace = true
|
||||||
[dependencies]
|
[dependencies]
|
||||||
clap = { version = "4.1.4", features = ["derive", "env"] }
|
clap = { version = "4.1.4", features = ["derive", "env"] }
|
||||||
ctrlc = { version = "3.2.5", features = ["termination"] }
|
ctrlc = { version = "3.2.5", features = ["termination"] }
|
||||||
|
nix = "0.26.2"
|
||||||
serde = { version = "1.0.152", features = ["derive"] }
|
serde = { version = "1.0.152", features = ["derive"] }
|
||||||
serde_json = "1.0.93"
|
serde_json = "1.0.93"
|
||||||
subprocess = "0.2.9"
|
|
||||||
tracing = "0.1.37"
|
tracing = "0.1.37"
|
||||||
tracing-subscriber = { version = "0.3.16", features = ["json"] }
|
tracing-subscriber = { version = "0.3.16", features = ["json"] }
|
||||||
|
|
||||||
|
|
|
@ -1,9 +1,13 @@
|
||||||
use clap::{Parser, ValueEnum};
|
use clap::{Parser, ValueEnum};
|
||||||
|
use nix::sys::signal::{self, Signal};
|
||||||
|
use nix::unistd::Pid;
|
||||||
use serde::Deserialize;
|
use serde::Deserialize;
|
||||||
use std::env;
|
use std::env;
|
||||||
use std::ffi::OsString;
|
use std::ffi::OsString;
|
||||||
use std::io::{BufRead, BufReader, Read};
|
use std::io::{BufRead, BufReader, Read};
|
||||||
|
use std::os::unix::process::{CommandExt, ExitStatusExt};
|
||||||
use std::path::Path;
|
use std::path::Path;
|
||||||
|
use std::process::{Child, Command, Stdio};
|
||||||
use std::sync::atomic::{AtomicBool, Ordering};
|
use std::sync::atomic::{AtomicBool, Ordering};
|
||||||
use std::sync::mpsc::TryRecvError;
|
use std::sync::mpsc::TryRecvError;
|
||||||
use std::sync::{mpsc, Arc};
|
use std::sync::{mpsc, Arc};
|
||||||
|
@ -11,7 +15,6 @@ use std::thread;
|
||||||
use std::thread::sleep;
|
use std::thread::sleep;
|
||||||
use std::time::{Duration, Instant};
|
use std::time::{Duration, Instant};
|
||||||
use std::{fs, io};
|
use std::{fs, io};
|
||||||
use subprocess::{ExitStatus, Popen, PopenConfig, PopenError, Redirection};
|
|
||||||
|
|
||||||
mod env_runtime;
|
mod env_runtime;
|
||||||
|
|
||||||
|
@ -72,6 +75,11 @@ struct Args {
|
||||||
#[clap(long, env)]
|
#[clap(long, env)]
|
||||||
revision: Option<String>,
|
revision: Option<String>,
|
||||||
|
|
||||||
|
/// The number of tokenizer workers used for payload validation and truncation inside the
|
||||||
|
/// router.
|
||||||
|
#[clap(default_value = "2", long, env)]
|
||||||
|
validation_workers: usize,
|
||||||
|
|
||||||
/// Whether to shard the model across multiple GPUs
|
/// Whether to shard the model across multiple GPUs
|
||||||
/// By default text-generation-inference will use all available GPUs to run
|
/// By default text-generation-inference will use all available GPUs to run
|
||||||
/// the model. Setting it to `false` deactivates `num_shard`.
|
/// the model. Setting it to `false` deactivates `num_shard`.
|
||||||
|
@ -306,11 +314,12 @@ fn shard_manager(
|
||||||
let uds_string = format!("{uds_path}-{rank}");
|
let uds_string = format!("{uds_path}-{rank}");
|
||||||
let uds = Path::new(&uds_string);
|
let uds = Path::new(&uds_string);
|
||||||
// Clean previous runs
|
// Clean previous runs
|
||||||
fs::remove_file(uds).unwrap_or_default();
|
if uds.exists() {
|
||||||
|
fs::remove_file(uds).unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
// Process args
|
// Process args
|
||||||
let mut shard_argv = vec![
|
let mut shard_argv = vec![
|
||||||
"text-generation-server".to_string(),
|
|
||||||
"serve".to_string(),
|
"serve".to_string(),
|
||||||
model_id,
|
model_id,
|
||||||
"--uds-path".to_string(),
|
"--uds-path".to_string(),
|
||||||
|
@ -415,26 +424,23 @@ fn shard_manager(
|
||||||
|
|
||||||
// Start process
|
// Start process
|
||||||
tracing::info!("Starting shard {rank}");
|
tracing::info!("Starting shard {rank}");
|
||||||
let mut p = match Popen::create(
|
let mut p = match Command::new("text-generation-server")
|
||||||
&shard_argv,
|
.args(shard_argv)
|
||||||
PopenConfig {
|
.envs(env)
|
||||||
stdout: Redirection::Pipe,
|
.stdout(Stdio::piped())
|
||||||
stderr: Redirection::Pipe,
|
.stderr(Stdio::piped())
|
||||||
// Needed for the shutdown procedure
|
.process_group(0)
|
||||||
setpgid: true,
|
.spawn()
|
||||||
// NCCL env vars
|
{
|
||||||
env: Some(env),
|
|
||||||
..Default::default()
|
|
||||||
},
|
|
||||||
) {
|
|
||||||
Ok(p) => p,
|
Ok(p) => p,
|
||||||
Err(err) => {
|
Err(err) => {
|
||||||
if let PopenError::IoError(ref err) = err {
|
if err.kind() == io::ErrorKind::NotFound {
|
||||||
if err.kind() == io::ErrorKind::NotFound {
|
tracing::error!("text-generation-server not found in PATH");
|
||||||
tracing::error!("text-generation-server not found in PATH");
|
tracing::error!("Please install it with `make install-server`")
|
||||||
tracing::error!("Please install it with `make install-server`")
|
} else {
|
||||||
}
|
tracing::error!("{}", err);
|
||||||
}
|
}
|
||||||
|
|
||||||
status_sender
|
status_sender
|
||||||
.send(ShardStatus::Failed((rank, Some(err.to_string()))))
|
.send(ShardStatus::Failed((rank, Some(err.to_string()))))
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
@ -462,7 +468,7 @@ fn shard_manager(
|
||||||
let mut wait_time = Instant::now();
|
let mut wait_time = Instant::now();
|
||||||
loop {
|
loop {
|
||||||
// Process exited
|
// Process exited
|
||||||
if let Some(exit_status) = p.poll() {
|
if let Some(exit_status) = p.try_wait().unwrap() {
|
||||||
// We read stderr in another thread as it seems that `read_to_string` can block
|
// We read stderr in another thread as it seems that `read_to_string` can block
|
||||||
// indefinitely in some cases
|
// indefinitely in some cases
|
||||||
let (err_sender, err_receiver) = mpsc::channel();
|
let (err_sender, err_receiver) = mpsc::channel();
|
||||||
|
@ -480,7 +486,7 @@ fn shard_manager(
|
||||||
})
|
})
|
||||||
.ok();
|
.ok();
|
||||||
|
|
||||||
if let ExitStatus::Signaled(signal) = exit_status {
|
if let Some(signal) = exit_status.signal() {
|
||||||
tracing::error!("Shard process was signaled to shutdown with signal {signal}");
|
tracing::error!("Shard process was signaled to shutdown with signal {signal}");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -493,7 +499,7 @@ fn shard_manager(
|
||||||
// We received a shutdown signal
|
// We received a shutdown signal
|
||||||
if shutdown.load(Ordering::SeqCst) {
|
if shutdown.load(Ordering::SeqCst) {
|
||||||
p.kill().unwrap();
|
p.kill().unwrap();
|
||||||
let _ = p.wait_timeout(Duration::from_secs(90));
|
let _ = p.wait();
|
||||||
tracing::info!("Shard {rank} terminated");
|
tracing::info!("Shard {rank} terminated");
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
@ -573,7 +579,10 @@ impl PythonLogMessage {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn find_num_shards(sharded: Option<bool>, num_shard: Option<usize>) -> usize {
|
fn find_num_shards(
|
||||||
|
sharded: Option<bool>,
|
||||||
|
num_shard: Option<usize>,
|
||||||
|
) -> Result<usize, LauncherError> {
|
||||||
// get the number of shards given `sharded` and `num_shard`
|
// get the number of shards given `sharded` and `num_shard`
|
||||||
let num_shard = match (sharded, num_shard) {
|
let num_shard = match (sharded, num_shard) {
|
||||||
(Some(true), None) => {
|
(Some(true), None) => {
|
||||||
|
@ -582,14 +591,18 @@ fn find_num_shards(sharded: Option<bool>, num_shard: Option<usize>) -> usize {
|
||||||
let n_devices = num_cuda_devices()
|
let n_devices = num_cuda_devices()
|
||||||
.expect("--num-shard and CUDA_VISIBLE_DEVICES/NVIDIA_VISIBLE_DEVICES are not set");
|
.expect("--num-shard and CUDA_VISIBLE_DEVICES/NVIDIA_VISIBLE_DEVICES are not set");
|
||||||
if n_devices <= 1 {
|
if n_devices <= 1 {
|
||||||
panic!("`sharded` is true but only found {n_devices} CUDA devices");
|
return Err(LauncherError::NotEnoughCUDADevices(format!(
|
||||||
|
"`sharded` is true but only found {n_devices} CUDA devices"
|
||||||
|
)));
|
||||||
}
|
}
|
||||||
n_devices
|
n_devices
|
||||||
}
|
}
|
||||||
(Some(true), Some(num_shard)) => {
|
(Some(true), Some(num_shard)) => {
|
||||||
// we can't have only one shard while sharded
|
// we can't have only one shard while sharded
|
||||||
if num_shard <= 1 {
|
if num_shard <= 1 {
|
||||||
panic!("`sharded` is true but `num_shard` <= 1");
|
return Err(LauncherError::ArgumentValidation(
|
||||||
|
"`sharded` is true but `num_shard` <= 1".to_string(),
|
||||||
|
));
|
||||||
}
|
}
|
||||||
num_shard
|
num_shard
|
||||||
}
|
}
|
||||||
|
@ -599,13 +612,17 @@ fn find_num_shards(sharded: Option<bool>, num_shard: Option<usize>) -> usize {
|
||||||
(None, Some(num_shard)) => num_shard,
|
(None, Some(num_shard)) => num_shard,
|
||||||
};
|
};
|
||||||
if num_shard < 1 {
|
if num_shard < 1 {
|
||||||
panic!("`num_shard` cannot be < 1");
|
return Err(LauncherError::ArgumentValidation(
|
||||||
|
"`num_shard` cannot be < 1".to_string(),
|
||||||
|
));
|
||||||
}
|
}
|
||||||
num_shard
|
Ok(num_shard)
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
enum LauncherError {
|
enum LauncherError {
|
||||||
|
ArgumentValidation(String),
|
||||||
|
NotEnoughCUDADevices(String),
|
||||||
DownloadError,
|
DownloadError,
|
||||||
ShardCannotStart,
|
ShardCannotStart,
|
||||||
ShardDisconnected,
|
ShardDisconnected,
|
||||||
|
@ -616,7 +633,6 @@ enum LauncherError {
|
||||||
|
|
||||||
fn download_convert_model(args: &Args, running: Arc<AtomicBool>) -> Result<(), LauncherError> {
|
fn download_convert_model(args: &Args, running: Arc<AtomicBool>) -> Result<(), LauncherError> {
|
||||||
let mut download_argv = vec![
|
let mut download_argv = vec![
|
||||||
"text-generation-server".to_string(),
|
|
||||||
"download-weights".to_string(),
|
"download-weights".to_string(),
|
||||||
args.model_id.to_string(),
|
args.model_id.to_string(),
|
||||||
"--extension".to_string(),
|
"--extension".to_string(),
|
||||||
|
@ -664,25 +680,21 @@ fn download_convert_model(args: &Args, running: Arc<AtomicBool>) -> Result<(), L
|
||||||
|
|
||||||
// Start process
|
// Start process
|
||||||
tracing::info!("Starting download process.");
|
tracing::info!("Starting download process.");
|
||||||
let mut download_process = match Popen::create(
|
let mut download_process = match Command::new("text-generation-server")
|
||||||
&download_argv,
|
.args(download_argv)
|
||||||
PopenConfig {
|
.envs(env)
|
||||||
stdout: Redirection::Pipe,
|
.stdout(Stdio::piped())
|
||||||
stderr: Redirection::Pipe,
|
.stderr(Stdio::piped())
|
||||||
// Needed for the shutdown procedure
|
.process_group(0)
|
||||||
setpgid: true,
|
.spawn()
|
||||||
env: Some(env),
|
{
|
||||||
..Default::default()
|
|
||||||
},
|
|
||||||
) {
|
|
||||||
Ok(p) => p,
|
Ok(p) => p,
|
||||||
Err(err) => {
|
Err(err) => {
|
||||||
if let PopenError::IoError(ref err) = err {
|
if err.kind() == io::ErrorKind::NotFound {
|
||||||
if err.kind() == io::ErrorKind::NotFound {
|
tracing::error!("text-generation-server not found in PATH");
|
||||||
tracing::error!("text-generation-server not found in PATH");
|
tracing::error!("Please install it with `make install-server`")
|
||||||
tracing::error!("Please install it with `make install-server`")
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return Err(LauncherError::DownloadError);
|
return Err(LauncherError::DownloadError);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
@ -702,49 +714,33 @@ fn download_convert_model(args: &Args, running: Arc<AtomicBool>) -> Result<(), L
|
||||||
});
|
});
|
||||||
|
|
||||||
loop {
|
loop {
|
||||||
if let Some(status) = download_process.poll() {
|
if let Some(status) = download_process.try_wait().unwrap() {
|
||||||
match status {
|
if status.success() {
|
||||||
ExitStatus::Exited(exit_code) => {
|
tracing::info!("Successfully downloaded weights.");
|
||||||
if exit_code == 0 {
|
break;
|
||||||
tracing::info!("Successfully downloaded weights.");
|
|
||||||
break;
|
|
||||||
} else {
|
|
||||||
let mut err = String::new();
|
|
||||||
download_process
|
|
||||||
.stderr
|
|
||||||
.take()
|
|
||||||
.unwrap()
|
|
||||||
.read_to_string(&mut err)
|
|
||||||
.unwrap();
|
|
||||||
tracing::error!("Download encountered an error: {err}");
|
|
||||||
return Err(LauncherError::DownloadError);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
ExitStatus::Signaled(signal) => {
|
|
||||||
let mut err = String::new();
|
|
||||||
download_process
|
|
||||||
.stderr
|
|
||||||
.take()
|
|
||||||
.unwrap()
|
|
||||||
.read_to_string(&mut err)
|
|
||||||
.unwrap();
|
|
||||||
tracing::error!(
|
|
||||||
"Download process was signaled to shutdown with signal {signal}: {err}"
|
|
||||||
);
|
|
||||||
return Err(LauncherError::DownloadError);
|
|
||||||
}
|
|
||||||
e => {
|
|
||||||
tracing::error!("Download process exited with an unknown status.: {e:?}");
|
|
||||||
return Err(LauncherError::DownloadError);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
let mut err = String::new();
|
||||||
|
download_process
|
||||||
|
.stderr
|
||||||
|
.take()
|
||||||
|
.unwrap()
|
||||||
|
.read_to_string(&mut err)
|
||||||
|
.unwrap();
|
||||||
|
if let Some(signal) = status.signal() {
|
||||||
|
tracing::error!(
|
||||||
|
"Download process was signaled to shutdown with signal {signal}: {err}"
|
||||||
|
);
|
||||||
|
} else {
|
||||||
|
tracing::error!("Download encountered an error: {err}");
|
||||||
|
}
|
||||||
|
|
||||||
|
return Err(LauncherError::DownloadError);
|
||||||
}
|
}
|
||||||
if !running.load(Ordering::SeqCst) {
|
if !running.load(Ordering::SeqCst) {
|
||||||
download_process.terminate().unwrap();
|
signal::kill(Pid::from_raw(download_process.id() as i32), Signal::SIGTERM).unwrap();
|
||||||
tracing::info!("Waiting for download process to gracefully shutdown");
|
tracing::info!("Waiting for download process to gracefully shutdown");
|
||||||
download_process
|
download_process.wait().unwrap();
|
||||||
.wait_timeout(Duration::from_secs(90))
|
|
||||||
.unwrap();
|
|
||||||
tracing::info!("Download process terminated");
|
tracing::info!("Download process terminated");
|
||||||
return Ok(());
|
return Ok(());
|
||||||
}
|
}
|
||||||
|
@ -854,12 +850,11 @@ fn spawn_webserver(
|
||||||
args: Args,
|
args: Args,
|
||||||
shutdown: Arc<AtomicBool>,
|
shutdown: Arc<AtomicBool>,
|
||||||
shutdown_receiver: &mpsc::Receiver<()>,
|
shutdown_receiver: &mpsc::Receiver<()>,
|
||||||
) -> Result<Popen, LauncherError> {
|
) -> Result<Child, LauncherError> {
|
||||||
// All shard started
|
// All shard started
|
||||||
// Start webserver
|
// Start webserver
|
||||||
tracing::info!("Starting Webserver");
|
tracing::info!("Starting Webserver");
|
||||||
let mut argv = vec![
|
let mut argv = vec![
|
||||||
"text-generation-router".to_string(),
|
|
||||||
"--max-concurrent-requests".to_string(),
|
"--max-concurrent-requests".to_string(),
|
||||||
args.max_concurrent_requests.to_string(),
|
args.max_concurrent_requests.to_string(),
|
||||||
"--max-best-of".to_string(),
|
"--max-best-of".to_string(),
|
||||||
|
@ -878,6 +873,8 @@ fn spawn_webserver(
|
||||||
args.waiting_served_ratio.to_string(),
|
args.waiting_served_ratio.to_string(),
|
||||||
"--max-waiting-tokens".to_string(),
|
"--max-waiting-tokens".to_string(),
|
||||||
args.max_waiting_tokens.to_string(),
|
args.max_waiting_tokens.to_string(),
|
||||||
|
"--validation-workers".to_string(),
|
||||||
|
args.validation_workers.to_string(),
|
||||||
"--hostname".to_string(),
|
"--hostname".to_string(),
|
||||||
args.hostname.to_string(),
|
args.hostname.to_string(),
|
||||||
"--port".to_string(),
|
"--port".to_string(),
|
||||||
|
@ -942,25 +939,20 @@ fn spawn_webserver(
|
||||||
env.push(("HUGGING_FACE_HUB_TOKEN".into(), api_token.into()))
|
env.push(("HUGGING_FACE_HUB_TOKEN".into(), api_token.into()))
|
||||||
};
|
};
|
||||||
|
|
||||||
let mut webserver = match Popen::create(
|
let mut webserver = match Command::new("text-generation-router")
|
||||||
&argv,
|
.args(argv)
|
||||||
PopenConfig {
|
.envs(env)
|
||||||
stdout: Redirection::Pipe,
|
.stdout(Stdio::piped())
|
||||||
stderr: Redirection::Pipe,
|
.stderr(Stdio::piped())
|
||||||
// Needed for the shutdown procedure
|
.process_group(0)
|
||||||
setpgid: true,
|
.spawn()
|
||||||
env: Some(env),
|
{
|
||||||
..Default::default()
|
|
||||||
},
|
|
||||||
) {
|
|
||||||
Ok(p) => p,
|
Ok(p) => p,
|
||||||
Err(err) => {
|
Err(err) => {
|
||||||
tracing::error!("Failed to start webserver: {}", err);
|
tracing::error!("Failed to start webserver: {}", err);
|
||||||
if let PopenError::IoError(err) = err {
|
if err.kind() == io::ErrorKind::NotFound {
|
||||||
if err.kind() == io::ErrorKind::NotFound {
|
tracing::error!("text-generation-router not found in PATH");
|
||||||
tracing::error!("text-generation-router not found in PATH");
|
tracing::error!("Please install it with `make install-router`")
|
||||||
tracing::error!("Please install it with `make install-router`")
|
|
||||||
}
|
|
||||||
} else {
|
} else {
|
||||||
tracing::error!("{}", err);
|
tracing::error!("{}", err);
|
||||||
}
|
}
|
||||||
|
@ -1004,7 +996,37 @@ fn main() -> Result<(), LauncherError> {
|
||||||
|
|
||||||
tracing::info!("{:?}", args);
|
tracing::info!("{:?}", args);
|
||||||
|
|
||||||
let num_shard = find_num_shards(args.sharded, args.num_shard);
|
// Validate args
|
||||||
|
if args.max_input_length >= args.max_total_tokens {
|
||||||
|
return Err(LauncherError::ArgumentValidation(
|
||||||
|
"`max_input_length` must be < `max_total_tokens`".to_string(),
|
||||||
|
));
|
||||||
|
}
|
||||||
|
if args.max_input_length as u32 > args.max_batch_prefill_tokens {
|
||||||
|
return Err(LauncherError::ArgumentValidation(format!(
|
||||||
|
"`max_batch_prefill_tokens` must be >= `max_input_length`. Given: {} and {}",
|
||||||
|
args.max_batch_prefill_tokens, args.max_input_length
|
||||||
|
)));
|
||||||
|
}
|
||||||
|
if args.max_batch_prefill_tokens > args.max_batch_total_tokens {
|
||||||
|
return Err(LauncherError::ArgumentValidation(format!(
|
||||||
|
"`max_batch_prefill_tokens` must be <= `max_batch_total_tokens`. Given: {} and {}",
|
||||||
|
args.max_batch_prefill_tokens, args.max_batch_total_tokens
|
||||||
|
)));
|
||||||
|
}
|
||||||
|
if args.max_total_tokens as u32 > args.max_batch_total_tokens {
|
||||||
|
return Err(LauncherError::ArgumentValidation(format!(
|
||||||
|
"`max_total_tokens` must be <= `max_batch_total_tokens`. Given: {} and {}",
|
||||||
|
args.max_total_tokens, args.max_batch_total_tokens
|
||||||
|
)));
|
||||||
|
}
|
||||||
|
if args.validation_workers == 0 {
|
||||||
|
return Err(LauncherError::ArgumentValidation(
|
||||||
|
"`validation_workers` must be > 0".to_string(),
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
let num_shard = find_num_shards(args.sharded, args.num_shard)?;
|
||||||
if num_shard > 1 {
|
if num_shard > 1 {
|
||||||
tracing::info!("Sharding model on {num_shard} processes");
|
tracing::info!("Sharding model on {num_shard} processes");
|
||||||
}
|
}
|
||||||
|
@ -1065,7 +1087,7 @@ fn main() -> Result<(), LauncherError> {
|
||||||
break;
|
break;
|
||||||
};
|
};
|
||||||
|
|
||||||
match webserver.poll() {
|
match webserver.try_wait().unwrap() {
|
||||||
Some(_) => {
|
Some(_) => {
|
||||||
tracing::error!("Webserver Crashed");
|
tracing::error!("Webserver Crashed");
|
||||||
shutdown_shards(shutdown, &shutdown_receiver);
|
shutdown_shards(shutdown, &shutdown_receiver);
|
||||||
|
@ -1078,9 +1100,9 @@ fn main() -> Result<(), LauncherError> {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Graceful termination
|
// Graceful termination
|
||||||
webserver.terminate().unwrap();
|
signal::kill(Pid::from_raw(webserver.id() as i32), Signal::SIGTERM).unwrap();
|
||||||
tracing::info!("Waiting for webserver to gracefully shutdown");
|
tracing::info!("Waiting for webserver to gracefully shutdown");
|
||||||
webserver.wait_timeout(Duration::from_secs(90)).unwrap();
|
webserver.wait().unwrap();
|
||||||
tracing::info!("Webserver terminated");
|
tracing::info!("Webserver terminated");
|
||||||
shutdown_shards(shutdown, &shutdown_receiver);
|
shutdown_shards(shutdown, &shutdown_receiver);
|
||||||
|
|
||||||
|
|
|
@ -98,7 +98,7 @@ impl Client {
|
||||||
/// Warmup on a max size batch
|
/// Warmup on a max size batch
|
||||||
///
|
///
|
||||||
/// Returns the maximum amount of tokens supported by the hardware
|
/// Returns the maximum amount of tokens supported by the hardware
|
||||||
#[instrument(skip(self))]
|
#[instrument(skip_all)]
|
||||||
pub async fn warmup(
|
pub async fn warmup(
|
||||||
&mut self,
|
&mut self,
|
||||||
max_input_length: u32,
|
max_input_length: u32,
|
||||||
|
|
|
@ -102,17 +102,24 @@ fn main() -> Result<(), RouterError> {
|
||||||
} = args;
|
} = args;
|
||||||
|
|
||||||
// Validate args
|
// Validate args
|
||||||
|
if max_input_length >= max_total_tokens {
|
||||||
|
return Err(RouterError::ArgumentValidation(
|
||||||
|
"`max_input_length` must be < `max_total_tokens`".to_string(),
|
||||||
|
));
|
||||||
|
}
|
||||||
if max_input_length as u32 > max_batch_prefill_tokens {
|
if max_input_length as u32 > max_batch_prefill_tokens {
|
||||||
panic!("{}", format!("`max_batch_prefill_tokens` must be >= `max_input_length`. Given: {max_batch_prefill_tokens} and {max_input_length}"));
|
return Err(RouterError::ArgumentValidation(format!("`max_batch_prefill_tokens` must be >= `max_input_length`. Given: {max_batch_prefill_tokens} and {max_input_length}")));
|
||||||
}
|
}
|
||||||
if max_batch_prefill_tokens > max_batch_total_tokens {
|
if max_batch_prefill_tokens > max_batch_total_tokens {
|
||||||
panic!("{}", format!("`max_batch_prefill_tokens` must be <= `max_batch_total_tokens`. Given: {max_batch_prefill_tokens} and {max_batch_total_tokens}"));
|
return Err(RouterError::ArgumentValidation(format!("`max_batch_prefill_tokens` must be <= `max_batch_total_tokens`. Given: {max_batch_prefill_tokens} and {max_batch_total_tokens}")));
|
||||||
}
|
}
|
||||||
if max_total_tokens as u32 > max_batch_total_tokens {
|
if max_total_tokens as u32 > max_batch_total_tokens {
|
||||||
panic!("{}", format!("`max_total_tokens` must be <= `max_batch_total_tokens`. Given: {max_total_tokens} and {max_batch_total_tokens}"));
|
return Err(RouterError::ArgumentValidation(format!("`max_total_tokens` must be <= `max_batch_total_tokens`. Given: {max_total_tokens} and {max_batch_total_tokens}")));
|
||||||
}
|
}
|
||||||
if validation_workers == 0 {
|
if validation_workers == 0 {
|
||||||
panic!("`validation_workers` must be > 0");
|
return Err(RouterError::ArgumentValidation(
|
||||||
|
"`validation_workers` must be > 0".to_string(),
|
||||||
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
// CORS allowed origins
|
// CORS allowed origins
|
||||||
|
@ -331,6 +338,8 @@ pub async fn get_model_info(
|
||||||
|
|
||||||
#[derive(Debug, Error)]
|
#[derive(Debug, Error)]
|
||||||
enum RouterError {
|
enum RouterError {
|
||||||
|
#[error("Argument validation error: {0}")]
|
||||||
|
ArgumentValidation(String),
|
||||||
#[error("Unable to connect to the Python model shards: {0}")]
|
#[error("Unable to connect to the Python model shards: {0}")]
|
||||||
Connection(ClientError),
|
Connection(ClientError),
|
||||||
#[error("Unable to clear the Python model shards cache: {0}")]
|
#[error("Unable to clear the Python model shards cache: {0}")]
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
use crate::validation::ValidationError::{BestOfSampling, BestOfSeed, EmptyInput};
|
|
||||||
/// Payload validation logic
|
/// Payload validation logic
|
||||||
|
use crate::validation::ValidationError::{BestOfSampling, BestOfSeed, EmptyInput};
|
||||||
use crate::{GenerateParameters, GenerateRequest};
|
use crate::{GenerateParameters, GenerateRequest};
|
||||||
use rand::{thread_rng, Rng};
|
use rand::{thread_rng, Rng};
|
||||||
use text_generation_client::{NextTokenChooserParameters, StoppingCriteriaParameters};
|
use text_generation_client::{NextTokenChooserParameters, StoppingCriteriaParameters};
|
||||||
|
@ -30,10 +30,6 @@ impl Validation {
|
||||||
max_input_length: usize,
|
max_input_length: usize,
|
||||||
max_total_tokens: usize,
|
max_total_tokens: usize,
|
||||||
) -> Self {
|
) -> Self {
|
||||||
if max_input_length >= max_total_tokens {
|
|
||||||
panic!("`max_input_length` must be < `max_total_tokens`");
|
|
||||||
}
|
|
||||||
|
|
||||||
// If we have a fast tokenizer
|
// If we have a fast tokenizer
|
||||||
let sender = if let Some(tokenizer) = tokenizer {
|
let sender = if let Some(tokenizer) = tokenizer {
|
||||||
// Create channel
|
// Create channel
|
||||||
|
|
Loading…
Reference in New Issue