From 30f4deba77770b0c594b675d7b202e61835d6cbc Mon Sep 17 00:00:00 2001 From: drbh Date: Sat, 18 May 2024 01:22:25 +0000 Subject: [PATCH] feat: bundle launcher and refactor cli wrappers --- launcher/src/lib.rs | 443 ++++++------------------------------------- launcher/src/main.rs | 52 +---- router/src/lib.rs | 103 ++++++++++ router/src/main.rs | 97 +--------- tgi/Makefile | 2 + tgi/pyproject.toml | 2 + tgi/src/lib.rs | 316 +++--------------------------- tgi/tgi/__init__.py | 13 +- 8 files changed, 209 insertions(+), 819 deletions(-) diff --git a/launcher/src/lib.rs b/launcher/src/lib.rs index d917a504..a4ffae81 100644 --- a/launcher/src/lib.rs +++ b/launcher/src/lib.rs @@ -1260,8 +1260,64 @@ fn terminate(process_name: &str, mut process: Child, timeout: Duration) -> io::R Ok(exit_status) } +pub fn internal_main_args() -> Result<(), LauncherError> { + let args: Vec = std::env::args() + // skips the first arg if it's python + .skip_while(|a| a.contains("python")) + .collect(); + let args = Args::parse_from(args); + + internal_main( + args.model_id, + args.revision, + args.validation_workers, + args.sharded, + args.num_shard, + args.quantize, + args.speculate, + args.dtype, + args.trust_remote_code, + args.max_concurrent_requests, + args.max_best_of, + args.max_stop_sequences, + args.max_top_n_tokens, + args.max_input_tokens, + args.max_input_length, + args.max_total_tokens, + args.waiting_served_ratio, + args.max_batch_prefill_tokens, + args.max_batch_total_tokens, + args.max_waiting_tokens, + args.max_batch_size, + args.cuda_graphs, + args.hostname, + args.port, + args.shard_uds_path, + args.master_addr, + args.master_port, + args.huggingface_hub_cache, + args.weights_cache_override, + args.disable_custom_kernels, + args.cuda_memory_fraction, + args.rope_scaling, + args.rope_factor, + args.json_output, + args.otlp_endpoint, + args.cors_allow_origin, + args.watermark_gamma, + args.watermark_delta, + args.ngrok, + args.ngrok_authtoken, + args.ngrok_edge, + args.tokenizer_config_path, + args.disable_grammar_support, + args.env, + args.max_client_batch_size, + ) +} + #[allow(clippy::too_many_arguments)] -pub fn launcher_main( +pub fn internal_main( model_id: String, revision: Option, validation_workers: usize, @@ -1639,388 +1695,3 @@ pub fn launcher_main( exit_code } - -#[allow(clippy::too_many_arguments)] -pub fn launcher_main_without_server( - model_id: String, - revision: Option, - validation_workers: usize, - sharded: Option, - num_shard: Option, - quantize: Option, - speculate: Option, - dtype: Option, - trust_remote_code: bool, - max_concurrent_requests: usize, - max_best_of: usize, - max_stop_sequences: usize, - max_top_n_tokens: u32, - max_input_tokens: Option, - max_input_length: Option, - max_total_tokens: Option, - waiting_served_ratio: f32, - max_batch_prefill_tokens: Option, - max_batch_total_tokens: Option, - max_waiting_tokens: usize, - max_batch_size: Option, - cuda_graphs: Option>, - hostname: String, - port: u16, - shard_uds_path: String, - master_addr: String, - master_port: usize, - huggingface_hub_cache: Option, - weights_cache_override: Option, - disable_custom_kernels: bool, - cuda_memory_fraction: f32, - rope_scaling: Option, - rope_factor: Option, - json_output: bool, - otlp_endpoint: Option, - cors_allow_origin: Vec, - watermark_gamma: Option, - watermark_delta: Option, - ngrok: bool, - ngrok_authtoken: Option, - ngrok_edge: Option, - tokenizer_config_path: Option, - disable_grammar_support: bool, - env: bool, - max_client_batch_size: usize, - webserver_callback: Box Result<(), LauncherError>>, -) -> Result<(), LauncherError> { - let args = Args { - model_id, - revision, - validation_workers, - sharded, - num_shard, - quantize, - speculate, - dtype, - trust_remote_code, - max_concurrent_requests, - max_best_of, - max_stop_sequences, - max_top_n_tokens, - max_input_tokens, - max_input_length, - max_total_tokens, - waiting_served_ratio, - max_batch_prefill_tokens, - max_batch_total_tokens, - max_waiting_tokens, - max_batch_size, - cuda_graphs, - hostname, - port, - shard_uds_path, - master_addr, - master_port, - huggingface_hub_cache, - weights_cache_override, - disable_custom_kernels, - cuda_memory_fraction, - rope_scaling, - rope_factor, - json_output, - otlp_endpoint, - cors_allow_origin, - watermark_gamma, - watermark_delta, - ngrok, - ngrok_authtoken, - ngrok_edge, - tokenizer_config_path, - disable_grammar_support, - env, - max_client_batch_size, - }; - - // Filter events with LOG_LEVEL - let env_filter = - EnvFilter::try_from_env("LOG_LEVEL").unwrap_or_else(|_| EnvFilter::new("info")); - - if args.json_output { - tracing_subscriber::fmt() - .with_env_filter(env_filter) - .json() - .init(); - } else { - tracing_subscriber::fmt() - .with_env_filter(env_filter) - .compact() - .init(); - } - - if args.env { - let env_runtime = env_runtime::Env::new(); - tracing::info!("{}", env_runtime); - } - - tracing::info!("{:#?}", args); - - let get_max_position_embeddings = || -> Result> { - let model_id = args.model_id.clone(); - let mut path = std::path::Path::new(&args.model_id).to_path_buf(); - let filename = if !path.exists() { - // Assume it's a hub id - let api = Api::new()?; - let repo = if let Some(ref revision) = args.revision { - api.repo(Repo::with_revision( - model_id, - RepoType::Model, - revision.to_string(), - )) - } else { - api.model(model_id) - }; - repo.get("config.json")? - } else { - path.push("config.json"); - path - }; - - let content = std::fs::read_to_string(filename)?; - let config: Config = serde_json::from_str(&content)?; - - // Quantization usually means you're even more RAM constrained. - let max_default = 4096; - - let max_position_embeddings = match (config.max_position_embeddings, config.max_seq_len) { - (Some(max_position_embeddings), _) | (None, Some(max_position_embeddings)) => { - if max_position_embeddings > max_default { - let max = max_position_embeddings; - if args.max_input_tokens.is_none() - && args.max_total_tokens.is_none() - && args.max_batch_prefill_tokens.is_none() - { - tracing::info!("Model supports up to {max} but tgi will now set its default to {max_default} instead. This is to save VRAM by refusing large prompts in order to allow more users on the same hardware. You can increase that size using `--max-batch-prefill-tokens={} --max-total-tokens={max} --max-input-tokens={}`.", max + 50, max - 1); - } - max_default - } else { - max_position_embeddings - } - } - _ => { - return Err(Box::new(LauncherError::ArgumentValidation( - "no max defined".to_string(), - ))); - } - }; - Ok(max_position_embeddings) - }; - let max_position_embeddings: usize = get_max_position_embeddings().unwrap_or(4096); - - let max_input_tokens = { - match (args.max_input_tokens, args.max_input_length) { - (Some(max_input_tokens), Some(max_input_length)) => { - return Err(LauncherError::ArgumentValidation( - format!("Both `max_input_tokens` ({max_input_tokens}) and `max_input_length` ({max_input_length}) are set. Please define only `max_input_tokens` as `max_input_length is deprecated for naming consistency.", - ))); - } - (Some(max_input_tokens), None) | (None, Some(max_input_tokens)) => max_input_tokens, - (None, None) => { - let value = max_position_embeddings - 1; - tracing::info!("Default `max_input_tokens` to {value}"); - value - } - } - }; - let max_total_tokens = { - match args.max_total_tokens { - Some(max_total_tokens) => max_total_tokens, - None => { - let value = max_position_embeddings; - tracing::info!("Default `max_total_tokens` to {value}"); - value - } - } - }; - let max_batch_prefill_tokens = { - match args.max_batch_prefill_tokens { - Some(max_batch_prefill_tokens) => max_batch_prefill_tokens, - None => { - let value: u32 = if let Some(max_batch_size) = args.max_batch_size { - max_batch_size * max_input_tokens - } else { - // Adding some edge in order to account for potential block_size alignement - // issue. - max_input_tokens + 50 - } as u32; - tracing::info!("Default `max_batch_prefill_tokens` to {value}"); - value - } - } - }; - - // Validate args - if max_input_tokens >= max_total_tokens { - return Err(LauncherError::ArgumentValidation( - "`max_input_tokens must be < `max_total_tokens`".to_string(), - )); - } - if max_input_tokens as u32 > max_batch_prefill_tokens { - return Err(LauncherError::ArgumentValidation(format!( - "`max_batch_prefill_tokens` must be >= `max_input_tokens`. Given: {} and {}", - max_batch_prefill_tokens, max_input_tokens - ))); - } - - let cuda_graphs = match (&args.cuda_graphs, &args.quantize) { - (Some(cuda_graphs), _) => cuda_graphs.iter().cloned().filter(|&c| c > 0).collect(), - #[allow(deprecated)] - ( - None, - Some( - Quantization::Bitsandbytes - | Quantization::BitsandbytesNF4 - | Quantization::BitsandbytesFP4, - ), - ) => { - tracing::info!("Bitsandbytes doesn't work with cuda graphs, deactivating them"); - vec![] - } - _ => { - let cuda_graphs = vec![1, 2, 4, 8, 16, 32]; - tracing::info!("Using default cuda graphs {cuda_graphs:?}"); - cuda_graphs - } - }; - - if args.validation_workers == 0 { - return Err(LauncherError::ArgumentValidation( - "`validation_workers` must be > 0".to_string(), - )); - } - if args.trust_remote_code { - tracing::warn!( - "`trust_remote_code` is set. Trusting that model `{}` do not contain malicious code.", - args.model_id - ); - } - - let num_shard = find_num_shards(args.sharded, args.num_shard)?; - if num_shard > 1 { - tracing::info!("Sharding model on {num_shard} processes"); - } - - if let Some(ref max_batch_total_tokens) = args.max_batch_total_tokens { - if max_batch_prefill_tokens > *max_batch_total_tokens { - return Err(LauncherError::ArgumentValidation(format!( - "`max_batch_prefill_tokens` must be <= `max_batch_total_tokens`. Given: {} and {}", - max_batch_prefill_tokens, max_batch_total_tokens - ))); - } - if max_total_tokens as u32 > *max_batch_total_tokens { - return Err(LauncherError::ArgumentValidation(format!( - "`max_total_tokens` must be <= `max_batch_total_tokens`. Given: {} and {}", - max_total_tokens, max_batch_total_tokens - ))); - } - } - - if args.ngrok { - if args.ngrok_authtoken.is_none() { - return Err(LauncherError::ArgumentValidation( - "`ngrok-authtoken` must be set when using ngrok tunneling".to_string(), - )); - } - - if args.ngrok_edge.is_none() { - return Err(LauncherError::ArgumentValidation( - "`ngrok-edge` must be set when using ngrok tunneling".to_string(), - )); - } - } - - // Signal handler - let running = Arc::new(AtomicBool::new(true)); - let r = running.clone(); - ctrlc::set_handler(move || { - r.store(false, Ordering::SeqCst); - }) - .expect("Error setting Ctrl-C handler"); - - // Download and convert model weights - download_convert_model(&args, running.clone())?; - - if !running.load(Ordering::SeqCst) { - // Launcher was asked to stop - return Ok(()); - } - - // Shared shutdown bool - let shutdown = Arc::new(AtomicBool::new(false)); - // Shared shutdown channel - // When shutting down, the main thread will wait for all senders to be dropped - let (shutdown_sender, shutdown_receiver) = mpsc::channel(); - - // Shared channel to track shard status - let (status_sender, status_receiver) = mpsc::channel(); - - spawn_shards( - num_shard, - &args, - cuda_graphs, - max_total_tokens, - shutdown.clone(), - &shutdown_receiver, - shutdown_sender, - &status_receiver, - status_sender, - running.clone(), - )?; - - // We might have received a termination signal - if !running.load(Ordering::SeqCst) { - shutdown_shards(shutdown, &shutdown_receiver); - return Ok(()); - } - - // let mut webserver = spawn_webserver( - // num_shard, - // args, - // max_input_tokens, - // max_total_tokens, - // max_batch_prefill_tokens, - // shutdown.clone(), - // &shutdown_receiver, - // ) - // .map_err(|err| { - // shutdown_shards(shutdown.clone(), &shutdown_receiver); - // err - // })?; - - webserver_callback()?; - - println!("Webserver started"); - - // Default exit code - let mut exit_code = Ok(()); - - while running.load(Ordering::SeqCst) { - if let Ok(ShardStatus::Failed(rank)) = status_receiver.try_recv() { - tracing::error!("Shard {rank} crashed"); - exit_code = Err(LauncherError::ShardFailed); - break; - }; - - // match webserver.try_wait().unwrap() { - // Some(_) => { - // tracing::error!("Webserver Crashed"); - // shutdown_shards(shutdown, &shutdown_receiver); - // return Err(LauncherError::WebserverFailed); - // } - // None => { - // sleep(Duration::from_millis(100)); - // } - // }; - } - - // Graceful termination - // terminate("webserver", webserver, Duration::from_secs(90)).unwrap(); - shutdown_shards(shutdown, &shutdown_receiver); - - exit_code -} diff --git a/launcher/src/main.rs b/launcher/src/main.rs index b9113478..2e885544 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -1,53 +1,5 @@ -use clap::Parser; -use text_generation_launcher::{launcher_main, Args, LauncherError}; +use text_generation_launcher::{internal_main_args, LauncherError}; fn main() -> Result<(), LauncherError> { - let args = Args::parse(); - launcher_main( - args.model_id, - args.revision, - args.validation_workers, - args.sharded, - args.num_shard, - args.quantize, - args.speculate, - args.dtype, - args.trust_remote_code, - args.max_concurrent_requests, - args.max_best_of, - args.max_stop_sequences, - args.max_top_n_tokens, - args.max_input_tokens, - args.max_input_length, - args.max_total_tokens, - args.waiting_served_ratio, - args.max_batch_prefill_tokens, - args.max_batch_total_tokens, - args.max_waiting_tokens, - args.max_batch_size, - args.cuda_graphs, - args.hostname, - args.port, - args.shard_uds_path, - args.master_addr, - args.master_port, - args.huggingface_hub_cache, - args.weights_cache_override, - args.disable_custom_kernels, - args.cuda_memory_fraction, - args.rope_scaling, - args.rope_factor, - args.json_output, - args.otlp_endpoint, - args.cors_allow_origin, - args.watermark_gamma, - args.watermark_delta, - args.ngrok, - args.ngrok_authtoken, - args.ngrok_edge, - args.tokenizer_config_path, - args.disable_grammar_support, - args.env, - args.max_client_batch_size, - ) + internal_main_args() } diff --git a/router/src/lib.rs b/router/src/lib.rs index 69c96d65..0cdd7873 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -7,6 +7,7 @@ pub mod server; mod validation; use axum::http::HeaderValue; +use clap::Parser; use config::Config; use hf_hub::api::tokio::{Api, ApiBuilder, ApiRepo}; use hf_hub::{Cache, Repo, RepoType}; @@ -175,6 +176,108 @@ pub enum RouterError { Axum(#[from] axum::BoxError), } +/// App Configuration +#[derive(Parser, Debug)] +#[clap(author, version, about, long_about = None)] +pub struct Args { + #[clap(default_value = "128", long, env)] + max_concurrent_requests: usize, + #[clap(default_value = "2", long, env)] + max_best_of: usize, + #[clap(default_value = "4", long, env)] + max_stop_sequences: usize, + #[clap(default_value = "5", long, env)] + max_top_n_tokens: u32, + #[clap(default_value = "1024", long, env)] + max_input_tokens: usize, + #[clap(default_value = "2048", long, env)] + max_total_tokens: usize, + #[clap(default_value = "1.2", long, env)] + waiting_served_ratio: f32, + #[clap(default_value = "4096", long, env)] + max_batch_prefill_tokens: u32, + #[clap(long, env)] + max_batch_total_tokens: Option, + #[clap(default_value = "20", long, env)] + max_waiting_tokens: usize, + #[clap(long, env)] + max_batch_size: Option, + #[clap(default_value = "0.0.0.0", long, env)] + hostname: String, + #[clap(default_value = "3000", long, short, env)] + port: u16, + #[clap(default_value = "/tmp/text-generation-server-0", long, env)] + master_shard_uds_path: String, + #[clap(default_value = "bigscience/bloom", long, env)] + tokenizer_name: String, + #[clap(long, env)] + tokenizer_config_path: Option, + #[clap(long, env)] + revision: Option, + #[clap(default_value = "2", long, env)] + validation_workers: usize, + #[clap(long, env)] + json_output: bool, + #[clap(long, env)] + otlp_endpoint: Option, + #[clap(long, env)] + cors_allow_origin: Option>, + #[clap(long, env)] + ngrok: bool, + #[clap(long, env)] + ngrok_authtoken: Option, + #[clap(long, env)] + ngrok_edge: Option, + #[clap(long, env, default_value_t = false)] + messages_api_enabled: bool, + #[clap(long, env, default_value_t = false)] + disable_grammar_support: bool, + #[clap(default_value = "4", long, env)] + max_client_batch_size: usize, +} + +pub async fn internal_main_args() -> Result<(), RouterError> { + let args: Vec = std::env::args() + // skips the first arg if it's python + .skip_while(|a| a.contains("python")) + .collect(); + let args = Args::parse_from(args); + + println!("{:?}", args); + let out = internal_main( + args.max_concurrent_requests, + args.max_best_of, + args.max_stop_sequences, + args.max_top_n_tokens, + args.max_input_tokens, + args.max_total_tokens, + args.waiting_served_ratio, + args.max_batch_prefill_tokens, + args.max_batch_total_tokens, + args.max_waiting_tokens, + args.max_batch_size, + args.hostname, + args.port, + args.master_shard_uds_path, + args.tokenizer_name, + args.tokenizer_config_path, + args.revision, + args.validation_workers, + args.json_output, + args.otlp_endpoint, + args.cors_allow_origin, + args.ngrok, + args.ngrok_authtoken, + args.ngrok_edge, + args.messages_api_enabled, + args.disable_grammar_support, + args.max_client_batch_size, + ) + .await; + println!("[internal_main_args] {:?}", out); + out +} + #[allow(clippy::too_many_arguments)] pub async fn internal_main( max_concurrent_requests: usize, diff --git a/router/src/main.rs b/router/src/main.rs index ca11801c..b59ec8e7 100644 --- a/router/src/main.rs +++ b/router/src/main.rs @@ -1,100 +1,7 @@ -use clap::Parser; -use text_generation_router::{internal_main, RouterError}; - -/// App Configuration -#[derive(Parser, Debug)] -#[clap(author, version, about, long_about = None)] -struct Args { - #[clap(default_value = "128", long, env)] - max_concurrent_requests: usize, - #[clap(default_value = "2", long, env)] - max_best_of: usize, - #[clap(default_value = "4", long, env)] - max_stop_sequences: usize, - #[clap(default_value = "5", long, env)] - max_top_n_tokens: u32, - #[clap(default_value = "1024", long, env)] - max_input_tokens: usize, - #[clap(default_value = "2048", long, env)] - max_total_tokens: usize, - #[clap(default_value = "1.2", long, env)] - waiting_served_ratio: f32, - #[clap(default_value = "4096", long, env)] - max_batch_prefill_tokens: u32, - #[clap(long, env)] - max_batch_total_tokens: Option, - #[clap(default_value = "20", long, env)] - max_waiting_tokens: usize, - #[clap(long, env)] - max_batch_size: Option, - #[clap(default_value = "0.0.0.0", long, env)] - hostname: String, - #[clap(default_value = "3000", long, short, env)] - port: u16, - #[clap(default_value = "/tmp/text-generation-server-0", long, env)] - master_shard_uds_path: String, - #[clap(default_value = "bigscience/bloom", long, env)] - tokenizer_name: String, - #[clap(long, env)] - tokenizer_config_path: Option, - #[clap(long, env)] - revision: Option, - #[clap(default_value = "2", long, env)] - validation_workers: usize, - #[clap(long, env)] - json_output: bool, - #[clap(long, env)] - otlp_endpoint: Option, - #[clap(long, env)] - cors_allow_origin: Option>, - #[clap(long, env)] - ngrok: bool, - #[clap(long, env)] - ngrok_authtoken: Option, - #[clap(long, env)] - ngrok_edge: Option, - #[clap(long, env, default_value_t = false)] - messages_api_enabled: bool, - #[clap(long, env, default_value_t = false)] - disable_grammar_support: bool, - #[clap(default_value = "4", long, env)] - max_client_batch_size: usize, -} +use text_generation_router::{internal_main_args, RouterError}; #[tokio::main] async fn main() -> Result<(), RouterError> { - // Get args - let args = Args::parse(); - - internal_main( - args.max_concurrent_requests, - args.max_best_of, - args.max_stop_sequences, - args.max_top_n_tokens, - args.max_input_tokens, - args.max_total_tokens, - args.waiting_served_ratio, - args.max_batch_prefill_tokens, - args.max_batch_total_tokens, - args.max_waiting_tokens, - args.max_batch_size, - args.hostname, - args.port, - args.master_shard_uds_path, - args.tokenizer_name, - args.tokenizer_config_path, - args.revision, - args.validation_workers, - args.json_output, - args.otlp_endpoint, - args.cors_allow_origin, - args.ngrok, - args.ngrok_authtoken, - args.ngrok_edge, - args.messages_api_enabled, - args.disable_grammar_support, - args.max_client_batch_size, - ) - .await?; + internal_main_args().await?; Ok(()) } diff --git a/tgi/Makefile b/tgi/Makefile index b1c844df..2f39cf5b 100644 --- a/tgi/Makefile +++ b/tgi/Makefile @@ -13,3 +13,5 @@ library-install: pip install -e . install: build comment-gitignore library-install remove-comment-gitignore + +quick-install: build library-install diff --git a/tgi/pyproject.toml b/tgi/pyproject.toml index 1c824608..8c05e4a4 100644 --- a/tgi/pyproject.toml +++ b/tgi/pyproject.toml @@ -31,3 +31,5 @@ python-packages = ["tgi", "text_generation_server"] [project.scripts] text-generation-server = "tgi:text_generation_server_cli_main" +text-generation-router = "tgi:text_generation_router_cli_main" +text-generation-launcher = "tgi:text_generation_launcher_cli_main" diff --git a/tgi/src/lib.rs b/tgi/src/lib.rs index 793dc64e..358aa6ab 100644 --- a/tgi/src/lib.rs +++ b/tgi/src/lib.rs @@ -1,6 +1,8 @@ use pyo3::{prelude::*, wrap_pyfunction}; -use text_generation_launcher::{launcher_main, launcher_main_without_server}; -use text_generation_router::internal_main; +use std::thread; +use text_generation_launcher::{internal_main, internal_main_args as internal_main_args_launcher}; +use text_generation_router::internal_main_args as internal_main_args_router; +use tokio::runtime::Runtime; #[allow(clippy::too_many_arguments)] #[pyfunction] @@ -100,7 +102,7 @@ fn rust_launcher( max_client_batch_size: usize, ) -> PyResult<&PyAny> { pyo3_asyncio::tokio::future_into_py(py, async move { - launcher_main( + internal_main( model_id, revision, validation_workers, @@ -153,251 +155,6 @@ fn rust_launcher( }) } -#[allow(clippy::too_many_arguments)] -#[pyfunction] -#[pyo3(signature = ( - model_id, - revision, - validation_workers, - sharded, - num_shard, - _quantize, - speculate, - _dtype, - trust_remote_code, - max_concurrent_requests, - max_best_of, - max_stop_sequences, - max_top_n_tokens, - max_input_tokens, - max_input_length, - max_total_tokens, - waiting_served_ratio, - max_batch_prefill_tokens, - max_batch_total_tokens, - max_waiting_tokens, - max_batch_size, - cuda_graphs, - hostname, - port, - shard_uds_path, - master_addr, - master_port, - huggingface_hub_cache, - weights_cache_override, - disable_custom_kernels, - cuda_memory_fraction, - _rope_scaling, - rope_factor, - json_output, - otlp_endpoint, - cors_allow_origin, - watermark_gamma, - watermark_delta, - ngrok, - ngrok_authtoken, - ngrok_edge, - tokenizer_config_path, - disable_grammar_support, - env, - max_client_batch_size, -))] -fn fully_packaged( - py: Python<'_>, - model_id: String, - revision: Option, - validation_workers: usize, - sharded: Option, - num_shard: Option, - _quantize: Option, // Option, - speculate: Option, - _dtype: Option, // Option, - trust_remote_code: bool, - max_concurrent_requests: usize, - max_best_of: usize, - max_stop_sequences: usize, - max_top_n_tokens: u32, - max_input_tokens: Option, - max_input_length: Option, - max_total_tokens: Option, - waiting_served_ratio: f32, - max_batch_prefill_tokens: Option, - max_batch_total_tokens: Option, - max_waiting_tokens: usize, - max_batch_size: Option, - cuda_graphs: Option>, - hostname: String, - port: u16, - shard_uds_path: String, - master_addr: String, - master_port: usize, - huggingface_hub_cache: Option, - weights_cache_override: Option, - disable_custom_kernels: bool, - cuda_memory_fraction: f32, - _rope_scaling: Option, // Option, - rope_factor: Option, - json_output: bool, - otlp_endpoint: Option, - cors_allow_origin: Vec, - watermark_gamma: Option, - watermark_delta: Option, - ngrok: bool, - ngrok_authtoken: Option, - ngrok_edge: Option, - tokenizer_config_path: Option, - disable_grammar_support: bool, - env: bool, - max_client_batch_size: usize, -) -> PyResult<&PyAny> { - pyo3_asyncio::tokio::future_into_py(py, async move { - use std::thread; - use tokio::runtime::Runtime; - - let model_id_clone = model_id.clone(); - let max_concurrent_requests_clone = max_concurrent_requests; - let max_best_of_clone = max_best_of; - let max_stop_sequences_clone = max_stop_sequences; - let max_top_n_tokens_clone = max_top_n_tokens; - let max_input_tokens_clone = max_input_tokens.unwrap_or(1024); - let max_total_tokens_clone = max_total_tokens.unwrap_or(2048); - let waiting_served_ratio_clone = waiting_served_ratio; - - let max_batch_prefill_tokens_clone = max_batch_prefill_tokens.unwrap_or(4096); - let max_batch_total_tokens_clone = max_batch_total_tokens; - let max_waiting_tokens_clone = max_waiting_tokens; - let max_batch_size_clone = max_batch_size; - let hostname_clone = hostname.clone(); - let port_clone = port; - - // TODO: fix this - let _shard_uds_path_clone = shard_uds_path.clone(); - - let tokenizer_config_path = tokenizer_config_path.clone(); - let revision = revision.clone(); - let validation_workers = validation_workers; - let json_output = json_output; - - let otlp_endpoint = otlp_endpoint.clone(); - let cors_allow_origin = cors_allow_origin.clone(); - let ngrok = ngrok; - let ngrok_authtoken = ngrok_authtoken.clone(); - let ngrok_edge = ngrok_edge.clone(); - let messages_api_enabled = true; - let disable_grammar_support = disable_grammar_support; - let max_client_batch_size = max_client_batch_size; - - let ngrok_clone = ngrok; - let ngrok_authtoken_clone = ngrok_authtoken.clone(); - let ngrok_edge_clone = ngrok_edge.clone(); - let messages_api_enabled_clone = messages_api_enabled; - let disable_grammar_support_clone = disable_grammar_support; - let max_client_batch_size_clone = max_client_batch_size; - - let tokenizer_config_path_clone = tokenizer_config_path.clone(); - let revision_clone = revision.clone(); - let validation_workers_clone = validation_workers; - let json_output_clone = json_output; - let otlp_endpoint_clone = otlp_endpoint.clone(); - - let webserver_callback = move || { - let handle = thread::spawn(move || { - let rt = Runtime::new().unwrap(); - rt.block_on(async { - internal_main( - max_concurrent_requests_clone, - max_best_of_clone, - max_stop_sequences_clone, - max_top_n_tokens_clone, - max_input_tokens_clone, - max_total_tokens_clone, - waiting_served_ratio_clone, - max_batch_prefill_tokens_clone, - max_batch_total_tokens_clone, - max_waiting_tokens_clone, - max_batch_size_clone, - hostname_clone, - port_clone, - "/tmp/text-generation-server-0".to_string(), - model_id_clone, - tokenizer_config_path_clone, - revision_clone, - validation_workers_clone, - json_output_clone, - otlp_endpoint_clone, - None, - ngrok_clone, - ngrok_authtoken_clone, - ngrok_edge_clone, - messages_api_enabled_clone, - disable_grammar_support_clone, - max_client_batch_size_clone, - ) - .await - }) - }); - match handle.join() { - Ok(_) => println!("Server exited successfully"), - Err(e) => println!("Server exited with error: {:?}", e), - } - Ok(()) - }; - - // parse the arguments and run the main function - launcher_main_without_server( - model_id, - revision, - validation_workers, - sharded, - num_shard, - None, - speculate, - None, - trust_remote_code, - max_concurrent_requests, - max_best_of, - max_stop_sequences, - max_top_n_tokens, - max_input_tokens, - max_input_length, - max_total_tokens, - waiting_served_ratio, - max_batch_prefill_tokens, - max_batch_total_tokens, - max_waiting_tokens, - max_batch_size, - cuda_graphs, - hostname, - port, - shard_uds_path, - master_addr, - master_port, - huggingface_hub_cache, - weights_cache_override, - disable_custom_kernels, - cuda_memory_fraction, - None, - rope_factor, - json_output, - otlp_endpoint, - cors_allow_origin, - watermark_gamma, - watermark_delta, - ngrok, - ngrok_authtoken, - ngrok_edge, - tokenizer_config_path, - disable_grammar_support, - env, - max_client_batch_size, - Box::new(webserver_callback), - ) - .unwrap(); - - Ok(Python::with_gil(|py| py.None())) - }) -} - /// Asynchronous sleep function. #[pyfunction] fn rust_sleep(py: Python<'_>) -> PyResult<&PyAny> { @@ -407,49 +164,38 @@ fn rust_sleep(py: Python<'_>) -> PyResult<&PyAny> { }) } -// TODO: remove hardcoding #[pyfunction] -fn rust_server(py: Python<'_>) -> PyResult<&PyAny> { - pyo3_asyncio::tokio::future_into_py(py, async { - let _ = internal_main( - 128, // max_concurrent_requests: usize, - 2, // max_best_of: usize, - 4, // max_stop_sequences: usize, - 5, // max_top_n_tokens: u32, - 1024, // max_input_tokens: usize, - 2048, // max_total_tokens: usize, - 1.2, // waiting_served_ratio: f32, - 4096, // max_batch_prefill_tokens: u32, - None, // max_batch_total_tokens: Option, - 20, // max_waiting_tokens: usize, - None, // max_batch_size: Option, - "0.0.0.0".to_string(), // hostname: String, - 3000, // port: u16, - "/tmp/text-generation-server-0".to_string(), // master_shard_uds_path: String, - "llava-hf/llava-v1.6-mistral-7b-hf".to_string(), // tokenizer_name: String, - None, // tokenizer_config_path: Option, - None, // revision: Option, - 2, // validation_workers: usize, - false, // json_output: bool, - None, // otlp_endpoint: Option, - None, // cors_allow_origin: Option>, - false, // ngrok: bool, - None, // ngrok_authtoken: Option, - None, // ngrok_edge: Option, - false, // messages_api_enabled: bool, - false, // disable_grammar_support: bool, - 4, // max_client_batch_size: usize, - ) - .await; - Ok(Python::with_gil(|py| py.None())) - }) +fn rust_router(_py: Python<'_>) -> PyResult { + let handle = thread::spawn(move || { + let rt = Runtime::new().unwrap(); + rt.block_on(async { internal_main_args_router().await }) + }); + match handle.join() { + Ok(thread_output) => match thread_output { + Ok(_) => println!("Inner server exited successfully"), + Err(e) => println!("Inner server exited with error: {:?}", e), + }, + Err(e) => { + println!("Server exited with error: {:?}", e); + } + } + Ok("Completed".to_string()) +} + +#[pyfunction] +fn rust_launcher_cli(_py: Python<'_>) -> PyResult { + match internal_main_args_launcher() { + Ok(_) => println!("Server exited successfully"), + Err(e) => println!("Server exited with error: {:?}", e), + } + Ok("Completed".to_string()) } #[pymodule] fn tgi(_py: Python, m: &PyModule) -> PyResult<()> { m.add_function(wrap_pyfunction!(rust_sleep, m)?)?; - m.add_function(wrap_pyfunction!(rust_server, m)?)?; + m.add_function(wrap_pyfunction!(rust_router, m)?)?; m.add_function(wrap_pyfunction!(rust_launcher, m)?)?; - m.add_function(wrap_pyfunction!(fully_packaged, m)?)?; + m.add_function(wrap_pyfunction!(rust_launcher_cli, m)?)?; Ok(()) } diff --git a/tgi/tgi/__init__.py b/tgi/tgi/__init__.py index 2e9684a5..b7555c96 100644 --- a/tgi/tgi/__init__.py +++ b/tgi/tgi/__init__.py @@ -1,9 +1,8 @@ from .tgi import * import threading -from tgi import rust_launcher, rust_sleep, fully_packaged +from tgi import rust_router, rust_launcher, rust_launcher_cli import asyncio from dataclasses import dataclass, asdict -import sys from text_generation_server.cli import app # add the rust_launcher coroutine to the __all__ list @@ -17,6 +16,14 @@ def text_generation_server_cli_main(): app() +def text_generation_router_cli_main(): + rust_router() + + +def text_generation_launcher_cli_main(): + rust_launcher_cli() + + @dataclass class Args: model_id = "google/gemma-2b-it" @@ -81,7 +88,7 @@ class TGI(object): print(args) args = Args(**args) try: - await fully_packaged( + await rust_launcher( args.model_id, args.revision, args.validation_workers,