add terminated_generations

This commit is contained in:
OlivierDehaene 2024-06-07 11:26:17 +02:00
parent 3c596983ba
commit 298bf31e69
16 changed files with 107 additions and 60 deletions

View File

@ -164,6 +164,7 @@ enum FinishReason {
FINISH_REASON_LENGTH = 0; FINISH_REASON_LENGTH = 0;
FINISH_REASON_EOS_TOKEN = 1; FINISH_REASON_EOS_TOKEN = 1;
FINISH_REASON_STOP_SEQUENCE = 2; FINISH_REASON_STOP_SEQUENCE = 2;
FINISH_REASON_TERMINATED = 3;
} }
message GeneratedText { message GeneratedText {
@ -198,11 +199,11 @@ message Generation {
optional GeneratedText generated_text = 4; optional GeneratedText generated_text = 4;
/// Top tokens /// Top tokens
repeated Tokens top_tokens = 5; repeated Tokens top_tokens = 5;
/// Current length of the request: prompt tokens + number of generated tokens until this point /// Current length of the cache: prompt tokens + number of generated tokens until this point
uint32 current_length = 6; uint32 cache_length = 6;
} }
message UpdatedRequest { message KeptRequest {
/// Request ID /// Request ID
uint64 id = 1; uint64 id = 1;
/// Paged attention blocks /// Paged attention blocks
@ -211,16 +212,23 @@ message UpdatedRequest {
repeated uint32 slots = 3; repeated uint32 slots = 3;
} }
/// kept_requests + terminated_request_ids might not cover all requests from the
/// cached batch as some requests can be filtered out without requiring to generate text
/// for example if the client dropped its connection to the router
message FilterBatchRequest { message FilterBatchRequest {
/// Batch ID /// Batch ID
uint64 batch_id = 1; uint64 batch_id = 1;
/// Requests to keep /// Requests to keep
repeated UpdatedRequest updated_requests = 2; repeated KeptRequest kept_requests = 2;
/// Requests to terminate and generate text for
repeated uint64 terminated_request_ids = 3;
} }
message FilterBatchResponse { message FilterBatchResponse {
/// Filtered Batch (cached) /// Filtered Batch (cached)
CachedBatch batch = 1; CachedBatch batch = 1;
/// Terminated generations
repeated GeneratedText terminated_generations = 2;
} }

View File

@ -90,11 +90,13 @@ impl Client {
pub async fn filter_batch( pub async fn filter_batch(
&mut self, &mut self,
batch_id: u64, batch_id: u64,
updated_requests: Vec<UpdatedRequest>, kept_requests: Vec<KeptRequest>,
terminated_request_ids: Vec<u64>,
) -> Result<Option<CachedBatch>> { ) -> Result<Option<CachedBatch>> {
let request = tonic::Request::new(FilterBatchRequest { let request = tonic::Request::new(FilterBatchRequest {
batch_id, batch_id,
updated_requests, kept_requests,
terminated_request_ids,
}) })
.inject_context(); .inject_context();
let filtered_batch = self.stub.filter_batch(request).await?.into_inner(); let filtered_batch = self.stub.filter_batch(request).await?.into_inner();

View File

@ -7,7 +7,7 @@ mod sharded_client;
pub use client::Client; pub use client::Client;
pub use pb::generate::v3::{ pub use pb::generate::v3::{
input_chunk::Chunk, Batch, CachedBatch, FinishReason, GeneratedText, Generation, GrammarType, input_chunk::Chunk, Batch, CachedBatch, FinishReason, GeneratedText, Generation, GrammarType,
HealthResponse, Image, InfoResponse, Input, InputChunk, NextTokenChooserParameters, Request, HealthResponse, Image, InfoResponse, Input, InputChunk, KeptRequest,
StoppingCriteriaParameters, Tokens, UpdatedRequest, NextTokenChooserParameters, Request, StoppingCriteriaParameters, Tokens,
}; };
pub use sharded_client::ShardedClient; pub use sharded_client::ShardedClient;

View File

@ -9,8 +9,8 @@ use tonic::transport::Uri;
use tracing::instrument; use tracing::instrument;
use v3::client::{DecodeTimings, PrefillTimings}; use v3::client::{DecodeTimings, PrefillTimings};
use v3::{ use v3::{
Batch, CachedBatch, Client, Generation, GrammarType, HealthResponse, Batch, CachedBatch, Client, Generation, GrammarType, HealthResponse, KeptRequest,
NextTokenChooserParameters, Request, StoppingCriteriaParameters, UpdatedRequest, NextTokenChooserParameters, Request, StoppingCriteriaParameters,
}; };
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
@ -84,12 +84,19 @@ impl ShardedClient {
pub async fn filter_batch( pub async fn filter_batch(
&mut self, &mut self,
batch_id: u64, batch_id: u64,
updated_requests: Vec<UpdatedRequest>, kept_requests: Vec<KeptRequest>,
terminated_request_ids: Vec<u64>,
) -> Result<Option<CachedBatch>> { ) -> Result<Option<CachedBatch>> {
let futures: Vec<_> = self let futures: Vec<_> = self
.clients .clients
.iter_mut() .iter_mut()
.map(|client| Box::pin(client.filter_batch(batch_id, updated_requests.clone()))) .map(|client| {
Box::pin(client.filter_batch(
batch_id,
kept_requests.clone(),
terminated_request_ids.clone(),
))
})
.collect(); .collect();
// all shards return the same message // all shards return the same message
join_all(futures).await.pop().unwrap() join_all(futures).await.pop().unwrap()

View File

@ -1,4 +1,4 @@
use std::cmp::{max, min}; use std::cmp::min;
use thiserror::Error; use thiserror::Error;
use tokio::sync::{mpsc, oneshot}; use tokio::sync::{mpsc, oneshot};
@ -16,8 +16,9 @@ impl BlockAllocation {
self.slots.len() self.slots.len()
} }
pub(crate) async fn extend(&mut self, current_length: u32) -> Result<(), AllocationError> { pub(crate) async fn extend(&mut self, cache_length: u32) -> Result<(), AllocationError> {
let remaining_tokens = max(self.prompt_tokens + self.decode_tokens - current_length, 1); let remaining_tokens =
(self.prompt_tokens + self.decode_tokens).saturating_sub(cache_length);
self.block_allocator self.block_allocator
.clone() .clone()
.extend(self, remaining_tokens) .extend(self, remaining_tokens)
@ -131,6 +132,7 @@ async fn block_allocator_task(
let decode_tokens = min(decode_tokens, block_size); let decode_tokens = min(decode_tokens, block_size);
let tokens = prompt_tokens + decode_tokens; let tokens = prompt_tokens + decode_tokens;
// FIXME: window size is not working
// Apply window size // Apply window size
let (required_blocks, repeats) = { let (required_blocks, repeats) = {
let (tokens, repeats) = match window_size { let (tokens, repeats) = match window_size {

View File

@ -5,7 +5,7 @@ use crate::validation::{
ValidGenerateRequest, ValidGrammar, ValidParameters, ValidStoppingParameters, ValidGenerateRequest, ValidGrammar, ValidParameters, ValidStoppingParameters,
}; };
use nohash_hasher::{BuildNoHashHasher, IntMap}; use nohash_hasher::{BuildNoHashHasher, IntMap};
use std::cmp::{max, min}; use std::cmp::max;
use std::collections::VecDeque; use std::collections::VecDeque;
use text_generation_client::v3::{ use text_generation_client::v3::{
Batch, GrammarType, NextTokenChooserParameters, Request, StoppingCriteriaParameters, Batch, GrammarType, NextTokenChooserParameters, Request, StoppingCriteriaParameters,
@ -33,8 +33,8 @@ pub(crate) struct Entry {
pub batch_time: Option<Instant>, pub batch_time: Option<Instant>,
/// Block Allocation /// Block Allocation
pub block_allocation: Option<BlockAllocation>, pub block_allocation: Option<BlockAllocation>,
/// Current length (in tokens) of the request (prompt tokens + generated_tokens) /// Cache length (in tokens) of the request (prompt tokens + generated_tokens)
pub current_length: u32, pub cache_length: u32,
} }
/// Request Queue /// Request Queue
@ -164,9 +164,6 @@ struct State {
/// Paged Attention block size /// Paged Attention block size
block_size: u32, block_size: u32,
/// Sliding window
window_size: Option<u32>,
/// Speculation amount /// Speculation amount
speculate: u32, speculate: u32,
@ -190,7 +187,6 @@ impl State {
next_id: 0, next_id: 0,
next_batch_id: 0, next_batch_id: 0,
block_size, block_size,
window_size,
speculate, speculate,
block_allocator, block_allocator,
} }
@ -276,18 +272,7 @@ impl State {
} }
Some(block_allocator) => { Some(block_allocator) => {
prefill_tokens += entry.request.input_length; prefill_tokens += entry.request.input_length;
let max_new_tokens = match self.window_size { if prefill_tokens > prefill_token_budget {
None => entry.request.stopping_parameters.max_new_tokens,
Some(window_size) => min(
window_size.saturating_sub(entry.request.input_length),
entry.request.stopping_parameters.max_new_tokens,
),
};
decode_tokens += max_new_tokens;
if prefill_tokens > prefill_token_budget
|| (prefill_tokens + decode_tokens + self.speculate) > token_budget
{
// Entry is over budget // Entry is over budget
// Add it back to the front // Add it back to the front
tracing::debug!("Over budget: prefill_tokens={prefill_tokens} > {prefill_token_budget} || {prefill_tokens} + {decode_tokens} + {} > {token_budget}", self.speculate); tracing::debug!("Over budget: prefill_tokens={prefill_tokens} > {prefill_token_budget} || {prefill_tokens} + {decode_tokens} + {} > {token_budget}", self.speculate);
@ -296,7 +281,7 @@ impl State {
} }
let decode_tokens = let decode_tokens =
entry.request.stopping_parameters.max_new_tokens + self.speculate - 1; entry.request.stopping_parameters.max_new_tokens + self.speculate;
match block_allocator match block_allocator
.allocate(entry.request.input_length, decode_tokens) .allocate(entry.request.input_length, decode_tokens)
.await .await
@ -500,7 +485,7 @@ mod tests {
queue_time: Instant::now(), queue_time: Instant::now(),
batch_time: None, batch_time: None,
block_allocation: None, block_allocation: None,
current_length: 0, cache_length: 0,
}; };
(entry, receiver_tx) (entry, receiver_tx)
} }

View File

@ -10,7 +10,7 @@ use std::sync::{
atomic::{AtomicBool, Ordering}, atomic::{AtomicBool, Ordering},
Arc, Arc,
}; };
use text_generation_client::v3::{Batch, CachedBatch, Generation, ShardedClient, UpdatedRequest}; use text_generation_client::v3::{Batch, CachedBatch, Generation, KeptRequest, ShardedClient};
use text_generation_client::ClientError; use text_generation_client::ClientError;
use tokio::sync::mpsc::error::SendError; use tokio::sync::mpsc::error::SendError;
use tokio::sync::{mpsc, Notify, OwnedSemaphorePermit}; use tokio::sync::{mpsc, Notify, OwnedSemaphorePermit};
@ -88,7 +88,7 @@ impl Scheduler for SchedulerV3 {
queue_time: Instant::now(), queue_time: Instant::now(),
batch_time: None, batch_time: None,
block_allocation: None, block_allocation: None,
current_length: input_length, cache_length: 0,
}); });
// Notify the background task that we have a new entry in the queue that needs // Notify the background task that we have a new entry in the queue that needs
@ -350,7 +350,7 @@ async fn filter_batch(
.map(|alloc| (alloc.blocks.clone(), alloc.slots.clone())) .map(|alloc| (alloc.blocks.clone(), alloc.slots.clone()))
.unwrap_or((Vec::new(), Vec::new())); .unwrap_or((Vec::new(), Vec::new()));
UpdatedRequest { KeptRequest {
id: *request_id, id: *request_id,
blocks, blocks,
slots, slots,
@ -359,7 +359,10 @@ async fn filter_batch(
.collect(); .collect();
// We unwrap here as we need to panic since we cannot recover if this method fails // We unwrap here as we need to panic since we cannot recover if this method fails
client.filter_batch(id, updated_requests).await.unwrap() client
.filter_batch(id, updated_requests, Vec::new())
.await
.unwrap()
} }
} }
@ -374,7 +377,7 @@ fn filter_send_generations(generations: Vec<Generation>, entries: &mut IntMap<u6
let entry = entries let entry = entries
.get_mut(&id) .get_mut(&id)
.expect("ID not found in entries. This is a bug."); .expect("ID not found in entries. This is a bug.");
entry.current_length = generation.current_length; entry.cache_length = generation.cache_length;
// Create and enter a span to link this function back to the entry // Create and enter a span to link this function back to the entry
let _span = info_span!(parent: entry.temp_span.as_ref().expect("batch_span is None. This is a bug."), "send_generation", generation = ?generation).entered(); let _span = info_span!(parent: entry.temp_span.as_ref().expect("batch_span is None. This is a bug."), "send_generation", generation = ?generation).entered();
@ -403,7 +406,7 @@ async fn filter_update_allocations(entries: &mut IntMap<u64, Entry>) -> bool {
.block_allocation .block_allocation
.as_ref() .as_ref()
.map(|block_allocation| { .map(|block_allocation| {
if entry.current_length > block_allocation.len() as u32 { if entry.cache_length > block_allocation.len() as u32 {
// We need to re-allocate // We need to re-allocate
Some(*id) Some(*id)
} else { } else {
@ -424,8 +427,8 @@ async fn filter_update_allocations(entries: &mut IntMap<u64, Entry>) -> bool {
entry entry
.block_allocation .block_allocation
.as_mut() .as_mut()
.unwrap() .expect("We checked that the block allocation exists above")
.extend(entry.current_length) .extend(entry.cache_length)
.await .await
}; };
@ -563,6 +566,7 @@ impl From<text_generation_client::v3::GeneratedText> for GeneratedText {
let v3_finish_reason = let v3_finish_reason =
text_generation_client::v3::FinishReason::try_from(value.finish_reason).unwrap(); text_generation_client::v3::FinishReason::try_from(value.finish_reason).unwrap();
let finish_reason = match v3_finish_reason { let finish_reason = match v3_finish_reason {
text_generation_client::v3::FinishReason::Terminated => FinishReason::OutOfResources,
text_generation_client::v3::FinishReason::Length => FinishReason::Length, text_generation_client::v3::FinishReason::Length => FinishReason::Length,
text_generation_client::v3::FinishReason::EosToken => FinishReason::EndOfSequenceToken, text_generation_client::v3::FinishReason::EosToken => FinishReason::EndOfSequenceToken,
text_generation_client::v3::FinishReason::StopSequence => FinishReason::StopSequence, text_generation_client::v3::FinishReason::StopSequence => FinishReason::StopSequence,

View File

@ -1085,6 +1085,8 @@ pub(crate) enum FinishReason {
EndOfSequenceToken, EndOfSequenceToken,
#[schema(rename = "stop_sequence")] #[schema(rename = "stop_sequence")]
StopSequence, StopSequence,
#[schema(rename = "out_of_resources")]
OutOfResources,
} }
impl std::fmt::Display for FinishReason { impl std::fmt::Display for FinishReason {
@ -1093,6 +1095,7 @@ impl std::fmt::Display for FinishReason {
FinishReason::Length => write!(f, "length"), FinishReason::Length => write!(f, "length"),
FinishReason::EndOfSequenceToken => write!(f, "eos_token"), FinishReason::EndOfSequenceToken => write!(f, "eos_token"),
FinishReason::StopSequence => write!(f, "stop_sequence"), FinishReason::StopSequence => write!(f, "stop_sequence"),
FinishReason::OutOfResources => write!(f, "out_of_resources"),
} }
} }
} }

View File

@ -159,7 +159,7 @@ class CausalLMBatch(Batch):
@tracer.start_as_current_span("filter") @tracer.start_as_current_span("filter")
def filter( def filter(
self, updated_requests: List[generate_pb2.UpdatedRequest] self, updated_requests: List[generate_pb2.KeptRequest]
) -> Optional["CausalLMBatch"]: ) -> Optional["CausalLMBatch"]:
request_ids = [r.id for r in updated_requests] request_ids = [r.id for r in updated_requests]

View File

@ -398,11 +398,37 @@ class FlashCausalLMBatch(Batch):
@tracer.start_as_current_span("filter") @tracer.start_as_current_span("filter")
def filter( def filter(
self, updated_requests: List[generate_pb2.UpdatedRequest] self,
) -> Optional["FlashCausalLMBatch"]: model: "FlashCausalLM",
if len(updated_requests) == 0: kept_requests: List[generate_pb2.KeptRequest],
terminated_request_ids: List[int],
) -> Tuple[Optional["FlashCausalLMBatch"], List[generate_pb2.GeneratedText]]:
if len(kept_requests) == 0:
raise ValueError("Batch must have at least one request") raise ValueError("Batch must have at least one request")
terminated_generations = []
for request_id in terminated_request_ids:
idx = self.requests_idx_mapping[request_id]
all_input_ids = self.all_input_ids[idx]
stopping_criteria = self.stopping_criterias[idx]
do_sample = self.next_token_chooser.do_sample[idx]
seed = self.next_token_chooser.seeds[idx]
# Decode generated tokens
output_text, _, _ = model.decode_token(
all_input_ids,
prefix_offset=len(all_input_ids) - stopping_criteria.current_tokens - 1,
read_offset=len(all_input_ids) - stopping_criteria.current_tokens,
skip_special_tokens=True,
)
generated_text = GeneratedText(
output_text,
stopping_criteria.current_tokens,
generate_pb2.FINISH_REASON_TERMINATED,
seed if do_sample else None,
)
terminated_generations.append(generated_text)
device = self.input_ids.device device = self.input_ids.device
# New values after filtering # New values after filtering
@ -429,7 +455,7 @@ class FlashCausalLMBatch(Batch):
num_blocks = 0 num_blocks = 0
max_blocks = 0 max_blocks = 0
for i, request in enumerate(updated_requests): for i, request in enumerate(kept_requests):
request_id = request.id request_id = request.id
idx = self.requests_idx_mapping[request_id] idx = self.requests_idx_mapping[request_id]
@ -491,7 +517,7 @@ class FlashCausalLMBatch(Batch):
# Move to GPU # Move to GPU
block_tables_tensor = block_tables_tensor.to(device) block_tables_tensor = block_tables_tensor.to(device)
return type(self)( filtered_batch = type(self)(
batch_id=self.batch_id, batch_id=self.batch_id,
requests=requests, requests=requests,
requests_idx_mapping=requests_idx_mapping, requests_idx_mapping=requests_idx_mapping,
@ -520,6 +546,7 @@ class FlashCausalLMBatch(Batch):
max_blocks=max_blocks, max_blocks=max_blocks,
speculative_ids=speculative_ids, speculative_ids=speculative_ids,
) )
return filtered_batch, terminated_generations
@classmethod @classmethod
@tracer.start_as_current_span("concatenate") @tracer.start_as_current_span("concatenate")

View File

@ -215,7 +215,7 @@ class IdeficsCausalLMBatch(Batch):
@tracer.start_as_current_span("filter") @tracer.start_as_current_span("filter")
def filter( def filter(
self, updated_requests: List[generate_pb2.UpdatedRequest] self, updated_requests: List[generate_pb2.KeptRequest]
) -> Optional["IdeficsCausalLMBatch"]: ) -> Optional["IdeficsCausalLMBatch"]:
request_ids = [r.id for r in updated_requests] request_ids = [r.id for r in updated_requests]

View File

@ -196,7 +196,7 @@ class MambaBatch(Batch):
) )
def filter( def filter(
self, updated_requests: List[generate_pb2.UpdatedRequest] self, updated_requests: List[generate_pb2.KeptRequest]
) -> Optional["MambaBatch"]: ) -> Optional["MambaBatch"]:
request_ids = [r.id for r in updated_requests] request_ids = [r.id for r in updated_requests]

View File

@ -167,7 +167,7 @@ class Seq2SeqLMBatch(Batch):
@tracer.start_as_current_span("filter") @tracer.start_as_current_span("filter")
def filter( def filter(
self, updated_requests: List[generate_pb2.UpdatedRequest] self, updated_requests: List[generate_pb2.KeptRequest]
) -> Optional["Seq2SeqLMBatch"]: ) -> Optional["Seq2SeqLMBatch"]:
request_ids = [r.id for r in updated_requests] request_ids = [r.id for r in updated_requests]

View File

@ -3,7 +3,7 @@ import torch
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from dataclasses import dataclass from dataclasses import dataclass
from typing import List, Optional from typing import List, Optional, Tuple
from transformers import PreTrainedTokenizerBase from transformers import PreTrainedTokenizerBase
@ -28,7 +28,12 @@ class Batch(ABC):
raise NotImplementedError raise NotImplementedError
@abstractmethod @abstractmethod
def filter(self, updated_requests: List[generate_pb2.UpdatedRequest]) -> "Batch": def filter(
self,
model,
kept_requests: List[generate_pb2.KeptRequest],
terminated_request_ids: List[int],
) -> Tuple["Batch", List[generate_pb2.GeneratedText]]:
raise NotImplementedError raise NotImplementedError
@classmethod @classmethod
@ -84,7 +89,7 @@ class Generation:
generated_text: Optional[GeneratedText] generated_text: Optional[GeneratedText]
# Optional for now, since it's not yet supported for every model. # Optional for now, since it's not yet supported for every model.
top_tokens: Optional[List[Tokens]] top_tokens: Optional[List[Tokens]]
current_length: int cache_length: int
def to_pb(self) -> generate_pb2.Generation: def to_pb(self) -> generate_pb2.Generation:
return generate_pb2.Generation( return generate_pb2.Generation(
@ -101,5 +106,5 @@ class Generation:
if self.top_tokens is not None if self.top_tokens is not None
else None else None
), ),
current_length=self.current_length, cache_length=self.cache_length,
) )

View File

@ -123,7 +123,7 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
@tracer.start_as_current_span("filter") @tracer.start_as_current_span("filter")
def filter( def filter(
self, updated_requests: List[generate_pb2.UpdatedRequest] self, updated_requests: List[generate_pb2.KeptRequest]
) -> Optional["VlmCausalLMBatch"]: ) -> Optional["VlmCausalLMBatch"]:
batch = super().filter(updated_requests) batch = super().filter(updated_requests)
batch.pixel_values = None batch.pixel_values = None

View File

@ -83,10 +83,14 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
batch = self.cache.pop(request.batch_id) batch = self.cache.pop(request.batch_id)
if batch is None: if batch is None:
raise ValueError(f"Batch ID {request.batch_id} not found in cache.") raise ValueError(f"Batch ID {request.batch_id} not found in cache.")
filtered_batch = batch.filter(request.updated_requests) filtered_batch, terminated_generations = batch.filter(
self.model, request.kept_requests, request.terminated_request_ids
)
self.cache.set(filtered_batch) self.cache.set(filtered_batch)
return generate_pb2.FilterBatchResponse(batch=filtered_batch.to_pb()) return generate_pb2.FilterBatchResponse(
batch=filtered_batch.to_pb(), terminated_generations=terminated_generations
)
async def Warmup(self, request, context): async def Warmup(self, request, context):
if self.quantize in {"exl2", "gptq"}: if self.quantize in {"exl2", "gptq"}: