feat: bundle launcher and refactor cli wrappers

This commit is contained in:
drbh 2024-05-18 01:22:25 +00:00
parent af2b2e8388
commit 30f4deba77
8 changed files with 209 additions and 819 deletions

View File

@ -1260,8 +1260,64 @@ fn terminate(process_name: &str, mut process: Child, timeout: Duration) -> io::R
Ok(exit_status)
}
pub fn internal_main_args() -> Result<(), LauncherError> {
let args: Vec<String> = std::env::args()
// skips the first arg if it's python
.skip_while(|a| a.contains("python"))
.collect();
let args = Args::parse_from(args);
internal_main(
args.model_id,
args.revision,
args.validation_workers,
args.sharded,
args.num_shard,
args.quantize,
args.speculate,
args.dtype,
args.trust_remote_code,
args.max_concurrent_requests,
args.max_best_of,
args.max_stop_sequences,
args.max_top_n_tokens,
args.max_input_tokens,
args.max_input_length,
args.max_total_tokens,
args.waiting_served_ratio,
args.max_batch_prefill_tokens,
args.max_batch_total_tokens,
args.max_waiting_tokens,
args.max_batch_size,
args.cuda_graphs,
args.hostname,
args.port,
args.shard_uds_path,
args.master_addr,
args.master_port,
args.huggingface_hub_cache,
args.weights_cache_override,
args.disable_custom_kernels,
args.cuda_memory_fraction,
args.rope_scaling,
args.rope_factor,
args.json_output,
args.otlp_endpoint,
args.cors_allow_origin,
args.watermark_gamma,
args.watermark_delta,
args.ngrok,
args.ngrok_authtoken,
args.ngrok_edge,
args.tokenizer_config_path,
args.disable_grammar_support,
args.env,
args.max_client_batch_size,
)
}
#[allow(clippy::too_many_arguments)]
pub fn launcher_main(
pub fn internal_main(
model_id: String,
revision: Option<String>,
validation_workers: usize,
@ -1639,388 +1695,3 @@ pub fn launcher_main(
exit_code
}
#[allow(clippy::too_many_arguments)]
pub fn launcher_main_without_server(
model_id: String,
revision: Option<String>,
validation_workers: usize,
sharded: Option<bool>,
num_shard: Option<usize>,
quantize: Option<Quantization>,
speculate: Option<usize>,
dtype: Option<Dtype>,
trust_remote_code: bool,
max_concurrent_requests: usize,
max_best_of: usize,
max_stop_sequences: usize,
max_top_n_tokens: u32,
max_input_tokens: Option<usize>,
max_input_length: Option<usize>,
max_total_tokens: Option<usize>,
waiting_served_ratio: f32,
max_batch_prefill_tokens: Option<u32>,
max_batch_total_tokens: Option<u32>,
max_waiting_tokens: usize,
max_batch_size: Option<usize>,
cuda_graphs: Option<Vec<usize>>,
hostname: String,
port: u16,
shard_uds_path: String,
master_addr: String,
master_port: usize,
huggingface_hub_cache: Option<String>,
weights_cache_override: Option<String>,
disable_custom_kernels: bool,
cuda_memory_fraction: f32,
rope_scaling: Option<RopeScaling>,
rope_factor: Option<f32>,
json_output: bool,
otlp_endpoint: Option<String>,
cors_allow_origin: Vec<String>,
watermark_gamma: Option<f32>,
watermark_delta: Option<f32>,
ngrok: bool,
ngrok_authtoken: Option<String>,
ngrok_edge: Option<String>,
tokenizer_config_path: Option<String>,
disable_grammar_support: bool,
env: bool,
max_client_batch_size: usize,
webserver_callback: Box<dyn FnOnce() -> Result<(), LauncherError>>,
) -> Result<(), LauncherError> {
let args = Args {
model_id,
revision,
validation_workers,
sharded,
num_shard,
quantize,
speculate,
dtype,
trust_remote_code,
max_concurrent_requests,
max_best_of,
max_stop_sequences,
max_top_n_tokens,
max_input_tokens,
max_input_length,
max_total_tokens,
waiting_served_ratio,
max_batch_prefill_tokens,
max_batch_total_tokens,
max_waiting_tokens,
max_batch_size,
cuda_graphs,
hostname,
port,
shard_uds_path,
master_addr,
master_port,
huggingface_hub_cache,
weights_cache_override,
disable_custom_kernels,
cuda_memory_fraction,
rope_scaling,
rope_factor,
json_output,
otlp_endpoint,
cors_allow_origin,
watermark_gamma,
watermark_delta,
ngrok,
ngrok_authtoken,
ngrok_edge,
tokenizer_config_path,
disable_grammar_support,
env,
max_client_batch_size,
};
// Filter events with LOG_LEVEL
let env_filter =
EnvFilter::try_from_env("LOG_LEVEL").unwrap_or_else(|_| EnvFilter::new("info"));
if args.json_output {
tracing_subscriber::fmt()
.with_env_filter(env_filter)
.json()
.init();
} else {
tracing_subscriber::fmt()
.with_env_filter(env_filter)
.compact()
.init();
}
if args.env {
let env_runtime = env_runtime::Env::new();
tracing::info!("{}", env_runtime);
}
tracing::info!("{:#?}", args);
let get_max_position_embeddings = || -> Result<usize, Box<dyn std::error::Error>> {
let model_id = args.model_id.clone();
let mut path = std::path::Path::new(&args.model_id).to_path_buf();
let filename = if !path.exists() {
// Assume it's a hub id
let api = Api::new()?;
let repo = if let Some(ref revision) = args.revision {
api.repo(Repo::with_revision(
model_id,
RepoType::Model,
revision.to_string(),
))
} else {
api.model(model_id)
};
repo.get("config.json")?
} else {
path.push("config.json");
path
};
let content = std::fs::read_to_string(filename)?;
let config: Config = serde_json::from_str(&content)?;
// Quantization usually means you're even more RAM constrained.
let max_default = 4096;
let max_position_embeddings = match (config.max_position_embeddings, config.max_seq_len) {
(Some(max_position_embeddings), _) | (None, Some(max_position_embeddings)) => {
if max_position_embeddings > max_default {
let max = max_position_embeddings;
if args.max_input_tokens.is_none()
&& args.max_total_tokens.is_none()
&& args.max_batch_prefill_tokens.is_none()
{
tracing::info!("Model supports up to {max} but tgi will now set its default to {max_default} instead. This is to save VRAM by refusing large prompts in order to allow more users on the same hardware. You can increase that size using `--max-batch-prefill-tokens={} --max-total-tokens={max} --max-input-tokens={}`.", max + 50, max - 1);
}
max_default
} else {
max_position_embeddings
}
}
_ => {
return Err(Box::new(LauncherError::ArgumentValidation(
"no max defined".to_string(),
)));
}
};
Ok(max_position_embeddings)
};
let max_position_embeddings: usize = get_max_position_embeddings().unwrap_or(4096);
let max_input_tokens = {
match (args.max_input_tokens, args.max_input_length) {
(Some(max_input_tokens), Some(max_input_length)) => {
return Err(LauncherError::ArgumentValidation(
format!("Both `max_input_tokens` ({max_input_tokens}) and `max_input_length` ({max_input_length}) are set. Please define only `max_input_tokens` as `max_input_length is deprecated for naming consistency.",
)));
}
(Some(max_input_tokens), None) | (None, Some(max_input_tokens)) => max_input_tokens,
(None, None) => {
let value = max_position_embeddings - 1;
tracing::info!("Default `max_input_tokens` to {value}");
value
}
}
};
let max_total_tokens = {
match args.max_total_tokens {
Some(max_total_tokens) => max_total_tokens,
None => {
let value = max_position_embeddings;
tracing::info!("Default `max_total_tokens` to {value}");
value
}
}
};
let max_batch_prefill_tokens = {
match args.max_batch_prefill_tokens {
Some(max_batch_prefill_tokens) => max_batch_prefill_tokens,
None => {
let value: u32 = if let Some(max_batch_size) = args.max_batch_size {
max_batch_size * max_input_tokens
} else {
// Adding some edge in order to account for potential block_size alignement
// issue.
max_input_tokens + 50
} as u32;
tracing::info!("Default `max_batch_prefill_tokens` to {value}");
value
}
}
};
// Validate args
if max_input_tokens >= max_total_tokens {
return Err(LauncherError::ArgumentValidation(
"`max_input_tokens must be < `max_total_tokens`".to_string(),
));
}
if max_input_tokens as u32 > max_batch_prefill_tokens {
return Err(LauncherError::ArgumentValidation(format!(
"`max_batch_prefill_tokens` must be >= `max_input_tokens`. Given: {} and {}",
max_batch_prefill_tokens, max_input_tokens
)));
}
let cuda_graphs = match (&args.cuda_graphs, &args.quantize) {
(Some(cuda_graphs), _) => cuda_graphs.iter().cloned().filter(|&c| c > 0).collect(),
#[allow(deprecated)]
(
None,
Some(
Quantization::Bitsandbytes
| Quantization::BitsandbytesNF4
| Quantization::BitsandbytesFP4,
),
) => {
tracing::info!("Bitsandbytes doesn't work with cuda graphs, deactivating them");
vec![]
}
_ => {
let cuda_graphs = vec![1, 2, 4, 8, 16, 32];
tracing::info!("Using default cuda graphs {cuda_graphs:?}");
cuda_graphs
}
};
if args.validation_workers == 0 {
return Err(LauncherError::ArgumentValidation(
"`validation_workers` must be > 0".to_string(),
));
}
if args.trust_remote_code {
tracing::warn!(
"`trust_remote_code` is set. Trusting that model `{}` do not contain malicious code.",
args.model_id
);
}
let num_shard = find_num_shards(args.sharded, args.num_shard)?;
if num_shard > 1 {
tracing::info!("Sharding model on {num_shard} processes");
}
if let Some(ref max_batch_total_tokens) = args.max_batch_total_tokens {
if max_batch_prefill_tokens > *max_batch_total_tokens {
return Err(LauncherError::ArgumentValidation(format!(
"`max_batch_prefill_tokens` must be <= `max_batch_total_tokens`. Given: {} and {}",
max_batch_prefill_tokens, max_batch_total_tokens
)));
}
if max_total_tokens as u32 > *max_batch_total_tokens {
return Err(LauncherError::ArgumentValidation(format!(
"`max_total_tokens` must be <= `max_batch_total_tokens`. Given: {} and {}",
max_total_tokens, max_batch_total_tokens
)));
}
}
if args.ngrok {
if args.ngrok_authtoken.is_none() {
return Err(LauncherError::ArgumentValidation(
"`ngrok-authtoken` must be set when using ngrok tunneling".to_string(),
));
}
if args.ngrok_edge.is_none() {
return Err(LauncherError::ArgumentValidation(
"`ngrok-edge` must be set when using ngrok tunneling".to_string(),
));
}
}
// Signal handler
let running = Arc::new(AtomicBool::new(true));
let r = running.clone();
ctrlc::set_handler(move || {
r.store(false, Ordering::SeqCst);
})
.expect("Error setting Ctrl-C handler");
// Download and convert model weights
download_convert_model(&args, running.clone())?;
if !running.load(Ordering::SeqCst) {
// Launcher was asked to stop
return Ok(());
}
// Shared shutdown bool
let shutdown = Arc::new(AtomicBool::new(false));
// Shared shutdown channel
// When shutting down, the main thread will wait for all senders to be dropped
let (shutdown_sender, shutdown_receiver) = mpsc::channel();
// Shared channel to track shard status
let (status_sender, status_receiver) = mpsc::channel();
spawn_shards(
num_shard,
&args,
cuda_graphs,
max_total_tokens,
shutdown.clone(),
&shutdown_receiver,
shutdown_sender,
&status_receiver,
status_sender,
running.clone(),
)?;
// We might have received a termination signal
if !running.load(Ordering::SeqCst) {
shutdown_shards(shutdown, &shutdown_receiver);
return Ok(());
}
// let mut webserver = spawn_webserver(
// num_shard,
// args,
// max_input_tokens,
// max_total_tokens,
// max_batch_prefill_tokens,
// shutdown.clone(),
// &shutdown_receiver,
// )
// .map_err(|err| {
// shutdown_shards(shutdown.clone(), &shutdown_receiver);
// err
// })?;
webserver_callback()?;
println!("Webserver started");
// Default exit code
let mut exit_code = Ok(());
while running.load(Ordering::SeqCst) {
if let Ok(ShardStatus::Failed(rank)) = status_receiver.try_recv() {
tracing::error!("Shard {rank} crashed");
exit_code = Err(LauncherError::ShardFailed);
break;
};
// match webserver.try_wait().unwrap() {
// Some(_) => {
// tracing::error!("Webserver Crashed");
// shutdown_shards(shutdown, &shutdown_receiver);
// return Err(LauncherError::WebserverFailed);
// }
// None => {
// sleep(Duration::from_millis(100));
// }
// };
}
// Graceful termination
// terminate("webserver", webserver, Duration::from_secs(90)).unwrap();
shutdown_shards(shutdown, &shutdown_receiver);
exit_code
}

View File

@ -1,53 +1,5 @@
use clap::Parser;
use text_generation_launcher::{launcher_main, Args, LauncherError};
use text_generation_launcher::{internal_main_args, LauncherError};
fn main() -> Result<(), LauncherError> {
let args = Args::parse();
launcher_main(
args.model_id,
args.revision,
args.validation_workers,
args.sharded,
args.num_shard,
args.quantize,
args.speculate,
args.dtype,
args.trust_remote_code,
args.max_concurrent_requests,
args.max_best_of,
args.max_stop_sequences,
args.max_top_n_tokens,
args.max_input_tokens,
args.max_input_length,
args.max_total_tokens,
args.waiting_served_ratio,
args.max_batch_prefill_tokens,
args.max_batch_total_tokens,
args.max_waiting_tokens,
args.max_batch_size,
args.cuda_graphs,
args.hostname,
args.port,
args.shard_uds_path,
args.master_addr,
args.master_port,
args.huggingface_hub_cache,
args.weights_cache_override,
args.disable_custom_kernels,
args.cuda_memory_fraction,
args.rope_scaling,
args.rope_factor,
args.json_output,
args.otlp_endpoint,
args.cors_allow_origin,
args.watermark_gamma,
args.watermark_delta,
args.ngrok,
args.ngrok_authtoken,
args.ngrok_edge,
args.tokenizer_config_path,
args.disable_grammar_support,
args.env,
args.max_client_batch_size,
)
internal_main_args()
}

View File

@ -7,6 +7,7 @@ pub mod server;
mod validation;
use axum::http::HeaderValue;
use clap::Parser;
use config::Config;
use hf_hub::api::tokio::{Api, ApiBuilder, ApiRepo};
use hf_hub::{Cache, Repo, RepoType};
@ -175,6 +176,108 @@ pub enum RouterError {
Axum(#[from] axum::BoxError),
}
/// App Configuration
#[derive(Parser, Debug)]
#[clap(author, version, about, long_about = None)]
pub struct Args {
#[clap(default_value = "128", long, env)]
max_concurrent_requests: usize,
#[clap(default_value = "2", long, env)]
max_best_of: usize,
#[clap(default_value = "4", long, env)]
max_stop_sequences: usize,
#[clap(default_value = "5", long, env)]
max_top_n_tokens: u32,
#[clap(default_value = "1024", long, env)]
max_input_tokens: usize,
#[clap(default_value = "2048", long, env)]
max_total_tokens: usize,
#[clap(default_value = "1.2", long, env)]
waiting_served_ratio: f32,
#[clap(default_value = "4096", long, env)]
max_batch_prefill_tokens: u32,
#[clap(long, env)]
max_batch_total_tokens: Option<u32>,
#[clap(default_value = "20", long, env)]
max_waiting_tokens: usize,
#[clap(long, env)]
max_batch_size: Option<usize>,
#[clap(default_value = "0.0.0.0", long, env)]
hostname: String,
#[clap(default_value = "3000", long, short, env)]
port: u16,
#[clap(default_value = "/tmp/text-generation-server-0", long, env)]
master_shard_uds_path: String,
#[clap(default_value = "bigscience/bloom", long, env)]
tokenizer_name: String,
#[clap(long, env)]
tokenizer_config_path: Option<String>,
#[clap(long, env)]
revision: Option<String>,
#[clap(default_value = "2", long, env)]
validation_workers: usize,
#[clap(long, env)]
json_output: bool,
#[clap(long, env)]
otlp_endpoint: Option<String>,
#[clap(long, env)]
cors_allow_origin: Option<Vec<String>>,
#[clap(long, env)]
ngrok: bool,
#[clap(long, env)]
ngrok_authtoken: Option<String>,
#[clap(long, env)]
ngrok_edge: Option<String>,
#[clap(long, env, default_value_t = false)]
messages_api_enabled: bool,
#[clap(long, env, default_value_t = false)]
disable_grammar_support: bool,
#[clap(default_value = "4", long, env)]
max_client_batch_size: usize,
}
pub async fn internal_main_args() -> Result<(), RouterError> {
let args: Vec<String> = std::env::args()
// skips the first arg if it's python
.skip_while(|a| a.contains("python"))
.collect();
let args = Args::parse_from(args);
println!("{:?}", args);
let out = internal_main(
args.max_concurrent_requests,
args.max_best_of,
args.max_stop_sequences,
args.max_top_n_tokens,
args.max_input_tokens,
args.max_total_tokens,
args.waiting_served_ratio,
args.max_batch_prefill_tokens,
args.max_batch_total_tokens,
args.max_waiting_tokens,
args.max_batch_size,
args.hostname,
args.port,
args.master_shard_uds_path,
args.tokenizer_name,
args.tokenizer_config_path,
args.revision,
args.validation_workers,
args.json_output,
args.otlp_endpoint,
args.cors_allow_origin,
args.ngrok,
args.ngrok_authtoken,
args.ngrok_edge,
args.messages_api_enabled,
args.disable_grammar_support,
args.max_client_batch_size,
)
.await;
println!("[internal_main_args] {:?}", out);
out
}
#[allow(clippy::too_many_arguments)]
pub async fn internal_main(
max_concurrent_requests: usize,

View File

@ -1,100 +1,7 @@
use clap::Parser;
use text_generation_router::{internal_main, RouterError};
/// App Configuration
#[derive(Parser, Debug)]
#[clap(author, version, about, long_about = None)]
struct Args {
#[clap(default_value = "128", long, env)]
max_concurrent_requests: usize,
#[clap(default_value = "2", long, env)]
max_best_of: usize,
#[clap(default_value = "4", long, env)]
max_stop_sequences: usize,
#[clap(default_value = "5", long, env)]
max_top_n_tokens: u32,
#[clap(default_value = "1024", long, env)]
max_input_tokens: usize,
#[clap(default_value = "2048", long, env)]
max_total_tokens: usize,
#[clap(default_value = "1.2", long, env)]
waiting_served_ratio: f32,
#[clap(default_value = "4096", long, env)]
max_batch_prefill_tokens: u32,
#[clap(long, env)]
max_batch_total_tokens: Option<u32>,
#[clap(default_value = "20", long, env)]
max_waiting_tokens: usize,
#[clap(long, env)]
max_batch_size: Option<usize>,
#[clap(default_value = "0.0.0.0", long, env)]
hostname: String,
#[clap(default_value = "3000", long, short, env)]
port: u16,
#[clap(default_value = "/tmp/text-generation-server-0", long, env)]
master_shard_uds_path: String,
#[clap(default_value = "bigscience/bloom", long, env)]
tokenizer_name: String,
#[clap(long, env)]
tokenizer_config_path: Option<String>,
#[clap(long, env)]
revision: Option<String>,
#[clap(default_value = "2", long, env)]
validation_workers: usize,
#[clap(long, env)]
json_output: bool,
#[clap(long, env)]
otlp_endpoint: Option<String>,
#[clap(long, env)]
cors_allow_origin: Option<Vec<String>>,
#[clap(long, env)]
ngrok: bool,
#[clap(long, env)]
ngrok_authtoken: Option<String>,
#[clap(long, env)]
ngrok_edge: Option<String>,
#[clap(long, env, default_value_t = false)]
messages_api_enabled: bool,
#[clap(long, env, default_value_t = false)]
disable_grammar_support: bool,
#[clap(default_value = "4", long, env)]
max_client_batch_size: usize,
}
use text_generation_router::{internal_main_args, RouterError};
#[tokio::main]
async fn main() -> Result<(), RouterError> {
// Get args
let args = Args::parse();
internal_main(
args.max_concurrent_requests,
args.max_best_of,
args.max_stop_sequences,
args.max_top_n_tokens,
args.max_input_tokens,
args.max_total_tokens,
args.waiting_served_ratio,
args.max_batch_prefill_tokens,
args.max_batch_total_tokens,
args.max_waiting_tokens,
args.max_batch_size,
args.hostname,
args.port,
args.master_shard_uds_path,
args.tokenizer_name,
args.tokenizer_config_path,
args.revision,
args.validation_workers,
args.json_output,
args.otlp_endpoint,
args.cors_allow_origin,
args.ngrok,
args.ngrok_authtoken,
args.ngrok_edge,
args.messages_api_enabled,
args.disable_grammar_support,
args.max_client_batch_size,
)
.await?;
internal_main_args().await?;
Ok(())
}

View File

@ -13,3 +13,5 @@ library-install:
pip install -e .
install: build comment-gitignore library-install remove-comment-gitignore
quick-install: build library-install

View File

@ -31,3 +31,5 @@ python-packages = ["tgi", "text_generation_server"]
[project.scripts]
text-generation-server = "tgi:text_generation_server_cli_main"
text-generation-router = "tgi:text_generation_router_cli_main"
text-generation-launcher = "tgi:text_generation_launcher_cli_main"

View File

@ -1,6 +1,8 @@
use pyo3::{prelude::*, wrap_pyfunction};
use text_generation_launcher::{launcher_main, launcher_main_without_server};
use text_generation_router::internal_main;
use std::thread;
use text_generation_launcher::{internal_main, internal_main_args as internal_main_args_launcher};
use text_generation_router::internal_main_args as internal_main_args_router;
use tokio::runtime::Runtime;
#[allow(clippy::too_many_arguments)]
#[pyfunction]
@ -100,7 +102,7 @@ fn rust_launcher(
max_client_batch_size: usize,
) -> PyResult<&PyAny> {
pyo3_asyncio::tokio::future_into_py(py, async move {
launcher_main(
internal_main(
model_id,
revision,
validation_workers,
@ -153,251 +155,6 @@ fn rust_launcher(
})
}
#[allow(clippy::too_many_arguments)]
#[pyfunction]
#[pyo3(signature = (
model_id,
revision,
validation_workers,
sharded,
num_shard,
_quantize,
speculate,
_dtype,
trust_remote_code,
max_concurrent_requests,
max_best_of,
max_stop_sequences,
max_top_n_tokens,
max_input_tokens,
max_input_length,
max_total_tokens,
waiting_served_ratio,
max_batch_prefill_tokens,
max_batch_total_tokens,
max_waiting_tokens,
max_batch_size,
cuda_graphs,
hostname,
port,
shard_uds_path,
master_addr,
master_port,
huggingface_hub_cache,
weights_cache_override,
disable_custom_kernels,
cuda_memory_fraction,
_rope_scaling,
rope_factor,
json_output,
otlp_endpoint,
cors_allow_origin,
watermark_gamma,
watermark_delta,
ngrok,
ngrok_authtoken,
ngrok_edge,
tokenizer_config_path,
disable_grammar_support,
env,
max_client_batch_size,
))]
fn fully_packaged(
py: Python<'_>,
model_id: String,
revision: Option<String>,
validation_workers: usize,
sharded: Option<bool>,
num_shard: Option<usize>,
_quantize: Option<String>, // Option<Quantization>,
speculate: Option<usize>,
_dtype: Option<String>, // Option<Dtype>,
trust_remote_code: bool,
max_concurrent_requests: usize,
max_best_of: usize,
max_stop_sequences: usize,
max_top_n_tokens: u32,
max_input_tokens: Option<usize>,
max_input_length: Option<usize>,
max_total_tokens: Option<usize>,
waiting_served_ratio: f32,
max_batch_prefill_tokens: Option<u32>,
max_batch_total_tokens: Option<u32>,
max_waiting_tokens: usize,
max_batch_size: Option<usize>,
cuda_graphs: Option<Vec<usize>>,
hostname: String,
port: u16,
shard_uds_path: String,
master_addr: String,
master_port: usize,
huggingface_hub_cache: Option<String>,
weights_cache_override: Option<String>,
disable_custom_kernels: bool,
cuda_memory_fraction: f32,
_rope_scaling: Option<f32>, // Option<RopeScaling>,
rope_factor: Option<f32>,
json_output: bool,
otlp_endpoint: Option<String>,
cors_allow_origin: Vec<String>,
watermark_gamma: Option<f32>,
watermark_delta: Option<f32>,
ngrok: bool,
ngrok_authtoken: Option<String>,
ngrok_edge: Option<String>,
tokenizer_config_path: Option<String>,
disable_grammar_support: bool,
env: bool,
max_client_batch_size: usize,
) -> PyResult<&PyAny> {
pyo3_asyncio::tokio::future_into_py(py, async move {
use std::thread;
use tokio::runtime::Runtime;
let model_id_clone = model_id.clone();
let max_concurrent_requests_clone = max_concurrent_requests;
let max_best_of_clone = max_best_of;
let max_stop_sequences_clone = max_stop_sequences;
let max_top_n_tokens_clone = max_top_n_tokens;
let max_input_tokens_clone = max_input_tokens.unwrap_or(1024);
let max_total_tokens_clone = max_total_tokens.unwrap_or(2048);
let waiting_served_ratio_clone = waiting_served_ratio;
let max_batch_prefill_tokens_clone = max_batch_prefill_tokens.unwrap_or(4096);
let max_batch_total_tokens_clone = max_batch_total_tokens;
let max_waiting_tokens_clone = max_waiting_tokens;
let max_batch_size_clone = max_batch_size;
let hostname_clone = hostname.clone();
let port_clone = port;
// TODO: fix this
let _shard_uds_path_clone = shard_uds_path.clone();
let tokenizer_config_path = tokenizer_config_path.clone();
let revision = revision.clone();
let validation_workers = validation_workers;
let json_output = json_output;
let otlp_endpoint = otlp_endpoint.clone();
let cors_allow_origin = cors_allow_origin.clone();
let ngrok = ngrok;
let ngrok_authtoken = ngrok_authtoken.clone();
let ngrok_edge = ngrok_edge.clone();
let messages_api_enabled = true;
let disable_grammar_support = disable_grammar_support;
let max_client_batch_size = max_client_batch_size;
let ngrok_clone = ngrok;
let ngrok_authtoken_clone = ngrok_authtoken.clone();
let ngrok_edge_clone = ngrok_edge.clone();
let messages_api_enabled_clone = messages_api_enabled;
let disable_grammar_support_clone = disable_grammar_support;
let max_client_batch_size_clone = max_client_batch_size;
let tokenizer_config_path_clone = tokenizer_config_path.clone();
let revision_clone = revision.clone();
let validation_workers_clone = validation_workers;
let json_output_clone = json_output;
let otlp_endpoint_clone = otlp_endpoint.clone();
let webserver_callback = move || {
let handle = thread::spawn(move || {
let rt = Runtime::new().unwrap();
rt.block_on(async {
internal_main(
max_concurrent_requests_clone,
max_best_of_clone,
max_stop_sequences_clone,
max_top_n_tokens_clone,
max_input_tokens_clone,
max_total_tokens_clone,
waiting_served_ratio_clone,
max_batch_prefill_tokens_clone,
max_batch_total_tokens_clone,
max_waiting_tokens_clone,
max_batch_size_clone,
hostname_clone,
port_clone,
"/tmp/text-generation-server-0".to_string(),
model_id_clone,
tokenizer_config_path_clone,
revision_clone,
validation_workers_clone,
json_output_clone,
otlp_endpoint_clone,
None,
ngrok_clone,
ngrok_authtoken_clone,
ngrok_edge_clone,
messages_api_enabled_clone,
disable_grammar_support_clone,
max_client_batch_size_clone,
)
.await
})
});
match handle.join() {
Ok(_) => println!("Server exited successfully"),
Err(e) => println!("Server exited with error: {:?}", e),
}
Ok(())
};
// parse the arguments and run the main function
launcher_main_without_server(
model_id,
revision,
validation_workers,
sharded,
num_shard,
None,
speculate,
None,
trust_remote_code,
max_concurrent_requests,
max_best_of,
max_stop_sequences,
max_top_n_tokens,
max_input_tokens,
max_input_length,
max_total_tokens,
waiting_served_ratio,
max_batch_prefill_tokens,
max_batch_total_tokens,
max_waiting_tokens,
max_batch_size,
cuda_graphs,
hostname,
port,
shard_uds_path,
master_addr,
master_port,
huggingface_hub_cache,
weights_cache_override,
disable_custom_kernels,
cuda_memory_fraction,
None,
rope_factor,
json_output,
otlp_endpoint,
cors_allow_origin,
watermark_gamma,
watermark_delta,
ngrok,
ngrok_authtoken,
ngrok_edge,
tokenizer_config_path,
disable_grammar_support,
env,
max_client_batch_size,
Box::new(webserver_callback),
)
.unwrap();
Ok(Python::with_gil(|py| py.None()))
})
}
/// Asynchronous sleep function.
#[pyfunction]
fn rust_sleep(py: Python<'_>) -> PyResult<&PyAny> {
@ -407,49 +164,38 @@ fn rust_sleep(py: Python<'_>) -> PyResult<&PyAny> {
})
}
// TODO: remove hardcoding
#[pyfunction]
fn rust_server(py: Python<'_>) -> PyResult<&PyAny> {
pyo3_asyncio::tokio::future_into_py(py, async {
let _ = internal_main(
128, // max_concurrent_requests: usize,
2, // max_best_of: usize,
4, // max_stop_sequences: usize,
5, // max_top_n_tokens: u32,
1024, // max_input_tokens: usize,
2048, // max_total_tokens: usize,
1.2, // waiting_served_ratio: f32,
4096, // max_batch_prefill_tokens: u32,
None, // max_batch_total_tokens: Option<u32>,
20, // max_waiting_tokens: usize,
None, // max_batch_size: Option<usize>,
"0.0.0.0".to_string(), // hostname: String,
3000, // port: u16,
"/tmp/text-generation-server-0".to_string(), // master_shard_uds_path: String,
"llava-hf/llava-v1.6-mistral-7b-hf".to_string(), // tokenizer_name: String,
None, // tokenizer_config_path: Option<String>,
None, // revision: Option<String>,
2, // validation_workers: usize,
false, // json_output: bool,
None, // otlp_endpoint: Option<String>,
None, // cors_allow_origin: Option<Vec<String>>,
false, // ngrok: bool,
None, // ngrok_authtoken: Option<String>,
None, // ngrok_edge: Option<String>,
false, // messages_api_enabled: bool,
false, // disable_grammar_support: bool,
4, // max_client_batch_size: usize,
)
.await;
Ok(Python::with_gil(|py| py.None()))
})
fn rust_router(_py: Python<'_>) -> PyResult<String> {
let handle = thread::spawn(move || {
let rt = Runtime::new().unwrap();
rt.block_on(async { internal_main_args_router().await })
});
match handle.join() {
Ok(thread_output) => match thread_output {
Ok(_) => println!("Inner server exited successfully"),
Err(e) => println!("Inner server exited with error: {:?}", e),
},
Err(e) => {
println!("Server exited with error: {:?}", e);
}
}
Ok("Completed".to_string())
}
#[pyfunction]
fn rust_launcher_cli(_py: Python<'_>) -> PyResult<String> {
match internal_main_args_launcher() {
Ok(_) => println!("Server exited successfully"),
Err(e) => println!("Server exited with error: {:?}", e),
}
Ok("Completed".to_string())
}
#[pymodule]
fn tgi(_py: Python, m: &PyModule) -> PyResult<()> {
m.add_function(wrap_pyfunction!(rust_sleep, m)?)?;
m.add_function(wrap_pyfunction!(rust_server, m)?)?;
m.add_function(wrap_pyfunction!(rust_router, m)?)?;
m.add_function(wrap_pyfunction!(rust_launcher, m)?)?;
m.add_function(wrap_pyfunction!(fully_packaged, m)?)?;
m.add_function(wrap_pyfunction!(rust_launcher_cli, m)?)?;
Ok(())
}

View File

@ -1,9 +1,8 @@
from .tgi import *
import threading
from tgi import rust_launcher, rust_sleep, fully_packaged
from tgi import rust_router, rust_launcher, rust_launcher_cli
import asyncio
from dataclasses import dataclass, asdict
import sys
from text_generation_server.cli import app
# add the rust_launcher coroutine to the __all__ list
@ -17,6 +16,14 @@ def text_generation_server_cli_main():
app()
def text_generation_router_cli_main():
rust_router()
def text_generation_launcher_cli_main():
rust_launcher_cli()
@dataclass
class Args:
model_id = "google/gemma-2b-it"
@ -81,7 +88,7 @@ class TGI(object):
print(args)
args = Args(**args)
try:
await fully_packaged(
await rust_launcher(
args.model_id,
args.revision,
args.validation_workers,