FlashCausalLM implem

This commit is contained in:
OlivierDehaene 2024-06-11 12:38:07 +02:00
parent 6983ec9537
commit 73c3903214
8 changed files with 273 additions and 128 deletions

View File

@ -224,11 +224,18 @@ message FilterBatchRequest {
repeated uint64 terminated_request_ids = 3;
}
message TerminatedGeneration {
// Request ID
uint64 id = 1;
// Generated text
GeneratedText generated_text = 2;
}
message FilterBatchResponse {
/// Filtered Batch (cached)
CachedBatch batch = 1;
/// Terminated generations
repeated GeneratedText terminated_generations = 2;
repeated TerminatedGeneration terminated_generations = 2;
}

View File

@ -92,7 +92,7 @@ impl Client {
batch_id: u64,
kept_requests: Vec<KeptRequest>,
terminated_request_ids: Vec<u64>,
) -> Result<Option<CachedBatch>> {
) -> Result<(Option<CachedBatch>, Vec<TerminatedGeneration>)> {
let request = tonic::Request::new(FilterBatchRequest {
batch_id,
kept_requests,
@ -100,7 +100,7 @@ impl Client {
})
.inject_context();
let filtered_batch = self.stub.filter_batch(request).await?.into_inner();
Ok(filtered_batch.batch)
Ok((filtered_batch.batch, filtered_batch.terminated_generations))
}
/// Warmup on a max size batch

View File

@ -8,6 +8,6 @@ pub use client::Client;
pub use pb::generate::v3::{
input_chunk::Chunk, Batch, CachedBatch, FinishReason, GeneratedText, Generation, GrammarType,
HealthResponse, Image, InfoResponse, Input, InputChunk, KeptRequest,
NextTokenChooserParameters, Request, StoppingCriteriaParameters, Tokens,
NextTokenChooserParameters, Request, StoppingCriteriaParameters, TerminatedGeneration, Tokens,
};
pub use sharded_client::ShardedClient;

View File

@ -2,7 +2,7 @@
use crate::{v3, Health, ShardInfo};
use crate::{ClientError, Result};
use crate::v3::{Chunk, InfoResponse, Input};
use crate::v3::{Chunk, InfoResponse, Input, TerminatedGeneration};
use async_trait::async_trait;
use futures::future::join_all;
use tonic::transport::Uri;
@ -86,7 +86,7 @@ impl ShardedClient {
batch_id: u64,
kept_requests: Vec<KeptRequest>,
terminated_request_ids: Vec<u64>,
) -> Result<Option<CachedBatch>> {
) -> Result<(Option<CachedBatch>, Vec<TerminatedGeneration>)> {
let futures: Vec<_> = self
.clients
.iter_mut()

View File

@ -5,12 +5,14 @@ use crate::infer::{
};
use crate::validation::ValidGenerateRequest;
use crate::{FinishReason, PrefillToken, Token};
use nohash_hasher::IntMap;
use nohash_hasher::{BuildNoHashHasher, IntMap};
use std::sync::{
atomic::{AtomicBool, Ordering},
Arc,
};
use text_generation_client::v3::{Batch, CachedBatch, Generation, KeptRequest, ShardedClient};
use text_generation_client::v3::{
Batch, CachedBatch, Generation, KeptRequest, ShardedClient, TerminatedGeneration,
};
use text_generation_client::ClientError;
use tokio::sync::mpsc::error::SendError;
use tokio::sync::{mpsc, Notify, OwnedSemaphorePermit};
@ -243,11 +245,38 @@ async fn prefill(
generation_health.store(true, Ordering::SeqCst);
let start_filtering_time = Instant::now();
// Send generated tokens and filter stopped entries
filter_send_generations(generations, entries);
// Filter and send finished generations
let filtered_stream_responses = filter_send_ended_generations(generations, entries);
// Iterate on intermediate generations
for (id, stream_responses) in filtered_stream_responses {
// Get entry
let entry = entries
.get_mut(&id)
.expect("ID not found in entries. This is a bug.");
// Send intermediate responses
if let Err(_) = send_stream_responses(stream_responses, entry).map_err(|err| {
tracing::error!("Entry response channel error.");
metrics::increment_counter!("tgi_request_failure", "err" => "dropped");
err
}) {
// Sending failed, remove entry
entries
.remove(&id)
.expect("ID not found in entries. This is a bug.");
}
}
// Filter next batch and remove requests that were stopped
let next_batch = filter_batch(client, next_batch, entries, false).await;
let next_batch = match next_batch {
Some(batch) if batch.size as usize != entries.len() => {
let (filtered_batch, _) =
filter_batch(client, batch, entries, &IntMap::default()).await;
filtered_batch
}
batch => batch,
};
metrics::histogram!("tgi_batch_forward_duration", timings.forward.as_secs_f64(), "method" => "prefill");
metrics::histogram!("tgi_batch_decode_duration", timings.decode.as_secs_f64(), "method" => "prefill");
@ -285,13 +314,32 @@ async fn decode(
generation_health.store(true, Ordering::SeqCst);
let start_filtering_time = Instant::now();
// Send generated tokens and filter stopped entries
filter_send_generations(generations, entries);
let updated = filter_update_allocations(entries).await;
// Filter and send finished generations
let mut filtered_stream_responses = filter_send_ended_generations(generations, entries);
// Send `StreamResponseInfer::Intermediate` messages for entries that don't need to be
// re-allocated,
// Allocated new blocks for entries that go over their allocation
// Filter entries that couldn't be re-allocated and add them to `terminated_entries`
let (force_update, terminated_entries) =
filter_send_update_allocations(entries, &mut filtered_stream_responses);
// Filter next batch and remove requests that were stopped
let next_batch = filter_batch(client, next_batch, entries, updated).await;
let next_batch = match next_batch {
// Run Only on re-allocation or if entries were filtered
Some(batch) if batch.size as usize != entries.len() || force_update => {
// Filter next batch: remove requests that were stopped and update blocks/slots
let (filtered_batch, terminated_generations) =
filter_batch(client, batch, entries, &terminated_entries).await;
send_terminated_generations(
terminated_generations,
terminated_entries,
filtered_stream_responses,
);
filtered_batch
}
batch => batch,
};
if let Some(concat_duration) = timings.concat {
metrics::histogram!("tgi_batch_concat_duration", concat_duration.as_secs_f64(), "method" => "decode");
@ -320,27 +368,20 @@ async fn decode(
#[instrument(skip_all)]
async fn filter_batch(
client: &mut ShardedClient,
next_batch: Option<CachedBatch>,
batch: CachedBatch,
entries: &IntMap<u64, Entry>,
force_update: bool,
) -> Option<CachedBatch> {
let batch = next_batch?;
// No need to filter
if batch.size as usize == entries.len() && !force_update {
return Some(batch);
}
terminated_entries: &IntMap<u64, Entry>,
) -> (Option<CachedBatch>, Vec<TerminatedGeneration>) {
let id = batch.id;
if entries.is_empty() {
if entries.is_empty() && terminated_entries.is_empty() {
// All requests have been filtered out
// Next batch is now empty
// Clear it from the Python shards cache
// We unwrap here as we need to panic since we cannot recover if this method fails
client.clear_cache(Some(id)).await.unwrap();
None
Default::default()
} else {
// Filter Python shard cache
// Collect new blocks/slots
let updated_requests = entries
.iter()
.map(|(request_id, entry)| {
@ -348,7 +389,7 @@ async fn filter_batch(
.block_allocation
.as_ref()
.map(|alloc| (alloc.blocks().to_vec(), alloc.slots().to_vec()))
.unwrap_or((Vec::new(), Vec::new()));
.unwrap_or_default();
KeptRequest {
id: *request_id,
@ -358,111 +399,207 @@ async fn filter_batch(
})
.collect();
// Filter Python shard cache
// We unwrap here as we need to panic since we cannot recover if this method fails
client
.filter_batch(id, updated_requests, Vec::new())
.filter_batch(
id,
updated_requests,
terminated_entries.keys().map(|v| *v).collect(),
)
.await
.unwrap()
}
}
/// Send one or multiple `InferStreamResponse` to Infer for all `entries`
/// and filter entries
///
#[instrument(skip_all)]
fn filter_send_generations(generations: Vec<Generation>, entries: &mut IntMap<u64, Entry>) {
generations.into_iter().for_each(|generation| {
fn send_terminated_generations(
terminated_generations: Vec<TerminatedGeneration>,
terminated_entries: IntMap<u64, Entry>,
mut stream_responses: IntMap<u64, Vec<InferStreamResponse>>,
) {
// Receive final message for terminated generations
'terminated_generations: for terminated_generation in terminated_generations {
let id = terminated_generation.id;
// Get entry for this generation
let entry = terminated_entries
.get(&id)
.expect("ID not found in entries. This is a bug.");
// Get previous `InferStreamResponse` for this generation
let stream_responses = stream_responses
.remove(&id)
.expect("ID not found in stream_responses. This is a bug.");
// Peekable iterator to know when we are at the last `InferStreamResponse`
let mut iterator = stream_responses.into_iter().peekable();
while let Some(stream_response) = iterator.next() {
let response = if iterator.peek().is_none() {
// Last `InferStreamResponse::Intermediate`
let (token, top_tokens) = match stream_response {
InferStreamResponse::Intermediate { token, top_tokens } => (token, top_tokens),
_ => unreachable!(),
};
// Modify it to be a `InferStreamResponse::End` with the new OutOfResources finish
// reason
InferStreamResponse::End {
token,
top_tokens,
generated_text: GeneratedText::from(
terminated_generation
.generated_text
.clone()
.expect("Generated Text is None. This is a bug."),
),
queued: entry.queue_time,
start: entry.batch_time.unwrap(),
}
} else {
stream_response
};
// Send responses
if let Err(_) = entry.response_tx.send(Ok(response)).map_err(|err| {
tracing::error!("Entry response channel error.");
metrics::increment_counter!("tgi_request_failure", "err" => "dropped");
err
}) {
continue 'terminated_generations;
}
}
}
}
/// Send `InferStreamResponse::End` to `Infer` for finished entries and remove them from `entries`
/// Returns filtered `InferStreamResponse::Intermediate` generations
#[instrument(skip_all)]
fn filter_send_ended_generations(
generations: Vec<Generation>,
entries: &mut IntMap<u64, Entry>,
) -> IntMap<u64, Vec<InferStreamResponse>> {
generations.into_iter().filter_map(|generation| {
let id = generation.request_id;
// Get entry
// We can `expect` here as the request id should always be in the entries
let entry = entries
.get_mut(&id)
.expect("ID not found in entries. This is a bug.");
entry.cache_length = generation.cache_length;
// Create and enter a span to link this function back to the entry
let _span = info_span!(parent: entry.temp_span.as_ref().expect("batch_span is None. This is a bug."), "send_generation", generation = ?generation).entered();
// Send generation responses back to the infer task
// If the receive an error from the Flume channel, it means that the client dropped the
// request and we need to stop generating hence why we unwrap_or(true)
let stopped = send_responses(generation, entry).map_err(|err| {
tracing::error!("Entry response channel error.");
// Return directly if the channel is disconnected
if entry.response_tx.is_closed() {
metrics::increment_counter!("tgi_request_failure", "err" => "dropped");
err
}).unwrap_or(true);
if stopped {
// Remove from entries and filter
entries.remove(&id).expect("ID not found in entries. This is a bug.");
return None;
}
});
// Update cache length
entry.cache_length = generation.cache_length;
let (finished, stream_responses) = map_generation(generation, entry);
// If the generation has ended for this request, we send the responses to the channel and
// remove the entry to drop it and free its blocks
if finished {
let _ = send_stream_responses(stream_responses, entry).map_err(|err| {
tracing::error!("Entry response channel error.");
metrics::increment_counter!("tgi_request_failure", "err" => "dropped");
err
});
// Remove from entries and filter
entries.remove(&id).expect("ID not found in entries. This is a bug.");
return None;
}
Some((id, stream_responses))
}).collect()
}
/// Send `InferStreamResponse` to `Infer` through an `Entry` response channel
#[instrument(skip_all)]
fn send_stream_responses(
stream_responses: Vec<InferStreamResponse>,
entry: &Entry,
) -> Result<(), Box<SendError<Result<InferStreamResponse, InferError>>>> {
for response in stream_responses {
entry.response_tx.send(Ok(response))?;
}
Ok(())
}
/// Check if block allocations need to be extended
/// If we don't have enough blocks, request will be filtered with an OutOfPages error
/// If we don't have enough blocks, request will be filtered with be added to an IntMap of
/// terminated entries.
/// If at least one entry allocation was extended, we return true to force an update
#[instrument(skip_all)]
async fn filter_update_allocations(entries: &mut IntMap<u64, Entry>) -> bool {
let ids: Vec<u64> = entries
.iter()
.filter_map(|(id, entry)| {
entry
.block_allocation
.as_ref()
.map(|block_allocation| {
if entry.cache_length > block_allocation.len() as u32 {
// We need to re-allocate
Some(*id)
} else {
None
}
})
.unwrap_or(None)
})
.collect();
fn filter_send_update_allocations(
entries: &mut IntMap<u64, Entry>,
stream_responses: &mut IntMap<u64, Vec<InferStreamResponse>>,
) -> (bool, IntMap<u64, Entry>) {
let mut updated = false;
for id in ids.iter() {
// Get entry
// We can `expect` here as the request id should always be in the entries
let extension = {
let entry = entries
.get_mut(id)
.expect("ID not found in entries. This is a bug.");
entry
.block_allocation
.as_mut()
.expect("We checked that the block allocation exists above")
.extend()
};
let ids: Vec<u64> = entries.keys().map(|v| *v).collect();
let mut terminated_entries =
IntMap::with_capacity_and_hasher(entries.len(), BuildNoHashHasher::default());
if extension.is_err() {
let entry = entries
for id in &ids {
let entry = entries
.get_mut(id)
.expect("ID not found in entries. This is a bug.");
if let Some(block_allocation) = entry.block_allocation.as_mut() {
// Check if allocation can handle the current cache_length
if entry.cache_length > block_allocation.len() as u32 {
updated = true;
// Extend allocation by asking for a new block
if let Err(err) = block_allocation.extend() {
// Failed to extend allocation
tracing::error!("Failed to extend allocation: {err}");
metrics::increment_counter!("tgi_request_failure", "err" => "out_of_resources");
// Remove entry
let mut entry = entries
.remove(id)
.expect("ID not found in entries. This is a bug.");
// Clear block allocation
entry.block_allocation = None;
// Add it to terminated entries
terminated_entries.insert(*id, entry);
// Skip the rest of the logic to not send the intermediate messages
// This entry will be terminated and we will need to edit the last intermediate
// response to add the complete generated text
continue;
}
}
}
let stream_response = stream_responses
.remove(id)
.expect("ID not found in stream_responses. This is a bug.");
// Send intermediate responses
if let Err(_) = send_stream_responses(stream_response, entry).map_err(|err| {
tracing::error!("Entry response channel error.");
metrics::increment_counter!("tgi_request_failure", "err" => "dropped");
err
}) {
// Sending failed, remove entry
entries
.remove(id)
.expect("ID not found in entries. This is a bug.");
// Create and enter a span to link this function back to the entry
let _send_error_span = info_span!(parent: entry.temp_span.as_ref().expect("batch_span is None. This is a bug."), "send_error").entered();
let err = InferError::OutOfPages;
metrics::increment_counter!("tgi_request_failure", "err" => "out_of_pages");
tracing::error!("{err}");
// unwrap_or is valid here as we don't care if the receiver is gone.
entry.response_tx.send(Err(err)).unwrap_or(());
}
}
// If ids is not empty, we need to update
!ids.is_empty()
(updated, terminated_entries)
}
/// Send responses through the `entry` response channel
fn send_responses(
generation: Generation,
entry: &Entry,
) -> Result<bool, Box<SendError<Result<InferStreamResponse, InferError>>>> {
// Return directly if the channel is disconnected
if entry.response_tx.is_closed() {
metrics::increment_counter!("tgi_request_failure", "err" => "dropped");
return Ok(true);
}
let mut stopped = false;
/// Map `Generation` to `<(bool, Vec<(u64, InferStreamResponse)>)>`
fn map_generation(generation: Generation, entry: &Entry) -> (bool, Vec<InferStreamResponse>) {
let mut finished = false;
let mut stream_responses = Vec::with_capacity(16);
if let Some(prefill_tokens) = generation.prefill_tokens {
// Create Token objects
@ -475,10 +612,8 @@ fn send_responses(
.map(|((id, logprob), text)| PrefillToken { id, text, logprob })
.collect();
// Send message
entry
.response_tx
.send(Ok(InferStreamResponse::Prefill(prefill_tokens)))?;
// Push to stream_responses
stream_responses.push(InferStreamResponse::Prefill(prefill_tokens));
}
// Create last Token
@ -520,26 +655,24 @@ fn send_responses(
match (&generation.generated_text, iterator.peek()) {
(Some(generated_text), None) => {
// Generation has ended
stopped = true;
// Send message
entry.response_tx.send(Ok(InferStreamResponse::End {
finished = true;
// Push to stream_responses
stream_responses.push(InferStreamResponse::End {
token,
top_tokens,
generated_text: GeneratedText::from(generated_text.clone()),
queued: entry.queue_time,
start: entry.batch_time.unwrap(),
}))?;
});
}
_ => {
// Send message
entry
.response_tx
.send(Ok(InferStreamResponse::Intermediate { token, top_tokens }))?;
// Push to stream_responses
stream_responses.push(InferStreamResponse::Intermediate { token, top_tokens });
}
}
}
Ok(stopped)
(finished, stream_responses)
}
/// Send errors to Infer for all `entries`

View File

@ -402,10 +402,7 @@ class FlashCausalLMBatch(Batch):
model: "FlashCausalLM",
kept_requests: List[generate_pb2.KeptRequest],
terminated_request_ids: List[int],
) -> Tuple[Optional["FlashCausalLMBatch"], List[generate_pb2.GeneratedText]]:
if len(kept_requests) == 0:
raise ValueError("Batch must have at least one request")
) -> Tuple[Optional["FlashCausalLMBatch"], List[generate_pb2.TerminatedGeneration]]:
terminated_generations = []
for request_id in terminated_request_ids:
idx = self.requests_idx_mapping[request_id]
@ -421,13 +418,19 @@ class FlashCausalLMBatch(Batch):
read_offset=len(all_input_ids) - stopping_criteria.current_tokens,
skip_special_tokens=True,
)
generated_text = GeneratedText(
output_text,
stopping_criteria.current_tokens,
generate_pb2.FINISH_REASON_TERMINATED,
seed if do_sample else None,
terminated_generations.append(
generate_pb2.TerminatedGeneration(
id=request_id,
generated_text=generate_pb2.GeneratedText(
text=output_text,
generated_tokens=stopping_criteria.current_tokens,
finish_reason=generate_pb2.FINISH_REASON_TERMINATED,
seed=seed if do_sample else None,
),
)
)
terminated_generations.append(generated_text)
if not kept_requests:
return None, terminated_generations
device = self.input_ids.device

View File

@ -33,7 +33,7 @@ class Batch(ABC):
model,
kept_requests: List[generate_pb2.KeptRequest],
terminated_request_ids: List[int],
) -> Tuple["Batch", List[generate_pb2.GeneratedText]]:
) -> Tuple[Optional["Batch"], List[generate_pb2.TerminatedGeneration]]:
raise NotImplementedError
@classmethod

View File

@ -86,10 +86,12 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
filtered_batch, terminated_generations = batch.filter(
self.model, request.kept_requests, request.terminated_request_ids
)
self.cache.set(filtered_batch)
if filtered_batch is not None:
self.cache.set(filtered_batch)
return generate_pb2.FilterBatchResponse(
batch=filtered_batch.to_pb(), terminated_generations=terminated_generations
batch=filtered_batch.to_pb() if filtered_batch is not None else None,
terminated_generations=terminated_generations,
)
async def Warmup(self, request, context):