From 73c39032142f7bc7243402e7ec9fa9f25e4bbcb3 Mon Sep 17 00:00:00 2001 From: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com> Date: Tue, 11 Jun 2024 12:38:07 +0200 Subject: [PATCH] FlashCausalLM implem --- proto/v3/generate.proto | 9 +- router/client/src/v3/client.rs | 4 +- router/client/src/v3/mod.rs | 2 +- router/client/src/v3/sharded_client.rs | 4 +- router/src/infer/v3/scheduler.rs | 351 ++++++++++++------ .../models/flash_causal_lm.py | 23 +- server/text_generation_server/models/types.py | 2 +- server/text_generation_server/server.py | 6 +- 8 files changed, 273 insertions(+), 128 deletions(-) diff --git a/proto/v3/generate.proto b/proto/v3/generate.proto index 3c9b1d71..8138e4fb 100644 --- a/proto/v3/generate.proto +++ b/proto/v3/generate.proto @@ -224,11 +224,18 @@ message FilterBatchRequest { repeated uint64 terminated_request_ids = 3; } +message TerminatedGeneration { + // Request ID + uint64 id = 1; + // Generated text + GeneratedText generated_text = 2; +} + message FilterBatchResponse { /// Filtered Batch (cached) CachedBatch batch = 1; /// Terminated generations - repeated GeneratedText terminated_generations = 2; + repeated TerminatedGeneration terminated_generations = 2; } diff --git a/router/client/src/v3/client.rs b/router/client/src/v3/client.rs index 90f43270..1f8070ca 100644 --- a/router/client/src/v3/client.rs +++ b/router/client/src/v3/client.rs @@ -92,7 +92,7 @@ impl Client { batch_id: u64, kept_requests: Vec, terminated_request_ids: Vec, - ) -> Result> { + ) -> Result<(Option, Vec)> { let request = tonic::Request::new(FilterBatchRequest { batch_id, kept_requests, @@ -100,7 +100,7 @@ impl Client { }) .inject_context(); let filtered_batch = self.stub.filter_batch(request).await?.into_inner(); - Ok(filtered_batch.batch) + Ok((filtered_batch.batch, filtered_batch.terminated_generations)) } /// Warmup on a max size batch diff --git a/router/client/src/v3/mod.rs b/router/client/src/v3/mod.rs index ea7486ee..9df17c50 100644 --- a/router/client/src/v3/mod.rs +++ b/router/client/src/v3/mod.rs @@ -8,6 +8,6 @@ pub use client::Client; pub use pb::generate::v3::{ input_chunk::Chunk, Batch, CachedBatch, FinishReason, GeneratedText, Generation, GrammarType, HealthResponse, Image, InfoResponse, Input, InputChunk, KeptRequest, - NextTokenChooserParameters, Request, StoppingCriteriaParameters, Tokens, + NextTokenChooserParameters, Request, StoppingCriteriaParameters, TerminatedGeneration, Tokens, }; pub use sharded_client::ShardedClient; diff --git a/router/client/src/v3/sharded_client.rs b/router/client/src/v3/sharded_client.rs index e1b35a21..3f11e101 100644 --- a/router/client/src/v3/sharded_client.rs +++ b/router/client/src/v3/sharded_client.rs @@ -2,7 +2,7 @@ use crate::{v3, Health, ShardInfo}; use crate::{ClientError, Result}; -use crate::v3::{Chunk, InfoResponse, Input}; +use crate::v3::{Chunk, InfoResponse, Input, TerminatedGeneration}; use async_trait::async_trait; use futures::future::join_all; use tonic::transport::Uri; @@ -86,7 +86,7 @@ impl ShardedClient { batch_id: u64, kept_requests: Vec, terminated_request_ids: Vec, - ) -> Result> { + ) -> Result<(Option, Vec)> { let futures: Vec<_> = self .clients .iter_mut() diff --git a/router/src/infer/v3/scheduler.rs b/router/src/infer/v3/scheduler.rs index 5fb9e11d..ee93c20a 100644 --- a/router/src/infer/v3/scheduler.rs +++ b/router/src/infer/v3/scheduler.rs @@ -5,12 +5,14 @@ use crate::infer::{ }; use crate::validation::ValidGenerateRequest; use crate::{FinishReason, PrefillToken, Token}; -use nohash_hasher::IntMap; +use nohash_hasher::{BuildNoHashHasher, IntMap}; use std::sync::{ atomic::{AtomicBool, Ordering}, Arc, }; -use text_generation_client::v3::{Batch, CachedBatch, Generation, KeptRequest, ShardedClient}; +use text_generation_client::v3::{ + Batch, CachedBatch, Generation, KeptRequest, ShardedClient, TerminatedGeneration, +}; use text_generation_client::ClientError; use tokio::sync::mpsc::error::SendError; use tokio::sync::{mpsc, Notify, OwnedSemaphorePermit}; @@ -243,11 +245,38 @@ async fn prefill( generation_health.store(true, Ordering::SeqCst); let start_filtering_time = Instant::now(); - // Send generated tokens and filter stopped entries - filter_send_generations(generations, entries); + // Filter and send finished generations + let filtered_stream_responses = filter_send_ended_generations(generations, entries); + + // Iterate on intermediate generations + for (id, stream_responses) in filtered_stream_responses { + // Get entry + let entry = entries + .get_mut(&id) + .expect("ID not found in entries. This is a bug."); + + // Send intermediate responses + if let Err(_) = send_stream_responses(stream_responses, entry).map_err(|err| { + tracing::error!("Entry response channel error."); + metrics::increment_counter!("tgi_request_failure", "err" => "dropped"); + err + }) { + // Sending failed, remove entry + entries + .remove(&id) + .expect("ID not found in entries. This is a bug."); + } + } // Filter next batch and remove requests that were stopped - let next_batch = filter_batch(client, next_batch, entries, false).await; + let next_batch = match next_batch { + Some(batch) if batch.size as usize != entries.len() => { + let (filtered_batch, _) = + filter_batch(client, batch, entries, &IntMap::default()).await; + filtered_batch + } + batch => batch, + }; metrics::histogram!("tgi_batch_forward_duration", timings.forward.as_secs_f64(), "method" => "prefill"); metrics::histogram!("tgi_batch_decode_duration", timings.decode.as_secs_f64(), "method" => "prefill"); @@ -285,13 +314,32 @@ async fn decode( generation_health.store(true, Ordering::SeqCst); let start_filtering_time = Instant::now(); - // Send generated tokens and filter stopped entries - filter_send_generations(generations, entries); - let updated = filter_update_allocations(entries).await; + // Filter and send finished generations + let mut filtered_stream_responses = filter_send_ended_generations(generations, entries); + // Send `StreamResponseInfer::Intermediate` messages for entries that don't need to be + // re-allocated, + // Allocated new blocks for entries that go over their allocation + // Filter entries that couldn't be re-allocated and add them to `terminated_entries` + let (force_update, terminated_entries) = + filter_send_update_allocations(entries, &mut filtered_stream_responses); - // Filter next batch and remove requests that were stopped - let next_batch = filter_batch(client, next_batch, entries, updated).await; + let next_batch = match next_batch { + // Run Only on re-allocation or if entries were filtered + Some(batch) if batch.size as usize != entries.len() || force_update => { + // Filter next batch: remove requests that were stopped and update blocks/slots + let (filtered_batch, terminated_generations) = + filter_batch(client, batch, entries, &terminated_entries).await; + send_terminated_generations( + terminated_generations, + terminated_entries, + filtered_stream_responses, + ); + + filtered_batch + } + batch => batch, + }; if let Some(concat_duration) = timings.concat { metrics::histogram!("tgi_batch_concat_duration", concat_duration.as_secs_f64(), "method" => "decode"); @@ -320,27 +368,20 @@ async fn decode( #[instrument(skip_all)] async fn filter_batch( client: &mut ShardedClient, - next_batch: Option, + batch: CachedBatch, entries: &IntMap, - force_update: bool, -) -> Option { - let batch = next_batch?; - - // No need to filter - if batch.size as usize == entries.len() && !force_update { - return Some(batch); - } - + terminated_entries: &IntMap, +) -> (Option, Vec) { let id = batch.id; - if entries.is_empty() { + if entries.is_empty() && terminated_entries.is_empty() { // All requests have been filtered out // Next batch is now empty // Clear it from the Python shards cache // We unwrap here as we need to panic since we cannot recover if this method fails client.clear_cache(Some(id)).await.unwrap(); - None + Default::default() } else { - // Filter Python shard cache + // Collect new blocks/slots let updated_requests = entries .iter() .map(|(request_id, entry)| { @@ -348,7 +389,7 @@ async fn filter_batch( .block_allocation .as_ref() .map(|alloc| (alloc.blocks().to_vec(), alloc.slots().to_vec())) - .unwrap_or((Vec::new(), Vec::new())); + .unwrap_or_default(); KeptRequest { id: *request_id, @@ -358,111 +399,207 @@ async fn filter_batch( }) .collect(); + // Filter Python shard cache // We unwrap here as we need to panic since we cannot recover if this method fails client - .filter_batch(id, updated_requests, Vec::new()) + .filter_batch( + id, + updated_requests, + terminated_entries.keys().map(|v| *v).collect(), + ) .await .unwrap() } } -/// Send one or multiple `InferStreamResponse` to Infer for all `entries` -/// and filter entries +/// #[instrument(skip_all)] -fn filter_send_generations(generations: Vec, entries: &mut IntMap) { - generations.into_iter().for_each(|generation| { +fn send_terminated_generations( + terminated_generations: Vec, + terminated_entries: IntMap, + mut stream_responses: IntMap>, +) { + // Receive final message for terminated generations + 'terminated_generations: for terminated_generation in terminated_generations { + let id = terminated_generation.id; + // Get entry for this generation + let entry = terminated_entries + .get(&id) + .expect("ID not found in entries. This is a bug."); + // Get previous `InferStreamResponse` for this generation + let stream_responses = stream_responses + .remove(&id) + .expect("ID not found in stream_responses. This is a bug."); + + // Peekable iterator to know when we are at the last `InferStreamResponse` + let mut iterator = stream_responses.into_iter().peekable(); + + while let Some(stream_response) = iterator.next() { + let response = if iterator.peek().is_none() { + // Last `InferStreamResponse::Intermediate` + let (token, top_tokens) = match stream_response { + InferStreamResponse::Intermediate { token, top_tokens } => (token, top_tokens), + _ => unreachable!(), + }; + // Modify it to be a `InferStreamResponse::End` with the new OutOfResources finish + // reason + InferStreamResponse::End { + token, + top_tokens, + generated_text: GeneratedText::from( + terminated_generation + .generated_text + .clone() + .expect("Generated Text is None. This is a bug."), + ), + queued: entry.queue_time, + start: entry.batch_time.unwrap(), + } + } else { + stream_response + }; + + // Send responses + if let Err(_) = entry.response_tx.send(Ok(response)).map_err(|err| { + tracing::error!("Entry response channel error."); + metrics::increment_counter!("tgi_request_failure", "err" => "dropped"); + err + }) { + continue 'terminated_generations; + } + } + } +} + +/// Send `InferStreamResponse::End` to `Infer` for finished entries and remove them from `entries` +/// Returns filtered `InferStreamResponse::Intermediate` generations +#[instrument(skip_all)] +fn filter_send_ended_generations( + generations: Vec, + entries: &mut IntMap, +) -> IntMap> { + generations.into_iter().filter_map(|generation| { let id = generation.request_id; // Get entry // We can `expect` here as the request id should always be in the entries let entry = entries .get_mut(&id) .expect("ID not found in entries. This is a bug."); - entry.cache_length = generation.cache_length; // Create and enter a span to link this function back to the entry let _span = info_span!(parent: entry.temp_span.as_ref().expect("batch_span is None. This is a bug."), "send_generation", generation = ?generation).entered(); - // Send generation responses back to the infer task - // If the receive an error from the Flume channel, it means that the client dropped the - // request and we need to stop generating hence why we unwrap_or(true) - let stopped = send_responses(generation, entry).map_err(|err| { - tracing::error!("Entry response channel error."); + + // Return directly if the channel is disconnected + if entry.response_tx.is_closed() { metrics::increment_counter!("tgi_request_failure", "err" => "dropped"); - err - }).unwrap_or(true); - if stopped { + // Remove from entries and filter entries.remove(&id).expect("ID not found in entries. This is a bug."); + return None; } - }); + + // Update cache length + entry.cache_length = generation.cache_length; + + let (finished, stream_responses) = map_generation(generation, entry); + // If the generation has ended for this request, we send the responses to the channel and + // remove the entry to drop it and free its blocks + if finished { + let _ = send_stream_responses(stream_responses, entry).map_err(|err| { + tracing::error!("Entry response channel error."); + metrics::increment_counter!("tgi_request_failure", "err" => "dropped"); + err + }); + // Remove from entries and filter + entries.remove(&id).expect("ID not found in entries. This is a bug."); + return None; + } + + Some((id, stream_responses)) + }).collect() +} + +/// Send `InferStreamResponse` to `Infer` through an `Entry` response channel +#[instrument(skip_all)] +fn send_stream_responses( + stream_responses: Vec, + entry: &Entry, +) -> Result<(), Box>>> { + for response in stream_responses { + entry.response_tx.send(Ok(response))?; + } + Ok(()) } /// Check if block allocations need to be extended -/// If we don't have enough blocks, request will be filtered with an OutOfPages error +/// If we don't have enough blocks, request will be filtered with be added to an IntMap of +/// terminated entries. +/// If at least one entry allocation was extended, we return true to force an update #[instrument(skip_all)] -async fn filter_update_allocations(entries: &mut IntMap) -> bool { - let ids: Vec = entries - .iter() - .filter_map(|(id, entry)| { - entry - .block_allocation - .as_ref() - .map(|block_allocation| { - if entry.cache_length > block_allocation.len() as u32 { - // We need to re-allocate - Some(*id) - } else { - None - } - }) - .unwrap_or(None) - }) - .collect(); +fn filter_send_update_allocations( + entries: &mut IntMap, + stream_responses: &mut IntMap>, +) -> (bool, IntMap) { + let mut updated = false; - for id in ids.iter() { - // Get entry - // We can `expect` here as the request id should always be in the entries - let extension = { - let entry = entries - .get_mut(id) - .expect("ID not found in entries. This is a bug."); - entry - .block_allocation - .as_mut() - .expect("We checked that the block allocation exists above") - .extend() - }; + let ids: Vec = entries.keys().map(|v| *v).collect(); + let mut terminated_entries = + IntMap::with_capacity_and_hasher(entries.len(), BuildNoHashHasher::default()); - if extension.is_err() { - let entry = entries + for id in &ids { + let entry = entries + .get_mut(id) + .expect("ID not found in entries. This is a bug."); + + if let Some(block_allocation) = entry.block_allocation.as_mut() { + // Check if allocation can handle the current cache_length + if entry.cache_length > block_allocation.len() as u32 { + updated = true; + + // Extend allocation by asking for a new block + if let Err(err) = block_allocation.extend() { + // Failed to extend allocation + tracing::error!("Failed to extend allocation: {err}"); + metrics::increment_counter!("tgi_request_failure", "err" => "out_of_resources"); + + // Remove entry + let mut entry = entries + .remove(id) + .expect("ID not found in entries. This is a bug."); + // Clear block allocation + entry.block_allocation = None; + // Add it to terminated entries + terminated_entries.insert(*id, entry); + // Skip the rest of the logic to not send the intermediate messages + // This entry will be terminated and we will need to edit the last intermediate + // response to add the complete generated text + continue; + } + } + } + let stream_response = stream_responses + .remove(id) + .expect("ID not found in stream_responses. This is a bug."); + + // Send intermediate responses + if let Err(_) = send_stream_responses(stream_response, entry).map_err(|err| { + tracing::error!("Entry response channel error."); + metrics::increment_counter!("tgi_request_failure", "err" => "dropped"); + err + }) { + // Sending failed, remove entry + entries .remove(id) .expect("ID not found in entries. This is a bug."); - - // Create and enter a span to link this function back to the entry - let _send_error_span = info_span!(parent: entry.temp_span.as_ref().expect("batch_span is None. This is a bug."), "send_error").entered(); - let err = InferError::OutOfPages; - metrics::increment_counter!("tgi_request_failure", "err" => "out_of_pages"); - tracing::error!("{err}"); - - // unwrap_or is valid here as we don't care if the receiver is gone. - entry.response_tx.send(Err(err)).unwrap_or(()); } } - // If ids is not empty, we need to update - !ids.is_empty() + (updated, terminated_entries) } -/// Send responses through the `entry` response channel -fn send_responses( - generation: Generation, - entry: &Entry, -) -> Result>>> { - // Return directly if the channel is disconnected - if entry.response_tx.is_closed() { - metrics::increment_counter!("tgi_request_failure", "err" => "dropped"); - return Ok(true); - } - - let mut stopped = false; +/// Map `Generation` to `<(bool, Vec<(u64, InferStreamResponse)>)>` +fn map_generation(generation: Generation, entry: &Entry) -> (bool, Vec) { + let mut finished = false; + let mut stream_responses = Vec::with_capacity(16); if let Some(prefill_tokens) = generation.prefill_tokens { // Create Token objects @@ -475,10 +612,8 @@ fn send_responses( .map(|((id, logprob), text)| PrefillToken { id, text, logprob }) .collect(); - // Send message - entry - .response_tx - .send(Ok(InferStreamResponse::Prefill(prefill_tokens)))?; + // Push to stream_responses + stream_responses.push(InferStreamResponse::Prefill(prefill_tokens)); } // Create last Token @@ -520,26 +655,24 @@ fn send_responses( match (&generation.generated_text, iterator.peek()) { (Some(generated_text), None) => { // Generation has ended - stopped = true; - // Send message - entry.response_tx.send(Ok(InferStreamResponse::End { + finished = true; + // Push to stream_responses + stream_responses.push(InferStreamResponse::End { token, top_tokens, generated_text: GeneratedText::from(generated_text.clone()), queued: entry.queue_time, start: entry.batch_time.unwrap(), - }))?; + }); } _ => { - // Send message - entry - .response_tx - .send(Ok(InferStreamResponse::Intermediate { token, top_tokens }))?; + // Push to stream_responses + stream_responses.push(InferStreamResponse::Intermediate { token, top_tokens }); } } } - Ok(stopped) + (finished, stream_responses) } /// Send errors to Infer for all `entries` diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 0bd9357f..e8fd8b16 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -402,10 +402,7 @@ class FlashCausalLMBatch(Batch): model: "FlashCausalLM", kept_requests: List[generate_pb2.KeptRequest], terminated_request_ids: List[int], - ) -> Tuple[Optional["FlashCausalLMBatch"], List[generate_pb2.GeneratedText]]: - if len(kept_requests) == 0: - raise ValueError("Batch must have at least one request") - + ) -> Tuple[Optional["FlashCausalLMBatch"], List[generate_pb2.TerminatedGeneration]]: terminated_generations = [] for request_id in terminated_request_ids: idx = self.requests_idx_mapping[request_id] @@ -421,13 +418,19 @@ class FlashCausalLMBatch(Batch): read_offset=len(all_input_ids) - stopping_criteria.current_tokens, skip_special_tokens=True, ) - generated_text = GeneratedText( - output_text, - stopping_criteria.current_tokens, - generate_pb2.FINISH_REASON_TERMINATED, - seed if do_sample else None, + terminated_generations.append( + generate_pb2.TerminatedGeneration( + id=request_id, + generated_text=generate_pb2.GeneratedText( + text=output_text, + generated_tokens=stopping_criteria.current_tokens, + finish_reason=generate_pb2.FINISH_REASON_TERMINATED, + seed=seed if do_sample else None, + ), + ) ) - terminated_generations.append(generated_text) + if not kept_requests: + return None, terminated_generations device = self.input_ids.device diff --git a/server/text_generation_server/models/types.py b/server/text_generation_server/models/types.py index c19f804e..0b7868fc 100644 --- a/server/text_generation_server/models/types.py +++ b/server/text_generation_server/models/types.py @@ -33,7 +33,7 @@ class Batch(ABC): model, kept_requests: List[generate_pb2.KeptRequest], terminated_request_ids: List[int], - ) -> Tuple["Batch", List[generate_pb2.GeneratedText]]: + ) -> Tuple[Optional["Batch"], List[generate_pb2.TerminatedGeneration]]: raise NotImplementedError @classmethod diff --git a/server/text_generation_server/server.py b/server/text_generation_server/server.py index 86df66e7..14297669 100644 --- a/server/text_generation_server/server.py +++ b/server/text_generation_server/server.py @@ -86,10 +86,12 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer): filtered_batch, terminated_generations = batch.filter( self.model, request.kept_requests, request.terminated_request_ids ) - self.cache.set(filtered_batch) + if filtered_batch is not None: + self.cache.set(filtered_batch) return generate_pb2.FilterBatchResponse( - batch=filtered_batch.to_pb(), terminated_generations=terminated_generations + batch=filtered_batch.to_pb() if filtered_batch is not None else None, + terminated_generations=terminated_generations, ) async def Warmup(self, request, context):