This commit is contained in:
OlivierDehaene 2024-06-26 14:12:01 +02:00
parent 93e0a7de8b
commit 230f2a415a
14 changed files with 430 additions and 424 deletions

View File

@ -1,21 +1,17 @@
use crate::client::{Batch, CachedBatch, ClientError, Generation, Health, ShardedClient};
/// Batching and inference logic /// Batching and inference logic
use crate::queue::{Entry, Queue}; use crate::queue::{Entry, Queue};
use text_generation_router::infer::{ use async_trait::async_trait;
GeneratedText, InferError, InferStreamResponse, Backend, use nohash_hasher::IntMap;
}; use std::sync::Arc;
use text_generation_router::infer::{Backend, GeneratedText, InferError, InferStreamResponse};
use text_generation_router::validation::ValidGenerateRequest; use text_generation_router::validation::ValidGenerateRequest;
use text_generation_router::{FinishReason, PrefillToken, Token}; use text_generation_router::{FinishReason, PrefillToken, Token};
use nohash_hasher::IntMap;
use std::sync::{
Arc,
};
use crate::client::{Batch, CachedBatch, Generation, ShardedClient, ClientError, Health};
use tokio::sync::mpsc::error::SendError; use tokio::sync::mpsc::error::SendError;
use tokio::sync::{mpsc, Notify}; use tokio::sync::{mpsc, Notify};
use tokio::time::Instant; use tokio::time::Instant;
use tokio_stream::wrappers::UnboundedReceiverStream; use tokio_stream::wrappers::UnboundedReceiverStream;
use tracing::{info_span, instrument, Instrument, Span}; use tracing::{info_span, instrument, Instrument, Span};
use async_trait::async_trait;
pub struct BackendV3 { pub struct BackendV3 {
/// Request queue /// Request queue
@ -94,9 +90,7 @@ impl Backend for BackendV3 {
self.batching_task_notifier.notify_one(); self.batching_task_notifier.notify_one();
// Return stream // Return stream
Ok( Ok(UnboundedReceiverStream::new(response_rx))
UnboundedReceiverStream::new(response_rx),
)
} }
async fn health(&self, current_health: bool) -> bool { async fn health(&self, current_health: bool) -> bool {
@ -193,8 +187,7 @@ pub(crate) async fn batching_task(
}); });
// Generate one token for this new batch to have the attention past in cache // Generate one token for this new batch to have the attention past in cache
let new_cached_batch = let new_cached_batch = prefill(&mut client, new_batch, &mut new_entries)
prefill(&mut client, new_batch, &mut new_entries)
.instrument(span) .instrument(span)
.await; .await;
// Reset waiting counter // Reset waiting counter
@ -480,8 +473,7 @@ fn send_errors(error: ClientError, entries: &mut IntMap<u64, Entry>) {
impl From<crate::client::GeneratedText> for GeneratedText { impl From<crate::client::GeneratedText> for GeneratedText {
fn from(value: crate::client::GeneratedText) -> Self { fn from(value: crate::client::GeneratedText) -> Self {
let v3_finish_reason = let v3_finish_reason = crate::client::FinishReason::try_from(value.finish_reason).unwrap();
crate::client::FinishReason::try_from(value.finish_reason).unwrap();
let finish_reason = match v3_finish_reason { let finish_reason = match v3_finish_reason {
crate::client::FinishReason::Length => FinishReason::Length, crate::client::FinishReason::Length => FinishReason::Length,
crate::client::FinishReason::EosToken => FinishReason::EndOfSequenceToken, crate::client::FinishReason::EosToken => FinishReason::EndOfSequenceToken,

View File

@ -1,5 +1,4 @@
/// Single shard Client /// Single shard Client
use crate::client::{pb, Chunk}; use crate::client::{pb, Chunk};
use crate::client::{ClientError, Result, WARMUP_IMAGE_BASE64}; use crate::client::{ClientError, Result, WARMUP_IMAGE_BASE64};
use base64::engine::general_purpose::STANDARD; use base64::engine::general_purpose::STANDARD;
@ -20,6 +19,7 @@ pub struct Client {
impl Client { impl Client {
/// Returns a client connected to the given url /// Returns a client connected to the given url
#[allow(dead_code)]
pub async fn connect(uri: Uri) -> Result<Self> { pub async fn connect(uri: Uri) -> Result<Self> {
let channel = Channel::builder(uri).connect().await?; let channel = Channel::builder(uri).connect().await?;

View File

@ -1,6 +1,5 @@
//! Text Generation gRPC client library //! Text Generation gRPC client library
use async_trait::async_trait; use async_trait::async_trait;
use thiserror::Error; use thiserror::Error;
use tonic::transport; use tonic::transport;

View File

@ -1,17 +1,17 @@
use crate::client::{ClientError, Result};
/// Multi shard Client /// Multi shard Client
use crate::client::{Health, ShardInfo}; use crate::client::{Health, ShardInfo};
use crate::client::{ClientError, Result};
use crate::client::{Chunk, InfoResponse, Input};
use async_trait::async_trait;
use futures::future::join_all;
use tonic::transport::Uri;
use tracing::instrument;
use crate::client::client::{DecodeTimings, PrefillTimings}; use crate::client::client::{DecodeTimings, PrefillTimings};
use crate::client::{ use crate::client::{
Batch, CachedBatch, Client, Generation, GrammarType, HealthResponse, Batch, CachedBatch, Client, Generation, GrammarType, HealthResponse,
NextTokenChooserParameters, Request, StoppingCriteriaParameters, NextTokenChooserParameters, Request, StoppingCriteriaParameters,
}; };
use crate::client::{Chunk, InfoResponse, Input};
use async_trait::async_trait;
use futures::future::join_all;
use tonic::transport::Uri;
use tracing::instrument;
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
/// Text Generation Inference gRPC multi client /// Text Generation Inference gRPC multi client
@ -35,6 +35,7 @@ impl ShardedClient {
} }
/// Returns a client connected to the given uri /// Returns a client connected to the given uri
#[allow(dead_code)]
pub async fn connect(uri: Uri) -> Result<Self> { pub async fn connect(uri: Uri) -> Result<Self> {
let master_client = Client::connect(uri).await?; let master_client = Client::connect(uri).await?;
Self::from_master_client(master_client).await Self::from_master_client(master_client).await

View File

@ -1,13 +1,13 @@
mod block_allocator;
mod queue;
mod backend; mod backend;
mod block_allocator;
mod client; mod client;
mod queue;
use crate::client::{ClientError, ShardedClient};
pub(crate) use backend::BackendV3;
use serde::Serialize; use serde::Serialize;
use thiserror::Error; use thiserror::Error;
use utoipa::ToSchema; use utoipa::ToSchema;
pub(crate) use backend::BackendV3;
use crate::client::{ShardedClient, ClientError};
#[derive(Clone, Debug, Serialize, ToSchema)] #[derive(Clone, Debug, Serialize, ToSchema)]
pub struct BackendInfo { pub struct BackendInfo {
@ -31,7 +31,8 @@ pub struct BackendInfo {
} }
pub async fn connect_backend( pub async fn connect_backend(
max_input_tokens: usize, max_total_tokens: usize, max_input_tokens: usize,
max_total_tokens: usize,
master_shard_uds_path: String, master_shard_uds_path: String,
waiting_served_ratio: f32, waiting_served_ratio: f32,
max_batch_prefill_tokens: u32, max_batch_prefill_tokens: u32,

View File

@ -1,26 +1,7 @@
use axum::http::HeaderValue;
use clap::Parser; use clap::Parser;
use hf_hub::api::tokio::{Api, ApiBuilder, ApiRepo}; use text_generation_router::server;
use hf_hub::{Cache, Repo, RepoType};
use opentelemetry::sdk::propagation::TraceContextPropagator;
use opentelemetry::sdk::trace;
use opentelemetry::sdk::trace::Sampler;
use opentelemetry::sdk::Resource;
use opentelemetry::{global, KeyValue};
use opentelemetry_otlp::WithExportConfig;
use std::fs::File;
use std::io::BufReader;
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
use std::path::{Path, PathBuf};
use text_generation_router::config::Config;
use text_generation_router::{server, HubModelInfo, HubProcessorConfig, HubTokenizerConfig};
use thiserror::Error;
use tokenizers::Tokenizer;
use tower_http::cors::AllowOrigin;
use tracing_subscriber::layer::SubscriberExt;
use tracing_subscriber::util::SubscriberInitExt;
use tracing_subscriber::{filter::LevelFilter, EnvFilter, Layer};
use text_generation_router_v3::{connect_backend, V3Error}; use text_generation_router_v3::{connect_backend, V3Error};
use thiserror::Error;
/// App Configuration /// App Configuration
#[derive(Parser, Debug)] #[derive(Parser, Debug)]
@ -121,7 +102,7 @@ async fn main() -> Result<(), RouterError> {
} = args; } = args;
// Launch Tokio runtime // Launch Tokio runtime
init_logging(otlp_endpoint, otlp_service_name, json_output); text_generation_router::logging::init_logging(otlp_endpoint, otlp_service_name, json_output);
// Validate args // Validate args
if max_input_tokens >= max_total_tokens { if max_input_tokens >= max_total_tokens {
@ -148,218 +129,37 @@ async fn main() -> Result<(), RouterError> {
} }
} }
// CORS allowed origins let (backend, _backend_info) = connect_backend(
// map to go inside the option and then map to parse from String to HeaderValue max_input_tokens,
// Finally, convert to AllowOrigin max_total_tokens,
let cors_allow_origin: Option<AllowOrigin> = cors_allow_origin.map(|cors_allow_origin| { master_shard_uds_path,
AllowOrigin::list( waiting_served_ratio,
cors_allow_origin max_batch_prefill_tokens,
.iter() max_batch_total_tokens,
.map(|origin| origin.parse::<HeaderValue>().unwrap()), max_waiting_tokens,
max_batch_size,
) )
}); .await?;
// Parse Huggingface hub token
let authorization_token = std::env::var("HF_TOKEN")
.or_else(|_| std::env::var("HUGGING_FACE_HUB_TOKEN"))
.ok();
// Tokenizer instance
// This will only be used to validate payloads
let local_path = Path::new(&tokenizer_name);
// Shared API builder initialization
let api_builder = || {
let mut builder = ApiBuilder::new()
.with_progress(false)
.with_token(authorization_token);
if let Ok(cache_dir) = std::env::var("HUGGINGFACE_HUB_CACHE") {
builder = builder.with_cache_dir(cache_dir.into());
}
builder
};
// Decide if we need to use the API based on the revision and local path
let use_api = revision.is_some() || !local_path.exists() || !local_path.is_dir();
// Initialize API if needed
#[derive(Clone)]
enum Type {
Api(Api),
Cache(Cache),
None,
}
let api = if use_api {
if std::env::var("HF_HUB_OFFLINE") == Ok("1".to_string()) {
let cache = Cache::default();
tracing::warn!("Offline mode active using cache defaults");
Type::Cache(cache)
} else {
tracing::info!("Using the Hugging Face API");
match api_builder().build() {
Ok(api) => Type::Api(api),
Err(_) => {
tracing::warn!("Unable to build the Hugging Face API");
Type::None
}
}
}
} else {
Type::None
};
// Load tokenizer and model info
let (
tokenizer_filename,
config_filename,
tokenizer_config_filename,
processor_config_filename,
model_info,
) = match api {
Type::None => (
Some(local_path.join("tokenizer.json")),
Some(local_path.join("config.json")),
Some(local_path.join("tokenizer_config.json")),
Some(local_path.join("processor_config.json")),
None,
),
Type::Api(api) => {
let api_repo = api.repo(Repo::with_revision(
tokenizer_name.to_string(),
RepoType::Model,
revision.clone().unwrap_or_else(|| "main".to_string()),
));
let tokenizer_filename = match api_repo.get("tokenizer.json").await {
Ok(tokenizer_filename) => Some(tokenizer_filename),
Err(_) => get_base_tokenizer(&api, &api_repo).await,
};
let config_filename = api_repo.get("config.json").await.ok();
let tokenizer_config_filename = api_repo.get("tokenizer_config.json").await.ok();
let processor_config_filename = api_repo.get("processor_config.json").await.ok();
let model_info = if let Some(model_info) = get_model_info(&api_repo).await {
Some(model_info)
} else {
tracing::warn!("Could not retrieve model info from the Hugging Face hub.");
None
};
(
tokenizer_filename,
config_filename,
tokenizer_config_filename,
processor_config_filename,
model_info,
)
}
Type::Cache(cache) => {
let repo = cache.repo(Repo::with_revision(
tokenizer_name.to_string(),
RepoType::Model,
revision.clone().unwrap_or_else(|| "main".to_string()),
));
(
repo.get("tokenizer.json"),
repo.get("config.json"),
repo.get("tokenizer_config.json"),
repo.get("processor_config.json"),
None,
)
}
};
let tokenizer: Option<Tokenizer> =
tokenizer_filename.and_then(|filename| Tokenizer::from_file(filename).ok());
let config: Option<Config> = config_filename.and_then(|filename| {
std::fs::read_to_string(filename)
.ok()
.as_ref()
.and_then(|c| {
let config: Result<Config, _> = serde_json::from_str(c);
if let Err(err) = &config {
tracing::warn!("Could not parse config {err:?}");
}
config.ok()
})
});
let model_info = model_info.unwrap_or_else(|| HubModelInfo {
model_id: tokenizer_name.to_string(),
sha: None,
pipeline_tag: None,
});
// Read the JSON contents of the file as an instance of 'HubTokenizerConfig'.
let tokenizer_config: Option<HubTokenizerConfig> = if let Some(filename) = tokenizer_config_path
{
HubTokenizerConfig::from_file(filename)
} else {
tokenizer_config_filename.and_then(HubTokenizerConfig::from_file)
};
let tokenizer_config = tokenizer_config.unwrap_or_else(|| {
tracing::warn!("Could not find tokenizer config locally and no API specified");
HubTokenizerConfig::default()
});
let processor_config = processor_config_filename
.and_then(HubProcessorConfig::from_file)
.unwrap_or_default();
tracing::info!("Using config {config:?}");
if tokenizer.is_none() {
tracing::warn!("Could not find a fast tokenizer implementation for {tokenizer_name}");
tracing::warn!("Rust input length validation and truncation is disabled");
}
// if pipeline-tag == text-generation we default to return_full_text = true
let compat_return_full_text = match &model_info.pipeline_tag {
None => {
tracing::warn!("no pipeline tag found for model {tokenizer_name}");
true
}
Some(pipeline_tag) => pipeline_tag.as_str() == "text-generation",
};
// Determine the server port based on the feature and environment variable.
let port = if cfg!(feature = "google") {
std::env::var("AIP_HTTP_PORT")
.map(|aip_http_port| aip_http_port.parse::<u16>().unwrap_or(port))
.unwrap_or(port)
} else {
port
};
let addr = match hostname.parse() {
Ok(ip) => SocketAddr::new(ip, port),
Err(_) => {
tracing::warn!("Invalid hostname, defaulting to 0.0.0.0");
SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), port)
}
};
let (backend, backend_info) = connect_backend(max_input_tokens, max_total_tokens, master_shard_uds_path, waiting_served_ratio, max_batch_prefill_tokens, max_batch_total_tokens, max_waiting_tokens, max_batch_size).await?;
// Run server // Run server
server::run( server::run(
backend, backend,
model_info,
compat_return_full_text,
max_concurrent_requests, max_concurrent_requests,
max_best_of, max_best_of,
max_stop_sequences, max_stop_sequences,
max_top_n_tokens, max_top_n_tokens,
max_input_tokens, max_input_tokens,
max_total_tokens, max_total_tokens,
tokenizer,
config,
validation_workers, validation_workers,
addr, tokenizer_name,
tokenizer_config_path,
revision,
hostname,
port,
cors_allow_origin, cors_allow_origin,
ngrok, ngrok,
ngrok_authtoken, ngrok_authtoken,
ngrok_edge, ngrok_edge,
tokenizer_config,
processor_config,
messages_api_enabled, messages_api_enabled,
disable_grammar_support, disable_grammar_support,
max_client_batch_size, max_client_batch_size,
@ -368,140 +168,6 @@ async fn main() -> Result<(), RouterError> {
Ok(()) Ok(())
} }
/// Init logging using env variables LOG_LEVEL and LOG_FORMAT:
/// - otlp_endpoint is an optional URL to an Open Telemetry collector
/// - otlp_service_name service name to appear in APM
/// - LOG_LEVEL may be TRACE, DEBUG, INFO, WARN or ERROR (default to INFO)
/// - LOG_FORMAT may be TEXT or JSON (default to TEXT)
/// - LOG_COLORIZE may be "false" or "true" (default to "true" or ansi supported platforms)
fn init_logging(otlp_endpoint: Option<String>, otlp_service_name: String, json_output: bool) {
let mut layers = Vec::new();
// STDOUT/STDERR layer
let ansi = std::env::var("LOG_COLORIZE") != Ok("1".to_string());
let fmt_layer = tracing_subscriber::fmt::layer()
.with_file(true)
.with_ansi(ansi)
.with_line_number(true);
let fmt_layer = match json_output {
true => fmt_layer.json().flatten_event(true).boxed(),
false => fmt_layer.boxed(),
};
layers.push(fmt_layer);
// OpenTelemetry tracing layer
if let Some(otlp_endpoint) = otlp_endpoint {
global::set_text_map_propagator(TraceContextPropagator::new());
let tracer = opentelemetry_otlp::new_pipeline()
.tracing()
.with_exporter(
opentelemetry_otlp::new_exporter()
.tonic()
.with_endpoint(otlp_endpoint),
)
.with_trace_config(
trace::config()
.with_resource(Resource::new(vec![KeyValue::new(
"service.name",
otlp_service_name,
)]))
.with_sampler(Sampler::AlwaysOn),
)
.install_batch(opentelemetry::runtime::Tokio);
if let Ok(tracer) = tracer {
layers.push(tracing_opentelemetry::layer().with_tracer(tracer).boxed());
init_tracing_opentelemetry::init_propagator().unwrap();
};
}
// Filter events with LOG_LEVEL
let varname = "LOG_LEVEL";
let env_filter = if let Ok(log_level) = std::env::var(varname) {
// Override to avoid simple logs to be spammed with tokio level informations
let log_level = match &log_level[..] {
"warn" => "text_generation_launcher=warn,text_generation_router=warn",
"info" => "text_generation_launcher=info,text_generation_router=info",
"debug" => "text_generation_launcher=debug,text_generation_router=debug",
log_level => log_level,
};
EnvFilter::builder()
.with_default_directive(LevelFilter::INFO.into())
.parse_lossy(log_level)
} else {
EnvFilter::new("info")
};
tracing_subscriber::registry()
.with(env_filter)
.with(layers)
.init();
}
/// get model info from the Huggingface Hub
pub async fn get_model_info(api: &ApiRepo) -> Option<HubModelInfo> {
let response = api.info_request().send().await.ok()?;
if response.status().is_success() {
let hub_model_info: HubModelInfo =
serde_json::from_str(&response.text().await.ok()?).ok()?;
if let Some(sha) = &hub_model_info.sha {
tracing::info!(
"Serving revision {sha} of model {}",
hub_model_info.model_id
);
}
Some(hub_model_info)
} else {
None
}
}
/// get base tokenizer
pub async fn get_base_tokenizer(api: &Api, api_repo: &ApiRepo) -> Option<PathBuf> {
let config_filename = api_repo.get("config.json").await.ok()?;
// Open the file in read-only mode with buffer.
let file = File::open(config_filename).ok()?;
let reader = BufReader::new(file);
// Read the JSON contents of the file as an instance of `User`.
let config: serde_json::Value = serde_json::from_reader(reader).ok()?;
if let Some(serde_json::Value::String(base_model_id)) = config.get("base_model_name_or_path") {
let api_base_repo = api.repo(Repo::with_revision(
base_model_id.to_string(),
RepoType::Model,
"main".to_string(),
));
api_base_repo.get("tokenizer.json").await.ok()
} else {
None
}
}
/// get tokenizer_config from the Huggingface Hub
pub async fn get_tokenizer_config(api_repo: &ApiRepo) -> Option<HubTokenizerConfig> {
let tokenizer_config_filename = api_repo.get("tokenizer_config.json").await.ok()?;
// Open the file in read-only mode with buffer.
let file = File::open(tokenizer_config_filename).ok()?;
let reader = BufReader::new(file);
// Read the JSON contents of the file as an instance of 'HubTokenizerConfig'.
let tokenizer_config: HubTokenizerConfig = serde_json::from_reader(reader)
.map_err(|e| {
tracing::warn!("Unable to parse tokenizer config: {}", e);
e
})
.ok()?;
Some(tokenizer_config)
}
#[derive(Debug, Error)] #[derive(Debug, Error)]
enum RouterError { enum RouterError {
#[error("Argument validation error: {0}")] #[error("Argument validation error: {0}")]

View File

@ -1,14 +1,17 @@
use crate::block_allocator::{BlockAllocation, BlockAllocator}; use crate::block_allocator::{BlockAllocation, BlockAllocator};
use text_generation_router::infer::InferError; use crate::client;
use text_generation_router::infer::InferStreamResponse;
use text_generation_router::validation::{Chunk, ChunksToString, ValidGenerateRequest, ValidGrammar, ValidParameters, ValidStoppingParameters};
use nohash_hasher::{BuildNoHashHasher, IntMap};
use std::cmp::{max, min};
use std::collections::VecDeque;
use crate::client::{ use crate::client::{
Batch, GrammarType, NextTokenChooserParameters, Request, StoppingCriteriaParameters, Batch, GrammarType, NextTokenChooserParameters, Request, StoppingCriteriaParameters,
}; };
use crate::client as client; use nohash_hasher::{BuildNoHashHasher, IntMap};
use std::cmp::{max, min};
use std::collections::VecDeque;
use text_generation_router::infer::InferError;
use text_generation_router::infer::InferStreamResponse;
use text_generation_router::validation::{
Chunk, ChunksToString, ValidGenerateRequest, ValidGrammar, ValidParameters,
ValidStoppingParameters,
};
use tokio::sync::{mpsc, oneshot}; use tokio::sync::{mpsc, oneshot};
use tokio::time::Instant; use tokio::time::Instant;
use tracing::{info_span, instrument, Instrument, Span}; use tracing::{info_span, instrument, Instrument, Span};
@ -335,10 +338,21 @@ impl State {
id, id,
prefill_logprobs: entry.request.decoder_input_details, prefill_logprobs: entry.request.decoder_input_details,
input_chunks: Some(client::Input { input_chunks: Some(client::Input {
chunks: entry.request.inputs.clone().into_iter().map(|c| client::InputChunk { chunk: Some(match c { chunks: entry
.request
.inputs
.clone()
.into_iter()
.map(|c| client::InputChunk {
chunk: Some(match c {
Chunk::Text(text) => client::Chunk::Text(text), Chunk::Text(text) => client::Chunk::Text(text),
Chunk::Image(image) => client::Chunk::Image(client::Image { data: image.data, mimetype: image.mimetype }) Chunk::Image(image) => client::Chunk::Image(client::Image {
})}).collect() data: image.data,
mimetype: image.mimetype,
}),
}),
})
.collect(),
}), }),
inputs: entry.request.inputs.chunks_to_string(), inputs: entry.request.inputs.chunks_to_string(),
truncate: entry.request.truncate, truncate: entry.request.truncate,

View File

@ -74,7 +74,6 @@ impl ChatTemplate {
} }
} }
// tests // tests
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {

View File

@ -3,23 +3,23 @@ mod chat_template;
pub mod tool_grammar; pub mod tool_grammar;
use crate::validation::{ValidGenerateRequest, Validation, ValidationError}; use crate::validation::{ValidGenerateRequest, Validation, ValidationError};
use crate::GrammarType;
use crate::{ use crate::{
ChatTemplateVersions, FinishReason, GenerateRequest, HubProcessorConfig, ChatTemplateVersions, FinishReason, GenerateRequest, HubProcessorConfig, HubTokenizerConfig,
HubTokenizerConfig, Message, PrefillToken, Token, Message, PrefillToken, Token,
}; };
use crate::{GrammarType}; use async_trait::async_trait;
use chat_template::ChatTemplate;
use futures::future::try_join_all; use futures::future::try_join_all;
use minijinja::{ErrorKind}; use minijinja::ErrorKind;
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use thiserror::Error; use thiserror::Error;
use tokio::sync::{OwnedSemaphorePermit, Semaphore, TryAcquireError}; use tokio::sync::{OwnedSemaphorePermit, Semaphore, TryAcquireError};
use tokio::time::Instant; use tokio::time::Instant;
use tokio_stream::wrappers::UnboundedReceiverStream; use tokio_stream::wrappers::UnboundedReceiverStream;
use tokio_stream::StreamExt; use tokio_stream::StreamExt;
use tracing::instrument; use tracing::instrument;
use chat_template::ChatTemplate;
use async_trait::async_trait;
#[async_trait] #[async_trait]
pub trait Backend { pub trait Backend {
@ -86,7 +86,8 @@ impl Infer {
#[instrument(skip_all)] #[instrument(skip_all)]
pub(crate) async fn generate_stream<'a>( pub(crate) async fn generate_stream<'a>(
&'a self, &'a self,
request: GenerateRequest) -> Result<GenerateStreamResponse, InferError> { request: GenerateRequest,
) -> Result<GenerateStreamResponse, InferError> {
// Limit concurrent requests by acquiring a permit from the semaphore // Limit concurrent requests by acquiring a permit from the semaphore
let permit = self let permit = self
.clone() .clone()
@ -106,9 +107,7 @@ impl Infer {
})?; })?;
let input_length = valid_request.input_length; let input_length = valid_request.input_length;
let generation_stream = self let generation_stream = self.backend.schedule(valid_request)?;
.backend
.schedule(valid_request)?;
Ok((permit, input_length, generation_stream)) Ok((permit, input_length, generation_stream))
} }

View File

@ -1,7 +1,7 @@
/// Batching and inference logic /// Batching and inference logic
use crate::infer::v2::queue::{Entry, Queue}; use crate::infer::v2::queue::{Entry, Queue};
use crate::infer::{ use crate::infer::{
GenerateStreamResponse, GeneratedText, InferError, InferStreamResponse, Backend, Backend, GenerateStreamResponse, GeneratedText, InferError, InferStreamResponse,
}; };
use crate::validation::ValidGenerateRequest; use crate::validation::ValidGenerateRequest;
use crate::{FinishReason, PrefillToken, Token}; use crate::{FinishReason, PrefillToken, Token};

View File

@ -6,6 +6,7 @@ pub mod validation;
#[cfg(feature = "kserve")] #[cfg(feature = "kserve")]
mod kserve; mod kserve;
pub mod logging;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use tracing::warn; use tracing::warn;
@ -158,7 +159,6 @@ pub struct Info {
#[schema(example = "32")] #[schema(example = "32")]
pub max_client_batch_size: usize, pub max_client_batch_size: usize,
/// Router Info /// Router Info
#[schema(example = "text-generation-router")] #[schema(example = "text-generation-router")]
pub router: &'static str, pub router: &'static str,

81
router/src/logging.rs Normal file
View File

@ -0,0 +1,81 @@
use opentelemetry::sdk::propagation::TraceContextPropagator;
use opentelemetry::sdk::trace;
use opentelemetry::sdk::trace::Sampler;
use opentelemetry::sdk::Resource;
use opentelemetry::{global, KeyValue};
use opentelemetry_otlp::WithExportConfig;
use tracing_subscriber::layer::SubscriberExt;
use tracing_subscriber::util::SubscriberInitExt;
use tracing_subscriber::{filter::LevelFilter, EnvFilter, Layer};
/// Init logging using env variables LOG_LEVEL and LOG_FORMAT:
/// - otlp_endpoint is an optional URL to an Open Telemetry collector
/// - otlp_service_name service name to appear in APM
/// - LOG_LEVEL may be TRACE, DEBUG, INFO, WARN or ERROR (default to INFO)
/// - LOG_FORMAT may be TEXT or JSON (default to TEXT)
/// - LOG_COLORIZE may be "false" or "true" (default to "true" or ansi supported platforms)
pub fn init_logging(otlp_endpoint: Option<String>, otlp_service_name: String, json_output: bool) {
let mut layers = Vec::new();
// STDOUT/STDERR layer
let ansi = std::env::var("LOG_COLORIZE") != Ok("1".to_string());
let fmt_layer = tracing_subscriber::fmt::layer()
.with_file(true)
.with_ansi(ansi)
.with_line_number(true);
let fmt_layer = match json_output {
true => fmt_layer.json().flatten_event(true).boxed(),
false => fmt_layer.boxed(),
};
layers.push(fmt_layer);
// OpenTelemetry tracing layer
if let Some(otlp_endpoint) = otlp_endpoint {
global::set_text_map_propagator(TraceContextPropagator::new());
let tracer = opentelemetry_otlp::new_pipeline()
.tracing()
.with_exporter(
opentelemetry_otlp::new_exporter()
.tonic()
.with_endpoint(otlp_endpoint),
)
.with_trace_config(
trace::config()
.with_resource(Resource::new(vec![KeyValue::new(
"service.name",
otlp_service_name,
)]))
.with_sampler(Sampler::AlwaysOn),
)
.install_batch(opentelemetry::runtime::Tokio);
if let Ok(tracer) = tracer {
layers.push(tracing_opentelemetry::layer().with_tracer(tracer).boxed());
init_tracing_opentelemetry::init_propagator().unwrap();
};
}
// Filter events with LOG_LEVEL
let varname = "LOG_LEVEL";
let env_filter = if let Ok(log_level) = std::env::var(varname) {
// Override to avoid simple logs to be spammed with tokio level informations
let log_level = match &log_level[..] {
"warn" => "text_generation_launcher=warn,text_generation_router=warn",
"info" => "text_generation_launcher=info,text_generation_router=info",
"debug" => "text_generation_launcher=debug,text_generation_router=debug",
log_level => log_level,
};
EnvFilter::builder()
.with_default_directive(LevelFilter::INFO.into())
.parse_lossy(log_level)
} else {
EnvFilter::new("info")
};
tracing_subscriber::registry()
.with(env_filter)
.with(layers)
.init();
}

View File

@ -1,7 +1,7 @@
/// HTTP Server logic /// HTTP Server logic
use crate::config::Config; use crate::config::Config;
use crate::infer::{Infer, InferError, InferResponse, InferStreamResponse, Backend};
use crate::infer::tool_grammar::ToolGrammar; use crate::infer::tool_grammar::ToolGrammar;
use crate::infer::{Backend, Infer, InferError, InferResponse, InferStreamResponse};
#[cfg(feature = "kserve")] #[cfg(feature = "kserve")]
use crate::kserve::{ use crate::kserve::{
kerve_server_metadata, kserve_health_live, kserve_health_ready, kserve_model_infer, kerve_server_metadata, kserve_health_live, kserve_health_ready, kserve_model_infer,
@ -23,7 +23,7 @@ use crate::{
use crate::{FunctionDefinition, ToolCall, ToolType}; use crate::{FunctionDefinition, ToolCall, ToolType};
use async_stream::__private::AsyncStream; use async_stream::__private::AsyncStream;
use axum::extract::Extension; use axum::extract::Extension;
use axum::http::{HeaderMap, Method, StatusCode}; use axum::http::{HeaderMap, HeaderValue, Method, StatusCode};
use axum::response::sse::{Event, KeepAlive, Sse}; use axum::response::sse::{Event, KeepAlive, Sse};
use axum::response::{IntoResponse, Response}; use axum::response::{IntoResponse, Response};
use axum::routing::{get, post}; use axum::routing::{get, post};
@ -33,10 +33,15 @@ use futures::stream::StreamExt;
use futures::stream::{FuturesOrdered, FuturesUnordered}; use futures::stream::{FuturesOrdered, FuturesUnordered};
use futures::Stream; use futures::Stream;
use futures::TryStreamExt; use futures::TryStreamExt;
use hf_hub::api::tokio::{Api, ApiBuilder, ApiRepo};
use hf_hub::{Cache, Repo, RepoType};
use metrics_exporter_prometheus::{Matcher, PrometheusBuilder, PrometheusHandle}; use metrics_exporter_prometheus::{Matcher, PrometheusBuilder, PrometheusHandle};
use serde_json::Value; use serde_json::Value;
use std::convert::Infallible; use std::convert::Infallible;
use std::net::SocketAddr; use std::fs::File;
use std::io::BufReader;
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
use std::path::{Path, PathBuf};
use thiserror::Error; use thiserror::Error;
use tokenizers::Tokenizer; use tokenizers::Tokenizer;
use tokio::select; use tokio::select;
@ -1395,24 +1400,22 @@ pub(crate) struct ComputeType(String);
#[allow(clippy::too_many_arguments)] #[allow(clippy::too_many_arguments)]
pub async fn run( pub async fn run(
backend: impl Backend + Send + Sync + 'static, backend: impl Backend + Send + Sync + 'static,
model_info: HubModelInfo,
compat_return_full_text: bool,
max_concurrent_requests: usize, max_concurrent_requests: usize,
max_best_of: usize, max_best_of: usize,
max_stop_sequences: usize, max_stop_sequences: usize,
max_top_n_tokens: u32, max_top_n_tokens: u32,
max_input_tokens: usize, max_input_tokens: usize,
max_total_tokens: usize, max_total_tokens: usize,
tokenizer: Option<Tokenizer>,
config: Option<Config>,
validation_workers: usize, validation_workers: usize,
addr: SocketAddr, tokenizer_name: String,
allow_origin: Option<AllowOrigin>, tokenizer_config_path: Option<String>,
revision: Option<String>,
hostname: String,
port: u16,
cors_allow_origin: Option<Vec<String>>,
ngrok: bool, ngrok: bool,
_ngrok_authtoken: Option<String>, _ngrok_authtoken: Option<String>,
_ngrok_edge: Option<String>, _ngrok_edge: Option<String>,
tokenizer_config: HubTokenizerConfig,
processor_config: HubProcessorConfig,
messages_api_enabled: bool, messages_api_enabled: bool,
grammar_support: bool, grammar_support: bool,
max_client_batch_size: usize, max_client_batch_size: usize,
@ -1485,6 +1488,195 @@ pub async fn run(
)] )]
struct ApiDoc; struct ApiDoc;
// CORS allowed origins
// map to go inside the option and then map to parse from String to HeaderValue
// Finally, convert to AllowOrigin
let allow_origin: Option<AllowOrigin> = cors_allow_origin.map(|cors_allow_origin| {
AllowOrigin::list(
cors_allow_origin
.iter()
.map(|origin| origin.parse::<HeaderValue>().unwrap()),
)
});
// Parse Huggingface hub token
let authorization_token = std::env::var("HF_TOKEN")
.or_else(|_| std::env::var("HUGGING_FACE_HUB_TOKEN"))
.ok();
// Tokenizer instance
// This will only be used to validate payloads
let local_path = Path::new(&tokenizer_name);
// Shared API builder initialization
let api_builder = || {
let mut builder = ApiBuilder::new()
.with_progress(false)
.with_token(authorization_token);
if let Ok(cache_dir) = std::env::var("HUGGINGFACE_HUB_CACHE") {
builder = builder.with_cache_dir(cache_dir.into());
}
builder
};
// Decide if we need to use the API based on the revision and local path
let use_api = revision.is_some() || !local_path.exists() || !local_path.is_dir();
// Initialize API if needed
#[derive(Clone)]
enum Type {
Api(Api),
Cache(Cache),
None,
}
let api = if use_api {
if std::env::var("HF_HUB_OFFLINE") == Ok("1".to_string()) {
let cache = Cache::default();
tracing::warn!("Offline mode active using cache defaults");
Type::Cache(cache)
} else {
tracing::info!("Using the Hugging Face API");
match api_builder().build() {
Ok(api) => Type::Api(api),
Err(_) => {
tracing::warn!("Unable to build the Hugging Face API");
Type::None
}
}
}
} else {
Type::None
};
// Load tokenizer and model info
let (
tokenizer_filename,
config_filename,
tokenizer_config_filename,
processor_config_filename,
model_info,
) = match api {
Type::None => (
Some(local_path.join("tokenizer.json")),
Some(local_path.join("config.json")),
Some(local_path.join("tokenizer_config.json")),
Some(local_path.join("processor_config.json")),
None,
),
Type::Api(api) => {
let api_repo = api.repo(Repo::with_revision(
tokenizer_name.to_string(),
RepoType::Model,
revision.clone().unwrap_or_else(|| "main".to_string()),
));
let tokenizer_filename = match api_repo.get("tokenizer.json").await {
Ok(tokenizer_filename) => Some(tokenizer_filename),
Err(_) => get_base_tokenizer(&api, &api_repo).await,
};
let config_filename = api_repo.get("config.json").await.ok();
let tokenizer_config_filename = api_repo.get("tokenizer_config.json").await.ok();
let processor_config_filename = api_repo.get("processor_config.json").await.ok();
let model_info = if let Some(model_info) = get_hub_model_info(&api_repo).await {
Some(model_info)
} else {
tracing::warn!("Could not retrieve model info from the Hugging Face hub.");
None
};
(
tokenizer_filename,
config_filename,
tokenizer_config_filename,
processor_config_filename,
model_info,
)
}
Type::Cache(cache) => {
let repo = cache.repo(Repo::with_revision(
tokenizer_name.to_string(),
RepoType::Model,
revision.clone().unwrap_or_else(|| "main".to_string()),
));
(
repo.get("tokenizer.json"),
repo.get("config.json"),
repo.get("tokenizer_config.json"),
repo.get("processor_config.json"),
None,
)
}
};
let tokenizer: Option<Tokenizer> =
tokenizer_filename.and_then(|filename| Tokenizer::from_file(filename).ok());
let config: Option<Config> = config_filename.and_then(|filename| {
std::fs::read_to_string(filename)
.ok()
.as_ref()
.and_then(|c| {
let config: Result<Config, _> = serde_json::from_str(c);
if let Err(err) = &config {
tracing::warn!("Could not parse config {err:?}");
}
config.ok()
})
});
let model_info = model_info.unwrap_or_else(|| HubModelInfo {
model_id: tokenizer_name.to_string(),
sha: None,
pipeline_tag: None,
});
// Read the JSON contents of the file as an instance of 'HubTokenizerConfig'.
let tokenizer_config: Option<HubTokenizerConfig> = if let Some(filename) = tokenizer_config_path
{
HubTokenizerConfig::from_file(filename)
} else {
tokenizer_config_filename.and_then(HubTokenizerConfig::from_file)
};
let tokenizer_config = tokenizer_config.unwrap_or_else(|| {
tracing::warn!("Could not find tokenizer config locally and no API specified");
HubTokenizerConfig::default()
});
let processor_config = processor_config_filename
.and_then(HubProcessorConfig::from_file)
.unwrap_or_default();
tracing::info!("Using config {config:?}");
if tokenizer.is_none() {
tracing::warn!("Could not find a fast tokenizer implementation for {tokenizer_name}");
tracing::warn!("Rust input length validation and truncation is disabled");
}
// if pipeline-tag == text-generation we default to return_full_text = true
let compat_return_full_text = match &model_info.pipeline_tag {
None => {
tracing::warn!("no pipeline tag found for model {tokenizer_name}");
true
}
Some(pipeline_tag) => pipeline_tag.as_str() == "text-generation",
};
// Determine the server port based on the feature and environment variable.
let port = if cfg!(feature = "google") {
std::env::var("AIP_HTTP_PORT")
.map(|aip_http_port| aip_http_port.parse::<u16>().unwrap_or(port))
.unwrap_or(port)
} else {
port
};
let addr = match hostname.parse() {
Ok(ip) => SocketAddr::new(ip, port),
Err(_) => {
tracing::warn!("Invalid hostname, defaulting to 0.0.0.0");
SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), port)
}
};
// Create state // Create state
let validation = Validation::new( let validation = Validation::new(
validation_workers, validation_workers,
@ -1750,6 +1942,68 @@ pub async fn run(
Ok(()) Ok(())
} }
/// get model info from the Huggingface Hub
pub async fn get_hub_model_info(api: &ApiRepo) -> Option<HubModelInfo> {
let response = api.info_request().send().await.ok()?;
if response.status().is_success() {
let hub_model_info: HubModelInfo =
serde_json::from_str(&response.text().await.ok()?).ok()?;
if let Some(sha) = &hub_model_info.sha {
tracing::info!(
"Serving revision {sha} of model {}",
hub_model_info.model_id
);
}
Some(hub_model_info)
} else {
None
}
}
/// get base tokenizer
pub async fn get_base_tokenizer(api: &Api, api_repo: &ApiRepo) -> Option<PathBuf> {
let config_filename = api_repo.get("config.json").await.ok()?;
// Open the file in read-only mode with buffer.
let file = File::open(config_filename).ok()?;
let reader = BufReader::new(file);
// Read the JSON contents of the file as an instance of `User`.
let config: serde_json::Value = serde_json::from_reader(reader).ok()?;
if let Some(serde_json::Value::String(base_model_id)) = config.get("base_model_name_or_path") {
let api_base_repo = api.repo(Repo::with_revision(
base_model_id.to_string(),
RepoType::Model,
"main".to_string(),
));
api_base_repo.get("tokenizer.json").await.ok()
} else {
None
}
}
/// get tokenizer_config from the Huggingface Hub
pub async fn get_tokenizer_config(api_repo: &ApiRepo) -> Option<HubTokenizerConfig> {
let tokenizer_config_filename = api_repo.get("tokenizer_config.json").await.ok()?;
// Open the file in read-only mode with buffer.
let file = File::open(tokenizer_config_filename).ok()?;
let reader = BufReader::new(file);
// Read the JSON contents of the file as an instance of 'HubTokenizerConfig'.
let tokenizer_config: HubTokenizerConfig = serde_json::from_reader(reader)
.map_err(|e| {
tracing::warn!("Unable to parse tokenizer config: {}", e);
e
})
.ok()?;
Some(tokenizer_config)
}
/// Shutdown signal handler /// Shutdown signal handler
async fn shutdown_signal() { async fn shutdown_signal() {
let ctrl_c = async { let ctrl_c = async {

View File

@ -638,7 +638,7 @@ pub struct Image {
#[derive(Debug, Clone, Eq, PartialEq)] #[derive(Debug, Clone, Eq, PartialEq)]
pub enum Chunk { pub enum Chunk {
Text(String), Text(String),
Image(Image) Image(Image),
} }
/// Convert input chunks to a stringly-typed input for backwards /// Convert input chunks to a stringly-typed input for backwards