2023-05-12 06:46:41 -06:00
use clap ::{ Parser , ValueEnum } ;
2023-07-13 06:22:37 -06:00
use nix ::sys ::signal ::{ self , Signal } ;
use nix ::unistd ::Pid ;
2023-04-16 16:26:47 -06:00
use serde ::Deserialize ;
2022-10-22 12:00:15 -06:00
use std ::env ;
2023-02-16 03:20:23 -07:00
use std ::ffi ::OsString ;
2022-10-18 07:19:03 -06:00
use std ::io ::{ BufRead , BufReader , Read } ;
2023-07-13 06:22:37 -06:00
use std ::os ::unix ::process ::{ CommandExt , ExitStatusExt } ;
2022-10-18 07:19:03 -06:00
use std ::path ::Path ;
2023-07-13 06:22:37 -06:00
use std ::process ::{ Child , Command , Stdio } ;
2022-10-18 07:19:03 -06:00
use std ::sync ::atomic ::{ AtomicBool , Ordering } ;
use std ::sync ::mpsc ::TryRecvError ;
2023-06-30 15:09:20 -06:00
use std ::sync ::{ mpsc , Arc } ;
2022-10-18 07:19:03 -06:00
use std ::thread ;
use std ::thread ::sleep ;
use std ::time ::{ Duration , Instant } ;
use std ::{ fs , io } ;
2023-05-02 07:43:19 -06:00
mod env_runtime ;
2023-05-12 06:46:41 -06:00
#[ derive(Clone, Copy, Debug, ValueEnum) ]
enum Quantization {
Bitsandbytes ,
Gptq ,
}
impl std ::fmt ::Display for Quantization {
fn fmt ( & self , f : & mut std ::fmt ::Formatter < '_ > ) -> std ::fmt ::Result {
// To keep in track with `server`.
match self {
Quantization ::Bitsandbytes = > {
write! ( f , " bitsandbytes " )
}
Quantization ::Gptq = > {
write! ( f , " gptq " )
}
}
}
}
2023-06-30 12:30:09 -06:00
#[ derive(Clone, Copy, Debug, ValueEnum) ]
enum Dtype {
Float16 ,
BFloat16 ,
}
impl std ::fmt ::Display for Dtype {
fn fmt ( & self , f : & mut std ::fmt ::Formatter < '_ > ) -> std ::fmt ::Result {
// To keep in track with `server`.
match self {
Dtype ::Float16 = > {
write! ( f , " float16 " )
}
Dtype ::BFloat16 = > {
write! ( f , " bfloat16 " )
}
}
}
}
2022-10-18 07:19:03 -06:00
/// App Configuration
#[ derive(Parser, Debug) ]
#[ clap(author, version, about, long_about = None) ]
struct Args {
2023-04-29 03:53:42 -06:00
/// The name of the model to load.
/// Can be a MODEL_ID as listed on <https://hf.co/models> like
/// `gpt2` or `OpenAssistant/oasst-sft-1-pythia-12b`.
/// Or it can be a local directory containing the necessary files
/// as saved by `save_pretrained(...)` methods of transformers
2022-10-18 07:19:03 -06:00
#[ clap(default_value = " bigscience/bloom-560m " , long, env) ]
2023-02-03 04:43:37 -07:00
model_id : String ,
2023-04-29 03:53:42 -06:00
/// The actual revision of the model if you're referring to a model
/// on the hub. You can use a specific commit id or a branch like `refs/pr/2`.
2022-10-18 07:19:03 -06:00
#[ clap(long, env) ]
2023-01-31 10:53:56 -07:00
revision : Option < String > ,
2023-04-29 03:53:42 -06:00
2023-07-13 06:22:37 -06:00
/// The number of tokenizer workers used for payload validation and truncation inside the
/// router.
#[ clap(default_value = " 2 " , long, env) ]
validation_workers : usize ,
2023-05-23 12:40:39 -06:00
/// Whether to shard the model across multiple GPUs
2023-04-29 03:53:42 -06:00
/// By default text-generation-inference will use all available GPUs to run
/// the model. Setting it to `false` deactivates `num_shard`.
2023-03-08 03:06:59 -07:00
#[ clap(long, env) ]
sharded : Option < bool > ,
2023-04-29 03:53:42 -06:00
/// The number of shards to use if you don't want to use all GPUs on a given machine.
2023-06-12 05:59:22 -06:00
/// You can use `CUDA_VISIBLE_DEVICES=0,1 text-generation-launcher... --num_shard 2`
/// and `CUDA_VISIBLE_DEVICES=2,3 text-generation-launcher... --num_shard 2` to
2023-04-29 03:53:42 -06:00
/// launch 2 copies with 2 shard each on a given machine with 4 GPUs for instance.
2023-03-08 03:06:59 -07:00
#[ clap(long, env) ]
num_shard : Option < usize > ,
2023-04-29 03:53:42 -06:00
2023-05-23 12:40:39 -06:00
/// Whether you want the model to be quantized. This will use `bitsandbytes` for
2023-05-12 06:46:41 -06:00
/// quantization on the fly, or `gptq`.
#[ clap(long, env, value_enum) ]
quantize : Option < Quantization > ,
2023-04-29 03:53:42 -06:00
2023-06-30 12:30:09 -06:00
/// The dtype to be forced upon the model. This option cannot be used with `--quantize`.
#[ clap(long, env, value_enum) ]
dtype : Option < Dtype > ,
2023-05-23 12:40:39 -06:00
/// Whether you want to execute hub modelling code. Explicitly passing a `revision` is
/// encouraged when loading a model with custom code to ensure no malicious code has been
/// contributed in a newer revision.
#[ clap(long, env, value_enum) ]
trust_remote_code : bool ,
2023-04-29 03:53:42 -06:00
/// The maximum amount of concurrent requests for this particular deployment.
/// Having a low limit will refuse clients requests instead of having them
/// wait for too long and is usually good to handle backpressure correctly.
2022-10-18 07:19:03 -06:00
#[ clap(default_value = " 128 " , long, env) ]
max_concurrent_requests : usize ,
2023-04-29 03:53:42 -06:00
/// This is the maximum allowed value for clients to set `best_of`.
/// Best of makes `n` generations at the same time, and return the best
/// in terms of overall log probability over the entire generated sequence
2023-03-09 07:30:54 -07:00
#[ clap(default_value = " 2 " , long, env) ]
max_best_of : usize ,
2023-04-29 03:53:42 -06:00
/// This is the maximum allowed value for clients to set `stop_sequences`.
/// Stop sequences are used to allow the model to stop on more than just
/// the EOS token, and enable more complex "prompting" where users can preprompt
/// the model in a specific way and define their "own" stop token aligned with
/// their prompt.
2023-03-03 08:01:25 -07:00
#[ clap(default_value = " 4 " , long, env) ]
max_stop_sequences : usize ,
2023-04-29 03:53:42 -06:00
/// This is the maximum allowed input length (expressed in number of tokens)
/// for users. The larger this value, the longer prompt users can send which
/// can impact the overall memory required to handle the load.
/// Please note that some models have a finite range of sequence they can handle.
2023-06-30 12:07:49 -06:00
#[ clap(default_value = " 1024 " , long, env) ]
2022-10-18 07:19:03 -06:00
max_input_length : usize ,
2023-04-29 03:53:42 -06:00
/// This is the most important value to set as it defines the "memory budget"
/// of running clients requests.
/// Clients will send input sequences and ask to generate `max_new_tokens`
/// on top. with a value of `1512` users can send either a prompt of
/// `1000` and ask for `512` new tokens, or send a prompt of `1` and ask for
/// `1511` max_new_tokens.
/// The larger this value, the larger amount each request will be in your RAM
/// and the less effective batching can be.
2023-06-30 12:07:49 -06:00
#[ clap(default_value = " 2048 " , long, env) ]
2023-03-03 08:01:25 -07:00
max_total_tokens : usize ,
2023-04-29 03:53:42 -06:00
/// This represents the ratio of waiting queries vs running queries where
/// you want to start considering pausing the running queries to include the waiting
/// ones into the same batch.
/// `waiting_served_ratio=1.2` Means when 12 queries are waiting and there's
/// only 10 queries left in the current batch we check if we can fit those 12
/// waiting queries into the batching strategy, and if yes, then batching happens
/// delaying the 10 running queries by a `prefill` run.
///
/// This setting is only applied if there is room in the batch
/// as defined by `max_batch_total_tokens`.
2023-04-24 09:59:00 -06:00
#[ clap(default_value = " 1.2 " , long, env) ]
waiting_served_ratio : f32 ,
2023-04-29 03:53:42 -06:00
2023-06-30 11:09:59 -06:00
/// Limits the number of tokens for the prefill operation.
/// Since this operation take the most memory and is compute bound, it is interesting
/// to limit the number of requests that can be sent.
#[ clap(default_value = " 4096 " , long, env) ]
max_batch_prefill_tokens : u32 ,
2023-04-29 03:53:42 -06:00
/// **IMPORTANT** This is one critical control to allow maximum usage
/// of the available hardware.
///
/// This represents the total amount of potential tokens within a batch.
/// When using padding (not recommended) this would be equivalent of
/// `batch_size` * `max_total_tokens`.
///
/// However in the non-padded (flash attention) version this can be much finer.
///
/// For `max_batch_total_tokens=1000`, you could fit `10` queries of `total_tokens=100`
/// or a single query of `1000` tokens.
///
/// Overall this number should be the largest possible amount that fits the
/// remaining memory (after the model is loaded). Since the actual memory overhead
/// depends on other parameters like if you're using quantization, flash attention
/// or the model implementation, text-generation-inference cannot infer this number
/// automatically.
2023-06-30 11:09:59 -06:00
#[ clap(default_value = " 16000 " , long, env) ]
2023-04-24 09:59:00 -06:00
max_batch_total_tokens : u32 ,
2023-04-29 03:53:42 -06:00
/// This setting defines how many tokens can be passed before forcing the waiting
/// queries to be put on the batch (if the size of the batch allows for it).
/// New queries require 1 `prefill` forward, which is different from `decode`
/// and therefore you need to pause the running batch in order to run `prefill`
/// to create the correct values for the waiting queries to be able to join the batch.
///
/// With a value too small, queries will always "steal" the compute to run `prefill`
/// and running queries will be delayed by a lot.
///
/// With a value too big, waiting queries could wait for a very long time
/// before being allowed a slot in the running batch. If your server is busy
/// that means that requests that could run in ~2s on an empty server could
/// end up running in ~20s because the query had to wait for 18s.
///
/// This number is expressed in number of tokens to make it a bit more
/// "model" agnostic, but what should really matter is the overall latency
/// for end users.
2022-10-21 08:40:05 -06:00
#[ clap(default_value = " 20 " , long, env) ]
max_waiting_tokens : usize ,
2023-04-29 03:53:42 -06:00
2023-07-05 10:28:45 -06:00
/// The IP address to listen on
#[ clap(default_value = " 0.0.0.0 " , long, env) ]
hostname : String ,
2023-04-29 03:53:42 -06:00
/// The port to listen on.
2023-06-30 11:09:59 -06:00
#[ clap(default_value = " 3000 " , long, short, env) ]
2022-10-18 07:19:03 -06:00
port : u16 ,
2023-04-29 03:53:42 -06:00
/// The name of the socket for gRPC communication between the webserver
/// and the shards.
2022-10-18 07:19:03 -06:00
#[ clap(default_value = " /tmp/text-generation-server " , long, env) ]
shard_uds_path : String ,
2023-04-29 03:53:42 -06:00
/// The address the master shard will listen on. (setting used by torch distributed)
2023-02-08 09:53:33 -07:00
#[ clap(default_value = " localhost " , long, env) ]
2022-10-18 07:19:03 -06:00
master_addr : String ,
2023-04-29 03:53:42 -06:00
/// The address the master port will listen on. (setting used by torch distributed)
2023-02-08 09:53:33 -07:00
#[ clap(default_value = " 29500 " , long, env) ]
2022-10-18 07:19:03 -06:00
master_port : usize ,
2023-04-29 03:53:42 -06:00
/// The location of the huggingface hub cache.
/// Used to override the location if you want to provide a mounted disk for instance
2022-11-02 10:29:56 -06:00
#[ clap(long, env) ]
2023-02-14 05:02:16 -07:00
huggingface_hub_cache : Option < String > ,
2023-04-29 03:53:42 -06:00
/// The location of the huggingface hub cache.
/// Used to override the location if you want to provide a mounted disk for instance
2023-02-14 05:02:16 -07:00
#[ clap(long, env) ]
weights_cache_override : Option < String > ,
2023-04-29 03:53:42 -06:00
/// For some models (like bloom), text-generation-inference implemented custom
/// cuda kernels to speed up inference. Those kernels were only tested on A100.
/// Use this flag to disable them if you're running on different hardware and
/// encounter issues.
2023-02-14 05:02:16 -07:00
#[ clap(long, env) ]
2023-02-15 08:23:45 -07:00
disable_custom_kernels : bool ,
2023-04-29 03:53:42 -06:00
/// Outputs the logs in JSON format (useful for telemetry)
2023-02-15 08:23:45 -07:00
#[ clap(long, env) ]
2022-11-02 10:29:56 -06:00
json_output : bool ,
2023-04-29 03:53:42 -06:00
2023-02-13 05:02:45 -07:00
#[ clap(long, env) ]
otlp_endpoint : Option < String > ,
2023-04-29 03:53:42 -06:00
2023-02-17 10:22:00 -07:00
#[ clap(long, env) ]
cors_allow_origin : Vec < String > ,
2023-03-02 04:30:41 -07:00
#[ clap(long, env) ]
watermark_gamma : Option < f32 > ,
#[ clap(long, env) ]
watermark_delta : Option < f32 > ,
2023-05-02 07:43:19 -06:00
2023-06-16 08:25:11 -06:00
/// Enable ngrok tunneling
#[ clap(long, env) ]
ngrok : bool ,
/// ngrok authentication token
#[ clap(long, env) ]
ngrok_authtoken : Option < String > ,
/// ngrok domain name where the axum webserver will be available at
#[ clap(long, env) ]
ngrok_domain : Option < String > ,
/// ngrok basic auth username
#[ clap(long, env) ]
ngrok_username : Option < String > ,
/// ngrok basic auth password
#[ clap(long, env) ]
ngrok_password : Option < String > ,
2023-05-02 07:43:19 -06:00
/// Display a lot of information about your runtime environment
#[ clap(long, short, action) ]
env : bool ,
2022-10-18 07:19:03 -06:00
}
2023-04-26 06:43:36 -06:00
#[ derive(Debug) ]
enum ShardStatus {
Ready ,
2023-06-30 15:09:20 -06:00
Failed ( ( usize , Option < String > ) ) ,
2023-04-26 06:43:36 -06:00
}
2023-02-15 08:11:32 -07:00
2023-04-26 06:43:36 -06:00
#[ allow(clippy::too_many_arguments) ]
fn shard_manager (
model_id : String ,
revision : Option < String > ,
2023-05-12 06:46:41 -06:00
quantize : Option < Quantization > ,
2023-06-30 12:30:09 -06:00
dtype : Option < Dtype > ,
2023-05-23 12:40:39 -06:00
trust_remote_code : bool ,
2023-04-26 06:43:36 -06:00
uds_path : String ,
rank : usize ,
world_size : usize ,
master_addr : String ,
master_port : usize ,
huggingface_hub_cache : Option < String > ,
weights_cache_override : Option < String > ,
disable_custom_kernels : bool ,
watermark_gamma : Option < f32 > ,
watermark_delta : Option < f32 > ,
otlp_endpoint : Option < String > ,
status_sender : mpsc ::Sender < ShardStatus > ,
2023-06-30 15:09:20 -06:00
shutdown : Arc < AtomicBool > ,
2023-04-26 06:43:36 -06:00
_shutdown_sender : mpsc ::Sender < ( ) > ,
) {
// Get UDS path
let uds_string = format! ( " {uds_path} - {rank} " ) ;
let uds = Path ::new ( & uds_string ) ;
// Clean previous runs
2023-07-13 06:22:37 -06:00
if uds . exists ( ) {
fs ::remove_file ( uds ) . unwrap ( ) ;
}
2023-04-26 06:43:36 -06:00
// Process args
let mut shard_argv = vec! [
" serve " . to_string ( ) ,
model_id ,
" --uds-path " . to_string ( ) ,
uds_path ,
" --logger-level " . to_string ( ) ,
" INFO " . to_string ( ) ,
" --json-output " . to_string ( ) ,
] ;
2023-05-23 12:40:39 -06:00
// Activate trust remote code
if trust_remote_code {
shard_argv . push ( " --trust-remote-code " . to_string ( ) ) ;
}
2023-04-26 06:43:36 -06:00
// Activate tensor parallelism
if world_size > 1 {
shard_argv . push ( " --sharded " . to_string ( ) ) ;
2023-02-15 08:11:32 -07:00
}
2023-05-12 06:46:41 -06:00
if let Some ( quantize ) = quantize {
shard_argv . push ( " --quantize " . to_string ( ) ) ;
shard_argv . push ( quantize . to_string ( ) )
2023-04-26 06:43:36 -06:00
}
2023-02-15 08:11:32 -07:00
2023-06-30 12:30:09 -06:00
if let Some ( dtype ) = dtype {
shard_argv . push ( " --dtype " . to_string ( ) ) ;
shard_argv . push ( dtype . to_string ( ) )
}
2023-04-26 06:43:36 -06:00
// Model optional revision
if let Some ( revision ) = revision {
shard_argv . push ( " --revision " . to_string ( ) ) ;
shard_argv . push ( revision )
}
2022-10-18 07:19:03 -06:00
2023-04-26 06:43:36 -06:00
// OpenTelemetry
if let Some ( otlp_endpoint ) = otlp_endpoint {
shard_argv . push ( " --otlp-endpoint " . to_string ( ) ) ;
shard_argv . push ( otlp_endpoint ) ;
}
// Copy current process env
let mut env : Vec < ( OsString , OsString ) > = env ::vars_os ( ) . collect ( ) ;
2023-06-30 11:09:59 -06:00
// Use cuda allocator. It leads to less memory fragmentation
env . push ( (
" PYTORCH_CUDA_ALLOC_CONF " . into ( ) ,
" backend:cudaMallocAsync " . into ( ) ,
) ) ;
2023-04-26 06:43:36 -06:00
// Torch Distributed Env vars
env . push ( ( " RANK " . into ( ) , rank . to_string ( ) . into ( ) ) ) ;
env . push ( ( " WORLD_SIZE " . into ( ) , world_size . to_string ( ) . into ( ) ) ) ;
env . push ( ( " MASTER_ADDR " . into ( ) , master_addr . into ( ) ) ) ;
env . push ( ( " MASTER_PORT " . into ( ) , master_port . to_string ( ) . into ( ) ) ) ;
env . push ( ( " NCCL_ASYNC_ERROR_HANDLING " . into ( ) , " 1 " . into ( ) ) ) ;
// Safetensors load fast
env . push ( ( " SAFETENSORS_FAST_GPU " . into ( ) , " 1 " . into ( ) ) ) ;
// Enable hf transfer for insane download speeds
let enable_hf_transfer = env ::var ( " HF_HUB_ENABLE_HF_TRANSFER " ) . unwrap_or ( " 1 " . to_string ( ) ) ;
env . push ( (
" HF_HUB_ENABLE_HF_TRANSFER " . into ( ) ,
enable_hf_transfer . into ( ) ,
) ) ;
// Parse Inference API token
if let Ok ( api_token ) = env ::var ( " HF_API_TOKEN " ) {
env . push ( ( " HUGGING_FACE_HUB_TOKEN " . into ( ) , api_token . into ( ) ) )
} ;
// If huggingface_hub_cache is some, pass it to the shard
// Useful when running inside a docker container
if let Some ( huggingface_hub_cache ) = huggingface_hub_cache {
env . push ( ( " HUGGINGFACE_HUB_CACHE " . into ( ) , huggingface_hub_cache . into ( ) ) ) ;
} ;
// If weights_cache_override is some, pass it to the shard
// Useful when running inside a HuggingFace Inference Endpoint
if let Some ( weights_cache_override ) = weights_cache_override {
env . push ( (
" WEIGHTS_CACHE_OVERRIDE " . into ( ) ,
weights_cache_override . into ( ) ,
) ) ;
} ;
// If disable_custom_kernels is true, pass it to the shard as an env var
if disable_custom_kernels {
env . push ( ( " DISABLE_CUSTOM_KERNELS " . into ( ) , " True " . into ( ) ) )
}
// Watermark Gamma
if let Some ( watermark_gamma ) = watermark_gamma {
env . push ( ( " WATERMARK_GAMMA " . into ( ) , watermark_gamma . to_string ( ) . into ( ) ) )
}
// Watermark Delta
if let Some ( watermark_delta ) = watermark_delta {
env . push ( ( " WATERMARK_DELTA " . into ( ) , watermark_delta . to_string ( ) . into ( ) ) )
}
// Start process
tracing ::info! ( " Starting shard {rank} " ) ;
2023-07-13 06:22:37 -06:00
let mut p = match Command ::new ( " text-generation-server " )
. args ( shard_argv )
. envs ( env )
. stdout ( Stdio ::piped ( ) )
. stderr ( Stdio ::piped ( ) )
. process_group ( 0 )
. spawn ( )
{
2023-04-26 06:43:36 -06:00
Ok ( p ) = > p ,
Err ( err ) = > {
2023-07-13 06:22:37 -06:00
if err . kind ( ) = = io ::ErrorKind ::NotFound {
tracing ::error! ( " text-generation-server not found in PATH " ) ;
tracing ::error! ( " Please install it with `make install-server` " )
} else {
tracing ::error! ( " {} " , err ) ;
2023-03-08 03:06:59 -07:00
}
2023-07-13 06:22:37 -06:00
2023-04-26 06:43:36 -06:00
status_sender
2023-06-30 15:09:20 -06:00
. send ( ShardStatus ::Failed ( ( rank , Some ( err . to_string ( ) ) ) ) )
2023-04-26 06:43:36 -06:00
. unwrap ( ) ;
return ;
}
} ;
// Redirect STDOUT to the console
2023-06-30 15:09:20 -06:00
let shard_stdout_reader = BufReader ::new ( p . stdout . take ( ) . unwrap ( ) ) ;
let mut shard_stderr_reader = BufReader ::new ( p . stderr . take ( ) . unwrap ( ) ) ;
2023-04-26 06:43:36 -06:00
thread ::spawn ( move | | {
// Enter shard-manager tracing span
let _span = tracing ::span! ( tracing ::Level ::INFO , " shard-manager " , rank = rank ) . entered ( ) ;
2023-06-30 15:09:20 -06:00
for line in shard_stdout_reader . lines ( ) {
2023-04-26 06:43:36 -06:00
// Parse loguru logs
if let Ok ( log ) = serde_json ::from_str ::< PythonLogMessage > ( & line . unwrap ( ) ) {
log . trace ( ) ;
}
}
} ) ;
let mut ready = false ;
let start_time = Instant ::now ( ) ;
let mut wait_time = Instant ::now ( ) ;
loop {
// Process exited
2023-07-13 06:22:37 -06:00
if let Some ( exit_status ) = p . try_wait ( ) . unwrap ( ) {
2023-06-30 15:09:20 -06:00
// We read stderr in another thread as it seems that `read_to_string` can block
// indefinitely in some cases
let ( err_sender , err_receiver ) = mpsc ::channel ( ) ;
thread ::spawn ( move | | {
let mut err = String ::new ( ) ;
shard_stderr_reader . read_to_string ( & mut err ) . unwrap ( ) ;
err_sender . send ( err ) . unwrap_or ( ( ) ) ;
} ) ;
let err = err_receiver
. recv_timeout ( Duration ::from_millis ( 100 ) )
. map_err ( | err | {
tracing ::error! ( " Unable to read shard {rank} error from stderr " ) ;
err
} )
. ok ( ) ;
2023-06-02 06:17:27 -06:00
2023-07-13 06:22:37 -06:00
if let Some ( signal ) = exit_status . signal ( ) {
2023-06-02 06:17:27 -06:00
tracing ::error! ( " Shard process was signaled to shutdown with signal {signal} " ) ;
}
2023-04-26 06:43:36 -06:00
status_sender
. send ( ShardStatus ::Failed ( ( rank , err ) ) )
. unwrap ( ) ;
return ;
}
// We received a shutdown signal
2023-06-30 15:09:20 -06:00
if shutdown . load ( Ordering ::SeqCst ) {
2023-06-30 11:09:59 -06:00
p . kill ( ) . unwrap ( ) ;
2023-07-13 06:22:37 -06:00
let _ = p . wait ( ) ;
2023-04-26 06:43:36 -06:00
tracing ::info! ( " Shard {rank} terminated " ) ;
return ;
}
// Shard is ready
if uds . exists ( ) & & ! ready {
tracing ::info! ( " Shard {rank} ready in {:?} " , start_time . elapsed ( ) ) ;
status_sender . send ( ShardStatus ::Ready ) . unwrap ( ) ;
ready = true ;
} else if ! ready & & wait_time . elapsed ( ) > Duration ::from_secs ( 10 ) {
tracing ::info! ( " Waiting for shard {rank} to be ready... " ) ;
wait_time = Instant ::now ( ) ;
}
sleep ( Duration ::from_millis ( 100 ) ) ;
}
}
2023-06-30 15:09:20 -06:00
fn shutdown_shards ( shutdown : Arc < AtomicBool > , shutdown_receiver : & mpsc ::Receiver < ( ) > ) {
2023-04-26 06:43:36 -06:00
tracing ::info! ( " Shutting down shards " ) ;
// Update shutdown value to true
// This will be picked up by the shard manager
2023-06-30 15:09:20 -06:00
shutdown . store ( true , Ordering ::SeqCst ) ;
2023-04-26 06:43:36 -06:00
// Wait for shards to shutdown
// This will block till all shutdown_sender are dropped
let _ = shutdown_receiver . recv ( ) ;
}
fn num_cuda_devices ( ) -> Option < usize > {
2023-05-30 05:27:48 -06:00
let devices = match env ::var ( " CUDA_VISIBLE_DEVICES " ) {
Ok ( devices ) = > devices ,
Err ( _ ) = > env ::var ( " NVIDIA_VISIBLE_DEVICES " ) . ok ( ) ? ,
} ;
2023-05-30 04:52:18 -06:00
let n_devices = devices . split ( ',' ) . count ( ) ;
Some ( n_devices )
2023-04-26 06:43:36 -06:00
}
#[ derive(Deserialize) ]
#[ serde(rename_all = " UPPERCASE " ) ]
enum PythonLogLevelEnum {
Trace ,
Debug ,
Info ,
Success ,
Warning ,
Error ,
Critical ,
}
#[ derive(Deserialize) ]
struct PythonLogLevel {
name : PythonLogLevelEnum ,
}
#[ derive(Deserialize) ]
struct PythonLogRecord {
level : PythonLogLevel ,
}
#[ derive(Deserialize) ]
struct PythonLogMessage {
text : String ,
record : PythonLogRecord ,
}
impl PythonLogMessage {
fn trace ( & self ) {
match self . record . level . name {
PythonLogLevelEnum ::Trace = > tracing ::trace! ( " {} " , self . text ) ,
PythonLogLevelEnum ::Debug = > tracing ::debug! ( " {} " , self . text ) ,
PythonLogLevelEnum ::Info = > tracing ::info! ( " {} " , self . text ) ,
PythonLogLevelEnum ::Success = > tracing ::info! ( " {} " , self . text ) ,
PythonLogLevelEnum ::Warning = > tracing ::warn! ( " {} " , self . text ) ,
PythonLogLevelEnum ::Error = > tracing ::error! ( " {} " , self . text ) ,
PythonLogLevelEnum ::Critical = > tracing ::error! ( " {} " , self . text ) ,
}
}
}
2023-07-13 06:22:37 -06:00
fn find_num_shards (
sharded : Option < bool > ,
num_shard : Option < usize > ,
) -> Result < usize , LauncherError > {
2023-04-26 06:43:36 -06:00
// get the number of shards given `sharded` and `num_shard`
let num_shard = match ( sharded , num_shard ) {
( Some ( true ) , None ) = > {
// try to default to the number of available GPUs
2023-05-30 04:52:18 -06:00
tracing ::info! ( " Parsing num_shard from CUDA_VISIBLE_DEVICES/NVIDIA_VISIBLE_DEVICES " ) ;
let n_devices = num_cuda_devices ( )
. expect ( " --num-shard and CUDA_VISIBLE_DEVICES/NVIDIA_VISIBLE_DEVICES are not set " ) ;
2023-04-26 06:43:36 -06:00
if n_devices < = 1 {
2023-07-13 06:22:37 -06:00
return Err ( LauncherError ::NotEnoughCUDADevices ( format! (
" `sharded` is true but only found {n_devices} CUDA devices "
) ) ) ;
2023-03-08 03:06:59 -07:00
}
2023-04-26 06:43:36 -06:00
n_devices
2023-03-08 03:06:59 -07:00
}
2023-04-26 06:43:36 -06:00
( Some ( true ) , Some ( num_shard ) ) = > {
// we can't have only one shard while sharded
if num_shard < = 1 {
2023-07-13 06:22:37 -06:00
return Err ( LauncherError ::ArgumentValidation (
" `sharded` is true but `num_shard` <= 1 " . to_string ( ) ,
) ) ;
2023-04-26 06:43:36 -06:00
}
num_shard
2023-03-08 05:53:41 -07:00
}
2023-04-26 06:43:36 -06:00
( Some ( false ) , Some ( num_shard ) ) = > num_shard ,
( Some ( false ) , None ) = > 1 ,
( None , None ) = > num_cuda_devices ( ) . unwrap_or ( 1 ) ,
( None , Some ( num_shard ) ) = > num_shard ,
2023-03-08 03:06:59 -07:00
} ;
2023-03-08 05:53:41 -07:00
if num_shard < 1 {
2023-07-13 06:22:37 -06:00
return Err ( LauncherError ::ArgumentValidation (
" `num_shard` cannot be < 1 " . to_string ( ) ,
) ) ;
2023-03-08 05:53:41 -07:00
}
2023-07-13 06:22:37 -06:00
Ok ( num_shard )
2023-04-26 06:43:36 -06:00
}
2023-03-08 05:53:41 -07:00
2023-04-26 06:43:36 -06:00
#[ derive(Debug) ]
enum LauncherError {
2023-07-13 06:22:37 -06:00
ArgumentValidation ( String ) ,
NotEnoughCUDADevices ( String ) ,
2023-04-26 06:43:36 -06:00
DownloadError ,
ShardCannotStart ,
ShardDisconnected ,
ShardFailed ,
WebserverFailed ,
WebserverCannotStart ,
}
2022-10-18 07:19:03 -06:00
2023-06-01 04:10:35 -06:00
fn download_convert_model ( args : & Args , running : Arc < AtomicBool > ) -> Result < ( ) , LauncherError > {
2023-04-26 06:43:36 -06:00
let mut download_argv = vec! [
" download-weights " . to_string ( ) ,
args . model_id . to_string ( ) ,
" --extension " . to_string ( ) ,
" .safetensors " . to_string ( ) ,
" --logger-level " . to_string ( ) ,
" INFO " . to_string ( ) ,
" --json-output " . to_string ( ) ,
] ;
2023-03-06 06:39:36 -07:00
2023-04-26 06:43:36 -06:00
// Model optional revision
if let Some ( revision ) = & args . revision {
download_argv . push ( " --revision " . to_string ( ) ) ;
download_argv . push ( revision . to_string ( ) )
}
2023-02-14 05:02:16 -07:00
2023-04-26 06:43:36 -06:00
// Copy current process env
let mut env : Vec < ( OsString , OsString ) > = env ::vars_os ( ) . collect ( ) ;
2023-02-14 05:02:16 -07:00
2023-05-03 15:39:35 -06:00
// If huggingface_hub_cache is set, pass it to the download process
2023-04-26 06:43:36 -06:00
// Useful when running inside a docker container
if let Some ( ref huggingface_hub_cache ) = args . huggingface_hub_cache {
env . push ( ( " HUGGINGFACE_HUB_CACHE " . into ( ) , huggingface_hub_cache . into ( ) ) ) ;
} ;
2023-02-14 05:02:16 -07:00
2023-04-26 06:43:36 -06:00
// Enable hf transfer for insane download speeds
let enable_hf_transfer = env ::var ( " HF_HUB_ENABLE_HF_TRANSFER " ) . unwrap_or ( " 1 " . to_string ( ) ) ;
env . push ( (
" HF_HUB_ENABLE_HF_TRANSFER " . into ( ) ,
enable_hf_transfer . into ( ) ,
) ) ;
2023-04-09 12:00:05 -06:00
2023-04-26 06:43:36 -06:00
// Parse Inference API token
if let Ok ( api_token ) = env ::var ( " HF_API_TOKEN " ) {
env . push ( ( " HUGGING_FACE_HUB_TOKEN " . into ( ) , api_token . into ( ) ) )
} ;
2023-02-18 06:04:11 -07:00
2023-05-03 15:39:35 -06:00
// If args.weights_cache_override is some, pass it to the download process
// Useful when running inside a HuggingFace Inference Endpoint
if let Some ( weights_cache_override ) = & args . weights_cache_override {
env . push ( (
" WEIGHTS_CACHE_OVERRIDE " . into ( ) ,
weights_cache_override . into ( ) ,
) ) ;
} ;
2023-04-26 06:43:36 -06:00
// Start process
tracing ::info! ( " Starting download process. " ) ;
2023-07-13 06:22:37 -06:00
let mut download_process = match Command ::new ( " text-generation-server " )
. args ( download_argv )
. envs ( env )
. stdout ( Stdio ::piped ( ) )
. stderr ( Stdio ::piped ( ) )
. process_group ( 0 )
. spawn ( )
{
2023-04-26 06:43:36 -06:00
Ok ( p ) = > p ,
Err ( err ) = > {
2023-07-13 06:22:37 -06:00
if err . kind ( ) = = io ::ErrorKind ::NotFound {
tracing ::error! ( " text-generation-server not found in PATH " ) ;
tracing ::error! ( " Please install it with `make install-server` " )
2023-02-14 05:02:16 -07:00
}
2023-07-13 06:22:37 -06:00
2023-04-26 06:43:36 -06:00
return Err ( LauncherError ::DownloadError ) ;
}
} ;
2023-02-14 05:02:16 -07:00
2023-04-26 06:43:36 -06:00
// Redirect STDOUT to the console
let download_stdout = download_process . stdout . take ( ) . unwrap ( ) ;
thread ::spawn ( move | | {
// Enter download tracing span
let stdout = BufReader ::new ( download_stdout ) ;
let _span = tracing ::span! ( tracing ::Level ::INFO , " download " ) . entered ( ) ;
for line in stdout . lines ( ) {
// Parse loguru logs
if let Ok ( log ) = serde_json ::from_str ::< PythonLogMessage > ( & line . unwrap ( ) ) {
log . trace ( ) ;
2023-02-14 05:02:16 -07:00
}
2023-04-26 06:43:36 -06:00
}
} ) ;
2023-02-14 05:02:16 -07:00
2023-04-26 06:43:36 -06:00
loop {
2023-07-13 06:22:37 -06:00
if let Some ( status ) = download_process . try_wait ( ) . unwrap ( ) {
if status . success ( ) {
tracing ::info! ( " Successfully downloaded weights. " ) ;
break ;
2023-02-14 05:02:16 -07:00
}
2023-07-13 06:22:37 -06:00
let mut err = String ::new ( ) ;
download_process
. stderr
. take ( )
. unwrap ( )
. read_to_string ( & mut err )
. unwrap ( ) ;
if let Some ( signal ) = status . signal ( ) {
tracing ::error! (
" Download process was signaled to shutdown with signal {signal}: {err} "
) ;
} else {
tracing ::error! ( " Download encountered an error: {err} " ) ;
}
return Err ( LauncherError ::DownloadError ) ;
2023-02-14 05:02:16 -07:00
}
2023-04-26 06:43:36 -06:00
if ! running . load ( Ordering ::SeqCst ) {
2023-07-13 06:22:37 -06:00
signal ::kill ( Pid ::from_raw ( download_process . id ( ) as i32 ) , Signal ::SIGTERM ) . unwrap ( ) ;
2023-04-26 06:43:36 -06:00
tracing ::info! ( " Waiting for download process to gracefully shutdown " ) ;
2023-07-13 06:22:37 -06:00
download_process . wait ( ) . unwrap ( ) ;
2023-04-26 06:43:36 -06:00
tracing ::info! ( " Download process terminated " ) ;
return Ok ( ( ) ) ;
}
sleep ( Duration ::from_millis ( 100 ) ) ;
2023-02-14 05:02:16 -07:00
}
2023-04-26 06:43:36 -06:00
Ok ( ( ) )
}
2023-02-14 05:02:16 -07:00
2023-04-26 12:23:54 -06:00
#[ allow(clippy::too_many_arguments) ]
2023-04-26 06:43:36 -06:00
fn spawn_shards (
num_shard : usize ,
args : & Args ,
2023-06-30 15:09:20 -06:00
shutdown : Arc < AtomicBool > ,
2023-04-26 06:43:36 -06:00
shutdown_receiver : & mpsc ::Receiver < ( ) > ,
shutdown_sender : mpsc ::Sender < ( ) > ,
status_receiver : & mpsc ::Receiver < ShardStatus > ,
status_sender : mpsc ::Sender < ShardStatus > ,
running : Arc < AtomicBool > ,
) -> Result < ( ) , LauncherError > {
2023-05-23 12:40:39 -06:00
if args . trust_remote_code {
tracing ::warn! (
" `trust_remote_code` is set. Trusting that model `{}` do not contain malicious code. " ,
args . model_id
) ;
if args . revision . is_none ( ) {
tracing ::warn! ( " Explicitly passing a `revision` is encouraged when loading a model with custom code to ensure no malicious code has been contributed in a newer revision. " ) ;
}
}
2022-10-18 07:19:03 -06:00
// Start shard processes
for rank in 0 .. num_shard {
2023-04-26 06:43:36 -06:00
let model_id = args . model_id . clone ( ) ;
let revision = args . revision . clone ( ) ;
let uds_path = args . shard_uds_path . clone ( ) ;
let master_addr = args . master_addr . clone ( ) ;
let huggingface_hub_cache = args . huggingface_hub_cache . clone ( ) ;
let weights_cache_override = args . weights_cache_override . clone ( ) ;
2022-10-18 07:19:03 -06:00
let status_sender = status_sender . clone ( ) ;
let shutdown = shutdown . clone ( ) ;
let shutdown_sender = shutdown_sender . clone ( ) ;
2023-04-26 06:43:36 -06:00
let otlp_endpoint = args . otlp_endpoint . clone ( ) ;
2023-04-26 12:23:54 -06:00
let quantize = args . quantize ;
2023-06-30 12:30:09 -06:00
let dtype = args . dtype ;
2023-05-23 12:40:39 -06:00
let trust_remote_code = args . trust_remote_code ;
2023-04-26 12:23:54 -06:00
let master_port = args . master_port ;
let disable_custom_kernels = args . disable_custom_kernels ;
let watermark_gamma = args . watermark_gamma ;
let watermark_delta = args . watermark_delta ;
2022-10-18 07:19:03 -06:00
thread ::spawn ( move | | {
shard_manager (
2023-02-03 04:43:37 -07:00
model_id ,
2023-01-31 10:53:56 -07:00
revision ,
2022-10-27 06:25:29 -06:00
quantize ,
2023-06-30 12:30:09 -06:00
dtype ,
2023-05-23 12:40:39 -06:00
trust_remote_code ,
2022-10-18 07:19:03 -06:00
uds_path ,
rank ,
num_shard ,
master_addr ,
master_port ,
2023-02-14 05:02:16 -07:00
huggingface_hub_cache ,
weights_cache_override ,
2023-02-15 08:23:45 -07:00
disable_custom_kernels ,
2023-03-02 04:30:41 -07:00
watermark_gamma ,
watermark_delta ,
2023-02-13 05:02:45 -07:00
otlp_endpoint ,
2022-10-18 07:19:03 -06:00
status_sender ,
shutdown ,
shutdown_sender ,
)
} ) ;
}
drop ( shutdown_sender ) ;
// Wait for shard to start
let mut shard_ready = 0 ;
while running . load ( Ordering ::SeqCst ) {
match status_receiver . try_recv ( ) {
Ok ( ShardStatus ::Ready ) = > {
shard_ready + = 1 ;
if shard_ready = = num_shard {
break ;
}
}
Err ( TryRecvError ::Empty ) = > {
sleep ( Duration ::from_millis ( 100 ) ) ;
}
2023-04-13 03:07:11 -06:00
Ok ( ShardStatus ::Failed ( ( rank , err ) ) ) = > {
2023-06-30 15:09:20 -06:00
tracing ::error! ( " Shard {rank} failed to start " ) ;
if let Some ( err ) = err {
tracing ::error! ( " {err} " ) ;
}
2023-04-26 12:23:54 -06:00
shutdown_shards ( shutdown , shutdown_receiver ) ;
2023-04-26 06:43:36 -06:00
return Err ( LauncherError ::ShardCannotStart ) ;
2022-10-18 07:19:03 -06:00
}
Err ( TryRecvError ::Disconnected ) = > {
tracing ::error! ( " Shard status channel disconnected " ) ;
2023-04-26 12:23:54 -06:00
shutdown_shards ( shutdown , shutdown_receiver ) ;
2023-04-26 06:43:36 -06:00
return Err ( LauncherError ::ShardDisconnected ) ;
2022-10-18 07:19:03 -06:00
}
}
}
2023-04-26 06:43:36 -06:00
Ok ( ( ) )
}
2022-10-18 07:19:03 -06:00
2023-04-26 06:43:36 -06:00
fn spawn_webserver (
args : Args ,
2023-06-30 15:09:20 -06:00
shutdown : Arc < AtomicBool > ,
2023-04-26 06:43:36 -06:00
shutdown_receiver : & mpsc ::Receiver < ( ) > ,
2023-07-13 06:22:37 -06:00
) -> Result < Child , LauncherError > {
2022-10-18 07:19:03 -06:00
// All shard started
// Start webserver
tracing ::info! ( " Starting Webserver " ) ;
2022-11-02 10:29:56 -06:00
let mut argv = vec! [
" --max-concurrent-requests " . to_string ( ) ,
2023-04-26 06:43:36 -06:00
args . max_concurrent_requests . to_string ( ) ,
2023-03-09 07:30:54 -07:00
" --max-best-of " . to_string ( ) ,
2023-04-26 06:43:36 -06:00
args . max_best_of . to_string ( ) ,
2023-03-03 08:01:25 -07:00
" --max-stop-sequences " . to_string ( ) ,
2023-04-26 06:43:36 -06:00
args . max_stop_sequences . to_string ( ) ,
2022-11-02 10:29:56 -06:00
" --max-input-length " . to_string ( ) ,
2023-04-26 06:43:36 -06:00
args . max_input_length . to_string ( ) ,
2023-03-03 08:01:25 -07:00
" --max-total-tokens " . to_string ( ) ,
2023-04-26 06:43:36 -06:00
args . max_total_tokens . to_string ( ) ,
2023-06-30 11:09:59 -06:00
" --max-batch-prefill-tokens " . to_string ( ) ,
args . max_batch_prefill_tokens . to_string ( ) ,
" --max-batch-total-tokens " . to_string ( ) ,
args . max_batch_total_tokens . to_string ( ) ,
2023-04-24 09:59:00 -06:00
" --waiting-served-ratio " . to_string ( ) ,
2023-04-26 06:43:36 -06:00
args . waiting_served_ratio . to_string ( ) ,
2022-11-02 10:29:56 -06:00
" --max-waiting-tokens " . to_string ( ) ,
2023-04-26 06:43:36 -06:00
args . max_waiting_tokens . to_string ( ) ,
2023-07-13 06:22:37 -06:00
" --validation-workers " . to_string ( ) ,
args . validation_workers . to_string ( ) ,
2023-07-05 10:28:45 -06:00
" --hostname " . to_string ( ) ,
args . hostname . to_string ( ) ,
2022-11-02 10:29:56 -06:00
" --port " . to_string ( ) ,
2023-04-26 06:43:36 -06:00
args . port . to_string ( ) ,
2022-11-02 10:29:56 -06:00
" --master-shard-uds-path " . to_string ( ) ,
2023-04-26 06:43:36 -06:00
format! ( " {} -0 " , args . shard_uds_path ) ,
2022-11-02 10:29:56 -06:00
" --tokenizer-name " . to_string ( ) ,
2023-04-26 06:43:36 -06:00
args . model_id ,
2022-11-02 10:29:56 -06:00
] ;
2023-04-26 06:43:36 -06:00
// Model optional revision
if let Some ( ref revision ) = args . revision {
argv . push ( " --revision " . to_string ( ) ) ;
argv . push ( revision . to_string ( ) )
2022-10-27 06:25:29 -06:00
}
2023-04-26 06:43:36 -06:00
if args . json_output {
argv . push ( " --json-output " . to_string ( ) ) ;
2023-01-31 10:53:56 -07:00
}
2023-02-13 05:02:45 -07:00
// OpenTelemetry
2023-04-26 06:43:36 -06:00
if let Some ( otlp_endpoint ) = args . otlp_endpoint {
argv . push ( " --otlp-endpoint " . to_string ( ) ) ;
argv . push ( otlp_endpoint ) ;
}
// CORS origins
for origin in args . cors_allow_origin . into_iter ( ) {
argv . push ( " --cors-allow-origin " . to_string ( ) ) ;
argv . push ( origin ) ;
2023-02-13 05:02:45 -07:00
}
2023-06-16 08:25:11 -06:00
// Ngrok
if args . ngrok {
let authtoken = args . ngrok_authtoken . ok_or_else ( | | {
tracing ::error! ( " `ngrok-authtoken` must be set when using ngrok tunneling " ) ;
LauncherError ::WebserverCannotStart
} ) ? ;
argv . push ( " --ngrok " . to_string ( ) ) ;
argv . push ( " --ngrok-authtoken " . to_string ( ) ) ;
argv . push ( authtoken ) ;
if let Some ( domain ) = args . ngrok_domain {
argv . push ( " --ngrok-domain " . to_string ( ) ) ;
argv . push ( domain ) ;
}
if let ( Some ( username ) , Some ( password ) ) = ( args . ngrok_username , args . ngrok_password ) {
argv . push ( " --ngrok-username " . to_string ( ) ) ;
argv . push ( username ) ;
argv . push ( " --ngrok-password " . to_string ( ) ) ;
argv . push ( password ) ;
}
}
2023-02-16 03:20:23 -07:00
// Copy current process env
let mut env : Vec < ( OsString , OsString ) > = env ::vars_os ( ) . collect ( ) ;
2023-04-09 12:00:05 -06:00
// Parse Inference API token
if let Ok ( api_token ) = env ::var ( " HF_API_TOKEN " ) {
env . push ( ( " HUGGING_FACE_HUB_TOKEN " . into ( ) , api_token . into ( ) ) )
} ;
2023-02-18 06:04:11 -07:00
2023-07-13 06:22:37 -06:00
let mut webserver = match Command ::new ( " text-generation-router " )
. args ( argv )
. envs ( env )
. stdout ( Stdio ::piped ( ) )
. stderr ( Stdio ::piped ( ) )
. process_group ( 0 )
. spawn ( )
{
2022-10-18 07:19:03 -06:00
Ok ( p ) = > p ,
Err ( err ) = > {
2023-04-26 06:43:36 -06:00
tracing ::error! ( " Failed to start webserver: {} " , err ) ;
2023-07-13 06:22:37 -06:00
if err . kind ( ) = = io ::ErrorKind ::NotFound {
tracing ::error! ( " text-generation-router not found in PATH " ) ;
tracing ::error! ( " Please install it with `make install-router` " )
2023-04-26 06:43:36 -06:00
} else {
tracing ::error! ( " {} " , err ) ;
2022-10-18 07:19:03 -06:00
}
2023-04-26 06:43:36 -06:00
2023-04-26 12:23:54 -06:00
shutdown_shards ( shutdown , shutdown_receiver ) ;
2023-04-26 06:43:36 -06:00
return Err ( LauncherError ::WebserverCannotStart ) ;
2022-10-18 07:19:03 -06:00
}
} ;
2023-04-26 06:43:36 -06:00
// Redirect STDOUT and STDERR to the console
let webserver_stdout = webserver . stdout . take ( ) . unwrap ( ) ;
let webserver_stderr = webserver . stderr . take ( ) . unwrap ( ) ;
2023-01-05 04:01:23 -07:00
thread ::spawn ( move | | {
2023-04-26 06:43:36 -06:00
let stdout = BufReader ::new ( webserver_stdout ) ;
let stderr = BufReader ::new ( webserver_stderr ) ;
2023-01-05 04:01:23 -07:00
for line in stdout . lines ( ) {
2023-04-26 06:43:36 -06:00
println! ( " {} " , line . unwrap ( ) ) ;
2023-01-05 04:01:23 -07:00
}
2023-04-26 06:43:36 -06:00
for line in stderr . lines ( ) {
println! ( " {} " , line . unwrap ( ) ) ;
2022-10-18 07:19:03 -06:00
}
2023-04-26 06:43:36 -06:00
} ) ;
Ok ( webserver )
}
2022-10-18 07:19:03 -06:00
2023-04-26 06:43:36 -06:00
fn main ( ) -> Result < ( ) , LauncherError > {
// Pattern match configuration
let args = Args ::parse ( ) ;
2022-10-18 07:19:03 -06:00
2023-04-26 06:43:36 -06:00
if args . json_output {
tracing_subscriber ::fmt ( ) . json ( ) . init ( ) ;
} else {
tracing_subscriber ::fmt ( ) . compact ( ) . init ( ) ;
2022-10-18 07:19:03 -06:00
}
2023-05-02 07:43:19 -06:00
if args . env {
let env_runtime = env_runtime ::Env ::new ( ) ;
tracing ::info! ( " {} " , env_runtime ) ;
}
2023-04-26 06:43:36 -06:00
tracing ::info! ( " {:?} " , args ) ;
2023-07-13 06:22:37 -06:00
// Validate args
if args . max_input_length > = args . max_total_tokens {
return Err ( LauncherError ::ArgumentValidation (
" `max_input_length` must be < `max_total_tokens` " . to_string ( ) ,
) ) ;
}
if args . max_input_length as u32 > args . max_batch_prefill_tokens {
return Err ( LauncherError ::ArgumentValidation ( format! (
" `max_batch_prefill_tokens` must be >= `max_input_length`. Given: {} and {} " ,
args . max_batch_prefill_tokens , args . max_input_length
) ) ) ;
}
if args . max_batch_prefill_tokens > args . max_batch_total_tokens {
return Err ( LauncherError ::ArgumentValidation ( format! (
" `max_batch_prefill_tokens` must be <= `max_batch_total_tokens`. Given: {} and {} " ,
args . max_batch_prefill_tokens , args . max_batch_total_tokens
) ) ) ;
}
if args . max_total_tokens as u32 > args . max_batch_total_tokens {
return Err ( LauncherError ::ArgumentValidation ( format! (
" `max_total_tokens` must be <= `max_batch_total_tokens`. Given: {} and {} " ,
args . max_total_tokens , args . max_batch_total_tokens
) ) ) ;
}
if args . validation_workers = = 0 {
return Err ( LauncherError ::ArgumentValidation (
" `validation_workers` must be > 0 " . to_string ( ) ,
) ) ;
}
let num_shard = find_num_shards ( args . sharded , args . num_shard ) ? ;
2023-04-26 06:43:36 -06:00
if num_shard > 1 {
tracing ::info! ( " Sharding model on {num_shard} processes " ) ;
2022-10-18 07:19:03 -06:00
}
2023-04-26 06:43:36 -06:00
// Signal handler
let running = Arc ::new ( AtomicBool ::new ( true ) ) ;
let r = running . clone ( ) ;
ctrlc ::set_handler ( move | | {
r . store ( false , Ordering ::SeqCst ) ;
} )
. expect ( " Error setting Ctrl-C handler " ) ;
2023-03-08 05:53:41 -07:00
2023-05-03 03:36:24 -06:00
// Download and convert model weights
2023-06-01 04:10:35 -06:00
download_convert_model ( & args , running . clone ( ) ) ? ;
2023-04-16 16:26:47 -06:00
2023-04-26 06:43:36 -06:00
// Shared shutdown bool
2023-06-30 15:09:20 -06:00
let shutdown = Arc ::new ( AtomicBool ::new ( false ) ) ;
2023-04-26 06:43:36 -06:00
// Shared shutdown channel
// When shutting down, the main thread will wait for all senders to be dropped
let ( shutdown_sender , shutdown_receiver ) = mpsc ::channel ( ) ;
2023-04-16 16:26:47 -06:00
2023-04-26 06:43:36 -06:00
// Shared channel to track shard status
let ( status_sender , status_receiver ) = mpsc ::channel ( ) ;
2023-04-16 16:26:47 -06:00
2023-04-26 06:43:36 -06:00
spawn_shards (
num_shard ,
& args ,
shutdown . clone ( ) ,
& shutdown_receiver ,
shutdown_sender ,
& status_receiver ,
status_sender ,
running . clone ( ) ,
) ? ;
2023-04-16 16:26:47 -06:00
2023-04-26 06:43:36 -06:00
// We might have received a termination signal
if ! running . load ( Ordering ::SeqCst ) {
shutdown_shards ( shutdown , & shutdown_receiver ) ;
return Ok ( ( ) ) ;
}
2023-04-16 16:26:47 -06:00
2023-07-01 11:25:41 -06:00
let mut webserver =
spawn_webserver ( args , shutdown . clone ( ) , & shutdown_receiver ) . map_err ( | err | {
shutdown_shards ( shutdown . clone ( ) , & shutdown_receiver ) ;
err
} ) ? ;
2023-04-26 06:43:36 -06:00
// Default exit code
let mut exit_code = Ok ( ( ) ) ;
while running . load ( Ordering ::SeqCst ) {
if let Ok ( ShardStatus ::Failed ( ( rank , err ) ) ) = status_receiver . try_recv ( ) {
2023-07-01 11:25:41 -06:00
tracing ::error! ( " Shard {rank} crashed " ) ;
2023-06-30 15:09:20 -06:00
if let Some ( err ) = err {
tracing ::error! ( " {err} " ) ;
}
2023-04-26 06:43:36 -06:00
exit_code = Err ( LauncherError ::ShardFailed ) ;
break ;
} ;
2023-07-13 06:22:37 -06:00
match webserver . try_wait ( ) . unwrap ( ) {
2023-04-26 06:43:36 -06:00
Some ( _ ) = > {
tracing ::error! ( " Webserver Crashed " ) ;
shutdown_shards ( shutdown , & shutdown_receiver ) ;
return Err ( LauncherError ::WebserverFailed ) ;
}
None = > {
sleep ( Duration ::from_millis ( 100 ) ) ;
}
} ;
2023-04-16 16:26:47 -06:00
}
2023-04-26 06:43:36 -06:00
// Graceful termination
2023-07-13 06:22:37 -06:00
signal ::kill ( Pid ::from_raw ( webserver . id ( ) as i32 ) , Signal ::SIGTERM ) . unwrap ( ) ;
2023-04-26 06:43:36 -06:00
tracing ::info! ( " Waiting for webserver to gracefully shutdown " ) ;
2023-07-13 06:22:37 -06:00
webserver . wait ( ) . unwrap ( ) ;
2023-04-26 06:43:36 -06:00
tracing ::info! ( " Webserver terminated " ) ;
shutdown_shards ( shutdown , & shutdown_receiver ) ;
exit_code
2023-04-16 16:26:47 -06:00
}