Keeping the benchmark somewhere (#2401)

Co-authored-by: Daniël de Kok <me@danieldk.eu>
This commit is contained in:
Nicolas Patry 2024-08-12 15:22:02 +02:00 committed by GitHub
parent 8deeaca4ff
commit 136bcc8128
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 255 additions and 22 deletions

171
Cargo.lock generated
View File

@ -180,6 +180,17 @@ version = "1.1.2"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1505bd5d3d116872e7271a6d4e16d81d0c8570876c8de68093a09ac269d8aac0" checksum = "1505bd5d3d116872e7271a6d4e16d81d0c8570876c8de68093a09ac269d8aac0"
[[package]]
name = "atty"
version = "0.2.14"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d9b39be18770d11421cdb1b9947a45dd3f37e93092cbf377614828a319d5fee8"
dependencies = [
"hermit-abi 0.1.19",
"libc",
"winapi",
]
[[package]] [[package]]
name = "autocfg" name = "autocfg"
version = "1.3.0" version = "1.3.0"
@ -565,6 +576,12 @@ version = "0.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "df8670b8c7b9dae1793364eafadf7239c40d669904660c5960d74cfd80b46a53" checksum = "df8670b8c7b9dae1793364eafadf7239c40d669904660c5960d74cfd80b46a53"
[[package]]
name = "cast"
version = "0.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "37b2a672a2cb129a2e41c10b1224bb368f9f37a2b16b612598138befd7b37eb5"
[[package]] [[package]]
name = "cc" name = "cc"
version = "1.1.7" version = "1.1.7"
@ -617,6 +634,17 @@ dependencies = [
"libloading", "libloading",
] ]
[[package]]
name = "clap"
version = "2.34.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a0610544180c38b88101fecf2dd634b174a62eef6946f84dfc6a7127512b381c"
dependencies = [
"bitflags 1.3.2",
"textwrap",
"unicode-width",
]
[[package]] [[package]]
name = "clap" name = "clap"
version = "4.5.11" version = "4.5.11"
@ -735,6 +763,42 @@ dependencies = [
"cfg-if", "cfg-if",
] ]
[[package]]
name = "criterion"
version = "0.3.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b01d6de93b2b6c65e17c634a26653a29d107b3c98c607c765bf38d041531cd8f"
dependencies = [
"atty",
"cast",
"clap 2.34.0",
"criterion-plot",
"csv",
"itertools 0.10.5",
"lazy_static",
"num-traits",
"oorandom",
"plotters",
"rayon",
"regex",
"serde",
"serde_cbor",
"serde_derive",
"serde_json",
"tinytemplate",
"walkdir",
]
[[package]]
name = "criterion-plot"
version = "0.4.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2673cc8207403546f45f5fd319a974b1e6983ad1a3ee7e6041650013be041876"
dependencies = [
"cast",
"itertools 0.10.5",
]
[[package]] [[package]]
name = "crossbeam-channel" name = "crossbeam-channel"
version = "0.5.13" version = "0.5.13"
@ -1060,7 +1124,7 @@ checksum = "887d93f60543e9a9362ef8a21beedd0a833c5d9610e18c67abe15a5963dcb1a4"
dependencies = [ dependencies = [
"bit_field", "bit_field",
"flume", "flume",
"half", "half 2.4.1",
"lebe", "lebe",
"miniz_oxide", "miniz_oxide",
"rayon-core", "rayon-core",
@ -1367,6 +1431,12 @@ dependencies = [
"tracing", "tracing",
] ]
[[package]]
name = "half"
version = "1.8.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1b43ede17f21864e81be2fa654110bf1e793774238d86ef8555c37e6519c0403"
[[package]] [[package]]
name = "half" name = "half"
version = "2.4.1" version = "2.4.1"
@ -1404,6 +1474,15 @@ version = "0.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea" checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea"
[[package]]
name = "hermit-abi"
version = "0.1.19"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "62b467343b94ba476dcb2500d242dadbb39557df889310ac77c5d99100aaac33"
dependencies = [
"libc",
]
[[package]] [[package]]
name = "hermit-abi" name = "hermit-abi"
version = "0.3.9" version = "0.3.9"
@ -1804,6 +1883,15 @@ dependencies = [
"either", "either",
] ]
[[package]]
name = "itertools"
version = "0.13.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "413ee7dfc52ee1a4949ceeb7dbc8a33f2d6c088194d9f922fb8318faf1f01186"
dependencies = [
"either",
]
[[package]] [[package]]
name = "itoa" name = "itoa"
version = "1.0.11" version = "1.0.11"
@ -1844,7 +1932,7 @@ dependencies = [
"anyhow", "anyhow",
"base64 0.21.7", "base64 0.21.7",
"bytecount", "bytecount",
"clap", "clap 4.5.11",
"fancy-regex", "fancy-regex",
"fraction", "fraction",
"getrandom", "getrandom",
@ -2132,7 +2220,7 @@ version = "1.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4569e456d394deccd22ce1c1913e6ea0e54519f577285001215d33557431afe4" checksum = "4569e456d394deccd22ce1c1913e6ea0e54519f577285001215d33557431afe4"
dependencies = [ dependencies = [
"hermit-abi", "hermit-abi 0.3.9",
"libc", "libc",
"wasi", "wasi",
"windows-sys 0.52.0", "windows-sys 0.52.0",
@ -2400,7 +2488,7 @@ version = "1.16.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4161fcb6d602d4d2081af7c3a45852d875a03dd337a6bfdd6e06407b61342a43" checksum = "4161fcb6d602d4d2081af7c3a45852d875a03dd337a6bfdd6e06407b61342a43"
dependencies = [ dependencies = [
"hermit-abi", "hermit-abi 0.3.9",
"libc", "libc",
] ]
@ -2456,6 +2544,12 @@ dependencies = [
"pkg-config", "pkg-config",
] ]
[[package]]
name = "oorandom"
version = "11.1.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b410bbe7e14ab526a0e86877eb47c6996a2bd7746f027ba551028c925390e4e9"
[[package]] [[package]]
name = "openssl" name = "openssl"
version = "0.10.66" version = "0.10.66"
@ -2783,6 +2877,34 @@ version = "0.3.30"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d231b230927b5e4ad203db57bbcbee2802f6bce620b1e4a9024a07d94e2907ec" checksum = "d231b230927b5e4ad203db57bbcbee2802f6bce620b1e4a9024a07d94e2907ec"
[[package]]
name = "plotters"
version = "0.3.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a15b6eccb8484002195a3e44fe65a4ce8e93a625797a063735536fd59cb01cf3"
dependencies = [
"num-traits",
"plotters-backend",
"plotters-svg",
"wasm-bindgen",
"web-sys",
]
[[package]]
name = "plotters-backend"
version = "0.3.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "414cec62c6634ae900ea1c56128dfe87cf63e7caece0852ec76aba307cebadb7"
[[package]]
name = "plotters-svg"
version = "0.3.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "81b30686a7d9c3e010b84284bdd26a29f2138574f52f5eb6f794fc0ad924e705"
dependencies = [
"plotters-backend",
]
[[package]] [[package]]
name = "png" name = "png"
version = "0.17.13" version = "0.17.13"
@ -3525,6 +3647,16 @@ dependencies = [
"serde_derive", "serde_derive",
] ]
[[package]]
name = "serde_cbor"
version = "0.11.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2bef2ebfde456fb76bbcf9f59315333decc4fda0b2b44b420243c11e0f5ec1f5"
dependencies = [
"half 1.8.3",
"serde",
]
[[package]] [[package]]
name = "serde_derive" name = "serde_derive"
version = "1.0.204" version = "1.0.204"
@ -3891,7 +4023,7 @@ version = "2.2.1-dev0"
dependencies = [ dependencies = [
"async-stream", "async-stream",
"async-trait", "async-trait",
"clap", "clap 4.5.11",
"cmake", "cmake",
"cxx", "cxx",
"cxx-build", "cxx-build",
@ -3912,7 +4044,7 @@ name = "text-generation-benchmark"
version = "2.2.1-dev0" version = "2.2.1-dev0"
dependencies = [ dependencies = [
"average", "average",
"clap", "clap 4.5.11",
"crossterm", "crossterm",
"float-ord", "float-ord",
"hf-hub", "hf-hub",
@ -3950,7 +4082,7 @@ dependencies = [
name = "text-generation-launcher" name = "text-generation-launcher"
version = "2.2.1-dev0" version = "2.2.1-dev0"
dependencies = [ dependencies = [
"clap", "clap 4.5.11",
"ctrlc", "ctrlc",
"float_eq", "float_eq",
"hf-hub", "hf-hub",
@ -3974,7 +4106,7 @@ dependencies = [
"axum 0.7.5", "axum 0.7.5",
"axum-tracing-opentelemetry", "axum-tracing-opentelemetry",
"base64 0.22.1", "base64 0.22.1",
"clap", "clap 4.5.11",
"csv", "csv",
"futures", "futures",
"futures-util", "futures-util",
@ -4022,13 +4154,15 @@ dependencies = [
"axum 0.7.5", "axum 0.7.5",
"axum-tracing-opentelemetry", "axum-tracing-opentelemetry",
"base64 0.22.1", "base64 0.22.1",
"clap", "clap 4.5.11",
"criterion",
"futures", "futures",
"futures-util", "futures-util",
"grpc-metadata", "grpc-metadata",
"hf-hub", "hf-hub",
"image", "image",
"init-tracing-opentelemetry", "init-tracing-opentelemetry",
"itertools 0.13.0",
"jsonschema", "jsonschema",
"metrics", "metrics",
"metrics-exporter-prometheus", "metrics-exporter-prometheus",
@ -4062,6 +4196,15 @@ dependencies = [
"utoipa-swagger-ui", "utoipa-swagger-ui",
] ]
[[package]]
name = "textwrap"
version = "0.11.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d326610f408c7a4eb6f51c37c330e496b08506c9457c9d34287ecc38809fb060"
dependencies = [
"unicode-width",
]
[[package]] [[package]]
name = "thiserror" name = "thiserror"
version = "1.0.63" version = "1.0.63"
@ -4136,6 +4279,16 @@ dependencies = [
"time-core", "time-core",
] ]
[[package]]
name = "tinytemplate"
version = "1.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "be4d6b5f19ff7664e8c98d03e2139cb510db9b0a60b55f8e8709b689d939b6bc"
dependencies = [
"serde",
"serde_json",
]
[[package]] [[package]]
name = "tinyvec" name = "tinyvec"
version = "1.8.0" version = "1.8.0"

View File

@ -36,7 +36,13 @@ serde_json = "1.0.107"
slotmap = "1.0.7" slotmap = "1.0.7"
thiserror = "1.0.48" thiserror = "1.0.48"
tokenizers = { workspace = true } tokenizers = { workspace = true }
tokio = { version = "1.32.0", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync"] } tokio = { version = "1.32.0", features = [
"rt",
"rt-multi-thread",
"parking_lot",
"signal",
"sync",
] }
tokio-stream = "0.1.14" tokio-stream = "0.1.14"
tower-http = { version = "0.5.1", features = ["cors"] } tower-http = { version = "0.5.1", features = ["cors"] }
tracing = "0.1.37" tracing = "0.1.37"
@ -44,7 +50,9 @@ tracing-opentelemetry = "0.21.0"
tracing-subscriber = { version = "0.3.17", features = ["json", "env-filter"] } tracing-subscriber = { version = "0.3.17", features = ["json", "env-filter"] }
utoipa = { version = "4.2.0", features = ["axum_extras"] } utoipa = { version = "4.2.0", features = ["axum_extras"] }
utoipa-swagger-ui = { version = "6.0.0", features = ["axum"] } utoipa-swagger-ui = { version = "6.0.0", features = ["axum"] }
init-tracing-opentelemetry = { version = "0.14.1", features = ["opentelemetry-otlp"] } init-tracing-opentelemetry = { version = "0.14.1", features = [
"opentelemetry-otlp",
] }
minijinja = { version = "2.0.2" } minijinja = { version = "2.0.2" }
minijinja-contrib = { version = "2.0.2", features = ["pycompat"] } minijinja-contrib = { version = "2.0.2", features = ["pycompat"] }
futures-util = "0.3.30" futures-util = "0.3.30"
@ -60,8 +68,16 @@ tower = "^0.4"
tonic-build = "0.10.1" tonic-build = "0.10.1"
prost-build = "0.12.1" prost-build = "0.12.1"
[dev-dependencies]
criterion = "0.3"
itertools = "0.13"
[features] [features]
default = ["ngrok"] default = ["ngrok"]
ngrok = ["text-generation-router/ngrok"] ngrok = ["text-generation-router/ngrok"]
google = ["text-generation-router/google"] google = ["text-generation-router/google"]
kserve = ["text-generation-router/kserve"] kserve = ["text-generation-router/kserve"]
[[bench]]
name = "prefix_cache"
harness = false

View File

@ -0,0 +1,47 @@
use std::sync::Arc;
use criterion::{black_box, criterion_group, criterion_main, Criterion};
use rand::Rng;
use text_generation_router_v3::block_allocator::Allocator;
use text_generation_router_v3::radix::RadixAllocator;
fn prefix_cache_benchmark(c: &mut Criterion) {
// let prefixes: Vec<Vec<u32>> = (0..8192)
// .chunks(256)
// .into_iter()
// .map(|c| c.collect())
// .collect();
let mut cache = RadixAllocator::new(1, 262144, None);
c.bench_function("Radix allocator", |b| {
b.iter_batched(
|| {
//prefixes
// .choose_multiple(&mut rand::thread_rng(), 5)
// .fold(Vec::new(), |mut v, s| {
// v.extend(s);
// v
// })
(0..7936)
.map(|_| rand::thread_rng().gen_range(0..1024))
.collect::<Vec<u32>>()
},
|prefill| {
let alloc = cache.allocate(
prefill.len() as u32 + 13,
Some(Arc::new(black_box(prefill))),
);
if let Some(alloc) = alloc {
cache.free(alloc.blocks.clone(), alloc.allocation_id);
}
},
criterion::BatchSize::SmallInput,
);
});
}
criterion_group!(benches, prefix_cache_benchmark);
criterion_main!(benches);

View File

@ -4,7 +4,7 @@ use tokio::sync::{mpsc, oneshot};
use crate::radix::RadixAllocator; use crate::radix::RadixAllocator;
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub(crate) struct BlockAllocation { pub struct BlockAllocation {
pub allocation_id: u64, pub allocation_id: u64,
pub blocks: Vec<u32>, pub blocks: Vec<u32>,
pub slots: Vec<u32>, pub slots: Vec<u32>,
@ -25,7 +25,7 @@ impl Drop for BlockAllocation {
} }
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub(crate) struct BlockAllocator { pub struct BlockAllocator {
/// Channel to communicate with the background task /// Channel to communicate with the background task
block_allocator: mpsc::UnboundedSender<BlockAllocatorCommand>, block_allocator: mpsc::UnboundedSender<BlockAllocatorCommand>,
} }
@ -128,7 +128,7 @@ enum BlockAllocatorCommand {
}, },
} }
pub(crate) trait Allocator { pub trait Allocator {
fn allocate( fn allocate(
&mut self, &mut self,
tokens: u32, tokens: u32,

View File

@ -1,8 +1,8 @@
mod backend; mod backend;
mod block_allocator; pub mod block_allocator;
mod client; mod client;
mod queue; mod queue;
mod radix; pub mod radix;
use crate::client::{ClientError, ShardedClient}; use crate::client::{ClientError, ShardedClient};
pub(crate) use backend::BackendV3; pub(crate) use backend::BackendV3;

View File

@ -250,7 +250,7 @@ impl State {
// Create span for this batch to add context to inference calls // Create span for this batch to add context to inference calls
let next_batch_span = info_span!(parent: None, "batch", batch_size = tracing::field::Empty); let next_batch_span = info_span!(parent: None, "batch", batch_size = tracing::field::Empty);
next_batch_span.follows_from(&Span::current()); next_batch_span.follows_from(Span::current());
let mut batch_requests = Vec::with_capacity(self.entries.len()); let mut batch_requests = Vec::with_capacity(self.entries.len());
let mut batch_entries = let mut batch_entries =

View File

@ -205,6 +205,11 @@ pub struct RadixTrie {
/// call that a real time lookup would require. /// call that a real time lookup would require.
time: u64, time: u64,
} }
impl Default for RadixTrie {
fn default() -> Self {
Self::new()
}
}
impl RadixTrie { impl RadixTrie {
/// Construct a new radix trie. /// Construct a new radix trie.

View File

@ -28,7 +28,13 @@ serde = "1.0.188"
serde_json = "1.0.107" serde_json = "1.0.107"
thiserror = "1.0.48" thiserror = "1.0.48"
tokenizers = { workspace = true } tokenizers = { workspace = true }
tokio = { version = "1.32.0", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync"] } tokio = { version = "1.32.0", features = [
"rt",
"rt-multi-thread",
"parking_lot",
"signal",
"sync",
] }
tokio-stream = "0.1.14" tokio-stream = "0.1.14"
tower-http = { version = "0.5.1", features = ["cors"] } tower-http = { version = "0.5.1", features = ["cors"] }
tracing = "0.1.40" tracing = "0.1.40"
@ -37,7 +43,9 @@ tracing-subscriber = { version = "0.3.18", features = ["json", "env-filter"] }
utoipa = { version = "4.2.0", features = ["axum_extras"] } utoipa = { version = "4.2.0", features = ["axum_extras"] }
utoipa-swagger-ui = { version = "6.0.0", features = ["axum"] } utoipa-swagger-ui = { version = "6.0.0", features = ["axum"] }
ngrok = { version = "0.13.1", features = ["axum"], optional = true } ngrok = { version = "0.13.1", features = ["axum"], optional = true }
init-tracing-opentelemetry = { version = "0.14.1", features = ["opentelemetry-otlp"] } init-tracing-opentelemetry = { version = "0.14.1", features = [
"opentelemetry-otlp",
] }
minijinja = { version = "2.0.2" } minijinja = { version = "2.0.2" }
minijinja-contrib = { version = "2.0.2", features = ["pycompat"] } minijinja-contrib = { version = "2.0.2", features = ["pycompat"] }
futures-util = "0.3.30" futures-util = "0.3.30"
@ -46,7 +54,11 @@ once_cell = "1.19.0"
image = "0.25.1" image = "0.25.1"
base64 = { workspace = true } base64 = { workspace = true }
sysinfo = "0.30.13" sysinfo = "0.30.13"
uuid = { version = "1.9.1", default-features = false, features = ["v4", "fast-rng", "macro-diagnostics"] } uuid = { version = "1.9.1", default-features = false, features = [
"v4",
"fast-rng",
"macro-diagnostics",
] }
csv = "1.3.0" csv = "1.3.0"
ureq = "=2.9" ureq = "=2.9"