diff --git a/Cargo.lock b/Cargo.lock index 0ec85025..ce911ce7 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -194,17 +194,6 @@ dependencies = [ "v_frame", ] -[[package]] -name = "average" -version = "0.14.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c309b1c7fca12ebeec3ecba29ea917b3a4cb458ccf504df68bb4d8a0ca565a00" -dependencies = [ - "easy-cast", - "float-ord", - "num-traits", -] - [[package]] name = "avif-serialize" version = "0.8.1" @@ -503,12 +492,6 @@ dependencies = [ "thiserror", ] -[[package]] -name = "cassowary" -version = "0.3.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "df8670b8c7b9dae1793364eafadf7239c40d669904660c5960d74cfd80b46a53" - [[package]] name = "cc" version = "1.0.98" @@ -570,7 +553,7 @@ version = "4.5.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "528131438037fd55894f62d6e9f068b8f45ac57ffa77517819645d10aed04f64" dependencies = [ - "heck 0.5.0", + "heck", "proc-macro2", "quote", "syn 2.0.66", @@ -675,31 +658,6 @@ version = "0.8.20" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "22ec99545bb0ed0ea7bb9b8e1e9122ea386ff8a48c0922e43f36d45ab09e0e80" -[[package]] -name = "crossterm" -version = "0.27.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f476fe445d41c9e991fd07515a6f463074b782242ccf4a5b7b1d1012e70824df" -dependencies = [ - "bitflags 2.5.0", - "crossterm_winapi", - "libc", - "mio", - "parking_lot", - "signal-hook", - "signal-hook-mio", - "winapi", -] - -[[package]] -name = "crossterm_winapi" -version = "0.9.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "acdd7c62a3665c7f6830a51635d9ac9b23ed385797f70a83bb8bafe9c572ab2b" -dependencies = [ - "winapi", -] - [[package]] name = "crunchy" version = "0.2.2" @@ -832,15 +790,6 @@ dependencies = [ "windows-sys 0.48.0", ] -[[package]] -name = "easy-cast" -version = "0.5.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "10936778145f3bea71fd9bf61332cce28c28e96a380714f7ab34838b80733fd6" -dependencies = [ - "libm", -] - [[package]] name = "either" version = "1.12.0" @@ -944,12 +893,6 @@ dependencies = [ "miniz_oxide", ] -[[package]] -name = "float-ord" -version = "0.3.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8ce81f49ae8a0482e4c55ea62ebbd7e5a686af544c00b9d090bba3ff9be97b3d" - [[package]] name = "float_eq" version = "1.0.1" @@ -1208,12 +1151,6 @@ version = "0.14.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e5274423e17b7c9fc20b6e7e208532f9b19825d82dfd615708b70edd83df41f1" -[[package]] -name = "heck" -version = "0.4.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "95505c38b4572b2d910cecb0281560f54b440a19336cbbcb27bf6ce6adc6f5a8" - [[package]] name = "heck" version = "0.5.0" @@ -1498,12 +1435,6 @@ dependencies = [ "unicode-width", ] -[[package]] -name = "indoc" -version = "2.0.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b248f5224d1d606005e02c97f5aa4e88eeb230488bcc03bc9ca4d7991399f2b5" - [[package]] name = "init-tracing-opentelemetry" version = "0.14.1" @@ -1674,12 +1605,6 @@ dependencies = [ "once_cell", ] -[[package]] -name = "libm" -version = "0.2.8" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4ec2a862134d2a7d32d7983ddcdd1c4923530833c9f2ea1a44fc5fa473989058" - [[package]] name = "libredox" version = "0.1.3" @@ -1896,7 +1821,6 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a4a650543ca06a924e8b371db273b2756685faae30f8487da1b56505a8f78b0c" dependencies = [ "libc", - "log", "wasi", "windows-sys 0.48.0", ] @@ -2148,7 +2072,6 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "071dfc062690e90b734c0b2273ce72ad0ffa95f0c74596bc250dcfd960262841" dependencies = [ "autocfg", - "libm", ] [[package]] @@ -2412,17 +2335,6 @@ version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b15813163c1d831bf4a13c3610c05c0d03b39feb07f7e09fa234dac9b15aaf39" -[[package]] -name = "papergrid" -version = "0.10.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a2ccbe15f2b6db62f9a9871642746427e297b0ceb85f9a7f1ee5ff47d184d0c8" -dependencies = [ - "bytecount", - "fnv", - "unicode-width", -] - [[package]] name = "parking_lot" version = "0.12.3" @@ -2626,7 +2538,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "22505a5c94da8e3b7c2996394d1c933236c4d743e81a410bcca4e6989fc066a4" dependencies = [ "bytes", - "heck 0.5.0", + "heck", "itertools 0.12.1", "log", "multimap", @@ -2745,23 +2657,6 @@ dependencies = [ "getrandom", ] -[[package]] -name = "ratatui" -version = "0.23.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2e2e4cd95294a85c3b4446e63ef054eea43e0205b1fd60120c16b74ff7ff96ad" -dependencies = [ - "bitflags 2.5.0", - "cassowary", - "crossterm", - "indoc", - "itertools 0.11.0", - "paste", - "strum", - "unicode-segmentation", - "unicode-width", -] - [[package]] name = "rav1e" version = "0.7.1" @@ -3269,27 +3164,6 @@ dependencies = [ "lazy_static", ] -[[package]] -name = "signal-hook" -version = "0.3.17" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8621587d4798caf8eb44879d42e56b9a93ea5dcd315a6487c357130095b62801" -dependencies = [ - "libc", - "signal-hook-registry", -] - -[[package]] -name = "signal-hook-mio" -version = "0.2.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "29ad2e15f37ec9a6cc544097b78a1ec90001e9f71b81338ca39f430adaca99af" -dependencies = [ - "libc", - "mio", - "signal-hook", -] - [[package]] name = "signal-hook-registry" version = "1.4.2" @@ -3387,28 +3261,6 @@ version = "0.11.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7da8b5736845d9f2fcb837ea5d9e2628564b3b043a70948a3f0b778838c5fb4f" -[[package]] -name = "strum" -version = "0.25.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "290d54ea6f91c969195bdbcd7442c8c2a2ba87da8bf60a7ee86a235d4bc1e125" -dependencies = [ - "strum_macros", -] - -[[package]] -name = "strum_macros" -version = "0.25.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "23dc1fa9ac9c169a78ba62f0b841814b7abae11bdd047b9c58f893439e309ea0" -dependencies = [ - "heck 0.4.1", - "proc-macro2", - "quote", - "rustversion", - "syn 2.0.66", -] - [[package]] name = "subtle" version = "2.5.0" @@ -3491,36 +3343,12 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a3e535eb8dded36d55ec13eddacd30dec501792ff23a0b1682c38601b8cf2349" dependencies = [ "cfg-expr", - "heck 0.5.0", + "heck", "pkg-config", "toml", "version-compare", ] -[[package]] -name = "tabled" -version = "0.14.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dfe9c3632da101aba5131ed63f9eed38665f8b3c68703a6bb18124835c1a5d22" -dependencies = [ - "papergrid", - "tabled_derive", - "unicode-width", -] - -[[package]] -name = "tabled_derive" -version = "0.6.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "99f688a08b54f4f02f0a3c382aefdb7884d3d69609f785bd253dc033243e3fe4" -dependencies = [ - "heck 0.4.1", - "proc-macro-error", - "proc-macro2", - "quote", - "syn 1.0.109", -] - [[package]] name = "target-lexicon" version = "0.12.14" @@ -3539,45 +3367,6 @@ dependencies = [ "windows-sys 0.52.0", ] -[[package]] -name = "text-generation-benchmark" -version = "2.0.5-dev0" -dependencies = [ - "average", - "clap", - "crossterm", - "float-ord", - "hf-hub", - "ratatui", - "serde", - "serde_json", - "tabled", - "text-generation-client", - "thiserror", - "tokenizers", - "tokio", - "tracing", - "tracing-subscriber", -] - -[[package]] -name = "text-generation-client" -version = "2.0.5-dev0" -dependencies = [ - "async-trait", - "base64 0.22.1", - "futures", - "grpc-metadata", - "prost 0.12.6", - "prost-build", - "thiserror", - "tokio", - "tonic 0.10.2", - "tonic-build", - "tower", - "tracing", -] - [[package]] name = "text-generation-launcher" version = "2.0.5-dev0" @@ -3627,7 +3416,6 @@ dependencies = [ "reqwest", "serde", "serde_json", - "text-generation-client", "thiserror", "tokenizers", "tokio", @@ -3641,6 +3429,54 @@ dependencies = [ "vergen", ] +[[package]] +name = "text-generation-router-v3" +version = "2.0.5-dev0" +dependencies = [ + "async-stream", + "async-trait", + "axum 0.7.5", + "axum-tracing-opentelemetry", + "base64 0.22.1", + "clap", + "futures", + "futures-util", + "grpc-metadata", + "hf-hub", + "image", + "init-tracing-opentelemetry", + "jsonschema", + "metrics", + "metrics-exporter-prometheus", + "minijinja", + "minijinja-contrib", + "nohash-hasher", + "once_cell", + "opentelemetry 0.20.0", + "opentelemetry-otlp", + "prost 0.12.6", + "prost-build", + "rand", + "regex", + "reqwest", + "serde", + "serde_json", + "text-generation-router", + "thiserror", + "tokenizers", + "tokio", + "tokio-stream", + "tonic 0.10.2", + "tonic-build", + "tower", + "tower-http", + "tracing", + "tracing-opentelemetry 0.21.0", + "tracing-subscriber", + "utoipa", + "utoipa-swagger-ui", +] + [[package]] name = "thiserror" version = "1.0.61" diff --git a/Cargo.toml b/Cargo.toml index bc2da5a1..28ded514 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,9 +1,9 @@ [workspace] members = [ - "benchmark", - "router", - "router/client", - "router/grpc-metadata", +# "benchmark", + "backends/v3", +# "backends/client", + "backends/grpc-metadata", "launcher" ] resolver = "2" diff --git a/router/client/Cargo.toml b/backends/client/Cargo.toml similarity index 100% rename from router/client/Cargo.toml rename to backends/client/Cargo.toml diff --git a/router/client/src/v2/client.rs b/backends/client/src/v2/client.rs similarity index 100% rename from router/client/src/v2/client.rs rename to backends/client/src/v2/client.rs diff --git a/router/client/src/v2/mod.rs b/backends/client/src/v2/mod.rs similarity index 100% rename from router/client/src/v2/mod.rs rename to backends/client/src/v2/mod.rs diff --git a/router/client/src/v2/sharded_client.rs b/backends/client/src/v2/sharded_client.rs similarity index 100% rename from router/client/src/v2/sharded_client.rs rename to backends/client/src/v2/sharded_client.rs diff --git a/router/grpc-metadata/Cargo.toml b/backends/grpc-metadata/Cargo.toml similarity index 100% rename from router/grpc-metadata/Cargo.toml rename to backends/grpc-metadata/Cargo.toml diff --git a/router/grpc-metadata/src/lib.rs b/backends/grpc-metadata/src/lib.rs similarity index 100% rename from router/grpc-metadata/src/lib.rs rename to backends/grpc-metadata/src/lib.rs diff --git a/backends/v3/Cargo.toml b/backends/v3/Cargo.toml new file mode 100644 index 00000000..161e1050 --- /dev/null +++ b/backends/v3/Cargo.toml @@ -0,0 +1,66 @@ +[package] +name = "text-generation-router-v3" +description = "Text Generation Webserver" +version.workspace = true +edition.workspace = true +authors.workspace = true +homepage.workspace = true + +[lib] +path = "src/lib.rs" + +[[bin]] +name = "text-generation-router" +path = "src/main.rs" + +[dependencies] +async-trait = "0.1.74" +async-stream = "0.3.5" +axum = { version = "0.7", features = ["json"] } +axum-tracing-opentelemetry = "0.16" +text-generation-router = { path = "../../router" } +clap = { version = "4.4.5", features = ["derive", "env"] } +grpc-metadata = { path = "../grpc-metadata" } +futures = "0.3.28" +hf-hub = { workspace = true } +jsonschema = { version = "0.17.1", features = ["draft202012"] } +metrics = "0.21.1" +metrics-exporter-prometheus = { version = "0.12.1", features = [] } +nohash-hasher = "0.2.0" +opentelemetry = { version = "0.20.0", features = ["rt-tokio"] } +opentelemetry-otlp = "0.13.0" +rand = "0.8.5" +reqwest = { version = "0.11.20", features = [] } +serde = "1.0.188" +serde_json = "1.0.107" +thiserror = "1.0.48" +tokenizers = { workspace = true} +tokio = { version = "1.32.0", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync"] } +tokio-stream = "0.1.14" +tower-http = { version = "0.5.1", features = ["cors"] } +tracing = "0.1.37" +tracing-opentelemetry = "0.21.0" +tracing-subscriber = { version = "0.3.17", features = ["json", "env-filter"] } +utoipa = { version = "4.2.0", features = ["axum_extras"] } +utoipa-swagger-ui = { version = "6.0.0", features = ["axum"] } +init-tracing-opentelemetry = { version = "0.14.1", features = ["opentelemetry-otlp"] } +minijinja = { version = "2.0.2" } +minijinja-contrib = { version = "2.0.2", features = ["pycompat"] } +futures-util = "0.3.30" +regex = "1.10.3" +once_cell = "1.19.0" +image = "0.25.1" +base64 = { workspace = true } +prost = "^0.12" +tonic = "^0.10" +tower = "^0.4" + +[build-dependencies] +tonic-build = "0.10.1" +prost-build = "0.12.1" + +[features] +default = ["ngrok"] +ngrok = ["text-generation-router/ngrok"] +google = ["text-generation-router/google"] +kserve = ["text-generation-router/kserve"] diff --git a/backends/v3/build.rs b/backends/v3/build.rs new file mode 100644 index 00000000..6d702d14 --- /dev/null +++ b/backends/v3/build.rs @@ -0,0 +1,19 @@ +use std::fs; + +fn main() -> Result<(), Box> { + println!("cargo:rerun-if-changed=../../proto/"); + + fs::create_dir_all("src/client/pb").unwrap_or(()); + let mut config = prost_build::Config::new(); + config.protoc_arg("--experimental_allow_proto3_optional"); + + tonic_build::configure() + .build_client(true) + .build_server(false) + .out_dir("src/client/pb") + .include_file("mod.rs") + .compile_with_config(config, &["../../proto/v3/generate.proto"], &["../../proto"]) + .unwrap_or_else(|e| panic!("protobuf compilation failed: {e}")); + + Ok(()) +} diff --git a/router/src/infer/v3/backend.rs b/backends/v3/src/backend.rs similarity index 95% rename from router/src/infer/v3/backend.rs rename to backends/v3/src/backend.rs index 6cba6e52..bfe587f4 100644 --- a/router/src/infer/v3/backend.rs +++ b/backends/v3/src/backend.rs @@ -1,16 +1,15 @@ /// Batching and inference logic -use crate::infer::v3::queue::{Entry, Queue}; -use crate::infer::{ +use crate::queue::{Entry, Queue}; +use text_generation_router::infer::{ GeneratedText, InferError, InferStreamResponse, Backend, }; -use crate::validation::ValidGenerateRequest; -use crate::{FinishReason, PrefillToken, Token}; +use text_generation_router::validation::ValidGenerateRequest; +use text_generation_router::{FinishReason, PrefillToken, Token}; use nohash_hasher::IntMap; use std::sync::{ Arc, }; -use text_generation_client::v3::{Batch, CachedBatch, Generation, ShardedClient}; -use text_generation_client::{ClientError, Health}; +use crate::client::{Batch, CachedBatch, Generation, ShardedClient, ClientError, Health}; use tokio::sync::mpsc::error::SendError; use tokio::sync::{mpsc, Notify}; use tokio::time::Instant; @@ -18,7 +17,7 @@ use tokio_stream::wrappers::UnboundedReceiverStream; use tracing::{info_span, instrument, Instrument, Span}; use async_trait::async_trait; -pub(crate) struct BackendV3 { +pub struct BackendV3 { /// Request queue queue: Queue, /// Notify batcher on queue appends @@ -78,7 +77,6 @@ impl Backend for BackendV3 { ) -> Result>, InferError> { // MPSC channel to communicate with the background batching task let (response_tx, response_rx) = mpsc::unbounded_channel(); - let input_length = request.input_length; // Append the request to the queue self.queue.append(Entry { @@ -480,14 +478,14 @@ fn send_errors(error: ClientError, entries: &mut IntMap) { }); } -impl From for GeneratedText { - fn from(value: text_generation_client::v3::GeneratedText) -> Self { +impl From for GeneratedText { + fn from(value: crate::client::GeneratedText) -> Self { let v3_finish_reason = - text_generation_client::v3::FinishReason::try_from(value.finish_reason).unwrap(); + crate::client::FinishReason::try_from(value.finish_reason).unwrap(); let finish_reason = match v3_finish_reason { - text_generation_client::v3::FinishReason::Length => FinishReason::Length, - text_generation_client::v3::FinishReason::EosToken => FinishReason::EndOfSequenceToken, - text_generation_client::v3::FinishReason::StopSequence => FinishReason::StopSequence, + crate::client::FinishReason::Length => FinishReason::Length, + crate::client::FinishReason::EosToken => FinishReason::EndOfSequenceToken, + crate::client::FinishReason::StopSequence => FinishReason::StopSequence, }; Self { diff --git a/router/src/infer/v3/block_allocator.rs b/backends/v3/src/block_allocator.rs similarity index 100% rename from router/src/infer/v3/block_allocator.rs rename to backends/v3/src/block_allocator.rs diff --git a/router/client/src/v3/client.rs b/backends/v3/src/client/client.rs similarity index 99% rename from router/client/src/v3/client.rs rename to backends/v3/src/client/client.rs index a996b14f..242a82d9 100644 --- a/router/client/src/v3/client.rs +++ b/backends/v3/src/client/client.rs @@ -1,6 +1,7 @@ -use crate::v3::{pb, Chunk}; -use crate::{ClientError, Result, WARMUP_IMAGE_BASE64}; /// Single shard Client + +use crate::client::{pb, Chunk}; +use crate::client::{ClientError, Result, WARMUP_IMAGE_BASE64}; use base64::engine::general_purpose::STANDARD; use base64::Engine; use grpc_metadata::InjectTelemetryContext; diff --git a/router/client/src/lib.rs b/backends/v3/src/client/mod.rs similarity index 69% rename from router/client/src/lib.rs rename to backends/v3/src/client/mod.rs index 45bee10c..4099ff87 100644 --- a/router/client/src/lib.rs +++ b/backends/v3/src/client/mod.rs @@ -1,15 +1,24 @@ //! Text Generation gRPC client library -pub mod v2; -pub mod v3; use async_trait::async_trait; -use base64::{engine::general_purpose::STANDARD, Engine}; use thiserror::Error; use tonic::transport; use tonic::Status; -pub use v3::{Chunk, Image, Input, InputChunk}; +#[allow(clippy::derive_partial_eq_without_eq)] +mod pb; + +mod client; +mod sharded_client; + +pub use client::Client; +pub use pb::generate::v3::{ + input_chunk::Chunk, Batch, CachedBatch, FinishReason, GeneratedText, Generation, GrammarType, + HealthResponse, Image, InfoResponse, Input, InputChunk, NextTokenChooserParameters, Request, + StoppingCriteriaParameters, +}; +pub use sharded_client::ShardedClient; #[async_trait] pub trait Health { @@ -63,29 +72,6 @@ impl From for InputChunk { } } -/// Convert input chunks to a stringly-typed input for backwards -/// compat for backends that haven't implemented chunked inputs. -pub trait ChunksToString { - /// Convert chunks to string. - fn chunks_to_string(&self) -> String; -} - -impl ChunksToString for Vec { - fn chunks_to_string(&self) -> String { - let mut output = String::new(); - self.iter().for_each(|c| match &c.chunk { - Some(Chunk::Text(text)) => output.push_str(text), - Some(Chunk::Image(Image { data, mimetype })) => { - let encoded = STANDARD.encode(data); - output.push_str(&format!("![](data:{};base64,{})", mimetype, encoded)) - } - // We don't create empty chunks, so this should be unreachable. - None => unreachable!("Chunks should never be empty"), - }); - output - } -} - static WARMUP_IMAGE_BASE64 :&str = "iVBORw0KGgoAAAANSUhEUgAAABQAAAAUCAIAAAAC64paAAABg2lDQ1BJQ0MgcHJvZmlsZQAAKJF9kT1Iw0AcxV/TSotUROxQxCFDdbKLijjWKhShQqgVWnUwufQLmrQkKS6OgmvBwY/FqoOLs64OroIg+AHi7OCk6CIl/i8ptIjx4Lgf7+497t4BQqvKNDOQADTdMjKppJjLr4rBVwQQwhAERGVm1uckKQ3P8XUPH1/v4jzL+9yfY0AtmAzwicQJVjcs4g3imU2rznmfOMLKskp8Tjxh0AWJH7muuPzGueSwwDMjRjYzTxwhFks9rPQwKxsa8TRxTNV0yhdyLquctzhr1Qbr3JO/MFzQV5a5TnMUKSxiCRJEKGiggiosxGnVSTGRof2kh3/E8UvkUshVASPHAmrQIDt+8D/43a1ZnJp0k8JJoO/Ftj/GgOAu0G7a9vexbbdPAP8zcKV3/bUWMPtJerOrxY6AwW3g4rqrKXvA5Q4QfarLhuxIfppCsQi8n9E35YHhW6B/ze2ts4/TByBLXaVvgINDYLxE2ese7w719vbvmU5/PycecohsjayNAAAACXBIWXMAAC4jAAAuIwF4pT92AAAAB3RJTUUH6AQIEQMnlTSSjwAAABl0RVh0Q29tbWVudABDcmVhdGVkIHdpdGggR0lNUFeBDhcAAAASSURBVDjLY2AYBaNgFIyCoQsABMQAAeRw1DoAAAAASUVORK5CYII="; pub type Result = std::result::Result; diff --git a/router/client/src/v3/sharded_client.rs b/backends/v3/src/client/sharded_client.rs similarity index 97% rename from router/client/src/v3/sharded_client.rs rename to backends/v3/src/client/sharded_client.rs index ae8a899b..32365648 100644 --- a/router/client/src/v3/sharded_client.rs +++ b/backends/v3/src/client/sharded_client.rs @@ -1,14 +1,14 @@ /// Multi shard Client -use crate::{v3, Health, ShardInfo}; -use crate::{ClientError, Result}; +use crate::client::{Health, ShardInfo}; +use crate::client::{ClientError, Result}; -use crate::v3::{Chunk, InfoResponse, Input}; +use crate::client::{Chunk, InfoResponse, Input}; use async_trait::async_trait; use futures::future::join_all; use tonic::transport::Uri; use tracing::instrument; -use v3::client::{DecodeTimings, PrefillTimings}; -use v3::{ +use crate::client::client::{DecodeTimings, PrefillTimings}; +use crate::client::{ Batch, CachedBatch, Client, Generation, GrammarType, HealthResponse, NextTokenChooserParameters, Request, StoppingCriteriaParameters, }; diff --git a/router/src/infer/v3/mod.rs b/backends/v3/src/lib.rs similarity index 96% rename from router/src/infer/v3/mod.rs rename to backends/v3/src/lib.rs index 0f7f4fdb..cd4b3b0a 100644 --- a/router/src/infer/v3/mod.rs +++ b/backends/v3/src/lib.rs @@ -1,14 +1,13 @@ mod block_allocator; mod queue; mod backend; +mod client; -use futures_util::TryFutureExt; use serde::Serialize; use thiserror::Error; use utoipa::ToSchema; pub(crate) use backend::BackendV3; -use text_generation_client::ClientError; -use text_generation_client::v3::ShardedClient; +use crate::client::{ShardedClient, ClientError}; #[derive(Clone, Debug, Serialize, ToSchema)] pub struct BackendInfo { @@ -127,7 +126,7 @@ pub async fn connect_backend( } #[derive(Debug, Error)] -pub(crate) enum V3Error { +pub enum V3Error { #[error("Unable to clear the Python model shards cache: {0}")] Cache(ClientError), #[error("Unable to connect to the Python model shards: {0}")] diff --git a/router/src/main.rs b/backends/v3/src/main.rs similarity index 97% rename from router/src/main.rs rename to backends/v3/src/main.rs index a7caec2e..e6971ac8 100644 --- a/router/src/main.rs +++ b/backends/v3/src/main.rs @@ -20,6 +20,7 @@ use tower_http::cors::AllowOrigin; use tracing_subscriber::layer::SubscriberExt; use tracing_subscriber::util::SubscriberInitExt; use tracing_subscriber::{filter::LevelFilter, EnvFilter, Layer}; +use text_generation_router_v3::{connect_backend, V3Error}; /// App Configuration #[derive(Parser, Debug)] @@ -336,9 +337,11 @@ async fn main() -> Result<(), RouterError> { } }; + let (backend, backend_info) = connect_backend(max_input_tokens, max_total_tokens, master_shard_uds_path, waiting_served_ratio, max_batch_prefill_tokens, max_batch_total_tokens, max_waiting_tokens, max_batch_size).await?; + // Run server server::run( - master_shard_uds_path, + backend, model_info, compat_return_full_text, max_concurrent_requests, @@ -347,11 +350,6 @@ async fn main() -> Result<(), RouterError> { max_top_n_tokens, max_input_tokens, max_total_tokens, - waiting_served_ratio, - max_batch_prefill_tokens, - max_batch_total_tokens, - max_waiting_tokens, - max_batch_size, tokenizer, config, validation_workers, @@ -508,6 +506,8 @@ pub async fn get_tokenizer_config(api_repo: &ApiRepo) -> Option client::Chunk::Text(text), + Chunk::Image(image) => client::Chunk::Image(client::Image { data: image.data, mimetype: image.mimetype }) + })}).collect() }), inputs: entry.request.inputs.chunks_to_string(), truncate: entry.request.truncate, diff --git a/benchmark/Cargo.toml b/benchmark/Cargo.toml index 756460e0..f82659c9 100644 --- a/benchmark/Cargo.toml +++ b/benchmark/Cargo.toml @@ -21,7 +21,7 @@ float-ord = "0.3.2" serde = {version = "1.0.188", features = ["derive"]} serde_json = "1.0" tabled = "0.14.0" -text-generation-client = { path = "../router/client" } +text-generation-client = { path = "../backends/client" } thiserror = "1.0.48" tokenizers = { workspace = true } tokio = { version = "1.32.0", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync", "macros"] } diff --git a/router/Cargo.toml b/router/Cargo.toml index bf96ab91..8e7c8bdc 100644 --- a/router/Cargo.toml +++ b/router/Cargo.toml @@ -7,19 +7,11 @@ edition.workspace = true authors.workspace = true homepage.workspace = true -[lib] -path = "src/lib.rs" - -[[bin]] -name = "text-generation-router" -path = "src/main.rs" - [dependencies] async-trait = "0.1.74" async-stream = "0.3.5" axum = { version = "0.7", features = ["json"] } axum-tracing-opentelemetry = "0.16" -text-generation-client = { path = "client" } clap = { version = "4.4.5", features = ["derive", "env"] } futures = "0.3.28" hf-hub = { workspace = true } diff --git a/router/client/build.rs b/router/client/build.rs deleted file mode 100644 index 210cd603..00000000 --- a/router/client/build.rs +++ /dev/null @@ -1,35 +0,0 @@ -use std::fs; - -fn main() -> Result<(), Box> { - println!("cargo:rerun-if-changed=../../proto/"); - - fs::create_dir_all("src/v2/pb").unwrap_or(()); - let mut config = prost_build::Config::new(); - config.protoc_arg("--experimental_allow_proto3_optional"); - - tonic_build::configure() - .build_client(true) - .build_server(false) - .out_dir("src/v2/pb") - .include_file("mod.rs") - .compile_with_config(config, &["../../proto/generate.proto"], &["../../proto"]) - .map_err(|e| match e.kind(){ - std::io::ErrorKind::NotFound => {panic!("`protoc` not found, install libprotoc")}, - std::io::ErrorKind::Other => {panic!("`protoc` version unsupported, upgrade protoc: https://github.com/protocolbuffers/protobuf/releases")}, - e => {e} - }).unwrap_or_else(|e| panic!("protobuf compilation failed: {e}")); - - fs::create_dir_all("src/v3/pb").unwrap_or(()); - let mut config = prost_build::Config::new(); - config.protoc_arg("--experimental_allow_proto3_optional"); - - tonic_build::configure() - .build_client(true) - .build_server(false) - .out_dir("src/v3/pb") - .include_file("mod.rs") - .compile_with_config(config, &["../../proto/v3/generate.proto"], &["../../proto"]) - .unwrap_or_else(|e| panic!("protobuf compilation failed: {e}")); - - Ok(()) -} diff --git a/router/client/src/v2/pb/.gitignore b/router/client/src/v2/pb/.gitignore deleted file mode 100644 index 72e8ffc0..00000000 --- a/router/client/src/v2/pb/.gitignore +++ /dev/null @@ -1 +0,0 @@ -* diff --git a/router/client/src/v3/mod.rs b/router/client/src/v3/mod.rs deleted file mode 100644 index 4a1296a2..00000000 --- a/router/client/src/v3/mod.rs +++ /dev/null @@ -1,13 +0,0 @@ -#[allow(clippy::derive_partial_eq_without_eq)] -mod pb; - -mod client; -mod sharded_client; - -pub use client::Client; -pub use pb::generate::v3::{ - input_chunk::Chunk, Batch, CachedBatch, FinishReason, GeneratedText, Generation, GrammarType, - HealthResponse, Image, InfoResponse, Input, InputChunk, NextTokenChooserParameters, Request, - StoppingCriteriaParameters, Tokens, -}; -pub use sharded_client::ShardedClient; diff --git a/router/client/src/v3/pb/.gitignore b/router/client/src/v3/pb/.gitignore deleted file mode 100644 index 72e8ffc0..00000000 --- a/router/client/src/v3/pb/.gitignore +++ /dev/null @@ -1 +0,0 @@ -* diff --git a/router/src/infer/mod.rs b/router/src/infer/mod.rs index a1aedadf..4cef5def 100644 --- a/router/src/infer/mod.rs +++ b/router/src/infer/mod.rs @@ -1,5 +1,4 @@ // pub(crate) mod v2; -pub(crate) mod v3; mod chat_template; pub mod tool_grammar; @@ -23,7 +22,7 @@ use chat_template::ChatTemplate; use async_trait::async_trait; #[async_trait] -pub(crate) trait Backend { +pub trait Backend { fn schedule( &self, request: ValidGenerateRequest, @@ -286,15 +285,15 @@ pub(crate) type GenerateStreamResponse = ( ); #[derive(Debug)] -pub(crate) struct GeneratedText { - pub(crate) text: String, - pub(crate) generated_tokens: u32, - pub(crate) finish_reason: FinishReason, - pub(crate) seed: Option, +pub struct GeneratedText { + pub text: String, + pub generated_tokens: u32, + pub finish_reason: FinishReason, + pub seed: Option, } #[derive(Debug)] -pub(crate) enum InferStreamResponse { +pub enum InferStreamResponse { // Optional first message Prefill(Vec), // Intermediate messages diff --git a/router/src/lib.rs b/router/src/lib.rs index 80ff23d5..d62f5257 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -1,8 +1,8 @@ /// Text Generation Inference Webserver pub mod config; -mod infer; +pub mod infer; pub mod server; -mod validation; +pub mod validation; #[cfg(feature = "kserve")] mod kserve; @@ -1055,23 +1055,23 @@ impl From for GenerateRequest { #[derive(Debug, Serialize, ToSchema)] pub struct PrefillToken { #[schema(example = 0)] - id: u32, + pub id: u32, #[schema(example = "test")] - text: String, + pub text: String, #[schema(nullable = true, example = - 0.34)] - logprob: f32, + pub logprob: f32, } #[derive(Debug, Serialize, ToSchema, Clone)] pub struct Token { #[schema(example = 0)] - id: u32, + pub id: u32, #[schema(example = "test")] - text: String, + pub text: String, #[schema(nullable = true, example = - 0.34)] - logprob: f32, + pub logprob: f32, #[schema(example = "false")] - special: bool, + pub special: bool, } #[derive(Debug, Serialize, ToSchema)] @@ -1089,7 +1089,7 @@ pub struct SimpleToken { #[derive(Debug, Serialize, ToSchema)] #[serde(rename_all(serialize = "snake_case"))] #[schema(example = "Length")] -pub(crate) enum FinishReason { +pub enum FinishReason { #[schema(rename = "length")] Length, #[serde(rename = "eos_token")] diff --git a/router/src/server.rs b/router/src/server.rs index 3b53fff4..a7cfd74e 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -1,6 +1,5 @@ /// HTTP Server logic use crate::config::Config; -use crate::infer::v3::{connect_backend, V3Error}; use crate::infer::{Infer, InferError, InferResponse, InferStreamResponse, Backend}; use crate::infer::tool_grammar::ToolGrammar; #[cfg(feature = "kserve")] @@ -38,9 +37,6 @@ use metrics_exporter_prometheus::{Matcher, PrometheusBuilder, PrometheusHandle}; use serde_json::Value; use std::convert::Infallible; use std::net::SocketAddr; -use std::sync::atomic::AtomicBool; -use std::sync::Arc; -use text_generation_client::{v2, v3, ClientError, ShardInfo}; use thiserror::Error; use tokenizers::Tokenizer; use tokio::select; @@ -1398,7 +1394,7 @@ pub(crate) struct ComputeType(String); /// Serving method #[allow(clippy::too_many_arguments)] pub async fn run( - master_shard_uds_path: String, + backend: impl Backend + Send + Sync + 'static, model_info: HubModelInfo, compat_return_full_text: bool, max_concurrent_requests: usize, @@ -1407,11 +1403,6 @@ pub async fn run( max_top_n_tokens: u32, max_input_tokens: usize, max_total_tokens: usize, - waiting_served_ratio: f32, - max_batch_prefill_tokens: u32, - max_batch_total_tokens: Option, - max_waiting_tokens: usize, - max_batch_size: Option, tokenizer: Option, config: Option, validation_workers: usize, @@ -1495,11 +1486,6 @@ pub async fn run( struct ApiDoc; // Create state - - // Open connection, get model info and warmup - let (backend, backend_info) = connect_backend(max_input_tokens, max_total_tokens, master_shard_uds_path, waiting_served_ratio, max_batch_prefill_tokens, max_batch_total_tokens, max_waiting_tokens, max_batch_size).await?; - // tracing::info!("Setting max batch total tokens to {max_batch_total_tokens}"); - let validation = Validation::new( validation_workers, tokenizer, @@ -1827,8 +1813,6 @@ impl From for Event { #[derive(Debug, Error)] pub enum WebServerError { - #[error("Backend error: {0}")] - Backend(#[from] V3Error), #[error("Axum error: {0}")] Axum(#[from] axum::BoxError), } diff --git a/router/src/validation.rs b/router/src/validation.rs index e2bf5a5d..cf9e107c 100644 --- a/router/src/validation.rs +++ b/router/src/validation.rs @@ -8,7 +8,6 @@ use jsonschema::{Draft, JSONSchema}; use rand::{thread_rng, Rng}; use serde_json::Value; use std::io::Cursor; -use text_generation_client::{Chunk, Image, InputChunk}; use thiserror::Error; use tokenizers::tokenizer::Tokenizer; use tokio::sync::mpsc; @@ -86,7 +85,7 @@ impl Validation { &self, inputs: String, truncate: Option, - ) -> Result)>, ValidationError> { + ) -> Result)>, ValidationError> { // If we have a fast tokenizer if let Some(sender) = &self.sender { // Create response channel @@ -112,7 +111,7 @@ impl Validation { inputs: String, truncate: Option, max_new_tokens: Option, - ) -> Result<(Vec, usize, u32), ValidationError> { + ) -> Result<(Vec, usize, u32), ValidationError> { // If we have a fast tokenizer if let Some((encoding, inputs)) = self.tokenize(inputs.clone(), truncate).await? { // Create response channel @@ -514,7 +513,7 @@ fn prepare_input( _truncate: Option, tokenizer: &Tokenizer, config: &Option, -) -> Result<(tokenizers::Encoding, Vec), ValidationError> { +) -> Result<(tokenizers::Encoding, Vec), ValidationError> { static RE: Lazy = Lazy::new(|| Regex::new(r"!\[\]\([^\)]*\)").unwrap()); let (tokenizer_query, input_chunks) = match config { Some(Config::LlavaNext(config)) => { @@ -626,18 +625,51 @@ fn prepare_input( type TokenizerRequest = ( (String, Option), - oneshot::Sender), ValidationError>>, + oneshot::Sender), ValidationError>>, Span, ); +#[derive(Debug, Clone, Eq, PartialEq)] +pub struct Image { + pub data: Vec, + pub mimetype: String, +} + +#[derive(Debug, Clone, Eq, PartialEq)] +pub enum Chunk { + Text(String), + Image(Image) +} + +/// Convert input chunks to a stringly-typed input for backwards +/// compat for backends that haven't implemented chunked inputs. +pub trait ChunksToString { + /// Convert chunks to string. + fn chunks_to_string(&self) -> String; +} + +impl ChunksToString for Vec { + fn chunks_to_string(&self) -> String { + let mut output = String::new(); + self.iter().for_each(|c| match &c { + Chunk::Text(text) => output.push_str(text), + Chunk::Image(Image { data, mimetype }) => { + let encoded = STANDARD.encode(data); + output.push_str(&format!("![](data:{};base64,{})", mimetype, encoded)) + } + }); + output + } +} + #[derive(Debug, Clone)] -pub(crate) enum ValidGrammar { +pub enum ValidGrammar { Json(String), Regex(String), } #[derive(Debug, Clone)] -pub(crate) struct ValidParameters { +pub struct ValidParameters { /// / exponential scaling output probability distribution pub temperature: f32, /// / restricting to the k highest probability elements @@ -661,7 +693,7 @@ pub(crate) struct ValidParameters { } #[derive(Debug, Clone)] -pub(crate) struct ValidStoppingParameters { +pub struct ValidStoppingParameters { /// / Maximum number of generated tokens pub max_new_tokens: u32, /// / Optional stopping sequences @@ -672,8 +704,8 @@ pub(crate) struct ValidStoppingParameters { } #[derive(Debug, Clone)] -pub(crate) struct ValidGenerateRequest { - pub inputs: Vec, +pub struct ValidGenerateRequest { + pub inputs: Vec, pub input_length: u32, pub truncate: u32, pub decoder_input_details: bool,