diff --git a/Dockerfile b/Dockerfile index 57c19ee2..9b8a2054 100644 --- a/Dockerfile +++ b/Dockerfile @@ -9,7 +9,7 @@ WORKDIR /usr/src/router RUN cargo install --path . -FROM nvidia/cuda:11.8.0-devel-ubuntu22.04 +FROM nvidia/cuda:11.6.1-devel-ubuntu18.04 ENV LANG=C.UTF-8 \ LC_ALL=C.UTF-8 \ diff --git a/README.md b/README.md index 78da4db2..6d23d9c5 100644 --- a/README.md +++ b/README.md @@ -43,7 +43,6 @@ python server/bloom_inference/main.py bigscience/bloom --num-gpus 8 --shard-dire ## TODO: -- [ ] Add batching args to router CLI - [ ] Add docstrings + comments everywhere as the codebase is fairly complicated - [ ] Add tests - [ ] Add shutdown logic in router and server diff --git a/router/Cargo.lock b/router/Cargo.lock index c912be89..eda42908 100644 --- a/router/Cargo.lock +++ b/router/Cargo.lock @@ -253,6 +253,43 @@ dependencies = [ "vec_map", ] +[[package]] +name = "clap" +version = "4.0.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6bf8832993da70a4c6d13c581f4463c2bdda27b9bf1c5498dc4365543abe6d6f" +dependencies = [ + "atty", + "bitflags", + "clap_derive", + "clap_lex", + "once_cell", + "strsim 0.10.0", + "termcolor", +] + +[[package]] +name = "clap_derive" +version = "4.0.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c42f169caba89a7d512b5418b09864543eeb4d497416c917d7137863bd2076ad" +dependencies = [ + "heck 0.4.0", + "proc-macro-error", + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "clap_lex" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0d4198f73e42b4936b35b5bb248d81d2b595ecb170da0bac7655c54eedfa8da8" +dependencies = [ + "os_str_bytes", +] + [[package]] name = "console" version = "0.15.2" @@ -701,6 +738,12 @@ dependencies = [ "unicode-segmentation", ] +[[package]] +name = "heck" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2540771e65fc8cb83cd6e8a237f70c319bd5c29f78ed1084ba5d50eeac86f7f9" + [[package]] name = "hermit-abi" version = "0.1.19" @@ -1136,6 +1179,12 @@ dependencies = [ "vcpkg", ] +[[package]] +name = "os_str_bytes" +version = "6.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9ff7415e9ae3fff1225851df9e0d9e4e5479f947619774677a63572e55e80eff" + [[package]] name = "parking_lot" version = "0.12.1" @@ -1225,6 +1274,30 @@ version = "0.2.16" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "eb9f9e6e233e5c4a35559a617bf40a4ec447db2e84c20b55a6f83167b7e57872" +[[package]] +name = "proc-macro-error" +version = "1.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "da25490ff9892aab3fcf7c36f08cfb902dd3e71ca0f9f9517bea02a73a5ce38c" +dependencies = [ + "proc-macro-error-attr", + "proc-macro2", + "quote", + "syn", + "version_check", +] + +[[package]] +name = "proc-macro-error-attr" +version = "1.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a1be40180e52ecc98ad80b184934baf3d0d29f979574e439af5a55274b35f869" +dependencies = [ + "proc-macro2", + "quote", + "version_check", +] + [[package]] name = "proc-macro2" version = "1.0.46" @@ -1251,7 +1324,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "62941722fb675d463659e49c4f3fe1fe792ff24fe5bbaa9c08cd3b98a1c354f5" dependencies = [ "bytes", - "heck", + "heck 0.3.3", "itertools 0.10.5", "lazy_static", "log", @@ -1601,6 +1674,12 @@ version = "0.9.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6446ced80d6c486436db5c078dde11a9f73d42b57fb273121e160b84f63d894c" +[[package]] +name = "strsim" +version = "0.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "73473c0e59e6d5812c5dfe2a064a6444949f089e20eec9a2e5506596494e4623" + [[package]] name = "syn" version = "1.0.101" @@ -1643,6 +1722,15 @@ dependencies = [ "winapi", ] +[[package]] +name = "termcolor" +version = "1.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bab24d30b911b2376f3a13cc2cd443142f0c81dda04c118693e35b3835757755" +dependencies = [ + "winapi-util", +] + [[package]] name = "terminal_size" version = "0.1.17" @@ -1659,6 +1747,7 @@ version = "0.1.0" dependencies = [ "axum", "bloom-inference-client", + "clap 4.0.15", "futures", "parking_lot", "serde", @@ -1742,7 +1831,7 @@ checksum = "3d7b08ede6742d7a59d58c71da8a6fa21bedc433dca2e855e439274d08df1170" dependencies = [ "aho-corasick", "cached-path", - "clap", + "clap 2.34.0", "derive_builder", "dirs", "esaxx-rs", @@ -2251,6 +2340,15 @@ version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ac3b87c63620426dd9b991e5ce0329eff545bccbbb34f3be09ff6fb6ab51b7b6" +[[package]] +name = "winapi-util" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "70ec6ce85bb158151cae5e5c87f95a8e97d2c0c4b001223f33a334e3ce5de178" +dependencies = [ + "winapi", +] + [[package]] name = "winapi-x86_64-pc-windows-gnu" version = "0.4.0" diff --git a/router/Cargo.toml b/router/Cargo.toml index 95666dcf..37f319e9 100644 --- a/router/Cargo.toml +++ b/router/Cargo.toml @@ -2,6 +2,8 @@ name = "text-generation-router" version = "0.1.0" edition = "2021" +authors = ["Olivier Dehaene"] +description = "Text Generation Webserver" [lib] path = "src/lib.rs" @@ -13,6 +15,7 @@ path = "src/main.rs" [dependencies] axum = { version = "0.5.16", features = ["json", "serde_json"] } bloom-inference-client = { path = "client" } +clap = { version = "4.0.15", features = ["derive", "env"] } futures = "0.3.24" parking_lot = "0.12.1" serde = "1.0.145" diff --git a/router/src/batcher.rs b/router/src/batcher.rs index 94c72cc5..ebd81730 100644 --- a/router/src/batcher.rs +++ b/router/src/batcher.rs @@ -27,7 +27,7 @@ impl From for (StatusCode, String) { } #[derive(Clone)] -pub(crate) struct Batcher { +pub struct Batcher { db: Db, shared: Arc, } @@ -37,13 +37,13 @@ struct Shared { } impl Batcher { - pub(crate) fn new(client: ShardedClient) -> Self { + pub(crate) fn new(client: ShardedClient, max_batch_size: usize) -> Self { let db = Db::new(); let shared = Arc::new(Shared { batching_task: Notify::new(), }); - tokio::spawn(batching_task(client, db.clone(), shared.clone())); + tokio::spawn(batching_task(max_batch_size, client, db.clone(), shared.clone())); Self { db, shared } } @@ -70,40 +70,46 @@ impl Batcher { } } -async fn batching_task(client: ShardedClient, db: Db, shared: Arc) { +async fn batching_task(max_batch_size: usize, + client: ShardedClient, + db: Db, + shared: Arc) { + let limit_min_batch_size = (max_batch_size / 2) as u32; + loop { shared.batching_task.notified().await; - if let Some(batch) = db.next_batch(32) { + if let Some(batch) = db.next_batch(max_batch_size) { let request_ids = batch.requests.iter().map(|req| req.id).collect(); let mut cached_batch = match batch.size { - size if size > 16 => { + size if size > limit_min_batch_size => { wrap_future(client.generate_until_finished(batch), request_ids, &db).await } _ => wrap_future(client.generate(batch), request_ids, &db).await, }; while let Some(batch) = cached_batch { - let batch_size = batch.size; + let mut current_batch_size = batch.size; let mut request_ids: Vec = batch.requests.iter().map(|req| req.id).collect(); let mut batches = vec![batch]; - if batch_size <= 16 { - if let Some(new_batch) = db.next_batch_minimum_size(16, 48) { + if current_batch_size <= limit_min_batch_size { + if let Some(new_batch) = db.next_batch_minimum_size(limit_min_batch_size as usize, max_batch_size) { let new_batch_request_ids = new_batch.requests.iter().map(|req| req.id).collect(); let new_cached_batch = wrap_future(client.generate(new_batch), new_batch_request_ids, &db) .await; if let Some(new_cached_batch) = new_cached_batch { + current_batch_size += new_cached_batch.size; request_ids.extend(new_cached_batch.requests.iter().map(|req| req.id)); batches.push(new_cached_batch); } } } - cached_batch = match batch_size { - size if size > 16 => { + cached_batch = match current_batch_size { + size if size > limit_min_batch_size => { wrap_future( client.generate_until_finished_with_cache(batches), request_ids, diff --git a/router/src/lib.rs b/router/src/lib.rs index 09dbdd12..14dc5724 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -1,8 +1,8 @@ mod batcher; mod db; -pub mod server; mod validation; +pub mod server; -use batcher::Batcher; use db::{Db, Entry}; +use batcher::Batcher; use validation::Validation; diff --git a/router/src/main.rs b/router/src/main.rs index 5169a071..89cd4731 100644 --- a/router/src/main.rs +++ b/router/src/main.rs @@ -1,10 +1,36 @@ use bloom_inference_client::ShardedClient; -use std::net::SocketAddr; +use std::net::{IpAddr, Ipv4Addr, SocketAddr}; use text_generation_router::server; use tokenizers::Tokenizer; +use clap::Parser; + +/// App Configuration +#[derive(Parser, Debug)] +#[clap(author, version, about, long_about = None)] +struct Args { + #[clap(default_value = "32", long, short, env)] + max_batch_size: usize, + #[clap(default_value = "3000", long, short, env)] + port: u16, + #[clap(default_value = "/tmp/bloom-inference-0", long, env)] + shard_uds_path: String, + #[clap(default_value = "bigscience/bloom", long, env)] + tokenizer_name: String, +} fn main() -> Result<(), std::io::Error> { - let tokenizer = Tokenizer::from_pretrained("bigscience/bloom", None).unwrap(); + // Get args + let args = Args::parse(); +// Pattern match configuration + let Args { + max_batch_size, + port, + shard_uds_path, + tokenizer_name, + } = args; + + + let tokenizer = Tokenizer::from_pretrained(tokenizer_name, None).unwrap(); tokio::runtime::Builder::new_multi_thread() .enable_all() @@ -13,7 +39,7 @@ fn main() -> Result<(), std::io::Error> { .block_on(async { tracing_subscriber::fmt::init(); - let sharded_client = ShardedClient::connect_uds("/tmp/bloom-inference-0".to_string()) + let sharded_client = ShardedClient::connect_uds(shard_uds_path) .await .expect("Could not connect to server"); sharded_client @@ -22,9 +48,9 @@ fn main() -> Result<(), std::io::Error> { .expect("Unable to clear cache"); tracing::info!("Connected"); - let addr = SocketAddr::from(([0, 0, 0, 0], 3000)); + let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), port); - server::run(sharded_client, tokenizer, addr).await; + server::run(max_batch_size, sharded_client, tokenizer, addr).await; Ok(()) }) } diff --git a/router/src/server.rs b/router/src/server.rs index b8331da1..0fdfd58b 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -64,7 +64,7 @@ pub(crate) struct GenerateRequest { #[instrument(skip(state), fields(time, time_per_token))] async fn liveness(state: Extension) -> Result<(), (StatusCode, String)> { state - .infer + .batcher .infer( 1, GenerateRequest { @@ -97,7 +97,7 @@ async fn generate( }) .await?; - let generated_text = state.infer.infer(input_length, validated_request).await?; + let generated_text = state.batcher.infer(input_length, validated_request).await?; tracing::Span::current().record("time", format!("{:?}", start.elapsed())); tracing::Span::current().record( @@ -114,18 +114,14 @@ async fn generate( #[derive(Clone)] struct ServerState { validation: Validation, - infer: Batcher, + batcher: Batcher, } -pub async fn run(client: ShardedClient, tokenizer: Tokenizer, addr: SocketAddr) { - client.clear_cache().await.expect("Unable to clear cache"); - tracing::info!("Connected"); - - let infer = Batcher::new(client); - +pub async fn run(max_batch_size: usize, client: ShardedClient, tokenizer: Tokenizer, addr: SocketAddr) { + let batcher = Batcher::new(client, max_batch_size); let validation = Validation::new(tokenizer); - let shared_state = ServerState { validation, infer }; + let shared_state = ServerState { validation, batcher }; let app = Router::new() .route("/generate", post(generate)) diff --git a/router/src/validation.rs b/router/src/validation.rs index c9d391aa..45b108fd 100644 --- a/router/src/validation.rs +++ b/router/src/validation.rs @@ -14,7 +14,7 @@ pub enum ValidationError { TopK, #[error("Max New Tokens must be < 512")] MaxNewTokens, - #[error("Inputs must have less than 512 tokens. Given: {0}")] + #[error("Inputs must have less than 1000 tokens. Given: {0}")] InputLength(usize), } @@ -30,7 +30,7 @@ type ValidationRequest = ( ); #[derive(Debug, Clone)] -pub(crate) struct Validation { +pub struct Validation { sender: mpsc::Sender, } @@ -81,7 +81,7 @@ async fn validation_task(tokenizer: Tokenizer, mut receiver: mpsc::Receiver 512 { + if input_length > 1000 { response_tx .send(Err(ValidationError::InputLength(input_length))) .unwrap_or(());