feat: bundle launcher and refactor cli wrappers
This commit is contained in:
parent
af2b2e8388
commit
30f4deba77
|
@ -1260,8 +1260,64 @@ fn terminate(process_name: &str, mut process: Child, timeout: Duration) -> io::R
|
|||
Ok(exit_status)
|
||||
}
|
||||
|
||||
pub fn internal_main_args() -> Result<(), LauncherError> {
|
||||
let args: Vec<String> = std::env::args()
|
||||
// skips the first arg if it's python
|
||||
.skip_while(|a| a.contains("python"))
|
||||
.collect();
|
||||
let args = Args::parse_from(args);
|
||||
|
||||
internal_main(
|
||||
args.model_id,
|
||||
args.revision,
|
||||
args.validation_workers,
|
||||
args.sharded,
|
||||
args.num_shard,
|
||||
args.quantize,
|
||||
args.speculate,
|
||||
args.dtype,
|
||||
args.trust_remote_code,
|
||||
args.max_concurrent_requests,
|
||||
args.max_best_of,
|
||||
args.max_stop_sequences,
|
||||
args.max_top_n_tokens,
|
||||
args.max_input_tokens,
|
||||
args.max_input_length,
|
||||
args.max_total_tokens,
|
||||
args.waiting_served_ratio,
|
||||
args.max_batch_prefill_tokens,
|
||||
args.max_batch_total_tokens,
|
||||
args.max_waiting_tokens,
|
||||
args.max_batch_size,
|
||||
args.cuda_graphs,
|
||||
args.hostname,
|
||||
args.port,
|
||||
args.shard_uds_path,
|
||||
args.master_addr,
|
||||
args.master_port,
|
||||
args.huggingface_hub_cache,
|
||||
args.weights_cache_override,
|
||||
args.disable_custom_kernels,
|
||||
args.cuda_memory_fraction,
|
||||
args.rope_scaling,
|
||||
args.rope_factor,
|
||||
args.json_output,
|
||||
args.otlp_endpoint,
|
||||
args.cors_allow_origin,
|
||||
args.watermark_gamma,
|
||||
args.watermark_delta,
|
||||
args.ngrok,
|
||||
args.ngrok_authtoken,
|
||||
args.ngrok_edge,
|
||||
args.tokenizer_config_path,
|
||||
args.disable_grammar_support,
|
||||
args.env,
|
||||
args.max_client_batch_size,
|
||||
)
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn launcher_main(
|
||||
pub fn internal_main(
|
||||
model_id: String,
|
||||
revision: Option<String>,
|
||||
validation_workers: usize,
|
||||
|
@ -1639,388 +1695,3 @@ pub fn launcher_main(
|
|||
|
||||
exit_code
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn launcher_main_without_server(
|
||||
model_id: String,
|
||||
revision: Option<String>,
|
||||
validation_workers: usize,
|
||||
sharded: Option<bool>,
|
||||
num_shard: Option<usize>,
|
||||
quantize: Option<Quantization>,
|
||||
speculate: Option<usize>,
|
||||
dtype: Option<Dtype>,
|
||||
trust_remote_code: bool,
|
||||
max_concurrent_requests: usize,
|
||||
max_best_of: usize,
|
||||
max_stop_sequences: usize,
|
||||
max_top_n_tokens: u32,
|
||||
max_input_tokens: Option<usize>,
|
||||
max_input_length: Option<usize>,
|
||||
max_total_tokens: Option<usize>,
|
||||
waiting_served_ratio: f32,
|
||||
max_batch_prefill_tokens: Option<u32>,
|
||||
max_batch_total_tokens: Option<u32>,
|
||||
max_waiting_tokens: usize,
|
||||
max_batch_size: Option<usize>,
|
||||
cuda_graphs: Option<Vec<usize>>,
|
||||
hostname: String,
|
||||
port: u16,
|
||||
shard_uds_path: String,
|
||||
master_addr: String,
|
||||
master_port: usize,
|
||||
huggingface_hub_cache: Option<String>,
|
||||
weights_cache_override: Option<String>,
|
||||
disable_custom_kernels: bool,
|
||||
cuda_memory_fraction: f32,
|
||||
rope_scaling: Option<RopeScaling>,
|
||||
rope_factor: Option<f32>,
|
||||
json_output: bool,
|
||||
otlp_endpoint: Option<String>,
|
||||
cors_allow_origin: Vec<String>,
|
||||
watermark_gamma: Option<f32>,
|
||||
watermark_delta: Option<f32>,
|
||||
ngrok: bool,
|
||||
ngrok_authtoken: Option<String>,
|
||||
ngrok_edge: Option<String>,
|
||||
tokenizer_config_path: Option<String>,
|
||||
disable_grammar_support: bool,
|
||||
env: bool,
|
||||
max_client_batch_size: usize,
|
||||
webserver_callback: Box<dyn FnOnce() -> Result<(), LauncherError>>,
|
||||
) -> Result<(), LauncherError> {
|
||||
let args = Args {
|
||||
model_id,
|
||||
revision,
|
||||
validation_workers,
|
||||
sharded,
|
||||
num_shard,
|
||||
quantize,
|
||||
speculate,
|
||||
dtype,
|
||||
trust_remote_code,
|
||||
max_concurrent_requests,
|
||||
max_best_of,
|
||||
max_stop_sequences,
|
||||
max_top_n_tokens,
|
||||
max_input_tokens,
|
||||
max_input_length,
|
||||
max_total_tokens,
|
||||
waiting_served_ratio,
|
||||
max_batch_prefill_tokens,
|
||||
max_batch_total_tokens,
|
||||
max_waiting_tokens,
|
||||
max_batch_size,
|
||||
cuda_graphs,
|
||||
hostname,
|
||||
port,
|
||||
shard_uds_path,
|
||||
master_addr,
|
||||
master_port,
|
||||
huggingface_hub_cache,
|
||||
weights_cache_override,
|
||||
disable_custom_kernels,
|
||||
cuda_memory_fraction,
|
||||
rope_scaling,
|
||||
rope_factor,
|
||||
json_output,
|
||||
otlp_endpoint,
|
||||
cors_allow_origin,
|
||||
watermark_gamma,
|
||||
watermark_delta,
|
||||
ngrok,
|
||||
ngrok_authtoken,
|
||||
ngrok_edge,
|
||||
tokenizer_config_path,
|
||||
disable_grammar_support,
|
||||
env,
|
||||
max_client_batch_size,
|
||||
};
|
||||
|
||||
// Filter events with LOG_LEVEL
|
||||
let env_filter =
|
||||
EnvFilter::try_from_env("LOG_LEVEL").unwrap_or_else(|_| EnvFilter::new("info"));
|
||||
|
||||
if args.json_output {
|
||||
tracing_subscriber::fmt()
|
||||
.with_env_filter(env_filter)
|
||||
.json()
|
||||
.init();
|
||||
} else {
|
||||
tracing_subscriber::fmt()
|
||||
.with_env_filter(env_filter)
|
||||
.compact()
|
||||
.init();
|
||||
}
|
||||
|
||||
if args.env {
|
||||
let env_runtime = env_runtime::Env::new();
|
||||
tracing::info!("{}", env_runtime);
|
||||
}
|
||||
|
||||
tracing::info!("{:#?}", args);
|
||||
|
||||
let get_max_position_embeddings = || -> Result<usize, Box<dyn std::error::Error>> {
|
||||
let model_id = args.model_id.clone();
|
||||
let mut path = std::path::Path::new(&args.model_id).to_path_buf();
|
||||
let filename = if !path.exists() {
|
||||
// Assume it's a hub id
|
||||
let api = Api::new()?;
|
||||
let repo = if let Some(ref revision) = args.revision {
|
||||
api.repo(Repo::with_revision(
|
||||
model_id,
|
||||
RepoType::Model,
|
||||
revision.to_string(),
|
||||
))
|
||||
} else {
|
||||
api.model(model_id)
|
||||
};
|
||||
repo.get("config.json")?
|
||||
} else {
|
||||
path.push("config.json");
|
||||
path
|
||||
};
|
||||
|
||||
let content = std::fs::read_to_string(filename)?;
|
||||
let config: Config = serde_json::from_str(&content)?;
|
||||
|
||||
// Quantization usually means you're even more RAM constrained.
|
||||
let max_default = 4096;
|
||||
|
||||
let max_position_embeddings = match (config.max_position_embeddings, config.max_seq_len) {
|
||||
(Some(max_position_embeddings), _) | (None, Some(max_position_embeddings)) => {
|
||||
if max_position_embeddings > max_default {
|
||||
let max = max_position_embeddings;
|
||||
if args.max_input_tokens.is_none()
|
||||
&& args.max_total_tokens.is_none()
|
||||
&& args.max_batch_prefill_tokens.is_none()
|
||||
{
|
||||
tracing::info!("Model supports up to {max} but tgi will now set its default to {max_default} instead. This is to save VRAM by refusing large prompts in order to allow more users on the same hardware. You can increase that size using `--max-batch-prefill-tokens={} --max-total-tokens={max} --max-input-tokens={}`.", max + 50, max - 1);
|
||||
}
|
||||
max_default
|
||||
} else {
|
||||
max_position_embeddings
|
||||
}
|
||||
}
|
||||
_ => {
|
||||
return Err(Box::new(LauncherError::ArgumentValidation(
|
||||
"no max defined".to_string(),
|
||||
)));
|
||||
}
|
||||
};
|
||||
Ok(max_position_embeddings)
|
||||
};
|
||||
let max_position_embeddings: usize = get_max_position_embeddings().unwrap_or(4096);
|
||||
|
||||
let max_input_tokens = {
|
||||
match (args.max_input_tokens, args.max_input_length) {
|
||||
(Some(max_input_tokens), Some(max_input_length)) => {
|
||||
return Err(LauncherError::ArgumentValidation(
|
||||
format!("Both `max_input_tokens` ({max_input_tokens}) and `max_input_length` ({max_input_length}) are set. Please define only `max_input_tokens` as `max_input_length is deprecated for naming consistency.",
|
||||
)));
|
||||
}
|
||||
(Some(max_input_tokens), None) | (None, Some(max_input_tokens)) => max_input_tokens,
|
||||
(None, None) => {
|
||||
let value = max_position_embeddings - 1;
|
||||
tracing::info!("Default `max_input_tokens` to {value}");
|
||||
value
|
||||
}
|
||||
}
|
||||
};
|
||||
let max_total_tokens = {
|
||||
match args.max_total_tokens {
|
||||
Some(max_total_tokens) => max_total_tokens,
|
||||
None => {
|
||||
let value = max_position_embeddings;
|
||||
tracing::info!("Default `max_total_tokens` to {value}");
|
||||
value
|
||||
}
|
||||
}
|
||||
};
|
||||
let max_batch_prefill_tokens = {
|
||||
match args.max_batch_prefill_tokens {
|
||||
Some(max_batch_prefill_tokens) => max_batch_prefill_tokens,
|
||||
None => {
|
||||
let value: u32 = if let Some(max_batch_size) = args.max_batch_size {
|
||||
max_batch_size * max_input_tokens
|
||||
} else {
|
||||
// Adding some edge in order to account for potential block_size alignement
|
||||
// issue.
|
||||
max_input_tokens + 50
|
||||
} as u32;
|
||||
tracing::info!("Default `max_batch_prefill_tokens` to {value}");
|
||||
value
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// Validate args
|
||||
if max_input_tokens >= max_total_tokens {
|
||||
return Err(LauncherError::ArgumentValidation(
|
||||
"`max_input_tokens must be < `max_total_tokens`".to_string(),
|
||||
));
|
||||
}
|
||||
if max_input_tokens as u32 > max_batch_prefill_tokens {
|
||||
return Err(LauncherError::ArgumentValidation(format!(
|
||||
"`max_batch_prefill_tokens` must be >= `max_input_tokens`. Given: {} and {}",
|
||||
max_batch_prefill_tokens, max_input_tokens
|
||||
)));
|
||||
}
|
||||
|
||||
let cuda_graphs = match (&args.cuda_graphs, &args.quantize) {
|
||||
(Some(cuda_graphs), _) => cuda_graphs.iter().cloned().filter(|&c| c > 0).collect(),
|
||||
#[allow(deprecated)]
|
||||
(
|
||||
None,
|
||||
Some(
|
||||
Quantization::Bitsandbytes
|
||||
| Quantization::BitsandbytesNF4
|
||||
| Quantization::BitsandbytesFP4,
|
||||
),
|
||||
) => {
|
||||
tracing::info!("Bitsandbytes doesn't work with cuda graphs, deactivating them");
|
||||
vec![]
|
||||
}
|
||||
_ => {
|
||||
let cuda_graphs = vec![1, 2, 4, 8, 16, 32];
|
||||
tracing::info!("Using default cuda graphs {cuda_graphs:?}");
|
||||
cuda_graphs
|
||||
}
|
||||
};
|
||||
|
||||
if args.validation_workers == 0 {
|
||||
return Err(LauncherError::ArgumentValidation(
|
||||
"`validation_workers` must be > 0".to_string(),
|
||||
));
|
||||
}
|
||||
if args.trust_remote_code {
|
||||
tracing::warn!(
|
||||
"`trust_remote_code` is set. Trusting that model `{}` do not contain malicious code.",
|
||||
args.model_id
|
||||
);
|
||||
}
|
||||
|
||||
let num_shard = find_num_shards(args.sharded, args.num_shard)?;
|
||||
if num_shard > 1 {
|
||||
tracing::info!("Sharding model on {num_shard} processes");
|
||||
}
|
||||
|
||||
if let Some(ref max_batch_total_tokens) = args.max_batch_total_tokens {
|
||||
if max_batch_prefill_tokens > *max_batch_total_tokens {
|
||||
return Err(LauncherError::ArgumentValidation(format!(
|
||||
"`max_batch_prefill_tokens` must be <= `max_batch_total_tokens`. Given: {} and {}",
|
||||
max_batch_prefill_tokens, max_batch_total_tokens
|
||||
)));
|
||||
}
|
||||
if max_total_tokens as u32 > *max_batch_total_tokens {
|
||||
return Err(LauncherError::ArgumentValidation(format!(
|
||||
"`max_total_tokens` must be <= `max_batch_total_tokens`. Given: {} and {}",
|
||||
max_total_tokens, max_batch_total_tokens
|
||||
)));
|
||||
}
|
||||
}
|
||||
|
||||
if args.ngrok {
|
||||
if args.ngrok_authtoken.is_none() {
|
||||
return Err(LauncherError::ArgumentValidation(
|
||||
"`ngrok-authtoken` must be set when using ngrok tunneling".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
if args.ngrok_edge.is_none() {
|
||||
return Err(LauncherError::ArgumentValidation(
|
||||
"`ngrok-edge` must be set when using ngrok tunneling".to_string(),
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
// Signal handler
|
||||
let running = Arc::new(AtomicBool::new(true));
|
||||
let r = running.clone();
|
||||
ctrlc::set_handler(move || {
|
||||
r.store(false, Ordering::SeqCst);
|
||||
})
|
||||
.expect("Error setting Ctrl-C handler");
|
||||
|
||||
// Download and convert model weights
|
||||
download_convert_model(&args, running.clone())?;
|
||||
|
||||
if !running.load(Ordering::SeqCst) {
|
||||
// Launcher was asked to stop
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// Shared shutdown bool
|
||||
let shutdown = Arc::new(AtomicBool::new(false));
|
||||
// Shared shutdown channel
|
||||
// When shutting down, the main thread will wait for all senders to be dropped
|
||||
let (shutdown_sender, shutdown_receiver) = mpsc::channel();
|
||||
|
||||
// Shared channel to track shard status
|
||||
let (status_sender, status_receiver) = mpsc::channel();
|
||||
|
||||
spawn_shards(
|
||||
num_shard,
|
||||
&args,
|
||||
cuda_graphs,
|
||||
max_total_tokens,
|
||||
shutdown.clone(),
|
||||
&shutdown_receiver,
|
||||
shutdown_sender,
|
||||
&status_receiver,
|
||||
status_sender,
|
||||
running.clone(),
|
||||
)?;
|
||||
|
||||
// We might have received a termination signal
|
||||
if !running.load(Ordering::SeqCst) {
|
||||
shutdown_shards(shutdown, &shutdown_receiver);
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// let mut webserver = spawn_webserver(
|
||||
// num_shard,
|
||||
// args,
|
||||
// max_input_tokens,
|
||||
// max_total_tokens,
|
||||
// max_batch_prefill_tokens,
|
||||
// shutdown.clone(),
|
||||
// &shutdown_receiver,
|
||||
// )
|
||||
// .map_err(|err| {
|
||||
// shutdown_shards(shutdown.clone(), &shutdown_receiver);
|
||||
// err
|
||||
// })?;
|
||||
|
||||
webserver_callback()?;
|
||||
|
||||
println!("Webserver started");
|
||||
|
||||
// Default exit code
|
||||
let mut exit_code = Ok(());
|
||||
|
||||
while running.load(Ordering::SeqCst) {
|
||||
if let Ok(ShardStatus::Failed(rank)) = status_receiver.try_recv() {
|
||||
tracing::error!("Shard {rank} crashed");
|
||||
exit_code = Err(LauncherError::ShardFailed);
|
||||
break;
|
||||
};
|
||||
|
||||
// match webserver.try_wait().unwrap() {
|
||||
// Some(_) => {
|
||||
// tracing::error!("Webserver Crashed");
|
||||
// shutdown_shards(shutdown, &shutdown_receiver);
|
||||
// return Err(LauncherError::WebserverFailed);
|
||||
// }
|
||||
// None => {
|
||||
// sleep(Duration::from_millis(100));
|
||||
// }
|
||||
// };
|
||||
}
|
||||
|
||||
// Graceful termination
|
||||
// terminate("webserver", webserver, Duration::from_secs(90)).unwrap();
|
||||
shutdown_shards(shutdown, &shutdown_receiver);
|
||||
|
||||
exit_code
|
||||
}
|
||||
|
|
|
@ -1,53 +1,5 @@
|
|||
use clap::Parser;
|
||||
use text_generation_launcher::{launcher_main, Args, LauncherError};
|
||||
use text_generation_launcher::{internal_main_args, LauncherError};
|
||||
|
||||
fn main() -> Result<(), LauncherError> {
|
||||
let args = Args::parse();
|
||||
launcher_main(
|
||||
args.model_id,
|
||||
args.revision,
|
||||
args.validation_workers,
|
||||
args.sharded,
|
||||
args.num_shard,
|
||||
args.quantize,
|
||||
args.speculate,
|
||||
args.dtype,
|
||||
args.trust_remote_code,
|
||||
args.max_concurrent_requests,
|
||||
args.max_best_of,
|
||||
args.max_stop_sequences,
|
||||
args.max_top_n_tokens,
|
||||
args.max_input_tokens,
|
||||
args.max_input_length,
|
||||
args.max_total_tokens,
|
||||
args.waiting_served_ratio,
|
||||
args.max_batch_prefill_tokens,
|
||||
args.max_batch_total_tokens,
|
||||
args.max_waiting_tokens,
|
||||
args.max_batch_size,
|
||||
args.cuda_graphs,
|
||||
args.hostname,
|
||||
args.port,
|
||||
args.shard_uds_path,
|
||||
args.master_addr,
|
||||
args.master_port,
|
||||
args.huggingface_hub_cache,
|
||||
args.weights_cache_override,
|
||||
args.disable_custom_kernels,
|
||||
args.cuda_memory_fraction,
|
||||
args.rope_scaling,
|
||||
args.rope_factor,
|
||||
args.json_output,
|
||||
args.otlp_endpoint,
|
||||
args.cors_allow_origin,
|
||||
args.watermark_gamma,
|
||||
args.watermark_delta,
|
||||
args.ngrok,
|
||||
args.ngrok_authtoken,
|
||||
args.ngrok_edge,
|
||||
args.tokenizer_config_path,
|
||||
args.disable_grammar_support,
|
||||
args.env,
|
||||
args.max_client_batch_size,
|
||||
)
|
||||
internal_main_args()
|
||||
}
|
||||
|
|
|
@ -7,6 +7,7 @@ pub mod server;
|
|||
mod validation;
|
||||
|
||||
use axum::http::HeaderValue;
|
||||
use clap::Parser;
|
||||
use config::Config;
|
||||
use hf_hub::api::tokio::{Api, ApiBuilder, ApiRepo};
|
||||
use hf_hub::{Cache, Repo, RepoType};
|
||||
|
@ -175,6 +176,108 @@ pub enum RouterError {
|
|||
Axum(#[from] axum::BoxError),
|
||||
}
|
||||
|
||||
/// App Configuration
|
||||
#[derive(Parser, Debug)]
|
||||
#[clap(author, version, about, long_about = None)]
|
||||
pub struct Args {
|
||||
#[clap(default_value = "128", long, env)]
|
||||
max_concurrent_requests: usize,
|
||||
#[clap(default_value = "2", long, env)]
|
||||
max_best_of: usize,
|
||||
#[clap(default_value = "4", long, env)]
|
||||
max_stop_sequences: usize,
|
||||
#[clap(default_value = "5", long, env)]
|
||||
max_top_n_tokens: u32,
|
||||
#[clap(default_value = "1024", long, env)]
|
||||
max_input_tokens: usize,
|
||||
#[clap(default_value = "2048", long, env)]
|
||||
max_total_tokens: usize,
|
||||
#[clap(default_value = "1.2", long, env)]
|
||||
waiting_served_ratio: f32,
|
||||
#[clap(default_value = "4096", long, env)]
|
||||
max_batch_prefill_tokens: u32,
|
||||
#[clap(long, env)]
|
||||
max_batch_total_tokens: Option<u32>,
|
||||
#[clap(default_value = "20", long, env)]
|
||||
max_waiting_tokens: usize,
|
||||
#[clap(long, env)]
|
||||
max_batch_size: Option<usize>,
|
||||
#[clap(default_value = "0.0.0.0", long, env)]
|
||||
hostname: String,
|
||||
#[clap(default_value = "3000", long, short, env)]
|
||||
port: u16,
|
||||
#[clap(default_value = "/tmp/text-generation-server-0", long, env)]
|
||||
master_shard_uds_path: String,
|
||||
#[clap(default_value = "bigscience/bloom", long, env)]
|
||||
tokenizer_name: String,
|
||||
#[clap(long, env)]
|
||||
tokenizer_config_path: Option<String>,
|
||||
#[clap(long, env)]
|
||||
revision: Option<String>,
|
||||
#[clap(default_value = "2", long, env)]
|
||||
validation_workers: usize,
|
||||
#[clap(long, env)]
|
||||
json_output: bool,
|
||||
#[clap(long, env)]
|
||||
otlp_endpoint: Option<String>,
|
||||
#[clap(long, env)]
|
||||
cors_allow_origin: Option<Vec<String>>,
|
||||
#[clap(long, env)]
|
||||
ngrok: bool,
|
||||
#[clap(long, env)]
|
||||
ngrok_authtoken: Option<String>,
|
||||
#[clap(long, env)]
|
||||
ngrok_edge: Option<String>,
|
||||
#[clap(long, env, default_value_t = false)]
|
||||
messages_api_enabled: bool,
|
||||
#[clap(long, env, default_value_t = false)]
|
||||
disable_grammar_support: bool,
|
||||
#[clap(default_value = "4", long, env)]
|
||||
max_client_batch_size: usize,
|
||||
}
|
||||
|
||||
pub async fn internal_main_args() -> Result<(), RouterError> {
|
||||
let args: Vec<String> = std::env::args()
|
||||
// skips the first arg if it's python
|
||||
.skip_while(|a| a.contains("python"))
|
||||
.collect();
|
||||
let args = Args::parse_from(args);
|
||||
|
||||
println!("{:?}", args);
|
||||
let out = internal_main(
|
||||
args.max_concurrent_requests,
|
||||
args.max_best_of,
|
||||
args.max_stop_sequences,
|
||||
args.max_top_n_tokens,
|
||||
args.max_input_tokens,
|
||||
args.max_total_tokens,
|
||||
args.waiting_served_ratio,
|
||||
args.max_batch_prefill_tokens,
|
||||
args.max_batch_total_tokens,
|
||||
args.max_waiting_tokens,
|
||||
args.max_batch_size,
|
||||
args.hostname,
|
||||
args.port,
|
||||
args.master_shard_uds_path,
|
||||
args.tokenizer_name,
|
||||
args.tokenizer_config_path,
|
||||
args.revision,
|
||||
args.validation_workers,
|
||||
args.json_output,
|
||||
args.otlp_endpoint,
|
||||
args.cors_allow_origin,
|
||||
args.ngrok,
|
||||
args.ngrok_authtoken,
|
||||
args.ngrok_edge,
|
||||
args.messages_api_enabled,
|
||||
args.disable_grammar_support,
|
||||
args.max_client_batch_size,
|
||||
)
|
||||
.await;
|
||||
println!("[internal_main_args] {:?}", out);
|
||||
out
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub async fn internal_main(
|
||||
max_concurrent_requests: usize,
|
||||
|
|
|
@ -1,100 +1,7 @@
|
|||
use clap::Parser;
|
||||
use text_generation_router::{internal_main, RouterError};
|
||||
|
||||
/// App Configuration
|
||||
#[derive(Parser, Debug)]
|
||||
#[clap(author, version, about, long_about = None)]
|
||||
struct Args {
|
||||
#[clap(default_value = "128", long, env)]
|
||||
max_concurrent_requests: usize,
|
||||
#[clap(default_value = "2", long, env)]
|
||||
max_best_of: usize,
|
||||
#[clap(default_value = "4", long, env)]
|
||||
max_stop_sequences: usize,
|
||||
#[clap(default_value = "5", long, env)]
|
||||
max_top_n_tokens: u32,
|
||||
#[clap(default_value = "1024", long, env)]
|
||||
max_input_tokens: usize,
|
||||
#[clap(default_value = "2048", long, env)]
|
||||
max_total_tokens: usize,
|
||||
#[clap(default_value = "1.2", long, env)]
|
||||
waiting_served_ratio: f32,
|
||||
#[clap(default_value = "4096", long, env)]
|
||||
max_batch_prefill_tokens: u32,
|
||||
#[clap(long, env)]
|
||||
max_batch_total_tokens: Option<u32>,
|
||||
#[clap(default_value = "20", long, env)]
|
||||
max_waiting_tokens: usize,
|
||||
#[clap(long, env)]
|
||||
max_batch_size: Option<usize>,
|
||||
#[clap(default_value = "0.0.0.0", long, env)]
|
||||
hostname: String,
|
||||
#[clap(default_value = "3000", long, short, env)]
|
||||
port: u16,
|
||||
#[clap(default_value = "/tmp/text-generation-server-0", long, env)]
|
||||
master_shard_uds_path: String,
|
||||
#[clap(default_value = "bigscience/bloom", long, env)]
|
||||
tokenizer_name: String,
|
||||
#[clap(long, env)]
|
||||
tokenizer_config_path: Option<String>,
|
||||
#[clap(long, env)]
|
||||
revision: Option<String>,
|
||||
#[clap(default_value = "2", long, env)]
|
||||
validation_workers: usize,
|
||||
#[clap(long, env)]
|
||||
json_output: bool,
|
||||
#[clap(long, env)]
|
||||
otlp_endpoint: Option<String>,
|
||||
#[clap(long, env)]
|
||||
cors_allow_origin: Option<Vec<String>>,
|
||||
#[clap(long, env)]
|
||||
ngrok: bool,
|
||||
#[clap(long, env)]
|
||||
ngrok_authtoken: Option<String>,
|
||||
#[clap(long, env)]
|
||||
ngrok_edge: Option<String>,
|
||||
#[clap(long, env, default_value_t = false)]
|
||||
messages_api_enabled: bool,
|
||||
#[clap(long, env, default_value_t = false)]
|
||||
disable_grammar_support: bool,
|
||||
#[clap(default_value = "4", long, env)]
|
||||
max_client_batch_size: usize,
|
||||
}
|
||||
use text_generation_router::{internal_main_args, RouterError};
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<(), RouterError> {
|
||||
// Get args
|
||||
let args = Args::parse();
|
||||
|
||||
internal_main(
|
||||
args.max_concurrent_requests,
|
||||
args.max_best_of,
|
||||
args.max_stop_sequences,
|
||||
args.max_top_n_tokens,
|
||||
args.max_input_tokens,
|
||||
args.max_total_tokens,
|
||||
args.waiting_served_ratio,
|
||||
args.max_batch_prefill_tokens,
|
||||
args.max_batch_total_tokens,
|
||||
args.max_waiting_tokens,
|
||||
args.max_batch_size,
|
||||
args.hostname,
|
||||
args.port,
|
||||
args.master_shard_uds_path,
|
||||
args.tokenizer_name,
|
||||
args.tokenizer_config_path,
|
||||
args.revision,
|
||||
args.validation_workers,
|
||||
args.json_output,
|
||||
args.otlp_endpoint,
|
||||
args.cors_allow_origin,
|
||||
args.ngrok,
|
||||
args.ngrok_authtoken,
|
||||
args.ngrok_edge,
|
||||
args.messages_api_enabled,
|
||||
args.disable_grammar_support,
|
||||
args.max_client_batch_size,
|
||||
)
|
||||
.await?;
|
||||
internal_main_args().await?;
|
||||
Ok(())
|
||||
}
|
||||
|
|
|
@ -13,3 +13,5 @@ library-install:
|
|||
pip install -e .
|
||||
|
||||
install: build comment-gitignore library-install remove-comment-gitignore
|
||||
|
||||
quick-install: build library-install
|
||||
|
|
|
@ -31,3 +31,5 @@ python-packages = ["tgi", "text_generation_server"]
|
|||
|
||||
[project.scripts]
|
||||
text-generation-server = "tgi:text_generation_server_cli_main"
|
||||
text-generation-router = "tgi:text_generation_router_cli_main"
|
||||
text-generation-launcher = "tgi:text_generation_launcher_cli_main"
|
||||
|
|
316
tgi/src/lib.rs
316
tgi/src/lib.rs
|
@ -1,6 +1,8 @@
|
|||
use pyo3::{prelude::*, wrap_pyfunction};
|
||||
use text_generation_launcher::{launcher_main, launcher_main_without_server};
|
||||
use text_generation_router::internal_main;
|
||||
use std::thread;
|
||||
use text_generation_launcher::{internal_main, internal_main_args as internal_main_args_launcher};
|
||||
use text_generation_router::internal_main_args as internal_main_args_router;
|
||||
use tokio::runtime::Runtime;
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
#[pyfunction]
|
||||
|
@ -100,7 +102,7 @@ fn rust_launcher(
|
|||
max_client_batch_size: usize,
|
||||
) -> PyResult<&PyAny> {
|
||||
pyo3_asyncio::tokio::future_into_py(py, async move {
|
||||
launcher_main(
|
||||
internal_main(
|
||||
model_id,
|
||||
revision,
|
||||
validation_workers,
|
||||
|
@ -153,251 +155,6 @@ fn rust_launcher(
|
|||
})
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
#[pyfunction]
|
||||
#[pyo3(signature = (
|
||||
model_id,
|
||||
revision,
|
||||
validation_workers,
|
||||
sharded,
|
||||
num_shard,
|
||||
_quantize,
|
||||
speculate,
|
||||
_dtype,
|
||||
trust_remote_code,
|
||||
max_concurrent_requests,
|
||||
max_best_of,
|
||||
max_stop_sequences,
|
||||
max_top_n_tokens,
|
||||
max_input_tokens,
|
||||
max_input_length,
|
||||
max_total_tokens,
|
||||
waiting_served_ratio,
|
||||
max_batch_prefill_tokens,
|
||||
max_batch_total_tokens,
|
||||
max_waiting_tokens,
|
||||
max_batch_size,
|
||||
cuda_graphs,
|
||||
hostname,
|
||||
port,
|
||||
shard_uds_path,
|
||||
master_addr,
|
||||
master_port,
|
||||
huggingface_hub_cache,
|
||||
weights_cache_override,
|
||||
disable_custom_kernels,
|
||||
cuda_memory_fraction,
|
||||
_rope_scaling,
|
||||
rope_factor,
|
||||
json_output,
|
||||
otlp_endpoint,
|
||||
cors_allow_origin,
|
||||
watermark_gamma,
|
||||
watermark_delta,
|
||||
ngrok,
|
||||
ngrok_authtoken,
|
||||
ngrok_edge,
|
||||
tokenizer_config_path,
|
||||
disable_grammar_support,
|
||||
env,
|
||||
max_client_batch_size,
|
||||
))]
|
||||
fn fully_packaged(
|
||||
py: Python<'_>,
|
||||
model_id: String,
|
||||
revision: Option<String>,
|
||||
validation_workers: usize,
|
||||
sharded: Option<bool>,
|
||||
num_shard: Option<usize>,
|
||||
_quantize: Option<String>, // Option<Quantization>,
|
||||
speculate: Option<usize>,
|
||||
_dtype: Option<String>, // Option<Dtype>,
|
||||
trust_remote_code: bool,
|
||||
max_concurrent_requests: usize,
|
||||
max_best_of: usize,
|
||||
max_stop_sequences: usize,
|
||||
max_top_n_tokens: u32,
|
||||
max_input_tokens: Option<usize>,
|
||||
max_input_length: Option<usize>,
|
||||
max_total_tokens: Option<usize>,
|
||||
waiting_served_ratio: f32,
|
||||
max_batch_prefill_tokens: Option<u32>,
|
||||
max_batch_total_tokens: Option<u32>,
|
||||
max_waiting_tokens: usize,
|
||||
max_batch_size: Option<usize>,
|
||||
cuda_graphs: Option<Vec<usize>>,
|
||||
hostname: String,
|
||||
port: u16,
|
||||
shard_uds_path: String,
|
||||
master_addr: String,
|
||||
master_port: usize,
|
||||
huggingface_hub_cache: Option<String>,
|
||||
weights_cache_override: Option<String>,
|
||||
disable_custom_kernels: bool,
|
||||
cuda_memory_fraction: f32,
|
||||
_rope_scaling: Option<f32>, // Option<RopeScaling>,
|
||||
rope_factor: Option<f32>,
|
||||
json_output: bool,
|
||||
otlp_endpoint: Option<String>,
|
||||
cors_allow_origin: Vec<String>,
|
||||
watermark_gamma: Option<f32>,
|
||||
watermark_delta: Option<f32>,
|
||||
ngrok: bool,
|
||||
ngrok_authtoken: Option<String>,
|
||||
ngrok_edge: Option<String>,
|
||||
tokenizer_config_path: Option<String>,
|
||||
disable_grammar_support: bool,
|
||||
env: bool,
|
||||
max_client_batch_size: usize,
|
||||
) -> PyResult<&PyAny> {
|
||||
pyo3_asyncio::tokio::future_into_py(py, async move {
|
||||
use std::thread;
|
||||
use tokio::runtime::Runtime;
|
||||
|
||||
let model_id_clone = model_id.clone();
|
||||
let max_concurrent_requests_clone = max_concurrent_requests;
|
||||
let max_best_of_clone = max_best_of;
|
||||
let max_stop_sequences_clone = max_stop_sequences;
|
||||
let max_top_n_tokens_clone = max_top_n_tokens;
|
||||
let max_input_tokens_clone = max_input_tokens.unwrap_or(1024);
|
||||
let max_total_tokens_clone = max_total_tokens.unwrap_or(2048);
|
||||
let waiting_served_ratio_clone = waiting_served_ratio;
|
||||
|
||||
let max_batch_prefill_tokens_clone = max_batch_prefill_tokens.unwrap_or(4096);
|
||||
let max_batch_total_tokens_clone = max_batch_total_tokens;
|
||||
let max_waiting_tokens_clone = max_waiting_tokens;
|
||||
let max_batch_size_clone = max_batch_size;
|
||||
let hostname_clone = hostname.clone();
|
||||
let port_clone = port;
|
||||
|
||||
// TODO: fix this
|
||||
let _shard_uds_path_clone = shard_uds_path.clone();
|
||||
|
||||
let tokenizer_config_path = tokenizer_config_path.clone();
|
||||
let revision = revision.clone();
|
||||
let validation_workers = validation_workers;
|
||||
let json_output = json_output;
|
||||
|
||||
let otlp_endpoint = otlp_endpoint.clone();
|
||||
let cors_allow_origin = cors_allow_origin.clone();
|
||||
let ngrok = ngrok;
|
||||
let ngrok_authtoken = ngrok_authtoken.clone();
|
||||
let ngrok_edge = ngrok_edge.clone();
|
||||
let messages_api_enabled = true;
|
||||
let disable_grammar_support = disable_grammar_support;
|
||||
let max_client_batch_size = max_client_batch_size;
|
||||
|
||||
let ngrok_clone = ngrok;
|
||||
let ngrok_authtoken_clone = ngrok_authtoken.clone();
|
||||
let ngrok_edge_clone = ngrok_edge.clone();
|
||||
let messages_api_enabled_clone = messages_api_enabled;
|
||||
let disable_grammar_support_clone = disable_grammar_support;
|
||||
let max_client_batch_size_clone = max_client_batch_size;
|
||||
|
||||
let tokenizer_config_path_clone = tokenizer_config_path.clone();
|
||||
let revision_clone = revision.clone();
|
||||
let validation_workers_clone = validation_workers;
|
||||
let json_output_clone = json_output;
|
||||
let otlp_endpoint_clone = otlp_endpoint.clone();
|
||||
|
||||
let webserver_callback = move || {
|
||||
let handle = thread::spawn(move || {
|
||||
let rt = Runtime::new().unwrap();
|
||||
rt.block_on(async {
|
||||
internal_main(
|
||||
max_concurrent_requests_clone,
|
||||
max_best_of_clone,
|
||||
max_stop_sequences_clone,
|
||||
max_top_n_tokens_clone,
|
||||
max_input_tokens_clone,
|
||||
max_total_tokens_clone,
|
||||
waiting_served_ratio_clone,
|
||||
max_batch_prefill_tokens_clone,
|
||||
max_batch_total_tokens_clone,
|
||||
max_waiting_tokens_clone,
|
||||
max_batch_size_clone,
|
||||
hostname_clone,
|
||||
port_clone,
|
||||
"/tmp/text-generation-server-0".to_string(),
|
||||
model_id_clone,
|
||||
tokenizer_config_path_clone,
|
||||
revision_clone,
|
||||
validation_workers_clone,
|
||||
json_output_clone,
|
||||
otlp_endpoint_clone,
|
||||
None,
|
||||
ngrok_clone,
|
||||
ngrok_authtoken_clone,
|
||||
ngrok_edge_clone,
|
||||
messages_api_enabled_clone,
|
||||
disable_grammar_support_clone,
|
||||
max_client_batch_size_clone,
|
||||
)
|
||||
.await
|
||||
})
|
||||
});
|
||||
match handle.join() {
|
||||
Ok(_) => println!("Server exited successfully"),
|
||||
Err(e) => println!("Server exited with error: {:?}", e),
|
||||
}
|
||||
Ok(())
|
||||
};
|
||||
|
||||
// parse the arguments and run the main function
|
||||
launcher_main_without_server(
|
||||
model_id,
|
||||
revision,
|
||||
validation_workers,
|
||||
sharded,
|
||||
num_shard,
|
||||
None,
|
||||
speculate,
|
||||
None,
|
||||
trust_remote_code,
|
||||
max_concurrent_requests,
|
||||
max_best_of,
|
||||
max_stop_sequences,
|
||||
max_top_n_tokens,
|
||||
max_input_tokens,
|
||||
max_input_length,
|
||||
max_total_tokens,
|
||||
waiting_served_ratio,
|
||||
max_batch_prefill_tokens,
|
||||
max_batch_total_tokens,
|
||||
max_waiting_tokens,
|
||||
max_batch_size,
|
||||
cuda_graphs,
|
||||
hostname,
|
||||
port,
|
||||
shard_uds_path,
|
||||
master_addr,
|
||||
master_port,
|
||||
huggingface_hub_cache,
|
||||
weights_cache_override,
|
||||
disable_custom_kernels,
|
||||
cuda_memory_fraction,
|
||||
None,
|
||||
rope_factor,
|
||||
json_output,
|
||||
otlp_endpoint,
|
||||
cors_allow_origin,
|
||||
watermark_gamma,
|
||||
watermark_delta,
|
||||
ngrok,
|
||||
ngrok_authtoken,
|
||||
ngrok_edge,
|
||||
tokenizer_config_path,
|
||||
disable_grammar_support,
|
||||
env,
|
||||
max_client_batch_size,
|
||||
Box::new(webserver_callback),
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
Ok(Python::with_gil(|py| py.None()))
|
||||
})
|
||||
}
|
||||
|
||||
/// Asynchronous sleep function.
|
||||
#[pyfunction]
|
||||
fn rust_sleep(py: Python<'_>) -> PyResult<&PyAny> {
|
||||
|
@ -407,49 +164,38 @@ fn rust_sleep(py: Python<'_>) -> PyResult<&PyAny> {
|
|||
})
|
||||
}
|
||||
|
||||
// TODO: remove hardcoding
|
||||
#[pyfunction]
|
||||
fn rust_server(py: Python<'_>) -> PyResult<&PyAny> {
|
||||
pyo3_asyncio::tokio::future_into_py(py, async {
|
||||
let _ = internal_main(
|
||||
128, // max_concurrent_requests: usize,
|
||||
2, // max_best_of: usize,
|
||||
4, // max_stop_sequences: usize,
|
||||
5, // max_top_n_tokens: u32,
|
||||
1024, // max_input_tokens: usize,
|
||||
2048, // max_total_tokens: usize,
|
||||
1.2, // waiting_served_ratio: f32,
|
||||
4096, // max_batch_prefill_tokens: u32,
|
||||
None, // max_batch_total_tokens: Option<u32>,
|
||||
20, // max_waiting_tokens: usize,
|
||||
None, // max_batch_size: Option<usize>,
|
||||
"0.0.0.0".to_string(), // hostname: String,
|
||||
3000, // port: u16,
|
||||
"/tmp/text-generation-server-0".to_string(), // master_shard_uds_path: String,
|
||||
"llava-hf/llava-v1.6-mistral-7b-hf".to_string(), // tokenizer_name: String,
|
||||
None, // tokenizer_config_path: Option<String>,
|
||||
None, // revision: Option<String>,
|
||||
2, // validation_workers: usize,
|
||||
false, // json_output: bool,
|
||||
None, // otlp_endpoint: Option<String>,
|
||||
None, // cors_allow_origin: Option<Vec<String>>,
|
||||
false, // ngrok: bool,
|
||||
None, // ngrok_authtoken: Option<String>,
|
||||
None, // ngrok_edge: Option<String>,
|
||||
false, // messages_api_enabled: bool,
|
||||
false, // disable_grammar_support: bool,
|
||||
4, // max_client_batch_size: usize,
|
||||
)
|
||||
.await;
|
||||
Ok(Python::with_gil(|py| py.None()))
|
||||
})
|
||||
fn rust_router(_py: Python<'_>) -> PyResult<String> {
|
||||
let handle = thread::spawn(move || {
|
||||
let rt = Runtime::new().unwrap();
|
||||
rt.block_on(async { internal_main_args_router().await })
|
||||
});
|
||||
match handle.join() {
|
||||
Ok(thread_output) => match thread_output {
|
||||
Ok(_) => println!("Inner server exited successfully"),
|
||||
Err(e) => println!("Inner server exited with error: {:?}", e),
|
||||
},
|
||||
Err(e) => {
|
||||
println!("Server exited with error: {:?}", e);
|
||||
}
|
||||
}
|
||||
Ok("Completed".to_string())
|
||||
}
|
||||
|
||||
#[pyfunction]
|
||||
fn rust_launcher_cli(_py: Python<'_>) -> PyResult<String> {
|
||||
match internal_main_args_launcher() {
|
||||
Ok(_) => println!("Server exited successfully"),
|
||||
Err(e) => println!("Server exited with error: {:?}", e),
|
||||
}
|
||||
Ok("Completed".to_string())
|
||||
}
|
||||
|
||||
#[pymodule]
|
||||
fn tgi(_py: Python, m: &PyModule) -> PyResult<()> {
|
||||
m.add_function(wrap_pyfunction!(rust_sleep, m)?)?;
|
||||
m.add_function(wrap_pyfunction!(rust_server, m)?)?;
|
||||
m.add_function(wrap_pyfunction!(rust_router, m)?)?;
|
||||
m.add_function(wrap_pyfunction!(rust_launcher, m)?)?;
|
||||
m.add_function(wrap_pyfunction!(fully_packaged, m)?)?;
|
||||
m.add_function(wrap_pyfunction!(rust_launcher_cli, m)?)?;
|
||||
Ok(())
|
||||
}
|
||||
|
|
|
@ -1,9 +1,8 @@
|
|||
from .tgi import *
|
||||
import threading
|
||||
from tgi import rust_launcher, rust_sleep, fully_packaged
|
||||
from tgi import rust_router, rust_launcher, rust_launcher_cli
|
||||
import asyncio
|
||||
from dataclasses import dataclass, asdict
|
||||
import sys
|
||||
from text_generation_server.cli import app
|
||||
|
||||
# add the rust_launcher coroutine to the __all__ list
|
||||
|
@ -17,6 +16,14 @@ def text_generation_server_cli_main():
|
|||
app()
|
||||
|
||||
|
||||
def text_generation_router_cli_main():
|
||||
rust_router()
|
||||
|
||||
|
||||
def text_generation_launcher_cli_main():
|
||||
rust_launcher_cli()
|
||||
|
||||
|
||||
@dataclass
|
||||
class Args:
|
||||
model_id = "google/gemma-2b-it"
|
||||
|
@ -81,7 +88,7 @@ class TGI(object):
|
|||
print(args)
|
||||
args = Args(**args)
|
||||
try:
|
||||
await fully_packaged(
|
||||
await rust_launcher(
|
||||
args.model_id,
|
||||
args.revision,
|
||||
args.validation_workers,
|
||||
|
|
Loading…
Reference in New Issue