diff --git a/Cargo.lock b/Cargo.lock index 752c4886..33f5d181 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1087,6 +1087,12 @@ dependencies = [ "libc", ] +[[package]] +name = "nohash-hasher" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2bf50223579dc7cdcfb3bfcacf7069ff68243f8c363f62ffa99cf000a6b9c451" + [[package]] name = "nom" version = "7.1.1" @@ -1826,6 +1832,7 @@ dependencies = [ "axum", "clap 4.0.22", "futures", + "nohash-hasher", "parking_lot", "serde", "serde_json", diff --git a/router/Cargo.toml b/router/Cargo.toml index f99069d3..546f127f 100644 --- a/router/Cargo.toml +++ b/router/Cargo.toml @@ -17,6 +17,7 @@ axum = { version = "0.5.16", features = ["json", "serde_json"] } text-generation-client = { path = "client" } clap = { version = "4.0.15", features = ["derive", "env"] } futures = "0.3.24" +nohash-hasher = "0.2.0" parking_lot = "0.12.1" serde = "1.0.145" serde_json = "1.0.85" diff --git a/router/src/batcher.rs b/router/src/batcher.rs index ee83d899..624ac82d 100644 --- a/router/src/batcher.rs +++ b/router/src/batcher.rs @@ -3,6 +3,7 @@ use crate::{Db, Entry}; use crate::{ErrorResponse, GenerateRequest}; use axum::http::StatusCode; use axum::Json; +use nohash_hasher::IntMap; use std::future::Future; use std::sync::Arc; use text_generation_client::{Batch, ClientError, GeneratedText, ShardedClient}; @@ -104,8 +105,8 @@ async fn batching_task( // Get the next batch from the DB // This batch might be smaller than the maximum batch size if there are not enough requests // waiting in the DB - while let Some((request_ids, batch)) = db.next_batch(None, max_batch_size) { - let mut cached_batch = wrap_future(client.generate(batch), request_ids, &db).await; + while let Some((mut entries, batch)) = db.next_batch(None, max_batch_size) { + let mut cached_batch = wrap_future(client.generate(batch), &mut entries).await; let mut waiting_tokens = 1; // We loop until we do not receive any cached batch from the inference server (== until @@ -113,7 +114,6 @@ async fn batching_task( while let Some(batch) = cached_batch { // Get current batch info let batch_size = batch.size; - let mut request_ids: Vec = batch.requests.iter().map(|req| req.id).collect(); let mut batches = vec![batch]; // If the current batch is too small, we try to add more requests to it @@ -127,24 +127,23 @@ async fn batching_task( }; // Try to get a new batch - if let Some((new_request_ids, new_batch)) = + if let Some((mut new_entries, new_batch)) = db.next_batch(min_size, max_batch_size - batch_size as usize) { // Generate one token for this new batch to have the attention past in cache let new_cached_batch = - wrap_future(client.generate(new_batch), new_request_ids, &db).await; + wrap_future(client.generate(new_batch), &mut new_entries).await; // Reset waiting counter waiting_tokens = 1; // Extend current batch with the new batch if let Some(new_cached_batch) = new_cached_batch { - request_ids.extend(new_cached_batch.requests.iter().map(|req| req.id)); + entries.extend(new_entries); batches.push(new_cached_batch); } } } - cached_batch = - wrap_future(client.generate_with_cache(batches), request_ids, &db).await; + cached_batch = wrap_future(client.generate_with_cache(batches), &mut entries).await; waiting_tokens += 1; } } @@ -154,39 +153,36 @@ async fn batching_task( /// Wrap a future inside a match statement to handle errors and send the response to the Batcher async fn wrap_future( future: impl Future, Option), ClientError>>, - request_ids: Vec, - db: &Db, + entries: &mut IntMap, ) -> Option { match future.await { Ok((generated_texts, next_batch)) => { - send_generated(generated_texts, db); + send_generated(generated_texts, entries); next_batch } // If we have an error, we discard the whole batch Err(err) => { - send_error(err, request_ids, db); + send_error(err, entries); None } } } -/// Send errors to the Batcher for all `request_ids` -fn send_error(error: ClientError, request_ids: Vec, db: &Db) { - request_ids.into_iter().for_each(|id| { - // We can `expect` here as the request id should always be in the DB - let entry = db.remove(&id).expect("ID not found in db. This is a bug."); +/// Send errors to the Batcher for all `entries` +fn send_error(error: ClientError, entries: &mut IntMap) { + entries.drain().for_each(|(_, entry)| { // unwrap_or is valid here as we don't care if the receiver is gone. entry.response_tx.send(Err(error.clone())).unwrap_or(()); }); } /// Send `generated_text` to the Batcher for all `finished` -fn send_generated(finished: Vec, db: &Db) { +fn send_generated(finished: Vec, entries: &mut IntMap) { finished.into_iter().for_each(|output| { - // We can `expect` here as the request id should always be in the DB - let entry = db + // We can `expect` here as the request id should always be in the entries + let entry = entries .remove(&output.request.unwrap().id) - .expect("ID not found in db. This is a bug."); + .expect("ID not found in entries. This is a bug."); let response = InferResponse { output_text: output.output_text, diff --git a/router/src/db.rs b/router/src/db.rs index 1d7df627..51de9d05 100644 --- a/router/src/db.rs +++ b/router/src/db.rs @@ -1,6 +1,7 @@ -use crate::InferResponse; /// This code is massively inspired by Tokio mini-redis +use crate::InferResponse; use crate::{GenerateParameters, GenerateRequest}; +use nohash_hasher::{BuildNoHashHasher, IntMap}; use parking_lot::Mutex; use std::collections::BTreeMap; use std::sync::Arc; @@ -112,18 +113,12 @@ impl Db { state.entries.insert(id, entry); } - /// Remove an entry from the database if it exists - pub(crate) fn remove(&self, id: &u64) -> Option { - let mut state = self.shared.state.lock(); - state.entries.remove(id) - } - // Get the next batch pub(crate) fn next_batch( &self, min_size: Option, max_size: usize, - ) -> Option<(Vec, Batch)> { + ) -> Option<(IntMap, Batch)> { // Acquire lock let mut state = self.shared.state.lock(); @@ -135,13 +130,19 @@ impl Db { return None; } } - ids.iter().for_each(|id| { - // Set batch_time for each request - state.entries.get_mut(id).unwrap().batch_time = Some(Instant::now()); - }); - // Batch size let size = requests.len(); + + let mut entries = IntMap::with_capacity_and_hasher(size, BuildNoHashHasher::default()); + ids.iter().for_each(|id| { + // Remove entry from db + let mut entry = state.entries.remove(id).unwrap(); + // Set batch_time + entry.batch_time = Some(Instant::now()); + // Insert in entries IntMap + entries.insert(*id, entry); + }); + let batch = Batch { id: state.next_batch_id, requests, @@ -152,7 +153,7 @@ impl Db { // Increment batch id state.next_batch_id += 1; - return Some((ids, batch)); + return Some((entries, batch)); } None }