hf_text-generation-inference/launcher/src/main.rs

597 lines
19 KiB
Rust
Raw Normal View History

2022-10-18 07:19:03 -06:00
use clap::Parser;
2023-01-20 04:24:39 -07:00
use serde_json::Value;
use std::env;
use std::ffi::OsString;
2022-10-18 07:19:03 -06:00
use std::io::{BufRead, BufReader, Read};
use std::path::Path;
use std::process::ExitCode;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::mpsc::TryRecvError;
use std::sync::Arc;
use std::sync::{mpsc, Mutex};
use std::thread;
use std::thread::sleep;
use std::time::{Duration, Instant};
use std::{fs, io};
2023-02-14 05:02:16 -07:00
use subprocess::{ExitStatus, Popen, PopenConfig, PopenError, Redirection};
2022-10-18 07:19:03 -06:00
/// App Configuration
#[derive(Parser, Debug)]
#[clap(author, version, about, long_about = None)]
struct Args {
#[clap(default_value = "bigscience/bloom-560m", long, env)]
model_id: String,
2022-10-18 07:19:03 -06:00
#[clap(long, env)]
2023-01-31 10:53:56 -07:00
revision: Option<String>,
#[clap(default_value = "1", long, env)]
num_shard: usize,
2022-10-27 06:25:29 -06:00
#[clap(long, env)]
quantize: bool,
2022-10-18 07:19:03 -06:00
#[clap(default_value = "128", long, env)]
max_concurrent_requests: usize,
#[clap(default_value = "1000", long, env)]
max_input_length: usize,
#[clap(default_value = "32", long, env)]
max_batch_size: usize,
2022-10-21 08:40:05 -06:00
#[clap(default_value = "20", long, env)]
max_waiting_tokens: usize,
2022-10-18 07:19:03 -06:00
#[clap(default_value = "3000", long, short, env)]
port: u16,
#[clap(default_value = "/tmp/text-generation-server", long, env)]
shard_uds_path: String,
2023-02-08 09:53:33 -07:00
#[clap(default_value = "localhost", long, env)]
2022-10-18 07:19:03 -06:00
master_addr: String,
2023-02-08 09:53:33 -07:00
#[clap(default_value = "29500", long, env)]
2022-10-18 07:19:03 -06:00
master_port: usize,
#[clap(long, env)]
2023-02-14 05:02:16 -07:00
huggingface_hub_cache: Option<String>,
#[clap(long, env)]
weights_cache_override: Option<String>,
#[clap(long, env)]
disable_custom_kernels: bool,
#[clap(long, env)]
json_output: bool,
2023-02-13 05:02:45 -07:00
#[clap(long, env)]
otlp_endpoint: Option<String>,
#[clap(long, env)]
cors_allow_origin: Vec<String>,
2022-10-18 07:19:03 -06:00
}
fn main() -> ExitCode {
// Pattern match configuration
let args = Args::parse();
if args.json_output {
tracing_subscriber::fmt().json().init();
} else {
tracing_subscriber::fmt().compact().init();
}
tracing::info!("{:?}", args);
2022-10-18 07:19:03 -06:00
let Args {
model_id,
2023-01-31 10:53:56 -07:00
revision,
2022-10-18 07:19:03 -06:00
num_shard,
2022-10-27 06:25:29 -06:00
quantize,
2022-10-18 07:19:03 -06:00
max_concurrent_requests,
max_input_length,
max_batch_size,
2022-10-21 08:40:05 -06:00
max_waiting_tokens,
2022-10-18 07:19:03 -06:00
port,
shard_uds_path,
master_addr,
master_port,
2023-02-14 05:02:16 -07:00
huggingface_hub_cache,
weights_cache_override,
disable_custom_kernels,
json_output,
2023-02-13 05:02:45 -07:00
otlp_endpoint,
cors_allow_origin,
} = args;
2022-10-18 07:19:03 -06:00
// Signal handler
let running = Arc::new(AtomicBool::new(true));
let r = running.clone();
ctrlc::set_handler(move || {
r.store(false, Ordering::SeqCst);
})
.expect("Error setting Ctrl-C handler");
2023-02-14 05:02:16 -07:00
// Download weights
if weights_cache_override.is_none() {
let mut download_argv = vec![
"text-generation-server".to_string(),
"download-weights".to_string(),
model_id.clone(),
"--logger-level".to_string(),
"INFO".to_string(),
"--json-output".to_string(),
];
if num_shard == 1 {
download_argv.push("--extension".to_string());
download_argv.push(".bin".to_string());
} else {
download_argv.push("--extension".to_string());
download_argv.push(".safetensors".to_string());
}
// Model optional revision
if let Some(ref revision) = revision {
download_argv.push("--revision".to_string());
download_argv.push(revision.to_string())
}
// Copy current process env
let mut env: Vec<(OsString, OsString)> = env::vars_os().collect();
2023-02-14 05:02:16 -07:00
// If huggingface_hub_cache is set, pass it to the shard
2023-02-14 05:02:16 -07:00
// Useful when running inside a docker container
if let Some(ref huggingface_hub_cache) = huggingface_hub_cache {
env.push(("HUGGINGFACE_HUB_CACHE".into(), huggingface_hub_cache.into()));
};
// Start process
tracing::info!("Starting download process.");
2023-02-14 05:02:16 -07:00
let mut download_process = match Popen::create(
&download_argv,
PopenConfig {
stdout: Redirection::Pipe,
stderr: Redirection::Pipe,
// Needed for the shutdown procedure
setpgid: true,
env: Some(env),
..Default::default()
},
) {
Ok(p) => p,
Err(err) => {
if let PopenError::IoError(ref err) = err {
if err.kind() == io::ErrorKind::NotFound {
tracing::error!("text-generation-server not found in PATH");
tracing::error!("Please install it with `make install-server`")
}
}
return ExitCode::FAILURE;
}
};
// Redirect STDOUT to the console
let download_stdout = download_process.stdout.take().unwrap();
thread::spawn(move || {
// Enter download tracing span
let stdout = BufReader::new(download_stdout);
let _span = tracing::span!(tracing::Level::INFO, "download").entered();
for line in stdout.lines() {
// Parse loguru logs
if let Ok(value) = serde_json::from_str::<Value>(&line.unwrap()) {
if let Some(text) = value.get("text") {
// Format escaped newlines
tracing::info!("{}", text.to_string().replace("\\n", ""));
}
}
}
});
loop {
if let Some(status) = download_process.poll() {
match status {
ExitStatus::Exited(exit_code) => {
if exit_code == 0 {
tracing::info!("Successfully downloaded weights.");
break;
} else {
let mut err = String::new();
download_process
.stderr
.take()
.unwrap()
.read_to_string(&mut err)
.unwrap();
tracing::error!("Download encountered an error: {err}");
return ExitCode::FAILURE;
}
}
_ => {
tracing::error!("Download process exited with an unknown status.");
2023-02-14 05:02:16 -07:00
return ExitCode::FAILURE;
}
}
}
if !running.load(Ordering::SeqCst) {
download_process.terminate().unwrap();
tracing::info!("Waiting for download process to gracefully shutdown");
download_process
.wait_timeout(Duration::from_secs(90))
.unwrap();
tracing::info!("Download process terminated");
return ExitCode::SUCCESS;
}
sleep(Duration::from_millis(100));
}
} else {
tracing::info!(
"weights_cache_override is set to {:?}.",
weights_cache_override
);
tracing::info!("Skipping download.")
}
2022-10-18 07:19:03 -06:00
// Shared shutdown bool
let shutdown = Arc::new(Mutex::new(false));
// Shared shutdown channel
// When shutting down, the main thread will wait for all senders to be dropped
let (shutdown_sender, shutdown_receiver) = mpsc::channel();
// Shared channel to track shard status
let (status_sender, status_receiver) = mpsc::channel();
// Start shard processes
for rank in 0..num_shard {
let model_id = model_id.clone();
2023-01-31 10:53:56 -07:00
let revision = revision.clone();
2022-10-18 07:19:03 -06:00
let uds_path = shard_uds_path.clone();
let master_addr = master_addr.clone();
2023-02-14 05:02:16 -07:00
let huggingface_hub_cache = huggingface_hub_cache.clone();
let weights_cache_override = weights_cache_override.clone();
2022-10-18 07:19:03 -06:00
let status_sender = status_sender.clone();
let shutdown = shutdown.clone();
let shutdown_sender = shutdown_sender.clone();
2023-02-13 05:02:45 -07:00
let otlp_endpoint = otlp_endpoint.clone();
2022-10-18 07:19:03 -06:00
thread::spawn(move || {
shard_manager(
model_id,
2023-01-31 10:53:56 -07:00
revision,
2022-10-27 06:25:29 -06:00
quantize,
2022-10-18 07:19:03 -06:00
uds_path,
rank,
num_shard,
master_addr,
master_port,
2023-02-14 05:02:16 -07:00
huggingface_hub_cache,
weights_cache_override,
disable_custom_kernels,
2023-02-13 05:02:45 -07:00
otlp_endpoint,
2022-10-18 07:19:03 -06:00
status_sender,
shutdown,
shutdown_sender,
)
});
}
drop(shutdown_sender);
// Wait for shard to start
let mut shard_ready = 0;
while running.load(Ordering::SeqCst) {
match status_receiver.try_recv() {
Ok(ShardStatus::Ready) => {
shard_ready += 1;
if shard_ready == num_shard {
break;
}
}
Err(TryRecvError::Empty) => {
sleep(Duration::from_millis(100));
}
Ok(ShardStatus::Failed((rank, err))) => {
tracing::error!("Shard {} failed to start:\n{}", rank, err);
shutdown_shards(shutdown, &shutdown_receiver);
return ExitCode::FAILURE;
}
Err(TryRecvError::Disconnected) => {
tracing::error!("Shard status channel disconnected");
shutdown_shards(shutdown, &shutdown_receiver);
return ExitCode::FAILURE;
}
}
}
// We might have received a termination signal
if !running.load(Ordering::SeqCst) {
shutdown_shards(shutdown, &shutdown_receiver);
return ExitCode::SUCCESS;
}
// All shard started
// Start webserver
tracing::info!("Starting Webserver");
let mut argv = vec![
"text-generation-router".to_string(),
"--max-concurrent-requests".to_string(),
max_concurrent_requests.to_string(),
"--max-input-length".to_string(),
max_input_length.to_string(),
"--max-batch-size".to_string(),
max_batch_size.to_string(),
"--max-waiting-tokens".to_string(),
max_waiting_tokens.to_string(),
"--port".to_string(),
port.to_string(),
"--master-shard-uds-path".to_string(),
2023-02-13 05:02:45 -07:00
format!("{shard_uds_path}-0"),
"--tokenizer-name".to_string(),
model_id,
];
if json_output {
argv.push("--json-output".to_string());
}
2023-02-13 05:02:45 -07:00
// OpenTelemetry
if let Some(otlp_endpoint) = otlp_endpoint {
argv.push("--otlp-endpoint".to_string());
argv.push(otlp_endpoint);
}
// CORS origins
for origin in cors_allow_origin.into_iter() {
argv.push("--cors-allow-origin".to_string());
argv.push(origin);
}
2022-10-18 07:19:03 -06:00
let mut webserver = match Popen::create(
&argv,
2022-10-18 07:19:03 -06:00
PopenConfig {
stdout: Redirection::Pipe,
stderr: Redirection::Pipe,
// Needed for the shutdown procedure
setpgid: true,
..Default::default()
},
) {
Ok(p) => p,
Err(err) => {
tracing::error!("Failed to start webserver: {}", err);
if let PopenError::IoError(err) = err {
if err.kind() == io::ErrorKind::NotFound {
tracing::error!("text-generation-router not found in PATH");
tracing::error!("Please install it with `make install-router`")
}
2022-10-27 06:25:29 -06:00
} else {
tracing::error!("{}", err);
2022-10-18 07:19:03 -06:00
}
shutdown_shards(shutdown, &shutdown_receiver);
return ExitCode::FAILURE;
}
};
// Redirect STDOUT and STDERR to the console
let webserver_stdout = webserver.stdout.take().unwrap();
let webserver_stderr = webserver.stderr.take().unwrap();
thread::spawn(move || {
let stdout = BufReader::new(webserver_stdout);
let stderr = BufReader::new(webserver_stderr);
for line in stdout.lines() {
println!("{}", line.unwrap());
}
for line in stderr.lines() {
println!("{}", line.unwrap());
}
});
// Default exit code
let mut exit_code = ExitCode::SUCCESS;
while running.load(Ordering::SeqCst) {
if let Ok(ShardStatus::Failed((rank, err))) = status_receiver.try_recv() {
2023-02-14 05:02:16 -07:00
tracing::error!("Shard {rank} failed:\n{err}");
2022-10-18 07:19:03 -06:00
exit_code = ExitCode::FAILURE;
break;
};
match webserver.poll() {
Some(_) => {
tracing::error!("Webserver Crashed");
shutdown_shards(shutdown, &shutdown_receiver);
return ExitCode::FAILURE;
}
None => {
sleep(Duration::from_millis(100));
}
};
}
// Graceful termination
webserver.terminate().unwrap();
tracing::info!("Waiting for webserver to gracefully shutdown");
webserver.wait_timeout(Duration::from_secs(90)).unwrap();
tracing::info!("Webserver terminated");
shutdown_shards(shutdown, &shutdown_receiver);
exit_code
}
#[derive(Debug)]
enum ShardStatus {
Ready,
Failed((usize, String)),
}
#[allow(clippy::too_many_arguments)]
fn shard_manager(
model_id: String,
2023-01-31 10:53:56 -07:00
revision: Option<String>,
2022-10-27 06:25:29 -06:00
quantize: bool,
2022-10-18 07:19:03 -06:00
uds_path: String,
rank: usize,
world_size: usize,
master_addr: String,
master_port: usize,
2023-02-14 05:02:16 -07:00
huggingface_hub_cache: Option<String>,
weights_cache_override: Option<String>,
disable_custom_kernels: bool,
2023-02-13 05:02:45 -07:00
otlp_endpoint: Option<String>,
2022-10-18 07:19:03 -06:00
status_sender: mpsc::Sender<ShardStatus>,
shutdown: Arc<Mutex<bool>>,
_shutdown_sender: mpsc::Sender<()>,
) {
// Get UDS path
2023-02-13 05:02:45 -07:00
let uds_string = format!("{uds_path}-{rank}");
2022-10-18 07:19:03 -06:00
let uds = Path::new(&uds_string);
// Clean previous runs
fs::remove_file(uds).unwrap_or_default();
// Process args
let mut shard_argv = vec![
"text-generation-server".to_string(),
2022-10-18 07:19:03 -06:00
"serve".to_string(),
model_id,
2022-10-18 07:19:03 -06:00
"--uds-path".to_string(),
uds_path,
"--logger-level".to_string(),
"ERROR".to_string(),
"--json-output".to_string(),
2022-10-18 07:19:03 -06:00
];
2023-02-13 05:02:45 -07:00
// Activate tensor parallelism
2022-10-18 07:19:03 -06:00
if world_size > 1 {
shard_argv.push("--sharded".to_string());
}
2022-10-27 06:25:29 -06:00
if quantize {
shard_argv.push("--quantize".to_string())
}
2023-02-13 05:02:45 -07:00
// Model optional revision
2023-01-31 10:53:56 -07:00
if let Some(revision) = revision {
shard_argv.push("--revision".to_string());
shard_argv.push(revision)
}
2023-02-13 05:02:45 -07:00
// OpenTelemetry
if let Some(otlp_endpoint) = otlp_endpoint {
shard_argv.push("--otlp-endpoint".to_string());
shard_argv.push(otlp_endpoint);
}
// Copy current process env
let mut env: Vec<(OsString, OsString)> = env::vars_os().collect();
// Torch Distributed Env vars
env.push(("RANK".into(), rank.to_string().into()));
env.push(("WORLD_SIZE".into(), world_size.to_string().into()));
env.push(("MASTER_ADDR".into(), master_addr.into()));
env.push(("MASTER_PORT".into(), master_port.to_string().into()));
env.push(("NCCL_ASYNC_ERROR_HANDLING".into(), "1".into()));
// Safetensors load fast
env.push(("SAFETENSORS_FAST_GPU".into(), "1".into()));
2023-02-14 05:02:16 -07:00
// If huggingface_hub_cache is some, pass it to the shard
// Useful when running inside a docker container
2023-02-14 05:02:16 -07:00
if let Some(huggingface_hub_cache) = huggingface_hub_cache {
2023-01-20 04:24:39 -07:00
env.push(("HUGGINGFACE_HUB_CACHE".into(), huggingface_hub_cache.into()));
};
2022-10-18 07:19:03 -06:00
2023-02-14 05:02:16 -07:00
// If weights_cache_override is some, pass it to the shard
// Useful when running inside a HuggingFace Inference Endpoint
2023-02-14 05:02:16 -07:00
if let Some(weights_cache_override) = weights_cache_override {
env.push((
"WEIGHTS_CACHE_OVERRIDE".into(),
weights_cache_override.into(),
));
};
// If disable_custom_kernels is true, pass it to the shard as an env var
if disable_custom_kernels {
env.push(("DISABLE_CUSTOM_KERNELS".into(), "True".into()))
}
2022-10-18 07:19:03 -06:00
// Start process
2023-02-14 05:02:16 -07:00
tracing::info!("Starting shard {rank}");
2022-10-18 07:19:03 -06:00
let mut p = match Popen::create(
&shard_argv,
PopenConfig {
stdout: Redirection::Pipe,
stderr: Redirection::Pipe,
// Needed for the shutdown procedure
setpgid: true,
// NCCL env vars
env: Some(env),
2022-10-18 07:19:03 -06:00
..Default::default()
},
) {
Ok(p) => p,
Err(err) => {
if let PopenError::IoError(ref err) = err {
if err.kind() == io::ErrorKind::NotFound {
tracing::error!("text-generation-server not found in PATH");
2022-10-18 07:19:03 -06:00
tracing::error!("Please install it with `make install-server`")
}
}
status_sender
.send(ShardStatus::Failed((rank, err.to_string())))
.unwrap();
return;
}
};
// Redirect STDOUT to the console
let shard_stdout = p.stdout.take().unwrap();
thread::spawn(move || {
// Enter shard-manager tracing span
let stdout = BufReader::new(shard_stdout);
let _span = tracing::span!(tracing::Level::INFO, "shard-manager", rank = rank).entered();
for line in stdout.lines() {
// Parse loguru logs
if let Ok(value) = serde_json::from_str::<Value>(&line.unwrap()) {
if let Some(text) = value.get("text") {
// Format escaped newlines
tracing::error!("{}", text.to_string().replace("\\n", "\n"));
}
}
}
});
2022-10-18 07:19:03 -06:00
let mut ready = false;
let start_time = Instant::now();
let mut wait_time = Instant::now();
2022-10-18 07:19:03 -06:00
loop {
// Process exited
if p.poll().is_some() {
let mut err = String::new();
p.stderr.take().unwrap().read_to_string(&mut err).unwrap();
status_sender
.send(ShardStatus::Failed((rank, err)))
.unwrap();
return;
}
// We received a shutdown signal
if *shutdown.lock().unwrap() {
p.terminate().unwrap();
let _ = p.wait_timeout(Duration::from_secs(90));
2023-02-14 05:02:16 -07:00
tracing::info!("Shard {rank} terminated");
2022-10-18 07:19:03 -06:00
return;
}
// Shard is ready
if uds.exists() && !ready {
2023-02-14 05:02:16 -07:00
tracing::info!("Shard {rank} ready in {:?}", start_time.elapsed());
2022-10-18 07:19:03 -06:00
status_sender.send(ShardStatus::Ready).unwrap();
ready = true;
2022-10-27 06:25:29 -06:00
} else if !ready && wait_time.elapsed() > Duration::from_secs(10) {
2023-02-14 05:02:16 -07:00
tracing::info!("Waiting for shard {rank} to be ready...");
2022-10-27 06:25:29 -06:00
wait_time = Instant::now();
2022-10-18 07:19:03 -06:00
}
sleep(Duration::from_millis(100));
2022-10-18 07:19:03 -06:00
}
}
fn shutdown_shards(shutdown: Arc<Mutex<bool>>, shutdown_receiver: &mpsc::Receiver<()>) {
tracing::info!("Shutting down shards");
// Update shutdown value to true
// This will be picked up by the shard manager
{
let mut shutdown = shutdown.lock().unwrap();
*shutdown = true;
}
// Wait for shards to shutdown
// This will block till all shutdown_sender are dropped
let _ = shutdown_receiver.recv();
}