From f230da8d636f30763ea22edb5945336734c6b36e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Danie=CC=88l=20de=20Kok?= Date: Tue, 6 Aug 2024 12:36:15 +0000 Subject: [PATCH] Keeping the benchmark somewhere --- backends/v3/Cargo.toml | 22 ++++++++++++-- backends/v3/benches/prefix_cache.rs | 45 +++++++++++++++++++++++++++++ backends/v3/src/block_allocator.rs | 2 +- backends/v3/src/lib.rs | 2 +- router/Cargo.toml | 20 ++++++++++--- 5 files changed, 82 insertions(+), 9 deletions(-) create mode 100644 backends/v3/benches/prefix_cache.rs diff --git a/backends/v3/Cargo.toml b/backends/v3/Cargo.toml index 129ceb9c..06a44bec 100644 --- a/backends/v3/Cargo.toml +++ b/backends/v3/Cargo.toml @@ -35,8 +35,14 @@ serde = "1.0.188" serde_json = "1.0.107" slotmap = "1.0.7" thiserror = "1.0.48" -tokenizers = { workspace = true} -tokio = { version = "1.32.0", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync"] } +tokenizers = { workspace = true } +tokio = { version = "1.32.0", features = [ + "rt", + "rt-multi-thread", + "parking_lot", + "signal", + "sync", +] } tokio-stream = "0.1.14" tower-http = { version = "0.5.1", features = ["cors"] } tracing = "0.1.37" @@ -44,7 +50,9 @@ tracing-opentelemetry = "0.21.0" tracing-subscriber = { version = "0.3.17", features = ["json", "env-filter"] } utoipa = { version = "4.2.0", features = ["axum_extras"] } utoipa-swagger-ui = { version = "6.0.0", features = ["axum"] } -init-tracing-opentelemetry = { version = "0.14.1", features = ["opentelemetry-otlp"] } +init-tracing-opentelemetry = { version = "0.14.1", features = [ + "opentelemetry-otlp", +] } minijinja = { version = "2.0.2" } minijinja-contrib = { version = "2.0.2", features = ["pycompat"] } futures-util = "0.3.30" @@ -60,8 +68,16 @@ tower = "^0.4" tonic-build = "0.10.1" prost-build = "0.12.1" +[dev-dependencies] +criterion = "0.3" +itertools = "0.13" + [features] default = ["ngrok"] ngrok = ["text-generation-router/ngrok"] google = ["text-generation-router/google"] kserve = ["text-generation-router/kserve"] + +[[bench]] +name = "prefix_cache" +harness = false diff --git a/backends/v3/benches/prefix_cache.rs b/backends/v3/benches/prefix_cache.rs new file mode 100644 index 00000000..919faf48 --- /dev/null +++ b/backends/v3/benches/prefix_cache.rs @@ -0,0 +1,45 @@ +use std::sync::Arc; + +use criterion::{black_box, criterion_group, criterion_main, Criterion}; +use itertools::Itertools; +use rand::seq::SliceRandom; +use rand::Rng; + +use text_generation_router_v3::block_allocator::{Allocator, RadixAllocator}; + +fn prefix_cache_benchmark(c: &mut Criterion) { + let prefixes: Vec> = (0..8192) + .chunks(256) + .into_iter() + .map(|c| c.collect()) + .collect(); + + let mut cache = RadixAllocator::new(1, 262144, None); + + c.bench_function("fib 20", |b| { + b.iter_batched( + || { + //prefixes + // .choose_multiple(&mut rand::thread_rng(), 5) + // .fold(Vec::new(), |mut v, s| { + // v.extend(s); + // v + // }) + + (0..7936) + .map(|_| rand::thread_rng().gen_range(0..1024)) + .collect::>() + }, + |prefill| { + let alloc = cache.allocate(prefill.len() as u32 + 13, Some(Arc::new(prefill))); + if let Some(alloc) = alloc { + cache.free(alloc.0, alloc.3); + } + }, + criterion::BatchSize::SmallInput, + ); + }); +} + +criterion_group!(benches, prefix_cache_benchmark); +criterion_main!(benches); diff --git a/backends/v3/src/block_allocator.rs b/backends/v3/src/block_allocator.rs index 37c71653..3653b11c 100644 --- a/backends/v3/src/block_allocator.rs +++ b/backends/v3/src/block_allocator.rs @@ -204,7 +204,7 @@ impl Allocator for SimpleAllocator { } } -struct RadixAllocator { +pub struct RadixAllocator { allocation_id: u64, allocations: HashMap, diff --git a/backends/v3/src/lib.rs b/backends/v3/src/lib.rs index 81e6e4fa..af3f6657 100644 --- a/backends/v3/src/lib.rs +++ b/backends/v3/src/lib.rs @@ -1,5 +1,5 @@ mod backend; -mod block_allocator; +pub mod block_allocator; mod client; mod queue; mod radix; diff --git a/router/Cargo.toml b/router/Cargo.toml index 1be74546..7773e212 100644 --- a/router/Cargo.toml +++ b/router/Cargo.toml @@ -27,8 +27,14 @@ reqwest = { version = "0.11.20", features = [] } serde = "1.0.188" serde_json = "1.0.107" thiserror = "1.0.48" -tokenizers = { workspace = true} -tokio = { version = "1.32.0", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync"] } +tokenizers = { workspace = true } +tokio = { version = "1.32.0", features = [ + "rt", + "rt-multi-thread", + "parking_lot", + "signal", + "sync", +] } tokio-stream = "0.1.14" tower-http = { version = "0.5.1", features = ["cors"] } tracing = "0.1.40" @@ -37,7 +43,9 @@ tracing-subscriber = { version = "0.3.18", features = ["json", "env-filter"] } utoipa = { version = "4.2.0", features = ["axum_extras"] } utoipa-swagger-ui = { version = "6.0.0", features = ["axum"] } ngrok = { version = "0.13.1", features = ["axum"], optional = true } -init-tracing-opentelemetry = { version = "0.14.1", features = ["opentelemetry-otlp"] } +init-tracing-opentelemetry = { version = "0.14.1", features = [ + "opentelemetry-otlp", +] } minijinja = { version = "2.0.2" } minijinja-contrib = { version = "2.0.2", features = ["pycompat"] } futures-util = "0.3.30" @@ -46,7 +54,11 @@ once_cell = "1.19.0" image = "0.25.1" base64 = { workspace = true } sysinfo = "0.30.13" -uuid = { version = "1.9.1", default-features = false, features = ["v4", "fast-rng", "macro-diagnostics"] } +uuid = { version = "1.9.1", default-features = false, features = [ + "v4", + "fast-rng", + "macro-diagnostics", +] } csv = "1.3.0" ureq = "=2.9"