refacto
This commit is contained in:
parent
93e0a7de8b
commit
230f2a415a
|
@ -1,21 +1,17 @@
|
||||||
|
use crate::client::{Batch, CachedBatch, ClientError, Generation, Health, ShardedClient};
|
||||||
/// Batching and inference logic
|
/// Batching and inference logic
|
||||||
use crate::queue::{Entry, Queue};
|
use crate::queue::{Entry, Queue};
|
||||||
use text_generation_router::infer::{
|
use async_trait::async_trait;
|
||||||
GeneratedText, InferError, InferStreamResponse, Backend,
|
use nohash_hasher::IntMap;
|
||||||
};
|
use std::sync::Arc;
|
||||||
|
use text_generation_router::infer::{Backend, GeneratedText, InferError, InferStreamResponse};
|
||||||
use text_generation_router::validation::ValidGenerateRequest;
|
use text_generation_router::validation::ValidGenerateRequest;
|
||||||
use text_generation_router::{FinishReason, PrefillToken, Token};
|
use text_generation_router::{FinishReason, PrefillToken, Token};
|
||||||
use nohash_hasher::IntMap;
|
|
||||||
use std::sync::{
|
|
||||||
Arc,
|
|
||||||
};
|
|
||||||
use crate::client::{Batch, CachedBatch, Generation, ShardedClient, ClientError, Health};
|
|
||||||
use tokio::sync::mpsc::error::SendError;
|
use tokio::sync::mpsc::error::SendError;
|
||||||
use tokio::sync::{mpsc, Notify};
|
use tokio::sync::{mpsc, Notify};
|
||||||
use tokio::time::Instant;
|
use tokio::time::Instant;
|
||||||
use tokio_stream::wrappers::UnboundedReceiverStream;
|
use tokio_stream::wrappers::UnboundedReceiverStream;
|
||||||
use tracing::{info_span, instrument, Instrument, Span};
|
use tracing::{info_span, instrument, Instrument, Span};
|
||||||
use async_trait::async_trait;
|
|
||||||
|
|
||||||
pub struct BackendV3 {
|
pub struct BackendV3 {
|
||||||
/// Request queue
|
/// Request queue
|
||||||
|
@ -94,9 +90,7 @@ impl Backend for BackendV3 {
|
||||||
self.batching_task_notifier.notify_one();
|
self.batching_task_notifier.notify_one();
|
||||||
|
|
||||||
// Return stream
|
// Return stream
|
||||||
Ok(
|
Ok(UnboundedReceiverStream::new(response_rx))
|
||||||
UnboundedReceiverStream::new(response_rx),
|
|
||||||
)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn health(&self, current_health: bool) -> bool {
|
async fn health(&self, current_health: bool) -> bool {
|
||||||
|
@ -193,10 +187,9 @@ pub(crate) async fn batching_task(
|
||||||
});
|
});
|
||||||
|
|
||||||
// Generate one token for this new batch to have the attention past in cache
|
// Generate one token for this new batch to have the attention past in cache
|
||||||
let new_cached_batch =
|
let new_cached_batch = prefill(&mut client, new_batch, &mut new_entries)
|
||||||
prefill(&mut client, new_batch, &mut new_entries)
|
.instrument(span)
|
||||||
.instrument(span)
|
.await;
|
||||||
.await;
|
|
||||||
// Reset waiting counter
|
// Reset waiting counter
|
||||||
waiting_tokens = 1;
|
waiting_tokens = 1;
|
||||||
// Extend current batch with the new batch
|
// Extend current batch with the new batch
|
||||||
|
@ -480,8 +473,7 @@ fn send_errors(error: ClientError, entries: &mut IntMap<u64, Entry>) {
|
||||||
|
|
||||||
impl From<crate::client::GeneratedText> for GeneratedText {
|
impl From<crate::client::GeneratedText> for GeneratedText {
|
||||||
fn from(value: crate::client::GeneratedText) -> Self {
|
fn from(value: crate::client::GeneratedText) -> Self {
|
||||||
let v3_finish_reason =
|
let v3_finish_reason = crate::client::FinishReason::try_from(value.finish_reason).unwrap();
|
||||||
crate::client::FinishReason::try_from(value.finish_reason).unwrap();
|
|
||||||
let finish_reason = match v3_finish_reason {
|
let finish_reason = match v3_finish_reason {
|
||||||
crate::client::FinishReason::Length => FinishReason::Length,
|
crate::client::FinishReason::Length => FinishReason::Length,
|
||||||
crate::client::FinishReason::EosToken => FinishReason::EndOfSequenceToken,
|
crate::client::FinishReason::EosToken => FinishReason::EndOfSequenceToken,
|
||||||
|
|
|
@ -1,5 +1,4 @@
|
||||||
/// Single shard Client
|
/// Single shard Client
|
||||||
|
|
||||||
use crate::client::{pb, Chunk};
|
use crate::client::{pb, Chunk};
|
||||||
use crate::client::{ClientError, Result, WARMUP_IMAGE_BASE64};
|
use crate::client::{ClientError, Result, WARMUP_IMAGE_BASE64};
|
||||||
use base64::engine::general_purpose::STANDARD;
|
use base64::engine::general_purpose::STANDARD;
|
||||||
|
@ -20,6 +19,7 @@ pub struct Client {
|
||||||
|
|
||||||
impl Client {
|
impl Client {
|
||||||
/// Returns a client connected to the given url
|
/// Returns a client connected to the given url
|
||||||
|
#[allow(dead_code)]
|
||||||
pub async fn connect(uri: Uri) -> Result<Self> {
|
pub async fn connect(uri: Uri) -> Result<Self> {
|
||||||
let channel = Channel::builder(uri).connect().await?;
|
let channel = Channel::builder(uri).connect().await?;
|
||||||
|
|
||||||
|
|
|
@ -1,6 +1,5 @@
|
||||||
//! Text Generation gRPC client library
|
//! Text Generation gRPC client library
|
||||||
|
|
||||||
|
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use thiserror::Error;
|
use thiserror::Error;
|
||||||
use tonic::transport;
|
use tonic::transport;
|
||||||
|
|
|
@ -1,17 +1,17 @@
|
||||||
|
use crate::client::{ClientError, Result};
|
||||||
/// Multi shard Client
|
/// Multi shard Client
|
||||||
use crate::client::{Health, ShardInfo};
|
use crate::client::{Health, ShardInfo};
|
||||||
use crate::client::{ClientError, Result};
|
|
||||||
|
|
||||||
use crate::client::{Chunk, InfoResponse, Input};
|
|
||||||
use async_trait::async_trait;
|
|
||||||
use futures::future::join_all;
|
|
||||||
use tonic::transport::Uri;
|
|
||||||
use tracing::instrument;
|
|
||||||
use crate::client::client::{DecodeTimings, PrefillTimings};
|
use crate::client::client::{DecodeTimings, PrefillTimings};
|
||||||
use crate::client::{
|
use crate::client::{
|
||||||
Batch, CachedBatch, Client, Generation, GrammarType, HealthResponse,
|
Batch, CachedBatch, Client, Generation, GrammarType, HealthResponse,
|
||||||
NextTokenChooserParameters, Request, StoppingCriteriaParameters,
|
NextTokenChooserParameters, Request, StoppingCriteriaParameters,
|
||||||
};
|
};
|
||||||
|
use crate::client::{Chunk, InfoResponse, Input};
|
||||||
|
use async_trait::async_trait;
|
||||||
|
use futures::future::join_all;
|
||||||
|
use tonic::transport::Uri;
|
||||||
|
use tracing::instrument;
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
/// Text Generation Inference gRPC multi client
|
/// Text Generation Inference gRPC multi client
|
||||||
|
@ -35,6 +35,7 @@ impl ShardedClient {
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Returns a client connected to the given uri
|
/// Returns a client connected to the given uri
|
||||||
|
#[allow(dead_code)]
|
||||||
pub async fn connect(uri: Uri) -> Result<Self> {
|
pub async fn connect(uri: Uri) -> Result<Self> {
|
||||||
let master_client = Client::connect(uri).await?;
|
let master_client = Client::connect(uri).await?;
|
||||||
Self::from_master_client(master_client).await
|
Self::from_master_client(master_client).await
|
||||||
|
|
|
@ -1,13 +1,13 @@
|
||||||
mod block_allocator;
|
|
||||||
mod queue;
|
|
||||||
mod backend;
|
mod backend;
|
||||||
|
mod block_allocator;
|
||||||
mod client;
|
mod client;
|
||||||
|
mod queue;
|
||||||
|
|
||||||
|
use crate::client::{ClientError, ShardedClient};
|
||||||
|
pub(crate) use backend::BackendV3;
|
||||||
use serde::Serialize;
|
use serde::Serialize;
|
||||||
use thiserror::Error;
|
use thiserror::Error;
|
||||||
use utoipa::ToSchema;
|
use utoipa::ToSchema;
|
||||||
pub(crate) use backend::BackendV3;
|
|
||||||
use crate::client::{ShardedClient, ClientError};
|
|
||||||
|
|
||||||
#[derive(Clone, Debug, Serialize, ToSchema)]
|
#[derive(Clone, Debug, Serialize, ToSchema)]
|
||||||
pub struct BackendInfo {
|
pub struct BackendInfo {
|
||||||
|
@ -31,7 +31,8 @@ pub struct BackendInfo {
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn connect_backend(
|
pub async fn connect_backend(
|
||||||
max_input_tokens: usize, max_total_tokens: usize,
|
max_input_tokens: usize,
|
||||||
|
max_total_tokens: usize,
|
||||||
master_shard_uds_path: String,
|
master_shard_uds_path: String,
|
||||||
waiting_served_ratio: f32,
|
waiting_served_ratio: f32,
|
||||||
max_batch_prefill_tokens: u32,
|
max_batch_prefill_tokens: u32,
|
||||||
|
@ -137,4 +138,4 @@ pub enum V3Error {
|
||||||
Warmup(ClientError),
|
Warmup(ClientError),
|
||||||
#[error("Not enough memory to handle `max_total_tokens={0}`")]
|
#[error("Not enough memory to handle `max_total_tokens={0}`")]
|
||||||
NotEnoughMemory(usize),
|
NotEnoughMemory(usize),
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,26 +1,7 @@
|
||||||
use axum::http::HeaderValue;
|
|
||||||
use clap::Parser;
|
use clap::Parser;
|
||||||
use hf_hub::api::tokio::{Api, ApiBuilder, ApiRepo};
|
use text_generation_router::server;
|
||||||
use hf_hub::{Cache, Repo, RepoType};
|
|
||||||
use opentelemetry::sdk::propagation::TraceContextPropagator;
|
|
||||||
use opentelemetry::sdk::trace;
|
|
||||||
use opentelemetry::sdk::trace::Sampler;
|
|
||||||
use opentelemetry::sdk::Resource;
|
|
||||||
use opentelemetry::{global, KeyValue};
|
|
||||||
use opentelemetry_otlp::WithExportConfig;
|
|
||||||
use std::fs::File;
|
|
||||||
use std::io::BufReader;
|
|
||||||
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
|
|
||||||
use std::path::{Path, PathBuf};
|
|
||||||
use text_generation_router::config::Config;
|
|
||||||
use text_generation_router::{server, HubModelInfo, HubProcessorConfig, HubTokenizerConfig};
|
|
||||||
use thiserror::Error;
|
|
||||||
use tokenizers::Tokenizer;
|
|
||||||
use tower_http::cors::AllowOrigin;
|
|
||||||
use tracing_subscriber::layer::SubscriberExt;
|
|
||||||
use tracing_subscriber::util::SubscriberInitExt;
|
|
||||||
use tracing_subscriber::{filter::LevelFilter, EnvFilter, Layer};
|
|
||||||
use text_generation_router_v3::{connect_backend, V3Error};
|
use text_generation_router_v3::{connect_backend, V3Error};
|
||||||
|
use thiserror::Error;
|
||||||
|
|
||||||
/// App Configuration
|
/// App Configuration
|
||||||
#[derive(Parser, Debug)]
|
#[derive(Parser, Debug)]
|
||||||
|
@ -121,7 +102,7 @@ async fn main() -> Result<(), RouterError> {
|
||||||
} = args;
|
} = args;
|
||||||
|
|
||||||
// Launch Tokio runtime
|
// Launch Tokio runtime
|
||||||
init_logging(otlp_endpoint, otlp_service_name, json_output);
|
text_generation_router::logging::init_logging(otlp_endpoint, otlp_service_name, json_output);
|
||||||
|
|
||||||
// Validate args
|
// Validate args
|
||||||
if max_input_tokens >= max_total_tokens {
|
if max_input_tokens >= max_total_tokens {
|
||||||
|
@ -148,218 +129,37 @@ async fn main() -> Result<(), RouterError> {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// CORS allowed origins
|
let (backend, _backend_info) = connect_backend(
|
||||||
// map to go inside the option and then map to parse from String to HeaderValue
|
max_input_tokens,
|
||||||
// Finally, convert to AllowOrigin
|
max_total_tokens,
|
||||||
let cors_allow_origin: Option<AllowOrigin> = cors_allow_origin.map(|cors_allow_origin| {
|
master_shard_uds_path,
|
||||||
AllowOrigin::list(
|
waiting_served_ratio,
|
||||||
cors_allow_origin
|
max_batch_prefill_tokens,
|
||||||
.iter()
|
max_batch_total_tokens,
|
||||||
.map(|origin| origin.parse::<HeaderValue>().unwrap()),
|
max_waiting_tokens,
|
||||||
)
|
max_batch_size,
|
||||||
});
|
)
|
||||||
|
.await?;
|
||||||
// Parse Huggingface hub token
|
|
||||||
let authorization_token = std::env::var("HF_TOKEN")
|
|
||||||
.or_else(|_| std::env::var("HUGGING_FACE_HUB_TOKEN"))
|
|
||||||
.ok();
|
|
||||||
|
|
||||||
// Tokenizer instance
|
|
||||||
// This will only be used to validate payloads
|
|
||||||
let local_path = Path::new(&tokenizer_name);
|
|
||||||
|
|
||||||
// Shared API builder initialization
|
|
||||||
let api_builder = || {
|
|
||||||
let mut builder = ApiBuilder::new()
|
|
||||||
.with_progress(false)
|
|
||||||
.with_token(authorization_token);
|
|
||||||
|
|
||||||
if let Ok(cache_dir) = std::env::var("HUGGINGFACE_HUB_CACHE") {
|
|
||||||
builder = builder.with_cache_dir(cache_dir.into());
|
|
||||||
}
|
|
||||||
|
|
||||||
builder
|
|
||||||
};
|
|
||||||
|
|
||||||
// Decide if we need to use the API based on the revision and local path
|
|
||||||
let use_api = revision.is_some() || !local_path.exists() || !local_path.is_dir();
|
|
||||||
|
|
||||||
// Initialize API if needed
|
|
||||||
#[derive(Clone)]
|
|
||||||
enum Type {
|
|
||||||
Api(Api),
|
|
||||||
Cache(Cache),
|
|
||||||
None,
|
|
||||||
}
|
|
||||||
let api = if use_api {
|
|
||||||
if std::env::var("HF_HUB_OFFLINE") == Ok("1".to_string()) {
|
|
||||||
let cache = Cache::default();
|
|
||||||
tracing::warn!("Offline mode active using cache defaults");
|
|
||||||
Type::Cache(cache)
|
|
||||||
} else {
|
|
||||||
tracing::info!("Using the Hugging Face API");
|
|
||||||
match api_builder().build() {
|
|
||||||
Ok(api) => Type::Api(api),
|
|
||||||
Err(_) => {
|
|
||||||
tracing::warn!("Unable to build the Hugging Face API");
|
|
||||||
Type::None
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
Type::None
|
|
||||||
};
|
|
||||||
|
|
||||||
// Load tokenizer and model info
|
|
||||||
let (
|
|
||||||
tokenizer_filename,
|
|
||||||
config_filename,
|
|
||||||
tokenizer_config_filename,
|
|
||||||
processor_config_filename,
|
|
||||||
model_info,
|
|
||||||
) = match api {
|
|
||||||
Type::None => (
|
|
||||||
Some(local_path.join("tokenizer.json")),
|
|
||||||
Some(local_path.join("config.json")),
|
|
||||||
Some(local_path.join("tokenizer_config.json")),
|
|
||||||
Some(local_path.join("processor_config.json")),
|
|
||||||
None,
|
|
||||||
),
|
|
||||||
Type::Api(api) => {
|
|
||||||
let api_repo = api.repo(Repo::with_revision(
|
|
||||||
tokenizer_name.to_string(),
|
|
||||||
RepoType::Model,
|
|
||||||
revision.clone().unwrap_or_else(|| "main".to_string()),
|
|
||||||
));
|
|
||||||
|
|
||||||
let tokenizer_filename = match api_repo.get("tokenizer.json").await {
|
|
||||||
Ok(tokenizer_filename) => Some(tokenizer_filename),
|
|
||||||
Err(_) => get_base_tokenizer(&api, &api_repo).await,
|
|
||||||
};
|
|
||||||
let config_filename = api_repo.get("config.json").await.ok();
|
|
||||||
let tokenizer_config_filename = api_repo.get("tokenizer_config.json").await.ok();
|
|
||||||
let processor_config_filename = api_repo.get("processor_config.json").await.ok();
|
|
||||||
|
|
||||||
let model_info = if let Some(model_info) = get_model_info(&api_repo).await {
|
|
||||||
Some(model_info)
|
|
||||||
} else {
|
|
||||||
tracing::warn!("Could not retrieve model info from the Hugging Face hub.");
|
|
||||||
None
|
|
||||||
};
|
|
||||||
(
|
|
||||||
tokenizer_filename,
|
|
||||||
config_filename,
|
|
||||||
tokenizer_config_filename,
|
|
||||||
processor_config_filename,
|
|
||||||
model_info,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
Type::Cache(cache) => {
|
|
||||||
let repo = cache.repo(Repo::with_revision(
|
|
||||||
tokenizer_name.to_string(),
|
|
||||||
RepoType::Model,
|
|
||||||
revision.clone().unwrap_or_else(|| "main".to_string()),
|
|
||||||
));
|
|
||||||
(
|
|
||||||
repo.get("tokenizer.json"),
|
|
||||||
repo.get("config.json"),
|
|
||||||
repo.get("tokenizer_config.json"),
|
|
||||||
repo.get("processor_config.json"),
|
|
||||||
None,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
};
|
|
||||||
let tokenizer: Option<Tokenizer> =
|
|
||||||
tokenizer_filename.and_then(|filename| Tokenizer::from_file(filename).ok());
|
|
||||||
let config: Option<Config> = config_filename.and_then(|filename| {
|
|
||||||
std::fs::read_to_string(filename)
|
|
||||||
.ok()
|
|
||||||
.as_ref()
|
|
||||||
.and_then(|c| {
|
|
||||||
let config: Result<Config, _> = serde_json::from_str(c);
|
|
||||||
if let Err(err) = &config {
|
|
||||||
tracing::warn!("Could not parse config {err:?}");
|
|
||||||
}
|
|
||||||
config.ok()
|
|
||||||
})
|
|
||||||
});
|
|
||||||
let model_info = model_info.unwrap_or_else(|| HubModelInfo {
|
|
||||||
model_id: tokenizer_name.to_string(),
|
|
||||||
sha: None,
|
|
||||||
pipeline_tag: None,
|
|
||||||
});
|
|
||||||
|
|
||||||
// Read the JSON contents of the file as an instance of 'HubTokenizerConfig'.
|
|
||||||
let tokenizer_config: Option<HubTokenizerConfig> = if let Some(filename) = tokenizer_config_path
|
|
||||||
{
|
|
||||||
HubTokenizerConfig::from_file(filename)
|
|
||||||
} else {
|
|
||||||
tokenizer_config_filename.and_then(HubTokenizerConfig::from_file)
|
|
||||||
};
|
|
||||||
let tokenizer_config = tokenizer_config.unwrap_or_else(|| {
|
|
||||||
tracing::warn!("Could not find tokenizer config locally and no API specified");
|
|
||||||
HubTokenizerConfig::default()
|
|
||||||
});
|
|
||||||
|
|
||||||
let processor_config = processor_config_filename
|
|
||||||
.and_then(HubProcessorConfig::from_file)
|
|
||||||
.unwrap_or_default();
|
|
||||||
|
|
||||||
tracing::info!("Using config {config:?}");
|
|
||||||
if tokenizer.is_none() {
|
|
||||||
tracing::warn!("Could not find a fast tokenizer implementation for {tokenizer_name}");
|
|
||||||
tracing::warn!("Rust input length validation and truncation is disabled");
|
|
||||||
}
|
|
||||||
|
|
||||||
// if pipeline-tag == text-generation we default to return_full_text = true
|
|
||||||
let compat_return_full_text = match &model_info.pipeline_tag {
|
|
||||||
None => {
|
|
||||||
tracing::warn!("no pipeline tag found for model {tokenizer_name}");
|
|
||||||
true
|
|
||||||
}
|
|
||||||
Some(pipeline_tag) => pipeline_tag.as_str() == "text-generation",
|
|
||||||
};
|
|
||||||
|
|
||||||
// Determine the server port based on the feature and environment variable.
|
|
||||||
let port = if cfg!(feature = "google") {
|
|
||||||
std::env::var("AIP_HTTP_PORT")
|
|
||||||
.map(|aip_http_port| aip_http_port.parse::<u16>().unwrap_or(port))
|
|
||||||
.unwrap_or(port)
|
|
||||||
} else {
|
|
||||||
port
|
|
||||||
};
|
|
||||||
|
|
||||||
let addr = match hostname.parse() {
|
|
||||||
Ok(ip) => SocketAddr::new(ip, port),
|
|
||||||
Err(_) => {
|
|
||||||
tracing::warn!("Invalid hostname, defaulting to 0.0.0.0");
|
|
||||||
SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), port)
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
let (backend, backend_info) = connect_backend(max_input_tokens, max_total_tokens, master_shard_uds_path, waiting_served_ratio, max_batch_prefill_tokens, max_batch_total_tokens, max_waiting_tokens, max_batch_size).await?;
|
|
||||||
|
|
||||||
// Run server
|
// Run server
|
||||||
server::run(
|
server::run(
|
||||||
backend,
|
backend,
|
||||||
model_info,
|
|
||||||
compat_return_full_text,
|
|
||||||
max_concurrent_requests,
|
max_concurrent_requests,
|
||||||
max_best_of,
|
max_best_of,
|
||||||
max_stop_sequences,
|
max_stop_sequences,
|
||||||
max_top_n_tokens,
|
max_top_n_tokens,
|
||||||
max_input_tokens,
|
max_input_tokens,
|
||||||
max_total_tokens,
|
max_total_tokens,
|
||||||
tokenizer,
|
|
||||||
config,
|
|
||||||
validation_workers,
|
validation_workers,
|
||||||
addr,
|
tokenizer_name,
|
||||||
|
tokenizer_config_path,
|
||||||
|
revision,
|
||||||
|
hostname,
|
||||||
|
port,
|
||||||
cors_allow_origin,
|
cors_allow_origin,
|
||||||
ngrok,
|
ngrok,
|
||||||
ngrok_authtoken,
|
ngrok_authtoken,
|
||||||
ngrok_edge,
|
ngrok_edge,
|
||||||
tokenizer_config,
|
|
||||||
processor_config,
|
|
||||||
messages_api_enabled,
|
messages_api_enabled,
|
||||||
disable_grammar_support,
|
disable_grammar_support,
|
||||||
max_client_batch_size,
|
max_client_batch_size,
|
||||||
|
@ -368,140 +168,6 @@ async fn main() -> Result<(), RouterError> {
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Init logging using env variables LOG_LEVEL and LOG_FORMAT:
|
|
||||||
/// - otlp_endpoint is an optional URL to an Open Telemetry collector
|
|
||||||
/// - otlp_service_name service name to appear in APM
|
|
||||||
/// - LOG_LEVEL may be TRACE, DEBUG, INFO, WARN or ERROR (default to INFO)
|
|
||||||
/// - LOG_FORMAT may be TEXT or JSON (default to TEXT)
|
|
||||||
/// - LOG_COLORIZE may be "false" or "true" (default to "true" or ansi supported platforms)
|
|
||||||
fn init_logging(otlp_endpoint: Option<String>, otlp_service_name: String, json_output: bool) {
|
|
||||||
let mut layers = Vec::new();
|
|
||||||
|
|
||||||
// STDOUT/STDERR layer
|
|
||||||
let ansi = std::env::var("LOG_COLORIZE") != Ok("1".to_string());
|
|
||||||
let fmt_layer = tracing_subscriber::fmt::layer()
|
|
||||||
.with_file(true)
|
|
||||||
.with_ansi(ansi)
|
|
||||||
.with_line_number(true);
|
|
||||||
|
|
||||||
let fmt_layer = match json_output {
|
|
||||||
true => fmt_layer.json().flatten_event(true).boxed(),
|
|
||||||
false => fmt_layer.boxed(),
|
|
||||||
};
|
|
||||||
layers.push(fmt_layer);
|
|
||||||
|
|
||||||
// OpenTelemetry tracing layer
|
|
||||||
if let Some(otlp_endpoint) = otlp_endpoint {
|
|
||||||
global::set_text_map_propagator(TraceContextPropagator::new());
|
|
||||||
|
|
||||||
let tracer = opentelemetry_otlp::new_pipeline()
|
|
||||||
.tracing()
|
|
||||||
.with_exporter(
|
|
||||||
opentelemetry_otlp::new_exporter()
|
|
||||||
.tonic()
|
|
||||||
.with_endpoint(otlp_endpoint),
|
|
||||||
)
|
|
||||||
.with_trace_config(
|
|
||||||
trace::config()
|
|
||||||
.with_resource(Resource::new(vec![KeyValue::new(
|
|
||||||
"service.name",
|
|
||||||
otlp_service_name,
|
|
||||||
)]))
|
|
||||||
.with_sampler(Sampler::AlwaysOn),
|
|
||||||
)
|
|
||||||
.install_batch(opentelemetry::runtime::Tokio);
|
|
||||||
|
|
||||||
if let Ok(tracer) = tracer {
|
|
||||||
layers.push(tracing_opentelemetry::layer().with_tracer(tracer).boxed());
|
|
||||||
init_tracing_opentelemetry::init_propagator().unwrap();
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
// Filter events with LOG_LEVEL
|
|
||||||
let varname = "LOG_LEVEL";
|
|
||||||
let env_filter = if let Ok(log_level) = std::env::var(varname) {
|
|
||||||
// Override to avoid simple logs to be spammed with tokio level informations
|
|
||||||
let log_level = match &log_level[..] {
|
|
||||||
"warn" => "text_generation_launcher=warn,text_generation_router=warn",
|
|
||||||
"info" => "text_generation_launcher=info,text_generation_router=info",
|
|
||||||
"debug" => "text_generation_launcher=debug,text_generation_router=debug",
|
|
||||||
log_level => log_level,
|
|
||||||
};
|
|
||||||
EnvFilter::builder()
|
|
||||||
.with_default_directive(LevelFilter::INFO.into())
|
|
||||||
.parse_lossy(log_level)
|
|
||||||
} else {
|
|
||||||
EnvFilter::new("info")
|
|
||||||
};
|
|
||||||
|
|
||||||
tracing_subscriber::registry()
|
|
||||||
.with(env_filter)
|
|
||||||
.with(layers)
|
|
||||||
.init();
|
|
||||||
}
|
|
||||||
|
|
||||||
/// get model info from the Huggingface Hub
|
|
||||||
pub async fn get_model_info(api: &ApiRepo) -> Option<HubModelInfo> {
|
|
||||||
let response = api.info_request().send().await.ok()?;
|
|
||||||
|
|
||||||
if response.status().is_success() {
|
|
||||||
let hub_model_info: HubModelInfo =
|
|
||||||
serde_json::from_str(&response.text().await.ok()?).ok()?;
|
|
||||||
if let Some(sha) = &hub_model_info.sha {
|
|
||||||
tracing::info!(
|
|
||||||
"Serving revision {sha} of model {}",
|
|
||||||
hub_model_info.model_id
|
|
||||||
);
|
|
||||||
}
|
|
||||||
Some(hub_model_info)
|
|
||||||
} else {
|
|
||||||
None
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// get base tokenizer
|
|
||||||
pub async fn get_base_tokenizer(api: &Api, api_repo: &ApiRepo) -> Option<PathBuf> {
|
|
||||||
let config_filename = api_repo.get("config.json").await.ok()?;
|
|
||||||
|
|
||||||
// Open the file in read-only mode with buffer.
|
|
||||||
let file = File::open(config_filename).ok()?;
|
|
||||||
let reader = BufReader::new(file);
|
|
||||||
|
|
||||||
// Read the JSON contents of the file as an instance of `User`.
|
|
||||||
let config: serde_json::Value = serde_json::from_reader(reader).ok()?;
|
|
||||||
|
|
||||||
if let Some(serde_json::Value::String(base_model_id)) = config.get("base_model_name_or_path") {
|
|
||||||
let api_base_repo = api.repo(Repo::with_revision(
|
|
||||||
base_model_id.to_string(),
|
|
||||||
RepoType::Model,
|
|
||||||
"main".to_string(),
|
|
||||||
));
|
|
||||||
|
|
||||||
api_base_repo.get("tokenizer.json").await.ok()
|
|
||||||
} else {
|
|
||||||
None
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// get tokenizer_config from the Huggingface Hub
|
|
||||||
pub async fn get_tokenizer_config(api_repo: &ApiRepo) -> Option<HubTokenizerConfig> {
|
|
||||||
let tokenizer_config_filename = api_repo.get("tokenizer_config.json").await.ok()?;
|
|
||||||
|
|
||||||
// Open the file in read-only mode with buffer.
|
|
||||||
let file = File::open(tokenizer_config_filename).ok()?;
|
|
||||||
let reader = BufReader::new(file);
|
|
||||||
|
|
||||||
// Read the JSON contents of the file as an instance of 'HubTokenizerConfig'.
|
|
||||||
let tokenizer_config: HubTokenizerConfig = serde_json::from_reader(reader)
|
|
||||||
.map_err(|e| {
|
|
||||||
tracing::warn!("Unable to parse tokenizer config: {}", e);
|
|
||||||
e
|
|
||||||
})
|
|
||||||
.ok()?;
|
|
||||||
|
|
||||||
Some(tokenizer_config)
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Error)]
|
#[derive(Debug, Error)]
|
||||||
enum RouterError {
|
enum RouterError {
|
||||||
#[error("Argument validation error: {0}")]
|
#[error("Argument validation error: {0}")]
|
||||||
|
|
|
@ -1,14 +1,17 @@
|
||||||
use crate::block_allocator::{BlockAllocation, BlockAllocator};
|
use crate::block_allocator::{BlockAllocation, BlockAllocator};
|
||||||
use text_generation_router::infer::InferError;
|
use crate::client;
|
||||||
use text_generation_router::infer::InferStreamResponse;
|
|
||||||
use text_generation_router::validation::{Chunk, ChunksToString, ValidGenerateRequest, ValidGrammar, ValidParameters, ValidStoppingParameters};
|
|
||||||
use nohash_hasher::{BuildNoHashHasher, IntMap};
|
|
||||||
use std::cmp::{max, min};
|
|
||||||
use std::collections::VecDeque;
|
|
||||||
use crate::client::{
|
use crate::client::{
|
||||||
Batch, GrammarType, NextTokenChooserParameters, Request, StoppingCriteriaParameters,
|
Batch, GrammarType, NextTokenChooserParameters, Request, StoppingCriteriaParameters,
|
||||||
};
|
};
|
||||||
use crate::client as client;
|
use nohash_hasher::{BuildNoHashHasher, IntMap};
|
||||||
|
use std::cmp::{max, min};
|
||||||
|
use std::collections::VecDeque;
|
||||||
|
use text_generation_router::infer::InferError;
|
||||||
|
use text_generation_router::infer::InferStreamResponse;
|
||||||
|
use text_generation_router::validation::{
|
||||||
|
Chunk, ChunksToString, ValidGenerateRequest, ValidGrammar, ValidParameters,
|
||||||
|
ValidStoppingParameters,
|
||||||
|
};
|
||||||
use tokio::sync::{mpsc, oneshot};
|
use tokio::sync::{mpsc, oneshot};
|
||||||
use tokio::time::Instant;
|
use tokio::time::Instant;
|
||||||
use tracing::{info_span, instrument, Instrument, Span};
|
use tracing::{info_span, instrument, Instrument, Span};
|
||||||
|
@ -335,10 +338,21 @@ impl State {
|
||||||
id,
|
id,
|
||||||
prefill_logprobs: entry.request.decoder_input_details,
|
prefill_logprobs: entry.request.decoder_input_details,
|
||||||
input_chunks: Some(client::Input {
|
input_chunks: Some(client::Input {
|
||||||
chunks: entry.request.inputs.clone().into_iter().map(|c| client::InputChunk { chunk: Some(match c {
|
chunks: entry
|
||||||
Chunk::Text(text) => client::Chunk::Text(text),
|
.request
|
||||||
Chunk::Image(image) => client::Chunk::Image(client::Image { data: image.data, mimetype: image.mimetype })
|
.inputs
|
||||||
})}).collect()
|
.clone()
|
||||||
|
.into_iter()
|
||||||
|
.map(|c| client::InputChunk {
|
||||||
|
chunk: Some(match c {
|
||||||
|
Chunk::Text(text) => client::Chunk::Text(text),
|
||||||
|
Chunk::Image(image) => client::Chunk::Image(client::Image {
|
||||||
|
data: image.data,
|
||||||
|
mimetype: image.mimetype,
|
||||||
|
}),
|
||||||
|
}),
|
||||||
|
})
|
||||||
|
.collect(),
|
||||||
}),
|
}),
|
||||||
inputs: entry.request.inputs.chunks_to_string(),
|
inputs: entry.request.inputs.chunks_to_string(),
|
||||||
truncate: entry.request.truncate,
|
truncate: entry.request.truncate,
|
||||||
|
|
|
@ -74,7 +74,6 @@ impl ChatTemplate {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
// tests
|
// tests
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
|
|
|
@ -3,23 +3,23 @@ mod chat_template;
|
||||||
pub mod tool_grammar;
|
pub mod tool_grammar;
|
||||||
|
|
||||||
use crate::validation::{ValidGenerateRequest, Validation, ValidationError};
|
use crate::validation::{ValidGenerateRequest, Validation, ValidationError};
|
||||||
|
use crate::GrammarType;
|
||||||
use crate::{
|
use crate::{
|
||||||
ChatTemplateVersions, FinishReason, GenerateRequest, HubProcessorConfig,
|
ChatTemplateVersions, FinishReason, GenerateRequest, HubProcessorConfig, HubTokenizerConfig,
|
||||||
HubTokenizerConfig, Message, PrefillToken, Token,
|
Message, PrefillToken, Token,
|
||||||
};
|
};
|
||||||
use crate::{GrammarType};
|
use async_trait::async_trait;
|
||||||
|
use chat_template::ChatTemplate;
|
||||||
use futures::future::try_join_all;
|
use futures::future::try_join_all;
|
||||||
use minijinja::{ErrorKind};
|
use minijinja::ErrorKind;
|
||||||
use std::sync::Arc;
|
|
||||||
use std::sync::atomic::{AtomicBool, Ordering};
|
use std::sync::atomic::{AtomicBool, Ordering};
|
||||||
|
use std::sync::Arc;
|
||||||
use thiserror::Error;
|
use thiserror::Error;
|
||||||
use tokio::sync::{OwnedSemaphorePermit, Semaphore, TryAcquireError};
|
use tokio::sync::{OwnedSemaphorePermit, Semaphore, TryAcquireError};
|
||||||
use tokio::time::Instant;
|
use tokio::time::Instant;
|
||||||
use tokio_stream::wrappers::UnboundedReceiverStream;
|
use tokio_stream::wrappers::UnboundedReceiverStream;
|
||||||
use tokio_stream::StreamExt;
|
use tokio_stream::StreamExt;
|
||||||
use tracing::instrument;
|
use tracing::instrument;
|
||||||
use chat_template::ChatTemplate;
|
|
||||||
use async_trait::async_trait;
|
|
||||||
|
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
pub trait Backend {
|
pub trait Backend {
|
||||||
|
@ -86,7 +86,8 @@ impl Infer {
|
||||||
#[instrument(skip_all)]
|
#[instrument(skip_all)]
|
||||||
pub(crate) async fn generate_stream<'a>(
|
pub(crate) async fn generate_stream<'a>(
|
||||||
&'a self,
|
&'a self,
|
||||||
request: GenerateRequest) -> Result<GenerateStreamResponse, InferError> {
|
request: GenerateRequest,
|
||||||
|
) -> Result<GenerateStreamResponse, InferError> {
|
||||||
// Limit concurrent requests by acquiring a permit from the semaphore
|
// Limit concurrent requests by acquiring a permit from the semaphore
|
||||||
let permit = self
|
let permit = self
|
||||||
.clone()
|
.clone()
|
||||||
|
@ -106,9 +107,7 @@ impl Infer {
|
||||||
})?;
|
})?;
|
||||||
|
|
||||||
let input_length = valid_request.input_length;
|
let input_length = valid_request.input_length;
|
||||||
let generation_stream = self
|
let generation_stream = self.backend.schedule(valid_request)?;
|
||||||
.backend
|
|
||||||
.schedule(valid_request)?;
|
|
||||||
|
|
||||||
Ok((permit, input_length, generation_stream))
|
Ok((permit, input_length, generation_stream))
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
/// Batching and inference logic
|
/// Batching and inference logic
|
||||||
use crate::infer::v2::queue::{Entry, Queue};
|
use crate::infer::v2::queue::{Entry, Queue};
|
||||||
use crate::infer::{
|
use crate::infer::{
|
||||||
GenerateStreamResponse, GeneratedText, InferError, InferStreamResponse, Backend,
|
Backend, GenerateStreamResponse, GeneratedText, InferError, InferStreamResponse,
|
||||||
};
|
};
|
||||||
use crate::validation::ValidGenerateRequest;
|
use crate::validation::ValidGenerateRequest;
|
||||||
use crate::{FinishReason, PrefillToken, Token};
|
use crate::{FinishReason, PrefillToken, Token};
|
||||||
|
|
|
@ -6,6 +6,7 @@ pub mod validation;
|
||||||
|
|
||||||
#[cfg(feature = "kserve")]
|
#[cfg(feature = "kserve")]
|
||||||
mod kserve;
|
mod kserve;
|
||||||
|
pub mod logging;
|
||||||
|
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use tracing::warn;
|
use tracing::warn;
|
||||||
|
@ -158,7 +159,6 @@ pub struct Info {
|
||||||
#[schema(example = "32")]
|
#[schema(example = "32")]
|
||||||
pub max_client_batch_size: usize,
|
pub max_client_batch_size: usize,
|
||||||
|
|
||||||
|
|
||||||
/// Router Info
|
/// Router Info
|
||||||
#[schema(example = "text-generation-router")]
|
#[schema(example = "text-generation-router")]
|
||||||
pub router: &'static str,
|
pub router: &'static str,
|
||||||
|
|
|
@ -0,0 +1,81 @@
|
||||||
|
use opentelemetry::sdk::propagation::TraceContextPropagator;
|
||||||
|
use opentelemetry::sdk::trace;
|
||||||
|
use opentelemetry::sdk::trace::Sampler;
|
||||||
|
use opentelemetry::sdk::Resource;
|
||||||
|
use opentelemetry::{global, KeyValue};
|
||||||
|
use opentelemetry_otlp::WithExportConfig;
|
||||||
|
use tracing_subscriber::layer::SubscriberExt;
|
||||||
|
use tracing_subscriber::util::SubscriberInitExt;
|
||||||
|
use tracing_subscriber::{filter::LevelFilter, EnvFilter, Layer};
|
||||||
|
|
||||||
|
/// Init logging using env variables LOG_LEVEL and LOG_FORMAT:
|
||||||
|
/// - otlp_endpoint is an optional URL to an Open Telemetry collector
|
||||||
|
/// - otlp_service_name service name to appear in APM
|
||||||
|
/// - LOG_LEVEL may be TRACE, DEBUG, INFO, WARN or ERROR (default to INFO)
|
||||||
|
/// - LOG_FORMAT may be TEXT or JSON (default to TEXT)
|
||||||
|
/// - LOG_COLORIZE may be "false" or "true" (default to "true" or ansi supported platforms)
|
||||||
|
pub fn init_logging(otlp_endpoint: Option<String>, otlp_service_name: String, json_output: bool) {
|
||||||
|
let mut layers = Vec::new();
|
||||||
|
|
||||||
|
// STDOUT/STDERR layer
|
||||||
|
let ansi = std::env::var("LOG_COLORIZE") != Ok("1".to_string());
|
||||||
|
let fmt_layer = tracing_subscriber::fmt::layer()
|
||||||
|
.with_file(true)
|
||||||
|
.with_ansi(ansi)
|
||||||
|
.with_line_number(true);
|
||||||
|
|
||||||
|
let fmt_layer = match json_output {
|
||||||
|
true => fmt_layer.json().flatten_event(true).boxed(),
|
||||||
|
false => fmt_layer.boxed(),
|
||||||
|
};
|
||||||
|
layers.push(fmt_layer);
|
||||||
|
|
||||||
|
// OpenTelemetry tracing layer
|
||||||
|
if let Some(otlp_endpoint) = otlp_endpoint {
|
||||||
|
global::set_text_map_propagator(TraceContextPropagator::new());
|
||||||
|
|
||||||
|
let tracer = opentelemetry_otlp::new_pipeline()
|
||||||
|
.tracing()
|
||||||
|
.with_exporter(
|
||||||
|
opentelemetry_otlp::new_exporter()
|
||||||
|
.tonic()
|
||||||
|
.with_endpoint(otlp_endpoint),
|
||||||
|
)
|
||||||
|
.with_trace_config(
|
||||||
|
trace::config()
|
||||||
|
.with_resource(Resource::new(vec![KeyValue::new(
|
||||||
|
"service.name",
|
||||||
|
otlp_service_name,
|
||||||
|
)]))
|
||||||
|
.with_sampler(Sampler::AlwaysOn),
|
||||||
|
)
|
||||||
|
.install_batch(opentelemetry::runtime::Tokio);
|
||||||
|
|
||||||
|
if let Ok(tracer) = tracer {
|
||||||
|
layers.push(tracing_opentelemetry::layer().with_tracer(tracer).boxed());
|
||||||
|
init_tracing_opentelemetry::init_propagator().unwrap();
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
// Filter events with LOG_LEVEL
|
||||||
|
let varname = "LOG_LEVEL";
|
||||||
|
let env_filter = if let Ok(log_level) = std::env::var(varname) {
|
||||||
|
// Override to avoid simple logs to be spammed with tokio level informations
|
||||||
|
let log_level = match &log_level[..] {
|
||||||
|
"warn" => "text_generation_launcher=warn,text_generation_router=warn",
|
||||||
|
"info" => "text_generation_launcher=info,text_generation_router=info",
|
||||||
|
"debug" => "text_generation_launcher=debug,text_generation_router=debug",
|
||||||
|
log_level => log_level,
|
||||||
|
};
|
||||||
|
EnvFilter::builder()
|
||||||
|
.with_default_directive(LevelFilter::INFO.into())
|
||||||
|
.parse_lossy(log_level)
|
||||||
|
} else {
|
||||||
|
EnvFilter::new("info")
|
||||||
|
};
|
||||||
|
|
||||||
|
tracing_subscriber::registry()
|
||||||
|
.with(env_filter)
|
||||||
|
.with(layers)
|
||||||
|
.init();
|
||||||
|
}
|
|
@ -1,7 +1,7 @@
|
||||||
/// HTTP Server logic
|
/// HTTP Server logic
|
||||||
use crate::config::Config;
|
use crate::config::Config;
|
||||||
use crate::infer::{Infer, InferError, InferResponse, InferStreamResponse, Backend};
|
|
||||||
use crate::infer::tool_grammar::ToolGrammar;
|
use crate::infer::tool_grammar::ToolGrammar;
|
||||||
|
use crate::infer::{Backend, Infer, InferError, InferResponse, InferStreamResponse};
|
||||||
#[cfg(feature = "kserve")]
|
#[cfg(feature = "kserve")]
|
||||||
use crate::kserve::{
|
use crate::kserve::{
|
||||||
kerve_server_metadata, kserve_health_live, kserve_health_ready, kserve_model_infer,
|
kerve_server_metadata, kserve_health_live, kserve_health_ready, kserve_model_infer,
|
||||||
|
@ -23,7 +23,7 @@ use crate::{
|
||||||
use crate::{FunctionDefinition, ToolCall, ToolType};
|
use crate::{FunctionDefinition, ToolCall, ToolType};
|
||||||
use async_stream::__private::AsyncStream;
|
use async_stream::__private::AsyncStream;
|
||||||
use axum::extract::Extension;
|
use axum::extract::Extension;
|
||||||
use axum::http::{HeaderMap, Method, StatusCode};
|
use axum::http::{HeaderMap, HeaderValue, Method, StatusCode};
|
||||||
use axum::response::sse::{Event, KeepAlive, Sse};
|
use axum::response::sse::{Event, KeepAlive, Sse};
|
||||||
use axum::response::{IntoResponse, Response};
|
use axum::response::{IntoResponse, Response};
|
||||||
use axum::routing::{get, post};
|
use axum::routing::{get, post};
|
||||||
|
@ -33,10 +33,15 @@ use futures::stream::StreamExt;
|
||||||
use futures::stream::{FuturesOrdered, FuturesUnordered};
|
use futures::stream::{FuturesOrdered, FuturesUnordered};
|
||||||
use futures::Stream;
|
use futures::Stream;
|
||||||
use futures::TryStreamExt;
|
use futures::TryStreamExt;
|
||||||
|
use hf_hub::api::tokio::{Api, ApiBuilder, ApiRepo};
|
||||||
|
use hf_hub::{Cache, Repo, RepoType};
|
||||||
use metrics_exporter_prometheus::{Matcher, PrometheusBuilder, PrometheusHandle};
|
use metrics_exporter_prometheus::{Matcher, PrometheusBuilder, PrometheusHandle};
|
||||||
use serde_json::Value;
|
use serde_json::Value;
|
||||||
use std::convert::Infallible;
|
use std::convert::Infallible;
|
||||||
use std::net::SocketAddr;
|
use std::fs::File;
|
||||||
|
use std::io::BufReader;
|
||||||
|
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
|
||||||
|
use std::path::{Path, PathBuf};
|
||||||
use thiserror::Error;
|
use thiserror::Error;
|
||||||
use tokenizers::Tokenizer;
|
use tokenizers::Tokenizer;
|
||||||
use tokio::select;
|
use tokio::select;
|
||||||
|
@ -1395,24 +1400,22 @@ pub(crate) struct ComputeType(String);
|
||||||
#[allow(clippy::too_many_arguments)]
|
#[allow(clippy::too_many_arguments)]
|
||||||
pub async fn run(
|
pub async fn run(
|
||||||
backend: impl Backend + Send + Sync + 'static,
|
backend: impl Backend + Send + Sync + 'static,
|
||||||
model_info: HubModelInfo,
|
|
||||||
compat_return_full_text: bool,
|
|
||||||
max_concurrent_requests: usize,
|
max_concurrent_requests: usize,
|
||||||
max_best_of: usize,
|
max_best_of: usize,
|
||||||
max_stop_sequences: usize,
|
max_stop_sequences: usize,
|
||||||
max_top_n_tokens: u32,
|
max_top_n_tokens: u32,
|
||||||
max_input_tokens: usize,
|
max_input_tokens: usize,
|
||||||
max_total_tokens: usize,
|
max_total_tokens: usize,
|
||||||
tokenizer: Option<Tokenizer>,
|
|
||||||
config: Option<Config>,
|
|
||||||
validation_workers: usize,
|
validation_workers: usize,
|
||||||
addr: SocketAddr,
|
tokenizer_name: String,
|
||||||
allow_origin: Option<AllowOrigin>,
|
tokenizer_config_path: Option<String>,
|
||||||
|
revision: Option<String>,
|
||||||
|
hostname: String,
|
||||||
|
port: u16,
|
||||||
|
cors_allow_origin: Option<Vec<String>>,
|
||||||
ngrok: bool,
|
ngrok: bool,
|
||||||
_ngrok_authtoken: Option<String>,
|
_ngrok_authtoken: Option<String>,
|
||||||
_ngrok_edge: Option<String>,
|
_ngrok_edge: Option<String>,
|
||||||
tokenizer_config: HubTokenizerConfig,
|
|
||||||
processor_config: HubProcessorConfig,
|
|
||||||
messages_api_enabled: bool,
|
messages_api_enabled: bool,
|
||||||
grammar_support: bool,
|
grammar_support: bool,
|
||||||
max_client_batch_size: usize,
|
max_client_batch_size: usize,
|
||||||
|
@ -1485,6 +1488,195 @@ pub async fn run(
|
||||||
)]
|
)]
|
||||||
struct ApiDoc;
|
struct ApiDoc;
|
||||||
|
|
||||||
|
// CORS allowed origins
|
||||||
|
// map to go inside the option and then map to parse from String to HeaderValue
|
||||||
|
// Finally, convert to AllowOrigin
|
||||||
|
let allow_origin: Option<AllowOrigin> = cors_allow_origin.map(|cors_allow_origin| {
|
||||||
|
AllowOrigin::list(
|
||||||
|
cors_allow_origin
|
||||||
|
.iter()
|
||||||
|
.map(|origin| origin.parse::<HeaderValue>().unwrap()),
|
||||||
|
)
|
||||||
|
});
|
||||||
|
|
||||||
|
// Parse Huggingface hub token
|
||||||
|
let authorization_token = std::env::var("HF_TOKEN")
|
||||||
|
.or_else(|_| std::env::var("HUGGING_FACE_HUB_TOKEN"))
|
||||||
|
.ok();
|
||||||
|
|
||||||
|
// Tokenizer instance
|
||||||
|
// This will only be used to validate payloads
|
||||||
|
let local_path = Path::new(&tokenizer_name);
|
||||||
|
|
||||||
|
// Shared API builder initialization
|
||||||
|
let api_builder = || {
|
||||||
|
let mut builder = ApiBuilder::new()
|
||||||
|
.with_progress(false)
|
||||||
|
.with_token(authorization_token);
|
||||||
|
|
||||||
|
if let Ok(cache_dir) = std::env::var("HUGGINGFACE_HUB_CACHE") {
|
||||||
|
builder = builder.with_cache_dir(cache_dir.into());
|
||||||
|
}
|
||||||
|
|
||||||
|
builder
|
||||||
|
};
|
||||||
|
|
||||||
|
// Decide if we need to use the API based on the revision and local path
|
||||||
|
let use_api = revision.is_some() || !local_path.exists() || !local_path.is_dir();
|
||||||
|
|
||||||
|
// Initialize API if needed
|
||||||
|
#[derive(Clone)]
|
||||||
|
enum Type {
|
||||||
|
Api(Api),
|
||||||
|
Cache(Cache),
|
||||||
|
None,
|
||||||
|
}
|
||||||
|
let api = if use_api {
|
||||||
|
if std::env::var("HF_HUB_OFFLINE") == Ok("1".to_string()) {
|
||||||
|
let cache = Cache::default();
|
||||||
|
tracing::warn!("Offline mode active using cache defaults");
|
||||||
|
Type::Cache(cache)
|
||||||
|
} else {
|
||||||
|
tracing::info!("Using the Hugging Face API");
|
||||||
|
match api_builder().build() {
|
||||||
|
Ok(api) => Type::Api(api),
|
||||||
|
Err(_) => {
|
||||||
|
tracing::warn!("Unable to build the Hugging Face API");
|
||||||
|
Type::None
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
Type::None
|
||||||
|
};
|
||||||
|
|
||||||
|
// Load tokenizer and model info
|
||||||
|
let (
|
||||||
|
tokenizer_filename,
|
||||||
|
config_filename,
|
||||||
|
tokenizer_config_filename,
|
||||||
|
processor_config_filename,
|
||||||
|
model_info,
|
||||||
|
) = match api {
|
||||||
|
Type::None => (
|
||||||
|
Some(local_path.join("tokenizer.json")),
|
||||||
|
Some(local_path.join("config.json")),
|
||||||
|
Some(local_path.join("tokenizer_config.json")),
|
||||||
|
Some(local_path.join("processor_config.json")),
|
||||||
|
None,
|
||||||
|
),
|
||||||
|
Type::Api(api) => {
|
||||||
|
let api_repo = api.repo(Repo::with_revision(
|
||||||
|
tokenizer_name.to_string(),
|
||||||
|
RepoType::Model,
|
||||||
|
revision.clone().unwrap_or_else(|| "main".to_string()),
|
||||||
|
));
|
||||||
|
|
||||||
|
let tokenizer_filename = match api_repo.get("tokenizer.json").await {
|
||||||
|
Ok(tokenizer_filename) => Some(tokenizer_filename),
|
||||||
|
Err(_) => get_base_tokenizer(&api, &api_repo).await,
|
||||||
|
};
|
||||||
|
let config_filename = api_repo.get("config.json").await.ok();
|
||||||
|
let tokenizer_config_filename = api_repo.get("tokenizer_config.json").await.ok();
|
||||||
|
let processor_config_filename = api_repo.get("processor_config.json").await.ok();
|
||||||
|
|
||||||
|
let model_info = if let Some(model_info) = get_hub_model_info(&api_repo).await {
|
||||||
|
Some(model_info)
|
||||||
|
} else {
|
||||||
|
tracing::warn!("Could not retrieve model info from the Hugging Face hub.");
|
||||||
|
None
|
||||||
|
};
|
||||||
|
(
|
||||||
|
tokenizer_filename,
|
||||||
|
config_filename,
|
||||||
|
tokenizer_config_filename,
|
||||||
|
processor_config_filename,
|
||||||
|
model_info,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
Type::Cache(cache) => {
|
||||||
|
let repo = cache.repo(Repo::with_revision(
|
||||||
|
tokenizer_name.to_string(),
|
||||||
|
RepoType::Model,
|
||||||
|
revision.clone().unwrap_or_else(|| "main".to_string()),
|
||||||
|
));
|
||||||
|
(
|
||||||
|
repo.get("tokenizer.json"),
|
||||||
|
repo.get("config.json"),
|
||||||
|
repo.get("tokenizer_config.json"),
|
||||||
|
repo.get("processor_config.json"),
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
};
|
||||||
|
let tokenizer: Option<Tokenizer> =
|
||||||
|
tokenizer_filename.and_then(|filename| Tokenizer::from_file(filename).ok());
|
||||||
|
let config: Option<Config> = config_filename.and_then(|filename| {
|
||||||
|
std::fs::read_to_string(filename)
|
||||||
|
.ok()
|
||||||
|
.as_ref()
|
||||||
|
.and_then(|c| {
|
||||||
|
let config: Result<Config, _> = serde_json::from_str(c);
|
||||||
|
if let Err(err) = &config {
|
||||||
|
tracing::warn!("Could not parse config {err:?}");
|
||||||
|
}
|
||||||
|
config.ok()
|
||||||
|
})
|
||||||
|
});
|
||||||
|
let model_info = model_info.unwrap_or_else(|| HubModelInfo {
|
||||||
|
model_id: tokenizer_name.to_string(),
|
||||||
|
sha: None,
|
||||||
|
pipeline_tag: None,
|
||||||
|
});
|
||||||
|
|
||||||
|
// Read the JSON contents of the file as an instance of 'HubTokenizerConfig'.
|
||||||
|
let tokenizer_config: Option<HubTokenizerConfig> = if let Some(filename) = tokenizer_config_path
|
||||||
|
{
|
||||||
|
HubTokenizerConfig::from_file(filename)
|
||||||
|
} else {
|
||||||
|
tokenizer_config_filename.and_then(HubTokenizerConfig::from_file)
|
||||||
|
};
|
||||||
|
let tokenizer_config = tokenizer_config.unwrap_or_else(|| {
|
||||||
|
tracing::warn!("Could not find tokenizer config locally and no API specified");
|
||||||
|
HubTokenizerConfig::default()
|
||||||
|
});
|
||||||
|
|
||||||
|
let processor_config = processor_config_filename
|
||||||
|
.and_then(HubProcessorConfig::from_file)
|
||||||
|
.unwrap_or_default();
|
||||||
|
|
||||||
|
tracing::info!("Using config {config:?}");
|
||||||
|
if tokenizer.is_none() {
|
||||||
|
tracing::warn!("Could not find a fast tokenizer implementation for {tokenizer_name}");
|
||||||
|
tracing::warn!("Rust input length validation and truncation is disabled");
|
||||||
|
}
|
||||||
|
|
||||||
|
// if pipeline-tag == text-generation we default to return_full_text = true
|
||||||
|
let compat_return_full_text = match &model_info.pipeline_tag {
|
||||||
|
None => {
|
||||||
|
tracing::warn!("no pipeline tag found for model {tokenizer_name}");
|
||||||
|
true
|
||||||
|
}
|
||||||
|
Some(pipeline_tag) => pipeline_tag.as_str() == "text-generation",
|
||||||
|
};
|
||||||
|
|
||||||
|
// Determine the server port based on the feature and environment variable.
|
||||||
|
let port = if cfg!(feature = "google") {
|
||||||
|
std::env::var("AIP_HTTP_PORT")
|
||||||
|
.map(|aip_http_port| aip_http_port.parse::<u16>().unwrap_or(port))
|
||||||
|
.unwrap_or(port)
|
||||||
|
} else {
|
||||||
|
port
|
||||||
|
};
|
||||||
|
|
||||||
|
let addr = match hostname.parse() {
|
||||||
|
Ok(ip) => SocketAddr::new(ip, port),
|
||||||
|
Err(_) => {
|
||||||
|
tracing::warn!("Invalid hostname, defaulting to 0.0.0.0");
|
||||||
|
SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), port)
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
// Create state
|
// Create state
|
||||||
let validation = Validation::new(
|
let validation = Validation::new(
|
||||||
validation_workers,
|
validation_workers,
|
||||||
|
@ -1551,8 +1743,8 @@ pub async fn run(
|
||||||
.unwrap()
|
.unwrap()
|
||||||
.set_buckets_for_metric(batch_size_matcher, &batch_size_buckets)
|
.set_buckets_for_metric(batch_size_matcher, &batch_size_buckets)
|
||||||
.unwrap();
|
.unwrap();
|
||||||
// .set_buckets_for_metric(skipped_matcher, &skipped_buckets)
|
// .set_buckets_for_metric(skipped_matcher, &skipped_buckets)
|
||||||
// .unwrap();
|
// .unwrap();
|
||||||
let prom_handle = builder
|
let prom_handle = builder
|
||||||
.install_recorder()
|
.install_recorder()
|
||||||
.expect("failed to install metrics recorder");
|
.expect("failed to install metrics recorder");
|
||||||
|
@ -1750,6 +1942,68 @@ pub async fn run(
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// get model info from the Huggingface Hub
|
||||||
|
pub async fn get_hub_model_info(api: &ApiRepo) -> Option<HubModelInfo> {
|
||||||
|
let response = api.info_request().send().await.ok()?;
|
||||||
|
|
||||||
|
if response.status().is_success() {
|
||||||
|
let hub_model_info: HubModelInfo =
|
||||||
|
serde_json::from_str(&response.text().await.ok()?).ok()?;
|
||||||
|
if let Some(sha) = &hub_model_info.sha {
|
||||||
|
tracing::info!(
|
||||||
|
"Serving revision {sha} of model {}",
|
||||||
|
hub_model_info.model_id
|
||||||
|
);
|
||||||
|
}
|
||||||
|
Some(hub_model_info)
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// get base tokenizer
|
||||||
|
pub async fn get_base_tokenizer(api: &Api, api_repo: &ApiRepo) -> Option<PathBuf> {
|
||||||
|
let config_filename = api_repo.get("config.json").await.ok()?;
|
||||||
|
|
||||||
|
// Open the file in read-only mode with buffer.
|
||||||
|
let file = File::open(config_filename).ok()?;
|
||||||
|
let reader = BufReader::new(file);
|
||||||
|
|
||||||
|
// Read the JSON contents of the file as an instance of `User`.
|
||||||
|
let config: serde_json::Value = serde_json::from_reader(reader).ok()?;
|
||||||
|
|
||||||
|
if let Some(serde_json::Value::String(base_model_id)) = config.get("base_model_name_or_path") {
|
||||||
|
let api_base_repo = api.repo(Repo::with_revision(
|
||||||
|
base_model_id.to_string(),
|
||||||
|
RepoType::Model,
|
||||||
|
"main".to_string(),
|
||||||
|
));
|
||||||
|
|
||||||
|
api_base_repo.get("tokenizer.json").await.ok()
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// get tokenizer_config from the Huggingface Hub
|
||||||
|
pub async fn get_tokenizer_config(api_repo: &ApiRepo) -> Option<HubTokenizerConfig> {
|
||||||
|
let tokenizer_config_filename = api_repo.get("tokenizer_config.json").await.ok()?;
|
||||||
|
|
||||||
|
// Open the file in read-only mode with buffer.
|
||||||
|
let file = File::open(tokenizer_config_filename).ok()?;
|
||||||
|
let reader = BufReader::new(file);
|
||||||
|
|
||||||
|
// Read the JSON contents of the file as an instance of 'HubTokenizerConfig'.
|
||||||
|
let tokenizer_config: HubTokenizerConfig = serde_json::from_reader(reader)
|
||||||
|
.map_err(|e| {
|
||||||
|
tracing::warn!("Unable to parse tokenizer config: {}", e);
|
||||||
|
e
|
||||||
|
})
|
||||||
|
.ok()?;
|
||||||
|
|
||||||
|
Some(tokenizer_config)
|
||||||
|
}
|
||||||
|
|
||||||
/// Shutdown signal handler
|
/// Shutdown signal handler
|
||||||
async fn shutdown_signal() {
|
async fn shutdown_signal() {
|
||||||
let ctrl_c = async {
|
let ctrl_c = async {
|
||||||
|
|
|
@ -638,7 +638,7 @@ pub struct Image {
|
||||||
#[derive(Debug, Clone, Eq, PartialEq)]
|
#[derive(Debug, Clone, Eq, PartialEq)]
|
||||||
pub enum Chunk {
|
pub enum Chunk {
|
||||||
Text(String),
|
Text(String),
|
||||||
Image(Image)
|
Image(Image),
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Convert input chunks to a stringly-typed input for backwards
|
/// Convert input chunks to a stringly-typed input for backwards
|
||||||
|
|
Loading…
Reference in New Issue