2022-10-18 07:19:03 -06:00
|
|
|
use clap::Parser;
|
2023-04-16 16:26:47 -06:00
|
|
|
use serde::Deserialize;
|
2022-10-22 12:00:15 -06:00
|
|
|
use std::env;
|
2023-02-16 03:20:23 -07:00
|
|
|
use std::ffi::OsString;
|
2022-10-18 07:19:03 -06:00
|
|
|
use std::io::{BufRead, BufReader, Read};
|
|
|
|
use std::path::Path;
|
|
|
|
use std::sync::atomic::{AtomicBool, Ordering};
|
|
|
|
use std::sync::mpsc::TryRecvError;
|
|
|
|
use std::sync::Arc;
|
|
|
|
use std::sync::{mpsc, Mutex};
|
|
|
|
use std::thread;
|
|
|
|
use std::thread::sleep;
|
|
|
|
use std::time::{Duration, Instant};
|
|
|
|
use std::{fs, io};
|
2023-02-14 05:02:16 -07:00
|
|
|
use subprocess::{ExitStatus, Popen, PopenConfig, PopenError, Redirection};
|
2022-10-18 07:19:03 -06:00
|
|
|
|
2023-05-02 07:43:19 -06:00
|
|
|
mod env_runtime;
|
|
|
|
|
2022-10-18 07:19:03 -06:00
|
|
|
/// App Configuration
|
|
|
|
#[derive(Parser, Debug)]
|
|
|
|
#[clap(author, version, about, long_about = None)]
|
|
|
|
struct Args {
|
2023-04-29 03:53:42 -06:00
|
|
|
/// The name of the model to load.
|
|
|
|
/// Can be a MODEL_ID as listed on <https://hf.co/models> like
|
|
|
|
/// `gpt2` or `OpenAssistant/oasst-sft-1-pythia-12b`.
|
|
|
|
/// Or it can be a local directory containing the necessary files
|
|
|
|
/// as saved by `save_pretrained(...)` methods of transformers
|
2022-10-18 07:19:03 -06:00
|
|
|
#[clap(default_value = "bigscience/bloom-560m", long, env)]
|
2023-02-03 04:43:37 -07:00
|
|
|
model_id: String,
|
2023-04-29 03:53:42 -06:00
|
|
|
|
|
|
|
/// The actual revision of the model if you're referring to a model
|
|
|
|
/// on the hub. You can use a specific commit id or a branch like `refs/pr/2`.
|
2022-10-18 07:19:03 -06:00
|
|
|
#[clap(long, env)]
|
2023-01-31 10:53:56 -07:00
|
|
|
revision: Option<String>,
|
2023-04-29 03:53:42 -06:00
|
|
|
|
|
|
|
/// Wether to shard or not the model across multiple GPUs
|
|
|
|
/// By default text-generation-inference will use all available GPUs to run
|
|
|
|
/// the model. Setting it to `false` deactivates `num_shard`.
|
2023-03-08 03:06:59 -07:00
|
|
|
#[clap(long, env)]
|
|
|
|
sharded: Option<bool>,
|
2023-04-29 03:53:42 -06:00
|
|
|
|
|
|
|
/// The number of shards to use if you don't want to use all GPUs on a given machine.
|
|
|
|
/// You can use `CUDA_VISIBLE_DEVICE=0,1 text-generation-launcher... --num_shard 2`
|
|
|
|
/// and `CUDA_VISIBLE_DEVICE=2,3 text-generation-launcher... --num_shard 2` to
|
|
|
|
/// launch 2 copies with 2 shard each on a given machine with 4 GPUs for instance.
|
2023-03-08 03:06:59 -07:00
|
|
|
#[clap(long, env)]
|
|
|
|
num_shard: Option<usize>,
|
2023-04-29 03:53:42 -06:00
|
|
|
|
|
|
|
/// Wether you want the model to be quantized or not. This will use bitsandbytes for
|
|
|
|
/// quantization on the fly.
|
2022-10-27 06:25:29 -06:00
|
|
|
#[clap(long, env)]
|
|
|
|
quantize: bool,
|
2023-04-29 03:53:42 -06:00
|
|
|
|
|
|
|
/// The maximum amount of concurrent requests for this particular deployment.
|
|
|
|
/// Having a low limit will refuse clients requests instead of having them
|
|
|
|
/// wait for too long and is usually good to handle backpressure correctly.
|
2022-10-18 07:19:03 -06:00
|
|
|
#[clap(default_value = "128", long, env)]
|
|
|
|
max_concurrent_requests: usize,
|
2023-04-29 03:53:42 -06:00
|
|
|
|
|
|
|
/// This is the maximum allowed value for clients to set `best_of`.
|
|
|
|
/// Best of makes `n` generations at the same time, and return the best
|
|
|
|
/// in terms of overall log probability over the entire generated sequence
|
2023-03-09 07:30:54 -07:00
|
|
|
#[clap(default_value = "2", long, env)]
|
|
|
|
max_best_of: usize,
|
2023-04-29 03:53:42 -06:00
|
|
|
|
|
|
|
/// This is the maximum allowed value for clients to set `stop_sequences`.
|
|
|
|
/// Stop sequences are used to allow the model to stop on more than just
|
|
|
|
/// the EOS token, and enable more complex "prompting" where users can preprompt
|
|
|
|
/// the model in a specific way and define their "own" stop token aligned with
|
|
|
|
/// their prompt.
|
2023-03-03 08:01:25 -07:00
|
|
|
#[clap(default_value = "4", long, env)]
|
|
|
|
max_stop_sequences: usize,
|
2023-04-29 03:53:42 -06:00
|
|
|
|
|
|
|
/// This is the maximum allowed input length (expressed in number of tokens)
|
|
|
|
/// for users. The larger this value, the longer prompt users can send which
|
|
|
|
/// can impact the overall memory required to handle the load.
|
|
|
|
/// Please note that some models have a finite range of sequence they can handle.
|
2022-10-18 07:19:03 -06:00
|
|
|
#[clap(default_value = "1000", long, env)]
|
|
|
|
max_input_length: usize,
|
2023-04-29 03:53:42 -06:00
|
|
|
|
|
|
|
/// This is the most important value to set as it defines the "memory budget"
|
|
|
|
/// of running clients requests.
|
|
|
|
/// Clients will send input sequences and ask to generate `max_new_tokens`
|
|
|
|
/// on top. with a value of `1512` users can send either a prompt of
|
|
|
|
/// `1000` and ask for `512` new tokens, or send a prompt of `1` and ask for
|
|
|
|
/// `1511` max_new_tokens.
|
|
|
|
/// The larger this value, the larger amount each request will be in your RAM
|
|
|
|
/// and the less effective batching can be.
|
2023-03-03 08:01:25 -07:00
|
|
|
#[clap(default_value = "1512", long, env)]
|
|
|
|
max_total_tokens: usize,
|
2023-04-29 03:53:42 -06:00
|
|
|
|
|
|
|
/// The maximum allowed batch size during dynamic batching.
|
|
|
|
/// Using `max_batch_total_tokens` should be favored in general
|
|
|
|
/// as it's a finer way to control RAM usage.
|
2023-04-24 09:59:00 -06:00
|
|
|
#[clap(long, env)]
|
|
|
|
max_batch_size: Option<usize>,
|
2023-04-29 03:53:42 -06:00
|
|
|
|
|
|
|
/// This represents the ratio of waiting queries vs running queries where
|
|
|
|
/// you want to start considering pausing the running queries to include the waiting
|
|
|
|
/// ones into the same batch.
|
|
|
|
/// `waiting_served_ratio=1.2` Means when 12 queries are waiting and there's
|
|
|
|
/// only 10 queries left in the current batch we check if we can fit those 12
|
|
|
|
/// waiting queries into the batching strategy, and if yes, then batching happens
|
|
|
|
/// delaying the 10 running queries by a `prefill` run.
|
|
|
|
///
|
|
|
|
/// This setting is only applied if there is room in the batch
|
|
|
|
/// as defined by `max_batch_total_tokens`.
|
2023-04-24 09:59:00 -06:00
|
|
|
#[clap(default_value = "1.2", long, env)]
|
|
|
|
waiting_served_ratio: f32,
|
2023-04-29 03:53:42 -06:00
|
|
|
|
|
|
|
/// **IMPORTANT** This is one critical control to allow maximum usage
|
|
|
|
/// of the available hardware.
|
|
|
|
///
|
|
|
|
/// This represents the total amount of potential tokens within a batch.
|
|
|
|
/// When using padding (not recommended) this would be equivalent of
|
|
|
|
/// `batch_size` * `max_total_tokens`.
|
|
|
|
///
|
|
|
|
/// However in the non-padded (flash attention) version this can be much finer.
|
|
|
|
///
|
|
|
|
/// For `max_batch_total_tokens=1000`, you could fit `10` queries of `total_tokens=100`
|
|
|
|
/// or a single query of `1000` tokens.
|
|
|
|
///
|
|
|
|
/// So you don't have to control that finely
|
|
|
|
/// `max_batch_size` or `max_total_tokens`. In fact you could mostly relax them if you
|
|
|
|
/// want maximum flexibility. However, for your users if they are asking for the full amount of
|
|
|
|
/// total tokens, they are likely to wait for a very long time to get a spot
|
|
|
|
/// in the batch (since they are going to be alone) so setting `max_batch_size`
|
|
|
|
/// and `max_total_tokens` can still be useful to prevent those long waiting times.
|
|
|
|
///
|
|
|
|
/// Overall this number should be the largest possible amount that fits the
|
|
|
|
/// remaining memory (after the model is loaded). Since the actual memory overhead
|
|
|
|
/// depends on other parameters like if you're using quantization, flash attention
|
|
|
|
/// or the model implementation, text-generation-inference cannot infer this number
|
|
|
|
/// automatically.
|
2023-04-24 09:59:00 -06:00
|
|
|
#[clap(default_value = "32000", long, env)]
|
|
|
|
max_batch_total_tokens: u32,
|
2023-04-29 03:53:42 -06:00
|
|
|
|
|
|
|
/// This setting defines how many tokens can be passed before forcing the waiting
|
|
|
|
/// queries to be put on the batch (if the size of the batch allows for it).
|
|
|
|
/// New queries require 1 `prefill` forward, which is different from `decode`
|
|
|
|
/// and therefore you need to pause the running batch in order to run `prefill`
|
|
|
|
/// to create the correct values for the waiting queries to be able to join the batch.
|
|
|
|
///
|
|
|
|
/// With a value too small, queries will always "steal" the compute to run `prefill`
|
|
|
|
/// and running queries will be delayed by a lot.
|
|
|
|
///
|
|
|
|
/// With a value too big, waiting queries could wait for a very long time
|
|
|
|
/// before being allowed a slot in the running batch. If your server is busy
|
|
|
|
/// that means that requests that could run in ~2s on an empty server could
|
|
|
|
/// end up running in ~20s because the query had to wait for 18s.
|
|
|
|
///
|
|
|
|
/// This number is expressed in number of tokens to make it a bit more
|
|
|
|
/// "model" agnostic, but what should really matter is the overall latency
|
|
|
|
/// for end users.
|
2022-10-21 08:40:05 -06:00
|
|
|
#[clap(default_value = "20", long, env)]
|
|
|
|
max_waiting_tokens: usize,
|
2022-10-18 07:19:03 -06:00
|
|
|
#[clap(default_value = "3000", long, short, env)]
|
2023-04-29 03:53:42 -06:00
|
|
|
|
|
|
|
/// The port to listen on.
|
2022-10-18 07:19:03 -06:00
|
|
|
port: u16,
|
2023-04-29 03:53:42 -06:00
|
|
|
|
|
|
|
/// The name of the socket for gRPC communication between the webserver
|
|
|
|
/// and the shards.
|
2022-10-18 07:19:03 -06:00
|
|
|
#[clap(default_value = "/tmp/text-generation-server", long, env)]
|
|
|
|
shard_uds_path: String,
|
2023-04-29 03:53:42 -06:00
|
|
|
|
|
|
|
/// The address the master shard will listen on. (setting used by torch distributed)
|
2023-02-08 09:53:33 -07:00
|
|
|
#[clap(default_value = "localhost", long, env)]
|
2022-10-18 07:19:03 -06:00
|
|
|
master_addr: String,
|
2023-04-29 03:53:42 -06:00
|
|
|
|
|
|
|
/// The address the master port will listen on. (setting used by torch distributed)
|
2023-02-08 09:53:33 -07:00
|
|
|
#[clap(default_value = "29500", long, env)]
|
2022-10-18 07:19:03 -06:00
|
|
|
master_port: usize,
|
2023-04-29 03:53:42 -06:00
|
|
|
|
|
|
|
/// The location of the huggingface hub cache.
|
|
|
|
/// Used to override the location if you want to provide a mounted disk for instance
|
2022-11-02 10:29:56 -06:00
|
|
|
#[clap(long, env)]
|
2023-02-14 05:02:16 -07:00
|
|
|
huggingface_hub_cache: Option<String>,
|
2023-04-29 03:53:42 -06:00
|
|
|
|
|
|
|
/// The location of the huggingface hub cache.
|
|
|
|
/// Used to override the location if you want to provide a mounted disk for instance
|
2023-02-14 05:02:16 -07:00
|
|
|
#[clap(long, env)]
|
|
|
|
weights_cache_override: Option<String>,
|
2023-04-29 03:53:42 -06:00
|
|
|
|
|
|
|
/// For some models (like bloom), text-generation-inference implemented custom
|
|
|
|
/// cuda kernels to speed up inference. Those kernels were only tested on A100.
|
|
|
|
/// Use this flag to disable them if you're running on different hardware and
|
|
|
|
/// encounter issues.
|
2023-02-14 05:02:16 -07:00
|
|
|
#[clap(long, env)]
|
2023-02-15 08:23:45 -07:00
|
|
|
disable_custom_kernels: bool,
|
2023-04-29 03:53:42 -06:00
|
|
|
|
|
|
|
/// Outputs the logs in JSON format (useful for telemetry)
|
2023-02-15 08:23:45 -07:00
|
|
|
#[clap(long, env)]
|
2022-11-02 10:29:56 -06:00
|
|
|
json_output: bool,
|
2023-04-29 03:53:42 -06:00
|
|
|
|
2023-02-13 05:02:45 -07:00
|
|
|
#[clap(long, env)]
|
|
|
|
otlp_endpoint: Option<String>,
|
2023-04-29 03:53:42 -06:00
|
|
|
|
2023-02-17 10:22:00 -07:00
|
|
|
#[clap(long, env)]
|
|
|
|
cors_allow_origin: Vec<String>,
|
2023-03-02 04:30:41 -07:00
|
|
|
#[clap(long, env)]
|
|
|
|
watermark_gamma: Option<f32>,
|
|
|
|
#[clap(long, env)]
|
|
|
|
watermark_delta: Option<f32>,
|
2023-05-02 07:43:19 -06:00
|
|
|
|
|
|
|
/// Display a lot of information about your runtime environment
|
|
|
|
#[clap(long, short, action)]
|
|
|
|
env: bool,
|
2022-10-18 07:19:03 -06:00
|
|
|
}
|
|
|
|
|
2023-04-26 06:43:36 -06:00
|
|
|
#[derive(Debug)]
|
|
|
|
enum ShardStatus {
|
|
|
|
Ready,
|
|
|
|
Failed((usize, String)),
|
|
|
|
}
|
2023-02-15 08:11:32 -07:00
|
|
|
|
2023-04-26 06:43:36 -06:00
|
|
|
#[allow(clippy::too_many_arguments)]
|
|
|
|
fn shard_manager(
|
|
|
|
model_id: String,
|
|
|
|
revision: Option<String>,
|
|
|
|
quantize: bool,
|
|
|
|
uds_path: String,
|
|
|
|
rank: usize,
|
|
|
|
world_size: usize,
|
|
|
|
master_addr: String,
|
|
|
|
master_port: usize,
|
|
|
|
huggingface_hub_cache: Option<String>,
|
|
|
|
weights_cache_override: Option<String>,
|
|
|
|
disable_custom_kernels: bool,
|
|
|
|
watermark_gamma: Option<f32>,
|
|
|
|
watermark_delta: Option<f32>,
|
|
|
|
otlp_endpoint: Option<String>,
|
|
|
|
status_sender: mpsc::Sender<ShardStatus>,
|
|
|
|
shutdown: Arc<Mutex<bool>>,
|
|
|
|
_shutdown_sender: mpsc::Sender<()>,
|
|
|
|
) {
|
|
|
|
// Get UDS path
|
|
|
|
let uds_string = format!("{uds_path}-{rank}");
|
|
|
|
let uds = Path::new(&uds_string);
|
|
|
|
// Clean previous runs
|
|
|
|
fs::remove_file(uds).unwrap_or_default();
|
|
|
|
|
|
|
|
// Process args
|
|
|
|
let mut shard_argv = vec![
|
|
|
|
"text-generation-server".to_string(),
|
|
|
|
"serve".to_string(),
|
|
|
|
model_id,
|
|
|
|
"--uds-path".to_string(),
|
|
|
|
uds_path,
|
|
|
|
"--logger-level".to_string(),
|
|
|
|
"INFO".to_string(),
|
|
|
|
"--json-output".to_string(),
|
|
|
|
];
|
|
|
|
|
|
|
|
// Activate tensor parallelism
|
|
|
|
if world_size > 1 {
|
|
|
|
shard_argv.push("--sharded".to_string());
|
2023-02-15 08:11:32 -07:00
|
|
|
}
|
|
|
|
|
2023-04-26 06:43:36 -06:00
|
|
|
if quantize {
|
|
|
|
shard_argv.push("--quantize".to_string())
|
|
|
|
}
|
2023-02-15 08:11:32 -07:00
|
|
|
|
2023-04-26 06:43:36 -06:00
|
|
|
// Model optional revision
|
|
|
|
if let Some(revision) = revision {
|
|
|
|
shard_argv.push("--revision".to_string());
|
|
|
|
shard_argv.push(revision)
|
|
|
|
}
|
2022-10-18 07:19:03 -06:00
|
|
|
|
2023-04-26 06:43:36 -06:00
|
|
|
// OpenTelemetry
|
|
|
|
if let Some(otlp_endpoint) = otlp_endpoint {
|
|
|
|
shard_argv.push("--otlp-endpoint".to_string());
|
|
|
|
shard_argv.push(otlp_endpoint);
|
|
|
|
}
|
|
|
|
|
|
|
|
// Copy current process env
|
|
|
|
let mut env: Vec<(OsString, OsString)> = env::vars_os().collect();
|
|
|
|
|
|
|
|
// Torch Distributed Env vars
|
|
|
|
env.push(("RANK".into(), rank.to_string().into()));
|
|
|
|
env.push(("WORLD_SIZE".into(), world_size.to_string().into()));
|
|
|
|
env.push(("MASTER_ADDR".into(), master_addr.into()));
|
|
|
|
env.push(("MASTER_PORT".into(), master_port.to_string().into()));
|
|
|
|
env.push(("NCCL_ASYNC_ERROR_HANDLING".into(), "1".into()));
|
|
|
|
|
|
|
|
// Safetensors load fast
|
|
|
|
env.push(("SAFETENSORS_FAST_GPU".into(), "1".into()));
|
|
|
|
|
|
|
|
// Enable hf transfer for insane download speeds
|
|
|
|
let enable_hf_transfer = env::var("HF_HUB_ENABLE_HF_TRANSFER").unwrap_or("1".to_string());
|
|
|
|
env.push((
|
|
|
|
"HF_HUB_ENABLE_HF_TRANSFER".into(),
|
|
|
|
enable_hf_transfer.into(),
|
|
|
|
));
|
|
|
|
|
|
|
|
// Parse Inference API token
|
|
|
|
if let Ok(api_token) = env::var("HF_API_TOKEN") {
|
|
|
|
env.push(("HUGGING_FACE_HUB_TOKEN".into(), api_token.into()))
|
|
|
|
};
|
|
|
|
|
|
|
|
// If huggingface_hub_cache is some, pass it to the shard
|
|
|
|
// Useful when running inside a docker container
|
|
|
|
if let Some(huggingface_hub_cache) = huggingface_hub_cache {
|
|
|
|
env.push(("HUGGINGFACE_HUB_CACHE".into(), huggingface_hub_cache.into()));
|
|
|
|
};
|
|
|
|
|
|
|
|
// If weights_cache_override is some, pass it to the shard
|
|
|
|
// Useful when running inside a HuggingFace Inference Endpoint
|
|
|
|
if let Some(weights_cache_override) = weights_cache_override {
|
|
|
|
env.push((
|
|
|
|
"WEIGHTS_CACHE_OVERRIDE".into(),
|
|
|
|
weights_cache_override.into(),
|
|
|
|
));
|
|
|
|
};
|
|
|
|
|
|
|
|
// If disable_custom_kernels is true, pass it to the shard as an env var
|
|
|
|
if disable_custom_kernels {
|
|
|
|
env.push(("DISABLE_CUSTOM_KERNELS".into(), "True".into()))
|
|
|
|
}
|
|
|
|
|
|
|
|
// Watermark Gamma
|
|
|
|
if let Some(watermark_gamma) = watermark_gamma {
|
|
|
|
env.push(("WATERMARK_GAMMA".into(), watermark_gamma.to_string().into()))
|
|
|
|
}
|
|
|
|
|
|
|
|
// Watermark Delta
|
|
|
|
if let Some(watermark_delta) = watermark_delta {
|
|
|
|
env.push(("WATERMARK_DELTA".into(), watermark_delta.to_string().into()))
|
|
|
|
}
|
|
|
|
|
|
|
|
// Start process
|
|
|
|
tracing::info!("Starting shard {rank}");
|
|
|
|
let mut p = match Popen::create(
|
|
|
|
&shard_argv,
|
|
|
|
PopenConfig {
|
|
|
|
stdout: Redirection::Pipe,
|
|
|
|
stderr: Redirection::Pipe,
|
|
|
|
// Needed for the shutdown procedure
|
|
|
|
setpgid: true,
|
|
|
|
// NCCL env vars
|
|
|
|
env: Some(env),
|
|
|
|
..Default::default()
|
|
|
|
},
|
|
|
|
) {
|
|
|
|
Ok(p) => p,
|
|
|
|
Err(err) => {
|
|
|
|
if let PopenError::IoError(ref err) = err {
|
|
|
|
if err.kind() == io::ErrorKind::NotFound {
|
|
|
|
tracing::error!("text-generation-server not found in PATH");
|
|
|
|
tracing::error!("Please install it with `make install-server`")
|
2023-03-08 03:06:59 -07:00
|
|
|
}
|
|
|
|
}
|
2023-04-26 06:43:36 -06:00
|
|
|
status_sender
|
|
|
|
.send(ShardStatus::Failed((rank, err.to_string())))
|
|
|
|
.unwrap();
|
|
|
|
return;
|
|
|
|
}
|
|
|
|
};
|
|
|
|
|
|
|
|
// Redirect STDOUT to the console
|
|
|
|
let shard_stdout = p.stdout.take().unwrap();
|
|
|
|
|
|
|
|
thread::spawn(move || {
|
|
|
|
// Enter shard-manager tracing span
|
|
|
|
let stdout = BufReader::new(shard_stdout);
|
|
|
|
let _span = tracing::span!(tracing::Level::INFO, "shard-manager", rank = rank).entered();
|
|
|
|
for line in stdout.lines() {
|
|
|
|
// Parse loguru logs
|
|
|
|
if let Ok(log) = serde_json::from_str::<PythonLogMessage>(&line.unwrap()) {
|
|
|
|
log.trace();
|
|
|
|
}
|
|
|
|
}
|
|
|
|
});
|
|
|
|
|
|
|
|
let mut ready = false;
|
|
|
|
let start_time = Instant::now();
|
|
|
|
let mut wait_time = Instant::now();
|
|
|
|
loop {
|
|
|
|
// Process exited
|
|
|
|
if p.poll().is_some() {
|
|
|
|
let mut err = String::new();
|
|
|
|
p.stderr.take().unwrap().read_to_string(&mut err).unwrap();
|
|
|
|
status_sender
|
|
|
|
.send(ShardStatus::Failed((rank, err)))
|
|
|
|
.unwrap();
|
|
|
|
return;
|
|
|
|
}
|
|
|
|
|
|
|
|
// We received a shutdown signal
|
|
|
|
if *shutdown.lock().unwrap() {
|
|
|
|
p.terminate().unwrap();
|
|
|
|
let _ = p.wait_timeout(Duration::from_secs(90));
|
|
|
|
tracing::info!("Shard {rank} terminated");
|
|
|
|
return;
|
|
|
|
}
|
|
|
|
|
|
|
|
// Shard is ready
|
|
|
|
if uds.exists() && !ready {
|
|
|
|
tracing::info!("Shard {rank} ready in {:?}", start_time.elapsed());
|
|
|
|
status_sender.send(ShardStatus::Ready).unwrap();
|
|
|
|
ready = true;
|
|
|
|
} else if !ready && wait_time.elapsed() > Duration::from_secs(10) {
|
|
|
|
tracing::info!("Waiting for shard {rank} to be ready...");
|
|
|
|
wait_time = Instant::now();
|
|
|
|
}
|
|
|
|
sleep(Duration::from_millis(100));
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
fn shutdown_shards(shutdown: Arc<Mutex<bool>>, shutdown_receiver: &mpsc::Receiver<()>) {
|
|
|
|
tracing::info!("Shutting down shards");
|
|
|
|
// Update shutdown value to true
|
|
|
|
// This will be picked up by the shard manager
|
|
|
|
{
|
|
|
|
let mut shutdown = shutdown.lock().unwrap();
|
|
|
|
*shutdown = true;
|
|
|
|
}
|
|
|
|
|
|
|
|
// Wait for shards to shutdown
|
|
|
|
// This will block till all shutdown_sender are dropped
|
|
|
|
let _ = shutdown_receiver.recv();
|
|
|
|
}
|
|
|
|
|
|
|
|
fn num_cuda_devices() -> Option<usize> {
|
|
|
|
if let Ok(cuda_visible_devices) = env::var("CUDA_VISIBLE_DEVICES") {
|
|
|
|
let n_devices = cuda_visible_devices.split(',').count();
|
|
|
|
return Some(n_devices);
|
|
|
|
}
|
|
|
|
None
|
|
|
|
}
|
|
|
|
|
|
|
|
#[derive(Deserialize)]
|
|
|
|
#[serde(rename_all = "UPPERCASE")]
|
|
|
|
enum PythonLogLevelEnum {
|
|
|
|
Trace,
|
|
|
|
Debug,
|
|
|
|
Info,
|
|
|
|
Success,
|
|
|
|
Warning,
|
|
|
|
Error,
|
|
|
|
Critical,
|
|
|
|
}
|
|
|
|
|
|
|
|
#[derive(Deserialize)]
|
|
|
|
struct PythonLogLevel {
|
|
|
|
name: PythonLogLevelEnum,
|
|
|
|
}
|
|
|
|
|
|
|
|
#[derive(Deserialize)]
|
|
|
|
struct PythonLogRecord {
|
|
|
|
level: PythonLogLevel,
|
|
|
|
}
|
|
|
|
|
|
|
|
#[derive(Deserialize)]
|
|
|
|
struct PythonLogMessage {
|
|
|
|
text: String,
|
|
|
|
record: PythonLogRecord,
|
|
|
|
}
|
|
|
|
|
|
|
|
impl PythonLogMessage {
|
|
|
|
fn trace(&self) {
|
|
|
|
match self.record.level.name {
|
|
|
|
PythonLogLevelEnum::Trace => tracing::trace!("{}", self.text),
|
|
|
|
PythonLogLevelEnum::Debug => tracing::debug!("{}", self.text),
|
|
|
|
PythonLogLevelEnum::Info => tracing::info!("{}", self.text),
|
|
|
|
PythonLogLevelEnum::Success => tracing::info!("{}", self.text),
|
|
|
|
PythonLogLevelEnum::Warning => tracing::warn!("{}", self.text),
|
|
|
|
PythonLogLevelEnum::Error => tracing::error!("{}", self.text),
|
|
|
|
PythonLogLevelEnum::Critical => tracing::error!("{}", self.text),
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
fn find_num_shards(sharded: Option<bool>, num_shard: Option<usize>) -> usize {
|
|
|
|
// get the number of shards given `sharded` and `num_shard`
|
|
|
|
let num_shard = match (sharded, num_shard) {
|
|
|
|
(Some(true), None) => {
|
|
|
|
// try to default to the number of available GPUs
|
|
|
|
tracing::info!("Parsing num_shard from CUDA_VISIBLE_DEVICES");
|
|
|
|
let n_devices =
|
|
|
|
num_cuda_devices().expect("--num-shard and CUDA_VISIBLE_DEVICES are not set");
|
|
|
|
if n_devices <= 1 {
|
|
|
|
panic!("`sharded` is true but only found {n_devices} CUDA devices");
|
2023-03-08 03:06:59 -07:00
|
|
|
}
|
2023-04-26 06:43:36 -06:00
|
|
|
n_devices
|
2023-03-08 03:06:59 -07:00
|
|
|
}
|
2023-04-26 06:43:36 -06:00
|
|
|
(Some(true), Some(num_shard)) => {
|
|
|
|
// we can't have only one shard while sharded
|
|
|
|
if num_shard <= 1 {
|
|
|
|
panic!("`sharded` is true but `num_shard` <= 1");
|
|
|
|
}
|
|
|
|
num_shard
|
2023-03-08 05:53:41 -07:00
|
|
|
}
|
2023-04-26 06:43:36 -06:00
|
|
|
(Some(false), Some(num_shard)) => num_shard,
|
|
|
|
(Some(false), None) => 1,
|
|
|
|
(None, None) => num_cuda_devices().unwrap_or(1),
|
|
|
|
(None, Some(num_shard)) => num_shard,
|
2023-03-08 03:06:59 -07:00
|
|
|
};
|
2023-03-08 05:53:41 -07:00
|
|
|
if num_shard < 1 {
|
|
|
|
panic!("`num_shard` cannot be < 1");
|
|
|
|
}
|
2023-04-26 06:43:36 -06:00
|
|
|
num_shard
|
|
|
|
}
|
2023-03-08 05:53:41 -07:00
|
|
|
|
2023-04-26 06:43:36 -06:00
|
|
|
#[derive(Debug)]
|
|
|
|
enum LauncherError {
|
|
|
|
DownloadError,
|
|
|
|
ShardCannotStart,
|
|
|
|
ShardDisconnected,
|
|
|
|
ShardFailed,
|
|
|
|
WebserverFailed,
|
|
|
|
WebserverCannotStart,
|
|
|
|
}
|
2022-10-18 07:19:03 -06:00
|
|
|
|
2023-05-03 03:36:24 -06:00
|
|
|
fn download_convert_model(
|
|
|
|
args: &Args,
|
|
|
|
auto_convert: bool,
|
|
|
|
running: Arc<AtomicBool>,
|
|
|
|
) -> Result<(), LauncherError> {
|
2023-04-26 06:43:36 -06:00
|
|
|
let mut download_argv = vec![
|
|
|
|
"text-generation-server".to_string(),
|
|
|
|
"download-weights".to_string(),
|
|
|
|
args.model_id.to_string(),
|
|
|
|
"--extension".to_string(),
|
|
|
|
".safetensors".to_string(),
|
|
|
|
"--logger-level".to_string(),
|
|
|
|
"INFO".to_string(),
|
|
|
|
"--json-output".to_string(),
|
|
|
|
];
|
2023-03-06 06:39:36 -07:00
|
|
|
|
2023-05-03 03:36:24 -06:00
|
|
|
// Auto convert weights to safetensors
|
|
|
|
if auto_convert {
|
|
|
|
download_argv.push("--auto-convert".to_string());
|
|
|
|
}
|
|
|
|
|
2023-04-26 06:43:36 -06:00
|
|
|
// Model optional revision
|
|
|
|
if let Some(revision) = &args.revision {
|
|
|
|
download_argv.push("--revision".to_string());
|
|
|
|
download_argv.push(revision.to_string())
|
|
|
|
}
|
2023-02-14 05:02:16 -07:00
|
|
|
|
2023-04-26 06:43:36 -06:00
|
|
|
// Copy current process env
|
|
|
|
let mut env: Vec<(OsString, OsString)> = env::vars_os().collect();
|
2023-02-14 05:02:16 -07:00
|
|
|
|
2023-04-26 06:43:36 -06:00
|
|
|
// If huggingface_hub_cache is set, pass it to the shard
|
|
|
|
// Useful when running inside a docker container
|
|
|
|
if let Some(ref huggingface_hub_cache) = args.huggingface_hub_cache {
|
|
|
|
env.push(("HUGGINGFACE_HUB_CACHE".into(), huggingface_hub_cache.into()));
|
|
|
|
};
|
2023-02-14 05:02:16 -07:00
|
|
|
|
2023-04-26 06:43:36 -06:00
|
|
|
// Enable hf transfer for insane download speeds
|
|
|
|
let enable_hf_transfer = env::var("HF_HUB_ENABLE_HF_TRANSFER").unwrap_or("1".to_string());
|
|
|
|
env.push((
|
|
|
|
"HF_HUB_ENABLE_HF_TRANSFER".into(),
|
|
|
|
enable_hf_transfer.into(),
|
|
|
|
));
|
2023-04-09 12:00:05 -06:00
|
|
|
|
2023-04-26 06:43:36 -06:00
|
|
|
// Parse Inference API token
|
|
|
|
if let Ok(api_token) = env::var("HF_API_TOKEN") {
|
|
|
|
env.push(("HUGGING_FACE_HUB_TOKEN".into(), api_token.into()))
|
|
|
|
};
|
2023-02-18 06:04:11 -07:00
|
|
|
|
2023-04-26 06:43:36 -06:00
|
|
|
// Start process
|
|
|
|
tracing::info!("Starting download process.");
|
|
|
|
let mut download_process = match Popen::create(
|
|
|
|
&download_argv,
|
|
|
|
PopenConfig {
|
|
|
|
stdout: Redirection::Pipe,
|
|
|
|
stderr: Redirection::Pipe,
|
|
|
|
// Needed for the shutdown procedure
|
|
|
|
setpgid: true,
|
|
|
|
env: Some(env),
|
|
|
|
..Default::default()
|
|
|
|
},
|
|
|
|
) {
|
|
|
|
Ok(p) => p,
|
|
|
|
Err(err) => {
|
|
|
|
if let PopenError::IoError(ref err) = err {
|
|
|
|
if err.kind() == io::ErrorKind::NotFound {
|
|
|
|
tracing::error!("text-generation-server not found in PATH");
|
|
|
|
tracing::error!("Please install it with `make install-server`")
|
2023-02-14 05:02:16 -07:00
|
|
|
}
|
|
|
|
}
|
2023-04-26 06:43:36 -06:00
|
|
|
return Err(LauncherError::DownloadError);
|
|
|
|
}
|
|
|
|
};
|
2023-02-14 05:02:16 -07:00
|
|
|
|
2023-04-26 06:43:36 -06:00
|
|
|
// Redirect STDOUT to the console
|
|
|
|
let download_stdout = download_process.stdout.take().unwrap();
|
|
|
|
thread::spawn(move || {
|
|
|
|
// Enter download tracing span
|
|
|
|
let stdout = BufReader::new(download_stdout);
|
|
|
|
let _span = tracing::span!(tracing::Level::INFO, "download").entered();
|
|
|
|
for line in stdout.lines() {
|
|
|
|
// Parse loguru logs
|
|
|
|
if let Ok(log) = serde_json::from_str::<PythonLogMessage>(&line.unwrap()) {
|
|
|
|
log.trace();
|
2023-02-14 05:02:16 -07:00
|
|
|
}
|
2023-04-26 06:43:36 -06:00
|
|
|
}
|
|
|
|
});
|
2023-02-14 05:02:16 -07:00
|
|
|
|
2023-04-26 06:43:36 -06:00
|
|
|
loop {
|
|
|
|
if let Some(status) = download_process.poll() {
|
|
|
|
match status {
|
|
|
|
ExitStatus::Exited(exit_code) => {
|
|
|
|
if exit_code == 0 {
|
|
|
|
tracing::info!("Successfully downloaded weights.");
|
|
|
|
break;
|
|
|
|
} else {
|
|
|
|
let mut err = String::new();
|
|
|
|
download_process
|
|
|
|
.stderr
|
|
|
|
.take()
|
|
|
|
.unwrap()
|
|
|
|
.read_to_string(&mut err)
|
|
|
|
.unwrap();
|
|
|
|
tracing::error!("Download encountered an error: {err}");
|
|
|
|
return Err(LauncherError::DownloadError);
|
2023-02-14 05:02:16 -07:00
|
|
|
}
|
|
|
|
}
|
2023-04-26 06:43:36 -06:00
|
|
|
_ => {
|
|
|
|
tracing::error!("Download process exited with an unknown status.");
|
|
|
|
return Err(LauncherError::DownloadError);
|
|
|
|
}
|
2023-02-14 05:02:16 -07:00
|
|
|
}
|
|
|
|
}
|
2023-04-26 06:43:36 -06:00
|
|
|
if !running.load(Ordering::SeqCst) {
|
|
|
|
download_process.terminate().unwrap();
|
|
|
|
tracing::info!("Waiting for download process to gracefully shutdown");
|
|
|
|
download_process
|
|
|
|
.wait_timeout(Duration::from_secs(90))
|
|
|
|
.unwrap();
|
|
|
|
tracing::info!("Download process terminated");
|
|
|
|
return Ok(());
|
|
|
|
}
|
|
|
|
sleep(Duration::from_millis(100));
|
2023-02-14 05:02:16 -07:00
|
|
|
}
|
2023-04-26 06:43:36 -06:00
|
|
|
Ok(())
|
|
|
|
}
|
2023-02-14 05:02:16 -07:00
|
|
|
|
2023-04-26 12:23:54 -06:00
|
|
|
#[allow(clippy::too_many_arguments)]
|
2023-04-26 06:43:36 -06:00
|
|
|
fn spawn_shards(
|
|
|
|
num_shard: usize,
|
|
|
|
args: &Args,
|
|
|
|
shutdown: Arc<Mutex<bool>>,
|
|
|
|
shutdown_receiver: &mpsc::Receiver<()>,
|
|
|
|
shutdown_sender: mpsc::Sender<()>,
|
|
|
|
status_receiver: &mpsc::Receiver<ShardStatus>,
|
|
|
|
status_sender: mpsc::Sender<ShardStatus>,
|
|
|
|
running: Arc<AtomicBool>,
|
|
|
|
) -> Result<(), LauncherError> {
|
2022-10-18 07:19:03 -06:00
|
|
|
// Start shard processes
|
|
|
|
for rank in 0..num_shard {
|
2023-04-26 06:43:36 -06:00
|
|
|
let model_id = args.model_id.clone();
|
|
|
|
let revision = args.revision.clone();
|
|
|
|
let uds_path = args.shard_uds_path.clone();
|
|
|
|
let master_addr = args.master_addr.clone();
|
|
|
|
let huggingface_hub_cache = args.huggingface_hub_cache.clone();
|
|
|
|
let weights_cache_override = args.weights_cache_override.clone();
|
2022-10-18 07:19:03 -06:00
|
|
|
let status_sender = status_sender.clone();
|
|
|
|
let shutdown = shutdown.clone();
|
|
|
|
let shutdown_sender = shutdown_sender.clone();
|
2023-04-26 06:43:36 -06:00
|
|
|
let otlp_endpoint = args.otlp_endpoint.clone();
|
2023-04-26 12:23:54 -06:00
|
|
|
let quantize = args.quantize;
|
|
|
|
let master_port = args.master_port;
|
|
|
|
let disable_custom_kernels = args.disable_custom_kernels;
|
|
|
|
let watermark_gamma = args.watermark_gamma;
|
|
|
|
let watermark_delta = args.watermark_delta;
|
2022-10-18 07:19:03 -06:00
|
|
|
thread::spawn(move || {
|
|
|
|
shard_manager(
|
2023-02-03 04:43:37 -07:00
|
|
|
model_id,
|
2023-01-31 10:53:56 -07:00
|
|
|
revision,
|
2022-10-27 06:25:29 -06:00
|
|
|
quantize,
|
2022-10-18 07:19:03 -06:00
|
|
|
uds_path,
|
|
|
|
rank,
|
|
|
|
num_shard,
|
|
|
|
master_addr,
|
|
|
|
master_port,
|
2023-02-14 05:02:16 -07:00
|
|
|
huggingface_hub_cache,
|
|
|
|
weights_cache_override,
|
2023-02-15 08:23:45 -07:00
|
|
|
disable_custom_kernels,
|
2023-03-02 04:30:41 -07:00
|
|
|
watermark_gamma,
|
|
|
|
watermark_delta,
|
2023-02-13 05:02:45 -07:00
|
|
|
otlp_endpoint,
|
2022-10-18 07:19:03 -06:00
|
|
|
status_sender,
|
|
|
|
shutdown,
|
|
|
|
shutdown_sender,
|
|
|
|
)
|
|
|
|
});
|
|
|
|
}
|
|
|
|
drop(shutdown_sender);
|
|
|
|
|
|
|
|
// Wait for shard to start
|
|
|
|
let mut shard_ready = 0;
|
|
|
|
while running.load(Ordering::SeqCst) {
|
|
|
|
match status_receiver.try_recv() {
|
|
|
|
Ok(ShardStatus::Ready) => {
|
|
|
|
shard_ready += 1;
|
|
|
|
if shard_ready == num_shard {
|
|
|
|
break;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
Err(TryRecvError::Empty) => {
|
|
|
|
sleep(Duration::from_millis(100));
|
|
|
|
}
|
2023-04-13 03:07:11 -06:00
|
|
|
Ok(ShardStatus::Failed((rank, err))) => {
|
|
|
|
tracing::error!("Shard {} failed to start:\n{}", rank, err);
|
2023-04-26 12:23:54 -06:00
|
|
|
shutdown_shards(shutdown, shutdown_receiver);
|
2023-04-26 06:43:36 -06:00
|
|
|
return Err(LauncherError::ShardCannotStart);
|
2022-10-18 07:19:03 -06:00
|
|
|
}
|
|
|
|
Err(TryRecvError::Disconnected) => {
|
|
|
|
tracing::error!("Shard status channel disconnected");
|
2023-04-26 12:23:54 -06:00
|
|
|
shutdown_shards(shutdown, shutdown_receiver);
|
2023-04-26 06:43:36 -06:00
|
|
|
return Err(LauncherError::ShardDisconnected);
|
2022-10-18 07:19:03 -06:00
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
2023-04-26 06:43:36 -06:00
|
|
|
Ok(())
|
|
|
|
}
|
2022-10-18 07:19:03 -06:00
|
|
|
|
2023-04-26 06:43:36 -06:00
|
|
|
fn spawn_webserver(
|
|
|
|
args: Args,
|
|
|
|
shutdown: Arc<Mutex<bool>>,
|
|
|
|
shutdown_receiver: &mpsc::Receiver<()>,
|
|
|
|
) -> Result<Popen, LauncherError> {
|
2022-10-18 07:19:03 -06:00
|
|
|
// All shard started
|
|
|
|
// Start webserver
|
|
|
|
tracing::info!("Starting Webserver");
|
2022-11-02 10:29:56 -06:00
|
|
|
let mut argv = vec![
|
|
|
|
"text-generation-router".to_string(),
|
|
|
|
"--max-concurrent-requests".to_string(),
|
2023-04-26 06:43:36 -06:00
|
|
|
args.max_concurrent_requests.to_string(),
|
2023-03-09 07:30:54 -07:00
|
|
|
"--max-best-of".to_string(),
|
2023-04-26 06:43:36 -06:00
|
|
|
args.max_best_of.to_string(),
|
2023-03-03 08:01:25 -07:00
|
|
|
"--max-stop-sequences".to_string(),
|
2023-04-26 06:43:36 -06:00
|
|
|
args.max_stop_sequences.to_string(),
|
2022-11-02 10:29:56 -06:00
|
|
|
"--max-input-length".to_string(),
|
2023-04-26 06:43:36 -06:00
|
|
|
args.max_input_length.to_string(),
|
2023-03-03 08:01:25 -07:00
|
|
|
"--max-total-tokens".to_string(),
|
2023-04-26 06:43:36 -06:00
|
|
|
args.max_total_tokens.to_string(),
|
2023-04-24 09:59:00 -06:00
|
|
|
"--waiting-served-ratio".to_string(),
|
2023-04-26 06:43:36 -06:00
|
|
|
args.waiting_served_ratio.to_string(),
|
2022-11-02 10:29:56 -06:00
|
|
|
"--max-waiting-tokens".to_string(),
|
2023-04-26 06:43:36 -06:00
|
|
|
args.max_waiting_tokens.to_string(),
|
2022-11-02 10:29:56 -06:00
|
|
|
"--port".to_string(),
|
2023-04-26 06:43:36 -06:00
|
|
|
args.port.to_string(),
|
2022-11-02 10:29:56 -06:00
|
|
|
"--master-shard-uds-path".to_string(),
|
2023-04-26 06:43:36 -06:00
|
|
|
format!("{}-0", args.shard_uds_path),
|
2022-11-02 10:29:56 -06:00
|
|
|
"--tokenizer-name".to_string(),
|
2023-04-26 06:43:36 -06:00
|
|
|
args.model_id,
|
2022-11-02 10:29:56 -06:00
|
|
|
];
|
|
|
|
|
2023-04-24 09:59:00 -06:00
|
|
|
// Deprecate max_batch_size
|
2023-04-26 06:43:36 -06:00
|
|
|
if let Some(max_batch_size) = args.max_batch_size {
|
2023-04-24 09:59:00 -06:00
|
|
|
argv.push("--max-batch-size".to_string());
|
2023-04-26 06:43:36 -06:00
|
|
|
argv.push(max_batch_size.to_string())
|
|
|
|
} else {
|
|
|
|
argv.push("--max-batch-total-tokens".to_string());
|
|
|
|
argv.push(args.max_batch_total_tokens.to_string())
|
2022-10-18 07:19:03 -06:00
|
|
|
}
|
|
|
|
|
2023-04-26 06:43:36 -06:00
|
|
|
// Model optional revision
|
|
|
|
if let Some(ref revision) = args.revision {
|
|
|
|
argv.push("--revision".to_string());
|
|
|
|
argv.push(revision.to_string())
|
2022-10-27 06:25:29 -06:00
|
|
|
}
|
|
|
|
|
2023-04-26 06:43:36 -06:00
|
|
|
if args.json_output {
|
|
|
|
argv.push("--json-output".to_string());
|
2023-01-31 10:53:56 -07:00
|
|
|
}
|
|
|
|
|
2023-02-13 05:02:45 -07:00
|
|
|
// OpenTelemetry
|
2023-04-26 06:43:36 -06:00
|
|
|
if let Some(otlp_endpoint) = args.otlp_endpoint {
|
|
|
|
argv.push("--otlp-endpoint".to_string());
|
|
|
|
argv.push(otlp_endpoint);
|
|
|
|
}
|
|
|
|
|
|
|
|
// CORS origins
|
|
|
|
for origin in args.cors_allow_origin.into_iter() {
|
|
|
|
argv.push("--cors-allow-origin".to_string());
|
|
|
|
argv.push(origin);
|
2023-02-13 05:02:45 -07:00
|
|
|
}
|
|
|
|
|
2023-02-16 03:20:23 -07:00
|
|
|
// Copy current process env
|
|
|
|
let mut env: Vec<(OsString, OsString)> = env::vars_os().collect();
|
|
|
|
|
2023-04-09 12:00:05 -06:00
|
|
|
// Parse Inference API token
|
|
|
|
if let Ok(api_token) = env::var("HF_API_TOKEN") {
|
|
|
|
env.push(("HUGGING_FACE_HUB_TOKEN".into(), api_token.into()))
|
|
|
|
};
|
2023-02-18 06:04:11 -07:00
|
|
|
|
2023-04-26 06:43:36 -06:00
|
|
|
let mut webserver = match Popen::create(
|
|
|
|
&argv,
|
2022-10-18 07:19:03 -06:00
|
|
|
PopenConfig {
|
|
|
|
stdout: Redirection::Pipe,
|
|
|
|
stderr: Redirection::Pipe,
|
|
|
|
// Needed for the shutdown procedure
|
|
|
|
setpgid: true,
|
2022-10-22 12:00:15 -06:00
|
|
|
env: Some(env),
|
2022-10-18 07:19:03 -06:00
|
|
|
..Default::default()
|
|
|
|
},
|
|
|
|
) {
|
|
|
|
Ok(p) => p,
|
|
|
|
Err(err) => {
|
2023-04-26 06:43:36 -06:00
|
|
|
tracing::error!("Failed to start webserver: {}", err);
|
|
|
|
if let PopenError::IoError(err) = err {
|
2022-10-18 07:19:03 -06:00
|
|
|
if err.kind() == io::ErrorKind::NotFound {
|
2023-04-26 06:43:36 -06:00
|
|
|
tracing::error!("text-generation-router not found in PATH");
|
|
|
|
tracing::error!("Please install it with `make install-router`")
|
2022-10-18 07:19:03 -06:00
|
|
|
}
|
2023-04-26 06:43:36 -06:00
|
|
|
} else {
|
|
|
|
tracing::error!("{}", err);
|
2022-10-18 07:19:03 -06:00
|
|
|
}
|
2023-04-26 06:43:36 -06:00
|
|
|
|
2023-04-26 12:23:54 -06:00
|
|
|
shutdown_shards(shutdown, shutdown_receiver);
|
2023-04-26 06:43:36 -06:00
|
|
|
return Err(LauncherError::WebserverCannotStart);
|
2022-10-18 07:19:03 -06:00
|
|
|
}
|
|
|
|
};
|
|
|
|
|
2023-04-26 06:43:36 -06:00
|
|
|
// Redirect STDOUT and STDERR to the console
|
|
|
|
let webserver_stdout = webserver.stdout.take().unwrap();
|
|
|
|
let webserver_stderr = webserver.stderr.take().unwrap();
|
2023-01-05 04:01:23 -07:00
|
|
|
|
|
|
|
thread::spawn(move || {
|
2023-04-26 06:43:36 -06:00
|
|
|
let stdout = BufReader::new(webserver_stdout);
|
|
|
|
let stderr = BufReader::new(webserver_stderr);
|
2023-01-05 04:01:23 -07:00
|
|
|
for line in stdout.lines() {
|
2023-04-26 06:43:36 -06:00
|
|
|
println!("{}", line.unwrap());
|
2023-01-05 04:01:23 -07:00
|
|
|
}
|
2023-04-26 06:43:36 -06:00
|
|
|
for line in stderr.lines() {
|
|
|
|
println!("{}", line.unwrap());
|
2022-10-18 07:19:03 -06:00
|
|
|
}
|
2023-04-26 06:43:36 -06:00
|
|
|
});
|
|
|
|
Ok(webserver)
|
|
|
|
}
|
2022-10-18 07:19:03 -06:00
|
|
|
|
2023-04-26 06:43:36 -06:00
|
|
|
fn main() -> Result<(), LauncherError> {
|
|
|
|
// Pattern match configuration
|
|
|
|
let args = Args::parse();
|
2022-10-18 07:19:03 -06:00
|
|
|
|
2023-04-26 06:43:36 -06:00
|
|
|
if args.json_output {
|
|
|
|
tracing_subscriber::fmt().json().init();
|
|
|
|
} else {
|
|
|
|
tracing_subscriber::fmt().compact().init();
|
2022-10-18 07:19:03 -06:00
|
|
|
}
|
|
|
|
|
2023-05-02 07:43:19 -06:00
|
|
|
if args.env {
|
|
|
|
let env_runtime = env_runtime::Env::new();
|
|
|
|
tracing::info!("{}", env_runtime);
|
|
|
|
}
|
|
|
|
|
2023-04-26 06:43:36 -06:00
|
|
|
tracing::info!("{:?}", args);
|
|
|
|
|
|
|
|
let num_shard = find_num_shards(args.sharded, args.num_shard);
|
|
|
|
if num_shard > 1 {
|
|
|
|
tracing::info!("Sharding model on {num_shard} processes");
|
2022-10-18 07:19:03 -06:00
|
|
|
}
|
|
|
|
|
2023-04-26 06:43:36 -06:00
|
|
|
// Signal handler
|
|
|
|
let running = Arc::new(AtomicBool::new(true));
|
|
|
|
let r = running.clone();
|
|
|
|
ctrlc::set_handler(move || {
|
|
|
|
r.store(false, Ordering::SeqCst);
|
|
|
|
})
|
|
|
|
.expect("Error setting Ctrl-C handler");
|
2023-03-08 05:53:41 -07:00
|
|
|
|
2023-05-03 03:36:24 -06:00
|
|
|
// auto_convert is only needed for sharded models as we do not require safetensors in
|
|
|
|
// single shard mode
|
|
|
|
let auto_convert = num_shard > 1;
|
|
|
|
// Download and convert model weights
|
|
|
|
download_convert_model(&args, auto_convert, running.clone())?;
|
2023-04-16 16:26:47 -06:00
|
|
|
|
2023-04-26 06:43:36 -06:00
|
|
|
// Shared shutdown bool
|
|
|
|
let shutdown = Arc::new(Mutex::new(false));
|
|
|
|
// Shared shutdown channel
|
|
|
|
// When shutting down, the main thread will wait for all senders to be dropped
|
|
|
|
let (shutdown_sender, shutdown_receiver) = mpsc::channel();
|
2023-04-16 16:26:47 -06:00
|
|
|
|
2023-04-26 06:43:36 -06:00
|
|
|
// Shared channel to track shard status
|
|
|
|
let (status_sender, status_receiver) = mpsc::channel();
|
2023-04-16 16:26:47 -06:00
|
|
|
|
2023-04-26 06:43:36 -06:00
|
|
|
spawn_shards(
|
|
|
|
num_shard,
|
|
|
|
&args,
|
|
|
|
shutdown.clone(),
|
|
|
|
&shutdown_receiver,
|
|
|
|
shutdown_sender,
|
|
|
|
&status_receiver,
|
|
|
|
status_sender,
|
|
|
|
running.clone(),
|
|
|
|
)?;
|
2023-04-16 16:26:47 -06:00
|
|
|
|
2023-04-26 06:43:36 -06:00
|
|
|
// We might have received a termination signal
|
|
|
|
if !running.load(Ordering::SeqCst) {
|
|
|
|
shutdown_shards(shutdown, &shutdown_receiver);
|
|
|
|
return Ok(());
|
|
|
|
}
|
2023-04-16 16:26:47 -06:00
|
|
|
|
2023-04-26 06:43:36 -06:00
|
|
|
let mut webserver = spawn_webserver(args, shutdown.clone(), &shutdown_receiver)?;
|
|
|
|
|
|
|
|
// Default exit code
|
|
|
|
let mut exit_code = Ok(());
|
|
|
|
|
|
|
|
while running.load(Ordering::SeqCst) {
|
|
|
|
if let Ok(ShardStatus::Failed((rank, err))) = status_receiver.try_recv() {
|
|
|
|
tracing::error!("Shard {rank} failed:\n{err}");
|
|
|
|
exit_code = Err(LauncherError::ShardFailed);
|
|
|
|
break;
|
|
|
|
};
|
|
|
|
|
|
|
|
match webserver.poll() {
|
|
|
|
Some(_) => {
|
|
|
|
tracing::error!("Webserver Crashed");
|
|
|
|
shutdown_shards(shutdown, &shutdown_receiver);
|
|
|
|
return Err(LauncherError::WebserverFailed);
|
|
|
|
}
|
|
|
|
None => {
|
|
|
|
sleep(Duration::from_millis(100));
|
|
|
|
}
|
|
|
|
};
|
2023-04-16 16:26:47 -06:00
|
|
|
}
|
2023-04-26 06:43:36 -06:00
|
|
|
|
|
|
|
// Graceful termination
|
|
|
|
webserver.terminate().unwrap();
|
|
|
|
tracing::info!("Waiting for webserver to gracefully shutdown");
|
|
|
|
webserver.wait_timeout(Duration::from_secs(90)).unwrap();
|
|
|
|
tracing::info!("Webserver terminated");
|
|
|
|
shutdown_shards(shutdown, &shutdown_receiver);
|
|
|
|
|
|
|
|
exit_code
|
2023-04-16 16:26:47 -06:00
|
|
|
}
|