diff --git a/Cargo.lock b/Cargo.lock index 4c44820b..33f5d181 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1829,7 +1829,6 @@ dependencies = [ name = "text-generation-router" version = "0.1.0" dependencies = [ - "async-stream", "axum", "clap 4.0.22", "futures", @@ -1841,7 +1840,6 @@ dependencies = [ "thiserror", "tokenizers", "tokio", - "tokio-stream", "tracing", "tracing-subscriber", ] diff --git a/launcher/Cargo.toml b/launcher/Cargo.toml index 58df28d9..21d5d3ee 100644 --- a/launcher/Cargo.toml +++ b/launcher/Cargo.toml @@ -16,4 +16,4 @@ tracing-subscriber = { version = "0.3.16", features = ["json"] } [dev-dependencies] float_eq = "1.0.1" reqwest = { version = "0.11.13", features = ["blocking", "json"] } -serde = { version = "1.0.150", features = ["derive"] } +serde = "1.0.150" diff --git a/launcher/tests/bloom_560m.json b/launcher/tests/bloom_560m.json index a81d1982..d17f1ed4 100644 --- a/launcher/tests/bloom_560m.json +++ b/launcher/tests/bloom_560m.json @@ -3,7 +3,7 @@ "details": { "finish_reason": "length", "generated_tokens": 20, - "prefill": [ + "tokens": [ [ 10264, "Test", @@ -13,9 +13,7 @@ 8821, " request", -11.895094 - ] - ], - "tokens": [ + ], [ 17, ".", diff --git a/launcher/tests/mt0_base.json b/launcher/tests/mt0_base.json index 51cb8b5c..1b772282 100644 --- a/launcher/tests/mt0_base.json +++ b/launcher/tests/mt0_base.json @@ -3,14 +3,12 @@ "details": { "finish_reason": "length", "generated_tokens": 20, - "prefill": [ + "tokens": [ [ 0, "", null - ] - ], - "tokens": [ + ], [ 259, "", diff --git a/proto/generate.proto b/proto/generate.proto index 32ec9681..921bd5c0 100644 --- a/proto/generate.proto +++ b/proto/generate.proto @@ -7,10 +7,10 @@ service TextGenerationService { rpc ServiceDiscovery (ServiceDiscoveryRequest) returns (ServiceDiscoveryResponse) {} /// Empties batch cache rpc ClearCache (ClearCacheRequest) returns (ClearCacheResponse); - /// Prefill batch and decode first token - rpc Prefill (PrefillRequest) returns (PrefillResponse); - /// Decode token for a list of prefilled batches - rpc Decode (DecodeRequest) returns (DecodeResponse); + /// Generate tokens for a batch + rpc Generate (GenerateRequest) returns (GenerateResponse); + /// Generate tokens for a list of cached batches + rpc GenerateWithCache (GenerateWithCacheRequest) returns (GenerateWithCacheResponse); } /// Empty request @@ -70,60 +70,44 @@ message Batch { } message GeneratedText { + /// Request + Request request = 1; /// Output - string text = 1; + string output_text = 2; /// Number of generated tokens - uint32 generated_tokens = 2; + uint32 generated_tokens = 3; + /// Tokens + repeated string tokens = 4; + /// Token IDs + repeated uint32 token_ids = 5; + /// Logprobs + repeated float logprobs = 6; /// Finish reason - string finish_reason = 3; + string finish_reason = 7; /// Seed - optional uint64 seed = 4; + optional uint64 seed = 8; } -message PrefillTokens { - /// Prefill Token IDs - repeated uint32 ids = 1; - /// Prefill Logprobs - repeated float logprobs = 2; - /// Prefill tokens - repeated string texts = 3; -} - -message Generation { - /// Request ID - uint64 request_id = 1; - /// Prefill tokens (optional) - PrefillTokens prefill_tokens = 2; - /// Token ID - uint32 token_id = 3; - /// Logprob - float token_logprob = 4; - /// Text - string token_text = 5; - /// Complete generated text - GeneratedText generated_text = 6; -} - -message PrefillRequest { +message GenerateRequest { /// Batch Batch batch = 1; } -message PrefillResponse { - /// Generation - repeated Generation generations = 1; +message GenerateResponse { + /// Finished requests + repeated GeneratedText generated_texts = 1; /// Next batch (cached) optional Batch batch = 2; } -message DecodeRequest { +message GenerateWithCacheRequest { /// Cached batches repeated Batch batches = 1; } -message DecodeResponse { - /// Decodes - repeated Generation generations = 1; +message GenerateWithCacheResponse { + /// Finished requests + repeated GeneratedText generated_texts = 1; /// Next batch (cached) optional Batch batch = 2; -} \ No newline at end of file +} diff --git a/router/Cargo.toml b/router/Cargo.toml index d30d3b48..546f127f 100644 --- a/router/Cargo.toml +++ b/router/Cargo.toml @@ -13,7 +13,6 @@ name = "text-generation-router" path = "src/main.rs" [dependencies] -async-stream = "0.3.3" axum = { version = "0.5.16", features = ["json", "serde_json"] } text-generation-client = { path = "client" } clap = { version = "4.0.15", features = ["derive", "env"] } @@ -25,7 +24,6 @@ serde_json = "1.0.85" thiserror = "1.0.37" tokenizers = "0.13.0" tokio = { version = "1.21.1", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync"] } -tokio-stream = "0.1.11" tracing = "0.1.36" tracing-subscriber = { version = "0.3.15", features = ["json"] } diff --git a/router/client/src/client.rs b/router/client/src/client.rs index 77a43110..172d0bf7 100644 --- a/router/client/src/client.rs +++ b/router/client/src/client.rs @@ -70,36 +70,36 @@ impl Client { /// Generate one token for each request in the given batch /// - /// Returns Generation for each request in batch + /// Returns a list of generated texts of request that met their stopping criteria /// and the next cached batch #[instrument(skip(self))] - pub async fn prefill(&mut self, batch: Batch) -> Result<(Vec, Option)> { - let request = tonic::Request::new(PrefillRequest { batch: Some(batch) }); + pub async fn generate(&mut self, batch: Batch) -> Result<(Vec, Option)> { + let request = tonic::Request::new(GenerateRequest { batch: Some(batch) }); let response = self .stub - .prefill(request) - .instrument(info_span!("prefill")) + .generate(request) + .instrument(info_span!("generate")) .await? .into_inner(); - Ok((response.generations, response.batch)) + Ok((response.generated_texts, response.batch)) } - /// Generate one token for each request in the given cached batches + /// Generate one token for each request in the given cached batch /// - /// Returns Generation for each request in batches + /// Returns a list of generated texts of request that met their stopping criteria /// and the next cached batch #[instrument(skip(self))] - pub async fn decode( + pub async fn generate_with_cache( &mut self, batches: Vec, - ) -> Result<(Vec, Option)> { - let request = tonic::Request::new(DecodeRequest { batches }); + ) -> Result<(Vec, Option)> { + let request = tonic::Request::new(GenerateWithCacheRequest { batches }); let response = self .stub - .decode(request) - .instrument(info_span!("decode")) + .generate_with_cache(request) + .instrument(info_span!("generate_with_cache")) .await? .into_inner(); - Ok((response.generations, response.batch)) + Ok((response.generated_texts, response.batch)) } } diff --git a/router/client/src/lib.rs b/router/client/src/lib.rs index e0546b16..295b009b 100644 --- a/router/client/src/lib.rs +++ b/router/client/src/lib.rs @@ -7,8 +7,7 @@ mod sharded_client; pub use client::Client; pub use pb::generate::v1::{ - Batch, GeneratedText, Generation, NextTokenChooserParameters, PrefillTokens, Request, - StoppingCriteriaParameters, + Batch, GeneratedText, NextTokenChooserParameters, Request, StoppingCriteriaParameters, }; pub use sharded_client::ShardedClient; use thiserror::Error; diff --git a/router/client/src/sharded_client.rs b/router/client/src/sharded_client.rs index 56335f92..6c70afca 100644 --- a/router/client/src/sharded_client.rs +++ b/router/client/src/sharded_client.rs @@ -1,6 +1,6 @@ /// Multi shard Client use crate::Result; -use crate::{Batch, Client, Generation}; +use crate::{Batch, Client, GeneratedText}; use futures::future::join_all; use futures::future::select_all; use tonic::transport::Uri; @@ -37,6 +37,39 @@ impl ShardedClient { Self::from_master_client(master_client).await } + /// Generate one token for each request in the given batch + /// + /// Returns a list of generated texts of request that met their stopping criteria + /// and the next cached batch + pub async fn generate(&mut self, batch: Batch) -> Result<(Vec, Option)> { + let futures: Vec<_> = self + .clients + .iter_mut() + .map(|client| Box::pin(client.generate(batch.clone()))) + .collect(); + // As soon as we receive one response, we can return as all shards will return the same + let (result, _, _) = select_all(futures).await; + result + } + + /// Generate one token for each request in the given cached batch + /// + /// Returns a list of generated texts of request that met their stopping criteria + /// and the next cached batch + pub async fn generate_with_cache( + &mut self, + batches: Vec, + ) -> Result<(Vec, Option)> { + let futures: Vec<_> = self + .clients + .iter_mut() + .map(|client| Box::pin(client.generate_with_cache(batches.clone()))) + .collect(); + // As soon as we receive one response, we can return as all shards will return the same + let (result, _, _) = select_all(futures).await; + result + } + /// Clear the past generations cache pub async fn clear_cache(&mut self) -> Result<()> { let futures: Vec<_> = self @@ -46,37 +79,4 @@ impl ShardedClient { .collect(); join_all(futures).await.into_iter().collect() } - - /// Generate one token for each request in the given batch - /// - /// Returns Generation for each request in batch - /// and the next cached batch - pub async fn prefill(&mut self, batch: Batch) -> Result<(Vec, Option)> { - let futures: Vec<_> = self - .clients - .iter_mut() - .map(|client| Box::pin(client.prefill(batch.clone()))) - .collect(); - // As soon as we receive one response, we can return as all shards will return the same - let (result, _, _) = select_all(futures).await; - result - } - - /// Generate one token for each request in the given cached batches - /// - /// Returns Generation for each request in batches - /// and the next cached batch - pub async fn decode( - &mut self, - batches: Vec, - ) -> Result<(Vec, Option)> { - let futures: Vec<_> = self - .clients - .iter_mut() - .map(|client| Box::pin(client.decode(batches.clone()))) - .collect(); - // As soon as we receive one response, we can return as all shards will return the same - let (result, _, _) = select_all(futures).await; - result - } } diff --git a/router/src/batcher.rs b/router/src/batcher.rs new file mode 100644 index 00000000..baf58af4 --- /dev/null +++ b/router/src/batcher.rs @@ -0,0 +1,236 @@ +/// Batching and inference logic +use crate::{Db, Entry}; +use crate::{ErrorResponse, GenerateRequest}; +use axum::http::StatusCode; +use axum::Json; +use nohash_hasher::IntMap; +use std::future::Future; +use std::sync::Arc; +use text_generation_client::{Batch, ClientError, GeneratedText, ShardedClient}; +use thiserror::Error; +use tokio::sync::{oneshot, Notify}; +use tokio::time::Instant; +use tracing::instrument; + +/// Batcher +#[derive(Clone)] +pub struct Batcher { + /// Request database + db: Db, + /// Shared state + shared: Arc, +} + +/// Batcher shared state +struct Shared { + /// Batching background Tokio task notifier + batching_task: Notify, +} + +impl Batcher { + pub(crate) fn new( + client: ShardedClient, + max_batch_size: usize, + max_waiting_tokens: usize, + ) -> Self { + // Batcher shared state + let db = Db::new(); + let shared = Arc::new(Shared { + batching_task: Notify::new(), + }); + + // Spawn batching background task that contains all the inference logic + tokio::spawn(batching_task( + client, + max_batch_size, + max_waiting_tokens, + db.clone(), + shared.clone(), + )); + + Self { db, shared } + } + + /// Add a new request to the database and return a future that will generate the text + pub(crate) async fn infer( + &self, + input_length: usize, + request: GenerateRequest, + ) -> Result { + // One shot channel to communicate with the background batching task + let (response_tx, response_rx) = oneshot::channel(); + + // Try to append the request to the database + self.db.append(Entry { + request, + response_tx, + input_length, + time: Instant::now(), + batch_time: None, + }); + + // Notify the background task that we have a new entry in the database that needs + // to be batched + self.shared.batching_task.notify_one(); + + // Await on the response from the background task + // We can safely unwrap as the background task will never drop the sender + response_rx + .await + .unwrap() + .map_err(|err| InferError::GenerationError(err.to_string())) + } +} + +/// Batching logic +/// Will be launched in a background Tokio task +/// +/// Batches requests and sends them to the inference server +#[instrument(skip(client, db, shared))] +async fn batching_task( + mut client: ShardedClient, + max_batch_size: usize, + max_waiting_tokens: usize, + db: Db, + shared: Arc, +) { + // Minimum batch size after which we try to add more requests + let limit_min_batch_size = (max_batch_size / 2) as u32; + + // Infinite loop + loop { + // Wait for a notification from the Batcher struct + shared.batching_task.notified().await; + + // Get the next batch from the DB + // This batch might be smaller than the maximum batch size if there are not enough requests + // waiting in the DB + while let Some((mut entries, batch)) = db.next_batch(None, max_batch_size) { + let mut cached_batch = wrap_future(client.generate(batch), &mut entries).await; + let mut waiting_tokens = 1; + + // We loop until we do not receive any cached batch from the inference server (== until + // all requests have met their stopping criteria) + while let Some(batch) = cached_batch { + // Get current batch info + let batch_size = batch.size; + let mut batches = vec![batch]; + + // If the current batch is too small, we try to add more requests to it + if batch_size <= limit_min_batch_size { + let min_size = match waiting_tokens { + // If we didn't onboard any new requests since >= max_waiting_tokens, we try + // to add a new batch even though its size might be small + _ if waiting_tokens >= max_waiting_tokens => None, + // Minimum size criteria + _ => Some(limit_min_batch_size as usize), + }; + + // Try to get a new batch + if let Some((mut new_entries, new_batch)) = + db.next_batch(min_size, max_batch_size - batch_size as usize) + { + // Generate one token for this new batch to have the attention past in cache + let new_cached_batch = + wrap_future(client.generate(new_batch), &mut new_entries).await; + // Reset waiting counter + waiting_tokens = 1; + // Extend current batch with the new batch + if let Some(new_cached_batch) = new_cached_batch { + entries.extend(new_entries); + batches.push(new_cached_batch); + } + } + } + + cached_batch = wrap_future(client.generate_with_cache(batches), &mut entries).await; + waiting_tokens += 1; + } + } + } +} + +/// Wrap a future inside a match statement to handle errors and send the response to the Batcher +async fn wrap_future( + future: impl Future, Option), ClientError>>, + entries: &mut IntMap, +) -> Option { + match future.await { + Ok((generated_texts, next_batch)) => { + send_generated(generated_texts, entries); + next_batch + } + // If we have an error, we discard the whole batch + Err(err) => { + send_error(err, entries); + None + } + } +} + +/// Send errors to the Batcher for all `entries` +fn send_error(error: ClientError, entries: &mut IntMap) { + entries.drain().for_each(|(_, entry)| { + // unwrap_or is valid here as we don't care if the receiver is gone. + entry.response_tx.send(Err(error.clone())).unwrap_or(()); + }); +} + +/// Send `generated_text` to the Batcher for all `finished` +fn send_generated(finished: Vec, entries: &mut IntMap) { + finished.into_iter().for_each(|output| { + // We can `expect` here as the request id should always be in the entries + let entry = entries + .remove(&output.request.unwrap().id) + .expect("ID not found in entries. This is a bug."); + + let response = InferResponse { + output_text: output.output_text, + generated_tokens: output.generated_tokens, + token_ids: output.token_ids, + tokens: output.tokens, + logprobs: output.logprobs, + finish_reason: output.finish_reason, + seed: output.seed, + queued: entry.time, + start: entry.batch_time.unwrap(), // unwrap is always valid + end: Instant::now(), + }; + // unwrap_or is valid here as we don't care if the receiver is gone. + entry.response_tx.send(Ok(response)).unwrap_or(()); + }); +} + +#[derive(Debug)] +pub(crate) struct InferResponse { + pub(crate) output_text: String, + pub(crate) generated_tokens: u32, + pub(crate) token_ids: Vec, + pub(crate) tokens: Vec, + pub(crate) logprobs: Vec, + pub(crate) finish_reason: String, + pub(crate) seed: Option, + pub(crate) queued: Instant, + pub(crate) start: Instant, + pub(crate) end: Instant, +} + +#[derive(Debug, Error)] +pub enum InferError { + #[error("Request failed during generation: {0}")] + GenerationError(String), +} + +/// Convert to Axum supported format +impl From for (StatusCode, Json) { + fn from(err: InferError) -> Self { + match err { + InferError::GenerationError(_) => ( + StatusCode::FAILED_DEPENDENCY, + Json(ErrorResponse { + error: err.to_string(), + }), + ), + } + } +} diff --git a/router/src/db.rs b/router/src/db.rs index f0a62d65..15007b64 100644 --- a/router/src/db.rs +++ b/router/src/db.rs @@ -1,16 +1,14 @@ /// This code is massively inspired by Tokio mini-redis -use crate::infer::InferError; -use crate::infer::InferStreamResponse; +use crate::InferResponse; use crate::{GenerateParameters, GenerateRequest}; use nohash_hasher::{BuildNoHashHasher, IntMap}; use parking_lot::Mutex; use std::collections::BTreeMap; use std::sync::Arc; use text_generation_client::{ - Batch, NextTokenChooserParameters, Request, StoppingCriteriaParameters, + Batch, ClientError, NextTokenChooserParameters, Request, StoppingCriteriaParameters, }; -use tokio::sync::mpsc::UnboundedSender; -use tokio::sync::OwnedSemaphorePermit; +use tokio::sync::oneshot::Sender; use tokio::time::Instant; /// Database entry @@ -18,16 +16,14 @@ use tokio::time::Instant; pub(crate) struct Entry { /// Request pub request: GenerateRequest, - /// Response sender to communicate between the Infer struct and the batching_task - pub response_tx: UnboundedSender>, + /// Response sender to communicate between the Batcher and the batching_task + pub response_tx: Sender>, /// Number of tokens in the input pub input_length: usize, /// Instant when this entry was created pub time: Instant, /// Instant when this entry was added to a batch pub batch_time: Option, - /// Permit - pub _permit: OwnedSemaphorePermit, } /// Request Database diff --git a/router/src/infer.rs b/router/src/infer.rs deleted file mode 100644 index 4c4a7eb8..00000000 --- a/router/src/infer.rs +++ /dev/null @@ -1,354 +0,0 @@ -/// Batching and inference logic -use crate::validation::{Validation, ValidationError}; -use crate::GenerateRequest; -use crate::{Db, Entry, Token}; -use nohash_hasher::IntMap; -use std::future::Future; -use std::sync::Arc; -use text_generation_client::{ - Batch, ClientError, GeneratedText, Generation, PrefillTokens, ShardedClient, -}; -use thiserror::Error; -use tokio::sync::{mpsc, Notify, Semaphore, TryAcquireError}; -use tokio::time::Instant; -use tokio_stream::wrappers::UnboundedReceiverStream; -use tokio_stream::StreamExt; -use tracing::instrument; - -/// Inference struct -#[derive(Clone)] -pub struct Infer { - /// Validation - validation: Validation, - /// Request database - db: Db, - /// Shared state - shared: Arc, - /// Inference limit - limit_concurrent_requests: Arc, -} - -/// Infer shared state -struct Shared { - /// Batching background Tokio task notifier - batching_task: Notify, -} - -impl Infer { - pub(crate) fn new( - client: ShardedClient, - validation: Validation, - max_batch_size: usize, - max_waiting_tokens: usize, - max_concurrent_requests: usize, - ) -> Self { - // Infer shared state - let db = Db::new(); - let shared = Arc::new(Shared { - batching_task: Notify::new(), - }); - - // Spawn batching background task that contains all the inference logic - tokio::spawn(batching_task( - client, - max_batch_size, - max_waiting_tokens, - db.clone(), - shared.clone(), - )); - - // Inference limit with a semaphore - let semaphore = Arc::new(Semaphore::new(max_concurrent_requests)); - - Self { - validation, - db, - shared, - limit_concurrent_requests: semaphore, - } - } - - /// Add a new request to the database and return a stream of InferStreamResponse - pub(crate) async fn generate_stream( - &self, - request: GenerateRequest, - ) -> Result>, InferError> { - // Limit concurrent requests by acquiring a permit from the semaphore - // This permit will live as long as Entry - let permit = self.clone().limit_concurrent_requests.try_acquire_owned()?; - - // Validate request - let (input_length, validated_request) = self.validation.validate(request).await?; - - // MPSC channel to communicate with the background batching task - let (response_tx, response_rx) = mpsc::unbounded_channel(); - - // Append the request to the database - self.db.append(Entry { - request: validated_request, - response_tx, - input_length, - time: Instant::now(), - batch_time: None, - _permit: permit, - }); - - // Notify the background task that we have a new entry in the database that needs - // to be batched - self.shared.batching_task.notify_one(); - - // Return stream - Ok(UnboundedReceiverStream::new(response_rx)) - } - - /// Add a new request to the database and return a InferResponse - pub(crate) async fn generate( - &self, - request: GenerateRequest, - ) -> Result { - // Create stream - let mut stream = self.generate_stream(request).await?; - - // Return values - let mut result_prefill = Vec::new(); - let mut result_tokens = Vec::new(); - let mut result_generated_text = None; - let mut result_start = None; - let mut result_queued = None; - - // Iterate on stream - while let Some(response) = stream.next().await { - match response? { - // Add prefill tokens - InferStreamResponse::Prefill(tokens) => { - // Create Token objects - // We do that here instead of in the Python code as Rust for loops are faster - result_prefill = tokens - .ids - .into_iter() - .zip(tokens.logprobs.into_iter()) - .zip(tokens.texts.into_iter()) - .map(|((id, logprob), text)| Token(id, text, logprob)) - .collect(); - } - // Push last token - InferStreamResponse::Token(token) => result_tokens.push(token), - // Final message - // Set return values - InferStreamResponse::End { - token, - generated_text, - start, - queued, - } => { - result_tokens.push(token); - result_generated_text = Some(generated_text); - result_start = Some(start); - result_queued = Some(queued) - } - } - } - - // Check that we received a `InferStreamResponse::End` message - if let (Some(generated_text), Some(queued), Some(start)) = - (result_generated_text, result_queued, result_start) - { - Ok(InferResponse { - prefill: result_prefill, - tokens: result_tokens, - generated_text, - queued, - start, - }) - } else { - Err(InferError::IncompleteGeneration) - } - } -} - -/// Batching logic -/// Will be launched in a background Tokio task -/// -/// Batches requests and sends them to the inference server -#[instrument(skip(client, db, shared))] -async fn batching_task( - mut client: ShardedClient, - max_batch_size: usize, - max_waiting_tokens: usize, - db: Db, - shared: Arc, -) { - // Minimum batch size after which we try to add more requests - let limit_min_batch_size = (max_batch_size / 2) as u32; - - // Infinite loop - loop { - // Wait for a notification from the Infer struct - shared.batching_task.notified().await; - - // Get the next batch from the DB - // This batch might be smaller than the maximum batch size if there are not enough requests - // waiting in the DB - while let Some((mut entries, batch)) = db.next_batch(None, max_batch_size) { - let mut cached_batch = wrap_future(client.prefill(batch), &mut entries).await; - let mut waiting_tokens = 1; - - // We loop until we do not receive any cached batch from the inference server (== until - // all requests have met their stopping criteria) - while let Some(batch) = cached_batch { - // Get current batch info - let batch_size = batch.size; - let mut batches = vec![batch]; - - // If the current batch is too small, we try to add more requests to it - if batch_size <= limit_min_batch_size { - let min_size = match waiting_tokens { - // If we didn't onboard any new requests since >= max_waiting_tokens, we try - // to add a new batch even though its size might be small - _ if waiting_tokens >= max_waiting_tokens => None, - // Minimum size criteria - _ => Some(limit_min_batch_size as usize), - }; - - // Try to get a new batch - if let Some((mut new_entries, new_batch)) = - db.next_batch(min_size, max_batch_size - batch_size as usize) - { - // Generate one token for this new batch to have the attention past in cache - let new_cached_batch = - wrap_future(client.prefill(new_batch), &mut new_entries).await; - // Reset waiting counter - waiting_tokens = 1; - // Extend current batch with the new batch - if let Some(new_cached_batch) = new_cached_batch { - entries.extend(new_entries); - batches.push(new_cached_batch); - } - } - } - - cached_batch = wrap_future(client.decode(batches), &mut entries).await; - waiting_tokens += 1; - } - } - } -} - -/// Wrap a future inside a match statement to handle errors and send the responses to Infer -async fn wrap_future( - future: impl Future, Option), ClientError>>, - entries: &mut IntMap, -) -> Option { - match future.await { - Ok((generations, next_batch)) => { - send_generations(generations, entries); - next_batch - } - // If we have an error, we discard the whole batch - Err(err) => { - send_error(err, entries); - None - } - } -} - -/// Send errors to Infer for all `entries` -fn send_error(error: ClientError, entries: &mut IntMap) { - entries.drain().for_each(|(_, entry)| { - // unwrap_or is valid here as we don't care if the receiver is gone. - entry - .response_tx - .send(Err(InferError::GenerationError(error.to_string()))) - .unwrap_or(()); - }); -} - -/// Send one or multiple `InferStreamResponse` to Infer for all `entries` -fn send_generations(generations: Vec, entries: &mut IntMap) { - generations.into_iter().for_each(|generation| { - // Get entry - // We can `expect` here as the request id should always be in the entries - let entry = entries - .get(&generation.request_id) - .expect("ID not found in entries. This is a bug."); - - if let Some(prefill_tokens) = generation.prefill_tokens { - // Send message - // unwrap_or is valid here as we don't care if the receiver is gone. - entry - .response_tx - .send(Ok(InferStreamResponse::Prefill(prefill_tokens))) - .unwrap_or(()); - } - - // Create last Token - let token = Token( - generation.token_id, - generation.token_text, - generation.token_logprob, - ); - - if let Some(generated_text) = generation.generated_text { - // Remove entry as this is the last message - // We can `expect` here as the request id should always be in the entries - let entry = entries - .remove(&generation.request_id) - .expect("ID not found in entries. This is a bug."); - - // Send message - // unwrap_or is valid here as we don't care if the receiver is gone. - entry - .response_tx - .send(Ok(InferStreamResponse::End { - token, - generated_text, - queued: entry.time, - start: entry.batch_time.unwrap(), - })) - .unwrap_or(()); - } else { - // Send message - // unwrap_or is valid here as we don't care if the receiver is gone. - entry - .response_tx - .send(Ok(InferStreamResponse::Token(token))) - .unwrap_or(()); - } - }); -} - -#[derive(Debug)] -pub(crate) enum InferStreamResponse { - // Optional first message - Prefill(PrefillTokens), - // Intermediate messages - Token(Token), - // Last message - End { - token: Token, - generated_text: GeneratedText, - start: Instant, - queued: Instant, - }, -} - -#[derive(Debug)] -pub(crate) struct InferResponse { - pub(crate) prefill: Vec, - pub(crate) tokens: Vec, - pub(crate) generated_text: GeneratedText, - pub(crate) queued: Instant, - pub(crate) start: Instant, -} - -#[derive(Debug, Error)] -pub enum InferError { - #[error("Request failed during generation: {0}")] - GenerationError(String), - #[error("Model is overloaded")] - Overloaded(#[from] TryAcquireError), - #[error("Input validation error: {0}")] - ValidationError(#[from] ValidationError), - #[error("Incomplete generation")] - IncompleteGeneration, -} diff --git a/router/src/lib.rs b/router/src/lib.rs index beab7138..1aeac302 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -1,11 +1,11 @@ /// Text Generation Inference Webserver +mod batcher; mod db; -mod infer; pub mod server; mod validation; +use batcher::{Batcher, InferResponse}; use db::{Db, Entry}; -use infer::Infer; use serde::{Deserialize, Serialize}; use validation::Validation; @@ -69,34 +69,21 @@ pub(crate) struct GenerateRequest { pub parameters: GenerateParameters, } -#[derive(Debug, Serialize)] -pub struct Token(u32, String, f32); - #[derive(Serialize)] pub(crate) struct Details { pub finish_reason: String, pub generated_tokens: u32, pub seed: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub prefill: Option>, - #[serde(skip_serializing_if = "Option::is_none")] - pub tokens: Option>, + pub tokens: Vec<(u32, String, f32)>, } #[derive(Serialize)] -pub(crate) struct GenerateResponse { +pub(crate) struct GeneratedText { pub generated_text: String, #[serde(skip_serializing_if = "Option::is_none")] pub details: Option
, } -#[derive(Serialize)] -pub(crate) struct StreamResponse { - pub token: Token, - pub generated_text: Option, - pub details: Option
, -} - #[derive(Serialize)] pub(crate) struct ErrorResponse { pub error: String, diff --git a/router/src/server.rs b/router/src/server.rs index ef3782d6..86041b96 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -1,54 +1,71 @@ -/// HTTP Server logic -use crate::infer::{InferError, InferStreamResponse}; use crate::{ - Details, ErrorResponse, GenerateParameters, GenerateRequest, GenerateResponse, Infer, - StreamResponse, Validation, + Batcher, Details, ErrorResponse, GenerateParameters, GenerateRequest, GeneratedText, Validation, }; use axum::extract::Extension; use axum::http::{HeaderMap, StatusCode}; -use axum::response::sse::{Event, KeepAlive, Sse}; use axum::response::IntoResponse; use axum::routing::{get, post}; use axum::{Json, Router}; -use futures::Stream; -use std::convert::Infallible; use std::net::SocketAddr; +use std::sync::Arc; use text_generation_client::ShardedClient; use tokenizers::Tokenizer; use tokio::signal; +use tokio::sync::Semaphore; use tokio::time::Instant; -use tokio_stream::StreamExt; use tracing::instrument; +// Server shared state +#[derive(Clone)] +struct ServerState { + validation: Validation, + batcher: Batcher, + limit_concurrent_requests: Arc, +} + /// Health check method -#[instrument(skip(infer))] -async fn health(infer: Extension) -> Result<(), (StatusCode, Json)> { +#[instrument(skip(state), fields(time, time_per_token))] +async fn health(state: Extension) -> Result<(), (StatusCode, Json)> { // TODO: while this is the best health check we can do, it is a bit on the heavy side and might // be a bit too slow for a health check. // What we should do instead if check if the gRPC channels are still healthy. + // Limit concurrent requests by acquiring a permit from the semaphore + let _permit = state.limit_concurrent_requests.try_acquire().map_err(|_| { + ( + StatusCode::TOO_MANY_REQUESTS, + Json(ErrorResponse { + error: "Model is overloaded".to_string(), + }), + ) + })?; + // Send a small inference request - infer - .generate(GenerateRequest { - inputs: "liveness".to_string(), - parameters: GenerateParameters { - temperature: 1.0, - top_k: 0, - top_p: 1.0, - do_sample: false, - max_new_tokens: 1, - stop: vec![], - details: false, - seed: None, + state + .batcher + .infer( + 1, + GenerateRequest { + inputs: "liveness".to_string(), + parameters: GenerateParameters { + temperature: 1.0, + top_k: 0, + top_p: 1.0, + do_sample: false, + max_new_tokens: 1, + stop: vec![], + details: false, + seed: None, + }, }, - }) + ) .await?; Ok(()) } /// Generate method #[instrument( - skip(infer), + skip(state), fields( total_time, validation_time, @@ -59,28 +76,56 @@ async fn health(infer: Extension) -> Result<(), (StatusCode, Json, + state: Extension, req: Json, ) -> Result)> { - let span = tracing::Span::current(); let start_time = Instant::now(); + // Limit concurrent requests by acquiring a permit from the semaphore + let _permit = state.limit_concurrent_requests.try_acquire().map_err(|_| { + tracing::error!("Model is overloaded"); + ( + StatusCode::TOO_MANY_REQUESTS, + Json(ErrorResponse { + error: "Model is overloaded".to_string(), + }), + ) + })?; + + // Validate request + let details = req.0.parameters.details; + let (input_length, validated_request) = + state.validation.validate(req.0).await.map_err(|err| { + tracing::error!("{}", err.to_string()); + err + })?; // Inference - let details = req.0.parameters.details; - let response = infer.generate(req.0).await.map_err(|err| { - tracing::error!("{}", err.to_string()); - err - })?; + let response = state + .batcher + .infer(input_length, validated_request) + .await + .map_err(|err| { + tracing::error!("{}", err.to_string()); + err + })?; // Token details let details = match details { - true => Some(Details { - finish_reason: response.generated_text.finish_reason, - generated_tokens: response.generated_text.generated_tokens, - prefill: Some(response.prefill), - tokens: Some(response.tokens), - seed: response.generated_text.seed, - }), + true => { + let tokens = response + .token_ids + .into_iter() + .zip(response.tokens.into_iter()) + .zip(response.logprobs.into_iter()) + .map(|((id, text), logprob)| (id, text, logprob)) + .collect(); + Some(Details { + seed: response.seed, + finish_reason: response.finish_reason, + generated_tokens: response.generated_tokens, + tokens, + }) + } false => None, }; @@ -88,8 +133,8 @@ async fn generate( let total_time = start_time.elapsed(); let validation_time = response.queued - start_time; let queue_time = response.start - response.queued; - let inference_time = Instant::now() - response.start; - let time_per_token = inference_time / response.generated_text.generated_tokens; + let inference_time = response.end - response.start; + let time_per_token = inference_time / response.generated_tokens; // Headers let mut headers = HeaderMap::new(); @@ -115,143 +160,22 @@ async fn generate( ); // Tracing metadata - span.record("total_time", format!("{:?}", total_time)); - span.record("validation_time", format!("{:?}", validation_time)); - span.record("queue_time", format!("{:?}", queue_time)); - span.record("inference_time", format!("{:?}", inference_time)); - span.record("time_per_token", format!("{:?}", time_per_token)); - span.record("seed", format!("{:?}", response.generated_text.seed)); - tracing::info!("Output: {}", response.generated_text.text); + tracing::Span::current().record("total_time", format!("{:?}", total_time)); + tracing::Span::current().record("validation_time", format!("{:?}", validation_time)); + tracing::Span::current().record("queue_time", format!("{:?}", queue_time)); + tracing::Span::current().record("inference_time", format!("{:?}", inference_time)); + tracing::Span::current().record("time_per_token", format!("{:?}", time_per_token)); + tracing::Span::current().record("seed", format!("{:?}", response.seed)); + tracing::info!("Output: {}", response.output_text); // Send response - let response = vec![GenerateResponse { - generated_text: response.generated_text.text, + let response = vec![GeneratedText { + generated_text: response.output_text, details, }]; Ok((headers, Json(response))) } -/// Generate stream method -#[instrument( - skip(infer), - fields( - total_time, - validation_time, - queue_time, - inference_time, - time_per_token - ) -)] -async fn generate_stream( - infer: Extension, - req: Json, -) -> Sse>> { - let span = tracing::Span::current(); - let start_time = Instant::now(); - - let stream = async_stream::stream! { - // Inference - let mut end_reached = false; - let mut error = false; - let details = req.0.parameters.details; - - match infer.generate_stream(req.0).await { - Ok(mut response_stream) => { - // Server Side Event stream - while let Some(response) = response_stream.next().await { - match response { - Ok(response) => { - match response { - // Prefill is ignored - InferStreamResponse::Prefill(_) => {} - // Yield event for every new token - InferStreamResponse::Token(token) => { - // StreamResponse - let stream_token = StreamResponse { - token, - generated_text: None, - details: None, - }; - - yield Ok(Event::default().json_data(stream_token).unwrap()) - } - // Yield event for last token and compute timings - InferStreamResponse::End { - token, - generated_text, - start, - queued, - } => { - // Token details - let details = match details { - true => Some(Details { - finish_reason: generated_text.finish_reason, - generated_tokens: generated_text.generated_tokens, - prefill: None, - tokens: None, - seed: generated_text.seed, - }), - false => None, - }; - - // Timings - let total_time = start_time.elapsed(); - let validation_time = queued - start_time; - let queue_time = start - queued; - let inference_time = Instant::now() - start; - let time_per_token = inference_time / generated_text.generated_tokens; - - // Tracing metadata - span.record("total_time", format!("{:?}", total_time)); - span - .record("validation_time", format!("{:?}", validation_time)); - span.record("queue_time", format!("{:?}", queue_time)); - span - .record("inference_time", format!("{:?}", inference_time)); - span - .record("time_per_token", format!("{:?}", time_per_token)); - tracing::info!(parent: &span, "Output: {}", generated_text.text); - - // StreamResponse - end_reached = true; - let stream_token = StreamResponse { - token, - generated_text: Some(generated_text.text), - details - }; - - yield Ok(Event::default().json_data(stream_token).unwrap()) - } - } - } - // Trace and yield error - Err(err) => { - error = true; - tracing::error!("{}", err.to_string()); - yield Ok(Event::from(err)) - } - } - } - }, - // Trace and yield error - Err(err) => { - error = true; - tracing::error!("{}", err.to_string()); - yield Ok(Event::from(err)) - } - } - // Check if generation reached the end - // Skip if we already sent an error - if !end_reached && !error { - let err = InferError::IncompleteGeneration; - tracing::error!("{}", err.to_string()); - yield Ok(Event::from(err)) - } - }; - - Sse::new(stream).keep_alive(KeepAlive::default()) -} - /// Serving method #[allow(clippy::too_many_arguments)] pub async fn run( @@ -265,23 +189,21 @@ pub async fn run( addr: SocketAddr, ) { // Create state + let batcher = Batcher::new(client, max_batch_size, max_waiting_tokens); let validation = Validation::new(validation_workers, tokenizer, max_input_length); - let infer = Infer::new( - client, + let shared_state = ServerState { validation, - max_batch_size, - max_waiting_tokens, - max_concurrent_requests, - ); + batcher, + limit_concurrent_requests: Arc::new(Semaphore::new(max_concurrent_requests)), + }; // Create router let app = Router::new() .route("/", post(generate)) .route("/generate", post(generate)) - .route("/generate_stream", post(generate_stream)) .route("/", get(health)) .route("/health", get(health)) - .layer(Extension(infer)); + .layer(Extension(shared_state.clone())); // Run server axum::Server::bind(&addr) @@ -318,32 +240,3 @@ async fn shutdown_signal() { tracing::info!("signal received, starting graceful shutdown"); } - -/// Convert to Axum supported formats -impl From for (StatusCode, Json) { - fn from(err: InferError) -> Self { - let status_code = match err { - InferError::GenerationError(_) => StatusCode::FAILED_DEPENDENCY, - InferError::Overloaded(_) => StatusCode::TOO_MANY_REQUESTS, - InferError::ValidationError(_) => StatusCode::UNPROCESSABLE_ENTITY, - InferError::IncompleteGeneration => StatusCode::INTERNAL_SERVER_ERROR, - }; - - ( - status_code, - Json(ErrorResponse { - error: err.to_string(), - }), - ) - } -} - -impl From for Event { - fn from(err: InferError) -> Self { - Event::default() - .json_data(ErrorResponse { - error: err.to_string(), - }) - .unwrap() - } -} diff --git a/router/src/validation.rs b/router/src/validation.rs index d2287168..aabc82a6 100644 --- a/router/src/validation.rs +++ b/router/src/validation.rs @@ -1,5 +1,7 @@ /// Payload validation logic -use crate::GenerateRequest; +use crate::{ErrorResponse, GenerateRequest}; +use axum::http::StatusCode; +use axum::Json; use thiserror::Error; use tokenizers::tokenizer::Tokenizer; use tokio::sync::{mpsc, oneshot}; @@ -159,3 +161,14 @@ pub enum ValidationError { #[error("tokenizer error {0}")] Tokenizer(String), } + +impl From for (StatusCode, Json) { + fn from(err: ValidationError) -> Self { + ( + StatusCode::UNPROCESSABLE_ENTITY, + Json(ErrorResponse { + error: err.to_string(), + }), + ) + } +} diff --git a/server/tests/models/test_bloom.py b/server/tests/models/test_bloom.py index 9f96efc3..1a788ce5 100644 --- a/server/tests/models/test_bloom.py +++ b/server/tests/models/test_bloom.py @@ -91,9 +91,9 @@ def test_causal_lm_batch_type(default_bloom): def test_causal_lm_generate_token(default_bloom, default_bloom_batch): sequence_length = len(default_bloom_batch.all_input_ids[0]) - generations, next_batch = default_bloom.generate_token(default_bloom_batch) + generated_texts, next_batch = default_bloom.generate_token(default_bloom_batch) - assert len(generations) == len(default_bloom_batch) + assert generated_texts == [] assert isinstance(next_batch, CausalLMBatch) assert not next_batch.keys_head_dim_last @@ -122,30 +122,24 @@ def test_causal_lm_generate_token(default_bloom, default_bloom_batch): assert all( [p[1].shape == (16, sequence_length, 64) for p in next_batch.past_key_values] ) - assert all([generation.generated_text is None for generation in generations]) - assert all([len(generation.prefill_tokens) == 1 for generation in generations]) - assert all([generation.token_id.item() == 10264 for generation in generations]) - assert all([generation.token_text == "Test" for generation in generations]) - assert generations[0].request_id == 0 def test_causal_lm_generate_token_completion(default_bloom, default_bloom_batch): next_batch = default_bloom_batch for _ in range(default_bloom_batch.stopping_criterias[0].max_new_tokens - 1): - generations, next_batch = default_bloom.generate_token(next_batch) - assert len(generations) == len(default_bloom_batch) + generated_texts, next_batch = default_bloom.generate_token(next_batch) + assert generated_texts == [] - generations, next_batch = default_bloom.generate_token(next_batch) + generated_texts, next_batch = default_bloom.generate_token(next_batch) assert next_batch is None - assert len(generations) == 1 + assert len(generated_texts) == 1 assert ( - generations[0].generated_text.text - == "TestTestTestTestTestTestTestTestTestTestTest" + generated_texts[0].output_text == "TestTestTestTestTestTestTestTestTestTestTest" ) - assert generations[0].request_id == default_bloom_batch.requests[0].id + assert generated_texts[0].request == default_bloom_batch.requests[0] assert ( - generations[0].generated_text.generated_tokens + generated_texts[0].generated_tokens == default_bloom_batch.stopping_criterias[0].max_new_tokens ) @@ -158,19 +152,17 @@ def test_causal_lm_generate_token_completion_multi( for i in range( default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens - 1 ): - generations, next_batch = default_bloom.generate_token(next_batch) - assert len(generations) == len(default_multi_requests_bloom_batch) + generated_texts, next_batch = default_bloom.generate_token(next_batch) + assert generated_texts == [] - generations, next_batch = default_bloom.generate_token(next_batch) + generated_texts, next_batch = default_bloom.generate_token(next_batch) assert next_batch is not None - assert len(generations) == 2 - assert generations[1].generated_text.text == "TestTestTestTestTestTest" + assert len(generated_texts) == 1 + assert generated_texts[0].output_text == "TestTestTestTestTestTest" + assert generated_texts[0].request == default_multi_requests_bloom_batch.requests[1] assert ( - generations[1].request_id == default_multi_requests_bloom_batch.requests[1].id - ) - assert ( - generations[1].generated_text.generated_tokens + generated_texts[0].generated_tokens == default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens ) @@ -179,22 +171,19 @@ def test_causal_lm_generate_token_completion_multi( - default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens - 1 ): - generations, next_batch = default_bloom.generate_token(next_batch) - assert len(generations) == len(next_batch) + generated_texts, next_batch = default_bloom.generate_token(next_batch) + assert generated_texts == [] - generations, next_batch = default_bloom.generate_token(next_batch) + generated_texts, next_batch = default_bloom.generate_token(next_batch) assert next_batch is None - assert len(generations) == 1 + assert len(generated_texts) == 1 assert ( - generations[0].generated_text.text - == "TestTestTestTestTestTestTestTestTestTestTest" + generated_texts[0].output_text == "TestTestTestTestTestTestTestTestTestTestTest" ) + assert generated_texts[0].request == default_multi_requests_bloom_batch.requests[0] assert ( - generations[0].request_id == default_multi_requests_bloom_batch.requests[0].id - ) - assert ( - generations[0].generated_text.generated_tokens + generated_texts[0].generated_tokens == default_multi_requests_bloom_batch.stopping_criterias[0].max_new_tokens ) @@ -254,19 +243,17 @@ def test_batch_concatenate( for _ in range( default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens - 2 ): - generations, next_batch = default_bloom.generate_token(next_batch) - assert len(generations) == len(next_batch) + generated_texts, next_batch = default_bloom.generate_token(next_batch) + assert generated_texts == [] - generations, next_batch = default_bloom.generate_token(next_batch) + generated_texts, next_batch = default_bloom.generate_token(next_batch) assert next_batch is not None - assert len(generations) == 3 - assert generations[2].generated_text.text == "TestTestTestTestTestTest" + assert len(generated_texts) == 1 + assert generated_texts[0].output_text == "TestTestTestTestTestTest" + assert generated_texts[0].request == default_multi_requests_bloom_batch.requests[1] assert ( - generations[2].request_id == default_multi_requests_bloom_batch.requests[1].id - ) - assert ( - generations[2].generated_text.generated_tokens + generated_texts[0].generated_tokens == default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens ) @@ -275,20 +262,19 @@ def test_batch_concatenate( - default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens - 2 ): - generations, next_batch = default_bloom.generate_token(next_batch) - assert len(generations) == len(next_batch) + generated_texts, next_batch = default_bloom.generate_token(next_batch) + assert generated_texts == [] - generations, next_batch = default_bloom.generate_token(next_batch) + generated_texts, next_batch = default_bloom.generate_token(next_batch) assert next_batch is not None - assert len(generations) == 2 + assert len(generated_texts) == 1 assert ( - generations[0].generated_text.text - == "TestTestTestTestTestTestTestTestTestTestTest" + generated_texts[0].output_text == "TestTestTestTestTestTestTestTestTestTestTest" ) - assert generations[0].request_id == default_bloom_batch.requests[0].id + assert generated_texts[0].request == default_bloom_batch.requests[0] assert ( - generations[0].generated_text.generated_tokens + generated_texts[0].generated_tokens == default_bloom_batch.stopping_criterias[0].max_new_tokens ) @@ -298,21 +284,18 @@ def test_batch_concatenate( - default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens - 4 ): - generations, next_batch = default_bloom.generate_token(next_batch) - assert len(generations) == len(next_batch) + generated_texts, next_batch = default_bloom.generate_token(next_batch) + assert generated_texts == [] - generations, next_batch = default_bloom.generate_token(next_batch) + generated_texts, next_batch = default_bloom.generate_token(next_batch) assert next_batch is None - assert len(generations) == 1 + assert len(generated_texts) == 1 assert ( - generations[0].generated_text.text - == "TestTestTestTestTestTestTestTestTestTestTest" + generated_texts[0].output_text == "TestTestTestTestTestTestTestTestTestTestTest" ) + assert generated_texts[0].request == default_multi_requests_bloom_batch.requests[0] assert ( - generations[0].request_id == default_multi_requests_bloom_batch.requests[0].id - ) - assert ( - generations[0].generated_text.generated_tokens + generated_texts[0].generated_tokens == default_multi_requests_bloom_batch.stopping_criterias[0].max_new_tokens ) diff --git a/server/tests/models/test_causal_lm.py b/server/tests/models/test_causal_lm.py index f9762b30..bedb65ba 100644 --- a/server/tests/models/test_causal_lm.py +++ b/server/tests/models/test_causal_lm.py @@ -88,9 +88,11 @@ def test_causal_lm_batch_type(default_causal_lm): def test_causal_lm_generate_token(default_causal_lm, default_causal_lm_batch): sequence_length = len(default_causal_lm_batch.all_input_ids[0]) - generations, next_batch = default_causal_lm.generate_token(default_causal_lm_batch) + generated_texts, next_batch = default_causal_lm.generate_token( + default_causal_lm_batch + ) - assert len(generations) == len(next_batch) + assert generated_texts == [] assert isinstance(next_batch, CausalLMBatch) assert len(next_batch.all_input_ids) == next_batch.size @@ -119,11 +121,6 @@ def test_causal_lm_generate_token(default_causal_lm, default_causal_lm_batch): assert all( [p[1].shape == (1, 12, sequence_length, 64) for p in next_batch.past_key_values] ) - assert all([generation.generated_text is None for generation in generations]) - assert all([len(generation.prefill_tokens) == 1 for generation in generations]) - assert all([generation.token_id.item() == 13 for generation in generations]) - assert all([generation.token_text == "." for generation in generations]) - assert generations[0].request_id == 0 def test_causal_lm_generate_token_completion( @@ -131,17 +128,18 @@ def test_causal_lm_generate_token_completion( ): next_batch = default_causal_lm_batch for _ in range(default_causal_lm_batch.stopping_criterias[0].max_new_tokens - 1): - generations, next_batch = default_causal_lm.generate_token(next_batch) - assert len(generations) == len(next_batch) + generated_texts, next_batch = default_causal_lm.generate_token(next_batch) + assert generated_texts == [] - generations, next_batch = default_causal_lm.generate_token(next_batch) + generated_texts, next_batch = default_causal_lm.generate_token(next_batch) assert next_batch is None - assert len(generations) == 1 - assert generations[0].generated_text.text == "Test.java:784) at net.minecraft." - assert generations[0].request_id == default_causal_lm_batch.requests[0].id + assert len(generated_texts) == 1 + assert generated_texts[0].output_text == "Test.java:784) at net.minecraft." + assert generated_texts[0].request == default_causal_lm_batch.requests[0] + assert len(generated_texts[0].tokens) == len(generated_texts[0].logprobs) assert ( - generations[0].generated_text.generated_tokens + generated_texts[0].generated_tokens == default_causal_lm_batch.stopping_criterias[0].max_new_tokens ) @@ -154,20 +152,19 @@ def test_causal_lm_generate_token_completion_multi( for i in range( default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens - 1 ): - generations, next_batch = default_causal_lm.generate_token(next_batch) - assert len(generations) == len(next_batch) + generated_texts, next_batch = default_causal_lm.generate_token(next_batch) + assert generated_texts == [] - generations, next_batch = default_causal_lm.generate_token(next_batch) + generated_texts, next_batch = default_causal_lm.generate_token(next_batch) assert next_batch is not None - assert len(generations) == 2 - assert generations[1].generated_text.text == "Test.java:784)" + assert len(generated_texts) == 1 + assert generated_texts[0].output_text == "Test.java:784)" assert ( - generations[1].request_id - == default_multi_requests_causal_lm_batch.requests[1].id + generated_texts[0].request == default_multi_requests_causal_lm_batch.requests[1] ) assert ( - generations[1].generated_text.generated_tokens + generated_texts[0].generated_tokens == default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens ) @@ -176,20 +173,19 @@ def test_causal_lm_generate_token_completion_multi( - default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens - 1 ): - generations, next_batch = default_causal_lm.generate_token(next_batch) - assert len(generations) == len(next_batch) + generated_texts, next_batch = default_causal_lm.generate_token(next_batch) + assert generated_texts == [] - generations, next_batch = default_causal_lm.generate_token(next_batch) + generated_texts, next_batch = default_causal_lm.generate_token(next_batch) assert next_batch is None - assert len(generations) == 1 - assert generations[0].generated_text.text == "Test.java:784) at net.minecraft." + assert len(generated_texts) == 1 + assert generated_texts[0].output_text == "Test.java:784) at net.minecraft." assert ( - generations[0].request_id - == default_multi_requests_causal_lm_batch.requests[0].id + generated_texts[0].request == default_multi_requests_causal_lm_batch.requests[0] ) assert ( - generations[0].generated_text.generated_tokens + generated_texts[0].generated_tokens == default_multi_requests_causal_lm_batch.stopping_criterias[0].max_new_tokens ) @@ -248,20 +244,19 @@ def test_batch_concatenate( for _ in range( default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens - 2 ): - generations, next_batch = default_causal_lm.generate_token(next_batch) - assert len(generations) == len(next_batch) + generated_texts, next_batch = default_causal_lm.generate_token(next_batch) + assert generated_texts == [] - generations, next_batch = default_causal_lm.generate_token(next_batch) + generated_texts, next_batch = default_causal_lm.generate_token(next_batch) assert next_batch is not None - assert len(generations) == 3 - assert generations[2].generated_text.text == "Test.java:784)" + assert len(generated_texts) == 1 + assert generated_texts[0].output_text == "Test.java:784)" assert ( - generations[2].request_id - == default_multi_requests_causal_lm_batch.requests[1].id + generated_texts[0].request == default_multi_requests_causal_lm_batch.requests[1] ) assert ( - generations[2].generated_text.generated_tokens + generated_texts[0].generated_tokens == default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens ) @@ -270,17 +265,17 @@ def test_batch_concatenate( - default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens - 2 ): - generations, next_batch = default_causal_lm.generate_token(next_batch) - assert len(generations) == len(next_batch) + generated_texts, next_batch = default_causal_lm.generate_token(next_batch) + assert generated_texts == [] - generations, next_batch = default_causal_lm.generate_token(next_batch) + generated_texts, next_batch = default_causal_lm.generate_token(next_batch) assert next_batch is not None - assert len(generations) == 2 - assert generations[0].generated_text.text == "Test.java:784) at net.minecraft." - assert generations[0].request_id == default_causal_lm_batch.requests[0].id + assert len(generated_texts) == 1 + assert generated_texts[0].output_text == "Test.java:784) at net.minecraft." + assert generated_texts[0].request == default_causal_lm_batch.requests[0] assert ( - generations[0].generated_text.generated_tokens + generated_texts[0].generated_tokens == default_causal_lm_batch.stopping_criterias[0].max_new_tokens ) @@ -290,19 +285,18 @@ def test_batch_concatenate( - default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens - 4 ): - generations, next_batch = default_causal_lm.generate_token(next_batch) - assert len(generations) == len(next_batch) + generated_texts, next_batch = default_causal_lm.generate_token(next_batch) + assert generated_texts == [] - generations, next_batch = default_causal_lm.generate_token(next_batch) + generated_texts, next_batch = default_causal_lm.generate_token(next_batch) assert next_batch is None - assert len(generations) == 1 - assert generations[0].generated_text.text == "Test.java:784) at net.minecraft." + assert len(generated_texts) == 1 + assert generated_texts[0].output_text == "Test.java:784) at net.minecraft." assert ( - generations[0].request_id - == default_multi_requests_causal_lm_batch.requests[0].id + generated_texts[0].request == default_multi_requests_causal_lm_batch.requests[0] ) assert ( - generations[0].generated_text.generated_tokens + generated_texts[0].generated_tokens == default_multi_requests_causal_lm_batch.stopping_criterias[0].max_new_tokens ) diff --git a/server/tests/models/test_santacoder.py b/server/tests/models/test_santacoder.py index 1b69477d..acebec04 100644 --- a/server/tests/models/test_santacoder.py +++ b/server/tests/models/test_santacoder.py @@ -50,17 +50,18 @@ def test_santacoder_generate_token_completion(default_santacoder, default_pb_bat next_batch = batch for _ in range(batch.stopping_criterias[0].max_new_tokens - 1): - generations, next_batch = default_santacoder.generate_token(next_batch) - assert len(generations) == len(next_batch) + generated_texts, next_batch = default_santacoder.generate_token(next_batch) + assert generated_texts == [] - generations, next_batch = default_santacoder.generate_token(next_batch) + generated_texts, next_batch = default_santacoder.generate_token(next_batch) assert next_batch is None - assert len(generations) == 1 - assert generations[0].generated_text.text == "def test_get_all_users_with_" - assert generations[0].request_id == batch.requests[0].id + assert len(generated_texts) == 1 + assert generated_texts[0].output_text == "def test_get_all_users_with_" + assert generated_texts[0].request == batch.requests[0] + assert len(generated_texts[0].tokens) == len(generated_texts[0].logprobs) assert ( - generations[0].generated_text.generated_tokens + generated_texts[0].generated_tokens == batch.stopping_criterias[0].max_new_tokens ) @@ -75,19 +76,20 @@ def test_fim_santacoder_generate_token_completion( next_batch = batch for _ in range(batch.stopping_criterias[0].max_new_tokens - 1): - generations, next_batch = default_santacoder.generate_token(next_batch) - assert len(generations) == len(next_batch) + generated_texts, next_batch = default_santacoder.generate_token(next_batch) + assert generated_texts == [] - generations, next_batch = default_santacoder.generate_token(next_batch) + generated_texts, next_batch = default_santacoder.generate_token(next_batch) assert next_batch is None - assert len(generations) == 1 + assert len(generated_texts) == 1 assert ( - generations[0].generated_text.text + generated_texts[0].output_text == """defworldineProperty(exports, "__esModule", { value""" ) - assert generations[0].request_id == batch.requests[0].id + assert generated_texts[0].request == batch.requests[0] + assert len(generated_texts[0].tokens) == len(generated_texts[0].logprobs) assert ( - generations[0].generated_text.generated_tokens + generated_texts[0].generated_tokens == batch.stopping_criterias[0].max_new_tokens ) diff --git a/server/tests/models/test_seq2seq_lm.py b/server/tests/models/test_seq2seq_lm.py index 22c6ac9c..de1a4829 100644 --- a/server/tests/models/test_seq2seq_lm.py +++ b/server/tests/models/test_seq2seq_lm.py @@ -99,11 +99,11 @@ def test_seq2seq_lm_batch_type(default_seq2seq_lm): def test_seq2seq_lm_generate_token(default_seq2seq_lm, default_seq2seq_lm_batch): sequence_length = len(default_seq2seq_lm_batch.input_ids[0]) - generations, next_batch = default_seq2seq_lm.generate_token( + generated_texts, next_batch = default_seq2seq_lm.generate_token( default_seq2seq_lm_batch ) - assert len(generations) == len(next_batch) + assert generated_texts == [] assert isinstance(next_batch, Seq2SeqLMBatch) assert torch.equal(next_batch.input_ids, default_seq2seq_lm_batch.input_ids) @@ -145,11 +145,6 @@ def test_seq2seq_lm_generate_token(default_seq2seq_lm, default_seq2seq_lm_batch) for p in next_batch.past_key_values ] ) - assert all([generation.generated_text is None for generation in generations]) - assert all([len(generation.prefill_tokens) == 1 for generation in generations]) - assert all([generation.token_id.item() == 259 for generation in generations]) - assert all([generation.token_text == "" for generation in generations]) - assert generations[0].request_id == 0 def test_seq2seq_lm_generate_token_completion( @@ -157,16 +152,16 @@ def test_seq2seq_lm_generate_token_completion( ): next_batch = default_seq2seq_lm_batch for _ in range(6): - generations, next_batch = default_seq2seq_lm.generate_token(next_batch) - assert len(generations) == len(next_batch) + generated_texts, next_batch = default_seq2seq_lm.generate_token(next_batch) + assert generated_texts == [] - generations, next_batch = default_seq2seq_lm.generate_token(next_batch) + generated_texts, next_batch = default_seq2seq_lm.generate_token(next_batch) assert next_batch is None - assert len(generations) == 1 - assert generations[0].generated_text.text == "a few weeks" - assert generations[0].request_id == default_seq2seq_lm_batch.requests[0].id - assert generations[0].generated_text.generated_tokens == 7 + assert len(generated_texts) == 1 + assert generated_texts[0].output_text == "a few weeks" + assert generated_texts[0].request == default_seq2seq_lm_batch.requests[0] + assert generated_texts[0].generated_tokens == 7 def test_seq2seq_lm_generate_token_completion_multi( @@ -175,33 +170,33 @@ def test_seq2seq_lm_generate_token_completion_multi( next_batch = default_multi_requests_seq2seq_lm_batch for i in range(4): - generations, next_batch = default_seq2seq_lm.generate_token(next_batch) - assert len(generations) == len(next_batch) + generated_texts, next_batch = default_seq2seq_lm.generate_token(next_batch) + assert generated_texts == [] - generations, next_batch = default_seq2seq_lm.generate_token(next_batch) + generated_texts, next_batch = default_seq2seq_lm.generate_token(next_batch) assert next_batch is not None - assert len(generations) == 2 - assert generations[1].generated_text.text == "a few " + assert len(generated_texts) == 1 + assert generated_texts[0].output_text == "a few " assert ( - generations[1].request_id - == default_multi_requests_seq2seq_lm_batch.requests[1].id + generated_texts[0].request + == default_multi_requests_seq2seq_lm_batch.requests[1] ) - assert generations[1].generated_text.generated_tokens == 5 + assert generated_texts[0].generated_tokens == 5 - generations, next_batch = default_seq2seq_lm.generate_token(next_batch) - assert len(generations) == len(next_batch) + generated_texts, next_batch = default_seq2seq_lm.generate_token(next_batch) + assert generated_texts == [] - generations, next_batch = default_seq2seq_lm.generate_token(next_batch) + generated_texts, next_batch = default_seq2seq_lm.generate_token(next_batch) assert next_batch is None - assert len(generations) == 1 - assert generations[0].generated_text.text == "a few weeks" + assert len(generated_texts) == 1 + assert generated_texts[0].output_text == "a few weeks" assert ( - generations[0].request_id - == default_multi_requests_seq2seq_lm_batch.requests[0].id + generated_texts[0].request + == default_multi_requests_seq2seq_lm_batch.requests[0] ) - assert generations[0].generated_text.generated_tokens == 7 + assert generated_texts[0].generated_tokens == 7 def test_batch_concatenate( @@ -296,35 +291,35 @@ def test_batch_concatenate( ) for _ in range(3): - generations, next_batch = default_seq2seq_lm.generate_token(next_batch) - assert len(generations) == len(next_batch) + generated_texts, next_batch = default_seq2seq_lm.generate_token(next_batch) + assert generated_texts == [] - generations, next_batch = default_seq2seq_lm.generate_token(next_batch) + generated_texts, next_batch = default_seq2seq_lm.generate_token(next_batch) assert next_batch is not None - assert len(generations) == 3 - assert generations[2].generated_text.text == "a few " + assert len(generated_texts) == 1 + assert generated_texts[0].output_text == "a few " assert ( - generations[2].request_id - == default_multi_requests_seq2seq_lm_batch.requests[1].id + generated_texts[0].request + == default_multi_requests_seq2seq_lm_batch.requests[1] ) - assert generations[2].generated_text.generated_tokens == 5 + assert generated_texts[0].generated_tokens == 5 - generations, next_batch = default_seq2seq_lm.generate_token(next_batch) + generated_texts, next_batch = default_seq2seq_lm.generate_token(next_batch) assert next_batch is not None - assert len(generations) == 2 - assert generations[0].generated_text.text == "a few weeks" - assert generations[0].request_id == default_seq2seq_lm_batch.requests[0].id - assert generations[0].generated_text.generated_tokens == 7 + assert len(generated_texts) == 1 + assert generated_texts[0].output_text == "a few weeks" + assert generated_texts[0].request == default_seq2seq_lm_batch.requests[0] + assert generated_texts[0].generated_tokens == 7 - generations, next_batch = default_seq2seq_lm.generate_token(next_batch) + generated_texts, next_batch = default_seq2seq_lm.generate_token(next_batch) assert next_batch is None - assert len(generations) == 1 - assert generations[0].generated_text.text == "a few weeks" + assert len(generated_texts) == 1 + assert generated_texts[0].output_text == "a few weeks" assert ( - generations[0].request_id - == default_multi_requests_seq2seq_lm_batch.requests[0].id + generated_texts[0].request + == default_multi_requests_seq2seq_lm_batch.requests[0] ) - assert generations[0].generated_text.generated_tokens == 7 + assert generated_texts[0].generated_tokens == 7 diff --git a/server/text_generation/models/causal_lm.py b/server/text_generation/models/causal_lm.py index a8fc23fe..ccd4c3ba 100644 --- a/server/text_generation/models/causal_lm.py +++ b/server/text_generation/models/causal_lm.py @@ -5,7 +5,7 @@ from transformers import AutoTokenizer, AutoModelForCausalLM, PreTrainedTokenize from typing import Optional, Tuple, List, Type from text_generation.models import Model -from text_generation.models.types import Batch, PrefillTokens, Generation, GeneratedText +from text_generation.models.types import GeneratedText, Batch from text_generation.pb import generate_pb2 from text_generation.utils import NextTokenChooser, StoppingCriteria, Sampling @@ -23,6 +23,7 @@ class CausalLMBatch(Batch): # All tokens all_input_ids: List[torch.Tensor] + all_logprobs: List[Optional[torch.Tensor]] # Lengths of all generations present in the batch input_lengths: List[int] @@ -56,6 +57,7 @@ class CausalLMBatch(Batch): next_token_choosers = [] stopping_criterias = [] input_lengths = [] + all_logprobs = [] # Parse batch for r in pb.requests: @@ -65,6 +67,7 @@ class CausalLMBatch(Batch): stopping_criterias.append( StoppingCriteria.from_pb(r.stopping_parameters, tokenizer) ) + all_logprobs.append(None) pad_to_multiple_of = 8 if device.type == "cuda" else None tokenized_inputs = tokenizer( @@ -86,6 +89,7 @@ class CausalLMBatch(Batch): position_ids=position_ids, past_key_values=None, all_input_ids=all_input_ids, + all_logprobs=all_logprobs, input_lengths=input_lengths, next_token_choosers=next_token_choosers, stopping_criterias=stopping_criterias, @@ -103,6 +107,7 @@ class CausalLMBatch(Batch): requests = [] input_lengths = [] all_input_ids = [] + all_logprobs = [] next_token_choosers = [] stopping_criterias = [] @@ -119,6 +124,7 @@ class CausalLMBatch(Batch): requests.extend(batch.requests) input_lengths.extend(batch.input_lengths) all_input_ids.extend(batch.all_input_ids) + all_logprobs.extend(batch.all_logprobs) next_token_choosers.extend(batch.next_token_choosers) stopping_criterias.extend(batch.stopping_criterias) @@ -219,6 +225,7 @@ class CausalLMBatch(Batch): position_ids=position_ids, past_key_values=past_key_values, all_input_ids=all_input_ids, + all_logprobs=all_logprobs, input_lengths=input_lengths, next_token_choosers=next_token_choosers, stopping_criterias=stopping_criterias, @@ -227,9 +234,6 @@ class CausalLMBatch(Batch): keys_head_dim_last=batches[0].keys_head_dim_last, ) - def __len__(self): - return len(self.requests) - class CausalLM(Model): def __init__(self, model_name: str, quantize=False): @@ -285,7 +289,7 @@ class CausalLM(Model): def generate_token( self, batch: CausalLMBatch - ) -> Tuple[List[Generation], Optional[CausalLMBatch]]: + ) -> Tuple[List[GeneratedText], Optional[CausalLMBatch]]: # For some reason, inference_mode does not work well with GLOO which we use on CPU context_manager = ( torch.no_grad if self.device.type == "cpu" else torch.inference_mode @@ -305,13 +309,14 @@ class CausalLM(Model): next_batch_input_lengths = [] next_batch_input_ids = [] next_batch_all_input_ids = [] + next_batch_all_logprobs = [] # Metadata next_batch_size = 0 next_batch_max_sequence_length = 0 - # Results - generations: List[Generation] = [] + # Finished requests + generated_texts: List[GeneratedText] = [] # Zipped iterator iterator = zip( @@ -321,6 +326,7 @@ class CausalLM(Model): batch.next_token_choosers, batch.stopping_criterias, batch.all_input_ids, + batch.all_logprobs, ) # For each member of the batch @@ -331,36 +337,44 @@ class CausalLM(Model): next_token_chooser, stopping_criteria, all_input_ids, + all_logprobs, ) in enumerate(iterator): # Select next token tokens, logprobs = next_token_chooser(all_input_ids, logits) - next_token_id = tokens[-1].view(1, 1) + next_token = tokens[-1].view(1, 1) # Append next token to all tokens - all_input_ids = torch.cat([all_input_ids, next_token_id]) + all_input_ids = torch.cat([all_input_ids, next_token]) new_input_length = input_length + 1 - # Generated token - next_token_logprob = logprobs[-1, next_token_id] - next_token_id_squeezed = next_token_id.squeeze() - next_token_text = self.tokenizer.decode( - next_token_id_squeezed, - clean_up_tokenization_spaces=False, - skip_special_tokens=False, - ) + if all_logprobs is None: + # logprobs of all prompt tokens (except the first one) and the generated token + all_logprobs = logprobs.gather(1, all_input_ids[1:]) + else: + # logprob of the generated token + next_token_logprob = logprobs[-1, next_token] + all_logprobs = torch.cat([all_logprobs, next_token_logprob]) # Evaluate stopping criteria stop, reason = stopping_criteria( - next_token_id_squeezed, - next_token_text, + next_token.squeeze(), + self.tokenizer.decode( + next_token.squeeze(), clean_up_tokenization_spaces=False + ), ) - if stop: # Decode generated tokens generated_text = self.decode( all_input_ids[-stopping_criteria.current_tokens :, 0] ) output_text = request.inputs + generated_text + # Slice with input_length to remove padding + token_ids = all_input_ids[-new_input_length:] + tokens = self.tokenizer.batch_decode(token_ids) + # Add NaN for the first prompt token + logprobs = [float("nan")] + all_logprobs[-input_length:].squeeze( + 1 + ).tolist() # Get seed if isinstance(next_token_chooser.choice, Sampling): @@ -368,58 +382,39 @@ class CausalLM(Model): else: seed = None - generated_text = GeneratedText( - output_text, stopping_criteria.current_tokens, reason, seed + # Add to the list of finished generations with the original request + generated_texts.append( + GeneratedText( + request=request, + output_text=output_text, + generated_tokens=stopping_criteria.current_tokens, + tokens=tokens, + token_ids=token_ids.squeeze(1).tolist(), + logprobs=logprobs, + reason=reason, + seed=seed, + ) ) + # add to the next batch else: - # Keep request in the batch - generated_text = None next_batch_keep_indices.append(i) - next_batch_input_ids.append(next_token_id) + next_batch_input_ids.append(next_token) next_batch_all_input_ids.append(all_input_ids) + next_batch_all_logprobs.append(all_logprobs) next_batch_size += 1 next_batch_input_lengths.append(new_input_length) next_batch_max_sequence_length = max( next_batch_max_sequence_length, new_input_length ) - # Prefill - if stopping_criteria.current_tokens == 1: - # Remove generated token to only have prefill and add nan for first prompt token - prefill_logprobs = [float("nan")] + logprobs.gather( - 1, all_input_ids[1:] - ).squeeze(1)[-new_input_length:-1].tolist() - prefill_token_ids = all_input_ids[-new_input_length:-1] - prefill_texts = self.tokenizer.batch_decode( - prefill_token_ids, - clean_up_tokenization_spaces=False, - skip_special_tokens=False, - ) - prefill_tokens = PrefillTokens( - prefill_token_ids, prefill_logprobs, prefill_texts - ) - else: - prefill_tokens = None - - generation = Generation( - request.id, - prefill_tokens, - next_token_id_squeezed, - next_token_logprob, - next_token_text, - generated_text, - ) - - generations.append(generation) - # We finished all generations in the batch; there is no next batch if not next_batch_keep_indices: - return generations, None + return generated_texts, None next_batch_input_ids = torch.cat(next_batch_input_ids, dim=0) # If we finished at least one generation, we need to evict the indices of the generations that finished # from the values of the next batch - if len(next_batch_keep_indices) != len(batch): + if generated_texts: # Apply indices to attention mask, past key values and other items that need to be cached next_batch_attention_mask = batch.attention_mask[next_batch_keep_indices] next_batch_position_ids = batch.position_ids[next_batch_keep_indices] @@ -466,6 +461,7 @@ class CausalLM(Model): position_ids=next_batch_position_ids, past_key_values=next_batch_past_key_values, all_input_ids=next_batch_all_input_ids, + all_logprobs=next_batch_all_logprobs, input_lengths=next_batch_input_lengths, next_token_choosers=next_batch_next_token_choosers, stopping_criterias=next_batch_stopping_criterias, @@ -473,4 +469,4 @@ class CausalLM(Model): max_sequence_length=next_batch_max_sequence_length, keys_head_dim_last=batch.keys_head_dim_last, ) - return generations, next_batch + return generated_texts, next_batch diff --git a/server/text_generation/models/seq2seq_lm.py b/server/text_generation/models/seq2seq_lm.py index 41657af2..f965ea88 100644 --- a/server/text_generation/models/seq2seq_lm.py +++ b/server/text_generation/models/seq2seq_lm.py @@ -5,7 +5,7 @@ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, PreTrainedTokeniz from typing import Optional, Tuple, List, Type from text_generation.models import Model -from text_generation.models.types import GeneratedText, Batch, Generation, PrefillTokens +from text_generation.models.types import GeneratedText, Batch from text_generation.pb import generate_pb2 from text_generation.utils import NextTokenChooser, StoppingCriteria, Sampling @@ -30,6 +30,7 @@ class Seq2SeqLMBatch(Batch): # Lengths of all generations present in the batch input_lengths: List[int] decoder_input_lengths: List[int] + decoder_logprobs: List[Optional[torch.Tensor]] # Generation helpers next_token_choosers: List[NextTokenChooser] @@ -63,6 +64,7 @@ class Seq2SeqLMBatch(Batch): decoder_input_ids = [] decoder_input_lengths = [] + decoder_logprobs = [] # Parse batch for r in pb.requests: @@ -75,6 +77,7 @@ class Seq2SeqLMBatch(Batch): stopping_criterias.append( StoppingCriteria.from_pb(r.stopping_parameters, tokenizer) ) + decoder_logprobs.append(None) # Tokenize batch pad_to_multiple_of = 8 if device.type == "cuda" else None @@ -99,6 +102,7 @@ class Seq2SeqLMBatch(Batch): past_key_values=None, input_lengths=input_lengths, decoder_input_lengths=decoder_input_lengths, + decoder_logprobs=decoder_logprobs, next_token_choosers=next_token_choosers, stopping_criterias=stopping_criterias, size=len(pb.requests), @@ -121,6 +125,7 @@ class Seq2SeqLMBatch(Batch): requests = [] input_lengths = [] decoder_input_lengths = [] + decoder_logprobs = [] next_token_choosers = [] stopping_criterias = [] @@ -141,6 +146,7 @@ class Seq2SeqLMBatch(Batch): requests.extend(batch.requests) input_lengths.extend(batch.input_lengths) decoder_input_lengths.extend(batch.decoder_input_lengths) + decoder_logprobs.extend(batch.decoder_logprobs) next_token_choosers.extend(batch.next_token_choosers) stopping_criterias.extend(batch.stopping_criterias) @@ -277,6 +283,7 @@ class Seq2SeqLMBatch(Batch): past_key_values=past_key_values, input_lengths=input_lengths, decoder_input_lengths=decoder_input_lengths, + decoder_logprobs=decoder_logprobs, next_token_choosers=next_token_choosers, stopping_criterias=stopping_criterias, size=total_batch_size, @@ -284,9 +291,6 @@ class Seq2SeqLMBatch(Batch): max_decoder_input_length=max_decoder_input_length, ) - def __len__(self): - return len(self.requests) - class Seq2SeqLM(Model): def __init__(self, model_name: str, quantize=False): @@ -360,7 +364,7 @@ class Seq2SeqLM(Model): def generate_token( self, batch: Seq2SeqLMBatch - ) -> Tuple[List[Generation], Optional[Seq2SeqLMBatch]]: + ) -> Tuple[List[GeneratedText], Optional[Seq2SeqLMBatch]]: # For some reason, inference_mode does not work well with GLOO which we use on CPU context_manager = ( torch.no_grad if self.device.type == "cpu" else torch.inference_mode @@ -382,6 +386,7 @@ class Seq2SeqLM(Model): next_batch_input_lengths = [] next_batch_decoder_input_ids = [] next_batch_decoder_input_lengths = [] + next_batch_decoder_logprobs = [] # Metadata next_batch_size = 0 @@ -389,13 +394,14 @@ class Seq2SeqLM(Model): next_batch_max_decoder_input_length = 0 # Finished requests - generations: List[Generation] = [] + generated_texts: List[GeneratedText] = [] # Zipped iterator iterator = zip( batch.requests, batch.input_lengths, batch.decoder_input_lengths, + batch.decoder_logprobs, logits, batch.next_token_choosers, batch.stopping_criterias, @@ -408,6 +414,7 @@ class Seq2SeqLM(Model): request, input_length, decoder_input_length, + decoder_logprobs, logits, next_token_chooser, stopping_criteria, @@ -415,28 +422,35 @@ class Seq2SeqLM(Model): decoder_input_ids, ) in enumerate(iterator): # Select next token - next_token_id, logprobs = next_token_chooser(decoder_input_ids, logits) + next_token, logprobs = next_token_chooser(decoder_input_ids, logits) # Append next token to decoder tokens - decoder_input_ids = torch.cat([decoder_input_ids, next_token_id]) + decoder_input_ids = torch.cat([decoder_input_ids, next_token]) new_decoder_input_length = decoder_input_length + 1 - # Generated token - next_token_logprob = logprobs[-1, next_token_id] - next_token_id_squeezed = next_token_id.squeeze() - next_token_text = self.tokenizer.decode( - next_token_id_squeezed, - clean_up_tokenization_spaces=False, - skip_special_tokens=False, - ) + next_token_logprob = logprobs[-1, next_token] + if decoder_logprobs is None: + decoder_logprobs = next_token_logprob + else: + decoder_logprobs = torch.cat([decoder_logprobs, next_token_logprob]) # Evaluate stopping criteria - stop, reason = stopping_criteria(next_token_id, next_token_text) - + stop, reason = stopping_criteria( + next_token.squeeze(), + self.tokenizer.decode( + next_token.squeeze(), clean_up_tokenization_spaces=False + ), + ) if stop: # Slice with decoder_input_length to remove padding # Decode all tokens - output_text = self.decode(decoder_input_ids[-new_decoder_input_length:]) + token_ids = decoder_input_ids[-new_decoder_input_length:] + output_text = self.decode(token_ids) + tokens = self.tokenizer.batch_decode(token_ids) + # Add NaN for the bos token + logprobs = [float("nan")] + decoder_logprobs[ + -decoder_input_length: + ].tolist() # Get seed if isinstance(next_token_chooser.choice, Sampling): @@ -444,17 +458,27 @@ class Seq2SeqLM(Model): else: seed = None - generated_text = GeneratedText( - output_text, stopping_criteria.current_tokens, reason, seed + # Add to the list of finished generations with the original request + generated_texts.append( + GeneratedText( + request=request, + output_text=output_text, + generated_tokens=stopping_criteria.current_tokens, + tokens=tokens, + token_ids=token_ids.tolist(), + logprobs=logprobs, + reason=reason, + seed=seed, + ) ) + # add to the next batch else: - # Keep request in the batch - generated_text = None next_batch_keep_indices.append(i) next_batch_decoder_input_ids.append(decoder_input_ids.unsqueeze(0)) next_batch_size += 1 next_batch_input_lengths.append(input_length) next_batch_decoder_input_lengths.append(new_decoder_input_length) + next_batch_decoder_logprobs.append(decoder_logprobs) next_batch_max_input_length = max( next_batch_max_input_length, input_length ) @@ -462,39 +486,14 @@ class Seq2SeqLM(Model): next_batch_max_decoder_input_length, new_decoder_input_length ) - # Prefill - if stopping_criteria.current_tokens == 1: - prefill_token_ids = decoder_input_ids[-new_decoder_input_length:-1] - prefill_texts = self.tokenizer.batch_decode( - prefill_token_ids, - clean_up_tokenization_spaces=False, - skip_special_tokens=False, - ) - prefill_tokens = PrefillTokens( - prefill_token_ids, [float("nan")], prefill_texts - ) - else: - prefill_tokens = None - - generation = Generation( - request.id, - prefill_tokens, - next_token_id_squeezed, - next_token_logprob, - next_token_text, - generated_text, - ) - - generations.append(generation) - # We finished all generations in the batch; there is no next batch if not next_batch_keep_indices: - return generations, None + return generated_texts, None next_batch_decoder_input_ids = torch.cat(next_batch_decoder_input_ids) # If we finished at least one generation, we need to evict the indices of the generations that finished # from the values of the next batch - if len(next_batch_keep_indices) != len(batch): + if generated_texts: # Apply indices to attention mask, past key values and other items that need to be cached next_batch_input_ids = batch.input_ids[next_batch_keep_indices] next_batch_attention_mask = batch.attention_mask[next_batch_keep_indices] @@ -552,10 +551,11 @@ class Seq2SeqLM(Model): past_key_values=next_batch_past_key_values, input_lengths=next_batch_input_lengths, decoder_input_lengths=next_batch_decoder_input_lengths, + decoder_logprobs=next_batch_decoder_logprobs, next_token_choosers=next_batch_next_token_choosers, stopping_criterias=next_batch_stopping_criterias, size=next_batch_size, max_input_length=next_batch_max_input_length, max_decoder_input_length=next_batch_max_decoder_input_length, ) - return generations, next_batch + return generated_texts, next_batch diff --git a/server/text_generation/models/types.py b/server/text_generation/models/types.py index 30cd716a..4ee3cb32 100644 --- a/server/text_generation/models/types.py +++ b/server/text_generation/models/types.py @@ -29,61 +29,26 @@ class Batch(ABC): def concatenate(cls, batches: List["Batch"]) -> "Batch": raise NotImplementedError - @abstractmethod - def __len__(self): - raise NotImplementedError - @dataclass class GeneratedText: - text: str + request: generate_pb2.Request + output_text: str generated_tokens: int - finish_reason: str + tokens: List[str] + token_ids: List[int] + logprobs: List[float] + reason: str seed: Optional[int] def to_pb(self) -> generate_pb2.GeneratedText: return generate_pb2.GeneratedText( - text=self.text, + request=self.request, + output_text=self.output_text, generated_tokens=self.generated_tokens, - finish_reason=self.finish_reason, + tokens=self.tokens, + token_ids=self.token_ids, + logprobs=self.logprobs, + finish_reason=self.reason, seed=self.seed, ) - - -@dataclass -class PrefillTokens: - token_ids: List[int] - logprobs: List[float] - texts: List[str] - - def to_pb(self) -> generate_pb2.PrefillTokens: - return generate_pb2.PrefillTokens( - ids=self.token_ids, logprobs=self.logprobs, texts=self.texts - ) - - def __len__(self): - return len(self.token_ids) - - -@dataclass -class Generation: - request_id: int - prefill_tokens: Optional[PrefillTokens] - token_id: int - token_logprob: float - token_text: str - generated_text: Optional[GeneratedText] - - def to_pb(self) -> generate_pb2.Generation: - return generate_pb2.Generation( - request_id=self.request_id, - prefill_tokens=self.prefill_tokens.to_pb() - if self.prefill_tokens is not None - else None, - token_id=self.token_id, - token_logprob=self.token_logprob, - token_text=self.token_text, - generated_text=self.generated_text.to_pb() - if self.generated_text is not None - else None, - ) diff --git a/server/text_generation/server.py b/server/text_generation/server.py index a2bad8a7..5fd3072e 100644 --- a/server/text_generation/server.py +++ b/server/text_generation/server.py @@ -27,20 +27,22 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer): self.cache.clear() return generate_pb2.ClearCacheResponse() - async def Prefill(self, request, context): + async def Generate(self, request, context): batch = self.model.batch_type.from_pb( request.batch, self.model.tokenizer, self.model.device ) - generations, next_batch = self.model.generate_token(batch) + generated_texts, next_batch = self.model.generate_token(batch) self.cache.set(next_batch) - return generate_pb2.PrefillResponse( - generations=[generation.to_pb() for generation in generations], + return generate_pb2.GenerateResponse( + generated_texts=[ + generated_text.to_pb() for generated_text in generated_texts + ], batch=next_batch.to_pb() if next_batch else None, ) - async def Decode(self, request, context): + async def GenerateWithCache(self, request, context): if len(request.batches) == 0: raise ValueError("Must provide at least one batch") @@ -56,11 +58,13 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer): else: batch = batches[0] - generations, next_batch = self.model.generate_token(batch) + generated_texts, next_batch = self.model.generate_token(batch) self.cache.set(next_batch) - return generate_pb2.DecodeResponse( - generations=[generation.to_pb() for generation in generations], + return generate_pb2.GenerateWithCacheResponse( + generated_texts=[ + generated_text.to_pb() for generated_text in generated_texts + ], batch=next_batch.to_pb() if next_batch else None, )