This commit is contained in:
Nicolas Patry 2024-07-30 12:22:24 +02:00
parent ddbbf6b50c
commit bc0a33e1c9
No known key found for this signature in database
GPG Key ID: E939E8CC91A1C674
17 changed files with 5305 additions and 185 deletions

2
.gitignore vendored
View File

@ -3,6 +3,8 @@ target
router/tokenizer.json
*__pycache__*
backends/v3/src/client/pb
# ROCm auto-generated files
*.hip
server/exllamav2_kernels/exllamav2_kernels/hip/

4963
Cargo.lock generated Normal file

File diff suppressed because it is too large Load Diff

View File

@ -1,11 +1,12 @@
[workspace]
members = [
# "benchmark",
"backends/v3",
# "backends/client",
"backends/grpc-metadata",
"launcher"
, "backends/trtllm"]
# "benchmark",
"backends/v3",
# "backends/client",
"backends/grpc-metadata",
# "backends/trtllm",
"launcher"
]
resolver = "2"
[workspace.package]
@ -18,6 +19,8 @@ homepage = "https://github.com/huggingface/text-generation-inference"
base64 = "0.22.0"
tokenizers = { version = "0.19.1", features = ["http"] }
hf-hub = { version = "0.3.1", features = ["tokio"] }
metrics = { version = "0.23.0" }
metrics-exporter-prometheus = { version = "0.15.1", features = [] }
[profile.release]
incremental = true

View File

@ -2,8 +2,8 @@ use std::future::Future;
use std::path::Path;
use std::pin::{pin, Pin};
use std::str::FromStr;
use std::sync::{Arc, OnceLock};
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::{Arc, OnceLock};
use std::task::{Context, Poll};
use std::time::Duration;
@ -13,15 +13,15 @@ use log::{error, warn};
use tokenizers::Tokenizer;
use tokio::sync::mpsc::{unbounded_channel, UnboundedSender};
use tokio::sync::RwLock;
use tokio::time::{Instant, sleep};
use tokio_stream::{Stream, StreamExt};
use tokio::time::{sleep, Instant};
use tokio_stream::wrappers::UnboundedReceiverStream;
use tracing::{instrument, Level, span};
use tokio_stream::{Stream, StreamExt};
use tracing::{instrument, span, Level};
use text_generation_router::{FinishReason, Token};
use text_generation_router::infer::{Backend, GeneratedText, InferError, InferStreamResponse};
use text_generation_router::validation::{Chunk, ValidationError, ValidGenerateRequest};
use text_generation_router::validation::ValidationError::UnsupportedModality;
use text_generation_router::validation::{Chunk, ValidGenerateRequest, ValidationError};
use text_generation_router::{FinishReason, Token};
use crate::errors::TensorRtLlmBackendError;
use crate::ffi::{create_tensorrt_llm_backend, GenerationStep, TensorRtLlmBackendImpl};

View File

@ -24,8 +24,8 @@ grpc-metadata = { path = "../grpc-metadata" }
futures = "0.3.28"
hf-hub = { workspace = true }
jsonschema = { version = "0.17.1", features = ["draft202012"] }
metrics = "0.21.1"
metrics-exporter-prometheus = { version = "0.12.1", features = [] }
metrics = { workspace = true }
metrics-exporter-prometheus = { workspace = true }
nohash-hasher = "0.2.0"
opentelemetry = { version = "0.20.0", features = ["rt-tokio"] }
opentelemetry-otlp = "0.13.0"

View File

@ -148,8 +148,8 @@ pub(crate) async fn batching_task(
let batch_size = batch.size;
let batch_max_tokens = batch.max_tokens;
let mut batches = vec![batch];
metrics::gauge!("tgi_batch_current_size", batch_size as f64);
metrics::gauge!("tgi_batch_current_max_tokens", batch_max_tokens as f64);
metrics::gauge!("tgi_batch_current_size").set(batch_size as f64);
metrics::gauge!("tgi_batch_current_max_tokens").set(batch_max_tokens as f64);
let min_size = if waiting_tokens >= max_waiting_tokens {
// If we didn't onboard any new requests since >= max_waiting_tokens, we try
@ -170,9 +170,11 @@ pub(crate) async fn batching_task(
{
// Tracking metrics
if min_size.is_some() {
metrics::increment_counter!("tgi_batch_concat", "reason" => "backpressure");
metrics::counter!("tgi_batch_concat", "reason" => "backpressure")
.increment(1);
} else {
metrics::increment_counter!("tgi_batch_concat", "reason" => "wait_exceeded");
metrics::counter!("tgi_batch_concat", "reason" => "wait_exceeded")
.increment(1);
}
entries.iter_mut().for_each(|(_, entry)| {
@ -218,8 +220,8 @@ pub(crate) async fn batching_task(
.await;
waiting_tokens += 1;
}
metrics::gauge!("tgi_batch_current_size", 0.0);
metrics::gauge!("tgi_batch_current_max_tokens", 0.0);
metrics::gauge!("tgi_batch_current_size").set(0.0);
metrics::gauge!("tgi_batch_current_max_tokens").set(0.0);
}
}
}
@ -232,7 +234,7 @@ async fn prefill(
) -> Option<CachedBatch> {
let start_time = Instant::now();
let batch_id = batch.id;
metrics::increment_counter!("tgi_batch_inference_count", "method" => "prefill");
metrics::counter!("tgi_batch_inference_count", "method" => "prefill").increment(1);
match client.prefill(batch).await {
Ok((generations, next_batch, timings)) => {
@ -243,18 +245,22 @@ async fn prefill(
// Filter next batch and remove requests that were stopped
let next_batch = filter_batch(client, next_batch, entries).await;
metrics::histogram!("tgi_batch_forward_duration", timings.forward.as_secs_f64(), "method" => "prefill");
metrics::histogram!("tgi_batch_decode_duration", timings.decode.as_secs_f64(), "method" => "prefill");
metrics::histogram!("tgi_batch_filter_duration", start_filtering_time.elapsed().as_secs_f64(), "method" => "prefill");
metrics::histogram!("tgi_batch_inference_duration", start_time.elapsed().as_secs_f64(), "method" => "prefill");
metrics::increment_counter!("tgi_batch_inference_success", "method" => "prefill");
metrics::histogram!("tgi_batch_forward_duration", "method" => "prefill")
.record(timings.forward.as_secs_f64());
metrics::histogram!("tgi_batch_decode_duration", "method" => "prefill")
.record(timings.decode.as_secs_f64());
metrics::histogram!("tgi_batch_filter_duration", "method" => "prefill")
.record(start_filtering_time.elapsed().as_secs_f64());
metrics::histogram!("tgi_batch_inference_duration", "method" => "prefill")
.record(start_time.elapsed().as_secs_f64());
metrics::counter!("tgi_batch_inference_success", "method" => "prefill").increment(1);
next_batch
}
// If we have an error, we discard the whole batch
Err(err) => {
let _ = client.clear_cache(Some(batch_id)).await;
send_errors(err, entries);
metrics::increment_counter!("tgi_batch_inference_failure", "method" => "prefill");
metrics::counter!("tgi_batch_inference_failure", "method" => "prefill").increment(1);
None
}
}
@ -268,7 +274,7 @@ async fn decode(
) -> Option<CachedBatch> {
let start_time = Instant::now();
let batch_ids: Vec<u64> = batches.iter().map(|b| b.id).collect();
metrics::increment_counter!("tgi_batch_inference_count", "method" => "decode");
metrics::counter!("tgi_batch_inference_count", "method" => "decode").increment(1);
match client.decode(batches).await {
Ok((generations, next_batch, timings)) => {
@ -280,13 +286,18 @@ async fn decode(
let next_batch = filter_batch(client, next_batch, entries).await;
if let Some(concat_duration) = timings.concat {
metrics::histogram!("tgi_batch_concat_duration", concat_duration.as_secs_f64(), "method" => "decode");
metrics::histogram!("tgi_batch_concat_duration", "method" => "decode")
.record(concat_duration.as_secs_f64());
}
metrics::histogram!("tgi_batch_forward_duration", timings.forward.as_secs_f64(), "method" => "decode");
metrics::histogram!("tgi_batch_decode_duration", timings.decode.as_secs_f64(), "method" => "decode");
metrics::histogram!("tgi_batch_filter_duration", start_filtering_time.elapsed().as_secs_f64(), "method" => "decode");
metrics::histogram!("tgi_batch_inference_duration", start_time.elapsed().as_secs_f64(), "method" => "decode");
metrics::increment_counter!("tgi_batch_inference_success", "method" => "decode");
metrics::histogram!("tgi_batch_forward_duration", "method" => "decode")
.record(timings.forward.as_secs_f64());
metrics::histogram!("tgi_batch_decode_duration", "method" => "decode")
.record(timings.decode.as_secs_f64());
metrics::histogram!("tgi_batch_filter_duration", "method" => "decode")
.record(start_filtering_time.elapsed().as_secs_f64());
metrics::histogram!("tgi_batch_inference_duration", "method" => "decode")
.record(start_time.elapsed().as_secs_f64());
metrics::counter!("tgi_batch_inference_success", "method" => "decode").increment(1);
next_batch
}
// If we have an error, we discard the whole batch
@ -295,7 +306,7 @@ async fn decode(
let _ = client.clear_cache(Some(id)).await;
}
send_errors(err, entries);
metrics::increment_counter!("tgi_batch_inference_failure", "method" => "decode");
metrics::counter!("tgi_batch_inference_failure", "method" => "decode").increment(1);
None
}
}
@ -353,7 +364,7 @@ fn filter_send_generations(generations: Vec<Generation>, entries: &mut IntMap<u6
// request and we need to stop generating hence why we unwrap_or(true)
let stopped = send_responses(generation, entry).map_err(|err| {
tracing::error!("Entry response channel error.");
metrics::increment_counter!("tgi_request_failure", "err" => "dropped");
metrics::counter!("tgi_request_failure", "err" => "dropped").increment(1);
err
}).unwrap_or(true);
if stopped {
@ -369,7 +380,7 @@ fn send_responses(
) -> Result<bool, Box<SendError<Result<InferStreamResponse, InferError>>>> {
// Return directly if the channel is disconnected
if entry.response_tx.is_closed() {
metrics::increment_counter!("tgi_request_failure", "err" => "dropped");
metrics::counter!("tgi_request_failure", "err" => "dropped").increment(1);
return Ok(true);
}
@ -395,7 +406,7 @@ fn send_responses(
// Create last Token
let tokens_ = generation.tokens.expect("Non empty tokens in generation");
let n = tokens_.ids.len();
metrics::histogram!("tgi_request_skipped_tokens", (n - 1) as f64);
metrics::histogram!("tgi_request_skipped_tokens").record((n - 1) as f64);
let mut iterator = tokens_
.ids
.into_iter()
@ -460,7 +471,7 @@ fn send_errors(error: ClientError, entries: &mut IntMap<u64, Entry>) {
// Create and enter a span to link this function back to the entry
let _send_error_span = info_span!(parent: entry.temp_span.as_ref().expect("batch_span is None. This is a bug."), "send_error").entered();
let err = InferError::GenerationError(error.to_string());
metrics::increment_counter!("tgi_request_failure", "err" => "generation");
metrics::counter!("tgi_request_failure", "err" => "generation").increment(1);
tracing::error!("{err}");
// unwrap_or is valid here as we don't care if the receiver is gone.

View File

@ -8,10 +8,10 @@ use tonic::Status;
#[allow(clippy::derive_partial_eq_without_eq)]
mod pb;
mod client;
mod grpc_client;
mod sharded_client;
pub use client::Client;
pub use grpc_client::Client;
pub use pb::generate::v3::{
input_chunk::Chunk, Batch, CachedBatch, FinishReason, GeneratedText, Generation, GrammarType,
HealthResponse, Image, InfoResponse, Input, InputChunk, NextTokenChooserParameters, Request,

View File

@ -2,7 +2,7 @@ use crate::client::{ClientError, Result};
/// Multi shard Client
use crate::client::{Health, ShardInfo};
use crate::client::client::{DecodeTimings, PrefillTimings};
use crate::client::grpc_client::{DecodeTimings, PrefillTimings};
use crate::client::{
Batch, CachedBatch, Client, Generation, GrammarType, HealthResponse,
NextTokenChooserParameters, Request, StoppingCriteriaParameters,

View File

@ -30,6 +30,7 @@ pub struct BackendInfo {
pub max_batch_size: Option<usize>,
}
#[allow(clippy::too_many_arguments)]
pub async fn connect_backend(
max_input_tokens: usize,
max_total_tokens: usize,

View File

@ -1,4 +1,4 @@
use clap::Parser;
use clap::{Parser, Subcommand};
use text_generation_router::server;
use text_generation_router_v3::{connect_backend, V3Error};
use thiserror::Error;
@ -7,6 +7,9 @@ use thiserror::Error;
#[derive(Parser, Debug)]
#[clap(author, version, about, long_about = None)]
struct Args {
#[command(subcommand)]
command: Option<Commands>,
#[clap(default_value = "128", long, env)]
max_concurrent_requests: usize,
#[clap(default_value = "2", long, env)]
@ -44,6 +47,8 @@ struct Args {
#[clap(default_value = "2", long, env)]
validation_workers: usize,
#[clap(long, env)]
api_key: Option<String>,
#[clap(long, env)]
json_output: bool,
#[clap(long, env)]
otlp_endpoint: Option<String>,
@ -65,12 +70,18 @@ struct Args {
max_client_batch_size: usize,
}
#[derive(Debug, Subcommand)]
enum Commands {
PrintSchema,
}
#[tokio::main]
async fn main() -> Result<(), RouterError> {
// Get args
let args = Args::parse();
// Pattern match configuration
let Args {
command,
max_concurrent_requests,
max_best_of,
max_stop_sequences,
@ -89,6 +100,7 @@ async fn main() -> Result<(), RouterError> {
tokenizer_config_path,
revision,
validation_workers,
api_key,
json_output,
otlp_endpoint,
otlp_service_name,
@ -101,8 +113,19 @@ async fn main() -> Result<(), RouterError> {
max_client_batch_size,
} = args;
let print_schema_command = match command {
Some(Commands::PrintSchema) => true,
None => {
// only init logging if we are not running the print schema command
text_generation_router::logging::init_logging(
otlp_endpoint,
otlp_service_name,
json_output,
);
false
}
};
// Launch Tokio runtime
text_generation_router::logging::init_logging(otlp_endpoint, otlp_service_name, json_output);
// Validate args
if max_input_tokens >= max_total_tokens {
@ -151,6 +174,7 @@ async fn main() -> Result<(), RouterError> {
max_input_tokens,
max_total_tokens,
validation_workers,
api_key,
tokenizer_name,
tokenizer_config_path,
revision,
@ -163,6 +187,7 @@ async fn main() -> Result<(), RouterError> {
messages_api_enabled,
disable_grammar_support,
max_client_batch_size,
print_schema_command,
)
.await?;
Ok(())

View File

@ -17,8 +17,8 @@ futures = "0.3.28"
hf-hub = { workspace = true }
itertools = "0.10"
jsonschema = { version = "0.17.1", features = ["draft202012"] }
metrics = "0.23.0"
metrics-exporter-prometheus = { version = "0.15.1", features = [] }
metrics = { workspace = true }
metrics-exporter-prometheus = { workspace = true }
nohash-hasher = "0.2.0"
opentelemetry = { version = "0.20.0", features = ["rt-tokio"] }
opentelemetry-otlp = "0.13.0"
@ -48,6 +48,7 @@ base64 = { workspace = true }
sysinfo = "0.30.13"
uuid = { version = "1.9.1", default-features = false, features = ["v4", "fast-rng", "macro-diagnostics"] }
csv = "1.3.0"
ureq = "=2.9"
[build-dependencies]

View File

@ -1,5 +1,7 @@
use crate::infer::InferError;
use crate::{ChatTemplateInputs, GrammarType, Message, MessageChunk, Text, TextMessage};
use crate::{
ChatTemplateInputs, GrammarType, Message, MessageChunk, TextMessage, TokenizerConfigToken,
};
use minijinja::{Environment, ErrorKind, Template};
use minijinja_contrib::pycompat;
@ -19,8 +21,8 @@ pub(crate) struct ChatTemplate {
impl ChatTemplate {
pub(crate) fn new(
template: String,
bos_token: Option<String>,
eos_token: Option<String>,
bos_token: Option<TokenizerConfigToken>,
eos_token: Option<TokenizerConfigToken>,
) -> Self {
let mut env = Box::new(Environment::new());
// enable things like .strip() or .capitalize()
@ -38,8 +40,8 @@ impl ChatTemplate {
Self {
template,
bos_token,
eos_token,
bos_token: bos_token.map(|token| token.as_str().to_string()),
eos_token: eos_token.map(|token| token.as_str().to_string()),
use_default_tool_template,
}
}
@ -52,9 +54,9 @@ impl ChatTemplate {
if self.use_default_tool_template {
if let Some(last_message) = messages.last_mut() {
if let Some((GrammarType::Json(tools), tool_prompt)) = grammar_with_prompt {
last_message.content.push(MessageChunk::Text(Text {
last_message.content.push(MessageChunk::Text {
text: format!("\n---\n{}\n{}", tool_prompt, tools),
}));
});
}
}
}

View File

@ -1,122 +1,135 @@
use crate::infer::InferError;
use crate::{FunctionRef, FunctionsMap, Properties, Tool, ToolType, Tools};
use crate::{FunctionRef, FunctionsMap, Properties, Tool, ToolChoice, ToolType, Tools};
use serde_json::{json, Map, Value};
use std::collections::HashMap;
pub(crate) struct ToolGrammar {}
impl ToolGrammar {
// find a tool by name
fn find_tool_by_name(tools: &[Tool], name: &str) -> Result<Tool, InferError> {
tools
.iter()
.find(|tool| tool.function.name == name)
.cloned()
.ok_or_else(|| InferError::ToolError(format!("Tool with name {} not found", name)))
}
pub fn apply(
tools: Option<Vec<Tool>>,
tool_choice: Option<ToolType>,
tool_choice: ToolChoice,
) -> Result<Option<Tools>, InferError> {
if let Some((req_tools, tool_choice)) = tools.zip(tool_choice) {
// let tool_prompt = tool_prompt.unwrap_or_default();
let tools_to_use = match tool_choice {
ToolType::FunctionName(name) => {
vec![req_tools
.iter()
.find(|tool| tool.function.name == *name)
.unwrap_or_else(|| panic!("Tool with name {} not found", name))
.clone()]
}
ToolType::OneOf => req_tools.to_owned(),
};
// if no tools are provided, we return None
let tools = match tools {
Some(tools) if !tools.is_empty() => tools,
_ => return Ok(None),
};
// adds the error notification function for LLM feedback if required
let mut text_response_properties = Map::new();
text_response_properties.insert(
"error".to_string(),
json!({
"type": "string",
"description": "The error or issue to notify"
}),
);
text_response_properties.insert(
"_name".to_string(),
json!({
"type": "string",
"const": "notify_error"
}),
);
let tool_choice = tool_choice.0.unwrap_or(ToolType::OneOf);
let functions: HashMap<String, Value> = tools_to_use
.iter()
.map(|tool| {
let func = tool.function.clone();
// if tools are provided and no tool_choice we default to the OneOf
let tools_to_use = match tool_choice {
ToolType::FunctionName(name) => {
vec![Self::find_tool_by_name(&tools, &name)?]
}
ToolType::Function { function } => {
vec![Self::find_tool_by_name(&tools, &function.name)?]
}
ToolType::OneOf => tools,
ToolType::NoTool => return Ok(None),
};
// Clone the existing parameters, which are expected to be a JSON object
let mut params = if let Value::Object(params) = &func.arguments {
params.clone()
} else {
Map::new()
};
// adds the error notification function for LLM feedback if required
let mut text_response_properties = Map::new();
text_response_properties.insert(
"error".to_string(),
serde_json::json!({
"type": "string",
"description": "The error or issue to notify"
}),
);
text_response_properties.insert(
"_name".to_string(),
serde_json::json!({
"type": "string",
"const": "notify_error"
}),
);
// Insert the function's description at the top level, outside of properties
params.insert(
"description".to_string(),
Value::String(func.description.clone().unwrap_or_default()),
);
let functions: HashMap<String, serde_json::Value> = tools_to_use
.iter()
.map(|tool| {
let func = tool.function.clone();
// Ensure 'properties' exists and is an object
let properties = params
.entry("properties".to_string())
.or_insert_with(|| json!({}))
.as_object_mut()
.unwrap();
// Clone the existing parameters, which are expected to be a JSON object
let mut params = if let Value::Object(params) = &func.arguments {
params.clone()
} else {
Map::new()
};
// Insert the constant for the function name inside 'properties'
properties.insert(
"_name".to_string(),
json!({
"type": "string",
"const": func.name.clone(),
// "description": "The name of the function"
}),
);
// Insert the function's description at the top level, outside of properties
params.insert(
"description".to_string(),
Value::String(func.description.clone().unwrap_or_default()),
);
// Check if 'required' exists, and it is an array. If not, create an empty array.
let required = params
.entry("required".to_string())
.or_insert_with(|| json!([]))
.as_array_mut()
.unwrap();
// Ensure 'properties' exists and is an object
let properties = params
.entry("properties".to_string())
.or_insert_with(|| json!({}))
.as_object_mut()
.unwrap();
// Add 'name' to the 'required' array if it is not already present
if !required.iter().any(|r| r == "_name") {
required.push(json!("_name"));
}
(func.name, Value::Object(params))
})
.chain([(
"notify_error".to_string(),
// Insert the constant for the function name inside 'properties'
properties.insert(
"_name".to_string(),
json!({
"properties": text_response_properties,
"required": ["error", "_name"],
"type": "object"
"type": "string",
"const": func.name.clone(),
// "description": "The name of the function"
}),
)])
.collect();
);
let tools = Tools {
functions_map: FunctionsMap { functions },
properties: Properties {
function: tools_to_use
.iter()
.map(|tool| FunctionRef {
ref_path: format!("#/$functions/{}", tool.function.name.clone()),
})
.chain(std::iter::once(FunctionRef {
ref_path: "#/$functions/notify_error".to_string(),
}))
.collect(),
},
};
// Check if 'required' exists, and it is an array. If not, create an empty array.
let required = params
.entry("required".to_string())
.or_insert_with(|| json!([]))
.as_array_mut()
.unwrap();
return Ok(Some(tools));
}
// Err(InferError::ToolError("No tools provided".to_string()))
Ok(None)
// Add 'name' to the 'required' array if it is not already present
if !required.iter().any(|r| r == "_name") {
required.push(json!("_name"));
}
(func.name, Value::Object(params))
})
.chain([(
"notify_error".to_string(),
serde_json::json!({
"properties": text_response_properties,
"required": ["error", "_name"],
"type": "object"
}),
)])
.collect();
let tools = Tools {
functions_map: FunctionsMap { functions },
properties: Properties {
function: tools_to_use
.iter()
.map(|tool| FunctionRef {
ref_path: format!("#/$functions/{}", tool.function.name.clone()),
})
.chain(std::iter::once(FunctionRef {
ref_path: "#/$functions/notify_error".to_string(),
}))
.collect(),
},
};
Ok(Some(tools))
}
}

View File

@ -35,9 +35,9 @@ use futures::stream::StreamExt;
use futures::stream::{FuturesOrdered, FuturesUnordered};
use futures::Stream;
use futures::TryStreamExt;
use http::header::AUTHORIZATION;
use hf_hub::api::tokio::{Api, ApiBuilder, ApiRepo};
use hf_hub::{Cache, Repo, RepoType};
use http::header::AUTHORIZATION;
use metrics_exporter_prometheus::{Matcher, PrometheusBuilder, PrometheusHandle};
use serde_json::Value;
use std::convert::Infallible;
@ -46,6 +46,7 @@ use std::io::BufReader;
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
use std::path::{Path, PathBuf};
use thiserror::Error;
use tokenizers::processors::template::TemplateProcessing;
use tokenizers::Tokenizer;
use tokio::select;
use tokio::signal;
@ -1406,15 +1407,16 @@ pub async fn run(
max_input_tokens: usize,
max_total_tokens: usize,
validation_workers: usize,
addr: SocketAddr,
allow_origin: Option<AllowOrigin>,
api_key: Option<String>,
tokenizer_name: String,
tokenizer_config_path: Option<String>,
revision: Option<String>,
hostname: String,
port: u16,
cors_allow_origin: Option<Vec<String>>,
ngrok: bool,
_ngrok_authtoken: Option<String>,
_ngrok_edge: Option<String>,
tokenizer_config: HubTokenizerConfig,
preprocessor_config: Option<HubPreprocessorConfig>,
processor_config: HubProcessorConfig,
messages_api_enabled: bool,
grammar_support: bool,
max_client_batch_size: usize,
@ -1507,6 +1509,16 @@ pub async fn run(
println!("{}", api_doc);
std::process::exit(0);
}
// CORS allowed origins
// map to go inside the option and then map to parse from String to HeaderValue
// Finally, convert to AllowOrigin
let allow_origin: Option<AllowOrigin> = cors_allow_origin.map(|cors_allow_origin| {
AllowOrigin::list(
cors_allow_origin
.iter()
.map(|origin| origin.parse::<HeaderValue>().unwrap()),
)
});
// Parse Huggingface hub token
let authorization_token = std::env::var("HF_TOKEN")
@ -1564,6 +1576,7 @@ pub async fn run(
tokenizer_filename,
config_filename,
tokenizer_config_filename,
preprocessor_config_filename,
processor_config_filename,
model_info,
) = match api {
@ -1571,6 +1584,7 @@ pub async fn run(
Some(local_path.join("tokenizer.json")),
Some(local_path.join("config.json")),
Some(local_path.join("tokenizer_config.json")),
Some(local_path.join("preprocessor_config.json")),
Some(local_path.join("processor_config.json")),
None,
),
@ -1587,6 +1601,7 @@ pub async fn run(
};
let config_filename = api_repo.get("config.json").await.ok();
let tokenizer_config_filename = api_repo.get("tokenizer_config.json").await.ok();
let preprocessor_config_filename = api_repo.get("preprocessor_config.json").await.ok();
let processor_config_filename = api_repo.get("processor_config.json").await.ok();
let model_info = if let Some(model_info) = get_hub_model_info(&api_repo).await {
@ -1599,6 +1614,7 @@ pub async fn run(
tokenizer_filename,
config_filename,
tokenizer_config_filename,
preprocessor_config_filename,
processor_config_filename,
model_info,
)
@ -1613,13 +1629,40 @@ pub async fn run(
repo.get("tokenizer.json"),
repo.get("config.json"),
repo.get("tokenizer_config.json"),
repo.get("preprocessor_config.json"),
repo.get("processor_config.json"),
None,
)
}
};
let tokenizer: Option<Tokenizer> =
tokenizer_filename.and_then(|filename| Tokenizer::from_file(filename).ok());
// Read the JSON contents of the file as an instance of 'HubTokenizerConfig'.
let tokenizer_config: Option<HubTokenizerConfig> = if let Some(filename) = tokenizer_config_path
{
HubTokenizerConfig::from_file(filename)
} else {
tokenizer_config_filename.and_then(HubTokenizerConfig::from_file)
};
let tokenizer_config = tokenizer_config.unwrap_or_else(|| {
tracing::warn!("Could not find tokenizer config locally and no API specified");
HubTokenizerConfig::default()
});
let tokenizer: Option<Tokenizer> = tokenizer_filename.and_then(|filename| {
let mut tokenizer = Tokenizer::from_file(filename).ok();
if let Some(tokenizer) = &mut tokenizer {
if let Some(class) = &tokenizer_config.tokenizer_class {
if class == "LlamaTokenizer" || class == "LlamaTokenizerFast"{
if let Ok(post_processor) = create_post_processor(tokenizer, &tokenizer_config) {
tracing::info!("Overriding LlamaTokenizer with TemplateProcessing to follow python override defined in https://github.com/huggingface/transformers/blob/4aa17d00690b7f82c95bb2949ea57e22c35b4336/src/transformers/models/llama/tokenization_llama_fast.py#L203-L205");
tokenizer.with_post_processor(post_processor);
}
}
}
}
tokenizer
});
let config: Option<Config> = config_filename.and_then(|filename| {
std::fs::read_to_string(filename)
.ok()
@ -1638,22 +1681,13 @@ pub async fn run(
pipeline_tag: None,
});
// Read the JSON contents of the file as an instance of 'HubTokenizerConfig'.
let tokenizer_config: Option<HubTokenizerConfig> = if let Some(filename) = tokenizer_config_path
{
HubTokenizerConfig::from_file(filename)
} else {
tokenizer_config_filename.and_then(HubTokenizerConfig::from_file)
};
let tokenizer_config = tokenizer_config.unwrap_or_else(|| {
tracing::warn!("Could not find tokenizer config locally and no API specified");
HubTokenizerConfig::default()
});
let processor_config = processor_config_filename
.and_then(HubProcessorConfig::from_file)
.unwrap_or_default();
let preprocessor_config: Option<HubPreprocessorConfig> =
preprocessor_config_filename.and_then(HubPreprocessorConfig::from_file);
tracing::info!("Using config {config:?}");
if tokenizer.is_none() {
tracing::warn!("Could not find a fast tokenizer implementation for {tokenizer_name}");
@ -2107,3 +2141,74 @@ pub enum WebServerError {
#[error("Axum error: {0}")]
Axum(#[from] axum::BoxError),
}
/// Create a post_processor for the LlamaTokenizer
fn create_post_processor(
tokenizer: &Tokenizer,
tokenizer_config: &HubTokenizerConfig,
) -> Result<TemplateProcessing, tokenizers::processors::template::TemplateProcessingBuilderError> {
let add_bos_token = tokenizer_config.add_bos_token.unwrap_or(true);
let add_eos_token = tokenizer_config.add_eos_token.unwrap_or(false);
let bos_token = tokenizer_config.bos_token.as_ref();
let eos_token = tokenizer_config.eos_token.as_ref();
if add_bos_token && bos_token.is_none() {
panic!("add_bos_token = true but bos_token is None");
}
if add_eos_token && eos_token.is_none() {
panic!("add_eos_token = true but eos_token is None");
}
let mut single = Vec::new();
let mut pair = Vec::new();
let mut special_tokens = Vec::new();
if add_bos_token {
if let Some(bos) = bos_token {
let bos_token_id = tokenizer
.token_to_id(bos.as_str())
.expect("Should have found the bos token id");
special_tokens.push((bos.as_str(), bos_token_id));
single.push(format!("{}:0", bos.as_str()));
pair.push(format!("{}:0", bos.as_str()));
}
}
single.push("$A:0".to_string());
pair.push("$A:0".to_string());
if add_eos_token {
if let Some(eos) = eos_token {
let eos_token_id = tokenizer
.token_to_id(eos.as_str())
.expect("Should have found the eos token id");
special_tokens.push((eos.as_str(), eos_token_id));
single.push(format!("{}:0", eos.as_str()));
pair.push(format!("{}:0", eos.as_str()));
}
}
if add_bos_token {
if let Some(bos) = bos_token {
pair.push(format!("{}:1", bos.as_str()));
}
}
pair.push("$B:1".to_string());
if add_eos_token {
if let Some(eos) = eos_token {
pair.push(format!("{}:1", eos.as_str()));
}
}
let post_processor = TemplateProcessing::builder()
.try_single(single)?
.try_pair(pair)?
.special_tokens(special_tokens)
.build()?;
Ok(post_processor)
}

View File

@ -5,13 +5,12 @@ use crate::{
GenerateParameters, GenerateRequest, GrammarType, HubPreprocessorConfig, Idefics2Preprocessor,
};
use base64::{engine::general_purpose::STANDARD, Engine};
use image::{io::Reader as ImageReader, ImageFormat};
use image::{ImageFormat, ImageReader};
use jsonschema::{Draft, JSONSchema};
use rand::{thread_rng, Rng};
use serde_json::Value;
use std::io::Cursor;
use std::iter;
use text_generation_client::{Chunk, Image, InputChunk};
use thiserror::Error;
use tokenizers::tokenizer::Tokenizer;
use tokio::sync::mpsc;
@ -181,11 +180,7 @@ impl Validation {
input_length = input_length.saturating_sub(max_new_tokens as usize);
}
Ok((
vec![Chunk::Text(inputs).into()],
input_length,
max_new_tokens,
))
Ok((vec![Chunk::Text(inputs)], input_length, max_new_tokens))
}
}
@ -589,7 +584,7 @@ fn prepare_input(
tokenizer: &Tokenizer,
config: Option<&Config>,
preprocessor_config: Option<&HubPreprocessorConfig>,
) -> Result<(tokenizers::Encoding, Vec<InputChunk>), ValidationError> {
) -> Result<(tokenizers::Encoding, Vec<Chunk>), ValidationError> {
use Config::*;
static RE: Lazy<Regex> = Lazy::new(|| Regex::new(r"!\[\]\([^\)]*\)").unwrap());
let (tokenizer_query, input_chunks) = match config {
@ -601,16 +596,16 @@ fn prepare_input(
let chunk_start = chunk.start();
let chunk_end = chunk.end();
if chunk_start != start {
input_chunks.push(Chunk::Text(inputs[start..chunk_start].to_string()).into());
input_chunks.push(Chunk::Text(inputs[start..chunk_start].to_string()));
tokenizer_query.push_str(&inputs[start..chunk_start]);
}
let (data, mimetype, height, width) = fetch_image(&inputs[chunk_start..chunk_end])?;
input_chunks.push(Chunk::Image(Image { data, mimetype }).into());
input_chunks.push(Chunk::Image(Image { data, mimetype }));
tokenizer_query.push_str(&image_tokens(config, preprocessor_config, height, width));
start = chunk_end;
}
if start != inputs.len() {
input_chunks.push(Chunk::Text(inputs[start..].to_string()).into());
input_chunks.push(Chunk::Text(inputs[start..].to_string()));
tokenizer_query.push_str(&inputs[start..]);
}
@ -618,7 +613,7 @@ fn prepare_input(
(tokenizer_query, input_chunks)
}
_ => (inputs.clone(), vec![Chunk::Text(inputs).into()]),
_ => (inputs.clone(), vec![Chunk::Text(inputs)]),
};
// Get the number of tokens in the input
@ -784,8 +779,7 @@ pub enum ValidationError {
#[error("Could not fetch image: {0}")]
FailedFetchImage(#[from] reqwest::Error),
#[error("{0} modality is not supported")]
UnsupportedModality(&'static str)
UnsupportedModality(&'static str),
}
#[cfg(test)]