diff --git a/Cargo.lock b/Cargo.lock index 3a5845a7..bb63422e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -180,6 +180,17 @@ version = "1.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1505bd5d3d116872e7271a6d4e16d81d0c8570876c8de68093a09ac269d8aac0" +[[package]] +name = "atty" +version = "0.2.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d9b39be18770d11421cdb1b9947a45dd3f37e93092cbf377614828a319d5fee8" +dependencies = [ + "hermit-abi 0.1.19", + "libc", + "winapi", +] + [[package]] name = "autocfg" version = "1.3.0" @@ -565,6 +576,12 @@ version = "0.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "df8670b8c7b9dae1793364eafadf7239c40d669904660c5960d74cfd80b46a53" +[[package]] +name = "cast" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "37b2a672a2cb129a2e41c10b1224bb368f9f37a2b16b612598138befd7b37eb5" + [[package]] name = "cc" version = "1.1.7" @@ -617,6 +634,17 @@ dependencies = [ "libloading", ] +[[package]] +name = "clap" +version = "2.34.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a0610544180c38b88101fecf2dd634b174a62eef6946f84dfc6a7127512b381c" +dependencies = [ + "bitflags 1.3.2", + "textwrap", + "unicode-width", +] + [[package]] name = "clap" version = "4.5.11" @@ -735,6 +763,42 @@ dependencies = [ "cfg-if", ] +[[package]] +name = "criterion" +version = "0.3.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b01d6de93b2b6c65e17c634a26653a29d107b3c98c607c765bf38d041531cd8f" +dependencies = [ + "atty", + "cast", + "clap 2.34.0", + "criterion-plot", + "csv", + "itertools 0.10.5", + "lazy_static", + "num-traits", + "oorandom", + "plotters", + "rayon", + "regex", + "serde", + "serde_cbor", + "serde_derive", + "serde_json", + "tinytemplate", + "walkdir", +] + +[[package]] +name = "criterion-plot" +version = "0.4.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2673cc8207403546f45f5fd319a974b1e6983ad1a3ee7e6041650013be041876" +dependencies = [ + "cast", + "itertools 0.10.5", +] + [[package]] name = "crossbeam-channel" version = "0.5.13" @@ -1060,7 +1124,7 @@ checksum = "887d93f60543e9a9362ef8a21beedd0a833c5d9610e18c67abe15a5963dcb1a4" dependencies = [ "bit_field", "flume", - "half", + "half 2.4.1", "lebe", "miniz_oxide", "rayon-core", @@ -1367,6 +1431,12 @@ dependencies = [ "tracing", ] +[[package]] +name = "half" +version = "1.8.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1b43ede17f21864e81be2fa654110bf1e793774238d86ef8555c37e6519c0403" + [[package]] name = "half" version = "2.4.1" @@ -1404,6 +1474,15 @@ version = "0.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea" +[[package]] +name = "hermit-abi" +version = "0.1.19" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "62b467343b94ba476dcb2500d242dadbb39557df889310ac77c5d99100aaac33" +dependencies = [ + "libc", +] + [[package]] name = "hermit-abi" version = "0.3.9" @@ -1804,6 +1883,15 @@ dependencies = [ "either", ] +[[package]] +name = "itertools" +version = "0.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "413ee7dfc52ee1a4949ceeb7dbc8a33f2d6c088194d9f922fb8318faf1f01186" +dependencies = [ + "either", +] + [[package]] name = "itoa" version = "1.0.11" @@ -1844,7 +1932,7 @@ dependencies = [ "anyhow", "base64 0.21.7", "bytecount", - "clap", + "clap 4.5.11", "fancy-regex", "fraction", "getrandom", @@ -2132,7 +2220,7 @@ version = "1.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4569e456d394deccd22ce1c1913e6ea0e54519f577285001215d33557431afe4" dependencies = [ - "hermit-abi", + "hermit-abi 0.3.9", "libc", "wasi", "windows-sys 0.52.0", @@ -2400,7 +2488,7 @@ version = "1.16.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4161fcb6d602d4d2081af7c3a45852d875a03dd337a6bfdd6e06407b61342a43" dependencies = [ - "hermit-abi", + "hermit-abi 0.3.9", "libc", ] @@ -2456,6 +2544,12 @@ dependencies = [ "pkg-config", ] +[[package]] +name = "oorandom" +version = "11.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b410bbe7e14ab526a0e86877eb47c6996a2bd7746f027ba551028c925390e4e9" + [[package]] name = "openssl" version = "0.10.66" @@ -2783,6 +2877,34 @@ version = "0.3.30" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d231b230927b5e4ad203db57bbcbee2802f6bce620b1e4a9024a07d94e2907ec" +[[package]] +name = "plotters" +version = "0.3.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a15b6eccb8484002195a3e44fe65a4ce8e93a625797a063735536fd59cb01cf3" +dependencies = [ + "num-traits", + "plotters-backend", + "plotters-svg", + "wasm-bindgen", + "web-sys", +] + +[[package]] +name = "plotters-backend" +version = "0.3.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "414cec62c6634ae900ea1c56128dfe87cf63e7caece0852ec76aba307cebadb7" + +[[package]] +name = "plotters-svg" +version = "0.3.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "81b30686a7d9c3e010b84284bdd26a29f2138574f52f5eb6f794fc0ad924e705" +dependencies = [ + "plotters-backend", +] + [[package]] name = "png" version = "0.17.13" @@ -3525,6 +3647,16 @@ dependencies = [ "serde_derive", ] +[[package]] +name = "serde_cbor" +version = "0.11.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2bef2ebfde456fb76bbcf9f59315333decc4fda0b2b44b420243c11e0f5ec1f5" +dependencies = [ + "half 1.8.3", + "serde", +] + [[package]] name = "serde_derive" version = "1.0.204" @@ -3891,7 +4023,7 @@ version = "2.2.1-dev0" dependencies = [ "async-stream", "async-trait", - "clap", + "clap 4.5.11", "cmake", "cxx", "cxx-build", @@ -3912,7 +4044,7 @@ name = "text-generation-benchmark" version = "2.2.1-dev0" dependencies = [ "average", - "clap", + "clap 4.5.11", "crossterm", "float-ord", "hf-hub", @@ -3950,7 +4082,7 @@ dependencies = [ name = "text-generation-launcher" version = "2.2.1-dev0" dependencies = [ - "clap", + "clap 4.5.11", "ctrlc", "float_eq", "hf-hub", @@ -3974,7 +4106,7 @@ dependencies = [ "axum 0.7.5", "axum-tracing-opentelemetry", "base64 0.22.1", - "clap", + "clap 4.5.11", "csv", "futures", "futures-util", @@ -4022,13 +4154,15 @@ dependencies = [ "axum 0.7.5", "axum-tracing-opentelemetry", "base64 0.22.1", - "clap", + "clap 4.5.11", + "criterion", "futures", "futures-util", "grpc-metadata", "hf-hub", "image", "init-tracing-opentelemetry", + "itertools 0.13.0", "jsonschema", "metrics", "metrics-exporter-prometheus", @@ -4062,6 +4196,15 @@ dependencies = [ "utoipa-swagger-ui", ] +[[package]] +name = "textwrap" +version = "0.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d326610f408c7a4eb6f51c37c330e496b08506c9457c9d34287ecc38809fb060" +dependencies = [ + "unicode-width", +] + [[package]] name = "thiserror" version = "1.0.63" @@ -4136,6 +4279,16 @@ dependencies = [ "time-core", ] +[[package]] +name = "tinytemplate" +version = "1.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "be4d6b5f19ff7664e8c98d03e2139cb510db9b0a60b55f8e8709b689d939b6bc" +dependencies = [ + "serde", + "serde_json", +] + [[package]] name = "tinyvec" version = "1.8.0" diff --git a/backends/v3/Cargo.toml b/backends/v3/Cargo.toml index 129ceb9c..06a44bec 100644 --- a/backends/v3/Cargo.toml +++ b/backends/v3/Cargo.toml @@ -35,8 +35,14 @@ serde = "1.0.188" serde_json = "1.0.107" slotmap = "1.0.7" thiserror = "1.0.48" -tokenizers = { workspace = true} -tokio = { version = "1.32.0", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync"] } +tokenizers = { workspace = true } +tokio = { version = "1.32.0", features = [ + "rt", + "rt-multi-thread", + "parking_lot", + "signal", + "sync", +] } tokio-stream = "0.1.14" tower-http = { version = "0.5.1", features = ["cors"] } tracing = "0.1.37" @@ -44,7 +50,9 @@ tracing-opentelemetry = "0.21.0" tracing-subscriber = { version = "0.3.17", features = ["json", "env-filter"] } utoipa = { version = "4.2.0", features = ["axum_extras"] } utoipa-swagger-ui = { version = "6.0.0", features = ["axum"] } -init-tracing-opentelemetry = { version = "0.14.1", features = ["opentelemetry-otlp"] } +init-tracing-opentelemetry = { version = "0.14.1", features = [ + "opentelemetry-otlp", +] } minijinja = { version = "2.0.2" } minijinja-contrib = { version = "2.0.2", features = ["pycompat"] } futures-util = "0.3.30" @@ -60,8 +68,16 @@ tower = "^0.4" tonic-build = "0.10.1" prost-build = "0.12.1" +[dev-dependencies] +criterion = "0.3" +itertools = "0.13" + [features] default = ["ngrok"] ngrok = ["text-generation-router/ngrok"] google = ["text-generation-router/google"] kserve = ["text-generation-router/kserve"] + +[[bench]] +name = "prefix_cache" +harness = false diff --git a/backends/v3/benches/prefix_cache.rs b/backends/v3/benches/prefix_cache.rs new file mode 100644 index 00000000..d9df45b2 --- /dev/null +++ b/backends/v3/benches/prefix_cache.rs @@ -0,0 +1,47 @@ +use std::sync::Arc; + +use criterion::{black_box, criterion_group, criterion_main, Criterion}; +use rand::Rng; + +use text_generation_router_v3::block_allocator::Allocator; +use text_generation_router_v3::radix::RadixAllocator; + +fn prefix_cache_benchmark(c: &mut Criterion) { + // let prefixes: Vec> = (0..8192) + // .chunks(256) + // .into_iter() + // .map(|c| c.collect()) + // .collect(); + + let mut cache = RadixAllocator::new(1, 262144, None); + + c.bench_function("Radix allocator", |b| { + b.iter_batched( + || { + //prefixes + // .choose_multiple(&mut rand::thread_rng(), 5) + // .fold(Vec::new(), |mut v, s| { + // v.extend(s); + // v + // }) + + (0..7936) + .map(|_| rand::thread_rng().gen_range(0..1024)) + .collect::>() + }, + |prefill| { + let alloc = cache.allocate( + prefill.len() as u32 + 13, + Some(Arc::new(black_box(prefill))), + ); + if let Some(alloc) = alloc { + cache.free(alloc.blocks.clone(), alloc.allocation_id); + } + }, + criterion::BatchSize::SmallInput, + ); + }); +} + +criterion_group!(benches, prefix_cache_benchmark); +criterion_main!(benches); diff --git a/backends/v3/src/block_allocator.rs b/backends/v3/src/block_allocator.rs index 05c2bd30..c5503b8c 100644 --- a/backends/v3/src/block_allocator.rs +++ b/backends/v3/src/block_allocator.rs @@ -4,7 +4,7 @@ use tokio::sync::{mpsc, oneshot}; use crate::radix::RadixAllocator; #[derive(Debug, Clone)] -pub(crate) struct BlockAllocation { +pub struct BlockAllocation { pub allocation_id: u64, pub blocks: Vec, pub slots: Vec, @@ -25,7 +25,7 @@ impl Drop for BlockAllocation { } #[derive(Debug, Clone)] -pub(crate) struct BlockAllocator { +pub struct BlockAllocator { /// Channel to communicate with the background task block_allocator: mpsc::UnboundedSender, } @@ -128,7 +128,7 @@ enum BlockAllocatorCommand { }, } -pub(crate) trait Allocator { +pub trait Allocator { fn allocate( &mut self, tokens: u32, diff --git a/backends/v3/src/lib.rs b/backends/v3/src/lib.rs index c8fc55f8..77a9a11a 100644 --- a/backends/v3/src/lib.rs +++ b/backends/v3/src/lib.rs @@ -1,8 +1,8 @@ mod backend; -mod block_allocator; +pub mod block_allocator; mod client; mod queue; -mod radix; +pub mod radix; use crate::client::{ClientError, ShardedClient}; pub(crate) use backend::BackendV3; diff --git a/backends/v3/src/queue.rs b/backends/v3/src/queue.rs index 13544235..0fb05a98 100644 --- a/backends/v3/src/queue.rs +++ b/backends/v3/src/queue.rs @@ -250,7 +250,7 @@ impl State { // Create span for this batch to add context to inference calls let next_batch_span = info_span!(parent: None, "batch", batch_size = tracing::field::Empty); - next_batch_span.follows_from(&Span::current()); + next_batch_span.follows_from(Span::current()); let mut batch_requests = Vec::with_capacity(self.entries.len()); let mut batch_entries = diff --git a/backends/v3/src/radix.rs b/backends/v3/src/radix.rs index 0464b9f8..ef963532 100644 --- a/backends/v3/src/radix.rs +++ b/backends/v3/src/radix.rs @@ -205,6 +205,11 @@ pub struct RadixTrie { /// call that a real time lookup would require. time: u64, } +impl Default for RadixTrie { + fn default() -> Self { + Self::new() + } +} impl RadixTrie { /// Construct a new radix trie. diff --git a/router/Cargo.toml b/router/Cargo.toml index 1be74546..7773e212 100644 --- a/router/Cargo.toml +++ b/router/Cargo.toml @@ -27,8 +27,14 @@ reqwest = { version = "0.11.20", features = [] } serde = "1.0.188" serde_json = "1.0.107" thiserror = "1.0.48" -tokenizers = { workspace = true} -tokio = { version = "1.32.0", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync"] } +tokenizers = { workspace = true } +tokio = { version = "1.32.0", features = [ + "rt", + "rt-multi-thread", + "parking_lot", + "signal", + "sync", +] } tokio-stream = "0.1.14" tower-http = { version = "0.5.1", features = ["cors"] } tracing = "0.1.40" @@ -37,7 +43,9 @@ tracing-subscriber = { version = "0.3.18", features = ["json", "env-filter"] } utoipa = { version = "4.2.0", features = ["axum_extras"] } utoipa-swagger-ui = { version = "6.0.0", features = ["axum"] } ngrok = { version = "0.13.1", features = ["axum"], optional = true } -init-tracing-opentelemetry = { version = "0.14.1", features = ["opentelemetry-otlp"] } +init-tracing-opentelemetry = { version = "0.14.1", features = [ + "opentelemetry-otlp", +] } minijinja = { version = "2.0.2" } minijinja-contrib = { version = "2.0.2", features = ["pycompat"] } futures-util = "0.3.30" @@ -46,7 +54,11 @@ once_cell = "1.19.0" image = "0.25.1" base64 = { workspace = true } sysinfo = "0.30.13" -uuid = { version = "1.9.1", default-features = false, features = ["v4", "fast-rng", "macro-diagnostics"] } +uuid = { version = "1.9.1", default-features = false, features = [ + "v4", + "fast-rng", + "macro-diagnostics", +] } csv = "1.3.0" ureq = "=2.9"