parent
91d7267534
commit
fbeb1c4475
|
@ -950,13 +950,17 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "2b780635574b3d92f036890d8373433d6f9fc7abb320ee42a5c25897fc8ed732"
|
checksum = "2b780635574b3d92f036890d8373433d6f9fc7abb320ee42a5c25897fc8ed732"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"dirs 5.0.1",
|
"dirs 5.0.1",
|
||||||
|
"futures",
|
||||||
"indicatif",
|
"indicatif",
|
||||||
"log",
|
"log",
|
||||||
"native-tls",
|
"native-tls",
|
||||||
|
"num_cpus",
|
||||||
"rand",
|
"rand",
|
||||||
|
"reqwest",
|
||||||
"serde",
|
"serde",
|
||||||
"serde_json",
|
"serde_json",
|
||||||
"thiserror",
|
"thiserror",
|
||||||
|
"tokio",
|
||||||
"ureq",
|
"ureq",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
|
@ -21,6 +21,7 @@ axum-tracing-opentelemetry = "0.14.1"
|
||||||
text-generation-client = { path = "client" }
|
text-generation-client = { path = "client" }
|
||||||
clap = { version = "4.4.5", features = ["derive", "env"] }
|
clap = { version = "4.4.5", features = ["derive", "env"] }
|
||||||
futures = "0.3.28"
|
futures = "0.3.28"
|
||||||
|
hf-hub = { version = "0.3.0", features = ["tokio"] }
|
||||||
metrics = "0.21.1"
|
metrics = "0.21.1"
|
||||||
metrics-exporter-prometheus = { version = "0.12.1", features = [] }
|
metrics-exporter-prometheus = { version = "0.12.1", features = [] }
|
||||||
nohash-hasher = "0.2.0"
|
nohash-hasher = "0.2.0"
|
||||||
|
@ -41,7 +42,6 @@ tracing-subscriber = { version = "0.3.17", features = ["json", "env-filter"] }
|
||||||
utoipa = { version = "3.5.0", features = ["axum_extras"] }
|
utoipa = { version = "3.5.0", features = ["axum_extras"] }
|
||||||
utoipa-swagger-ui = { version = "3.1.5", features = ["axum"] }
|
utoipa-swagger-ui = { version = "3.1.5", features = ["axum"] }
|
||||||
ngrok = { version = "0.13.1", features = ["axum"], optional = true }
|
ngrok = { version = "0.13.1", features = ["axum"], optional = true }
|
||||||
hf-hub = "0.3.1"
|
|
||||||
init-tracing-opentelemetry = { version = "0.14.1", features = ["opentelemetry-otlp"] }
|
init-tracing-opentelemetry = { version = "0.14.1", features = ["opentelemetry-otlp"] }
|
||||||
|
|
||||||
[build-dependencies]
|
[build-dependencies]
|
||||||
|
|
|
@ -1,19 +1,22 @@
|
||||||
/// Text Generation Inference webserver entrypoint
|
|
||||||
use axum::http::HeaderValue;
|
use axum::http::HeaderValue;
|
||||||
use clap::Parser;
|
use clap::Parser;
|
||||||
|
use hf_hub::api::tokio::{Api, ApiBuilder, ApiRepo};
|
||||||
|
use hf_hub::{Repo, RepoType};
|
||||||
use opentelemetry::sdk::propagation::TraceContextPropagator;
|
use opentelemetry::sdk::propagation::TraceContextPropagator;
|
||||||
use opentelemetry::sdk::trace;
|
use opentelemetry::sdk::trace;
|
||||||
use opentelemetry::sdk::trace::Sampler;
|
use opentelemetry::sdk::trace::Sampler;
|
||||||
use opentelemetry::sdk::Resource;
|
use opentelemetry::sdk::Resource;
|
||||||
use opentelemetry::{global, KeyValue};
|
use opentelemetry::{global, KeyValue};
|
||||||
use opentelemetry_otlp::WithExportConfig;
|
use opentelemetry_otlp::WithExportConfig;
|
||||||
|
/// Text Generation Inference webserver entrypoint
|
||||||
|
use std::fs::File;
|
||||||
|
use std::io::BufReader;
|
||||||
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
|
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
|
||||||
use std::path::Path;
|
use std::path::Path;
|
||||||
use std::time::Duration;
|
|
||||||
use text_generation_client::{ClientError, ShardedClient};
|
use text_generation_client::{ClientError, ShardedClient};
|
||||||
use text_generation_router::{server, HubModelInfo};
|
use text_generation_router::{server, HubModelInfo};
|
||||||
use thiserror::Error;
|
use thiserror::Error;
|
||||||
use tokenizers::{FromPretrainedParameters, Tokenizer};
|
use tokenizers::Tokenizer;
|
||||||
use tower_http::cors::AllowOrigin;
|
use tower_http::cors::AllowOrigin;
|
||||||
use tracing_subscriber::layer::SubscriberExt;
|
use tracing_subscriber::layer::SubscriberExt;
|
||||||
use tracing_subscriber::util::SubscriberInitExt;
|
use tracing_subscriber::util::SubscriberInitExt;
|
||||||
|
@ -69,7 +72,8 @@ struct Args {
|
||||||
ngrok_edge: Option<String>,
|
ngrok_edge: Option<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
fn main() -> Result<(), RouterError> {
|
#[tokio::main]
|
||||||
|
async fn main() -> Result<(), RouterError> {
|
||||||
// Get args
|
// Get args
|
||||||
let args = Args::parse();
|
let args = Args::parse();
|
||||||
// Pattern match configuration
|
// Pattern match configuration
|
||||||
|
@ -98,6 +102,9 @@ fn main() -> Result<(), RouterError> {
|
||||||
ngrok_edge,
|
ngrok_edge,
|
||||||
} = args;
|
} = args;
|
||||||
|
|
||||||
|
// Launch Tokio runtime
|
||||||
|
init_logging(otlp_endpoint, json_output);
|
||||||
|
|
||||||
// Validate args
|
// Validate args
|
||||||
if max_input_length >= max_total_tokens {
|
if max_input_length >= max_total_tokens {
|
||||||
return Err(RouterError::ArgumentValidation(
|
return Err(RouterError::ArgumentValidation(
|
||||||
|
@ -141,146 +148,158 @@ fn main() -> Result<(), RouterError> {
|
||||||
// This will only be used to validate payloads
|
// This will only be used to validate payloads
|
||||||
let local_path = Path::new(&tokenizer_name);
|
let local_path = Path::new(&tokenizer_name);
|
||||||
let local_model = local_path.exists() && local_path.is_dir();
|
let local_model = local_path.exists() && local_path.is_dir();
|
||||||
let tokenizer = if local_model {
|
|
||||||
// Load local tokenizer
|
let (tokenizer, model_info) = if local_model {
|
||||||
Tokenizer::from_file(local_path.join("tokenizer.json")).ok()
|
// Get Model info
|
||||||
} else {
|
let model_info = HubModelInfo {
|
||||||
// Download and instantiate tokenizer
|
model_id: tokenizer_name.clone(),
|
||||||
// We need to download it outside of the Tokio runtime
|
sha: None,
|
||||||
let params = FromPretrainedParameters {
|
pipeline_tag: None,
|
||||||
revision: revision.clone().unwrap_or("main".to_string()),
|
|
||||||
auth_token: authorization_token.clone(),
|
|
||||||
..Default::default()
|
|
||||||
};
|
};
|
||||||
Tokenizer::from_pretrained(tokenizer_name.clone(), Some(params)).ok()
|
|
||||||
|
// Load local tokenizer
|
||||||
|
let tokenizer = Tokenizer::from_file(local_path.join("tokenizer.json")).ok();
|
||||||
|
|
||||||
|
(tokenizer, model_info)
|
||||||
|
} else {
|
||||||
|
let mut builder = ApiBuilder::new()
|
||||||
|
.with_progress(false)
|
||||||
|
.with_token(authorization_token);
|
||||||
|
|
||||||
|
if let Some(cache_dir) = std::env::var("HUGGINGFACE_HUB_CACHE").ok() {
|
||||||
|
builder = builder.with_cache_dir(cache_dir.into());
|
||||||
|
}
|
||||||
|
|
||||||
|
if revision.is_none() {
|
||||||
|
tracing::warn!("`--revision` is not set");
|
||||||
|
tracing::warn!("We strongly advise to set it to a known supported commit.");
|
||||||
|
}
|
||||||
|
|
||||||
|
let api = builder.build().unwrap();
|
||||||
|
let api_repo = api.repo(Repo::with_revision(
|
||||||
|
tokenizer_name.clone(),
|
||||||
|
RepoType::Model,
|
||||||
|
revision.clone().unwrap_or("main".to_string()),
|
||||||
|
));
|
||||||
|
|
||||||
|
// Get Model info
|
||||||
|
let model_info = get_model_info(&api_repo).await.unwrap_or_else(|| {
|
||||||
|
tracing::warn!("Could not retrieve model info from the Hugging Face hub.");
|
||||||
|
HubModelInfo {
|
||||||
|
model_id: tokenizer_name.to_string(),
|
||||||
|
sha: None,
|
||||||
|
pipeline_tag: None,
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
let tokenizer = match api_repo.get("tokenizer.json").await {
|
||||||
|
Ok(tokenizer_filename) => Tokenizer::from_file(tokenizer_filename).ok(),
|
||||||
|
Err(_) => get_base_tokenizer(&api, &api_repo).await,
|
||||||
|
};
|
||||||
|
|
||||||
|
(tokenizer, model_info)
|
||||||
};
|
};
|
||||||
|
|
||||||
// Launch Tokio runtime
|
if tokenizer.is_none() {
|
||||||
tokio::runtime::Builder::new_multi_thread()
|
tracing::warn!("Could not find a fast tokenizer implementation for {tokenizer_name}");
|
||||||
.enable_all()
|
tracing::warn!("Rust input length validation and truncation is disabled");
|
||||||
.build()?
|
}
|
||||||
.block_on(async {
|
|
||||||
init_logging(otlp_endpoint, json_output);
|
|
||||||
|
|
||||||
if tokenizer.is_none() {
|
// if pipeline-tag == text-generation we default to return_full_text = true
|
||||||
|
let compat_return_full_text = match &model_info.pipeline_tag {
|
||||||
|
None => {
|
||||||
|
tracing::warn!("no pipeline tag found for model {tokenizer_name}");
|
||||||
|
false
|
||||||
|
}
|
||||||
|
Some(pipeline_tag) => pipeline_tag.as_str() == "text-generation",
|
||||||
|
};
|
||||||
|
|
||||||
|
// Instantiate sharded client from the master unix socket
|
||||||
|
let mut sharded_client = ShardedClient::connect_uds(master_shard_uds_path)
|
||||||
|
.await
|
||||||
|
.map_err(RouterError::Connection)?;
|
||||||
|
// Clear the cache; useful if the webserver rebooted
|
||||||
|
sharded_client
|
||||||
|
.clear_cache(None)
|
||||||
|
.await
|
||||||
|
.map_err(RouterError::Cache)?;
|
||||||
|
// Get info from the shard
|
||||||
|
let shard_info = sharded_client.info().await.map_err(RouterError::Info)?;
|
||||||
|
|
||||||
|
// Warmup model
|
||||||
|
tracing::info!("Warming up model");
|
||||||
|
let max_supported_batch_total_tokens = match sharded_client
|
||||||
|
.warmup(
|
||||||
|
max_input_length as u32,
|
||||||
|
max_batch_prefill_tokens,
|
||||||
|
max_total_tokens as u32,
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
.map_err(RouterError::Warmup)?
|
||||||
|
{
|
||||||
|
// Older models do not support automatic max-batch-total-tokens
|
||||||
|
None => {
|
||||||
|
let max_batch_total_tokens = max_batch_total_tokens
|
||||||
|
.unwrap_or(16000.max((max_total_tokens as u32).max(max_batch_prefill_tokens)));
|
||||||
|
tracing::warn!("Model does not support automatic max batch total tokens");
|
||||||
|
max_batch_total_tokens
|
||||||
|
}
|
||||||
|
// Flash attention models return their max supported total tokens
|
||||||
|
Some(max_supported_batch_total_tokens) => {
|
||||||
|
// Warn if user added his own max-batch-total-tokens as we will ignore it
|
||||||
|
if max_batch_total_tokens.is_some() {
|
||||||
tracing::warn!(
|
tracing::warn!(
|
||||||
"Could not find a fast tokenizer implementation for {tokenizer_name}"
|
"`--max-batch-total-tokens` is deprecated for Flash \
|
||||||
|
Attention models."
|
||||||
);
|
);
|
||||||
tracing::warn!("Rust input length validation and truncation is disabled");
|
tracing::warn!(
|
||||||
|
"Inferred max batch total tokens: {max_supported_batch_total_tokens}"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
if max_total_tokens as u32 > max_supported_batch_total_tokens {
|
||||||
|
return Err(RouterError::ArgumentValidation(format!("`max_total_tokens` must be <= `max_batch_total_tokens`. Given: {max_total_tokens} and {max_supported_batch_total_tokens}")));
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get Model info
|
max_supported_batch_total_tokens
|
||||||
let model_info = match local_model {
|
}
|
||||||
true => HubModelInfo {
|
};
|
||||||
model_id: tokenizer_name.clone(),
|
tracing::info!("Setting max batch total tokens to {max_supported_batch_total_tokens}");
|
||||||
sha: None,
|
tracing::info!("Connected");
|
||||||
pipeline_tag: None,
|
|
||||||
},
|
|
||||||
false => get_model_info(&tokenizer_name, revision, authorization_token)
|
|
||||||
.await
|
|
||||||
.unwrap_or_else(|| {
|
|
||||||
tracing::warn!("Could not retrieve model info from the Hugging Face hub.");
|
|
||||||
HubModelInfo {
|
|
||||||
model_id: tokenizer_name.to_string(),
|
|
||||||
sha: None,
|
|
||||||
pipeline_tag: None,
|
|
||||||
}
|
|
||||||
}),
|
|
||||||
};
|
|
||||||
|
|
||||||
// if pipeline-tag == text-generation we default to return_full_text = true
|
let addr = match hostname.parse() {
|
||||||
let compat_return_full_text = match &model_info.pipeline_tag {
|
Ok(ip) => SocketAddr::new(ip, port),
|
||||||
None => {
|
Err(_) => {
|
||||||
tracing::warn!("no pipeline tag found for model {tokenizer_name}");
|
tracing::warn!("Invalid hostname, defaulting to 0.0.0.0");
|
||||||
false
|
SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), port)
|
||||||
}
|
}
|
||||||
Some(pipeline_tag) => pipeline_tag.as_str() == "text-generation",
|
};
|
||||||
};
|
|
||||||
|
|
||||||
// Instantiate sharded client from the master unix socket
|
// Run server
|
||||||
let mut sharded_client = ShardedClient::connect_uds(master_shard_uds_path)
|
server::run(
|
||||||
.await
|
model_info,
|
||||||
.map_err(RouterError::Connection)?;
|
shard_info,
|
||||||
// Clear the cache; useful if the webserver rebooted
|
compat_return_full_text,
|
||||||
sharded_client
|
max_concurrent_requests,
|
||||||
.clear_cache(None)
|
max_best_of,
|
||||||
.await
|
max_stop_sequences,
|
||||||
.map_err(RouterError::Cache)?;
|
max_top_n_tokens,
|
||||||
// Get info from the shard
|
max_input_length,
|
||||||
let shard_info = sharded_client.info().await.map_err(RouterError::Info)?;
|
max_total_tokens,
|
||||||
|
waiting_served_ratio,
|
||||||
// Warmup model
|
max_batch_prefill_tokens,
|
||||||
tracing::info!("Warming up model");
|
max_supported_batch_total_tokens,
|
||||||
let max_supported_batch_total_tokens = match sharded_client
|
max_waiting_tokens,
|
||||||
.warmup(max_input_length as u32, max_batch_prefill_tokens, max_total_tokens as u32)
|
sharded_client,
|
||||||
.await
|
tokenizer,
|
||||||
.map_err(RouterError::Warmup)?
|
validation_workers,
|
||||||
{
|
addr,
|
||||||
// Older models do not support automatic max-batch-total-tokens
|
cors_allow_origin,
|
||||||
None => {
|
ngrok,
|
||||||
let max_batch_total_tokens = max_batch_total_tokens.unwrap_or(
|
ngrok_authtoken,
|
||||||
16000.max((max_total_tokens as u32).max(max_batch_prefill_tokens)),
|
ngrok_edge,
|
||||||
);
|
)
|
||||||
tracing::warn!("Model does not support automatic max batch total tokens");
|
.await?;
|
||||||
max_batch_total_tokens
|
Ok(())
|
||||||
}
|
|
||||||
// Flash attention models return their max supported total tokens
|
|
||||||
Some(max_supported_batch_total_tokens) => {
|
|
||||||
// Warn if user added his own max-batch-total-tokens as we will ignore it
|
|
||||||
if max_batch_total_tokens.is_some() {
|
|
||||||
tracing::warn!(
|
|
||||||
"`--max-batch-total-tokens` is deprecated for Flash \
|
|
||||||
Attention models."
|
|
||||||
);
|
|
||||||
tracing::warn!(
|
|
||||||
"Inferred max batch total tokens: {max_supported_batch_total_tokens}"
|
|
||||||
);
|
|
||||||
}
|
|
||||||
if max_total_tokens as u32 > max_supported_batch_total_tokens {
|
|
||||||
return Err(RouterError::ArgumentValidation(format!("`max_total_tokens` must be <= `max_batch_total_tokens`. Given: {max_total_tokens} and {max_supported_batch_total_tokens}")));
|
|
||||||
}
|
|
||||||
|
|
||||||
max_supported_batch_total_tokens
|
|
||||||
}
|
|
||||||
};
|
|
||||||
tracing::info!("Setting max batch total tokens to {max_supported_batch_total_tokens}");
|
|
||||||
tracing::info!("Connected");
|
|
||||||
|
|
||||||
let addr = match hostname.parse() {
|
|
||||||
Ok(ip) => SocketAddr::new(ip, port),
|
|
||||||
Err(_) => {
|
|
||||||
tracing::warn!("Invalid hostname, defaulting to 0.0.0.0");
|
|
||||||
SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), port)
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
// Run server
|
|
||||||
server::run(
|
|
||||||
model_info,
|
|
||||||
shard_info,
|
|
||||||
compat_return_full_text,
|
|
||||||
max_concurrent_requests,
|
|
||||||
max_best_of,
|
|
||||||
max_stop_sequences,
|
|
||||||
max_top_n_tokens,
|
|
||||||
max_input_length,
|
|
||||||
max_total_tokens,
|
|
||||||
waiting_served_ratio,
|
|
||||||
max_batch_prefill_tokens,
|
|
||||||
max_supported_batch_total_tokens,
|
|
||||||
max_waiting_tokens,
|
|
||||||
sharded_client,
|
|
||||||
tokenizer,
|
|
||||||
validation_workers,
|
|
||||||
addr,
|
|
||||||
cors_allow_origin,
|
|
||||||
ngrok,
|
|
||||||
ngrok_authtoken,
|
|
||||||
ngrok_edge,
|
|
||||||
)
|
|
||||||
.await?;
|
|
||||||
Ok(())
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Init logging using env variables LOG_LEVEL and LOG_FORMAT:
|
/// Init logging using env variables LOG_LEVEL and LOG_FORMAT:
|
||||||
|
@ -339,30 +358,8 @@ fn init_logging(otlp_endpoint: Option<String>, json_output: bool) {
|
||||||
}
|
}
|
||||||
|
|
||||||
/// get model info from the Huggingface Hub
|
/// get model info from the Huggingface Hub
|
||||||
pub async fn get_model_info(
|
pub async fn get_model_info(api: &ApiRepo) -> Option<HubModelInfo> {
|
||||||
model_id: &str,
|
let response = api.info_request().send().await.ok()?;
|
||||||
revision: Option<String>,
|
|
||||||
token: Option<String>,
|
|
||||||
) -> Option<HubModelInfo> {
|
|
||||||
let revision = match revision {
|
|
||||||
None => {
|
|
||||||
tracing::warn!("`--revision` is not set");
|
|
||||||
tracing::warn!("We strongly advise to set it to a known supported commit.");
|
|
||||||
"main".to_string()
|
|
||||||
}
|
|
||||||
Some(revision) => revision,
|
|
||||||
};
|
|
||||||
|
|
||||||
let client = reqwest::Client::new();
|
|
||||||
// Poor man's urlencode
|
|
||||||
let revision = revision.replace('/', "%2F");
|
|
||||||
let url = format!("https://huggingface.co/api/models/{model_id}/revision/{revision}");
|
|
||||||
let mut builder = client.get(url).timeout(Duration::from_secs(5));
|
|
||||||
if let Some(token) = token {
|
|
||||||
builder = builder.bearer_auth(token);
|
|
||||||
}
|
|
||||||
|
|
||||||
let response = builder.send().await.ok()?;
|
|
||||||
|
|
||||||
if response.status().is_success() {
|
if response.status().is_success() {
|
||||||
let hub_model_info: HubModelInfo =
|
let hub_model_info: HubModelInfo =
|
||||||
|
@ -379,6 +376,31 @@ pub async fn get_model_info(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// get base tokenizer
|
||||||
|
pub async fn get_base_tokenizer(api: &Api, api_repo: &ApiRepo) -> Option<Tokenizer> {
|
||||||
|
let config_filename = api_repo.get("config.json").await.ok()?;
|
||||||
|
|
||||||
|
// Open the file in read-only mode with buffer.
|
||||||
|
let file = File::open(config_filename).ok()?;
|
||||||
|
let reader = BufReader::new(file);
|
||||||
|
|
||||||
|
// Read the JSON contents of the file as an instance of `User`.
|
||||||
|
let config: serde_json::Value = serde_json::from_reader(reader).ok()?;
|
||||||
|
|
||||||
|
if let Some(serde_json::Value::String(base_model_id)) = config.get("base_model_name_or_path") {
|
||||||
|
let api_base_repo = api.repo(Repo::with_revision(
|
||||||
|
base_model_id.to_string(),
|
||||||
|
RepoType::Model,
|
||||||
|
"main".to_string(),
|
||||||
|
));
|
||||||
|
|
||||||
|
let tokenizer_filename = api_base_repo.get("tokenizer.json").await.ok()?;
|
||||||
|
Tokenizer::from_file(tokenizer_filename).ok()
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Debug, Error)]
|
#[derive(Debug, Error)]
|
||||||
enum RouterError {
|
enum RouterError {
|
||||||
#[error("Argument validation error: {0}")]
|
#[error("Argument validation error: {0}")]
|
||||||
|
|
Loading…
Reference in New Issue