This commit is contained in:
OlivierDehaene 2024-06-26 14:00:03 +02:00
parent b562680be4
commit 93e0a7de8b
28 changed files with 254 additions and 392 deletions

266
Cargo.lock generated
View File

@ -194,17 +194,6 @@ dependencies = [
"v_frame",
]
[[package]]
name = "average"
version = "0.14.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c309b1c7fca12ebeec3ecba29ea917b3a4cb458ccf504df68bb4d8a0ca565a00"
dependencies = [
"easy-cast",
"float-ord",
"num-traits",
]
[[package]]
name = "avif-serialize"
version = "0.8.1"
@ -503,12 +492,6 @@ dependencies = [
"thiserror",
]
[[package]]
name = "cassowary"
version = "0.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "df8670b8c7b9dae1793364eafadf7239c40d669904660c5960d74cfd80b46a53"
[[package]]
name = "cc"
version = "1.0.98"
@ -570,7 +553,7 @@ version = "4.5.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "528131438037fd55894f62d6e9f068b8f45ac57ffa77517819645d10aed04f64"
dependencies = [
"heck 0.5.0",
"heck",
"proc-macro2",
"quote",
"syn 2.0.66",
@ -675,31 +658,6 @@ version = "0.8.20"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "22ec99545bb0ed0ea7bb9b8e1e9122ea386ff8a48c0922e43f36d45ab09e0e80"
[[package]]
name = "crossterm"
version = "0.27.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f476fe445d41c9e991fd07515a6f463074b782242ccf4a5b7b1d1012e70824df"
dependencies = [
"bitflags 2.5.0",
"crossterm_winapi",
"libc",
"mio",
"parking_lot",
"signal-hook",
"signal-hook-mio",
"winapi",
]
[[package]]
name = "crossterm_winapi"
version = "0.9.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "acdd7c62a3665c7f6830a51635d9ac9b23ed385797f70a83bb8bafe9c572ab2b"
dependencies = [
"winapi",
]
[[package]]
name = "crunchy"
version = "0.2.2"
@ -832,15 +790,6 @@ dependencies = [
"windows-sys 0.48.0",
]
[[package]]
name = "easy-cast"
version = "0.5.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "10936778145f3bea71fd9bf61332cce28c28e96a380714f7ab34838b80733fd6"
dependencies = [
"libm",
]
[[package]]
name = "either"
version = "1.12.0"
@ -944,12 +893,6 @@ dependencies = [
"miniz_oxide",
]
[[package]]
name = "float-ord"
version = "0.3.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8ce81f49ae8a0482e4c55ea62ebbd7e5a686af544c00b9d090bba3ff9be97b3d"
[[package]]
name = "float_eq"
version = "1.0.1"
@ -1208,12 +1151,6 @@ version = "0.14.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e5274423e17b7c9fc20b6e7e208532f9b19825d82dfd615708b70edd83df41f1"
[[package]]
name = "heck"
version = "0.4.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "95505c38b4572b2d910cecb0281560f54b440a19336cbbcb27bf6ce6adc6f5a8"
[[package]]
name = "heck"
version = "0.5.0"
@ -1498,12 +1435,6 @@ dependencies = [
"unicode-width",
]
[[package]]
name = "indoc"
version = "2.0.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b248f5224d1d606005e02c97f5aa4e88eeb230488bcc03bc9ca4d7991399f2b5"
[[package]]
name = "init-tracing-opentelemetry"
version = "0.14.1"
@ -1674,12 +1605,6 @@ dependencies = [
"once_cell",
]
[[package]]
name = "libm"
version = "0.2.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4ec2a862134d2a7d32d7983ddcdd1c4923530833c9f2ea1a44fc5fa473989058"
[[package]]
name = "libredox"
version = "0.1.3"
@ -1896,7 +1821,6 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a4a650543ca06a924e8b371db273b2756685faae30f8487da1b56505a8f78b0c"
dependencies = [
"libc",
"log",
"wasi",
"windows-sys 0.48.0",
]
@ -2148,7 +2072,6 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "071dfc062690e90b734c0b2273ce72ad0ffa95f0c74596bc250dcfd960262841"
dependencies = [
"autocfg",
"libm",
]
[[package]]
@ -2412,17 +2335,6 @@ version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b15813163c1d831bf4a13c3610c05c0d03b39feb07f7e09fa234dac9b15aaf39"
[[package]]
name = "papergrid"
version = "0.10.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a2ccbe15f2b6db62f9a9871642746427e297b0ceb85f9a7f1ee5ff47d184d0c8"
dependencies = [
"bytecount",
"fnv",
"unicode-width",
]
[[package]]
name = "parking_lot"
version = "0.12.3"
@ -2626,7 +2538,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "22505a5c94da8e3b7c2996394d1c933236c4d743e81a410bcca4e6989fc066a4"
dependencies = [
"bytes",
"heck 0.5.0",
"heck",
"itertools 0.12.1",
"log",
"multimap",
@ -2745,23 +2657,6 @@ dependencies = [
"getrandom",
]
[[package]]
name = "ratatui"
version = "0.23.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2e2e4cd95294a85c3b4446e63ef054eea43e0205b1fd60120c16b74ff7ff96ad"
dependencies = [
"bitflags 2.5.0",
"cassowary",
"crossterm",
"indoc",
"itertools 0.11.0",
"paste",
"strum",
"unicode-segmentation",
"unicode-width",
]
[[package]]
name = "rav1e"
version = "0.7.1"
@ -3269,27 +3164,6 @@ dependencies = [
"lazy_static",
]
[[package]]
name = "signal-hook"
version = "0.3.17"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8621587d4798caf8eb44879d42e56b9a93ea5dcd315a6487c357130095b62801"
dependencies = [
"libc",
"signal-hook-registry",
]
[[package]]
name = "signal-hook-mio"
version = "0.2.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "29ad2e15f37ec9a6cc544097b78a1ec90001e9f71b81338ca39f430adaca99af"
dependencies = [
"libc",
"mio",
"signal-hook",
]
[[package]]
name = "signal-hook-registry"
version = "1.4.2"
@ -3387,28 +3261,6 @@ version = "0.11.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7da8b5736845d9f2fcb837ea5d9e2628564b3b043a70948a3f0b778838c5fb4f"
[[package]]
name = "strum"
version = "0.25.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "290d54ea6f91c969195bdbcd7442c8c2a2ba87da8bf60a7ee86a235d4bc1e125"
dependencies = [
"strum_macros",
]
[[package]]
name = "strum_macros"
version = "0.25.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "23dc1fa9ac9c169a78ba62f0b841814b7abae11bdd047b9c58f893439e309ea0"
dependencies = [
"heck 0.4.1",
"proc-macro2",
"quote",
"rustversion",
"syn 2.0.66",
]
[[package]]
name = "subtle"
version = "2.5.0"
@ -3491,36 +3343,12 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a3e535eb8dded36d55ec13eddacd30dec501792ff23a0b1682c38601b8cf2349"
dependencies = [
"cfg-expr",
"heck 0.5.0",
"heck",
"pkg-config",
"toml",
"version-compare",
]
[[package]]
name = "tabled"
version = "0.14.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "dfe9c3632da101aba5131ed63f9eed38665f8b3c68703a6bb18124835c1a5d22"
dependencies = [
"papergrid",
"tabled_derive",
"unicode-width",
]
[[package]]
name = "tabled_derive"
version = "0.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "99f688a08b54f4f02f0a3c382aefdb7884d3d69609f785bd253dc033243e3fe4"
dependencies = [
"heck 0.4.1",
"proc-macro-error",
"proc-macro2",
"quote",
"syn 1.0.109",
]
[[package]]
name = "target-lexicon"
version = "0.12.14"
@ -3539,45 +3367,6 @@ dependencies = [
"windows-sys 0.52.0",
]
[[package]]
name = "text-generation-benchmark"
version = "2.0.5-dev0"
dependencies = [
"average",
"clap",
"crossterm",
"float-ord",
"hf-hub",
"ratatui",
"serde",
"serde_json",
"tabled",
"text-generation-client",
"thiserror",
"tokenizers",
"tokio",
"tracing",
"tracing-subscriber",
]
[[package]]
name = "text-generation-client"
version = "2.0.5-dev0"
dependencies = [
"async-trait",
"base64 0.22.1",
"futures",
"grpc-metadata",
"prost 0.12.6",
"prost-build",
"thiserror",
"tokio",
"tonic 0.10.2",
"tonic-build",
"tower",
"tracing",
]
[[package]]
name = "text-generation-launcher"
version = "2.0.5-dev0"
@ -3627,7 +3416,6 @@ dependencies = [
"reqwest",
"serde",
"serde_json",
"text-generation-client",
"thiserror",
"tokenizers",
"tokio",
@ -3641,6 +3429,54 @@ dependencies = [
"vergen",
]
[[package]]
name = "text-generation-router-v3"
version = "2.0.5-dev0"
dependencies = [
"async-stream",
"async-trait",
"axum 0.7.5",
"axum-tracing-opentelemetry",
"base64 0.22.1",
"clap",
"futures",
"futures-util",
"grpc-metadata",
"hf-hub",
"image",
"init-tracing-opentelemetry",
"jsonschema",
"metrics",
"metrics-exporter-prometheus",
"minijinja",
"minijinja-contrib",
"nohash-hasher",
"once_cell",
"opentelemetry 0.20.0",
"opentelemetry-otlp",
"prost 0.12.6",
"prost-build",
"rand",
"regex",
"reqwest",
"serde",
"serde_json",
"text-generation-router",
"thiserror",
"tokenizers",
"tokio",
"tokio-stream",
"tonic 0.10.2",
"tonic-build",
"tower",
"tower-http",
"tracing",
"tracing-opentelemetry 0.21.0",
"tracing-subscriber",
"utoipa",
"utoipa-swagger-ui",
]
[[package]]
name = "thiserror"
version = "1.0.61"

View File

@ -1,9 +1,9 @@
[workspace]
members = [
"benchmark",
"router",
"router/client",
"router/grpc-metadata",
# "benchmark",
"backends/v3",
# "backends/client",
"backends/grpc-metadata",
"launcher"
]
resolver = "2"

66
backends/v3/Cargo.toml Normal file
View File

@ -0,0 +1,66 @@
[package]
name = "text-generation-router-v3"
description = "Text Generation Webserver"
version.workspace = true
edition.workspace = true
authors.workspace = true
homepage.workspace = true
[lib]
path = "src/lib.rs"
[[bin]]
name = "text-generation-router"
path = "src/main.rs"
[dependencies]
async-trait = "0.1.74"
async-stream = "0.3.5"
axum = { version = "0.7", features = ["json"] }
axum-tracing-opentelemetry = "0.16"
text-generation-router = { path = "../../router" }
clap = { version = "4.4.5", features = ["derive", "env"] }
grpc-metadata = { path = "../grpc-metadata" }
futures = "0.3.28"
hf-hub = { workspace = true }
jsonschema = { version = "0.17.1", features = ["draft202012"] }
metrics = "0.21.1"
metrics-exporter-prometheus = { version = "0.12.1", features = [] }
nohash-hasher = "0.2.0"
opentelemetry = { version = "0.20.0", features = ["rt-tokio"] }
opentelemetry-otlp = "0.13.0"
rand = "0.8.5"
reqwest = { version = "0.11.20", features = [] }
serde = "1.0.188"
serde_json = "1.0.107"
thiserror = "1.0.48"
tokenizers = { workspace = true}
tokio = { version = "1.32.0", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync"] }
tokio-stream = "0.1.14"
tower-http = { version = "0.5.1", features = ["cors"] }
tracing = "0.1.37"
tracing-opentelemetry = "0.21.0"
tracing-subscriber = { version = "0.3.17", features = ["json", "env-filter"] }
utoipa = { version = "4.2.0", features = ["axum_extras"] }
utoipa-swagger-ui = { version = "6.0.0", features = ["axum"] }
init-tracing-opentelemetry = { version = "0.14.1", features = ["opentelemetry-otlp"] }
minijinja = { version = "2.0.2" }
minijinja-contrib = { version = "2.0.2", features = ["pycompat"] }
futures-util = "0.3.30"
regex = "1.10.3"
once_cell = "1.19.0"
image = "0.25.1"
base64 = { workspace = true }
prost = "^0.12"
tonic = "^0.10"
tower = "^0.4"
[build-dependencies]
tonic-build = "0.10.1"
prost-build = "0.12.1"
[features]
default = ["ngrok"]
ngrok = ["text-generation-router/ngrok"]
google = ["text-generation-router/google"]
kserve = ["text-generation-router/kserve"]

19
backends/v3/build.rs Normal file
View File

@ -0,0 +1,19 @@
use std::fs;
fn main() -> Result<(), Box<dyn std::error::Error>> {
println!("cargo:rerun-if-changed=../../proto/");
fs::create_dir_all("src/client/pb").unwrap_or(());
let mut config = prost_build::Config::new();
config.protoc_arg("--experimental_allow_proto3_optional");
tonic_build::configure()
.build_client(true)
.build_server(false)
.out_dir("src/client/pb")
.include_file("mod.rs")
.compile_with_config(config, &["../../proto/v3/generate.proto"], &["../../proto"])
.unwrap_or_else(|e| panic!("protobuf compilation failed: {e}"));
Ok(())
}

View File

@ -1,16 +1,15 @@
/// Batching and inference logic
use crate::infer::v3::queue::{Entry, Queue};
use crate::infer::{
use crate::queue::{Entry, Queue};
use text_generation_router::infer::{
GeneratedText, InferError, InferStreamResponse, Backend,
};
use crate::validation::ValidGenerateRequest;
use crate::{FinishReason, PrefillToken, Token};
use text_generation_router::validation::ValidGenerateRequest;
use text_generation_router::{FinishReason, PrefillToken, Token};
use nohash_hasher::IntMap;
use std::sync::{
Arc,
};
use text_generation_client::v3::{Batch, CachedBatch, Generation, ShardedClient};
use text_generation_client::{ClientError, Health};
use crate::client::{Batch, CachedBatch, Generation, ShardedClient, ClientError, Health};
use tokio::sync::mpsc::error::SendError;
use tokio::sync::{mpsc, Notify};
use tokio::time::Instant;
@ -18,7 +17,7 @@ use tokio_stream::wrappers::UnboundedReceiverStream;
use tracing::{info_span, instrument, Instrument, Span};
use async_trait::async_trait;
pub(crate) struct BackendV3 {
pub struct BackendV3 {
/// Request queue
queue: Queue,
/// Notify batcher on queue appends
@ -78,7 +77,6 @@ impl Backend for BackendV3 {
) -> Result<UnboundedReceiverStream<Result<InferStreamResponse, InferError>>, InferError> {
// MPSC channel to communicate with the background batching task
let (response_tx, response_rx) = mpsc::unbounded_channel();
let input_length = request.input_length;
// Append the request to the queue
self.queue.append(Entry {
@ -480,14 +478,14 @@ fn send_errors(error: ClientError, entries: &mut IntMap<u64, Entry>) {
});
}
impl From<text_generation_client::v3::GeneratedText> for GeneratedText {
fn from(value: text_generation_client::v3::GeneratedText) -> Self {
impl From<crate::client::GeneratedText> for GeneratedText {
fn from(value: crate::client::GeneratedText) -> Self {
let v3_finish_reason =
text_generation_client::v3::FinishReason::try_from(value.finish_reason).unwrap();
crate::client::FinishReason::try_from(value.finish_reason).unwrap();
let finish_reason = match v3_finish_reason {
text_generation_client::v3::FinishReason::Length => FinishReason::Length,
text_generation_client::v3::FinishReason::EosToken => FinishReason::EndOfSequenceToken,
text_generation_client::v3::FinishReason::StopSequence => FinishReason::StopSequence,
crate::client::FinishReason::Length => FinishReason::Length,
crate::client::FinishReason::EosToken => FinishReason::EndOfSequenceToken,
crate::client::FinishReason::StopSequence => FinishReason::StopSequence,
};
Self {

View File

@ -1,6 +1,7 @@
use crate::v3::{pb, Chunk};
use crate::{ClientError, Result, WARMUP_IMAGE_BASE64};
/// Single shard Client
use crate::client::{pb, Chunk};
use crate::client::{ClientError, Result, WARMUP_IMAGE_BASE64};
use base64::engine::general_purpose::STANDARD;
use base64::Engine;
use grpc_metadata::InjectTelemetryContext;

View File

@ -1,15 +1,24 @@
//! Text Generation gRPC client library
pub mod v2;
pub mod v3;
use async_trait::async_trait;
use base64::{engine::general_purpose::STANDARD, Engine};
use thiserror::Error;
use tonic::transport;
use tonic::Status;
pub use v3::{Chunk, Image, Input, InputChunk};
#[allow(clippy::derive_partial_eq_without_eq)]
mod pb;
mod client;
mod sharded_client;
pub use client::Client;
pub use pb::generate::v3::{
input_chunk::Chunk, Batch, CachedBatch, FinishReason, GeneratedText, Generation, GrammarType,
HealthResponse, Image, InfoResponse, Input, InputChunk, NextTokenChooserParameters, Request,
StoppingCriteriaParameters,
};
pub use sharded_client::ShardedClient;
#[async_trait]
pub trait Health {
@ -63,29 +72,6 @@ impl From<Chunk> for InputChunk {
}
}
/// Convert input chunks to a stringly-typed input for backwards
/// compat for backends that haven't implemented chunked inputs.
pub trait ChunksToString {
/// Convert chunks to string.
fn chunks_to_string(&self) -> String;
}
impl ChunksToString for Vec<InputChunk> {
fn chunks_to_string(&self) -> String {
let mut output = String::new();
self.iter().for_each(|c| match &c.chunk {
Some(Chunk::Text(text)) => output.push_str(text),
Some(Chunk::Image(Image { data, mimetype })) => {
let encoded = STANDARD.encode(data);
output.push_str(&format!("![](data:{};base64,{})", mimetype, encoded))
}
// We don't create empty chunks, so this should be unreachable.
None => unreachable!("Chunks should never be empty"),
});
output
}
}
static WARMUP_IMAGE_BASE64 :&str = "iVBORw0KGgoAAAANSUhEUgAAABQAAAAUCAIAAAAC64paAAABg2lDQ1BJQ0MgcHJvZmlsZQAAKJF9kT1Iw0AcxV/TSotUROxQxCFDdbKLijjWKhShQqgVWnUwufQLmrQkKS6OgmvBwY/FqoOLs64OroIg+AHi7OCk6CIl/i8ptIjx4Lgf7+497t4BQqvKNDOQADTdMjKppJjLr4rBVwQQwhAERGVm1uckKQ3P8XUPH1/v4jzL+9yfY0AtmAzwicQJVjcs4g3imU2rznmfOMLKskp8Tjxh0AWJH7muuPzGueSwwDMjRjYzTxwhFks9rPQwKxsa8TRxTNV0yhdyLquctzhr1Qbr3JO/MFzQV5a5TnMUKSxiCRJEKGiggiosxGnVSTGRof2kh3/E8UvkUshVASPHAmrQIDt+8D/43a1ZnJp0k8JJoO/Ftj/GgOAu0G7a9vexbbdPAP8zcKV3/bUWMPtJerOrxY6AwW3g4rqrKXvA5Q4QfarLhuxIfppCsQi8n9E35YHhW6B/ze2ts4/TByBLXaVvgINDYLxE2ese7w719vbvmU5/PycecohsjayNAAAACXBIWXMAAC4jAAAuIwF4pT92AAAAB3RJTUUH6AQIEQMnlTSSjwAAABl0RVh0Q29tbWVudABDcmVhdGVkIHdpdGggR0lNUFeBDhcAAAASSURBVDjLY2AYBaNgFIyCoQsABMQAAeRw1DoAAAAASUVORK5CYII=";
pub type Result<T> = std::result::Result<T, ClientError>;

View File

@ -1,14 +1,14 @@
/// Multi shard Client
use crate::{v3, Health, ShardInfo};
use crate::{ClientError, Result};
use crate::client::{Health, ShardInfo};
use crate::client::{ClientError, Result};
use crate::v3::{Chunk, InfoResponse, Input};
use crate::client::{Chunk, InfoResponse, Input};
use async_trait::async_trait;
use futures::future::join_all;
use tonic::transport::Uri;
use tracing::instrument;
use v3::client::{DecodeTimings, PrefillTimings};
use v3::{
use crate::client::client::{DecodeTimings, PrefillTimings};
use crate::client::{
Batch, CachedBatch, Client, Generation, GrammarType, HealthResponse,
NextTokenChooserParameters, Request, StoppingCriteriaParameters,
};

View File

@ -1,14 +1,13 @@
mod block_allocator;
mod queue;
mod backend;
mod client;
use futures_util::TryFutureExt;
use serde::Serialize;
use thiserror::Error;
use utoipa::ToSchema;
pub(crate) use backend::BackendV3;
use text_generation_client::ClientError;
use text_generation_client::v3::ShardedClient;
use crate::client::{ShardedClient, ClientError};
#[derive(Clone, Debug, Serialize, ToSchema)]
pub struct BackendInfo {
@ -127,7 +126,7 @@ pub async fn connect_backend(
}
#[derive(Debug, Error)]
pub(crate) enum V3Error {
pub enum V3Error {
#[error("Unable to clear the Python model shards cache: {0}")]
Cache(ClientError),
#[error("Unable to connect to the Python model shards: {0}")]

View File

@ -20,6 +20,7 @@ use tower_http::cors::AllowOrigin;
use tracing_subscriber::layer::SubscriberExt;
use tracing_subscriber::util::SubscriberInitExt;
use tracing_subscriber::{filter::LevelFilter, EnvFilter, Layer};
use text_generation_router_v3::{connect_backend, V3Error};
/// App Configuration
#[derive(Parser, Debug)]
@ -336,9 +337,11 @@ async fn main() -> Result<(), RouterError> {
}
};
let (backend, backend_info) = connect_backend(max_input_tokens, max_total_tokens, master_shard_uds_path, waiting_served_ratio, max_batch_prefill_tokens, max_batch_total_tokens, max_waiting_tokens, max_batch_size).await?;
// Run server
server::run(
master_shard_uds_path,
backend,
model_info,
compat_return_full_text,
max_concurrent_requests,
@ -347,11 +350,6 @@ async fn main() -> Result<(), RouterError> {
max_top_n_tokens,
max_input_tokens,
max_total_tokens,
waiting_served_ratio,
max_batch_prefill_tokens,
max_batch_total_tokens,
max_waiting_tokens,
max_batch_size,
tokenizer,
config,
validation_workers,
@ -508,6 +506,8 @@ pub async fn get_tokenizer_config(api_repo: &ApiRepo) -> Option<HubTokenizerConf
enum RouterError {
#[error("Argument validation error: {0}")]
ArgumentValidation(String),
#[error("Backend failed: {0}")]
Backend(#[from] V3Error),
#[error("WebServer error: {0}")]
WebServer(#[from] server::WebServerError),
#[error("Tokio runtime failed to start: {0}")]

View File

@ -1,17 +1,14 @@
use crate::infer::v3::block_allocator::{BlockAllocation, BlockAllocator};
use crate::infer::InferError;
use crate::infer::InferStreamResponse;
use crate::validation::{
ValidGenerateRequest, ValidGrammar, ValidParameters, ValidStoppingParameters,
};
use crate::block_allocator::{BlockAllocation, BlockAllocator};
use text_generation_router::infer::InferError;
use text_generation_router::infer::InferStreamResponse;
use text_generation_router::validation::{Chunk, ChunksToString, ValidGenerateRequest, ValidGrammar, ValidParameters, ValidStoppingParameters};
use nohash_hasher::{BuildNoHashHasher, IntMap};
use std::cmp::{max, min};
use std::collections::VecDeque;
use text_generation_client::v3::{
use crate::client::{
Batch, GrammarType, NextTokenChooserParameters, Request, StoppingCriteriaParameters,
};
use text_generation_client::ChunksToString;
use text_generation_client::Input;
use crate::client as client;
use tokio::sync::{mpsc, oneshot};
use tokio::time::Instant;
use tracing::{info_span, instrument, Instrument, Span};
@ -337,8 +334,11 @@ impl State {
batch_requests.push(Request {
id,
prefill_logprobs: entry.request.decoder_input_details,
input_chunks: Some(Input {
chunks: entry.request.inputs.clone(),
input_chunks: Some(client::Input {
chunks: entry.request.inputs.clone().into_iter().map(|c| client::InputChunk { chunk: Some(match c {
Chunk::Text(text) => client::Chunk::Text(text),
Chunk::Image(image) => client::Chunk::Image(client::Image { data: image.data, mimetype: image.mimetype })
})}).collect()
}),
inputs: entry.request.inputs.chunks_to_string(),
truncate: entry.request.truncate,

View File

@ -21,7 +21,7 @@ float-ord = "0.3.2"
serde = {version = "1.0.188", features = ["derive"]}
serde_json = "1.0"
tabled = "0.14.0"
text-generation-client = { path = "../router/client" }
text-generation-client = { path = "../backends/client" }
thiserror = "1.0.48"
tokenizers = { workspace = true }
tokio = { version = "1.32.0", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync", "macros"] }

View File

@ -7,19 +7,11 @@ edition.workspace = true
authors.workspace = true
homepage.workspace = true
[lib]
path = "src/lib.rs"
[[bin]]
name = "text-generation-router"
path = "src/main.rs"
[dependencies]
async-trait = "0.1.74"
async-stream = "0.3.5"
axum = { version = "0.7", features = ["json"] }
axum-tracing-opentelemetry = "0.16"
text-generation-client = { path = "client" }
clap = { version = "4.4.5", features = ["derive", "env"] }
futures = "0.3.28"
hf-hub = { workspace = true }

View File

@ -1,35 +0,0 @@
use std::fs;
fn main() -> Result<(), Box<dyn std::error::Error>> {
println!("cargo:rerun-if-changed=../../proto/");
fs::create_dir_all("src/v2/pb").unwrap_or(());
let mut config = prost_build::Config::new();
config.protoc_arg("--experimental_allow_proto3_optional");
tonic_build::configure()
.build_client(true)
.build_server(false)
.out_dir("src/v2/pb")
.include_file("mod.rs")
.compile_with_config(config, &["../../proto/generate.proto"], &["../../proto"])
.map_err(|e| match e.kind(){
std::io::ErrorKind::NotFound => {panic!("`protoc` not found, install libprotoc")},
std::io::ErrorKind::Other => {panic!("`protoc` version unsupported, upgrade protoc: https://github.com/protocolbuffers/protobuf/releases")},
e => {e}
}).unwrap_or_else(|e| panic!("protobuf compilation failed: {e}"));
fs::create_dir_all("src/v3/pb").unwrap_or(());
let mut config = prost_build::Config::new();
config.protoc_arg("--experimental_allow_proto3_optional");
tonic_build::configure()
.build_client(true)
.build_server(false)
.out_dir("src/v3/pb")
.include_file("mod.rs")
.compile_with_config(config, &["../../proto/v3/generate.proto"], &["../../proto"])
.unwrap_or_else(|e| panic!("protobuf compilation failed: {e}"));
Ok(())
}

View File

@ -1 +0,0 @@
*

View File

@ -1,13 +0,0 @@
#[allow(clippy::derive_partial_eq_without_eq)]
mod pb;
mod client;
mod sharded_client;
pub use client::Client;
pub use pb::generate::v3::{
input_chunk::Chunk, Batch, CachedBatch, FinishReason, GeneratedText, Generation, GrammarType,
HealthResponse, Image, InfoResponse, Input, InputChunk, NextTokenChooserParameters, Request,
StoppingCriteriaParameters, Tokens,
};
pub use sharded_client::ShardedClient;

View File

@ -1 +0,0 @@
*

View File

@ -1,5 +1,4 @@
// pub(crate) mod v2;
pub(crate) mod v3;
mod chat_template;
pub mod tool_grammar;
@ -23,7 +22,7 @@ use chat_template::ChatTemplate;
use async_trait::async_trait;
#[async_trait]
pub(crate) trait Backend {
pub trait Backend {
fn schedule(
&self,
request: ValidGenerateRequest,
@ -286,15 +285,15 @@ pub(crate) type GenerateStreamResponse = (
);
#[derive(Debug)]
pub(crate) struct GeneratedText {
pub(crate) text: String,
pub(crate) generated_tokens: u32,
pub(crate) finish_reason: FinishReason,
pub(crate) seed: Option<u64>,
pub struct GeneratedText {
pub text: String,
pub generated_tokens: u32,
pub finish_reason: FinishReason,
pub seed: Option<u64>,
}
#[derive(Debug)]
pub(crate) enum InferStreamResponse {
pub enum InferStreamResponse {
// Optional first message
Prefill(Vec<PrefillToken>),
// Intermediate messages

View File

@ -1,8 +1,8 @@
/// Text Generation Inference Webserver
pub mod config;
mod infer;
pub mod infer;
pub mod server;
mod validation;
pub mod validation;
#[cfg(feature = "kserve")]
mod kserve;
@ -1055,23 +1055,23 @@ impl From<CompatGenerateRequest> for GenerateRequest {
#[derive(Debug, Serialize, ToSchema)]
pub struct PrefillToken {
#[schema(example = 0)]
id: u32,
pub id: u32,
#[schema(example = "test")]
text: String,
pub text: String,
#[schema(nullable = true, example = - 0.34)]
logprob: f32,
pub logprob: f32,
}
#[derive(Debug, Serialize, ToSchema, Clone)]
pub struct Token {
#[schema(example = 0)]
id: u32,
pub id: u32,
#[schema(example = "test")]
text: String,
pub text: String,
#[schema(nullable = true, example = - 0.34)]
logprob: f32,
pub logprob: f32,
#[schema(example = "false")]
special: bool,
pub special: bool,
}
#[derive(Debug, Serialize, ToSchema)]
@ -1089,7 +1089,7 @@ pub struct SimpleToken {
#[derive(Debug, Serialize, ToSchema)]
#[serde(rename_all(serialize = "snake_case"))]
#[schema(example = "Length")]
pub(crate) enum FinishReason {
pub enum FinishReason {
#[schema(rename = "length")]
Length,
#[serde(rename = "eos_token")]

View File

@ -1,6 +1,5 @@
/// HTTP Server logic
use crate::config::Config;
use crate::infer::v3::{connect_backend, V3Error};
use crate::infer::{Infer, InferError, InferResponse, InferStreamResponse, Backend};
use crate::infer::tool_grammar::ToolGrammar;
#[cfg(feature = "kserve")]
@ -38,9 +37,6 @@ use metrics_exporter_prometheus::{Matcher, PrometheusBuilder, PrometheusHandle};
use serde_json::Value;
use std::convert::Infallible;
use std::net::SocketAddr;
use std::sync::atomic::AtomicBool;
use std::sync::Arc;
use text_generation_client::{v2, v3, ClientError, ShardInfo};
use thiserror::Error;
use tokenizers::Tokenizer;
use tokio::select;
@ -1398,7 +1394,7 @@ pub(crate) struct ComputeType(String);
/// Serving method
#[allow(clippy::too_many_arguments)]
pub async fn run(
master_shard_uds_path: String,
backend: impl Backend + Send + Sync + 'static,
model_info: HubModelInfo,
compat_return_full_text: bool,
max_concurrent_requests: usize,
@ -1407,11 +1403,6 @@ pub async fn run(
max_top_n_tokens: u32,
max_input_tokens: usize,
max_total_tokens: usize,
waiting_served_ratio: f32,
max_batch_prefill_tokens: u32,
max_batch_total_tokens: Option<u32>,
max_waiting_tokens: usize,
max_batch_size: Option<usize>,
tokenizer: Option<Tokenizer>,
config: Option<Config>,
validation_workers: usize,
@ -1495,11 +1486,6 @@ pub async fn run(
struct ApiDoc;
// Create state
// Open connection, get model info and warmup
let (backend, backend_info) = connect_backend(max_input_tokens, max_total_tokens, master_shard_uds_path, waiting_served_ratio, max_batch_prefill_tokens, max_batch_total_tokens, max_waiting_tokens, max_batch_size).await?;
// tracing::info!("Setting max batch total tokens to {max_batch_total_tokens}");
let validation = Validation::new(
validation_workers,
tokenizer,
@ -1827,8 +1813,6 @@ impl From<InferError> for Event {
#[derive(Debug, Error)]
pub enum WebServerError {
#[error("Backend error: {0}")]
Backend(#[from] V3Error),
#[error("Axum error: {0}")]
Axum(#[from] axum::BoxError),
}

View File

@ -8,7 +8,6 @@ use jsonschema::{Draft, JSONSchema};
use rand::{thread_rng, Rng};
use serde_json::Value;
use std::io::Cursor;
use text_generation_client::{Chunk, Image, InputChunk};
use thiserror::Error;
use tokenizers::tokenizer::Tokenizer;
use tokio::sync::mpsc;
@ -86,7 +85,7 @@ impl Validation {
&self,
inputs: String,
truncate: Option<usize>,
) -> Result<Option<(tokenizers::Encoding, Vec<InputChunk>)>, ValidationError> {
) -> Result<Option<(tokenizers::Encoding, Vec<Chunk>)>, ValidationError> {
// If we have a fast tokenizer
if let Some(sender) = &self.sender {
// Create response channel
@ -112,7 +111,7 @@ impl Validation {
inputs: String,
truncate: Option<usize>,
max_new_tokens: Option<u32>,
) -> Result<(Vec<InputChunk>, usize, u32), ValidationError> {
) -> Result<(Vec<Chunk>, usize, u32), ValidationError> {
// If we have a fast tokenizer
if let Some((encoding, inputs)) = self.tokenize(inputs.clone(), truncate).await? {
// Create response channel
@ -514,7 +513,7 @@ fn prepare_input(
_truncate: Option<usize>,
tokenizer: &Tokenizer,
config: &Option<Config>,
) -> Result<(tokenizers::Encoding, Vec<InputChunk>), ValidationError> {
) -> Result<(tokenizers::Encoding, Vec<Chunk>), ValidationError> {
static RE: Lazy<Regex> = Lazy::new(|| Regex::new(r"!\[\]\([^\)]*\)").unwrap());
let (tokenizer_query, input_chunks) = match config {
Some(Config::LlavaNext(config)) => {
@ -626,18 +625,51 @@ fn prepare_input(
type TokenizerRequest = (
(String, Option<usize>),
oneshot::Sender<Result<(tokenizers::Encoding, Vec<InputChunk>), ValidationError>>,
oneshot::Sender<Result<(tokenizers::Encoding, Vec<Chunk>), ValidationError>>,
Span,
);
#[derive(Debug, Clone, Eq, PartialEq)]
pub struct Image {
pub data: Vec<u8>,
pub mimetype: String,
}
#[derive(Debug, Clone, Eq, PartialEq)]
pub enum Chunk {
Text(String),
Image(Image)
}
/// Convert input chunks to a stringly-typed input for backwards
/// compat for backends that haven't implemented chunked inputs.
pub trait ChunksToString {
/// Convert chunks to string.
fn chunks_to_string(&self) -> String;
}
impl ChunksToString for Vec<Chunk> {
fn chunks_to_string(&self) -> String {
let mut output = String::new();
self.iter().for_each(|c| match &c {
Chunk::Text(text) => output.push_str(text),
Chunk::Image(Image { data, mimetype }) => {
let encoded = STANDARD.encode(data);
output.push_str(&format!("![](data:{};base64,{})", mimetype, encoded))
}
});
output
}
}
#[derive(Debug, Clone)]
pub(crate) enum ValidGrammar {
pub enum ValidGrammar {
Json(String),
Regex(String),
}
#[derive(Debug, Clone)]
pub(crate) struct ValidParameters {
pub struct ValidParameters {
/// / exponential scaling output probability distribution
pub temperature: f32,
/// / restricting to the k highest probability elements
@ -661,7 +693,7 @@ pub(crate) struct ValidParameters {
}
#[derive(Debug, Clone)]
pub(crate) struct ValidStoppingParameters {
pub struct ValidStoppingParameters {
/// / Maximum number of generated tokens
pub max_new_tokens: u32,
/// / Optional stopping sequences
@ -672,8 +704,8 @@ pub(crate) struct ValidStoppingParameters {
}
#[derive(Debug, Clone)]
pub(crate) struct ValidGenerateRequest {
pub inputs: Vec<InputChunk>,
pub struct ValidGenerateRequest {
pub inputs: Vec<Chunk>,
pub input_length: u32,
pub truncate: u32,
pub decoder_input_details: bool,