parent
7de8a377b0
commit
77758f603b
|
@ -4,7 +4,6 @@ use std::env;
|
|||
use std::ffi::OsString;
|
||||
use std::io::{BufRead, BufReader, Read};
|
||||
use std::path::Path;
|
||||
use std::process::ExitCode;
|
||||
use std::sync::atomic::{AtomicBool, Ordering};
|
||||
use std::sync::mpsc::TryRecvError;
|
||||
use std::sync::Arc;
|
||||
|
@ -73,445 +72,6 @@ struct Args {
|
|||
watermark_delta: Option<f32>,
|
||||
}
|
||||
|
||||
fn main() -> ExitCode {
|
||||
// Pattern match configuration
|
||||
let args = Args::parse();
|
||||
|
||||
if args.json_output {
|
||||
tracing_subscriber::fmt().json().init();
|
||||
} else {
|
||||
tracing_subscriber::fmt().compact().init();
|
||||
}
|
||||
|
||||
tracing::info!("{:?}", args);
|
||||
|
||||
let Args {
|
||||
model_id,
|
||||
revision,
|
||||
sharded,
|
||||
num_shard,
|
||||
quantize,
|
||||
max_concurrent_requests,
|
||||
max_best_of,
|
||||
max_stop_sequences,
|
||||
max_input_length,
|
||||
max_total_tokens,
|
||||
max_batch_size,
|
||||
max_batch_total_tokens,
|
||||
waiting_served_ratio,
|
||||
max_waiting_tokens,
|
||||
port,
|
||||
shard_uds_path,
|
||||
master_addr,
|
||||
master_port,
|
||||
huggingface_hub_cache,
|
||||
weights_cache_override,
|
||||
disable_custom_kernels,
|
||||
json_output,
|
||||
otlp_endpoint,
|
||||
cors_allow_origin,
|
||||
watermark_gamma,
|
||||
watermark_delta,
|
||||
} = args;
|
||||
|
||||
// get the number of shards given `sharded` and `num_shard`
|
||||
let num_shard = if let Some(sharded) = sharded {
|
||||
// sharded is set
|
||||
match sharded {
|
||||
// sharded is set and true
|
||||
true => {
|
||||
match num_shard {
|
||||
None => {
|
||||
// try to default to the number of available GPUs
|
||||
tracing::info!("Parsing num_shard from CUDA_VISIBLE_DEVICES");
|
||||
let n_devices = num_cuda_devices()
|
||||
.expect("--num-shard and CUDA_VISIBLE_DEVICES are not set");
|
||||
if n_devices <= 1 {
|
||||
panic!("`sharded` is true but only found {n_devices} CUDA devices");
|
||||
}
|
||||
n_devices
|
||||
}
|
||||
Some(num_shard) => {
|
||||
// we can't have only one shard while sharded
|
||||
if num_shard <= 1 {
|
||||
panic!("`sharded` is true but `num_shard` <= 1");
|
||||
}
|
||||
num_shard
|
||||
}
|
||||
}
|
||||
}
|
||||
// sharded is set and false
|
||||
false => {
|
||||
let num_shard = num_shard.unwrap_or(1);
|
||||
// we can't have more than one shard while not sharded
|
||||
if num_shard != 1 {
|
||||
panic!("`sharded` is false but `num_shard` != 1");
|
||||
}
|
||||
num_shard
|
||||
}
|
||||
}
|
||||
} else {
|
||||
match num_shard {
|
||||
// get num_shard from CUDA_VISIBLE_DEVICES or default to a single shard
|
||||
None => num_cuda_devices().unwrap_or(1),
|
||||
Some(num_shard) => num_shard,
|
||||
}
|
||||
};
|
||||
if num_shard < 1 {
|
||||
panic!("`num_shard` cannot be < 1");
|
||||
}
|
||||
|
||||
if num_shard > 1 {
|
||||
tracing::info!("Sharding model on {num_shard} processes");
|
||||
}
|
||||
|
||||
// Signal handler
|
||||
let running = Arc::new(AtomicBool::new(true));
|
||||
let r = running.clone();
|
||||
ctrlc::set_handler(move || {
|
||||
r.store(false, Ordering::SeqCst);
|
||||
})
|
||||
.expect("Error setting Ctrl-C handler");
|
||||
|
||||
// Check if model_id is a local model
|
||||
let local_path = Path::new(&model_id);
|
||||
let is_local_model = local_path.exists() && local_path.is_dir();
|
||||
|
||||
// Download weights for sharded models
|
||||
if !is_local_model && weights_cache_override.is_none() && num_shard > 1 {
|
||||
let mut download_argv = vec![
|
||||
"text-generation-server".to_string(),
|
||||
"download-weights".to_string(),
|
||||
model_id.clone(),
|
||||
"--extension".to_string(),
|
||||
".safetensors".to_string(),
|
||||
"--logger-level".to_string(),
|
||||
"INFO".to_string(),
|
||||
"--json-output".to_string(),
|
||||
];
|
||||
|
||||
// Model optional revision
|
||||
if let Some(ref revision) = revision {
|
||||
download_argv.push("--revision".to_string());
|
||||
download_argv.push(revision.to_string())
|
||||
}
|
||||
|
||||
// Copy current process env
|
||||
let mut env: Vec<(OsString, OsString)> = env::vars_os().collect();
|
||||
|
||||
// If huggingface_hub_cache is set, pass it to the shard
|
||||
// Useful when running inside a docker container
|
||||
if let Some(ref huggingface_hub_cache) = huggingface_hub_cache {
|
||||
env.push(("HUGGINGFACE_HUB_CACHE".into(), huggingface_hub_cache.into()));
|
||||
};
|
||||
|
||||
// Enable hf transfer for insane download speeds
|
||||
let enable_hf_transfer = env::var("HF_HUB_ENABLE_HF_TRANSFER").unwrap_or("1".to_string());
|
||||
env.push((
|
||||
"HF_HUB_ENABLE_HF_TRANSFER".into(),
|
||||
enable_hf_transfer.into(),
|
||||
));
|
||||
|
||||
// Parse Inference API token
|
||||
if let Ok(api_token) = env::var("HF_API_TOKEN") {
|
||||
env.push(("HUGGING_FACE_HUB_TOKEN".into(), api_token.into()))
|
||||
};
|
||||
|
||||
// Start process
|
||||
tracing::info!("Starting download process.");
|
||||
let mut download_process = match Popen::create(
|
||||
&download_argv,
|
||||
PopenConfig {
|
||||
stdout: Redirection::Pipe,
|
||||
stderr: Redirection::Pipe,
|
||||
// Needed for the shutdown procedure
|
||||
setpgid: true,
|
||||
env: Some(env),
|
||||
..Default::default()
|
||||
},
|
||||
) {
|
||||
Ok(p) => p,
|
||||
Err(err) => {
|
||||
if let PopenError::IoError(ref err) = err {
|
||||
if err.kind() == io::ErrorKind::NotFound {
|
||||
tracing::error!("text-generation-server not found in PATH");
|
||||
tracing::error!("Please install it with `make install-server`")
|
||||
}
|
||||
}
|
||||
return ExitCode::FAILURE;
|
||||
}
|
||||
};
|
||||
|
||||
// Redirect STDOUT to the console
|
||||
let download_stdout = download_process.stdout.take().unwrap();
|
||||
thread::spawn(move || {
|
||||
// Enter download tracing span
|
||||
let stdout = BufReader::new(download_stdout);
|
||||
let _span = tracing::span!(tracing::Level::INFO, "download").entered();
|
||||
for line in stdout.lines() {
|
||||
// Parse loguru logs
|
||||
if let Ok(log) = serde_json::from_str::<PythonLogMessage>(&line.unwrap()) {
|
||||
log.trace();
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
loop {
|
||||
if let Some(status) = download_process.poll() {
|
||||
match status {
|
||||
ExitStatus::Exited(exit_code) => {
|
||||
if exit_code == 0 {
|
||||
tracing::info!("Successfully downloaded weights.");
|
||||
break;
|
||||
} else {
|
||||
let mut err = String::new();
|
||||
download_process
|
||||
.stderr
|
||||
.take()
|
||||
.unwrap()
|
||||
.read_to_string(&mut err)
|
||||
.unwrap();
|
||||
tracing::error!("Download encountered an error: {err}");
|
||||
return ExitCode::FAILURE;
|
||||
}
|
||||
}
|
||||
_ => {
|
||||
tracing::error!("Download process exited with an unknown status.");
|
||||
return ExitCode::FAILURE;
|
||||
}
|
||||
}
|
||||
}
|
||||
if !running.load(Ordering::SeqCst) {
|
||||
download_process.terminate().unwrap();
|
||||
tracing::info!("Waiting for download process to gracefully shutdown");
|
||||
download_process
|
||||
.wait_timeout(Duration::from_secs(90))
|
||||
.unwrap();
|
||||
tracing::info!("Download process terminated");
|
||||
return ExitCode::SUCCESS;
|
||||
}
|
||||
sleep(Duration::from_millis(100));
|
||||
}
|
||||
}
|
||||
|
||||
// Shared shutdown bool
|
||||
let shutdown = Arc::new(Mutex::new(false));
|
||||
// Shared shutdown channel
|
||||
// When shutting down, the main thread will wait for all senders to be dropped
|
||||
let (shutdown_sender, shutdown_receiver) = mpsc::channel();
|
||||
|
||||
// Shared channel to track shard status
|
||||
let (status_sender, status_receiver) = mpsc::channel();
|
||||
|
||||
// Start shard processes
|
||||
for rank in 0..num_shard {
|
||||
let model_id = model_id.clone();
|
||||
let revision = revision.clone();
|
||||
let uds_path = shard_uds_path.clone();
|
||||
let master_addr = master_addr.clone();
|
||||
let huggingface_hub_cache = huggingface_hub_cache.clone();
|
||||
let weights_cache_override = weights_cache_override.clone();
|
||||
let status_sender = status_sender.clone();
|
||||
let shutdown = shutdown.clone();
|
||||
let shutdown_sender = shutdown_sender.clone();
|
||||
let otlp_endpoint = otlp_endpoint.clone();
|
||||
thread::spawn(move || {
|
||||
shard_manager(
|
||||
model_id,
|
||||
revision,
|
||||
quantize,
|
||||
uds_path,
|
||||
rank,
|
||||
num_shard,
|
||||
master_addr,
|
||||
master_port,
|
||||
huggingface_hub_cache,
|
||||
weights_cache_override,
|
||||
disable_custom_kernels,
|
||||
watermark_gamma,
|
||||
watermark_delta,
|
||||
otlp_endpoint,
|
||||
status_sender,
|
||||
shutdown,
|
||||
shutdown_sender,
|
||||
)
|
||||
});
|
||||
}
|
||||
drop(shutdown_sender);
|
||||
|
||||
// Wait for shard to start
|
||||
let mut shard_ready = 0;
|
||||
while running.load(Ordering::SeqCst) {
|
||||
match status_receiver.try_recv() {
|
||||
Ok(ShardStatus::Ready) => {
|
||||
shard_ready += 1;
|
||||
if shard_ready == num_shard {
|
||||
break;
|
||||
}
|
||||
}
|
||||
Err(TryRecvError::Empty) => {
|
||||
sleep(Duration::from_millis(100));
|
||||
}
|
||||
Ok(ShardStatus::Failed((rank, err))) => {
|
||||
tracing::error!("Shard {} failed to start:\n{}", rank, err);
|
||||
shutdown_shards(shutdown, &shutdown_receiver);
|
||||
return ExitCode::FAILURE;
|
||||
}
|
||||
Err(TryRecvError::Disconnected) => {
|
||||
tracing::error!("Shard status channel disconnected");
|
||||
shutdown_shards(shutdown, &shutdown_receiver);
|
||||
return ExitCode::FAILURE;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// We might have received a termination signal
|
||||
if !running.load(Ordering::SeqCst) {
|
||||
shutdown_shards(shutdown, &shutdown_receiver);
|
||||
return ExitCode::SUCCESS;
|
||||
}
|
||||
|
||||
// All shard started
|
||||
// Start webserver
|
||||
tracing::info!("Starting Webserver");
|
||||
let mut argv = vec![
|
||||
"text-generation-router".to_string(),
|
||||
"--max-concurrent-requests".to_string(),
|
||||
max_concurrent_requests.to_string(),
|
||||
"--max-best-of".to_string(),
|
||||
max_best_of.to_string(),
|
||||
"--max-stop-sequences".to_string(),
|
||||
max_stop_sequences.to_string(),
|
||||
"--max-input-length".to_string(),
|
||||
max_input_length.to_string(),
|
||||
"--max-total-tokens".to_string(),
|
||||
max_total_tokens.to_string(),
|
||||
"--waiting-served-ratio".to_string(),
|
||||
waiting_served_ratio.to_string(),
|
||||
"--max-waiting-tokens".to_string(),
|
||||
max_waiting_tokens.to_string(),
|
||||
"--port".to_string(),
|
||||
port.to_string(),
|
||||
"--master-shard-uds-path".to_string(),
|
||||
format!("{shard_uds_path}-0"),
|
||||
"--tokenizer-name".to_string(),
|
||||
model_id,
|
||||
];
|
||||
|
||||
// Deprecate max_batch_size
|
||||
if let Some(max_batch_size) = max_batch_size {
|
||||
argv.push("--max-batch-size".to_string());
|
||||
argv.push(max_batch_size.to_string())
|
||||
} else {
|
||||
argv.push("--max-batch-total-tokens".to_string());
|
||||
argv.push(max_batch_total_tokens.to_string())
|
||||
}
|
||||
|
||||
// Model optional revision
|
||||
if let Some(ref revision) = revision {
|
||||
argv.push("--revision".to_string());
|
||||
argv.push(revision.to_string())
|
||||
}
|
||||
|
||||
if json_output {
|
||||
argv.push("--json-output".to_string());
|
||||
}
|
||||
|
||||
// OpenTelemetry
|
||||
if let Some(otlp_endpoint) = otlp_endpoint {
|
||||
argv.push("--otlp-endpoint".to_string());
|
||||
argv.push(otlp_endpoint);
|
||||
}
|
||||
|
||||
// CORS origins
|
||||
for origin in cors_allow_origin.into_iter() {
|
||||
argv.push("--cors-allow-origin".to_string());
|
||||
argv.push(origin);
|
||||
}
|
||||
|
||||
// Copy current process env
|
||||
let mut env: Vec<(OsString, OsString)> = env::vars_os().collect();
|
||||
|
||||
// Parse Inference API token
|
||||
if let Ok(api_token) = env::var("HF_API_TOKEN") {
|
||||
env.push(("HUGGING_FACE_HUB_TOKEN".into(), api_token.into()))
|
||||
};
|
||||
|
||||
let mut webserver = match Popen::create(
|
||||
&argv,
|
||||
PopenConfig {
|
||||
stdout: Redirection::Pipe,
|
||||
stderr: Redirection::Pipe,
|
||||
// Needed for the shutdown procedure
|
||||
setpgid: true,
|
||||
env: Some(env),
|
||||
..Default::default()
|
||||
},
|
||||
) {
|
||||
Ok(p) => p,
|
||||
Err(err) => {
|
||||
tracing::error!("Failed to start webserver: {}", err);
|
||||
if let PopenError::IoError(err) = err {
|
||||
if err.kind() == io::ErrorKind::NotFound {
|
||||
tracing::error!("text-generation-router not found in PATH");
|
||||
tracing::error!("Please install it with `make install-router`")
|
||||
}
|
||||
} else {
|
||||
tracing::error!("{}", err);
|
||||
}
|
||||
|
||||
shutdown_shards(shutdown, &shutdown_receiver);
|
||||
return ExitCode::FAILURE;
|
||||
}
|
||||
};
|
||||
|
||||
// Redirect STDOUT and STDERR to the console
|
||||
let webserver_stdout = webserver.stdout.take().unwrap();
|
||||
let webserver_stderr = webserver.stderr.take().unwrap();
|
||||
|
||||
thread::spawn(move || {
|
||||
let stdout = BufReader::new(webserver_stdout);
|
||||
let stderr = BufReader::new(webserver_stderr);
|
||||
for line in stdout.lines() {
|
||||
println!("{}", line.unwrap());
|
||||
}
|
||||
for line in stderr.lines() {
|
||||
println!("{}", line.unwrap());
|
||||
}
|
||||
});
|
||||
|
||||
// Default exit code
|
||||
let mut exit_code = ExitCode::SUCCESS;
|
||||
|
||||
while running.load(Ordering::SeqCst) {
|
||||
if let Ok(ShardStatus::Failed((rank, err))) = status_receiver.try_recv() {
|
||||
tracing::error!("Shard {rank} failed:\n{err}");
|
||||
exit_code = ExitCode::FAILURE;
|
||||
break;
|
||||
};
|
||||
|
||||
match webserver.poll() {
|
||||
Some(_) => {
|
||||
tracing::error!("Webserver Crashed");
|
||||
shutdown_shards(shutdown, &shutdown_receiver);
|
||||
return ExitCode::FAILURE;
|
||||
}
|
||||
None => {
|
||||
sleep(Duration::from_millis(100));
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
// Graceful termination
|
||||
webserver.terminate().unwrap();
|
||||
tracing::info!("Waiting for webserver to gracefully shutdown");
|
||||
webserver.wait_timeout(Duration::from_secs(90)).unwrap();
|
||||
tracing::info!("Webserver terminated");
|
||||
shutdown_shards(shutdown, &shutdown_receiver);
|
||||
|
||||
exit_code
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
enum ShardStatus {
|
||||
Ready,
|
||||
|
@ -774,3 +334,450 @@ impl PythonLogMessage {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn find_num_shards(sharded: Option<bool>, num_shard: Option<usize>) -> usize {
|
||||
// get the number of shards given `sharded` and `num_shard`
|
||||
let num_shard = match (sharded, num_shard) {
|
||||
(Some(true), None) => {
|
||||
// try to default to the number of available GPUs
|
||||
tracing::info!("Parsing num_shard from CUDA_VISIBLE_DEVICES");
|
||||
let n_devices =
|
||||
num_cuda_devices().expect("--num-shard and CUDA_VISIBLE_DEVICES are not set");
|
||||
if n_devices <= 1 {
|
||||
panic!("`sharded` is true but only found {n_devices} CUDA devices");
|
||||
}
|
||||
n_devices
|
||||
}
|
||||
(Some(true), Some(num_shard)) => {
|
||||
// we can't have only one shard while sharded
|
||||
if num_shard <= 1 {
|
||||
panic!("`sharded` is true but `num_shard` <= 1");
|
||||
}
|
||||
num_shard
|
||||
}
|
||||
(Some(false), Some(num_shard)) => num_shard,
|
||||
(Some(false), None) => 1,
|
||||
(None, None) => num_cuda_devices().unwrap_or(1),
|
||||
(None, Some(num_shard)) => num_shard,
|
||||
};
|
||||
if num_shard < 1 {
|
||||
panic!("`num_shard` cannot be < 1");
|
||||
}
|
||||
num_shard
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
enum LauncherError {
|
||||
DownloadError,
|
||||
ShardCannotStart,
|
||||
ShardDisconnected,
|
||||
ShardFailed,
|
||||
WebserverFailed,
|
||||
WebserverCannotStart,
|
||||
}
|
||||
|
||||
fn download_model(args: &Args, running: Arc<AtomicBool>) -> Result<(), LauncherError> {
|
||||
let mut download_argv = vec![
|
||||
"text-generation-server".to_string(),
|
||||
"download-weights".to_string(),
|
||||
args.model_id.to_string(),
|
||||
"--extension".to_string(),
|
||||
".safetensors".to_string(),
|
||||
"--logger-level".to_string(),
|
||||
"INFO".to_string(),
|
||||
"--json-output".to_string(),
|
||||
];
|
||||
|
||||
// Model optional revision
|
||||
if let Some(revision) = &args.revision {
|
||||
download_argv.push("--revision".to_string());
|
||||
download_argv.push(revision.to_string())
|
||||
}
|
||||
|
||||
// Copy current process env
|
||||
let mut env: Vec<(OsString, OsString)> = env::vars_os().collect();
|
||||
|
||||
// If huggingface_hub_cache is set, pass it to the shard
|
||||
// Useful when running inside a docker container
|
||||
if let Some(ref huggingface_hub_cache) = args.huggingface_hub_cache {
|
||||
env.push(("HUGGINGFACE_HUB_CACHE".into(), huggingface_hub_cache.into()));
|
||||
};
|
||||
|
||||
// Enable hf transfer for insane download speeds
|
||||
let enable_hf_transfer = env::var("HF_HUB_ENABLE_HF_TRANSFER").unwrap_or("1".to_string());
|
||||
env.push((
|
||||
"HF_HUB_ENABLE_HF_TRANSFER".into(),
|
||||
enable_hf_transfer.into(),
|
||||
));
|
||||
|
||||
// Parse Inference API token
|
||||
if let Ok(api_token) = env::var("HF_API_TOKEN") {
|
||||
env.push(("HUGGING_FACE_HUB_TOKEN".into(), api_token.into()))
|
||||
};
|
||||
|
||||
// Start process
|
||||
tracing::info!("Starting download process.");
|
||||
let mut download_process = match Popen::create(
|
||||
&download_argv,
|
||||
PopenConfig {
|
||||
stdout: Redirection::Pipe,
|
||||
stderr: Redirection::Pipe,
|
||||
// Needed for the shutdown procedure
|
||||
setpgid: true,
|
||||
env: Some(env),
|
||||
..Default::default()
|
||||
},
|
||||
) {
|
||||
Ok(p) => p,
|
||||
Err(err) => {
|
||||
if let PopenError::IoError(ref err) = err {
|
||||
if err.kind() == io::ErrorKind::NotFound {
|
||||
tracing::error!("text-generation-server not found in PATH");
|
||||
tracing::error!("Please install it with `make install-server`")
|
||||
}
|
||||
}
|
||||
return Err(LauncherError::DownloadError);
|
||||
}
|
||||
};
|
||||
|
||||
// Redirect STDOUT to the console
|
||||
let download_stdout = download_process.stdout.take().unwrap();
|
||||
thread::spawn(move || {
|
||||
// Enter download tracing span
|
||||
let stdout = BufReader::new(download_stdout);
|
||||
let _span = tracing::span!(tracing::Level::INFO, "download").entered();
|
||||
for line in stdout.lines() {
|
||||
// Parse loguru logs
|
||||
if let Ok(log) = serde_json::from_str::<PythonLogMessage>(&line.unwrap()) {
|
||||
log.trace();
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
loop {
|
||||
if let Some(status) = download_process.poll() {
|
||||
match status {
|
||||
ExitStatus::Exited(exit_code) => {
|
||||
if exit_code == 0 {
|
||||
tracing::info!("Successfully downloaded weights.");
|
||||
break;
|
||||
} else {
|
||||
let mut err = String::new();
|
||||
download_process
|
||||
.stderr
|
||||
.take()
|
||||
.unwrap()
|
||||
.read_to_string(&mut err)
|
||||
.unwrap();
|
||||
tracing::error!("Download encountered an error: {err}");
|
||||
return Err(LauncherError::DownloadError);
|
||||
}
|
||||
}
|
||||
_ => {
|
||||
tracing::error!("Download process exited with an unknown status.");
|
||||
return Err(LauncherError::DownloadError);
|
||||
}
|
||||
}
|
||||
}
|
||||
if !running.load(Ordering::SeqCst) {
|
||||
download_process.terminate().unwrap();
|
||||
tracing::info!("Waiting for download process to gracefully shutdown");
|
||||
download_process
|
||||
.wait_timeout(Duration::from_secs(90))
|
||||
.unwrap();
|
||||
tracing::info!("Download process terminated");
|
||||
return Ok(());
|
||||
}
|
||||
sleep(Duration::from_millis(100));
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn spawn_shards(
|
||||
num_shard: usize,
|
||||
args: &Args,
|
||||
shutdown: Arc<Mutex<bool>>,
|
||||
shutdown_receiver: &mpsc::Receiver<()>,
|
||||
shutdown_sender: mpsc::Sender<()>,
|
||||
status_receiver: &mpsc::Receiver<ShardStatus>,
|
||||
status_sender: mpsc::Sender<ShardStatus>,
|
||||
running: Arc<AtomicBool>,
|
||||
) -> Result<(), LauncherError> {
|
||||
// Start shard processes
|
||||
for rank in 0..num_shard {
|
||||
let model_id = args.model_id.clone();
|
||||
let revision = args.revision.clone();
|
||||
let uds_path = args.shard_uds_path.clone();
|
||||
let master_addr = args.master_addr.clone();
|
||||
let huggingface_hub_cache = args.huggingface_hub_cache.clone();
|
||||
let weights_cache_override = args.weights_cache_override.clone();
|
||||
let status_sender = status_sender.clone();
|
||||
let shutdown = shutdown.clone();
|
||||
let shutdown_sender = shutdown_sender.clone();
|
||||
let otlp_endpoint = args.otlp_endpoint.clone();
|
||||
let quantize = args.quantize.clone();
|
||||
let master_port = args.master_port.clone();
|
||||
let disable_custom_kernels = args.disable_custom_kernels.clone();
|
||||
let watermark_gamma = args.watermark_gamma.clone();
|
||||
let watermark_delta = args.watermark_delta.clone();
|
||||
thread::spawn(move || {
|
||||
shard_manager(
|
||||
model_id,
|
||||
revision,
|
||||
quantize,
|
||||
uds_path,
|
||||
rank,
|
||||
num_shard,
|
||||
master_addr,
|
||||
master_port,
|
||||
huggingface_hub_cache,
|
||||
weights_cache_override,
|
||||
disable_custom_kernels,
|
||||
watermark_gamma,
|
||||
watermark_delta,
|
||||
otlp_endpoint,
|
||||
status_sender,
|
||||
shutdown,
|
||||
shutdown_sender,
|
||||
)
|
||||
});
|
||||
}
|
||||
drop(shutdown_sender);
|
||||
|
||||
// Wait for shard to start
|
||||
let mut shard_ready = 0;
|
||||
while running.load(Ordering::SeqCst) {
|
||||
match status_receiver.try_recv() {
|
||||
Ok(ShardStatus::Ready) => {
|
||||
shard_ready += 1;
|
||||
if shard_ready == num_shard {
|
||||
break;
|
||||
}
|
||||
}
|
||||
Err(TryRecvError::Empty) => {
|
||||
sleep(Duration::from_millis(100));
|
||||
}
|
||||
Ok(ShardStatus::Failed((rank, err))) => {
|
||||
tracing::error!("Shard {} failed to start:\n{}", rank, err);
|
||||
shutdown_shards(shutdown, &shutdown_receiver);
|
||||
return Err(LauncherError::ShardCannotStart);
|
||||
}
|
||||
Err(TryRecvError::Disconnected) => {
|
||||
tracing::error!("Shard status channel disconnected");
|
||||
shutdown_shards(shutdown, &shutdown_receiver);
|
||||
return Err(LauncherError::ShardDisconnected);
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn spawn_webserver(
|
||||
args: Args,
|
||||
shutdown: Arc<Mutex<bool>>,
|
||||
shutdown_receiver: &mpsc::Receiver<()>,
|
||||
) -> Result<Popen, LauncherError> {
|
||||
// All shard started
|
||||
// Start webserver
|
||||
tracing::info!("Starting Webserver");
|
||||
let mut argv = vec![
|
||||
"text-generation-router".to_string(),
|
||||
"--max-concurrent-requests".to_string(),
|
||||
args.max_concurrent_requests.to_string(),
|
||||
"--max-best-of".to_string(),
|
||||
args.max_best_of.to_string(),
|
||||
"--max-stop-sequences".to_string(),
|
||||
args.max_stop_sequences.to_string(),
|
||||
"--max-input-length".to_string(),
|
||||
args.max_input_length.to_string(),
|
||||
"--max-total-tokens".to_string(),
|
||||
args.max_total_tokens.to_string(),
|
||||
"--waiting-served-ratio".to_string(),
|
||||
args.waiting_served_ratio.to_string(),
|
||||
"--max-waiting-tokens".to_string(),
|
||||
args.max_waiting_tokens.to_string(),
|
||||
"--port".to_string(),
|
||||
args.port.to_string(),
|
||||
"--master-shard-uds-path".to_string(),
|
||||
format!("{}-0", args.shard_uds_path),
|
||||
"--tokenizer-name".to_string(),
|
||||
args.model_id,
|
||||
];
|
||||
|
||||
// Deprecate max_batch_size
|
||||
if let Some(max_batch_size) = args.max_batch_size {
|
||||
argv.push("--max-batch-size".to_string());
|
||||
argv.push(max_batch_size.to_string())
|
||||
} else {
|
||||
argv.push("--max-batch-total-tokens".to_string());
|
||||
argv.push(args.max_batch_total_tokens.to_string())
|
||||
}
|
||||
|
||||
// Model optional revision
|
||||
if let Some(ref revision) = args.revision {
|
||||
argv.push("--revision".to_string());
|
||||
argv.push(revision.to_string())
|
||||
}
|
||||
|
||||
if args.json_output {
|
||||
argv.push("--json-output".to_string());
|
||||
}
|
||||
|
||||
// OpenTelemetry
|
||||
if let Some(otlp_endpoint) = args.otlp_endpoint {
|
||||
argv.push("--otlp-endpoint".to_string());
|
||||
argv.push(otlp_endpoint);
|
||||
}
|
||||
|
||||
// CORS origins
|
||||
for origin in args.cors_allow_origin.into_iter() {
|
||||
argv.push("--cors-allow-origin".to_string());
|
||||
argv.push(origin);
|
||||
}
|
||||
|
||||
// Copy current process env
|
||||
let mut env: Vec<(OsString, OsString)> = env::vars_os().collect();
|
||||
|
||||
// Parse Inference API token
|
||||
if let Ok(api_token) = env::var("HF_API_TOKEN") {
|
||||
env.push(("HUGGING_FACE_HUB_TOKEN".into(), api_token.into()))
|
||||
};
|
||||
|
||||
let mut webserver = match Popen::create(
|
||||
&argv,
|
||||
PopenConfig {
|
||||
stdout: Redirection::Pipe,
|
||||
stderr: Redirection::Pipe,
|
||||
// Needed for the shutdown procedure
|
||||
setpgid: true,
|
||||
env: Some(env),
|
||||
..Default::default()
|
||||
},
|
||||
) {
|
||||
Ok(p) => p,
|
||||
Err(err) => {
|
||||
tracing::error!("Failed to start webserver: {}", err);
|
||||
if let PopenError::IoError(err) = err {
|
||||
if err.kind() == io::ErrorKind::NotFound {
|
||||
tracing::error!("text-generation-router not found in PATH");
|
||||
tracing::error!("Please install it with `make install-router`")
|
||||
}
|
||||
} else {
|
||||
tracing::error!("{}", err);
|
||||
}
|
||||
|
||||
shutdown_shards(shutdown, &shutdown_receiver);
|
||||
return Err(LauncherError::WebserverCannotStart);
|
||||
}
|
||||
};
|
||||
|
||||
// Redirect STDOUT and STDERR to the console
|
||||
let webserver_stdout = webserver.stdout.take().unwrap();
|
||||
let webserver_stderr = webserver.stderr.take().unwrap();
|
||||
|
||||
thread::spawn(move || {
|
||||
let stdout = BufReader::new(webserver_stdout);
|
||||
let stderr = BufReader::new(webserver_stderr);
|
||||
for line in stdout.lines() {
|
||||
println!("{}", line.unwrap());
|
||||
}
|
||||
for line in stderr.lines() {
|
||||
println!("{}", line.unwrap());
|
||||
}
|
||||
});
|
||||
Ok(webserver)
|
||||
}
|
||||
|
||||
fn main() -> Result<(), LauncherError> {
|
||||
// Pattern match configuration
|
||||
let args = Args::parse();
|
||||
|
||||
if args.json_output {
|
||||
tracing_subscriber::fmt().json().init();
|
||||
} else {
|
||||
tracing_subscriber::fmt().compact().init();
|
||||
}
|
||||
|
||||
tracing::info!("{:?}", args);
|
||||
|
||||
let num_shard = find_num_shards(args.sharded, args.num_shard);
|
||||
if num_shard > 1 {
|
||||
tracing::info!("Sharding model on {num_shard} processes");
|
||||
}
|
||||
|
||||
// Signal handler
|
||||
let running = Arc::new(AtomicBool::new(true));
|
||||
let r = running.clone();
|
||||
ctrlc::set_handler(move || {
|
||||
r.store(false, Ordering::SeqCst);
|
||||
})
|
||||
.expect("Error setting Ctrl-C handler");
|
||||
|
||||
// Check if model_id is a local model
|
||||
let local_path = Path::new(&args.model_id);
|
||||
let is_local_model = local_path.exists() && local_path.is_dir();
|
||||
|
||||
// Download weights for sharded models
|
||||
if !is_local_model && args.weights_cache_override.is_none() && num_shard > 1 {
|
||||
download_model(&args, running.clone())?;
|
||||
}
|
||||
|
||||
// Shared shutdown bool
|
||||
let shutdown = Arc::new(Mutex::new(false));
|
||||
// Shared shutdown channel
|
||||
// When shutting down, the main thread will wait for all senders to be dropped
|
||||
let (shutdown_sender, shutdown_receiver) = mpsc::channel();
|
||||
|
||||
// Shared channel to track shard status
|
||||
let (status_sender, status_receiver) = mpsc::channel();
|
||||
|
||||
spawn_shards(
|
||||
num_shard,
|
||||
&args,
|
||||
shutdown.clone(),
|
||||
&shutdown_receiver,
|
||||
shutdown_sender,
|
||||
&status_receiver,
|
||||
status_sender,
|
||||
running.clone(),
|
||||
)?;
|
||||
|
||||
// We might have received a termination signal
|
||||
if !running.load(Ordering::SeqCst) {
|
||||
shutdown_shards(shutdown, &shutdown_receiver);
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let mut webserver = spawn_webserver(args, shutdown.clone(), &shutdown_receiver)?;
|
||||
|
||||
// Default exit code
|
||||
let mut exit_code = Ok(());
|
||||
|
||||
while running.load(Ordering::SeqCst) {
|
||||
if let Ok(ShardStatus::Failed((rank, err))) = status_receiver.try_recv() {
|
||||
tracing::error!("Shard {rank} failed:\n{err}");
|
||||
exit_code = Err(LauncherError::ShardFailed);
|
||||
break;
|
||||
};
|
||||
|
||||
match webserver.poll() {
|
||||
Some(_) => {
|
||||
tracing::error!("Webserver Crashed");
|
||||
shutdown_shards(shutdown, &shutdown_receiver);
|
||||
return Err(LauncherError::WebserverFailed);
|
||||
}
|
||||
None => {
|
||||
sleep(Duration::from_millis(100));
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
// Graceful termination
|
||||
webserver.terminate().unwrap();
|
||||
tracing::info!("Waiting for webserver to gracefully shutdown");
|
||||
webserver.wait_timeout(Duration::from_secs(90)).unwrap();
|
||||
tracing::info!("Webserver terminated");
|
||||
shutdown_shards(shutdown, &shutdown_receiver);
|
||||
|
||||
exit_code
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue