From 4c693e65245058a4d0ca227ee30b6d8a35d115f1 Mon Sep 17 00:00:00 2001 From: Olivier Dehaene Date: Tue, 11 Oct 2022 16:50:54 +0200 Subject: [PATCH] Refactored gRPC interface Added validation logic --- README.md | 21 +- proto/generate.proto | 113 +++++++---- router/client/src/client.rs | 61 ++++-- router/client/src/lib.rs | 4 +- router/client/src/sharded_client.rs | 64 ++++-- router/src/batcher.rs | 124 ++++++------ router/src/db.rs | 10 + router/src/main.rs | 42 ++-- router/src/server.rs | 42 ++-- router/src/validation.rs | 65 +++++++ server/bloom_inference/cache.py | 37 +--- server/bloom_inference/model.py | 290 +++++++++++++++------------- server/bloom_inference/server.py | 100 ++++++++-- 13 files changed, 612 insertions(+), 361 deletions(-) create mode 100644 router/src/validation.rs diff --git a/README.md b/README.md index fe7bb443..ad8adaa9 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,23 @@ -# BLOOM Inference +# Text Generation Inference -A Rust and gRPC server for BLOOM Inference. +A Rust and gRPC server for text generation inference. + +## Load Tests + +See `k6/load_test.js` +We send the default examples with a 1 second delay between each request. + +Stages: +- Ramp up to 50 concurrent requests per second in 1min +- Ramp up from 50 to 100 concurrent requests per second in 2min +- Ramp down to 0 concurrent requests per second in 1min + + +| | avg | min | med | max | p(90) | p(95) | RPS | +|------------------------|-----------|-----------|-----------|------------|-----------|-----------|----------| +| Original code | 8.9s | 1s | 9.12s | 16.69s | 13.7s | 14.26s | 5.9 | +| ISO with original code | 8.88s | 959.53ms | 8.89s | 17.08s | 13.34s | 14.12s | 5.94 | +| New batching logic | **5.44s** | **1.27s** | **5.28s** | **13.12s** | **7.78s** | **8.92s** | **9.08** | ## Install diff --git a/proto/generate.proto b/proto/generate.proto index 324b9206..8c5221b4 100644 --- a/proto/generate.proto +++ b/proto/generate.proto @@ -2,21 +2,35 @@ syntax = "proto3"; package generate.v1; -service TextGeneration { +service TextGenerationService { /// Service discovery - rpc ServiceDiscovery(Empty) returns (ServiceDiscoveryResponse) {} + rpc ServiceDiscovery (ServiceDiscoveryRequest) returns (ServiceDiscoveryResponse) {} /// Empties batch cache - rpc ClearCache(Empty) returns (Empty); - /// Generate tokens for a batch without cache - rpc Generate(Batch) returns (Response); - /// Generate tokens for a batch with cache - rpc GenerateWithCache(BatchCached) returns (Response); + rpc ClearCache (ClearCacheRequest) returns (ClearCacheResponse); + /// Generate tokens for a batch + rpc Generate (GenerateRequest) returns (GenerateResponse); + /// Generate tokens for a list of cached batches + rpc GenerateWithCache (GenerateWithCacheRequest) returns (GenerateWithCacheResponse); + /// Generate tokens until the text of at least one request of the batch is generated + rpc GenerateUntilFinished (GenerateUntilFinishedRequest) returns (GenerateUntilFinishedResponse); + /// Generate tokens until the text of at least one request of the cached batches i finished + rpc GenerateUntilFinishedWithCache (GenerateUntilFinishedWithCacheRequest) returns (GenerateUntilFinishedWithCacheResponse); } +/// Empty request +message ServiceDiscoveryRequest {} + message ServiceDiscoveryResponse { + /// Other shards urls repeated string urls = 1; } +/// Empty request +message ClearCacheRequest {} + +/// Empty response +message ClearCacheResponse {} + message LogitsWarperParameters { float temperature = 1; uint32 top_k = 2; @@ -29,10 +43,12 @@ message Request { uint64 id = 1; /// The generation context string inputs = 2; + /// The number of tokens inside inputs + uint32 input_length = 3; /// Logits Warper Parameters - LogitsWarperParameters parameters = 3; + LogitsWarperParameters parameters = 4; /// Stopping criteria - uint32 max_new_tokens = 4; + uint32 max_new_tokens = 5; } message Batch { @@ -40,44 +56,63 @@ message Batch { uint64 id = 1; /// Individual requests repeated Request requests = 2; + /// Batch size (==len(requests)) + uint32 size = 3; + /// Length of the longest sequence within the batch (used for padding) + uint32 max_sequence_length = 4; } -message BatchCached { - /// Batch ID - uint64 id = 1; - /// Request ids within cache - repeated uint64 request_ids = 2; - /// Cache IDs - repeated uint64 batch_cached_ids = 3; - /// Batch size (sum of all batch sizes) - uint32 total_batch_size = 4; - /// Max sequence length - uint32 max_sequence_length = 5; -} - -message FinishedGeneration { - /// ID of the original request - uint64 id = 1; +message GeneratedText { + /// Request + Request request = 1; /// Output string output = 2; } -message CacheEntry { - /// Cache ID; same as batch ID - uint64 id = 1; - /// Requests present in cache entry - repeated uint64 request_ids = 2; - /// Sequence length - uint32 sequence_length = 3; +message GenerateRequest { + /// Batch + Batch batch = 1; } -message Response { - /// Finished requests (optional) - repeated FinishedGeneration finished = 1; - /// Cache entry (optional) - optional CacheEntry cache_entry = 2; +message GenerateResponse { + /// Finished requests + repeated GeneratedText generated_texts = 1; + /// Next batch (cached) + optional Batch batch = 2; } +message GenerateWithCacheRequest { + /// Cached batches + repeated Batch batches = 1; +} -// Represent an empty message. -message Empty {} \ No newline at end of file +message GenerateWithCacheResponse { + /// Finished requests + repeated GeneratedText generated_texts = 1; + /// Next batch (cached) + optional Batch batch = 2; +} + +message GenerateUntilFinishedRequest { + /// Batch + Batch batch = 1; +} + +message GenerateUntilFinishedResponse { + /// Finished requests + repeated GeneratedText generated_texts = 1; + /// Next batch (cached) + optional Batch batch = 2; +} + +message GenerateUntilFinishedWithCacheRequest { + /// Cached batches + repeated Batch batches = 1; +} + +message GenerateUntilFinishedWithCacheResponse { + /// Finished requests + repeated GeneratedText generated_texts = 1; + /// Next batch (cached) + optional Batch batch = 2; +} diff --git a/router/client/src/client.rs b/router/client/src/client.rs index 7f68e252..712c13da 100644 --- a/router/client/src/client.rs +++ b/router/client/src/client.rs @@ -1,4 +1,4 @@ -use crate::pb::generate::v1::text_generation_client::TextGenerationClient; +use crate::pb::generate::v1::text_generation_service_client::TextGenerationServiceClient; use crate::pb::generate::v1::*; use crate::Result; use std::time::Duration; @@ -9,7 +9,7 @@ use tracing::*; /// BLOOM Inference gRPC client #[derive(Clone)] pub struct Client { - stub: TextGenerationClient>, + stub: TextGenerationServiceClient>, } impl Client { @@ -22,13 +22,13 @@ impl Client { let timeout_channel = Timeout::new(channel, timeout); Self { - stub: TextGenerationClient::new(timeout_channel), + stub: TextGenerationServiceClient::new(timeout_channel), } } /// Returns a client connected to the given unix socket. Requests exceeding timeout will fail. pub async fn connect_uds(path: String, timeout: Duration) -> Self { - let channel = Channel::from_shared(format!("http://[::]:50051")) + let channel = Channel::from_shared("http://[::]:50051".to_string()) .unwrap() .connect_with_connector(tower::service_fn(move |_: Uri| { tokio::net::UnixStream::connect(path.clone()) @@ -38,13 +38,13 @@ impl Client { let timeout_channel = Timeout::new(channel, timeout); Self { - stub: TextGenerationClient::new(timeout_channel), + stub: TextGenerationServiceClient::new(timeout_channel), } } #[instrument(skip(self))] pub async fn service_discovery(&mut self) -> Result> { - let request = tonic::Request::new(Empty {}); + let request = tonic::Request::new(ServiceDiscoveryRequest {}); let response = self .stub .service_discovery(request) @@ -64,7 +64,7 @@ impl Client { #[instrument(skip(self))] pub async fn clear_cache(&mut self) -> Result<()> { - let request = tonic::Request::new(Empty {}); + let request = tonic::Request::new(ClearCacheRequest {}); self.stub .clear_cache(request) .instrument(info_span!("clear_cache")) @@ -73,32 +73,59 @@ impl Client { } #[instrument(skip(self))] - pub async fn generate( - &mut self, - request: Batch, - ) -> Result<(Vec, Option)> { - let request = tonic::Request::new(request); + pub async fn generate(&mut self, batch: Batch) -> Result<(Vec, Option)> { + let request = tonic::Request::new(GenerateRequest { batch: Some(batch) }); let response = self .stub .generate(request) .instrument(info_span!("generate")) .await? .into_inner(); - Ok((response.finished, response.cache_entry)) + Ok((response.generated_texts, response.batch)) } #[instrument(skip(self))] pub async fn generate_with_cache( &mut self, - request: BatchCached, - ) -> Result<(Vec, Option)> { - let request = tonic::Request::new(request); + batches: Vec, + ) -> Result<(Vec, Option)> { + let request = tonic::Request::new(GenerateWithCacheRequest { batches }); let response = self .stub .generate_with_cache(request) .instrument(info_span!("generate_with_cache")) .await? .into_inner(); - Ok((response.finished, response.cache_entry)) + Ok((response.generated_texts, response.batch)) + } + + #[instrument(skip(self))] + pub async fn generate_until_finished( + &mut self, + batch: Batch, + ) -> Result<(Vec, Option)> { + let request = tonic::Request::new(GenerateUntilFinishedRequest { batch: Some(batch) }); + let response = self + .stub + .generate_until_finished(request) + .instrument(info_span!("generate_until_finished")) + .await? + .into_inner(); + Ok((response.generated_texts, response.batch)) + } + + #[instrument(skip(self))] + pub async fn generate_until_finished_with_cache( + &mut self, + batches: Vec, + ) -> Result<(Vec, Option)> { + let request = tonic::Request::new(GenerateUntilFinishedWithCacheRequest { batches }); + let response = self + .stub + .generate_until_finished_with_cache(request) + .instrument(info_span!("generate_until_finished_with_cache")) + .await? + .into_inner(); + Ok((response.generated_texts, response.batch)) } } diff --git a/router/client/src/lib.rs b/router/client/src/lib.rs index aca9f65b..bb9919e2 100644 --- a/router/client/src/lib.rs +++ b/router/client/src/lib.rs @@ -5,9 +5,7 @@ mod pb; mod sharded_client; pub use client::Client; -pub use pb::generate::v1::{ - Batch, BatchCached, CacheEntry, FinishedGeneration, LogitsWarperParameters, Request, -}; +pub use pb::generate::v1::{Batch, GeneratedText, LogitsWarperParameters, Request}; pub use sharded_client::ShardedClient; use thiserror::Error; pub use tonic::transport::Uri; diff --git a/router/client/src/sharded_client.rs b/router/client/src/sharded_client.rs index 6856a9bc..7af741f0 100644 --- a/router/client/src/sharded_client.rs +++ b/router/client/src/sharded_client.rs @@ -1,5 +1,5 @@ use crate::Result; -use crate::{Batch, BatchCached, CacheEntry, Client, FinishedGeneration}; +use crate::{Batch, Client, GeneratedText}; use futures::future::join_all; use std::time::Duration; use tokio::sync::{broadcast, mpsc}; @@ -9,11 +9,19 @@ use tonic::transport::Uri; enum Command { Generate( Batch, - mpsc::Sender, Option)>>, + mpsc::Sender, Option)>>, ), GenerateWithCache( - BatchCached, - mpsc::Sender, Option)>>, + Vec, + mpsc::Sender, Option)>>, + ), + GenerateUntilFinished( + Batch, + mpsc::Sender, Option)>>, + ), + GenerateUntilFinishedWithCache( + Vec, + mpsc::Sender, Option)>>, ), ClearCache(mpsc::Sender>), } @@ -25,8 +33,16 @@ async fn client_task(mut client: Client, mut request_subscriber: broadcast::Rece let result = client.generate(batch).await; response_tx.try_send(result).unwrap_or(()); } - Command::GenerateWithCache(batch_cached, response_tx) => { - let result = client.generate_with_cache(batch_cached).await; + Command::GenerateWithCache(batches, response_tx) => { + let result = client.generate_with_cache(batches).await; + response_tx.try_send(result).unwrap_or(()); + } + Command::GenerateUntilFinished(batch, response_tx) => { + let result = client.generate_until_finished(batch).await; + response_tx.try_send(result).unwrap_or(()); + } + Command::GenerateUntilFinishedWithCache(batches, response_tx) => { + let result = client.generate_until_finished_with_cache(batches).await; response_tx.try_send(result).unwrap_or(()); } Command::ClearCache(response_tx) => { @@ -74,10 +90,7 @@ impl ShardedClient { Self::from_master_client(master_client).await } - pub async fn generate( - &self, - batch: Batch, - ) -> Result<(Vec, Option)> { + pub async fn generate(&self, batch: Batch) -> Result<(Vec, Option)> { let (response_tx, mut response_rx) = mpsc::channel(1); self.request_tx .send(Command::Generate(batch, response_tx)) @@ -87,11 +100,36 @@ impl ShardedClient { pub async fn generate_with_cache( &self, - batch_cached: BatchCached, - ) -> Result<(Vec, Option)> { + batches: Vec, + ) -> Result<(Vec, Option)> { let (response_tx, mut response_rx) = mpsc::channel(1); self.request_tx - .send(Command::GenerateWithCache(batch_cached, response_tx)) + .send(Command::GenerateWithCache(batches, response_tx)) + .unwrap(); + response_rx.recv().await.unwrap() + } + + pub async fn generate_until_finished( + &self, + batch: Batch, + ) -> Result<(Vec, Option)> { + let (response_tx, mut response_rx) = mpsc::channel(1); + self.request_tx + .send(Command::GenerateUntilFinished(batch, response_tx)) + .unwrap(); + response_rx.recv().await.unwrap() + } + + pub async fn generate_until_finished_with_cache( + &self, + batches: Vec, + ) -> Result<(Vec, Option)> { + let (response_tx, mut response_rx) = mpsc::channel(1); + self.request_tx + .send(Command::GenerateUntilFinishedWithCache( + batches, + response_tx, + )) .unwrap(); response_rx.recv().await.unwrap() } diff --git a/router/src/batcher.rs b/router/src/batcher.rs index 2da47dfc..a044e26c 100644 --- a/router/src/batcher.rs +++ b/router/src/batcher.rs @@ -1,10 +1,9 @@ -use crate::Db; -use bloom_inference_client::{ - Batch, BatchCached, CacheEntry, ClientError, FinishedGeneration, ShardedClient, -}; -use std::sync::Arc; -use tokio::sync::{Notify, oneshot}; use crate::server::GenerateRequest; +use crate::Db; +use bloom_inference_client::{Batch, ClientError, GeneratedText, ShardedClient}; +use std::future::Future; +use std::sync::Arc; +use tokio::sync::{oneshot, Notify}; const MAX_LENGTH: usize = 128; @@ -32,12 +31,16 @@ impl Batcher { Self { db, shared } } - pub(crate) async fn infer(&self, request: GenerateRequest) -> Result { + pub(crate) async fn infer( + &self, + input_length: usize, + request: GenerateRequest, + ) -> Result { if self.db.len() > MAX_LENGTH { return Err(InferError {}); } let (request_tx, request_rx) = oneshot::channel(); - self.db.append(request, request_tx); + self.db.append(input_length, request, request_tx); self.shared.batching_task.notify_waiters(); match request_rx.await.unwrap() { Ok(output) => Ok(output), @@ -51,76 +54,57 @@ async fn batching_task(client: ShardedClient, db: Db, shared: Arc) { shared.batching_task.notified().await; if let Some(batch) = db.next_batch(32) { - let mut cache_entry = infer_batch(batch, &client, &db).await; - - loop { - if let Some(entry) = cache_entry { - let mut batch_cached_ids = vec![entry.id]; - let mut total_batch_size = entry.request_ids.len(); - let mut max_sequence_length = entry.sequence_length; - let mut request_ids = entry.request_ids; - - // if total_batch_size <= 16 { - // if let Some(batch) = db.next_batch_minimum_size(16, 48) { - // let other_cache_entry = infer_batch(batch, &client, &db).await; - // - // if let Some(entry) = other_cache_entry { - // batch_cached_ids.push(entry.id); - // total_batch_size += entry.request_ids.len(); - // max_sequence_length = - // max_sequence_length.max(entry.sequence_length); - // request_ids.extend(entry.request_ids.into_iter()); - // } - // } - // } - - let batch_cached = BatchCached { - id: entry.id, - batch_cached_ids, - total_batch_size: total_batch_size as u32, - max_sequence_length, - request_ids, - }; - cache_entry = infer_batch_cached(batch_cached, &client, &db).await; - } else { - break; + let request_ids = batch.requests.iter().map(|req| req.id).collect(); + let mut cached_batch = match batch.size { + size if size > 16 => { + wrap_future(client.generate_until_finished(batch), request_ids, &db).await } + _ => wrap_future(client.generate(batch), request_ids, &db).await, + }; + + while let Some(batch) = cached_batch { + let batch_size = batch.size; + let mut request_ids: Vec = batch.requests.iter().map(|req| req.id).collect(); + let mut batches = vec![batch]; + + if batch_size <= 16 { + if let Some(new_batch) = db.next_batch_minimum_size(16, 48) { + let new_batch_request_ids = + new_batch.requests.iter().map(|req| req.id).collect(); + let new_cached_batch = + wrap_future(client.generate(new_batch), new_batch_request_ids, &db) + .await; + if let Some(new_cached_batch) = new_cached_batch { + request_ids.extend(new_cached_batch.requests.iter().map(|req| req.id)); + batches.push(new_cached_batch); + } + } + } + + cached_batch = match batch_size { + size if size > 16 => { + wrap_future(client.generate_until_finished_with_cache(batches), request_ids, &db).await + } + _ => wrap_future(client.generate_with_cache(batches), request_ids, &db).await, + }; } } } } -async fn infer_batch_cached( - batch: BatchCached, - client: &ShardedClient, +async fn wrap_future( + future: impl Future, Option), ClientError>>, + request_ids: Vec, db: &Db, -) -> Option { - match client.generate_with_cache(batch.clone()).await { - Ok((finished, cache_entry)) => { - send_finished(finished, db); - cache_entry +) -> Option { + match future.await { + Ok((generated_texts, next_batch)) => { + send_generated(generated_texts, db); + next_batch } Err(err) => { println!("{:?}", err); - send_error(err, batch.request_ids, &db); - None - } - } -} - -async fn infer_batch(batch: Batch, client: &ShardedClient, db: &Db) -> Option { - match client.generate(batch.clone()).await { - Ok((finished, cache_entry)) => { - send_finished(finished, db); - cache_entry - } - Err(err) => { - println!("{:?}", err); - send_error( - err, - batch.requests.into_iter().map(|req| req.id).collect(), - &db, - ); + send_error(err, request_ids, db); None } } @@ -133,9 +117,9 @@ fn send_error(error: ClientError, request_ids: Vec, db: &Db) { }); } -fn send_finished(finished: Vec, db: &Db) { +fn send_generated(finished: Vec, db: &Db) { finished.into_iter().for_each(|output| { - let (_, response_tx) = db.remove(&output.id).unwrap(); + let (_, response_tx) = db.remove(&output.request.unwrap().id).unwrap(); response_tx.send(Ok(output.output)).unwrap_or(()); }); } diff --git a/router/src/db.rs b/router/src/db.rs index b6d218e2..5118b2fc 100644 --- a/router/src/db.rs +++ b/router/src/db.rs @@ -46,6 +46,7 @@ impl Db { pub(crate) fn append( &self, + input_length: usize, request: GenerateRequest, sender: Sender>, ) { @@ -63,6 +64,7 @@ impl Db { let request = Request { id, inputs: request.inputs, + input_length: input_length as u32, parameters, max_new_tokens: request.parameters.max_new_tokens, }; @@ -103,9 +105,13 @@ impl Db { pub(crate) fn next_batch(&self, max_size: usize) -> Option { if let Some((last_id, requests)) = self.next_requests(max_size) { let mut state = self.shared.state.write(); + let size = requests.len(); + let max_sequence_length = requests.iter().map(|r| r.input_length).max().unwrap(); let batch = Batch { id: state.next_batch_id, requests, + size: size as u32, + max_sequence_length, }; state.next_batch_start_id = last_id + 1; state.next_batch_id += 1; @@ -122,9 +128,13 @@ impl Db { if let Some((last_id, requests)) = self.next_requests(max_size) { if requests.len() >= min_size { let mut state = self.shared.state.write(); + let size = requests.len(); + let max_sequence_length = requests.iter().map(|r| r.input_length).max().unwrap(); let batch = Batch { id: state.next_batch_id, requests, + size: size as u32, + max_sequence_length, }; state.next_batch_start_id = last_id + 1; state.next_batch_id += 1; diff --git a/router/src/main.rs b/router/src/main.rs index 97ccb571..fe82d059 100644 --- a/router/src/main.rs +++ b/router/src/main.rs @@ -1,31 +1,45 @@ use bloom_inference_client::ShardedClient; -use poem; use poem::listener::TcpListener; use std::time::Duration; +use tokenizers::Tokenizer; mod server; +mod validation; + +use validation::Validation; mod db; + use db::Db; mod batcher; + use batcher::Batcher; -#[tokio::main] -async fn main() -> Result<(), std::io::Error> { - tracing_subscriber::fmt::init(); +fn main() -> Result<(), std::io::Error> { + let tokenizer = Tokenizer::from_pretrained("bigscience/bloom", None).unwrap(); - let sharded_client = - ShardedClient::connect_uds("/tmp/bloom-inference-0".to_string(), Duration::from_secs(5)) + tokio::runtime::Builder::new_multi_thread() + .enable_all() + .build() + .unwrap() + .block_on(async { + tracing_subscriber::fmt::init(); + + let sharded_client = ShardedClient::connect_uds( + "/tmp/bloom-inference-0".to_string(), + Duration::from_secs(5), + ) .await; - sharded_client - .clear_cache() - .await - .expect("Unable to clear cache"); - tracing::info!("Connected"); + sharded_client + .clear_cache() + .await + .expect("Unable to clear cache"); + tracing::info!("Connected"); - let addr = "127.0.0.1:3000".to_string(); - let listener = TcpListener::bind(addr); + let addr = "127.0.0.1:3000".to_string(); + let listener = TcpListener::bind(addr); - server::run(sharded_client, listener).await + server::run(sharded_client, tokenizer, listener).await + }) } diff --git a/router/src/server.rs b/router/src/server.rs index 0daf8df3..14e81709 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -1,12 +1,13 @@ -use poem::{EndpointExt, handler, post, Route, Server}; +use crate::{Batcher, ShardedClient, Validation}; use poem::http::StatusCode; use poem::listener::TcpListener; use poem::middleware::AddData; use poem::web::{Data, Json}; -use tokio::time::Instant; -use crate::{Batcher, ShardedClient}; -use tracing::instrument; +use poem::{handler, post, EndpointExt, Route, Server}; use serde::Deserialize; +use tokenizers::Tokenizer; +use tokio::time::Instant; +use tracing::instrument; #[derive(Clone, Debug, Deserialize)] pub(crate) struct GenerateParameters { @@ -59,21 +60,24 @@ pub(crate) struct GenerateRequest { pub parameters: GenerateParameters, } - #[handler] -#[instrument(skip(infer), fields(time, time_per_token))] +#[instrument(skip(validation, infer), fields(time, time_per_token))] async fn generate( + validation: Data<&Validation>, infer: Data<&Batcher>, req: Json, ) -> poem::Result> { let start = Instant::now(); - let output = infer - .infer(GenerateRequest { + let (input_length, validated_request) = validation + .validate(GenerateRequest { inputs: req.inputs.clone(), parameters: req.parameters.clone(), }) - .await; + .await + .unwrap(); + + let output = infer.infer(input_length, validated_request).await; match output { Ok(generated_text) => { @@ -92,20 +96,22 @@ async fn generate( } } -pub async fn run(client: ShardedClient, listener: TcpListener) -> Result<(), std::io::Error> { - client - .clear_cache() - .await - .expect("Unable to clear cache"); +pub async fn run( + client: ShardedClient, + tokenizer: Tokenizer, + listener: TcpListener, +) -> Result<(), std::io::Error> { + client.clear_cache().await.expect("Unable to clear cache"); tracing::info!("Connected"); let infer = Batcher::new(client); + let validation = Validation::new(tokenizer); + let app = Route::new() .at("/generate", post(generate)) + .with(AddData::new(validation)) .with(AddData::new(infer)); - Server::new(listener) - .run(app) - .await -} \ No newline at end of file + Server::new(listener).run(app).await +} diff --git a/router/src/validation.rs b/router/src/validation.rs new file mode 100644 index 00000000..6987894d --- /dev/null +++ b/router/src/validation.rs @@ -0,0 +1,65 @@ +use crate::server::GenerateRequest; +use tokenizers::tokenizer::Tokenizer; +use tokio::sync::{mpsc, oneshot}; + +#[derive(Debug)] +pub struct ValidationError {} + +type ValidationRequest = ( + GenerateRequest, + oneshot::Sender>, +); + +#[derive(Debug, Clone)] +pub(crate) struct Validation { + sender: mpsc::Sender, +} + +impl Validation { + pub(crate) fn new(tokenizer: Tokenizer) -> Self { + let (validation_sender, validation_receiver) = mpsc::channel(128); + + tokio::spawn(validation_task(tokenizer, validation_receiver)); + + Self { + sender: validation_sender, + } + } + + pub(crate) async fn validate( + &self, + request: GenerateRequest, + ) -> Result<(usize, GenerateRequest), ValidationError> { + let (sender, receiver) = oneshot::channel(); + self.sender.send((request, sender)).await.unwrap(); + receiver.await.unwrap() + } +} + +async fn validation_task(tokenizer: Tokenizer, mut receiver: mpsc::Receiver) { + while let Some((request, response_tx)) = receiver.recv().await { + if request.parameters.temperature < 0.0 { + response_tx.send(Err(ValidationError {})).unwrap_or(()); + continue; + } + if request.parameters.top_p <= 0.0 || request.parameters.top_p > 1.0 { + response_tx.send(Err(ValidationError {})).unwrap_or(()); + continue; + } + if request.parameters.max_new_tokens > 512 { + response_tx.send(Err(ValidationError {})).unwrap_or(()); + continue; + } + + let inputs = tokenizer.encode(request.inputs.clone(), false).unwrap(); + let input_length = inputs.len(); + + if input_length > 512 { + response_tx.send(Err(ValidationError {})).unwrap_or(()); + continue; + } + + response_tx.send(Ok((input_length, request))).unwrap_or(()); + } + println!("drop here"); +} diff --git a/server/bloom_inference/cache.py b/server/bloom_inference/cache.py index 5ef1f1da..6812b306 100644 --- a/server/bloom_inference/cache.py +++ b/server/bloom_inference/cache.py @@ -1,44 +1,19 @@ -import torch - -from dataclasses import dataclass -from typing import Dict, Optional, List - -from bloom_inference.pb import generate_pb2 -from bloom_inference.utils import NextTokenChooser, StoppingCriteria - - -@dataclass -class CacheEntry: - batch_id: int - request_ids: List[int] - input_ids: Dict[str, torch.Tensor] - all_input_ids: List[torch.Tensor] - next_token_choosers: List[NextTokenChooser] - stopping_criterias: List[StoppingCriteria] - - def __len__(self): - return len(self.request_ids) - - def to_pb(self): - return generate_pb2.CacheEntry( - id=self.batch_id, - request_ids=self.request_ids, - sequence_length=max(len(entry) for entry in self.all_input_ids), - ) +from bloom_inference.model import Batch +from typing import Dict, Optional class Cache: def __init__(self): - self.cache: Dict[str, CacheEntry] = {} + self.cache: Dict[int, Batch] = {} - def pop(self, batch_id: str) -> Optional[CacheEntry]: + def pop(self, batch_id: int) -> Optional[Batch]: return self.cache.pop(batch_id, None) - def set(self, entry: CacheEntry): + def set(self, entry: Batch): if entry is not None: self.cache[entry.batch_id] = entry - def delete(self, batch_id: str): + def delete(self, batch_id: int): del self.cache[batch_id] def clear(self): diff --git a/server/bloom_inference/model.py b/server/bloom_inference/model.py index a762a792..21cf1154 100644 --- a/server/bloom_inference/model.py +++ b/server/bloom_inference/model.py @@ -8,7 +8,6 @@ from typing import List, Tuple, Optional, Dict from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig from transformers.modeling_utils import no_init_weights -from bloom_inference.cache import CacheEntry from bloom_inference.pb import generate_pb2 from bloom_inference.shard_model import shard_model, match_suffix from bloom_inference.utils import ( @@ -24,25 +23,35 @@ torch.manual_seed(0) @dataclass class Batch: batch_id: int - request_ids: List[int] + requests: List[generate_pb2.Request] input_ids: Dict[str, torch.Tensor] all_input_ids: List[torch.Tensor] next_token_choosers: List[NextTokenChooser] stopping_criterias: List[StoppingCriteria] + size: int + max_sequence_length: int + + def to_pb(self): + return generate_pb2.Batch( + id=self.batch_id, + requests=self.requests, + size=self.size, + max_sequence_length=self.max_sequence_length, + ) @classmethod - def from_batch_pb( - cls, pb: generate_pb2.Batch, tokenizer: AutoTokenizer, device: torch.device + def from_pb( + cls, pb: generate_pb2.Batch, tokenizer: AutoTokenizer, device: torch.device ) -> "Batch": - request_ids = [] inputs = [] next_token_choosers = [] stopping_criterias = [] + input_lengths = [] # Parse batch for r in pb.requests: - request_ids.append(r.id) inputs.append(r.inputs) + input_lengths.append(r.input_length) next_token_choosers.append( NextTokenChooser( temperature=r.parameters.temperature, @@ -54,94 +63,93 @@ class Batch: stopping_criterias.append(StoppingCriteria(max_new_tokens=r.max_new_tokens)) input_ids = tokenizer(inputs, return_tensors="pt", padding=True).to(device) - all_input_ids = input_ids["input_ids"].unsqueeze(-1) + # Remove padding from all_input_ids + all_input_ids = [ + input_ids.squeeze(0)[-length:].unsqueeze(-1) + for length, input_ids in zip( + input_lengths, input_ids["input_ids"].split(1, dim=0) + ) + ] return cls( - pb.id, - request_ids, - input_ids, - all_input_ids, - next_token_choosers, - stopping_criterias, + batch_id=pb.id, + requests=pb.requests, + input_ids=input_ids, + all_input_ids=all_input_ids, + next_token_choosers=next_token_choosers, + stopping_criterias=stopping_criterias, + size=pb.size, + max_sequence_length=pb.max_sequence_length, ) @classmethod - def from_cache_entry(cls, cache_entry: CacheEntry) -> "Batch": - return cls( - cache_entry.batch_id, - cache_entry.request_ids, - cache_entry.input_ids, - cache_entry.all_input_ids, - cache_entry.next_token_choosers, - cache_entry.stopping_criterias, - ) + def concatenate(cls, batches: List["Batch"]) -> "Batch": + # Used for padding + total_batch_size = sum(batch.size for batch in batches) + max_sequence_length = max(batch.max_sequence_length for batch in batches) - @classmethod - def from_batch_cached_pb(cls, pb: generate_pb2.BatchCached, cache) -> "Batch": - if len(pb.batch_cached_ids) == 1: - cache_entry = cache.pop(pb.batch_cached_ids[0]) - if cache_entry is None: - raise ValueError(f"Batch ID {pb.batch_id} not found in cache") - return cls.from_cache_entry(cache_entry) - - total_batch_size = pb.total_batch_size - max_sequence_length = pb.max_sequence_length + # Batch attributes input_ids = {"input_ids": None, "attention_mask": None, "past_key_values": []} - request_ids = [] + requests = [] all_input_ids = [] next_token_choosers = [] stopping_criterias = [] + + # Used for slicing correctly inside the tensors + # Equivalent to a cumsum on batch sizes start_index = 0 - for i, batch_id in enumerate(pb.batch_cached_ids): - cache_entry = cache.pop(batch_id) - if cache_entry is None: - raise ValueError(f"Batch ID {batch_id} not found in cache") - request_ids.extend(cache_entry.request_ids) - all_input_ids.extend(cache_entry.all_input_ids) - next_token_choosers.extend(cache_entry.next_token_choosers) - stopping_criterias.extend(cache_entry.stopping_criterias) + for i, batch in enumerate(batches): + requests.extend(batch.requests) + all_input_ids.extend(batch.all_input_ids) + next_token_choosers.extend(batch.next_token_choosers) + stopping_criterias.extend(batch.stopping_criterias) - batch_size = len(cache_entry.request_ids) - end_index = start_index + batch_size - sequence_length = max(len(entry) for entry in cache_entry.all_input_ids) + # Slicing end index for this batch + end_index = start_index + batch.size - if input_ids["input_ids"] is None: + # We only concatenate batches that did at least one step + if batch.input_ids["input_ids"].shape[1] > 1: + raise ValueError("Batch input_ids should be of shape (batch_size, 1)") + + # Initialize tensors + if i == 0: input_ids["input_ids"] = torch.empty( (total_batch_size, 1), - dtype=cache_entry.input_ids["input_ids"].dtype, - device=cache_entry.input_ids["input_ids"].device, + dtype=batch.input_ids["input_ids"].dtype, + device=batch.input_ids["input_ids"].device, ) - - input_ids["input_ids"][start_index:end_index] = cache_entry.input_ids[ - "input_ids" - ] - - if input_ids["attention_mask"] is None: input_ids["attention_mask"] = torch.zeros( (total_batch_size, max_sequence_length), - dtype=cache_entry.input_ids["attention_mask"].dtype, - device=cache_entry.input_ids["attention_mask"].device, + dtype=batch.input_ids["attention_mask"].dtype, + device=batch.input_ids["attention_mask"].device, ) - input_ids["attention_mask"][ - start_index:end_index, -sequence_length: - ] = cache_entry.input_ids["attention_mask"][:, -sequence_length:] + # input_ids["input_ids"] is always of shape [batch_size, 1] + # We do not need to pad it + input_ids["input_ids"][start_index:end_index] = batch.input_ids["input_ids"] - for j, past in enumerate(cache_entry.input_ids["past_key_values"]): - # TODO: this could be done without the views by using indices + # We need to slice the attention mask to remove padding from previous steps + input_ids["attention_mask"][ + start_index:end_index, -batch.max_sequence_length : + ] = batch.input_ids["attention_mask"][:, -batch.max_sequence_length :] + + for j, past in enumerate(batch.input_ids["past_key_values"]): past_keys = past[0] past_values = past[1] _, head_dim, padded_sequence_length = past_keys.shape + # Reshape the tensors to make slicing easier past_keys = past_keys.view( - batch_size, -1, head_dim, padded_sequence_length + batch.size, -1, head_dim, padded_sequence_length ) past_values = past_values.view( - batch_size, -1, padded_sequence_length, head_dim + batch.size, -1, padded_sequence_length, head_dim ) num_heads = past_keys.shape[1] + # Initialize tensors + # This will run only once per layer if j == len(input_ids["past_key_values"]): padded_past_keys = torch.zeros( ( @@ -167,15 +175,17 @@ class Batch: [padded_past_keys, padded_past_values] ) + # We slice the past keys and values to remove the padding from previous batches input_ids["past_key_values"][j][0][ - start_index:end_index, :, :, -(sequence_length - 1): - ] = past_keys[:, :, :, -(sequence_length - 1):] + start_index:end_index, :, :, -(batch.max_sequence_length - 1) : + ] = past_keys[:, :, :, -(batch.max_sequence_length - 1) :] input_ids["past_key_values"][j][1][ - start_index:end_index, :, -(sequence_length - 1):, : - ] = past_values[:, :, -(sequence_length - 1):, :] + start_index:end_index, :, -(batch.max_sequence_length - 1) :, : + ] = past_values[:, :, -(batch.max_sequence_length - 1) :, :] - if (i + 1) == len(pb.batch_cached_ids): + # If we are on the last batch, we need to reshape the tensors + if (i + 1) == len(batches): input_ids["past_key_values"][j][0] = input_ids["past_key_values"][ j ][0].view(total_batch_size * num_heads, head_dim, -1) @@ -183,27 +193,27 @@ class Batch: j ][1].view(total_batch_size * num_heads, -1, head_dim) - start_index += batch_size - - assert pb.request_ids == request_ids + start_index += batch.size return cls( - pb.id, - request_ids, - input_ids, - all_input_ids, - next_token_choosers, - stopping_criterias, + batch_id=batches[0].batch_id, + requests=requests, + input_ids=input_ids, + all_input_ids=all_input_ids, + next_token_choosers=next_token_choosers, + stopping_criterias=stopping_criterias, + size=total_batch_size, + max_sequence_length=max_sequence_length, ) @dataclass -class FinishedGeneration: - request_id: str +class GeneratedText: + request: generate_pb2.Request output: str - def to_pb(self) -> generate_pb2.FinishedGeneration: - return generate_pb2.FinishedGeneration(id=self.request_id, output=self.output) + def to_pb(self) -> generate_pb2.GeneratedText: + return generate_pb2.GeneratedText(request=self.request, output=self.output) class BLOOM: @@ -229,25 +239,28 @@ class BLOOM: ) def generate_token( - self, batch: Batch - ) -> Tuple[List[FinishedGeneration], Optional[CacheEntry]]: + self, batch: Batch + ) -> Tuple[List[GeneratedText], Optional[Batch]]: with torch.no_grad(): outputs = self.forward(**batch.input_ids) # List of indices to cache - cache_indices = [] - cache_past_indices = [] + next_batch_keep_indices = [] + next_batch_past_keep_indices = [] - # New input_ids for next forward; keep in cache - cache_next_input_ids = [] - cache_all_input_ids = [] + # New input_ids for next forward + next_batch_input_ids = [] + next_batch_all_input_ids = [] + + next_batch_size = 0 + next_batch_max_sequence_length = 0 # Finished requests - finished_generations: List[FinishedGeneration] = [] + generated_texts: List[GeneratedText] = [] # Zipped iterator iterator = zip( - batch.request_ids, + batch.requests, outputs.logits, batch.next_token_choosers, batch.stopping_criterias, @@ -256,11 +269,11 @@ class BLOOM: # For each member of the batch for i, ( - request_id, - logits, - next_token_chooser, - stopping_criteria, - all_tokens, + request, + logits, + next_token_chooser, + stopping_criteria, + all_tokens, ) in enumerate(iterator): # Select next token next_token = next_token_chooser(all_tokens, logits.unsqueeze(0)[:, -1]) @@ -274,64 +287,75 @@ class BLOOM: output = self.tokenizer.decode( all_tokens.squeeze(-1), skip_special_tokens=True ) - # Add to the list of finished generations with the original request id - finished_generations.append(FinishedGeneration(request_id, output)) - # must be added to the cache + # Add to the list of finished generations with the original request + generated_texts.append(GeneratedText(request, output)) + # add to the next batch else: - cache_indices.append(i) - cache_past_indices.extend([j for j in range(i * self.num_heads, (i + 1) * self.num_heads)]) - cache_next_input_ids.append(next_token) - cache_all_input_ids.append(all_tokens) + next_batch_keep_indices.append(i) + # past_key_values is of shape [batch_size * num_heads, ...] + # so we need to take into account the `num_heads` stride here + next_batch_past_keep_indices.extend( + [j for j in range(i * self.num_heads, (i + 1) * self.num_heads)] + ) + next_batch_input_ids.append(next_token) + next_batch_all_input_ids.append(all_tokens) + next_batch_size += 1 + next_batch_max_sequence_length = max( + next_batch_max_sequence_length, len(all_tokens) + ) - # No cache is needed, we finished all generations in the batch - if not cache_indices: - return finished_generations, None + # We finished all generations in the batch; there is no next batch + if not next_batch_keep_indices: + return generated_texts, None # If we finished at least one generation - cache_input_ids = {"input_ids": torch.cat(cache_next_input_ids, dim=0)} - if finished_generations: + next_batch_input_ids = {"input_ids": torch.cat(next_batch_input_ids, dim=0)} + if generated_texts: # Apply indices to attention mask, past key values and other items that need to be cached - cache_input_ids["attention_mask"] = batch.input_ids["attention_mask"][ - cache_indices + next_batch_input_ids["attention_mask"] = batch.input_ids["attention_mask"][ + next_batch_keep_indices ] - cache_input_ids["past_key_values"] = [ - (keys[cache_past_indices], values[cache_past_indices]) + next_batch_input_ids["past_key_values"] = [ + ( + keys[next_batch_past_keep_indices], + values[next_batch_past_keep_indices], + ) for keys, values in outputs["past_key_values"] ] - cache_request_ids = [batch.request_ids[i] for i in cache_indices] - cache_next_token_choosers = [ - batch.next_token_choosers[i] for i in cache_indices + next_batch_requests = [batch.requests[i] for i in next_batch_keep_indices] + next_batch_next_token_choosers = [ + batch.next_token_choosers[i] for i in next_batch_keep_indices ] - cache_stopping_criterias = [ - batch.stopping_criterias[i] for i in cache_indices + next_batch_stopping_criterias = [ + batch.stopping_criterias[i] for i in next_batch_keep_indices ] else: - cache_input_ids["attention_mask"] = batch.input_ids["attention_mask"] - cache_input_ids["past_key_values"] = outputs["past_key_values"] - cache_request_ids = batch.request_ids - cache_next_token_choosers = batch.next_token_choosers - cache_stopping_criterias = batch.stopping_criterias + next_batch_input_ids["attention_mask"] = batch.input_ids["attention_mask"] + next_batch_input_ids["past_key_values"] = outputs["past_key_values"] + next_batch_requests = batch.requests + next_batch_next_token_choosers = batch.next_token_choosers + next_batch_stopping_criterias = batch.stopping_criterias # Update attention_mask with padding as we added a new token to input_ids - cache_input_ids["attention_mask"] = torch.cat( + next_batch_input_ids["attention_mask"] = torch.cat( [ - cache_input_ids["attention_mask"], - torch.ones((cache_input_ids["attention_mask"].shape[0], 1)).to( - cache_input_ids["attention_mask"].device - ), + next_batch_input_ids["attention_mask"], + torch.ones((next_batch_size, 1)).to(self.device), ], dim=1, ) - cache_entry = CacheEntry( - batch.batch_id, - cache_request_ids, - cache_input_ids, - cache_all_input_ids, - cache_next_token_choosers, - cache_stopping_criterias, + next_batch = Batch( + batch_id=batch.batch_id, + requests=next_batch_requests, + input_ids=next_batch_input_ids, + all_input_ids=next_batch_all_input_ids, + next_token_choosers=next_batch_next_token_choosers, + stopping_criterias=next_batch_stopping_criterias, + size=next_batch_size, + max_sequence_length=next_batch_max_sequence_length, ) - return finished_generations, cache_entry + return generated_texts, next_batch class BLOOMSharded(BLOOM): diff --git a/server/bloom_inference/server.py b/server/bloom_inference/server.py index c652468e..3a509169 100644 --- a/server/bloom_inference/server.py +++ b/server/bloom_inference/server.py @@ -10,7 +10,7 @@ from bloom_inference.model import BLOOM, Batch, BLOOMSharded from bloom_inference.pb import generate_pb2_grpc, generate_pb2 -class TextGeneration(generate_pb2_grpc.TextGenerationServicer): +class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer): def __init__(self, model: BLOOM, cache: Cache, server_urls: List[str]): self.cache = cache self.model = model @@ -21,32 +21,90 @@ class TextGeneration(generate_pb2_grpc.TextGenerationServicer): async def ClearCache(self, request, context): self.cache.clear() - return generate_pb2.Empty() + return generate_pb2.ClearCacheResponse() async def Generate(self, request, context): - batch = Batch.from_batch_pb(request, self.model.tokenizer, self.model.device) - finished_generations, cache_entry = self.model.generate_token(batch) - self.cache.set(cache_entry) + batch = Batch.from_pb(request.batch, self.model.tokenizer, self.model.device) - return generate_pb2.Response( - finished=[ - finished_generation.to_pb() - for finished_generation in finished_generations + generated_texts, next_batch = self.model.generate_token(batch) + self.cache.set(next_batch) + + return generate_pb2.GenerateResponse( + generated_texts=[ + generated_text.to_pb() for generated_text in generated_texts ], - cache_entry=cache_entry.to_pb() if cache_entry else None, + batch=next_batch.to_pb() if next_batch else None, ) async def GenerateWithCache(self, request, context): - batch = Batch.from_batch_cached_pb(request, self.cache) - finished_generations, cache_entry = self.model.generate_token(batch) - self.cache.set(cache_entry) + if len(request.batches) == 0: + raise ValueError("Must provide at least one batch") - return generate_pb2.Response( - finished=[ - finished_generation.to_pb() - for finished_generation in finished_generations + batches = [] + for batch_pb in request.batches: + batch = self.cache.pop(batch_pb.id) + if batch is None: + raise ValueError(f"Batch ID {batch_pb.id} not found in cache.") + batches.append(batch) + + if len(batches) > 1: + batch = Batch.concatenate(batches) + else: + batch = batches[0] + + generated_texts, next_batch = self.model.generate_token(batch) + self.cache.set(next_batch) + + return generate_pb2.GenerateWithCacheResponse( + generated_texts=[ + generated_text.to_pb() for generated_text in generated_texts ], - cache_entry=cache_entry.to_pb() if cache_entry else None, + batch=next_batch.to_pb() if next_batch else None, + ) + + async def GenerateUntilFinished(self, request, context): + batch = Batch.from_pb(request.batch, self.model.tokenizer, self.model.device) + + generated_texts = [] + while not generated_texts: + generated_texts, next_batch = self.model.generate_token(batch) + batch = next_batch + self.cache.set(next_batch) + + return generate_pb2.GenerateUntilFinishedResponse( + generated_texts=[ + generated_text.to_pb() for generated_text in generated_texts + ], + batch=next_batch.to_pb() if next_batch else None, + ) + + async def GenerateUntilFinishedWithCache(self, request, context): + if len(request.batches) == 0: + raise ValueError("Must provide at least one batch") + + batches = [] + for batch_pb in request.batches: + batch = self.cache.pop(batch_pb.id) + if batch is None: + raise ValueError(f"Batch ID {batch_pb.id} not found in cache.") + batches.append(batch) + + if len(batches) > 1: + batch = Batch.concatenate(batches) + else: + batch = batches[0] + + generated_texts = [] + while not generated_texts: + generated_texts, next_batch = self.model.generate_token(batch) + batch = next_batch + self.cache.set(next_batch) + + return generate_pb2.GenerateUntilFinishedWithCacheResponse( + generated_texts=[ + generated_text.to_pb() for generated_text in generated_texts + ], + batch=next_batch.to_pb() if next_batch else None, ) @@ -71,11 +129,11 @@ def serve(model_name, sharded, shard_directory): server_urls = [local_url] server = aio.server() - generate_pb2_grpc.add_TextGenerationServicer_to_server( - TextGeneration(model, Cache(), server_urls), server + generate_pb2_grpc.add_TextGenerationServiceServicer_to_server( + TextGenerationService(model, Cache(), server_urls), server ) SERVICE_NAMES = ( - generate_pb2.DESCRIPTOR.services_by_name["TextGeneration"].full_name, + generate_pb2.DESCRIPTOR.services_by_name["TextGenerationService"].full_name, reflection.SERVICE_NAME, ) reflection.enable_server_reflection(SERVICE_NAMES, server)