add terminated_generations

This commit is contained in:
OlivierDehaene 2024-06-07 11:26:17 +02:00
parent 3c596983ba
commit 298bf31e69
16 changed files with 107 additions and 60 deletions

View File

@ -164,6 +164,7 @@ enum FinishReason {
FINISH_REASON_LENGTH = 0;
FINISH_REASON_EOS_TOKEN = 1;
FINISH_REASON_STOP_SEQUENCE = 2;
FINISH_REASON_TERMINATED = 3;
}
message GeneratedText {
@ -198,11 +199,11 @@ message Generation {
optional GeneratedText generated_text = 4;
/// Top tokens
repeated Tokens top_tokens = 5;
/// Current length of the request: prompt tokens + number of generated tokens until this point
uint32 current_length = 6;
/// Current length of the cache: prompt tokens + number of generated tokens until this point
uint32 cache_length = 6;
}
message UpdatedRequest {
message KeptRequest {
/// Request ID
uint64 id = 1;
/// Paged attention blocks
@ -211,16 +212,23 @@ message UpdatedRequest {
repeated uint32 slots = 3;
}
/// kept_requests + terminated_request_ids might not cover all requests from the
/// cached batch as some requests can be filtered out without requiring to generate text
/// for example if the client dropped its connection to the router
message FilterBatchRequest {
/// Batch ID
uint64 batch_id = 1;
/// Requests to keep
repeated UpdatedRequest updated_requests = 2;
repeated KeptRequest kept_requests = 2;
/// Requests to terminate and generate text for
repeated uint64 terminated_request_ids = 3;
}
message FilterBatchResponse {
/// Filtered Batch (cached)
CachedBatch batch = 1;
/// Terminated generations
repeated GeneratedText terminated_generations = 2;
}

View File

@ -90,11 +90,13 @@ impl Client {
pub async fn filter_batch(
&mut self,
batch_id: u64,
updated_requests: Vec<UpdatedRequest>,
kept_requests: Vec<KeptRequest>,
terminated_request_ids: Vec<u64>,
) -> Result<Option<CachedBatch>> {
let request = tonic::Request::new(FilterBatchRequest {
batch_id,
updated_requests,
kept_requests,
terminated_request_ids,
})
.inject_context();
let filtered_batch = self.stub.filter_batch(request).await?.into_inner();

View File

@ -7,7 +7,7 @@ mod sharded_client;
pub use client::Client;
pub use pb::generate::v3::{
input_chunk::Chunk, Batch, CachedBatch, FinishReason, GeneratedText, Generation, GrammarType,
HealthResponse, Image, InfoResponse, Input, InputChunk, NextTokenChooserParameters, Request,
StoppingCriteriaParameters, Tokens, UpdatedRequest,
HealthResponse, Image, InfoResponse, Input, InputChunk, KeptRequest,
NextTokenChooserParameters, Request, StoppingCriteriaParameters, Tokens,
};
pub use sharded_client::ShardedClient;

View File

@ -9,8 +9,8 @@ use tonic::transport::Uri;
use tracing::instrument;
use v3::client::{DecodeTimings, PrefillTimings};
use v3::{
Batch, CachedBatch, Client, Generation, GrammarType, HealthResponse,
NextTokenChooserParameters, Request, StoppingCriteriaParameters, UpdatedRequest,
Batch, CachedBatch, Client, Generation, GrammarType, HealthResponse, KeptRequest,
NextTokenChooserParameters, Request, StoppingCriteriaParameters,
};
#[derive(Debug, Clone)]
@ -84,12 +84,19 @@ impl ShardedClient {
pub async fn filter_batch(
&mut self,
batch_id: u64,
updated_requests: Vec<UpdatedRequest>,
kept_requests: Vec<KeptRequest>,
terminated_request_ids: Vec<u64>,
) -> Result<Option<CachedBatch>> {
let futures: Vec<_> = self
.clients
.iter_mut()
.map(|client| Box::pin(client.filter_batch(batch_id, updated_requests.clone())))
.map(|client| {
Box::pin(client.filter_batch(
batch_id,
kept_requests.clone(),
terminated_request_ids.clone(),
))
})
.collect();
// all shards return the same message
join_all(futures).await.pop().unwrap()

View File

@ -1,4 +1,4 @@
use std::cmp::{max, min};
use std::cmp::min;
use thiserror::Error;
use tokio::sync::{mpsc, oneshot};
@ -16,8 +16,9 @@ impl BlockAllocation {
self.slots.len()
}
pub(crate) async fn extend(&mut self, current_length: u32) -> Result<(), AllocationError> {
let remaining_tokens = max(self.prompt_tokens + self.decode_tokens - current_length, 1);
pub(crate) async fn extend(&mut self, cache_length: u32) -> Result<(), AllocationError> {
let remaining_tokens =
(self.prompt_tokens + self.decode_tokens).saturating_sub(cache_length);
self.block_allocator
.clone()
.extend(self, remaining_tokens)
@ -131,6 +132,7 @@ async fn block_allocator_task(
let decode_tokens = min(decode_tokens, block_size);
let tokens = prompt_tokens + decode_tokens;
// FIXME: window size is not working
// Apply window size
let (required_blocks, repeats) = {
let (tokens, repeats) = match window_size {

View File

@ -5,7 +5,7 @@ use crate::validation::{
ValidGenerateRequest, ValidGrammar, ValidParameters, ValidStoppingParameters,
};
use nohash_hasher::{BuildNoHashHasher, IntMap};
use std::cmp::{max, min};
use std::cmp::max;
use std::collections::VecDeque;
use text_generation_client::v3::{
Batch, GrammarType, NextTokenChooserParameters, Request, StoppingCriteriaParameters,
@ -33,8 +33,8 @@ pub(crate) struct Entry {
pub batch_time: Option<Instant>,
/// Block Allocation
pub block_allocation: Option<BlockAllocation>,
/// Current length (in tokens) of the request (prompt tokens + generated_tokens)
pub current_length: u32,
/// Cache length (in tokens) of the request (prompt tokens + generated_tokens)
pub cache_length: u32,
}
/// Request Queue
@ -164,9 +164,6 @@ struct State {
/// Paged Attention block size
block_size: u32,
/// Sliding window
window_size: Option<u32>,
/// Speculation amount
speculate: u32,
@ -190,7 +187,6 @@ impl State {
next_id: 0,
next_batch_id: 0,
block_size,
window_size,
speculate,
block_allocator,
}
@ -276,18 +272,7 @@ impl State {
}
Some(block_allocator) => {
prefill_tokens += entry.request.input_length;
let max_new_tokens = match self.window_size {
None => entry.request.stopping_parameters.max_new_tokens,
Some(window_size) => min(
window_size.saturating_sub(entry.request.input_length),
entry.request.stopping_parameters.max_new_tokens,
),
};
decode_tokens += max_new_tokens;
if prefill_tokens > prefill_token_budget
|| (prefill_tokens + decode_tokens + self.speculate) > token_budget
{
if prefill_tokens > prefill_token_budget {
// Entry is over budget
// Add it back to the front
tracing::debug!("Over budget: prefill_tokens={prefill_tokens} > {prefill_token_budget} || {prefill_tokens} + {decode_tokens} + {} > {token_budget}", self.speculate);
@ -296,7 +281,7 @@ impl State {
}
let decode_tokens =
entry.request.stopping_parameters.max_new_tokens + self.speculate - 1;
entry.request.stopping_parameters.max_new_tokens + self.speculate;
match block_allocator
.allocate(entry.request.input_length, decode_tokens)
.await
@ -500,7 +485,7 @@ mod tests {
queue_time: Instant::now(),
batch_time: None,
block_allocation: None,
current_length: 0,
cache_length: 0,
};
(entry, receiver_tx)
}

View File

@ -10,7 +10,7 @@ use std::sync::{
atomic::{AtomicBool, Ordering},
Arc,
};
use text_generation_client::v3::{Batch, CachedBatch, Generation, ShardedClient, UpdatedRequest};
use text_generation_client::v3::{Batch, CachedBatch, Generation, KeptRequest, ShardedClient};
use text_generation_client::ClientError;
use tokio::sync::mpsc::error::SendError;
use tokio::sync::{mpsc, Notify, OwnedSemaphorePermit};
@ -88,7 +88,7 @@ impl Scheduler for SchedulerV3 {
queue_time: Instant::now(),
batch_time: None,
block_allocation: None,
current_length: input_length,
cache_length: 0,
});
// Notify the background task that we have a new entry in the queue that needs
@ -350,7 +350,7 @@ async fn filter_batch(
.map(|alloc| (alloc.blocks.clone(), alloc.slots.clone()))
.unwrap_or((Vec::new(), Vec::new()));
UpdatedRequest {
KeptRequest {
id: *request_id,
blocks,
slots,
@ -359,7 +359,10 @@ async fn filter_batch(
.collect();
// We unwrap here as we need to panic since we cannot recover if this method fails
client.filter_batch(id, updated_requests).await.unwrap()
client
.filter_batch(id, updated_requests, Vec::new())
.await
.unwrap()
}
}
@ -374,7 +377,7 @@ fn filter_send_generations(generations: Vec<Generation>, entries: &mut IntMap<u6
let entry = entries
.get_mut(&id)
.expect("ID not found in entries. This is a bug.");
entry.current_length = generation.current_length;
entry.cache_length = generation.cache_length;
// Create and enter a span to link this function back to the entry
let _span = info_span!(parent: entry.temp_span.as_ref().expect("batch_span is None. This is a bug."), "send_generation", generation = ?generation).entered();
@ -403,7 +406,7 @@ async fn filter_update_allocations(entries: &mut IntMap<u64, Entry>) -> bool {
.block_allocation
.as_ref()
.map(|block_allocation| {
if entry.current_length > block_allocation.len() as u32 {
if entry.cache_length > block_allocation.len() as u32 {
// We need to re-allocate
Some(*id)
} else {
@ -424,8 +427,8 @@ async fn filter_update_allocations(entries: &mut IntMap<u64, Entry>) -> bool {
entry
.block_allocation
.as_mut()
.unwrap()
.extend(entry.current_length)
.expect("We checked that the block allocation exists above")
.extend(entry.cache_length)
.await
};
@ -563,6 +566,7 @@ impl From<text_generation_client::v3::GeneratedText> for GeneratedText {
let v3_finish_reason =
text_generation_client::v3::FinishReason::try_from(value.finish_reason).unwrap();
let finish_reason = match v3_finish_reason {
text_generation_client::v3::FinishReason::Terminated => FinishReason::OutOfResources,
text_generation_client::v3::FinishReason::Length => FinishReason::Length,
text_generation_client::v3::FinishReason::EosToken => FinishReason::EndOfSequenceToken,
text_generation_client::v3::FinishReason::StopSequence => FinishReason::StopSequence,

View File

@ -1085,6 +1085,8 @@ pub(crate) enum FinishReason {
EndOfSequenceToken,
#[schema(rename = "stop_sequence")]
StopSequence,
#[schema(rename = "out_of_resources")]
OutOfResources,
}
impl std::fmt::Display for FinishReason {
@ -1093,6 +1095,7 @@ impl std::fmt::Display for FinishReason {
FinishReason::Length => write!(f, "length"),
FinishReason::EndOfSequenceToken => write!(f, "eos_token"),
FinishReason::StopSequence => write!(f, "stop_sequence"),
FinishReason::OutOfResources => write!(f, "out_of_resources"),
}
}
}

View File

@ -159,7 +159,7 @@ class CausalLMBatch(Batch):
@tracer.start_as_current_span("filter")
def filter(
self, updated_requests: List[generate_pb2.UpdatedRequest]
self, updated_requests: List[generate_pb2.KeptRequest]
) -> Optional["CausalLMBatch"]:
request_ids = [r.id for r in updated_requests]

View File

@ -398,11 +398,37 @@ class FlashCausalLMBatch(Batch):
@tracer.start_as_current_span("filter")
def filter(
self, updated_requests: List[generate_pb2.UpdatedRequest]
) -> Optional["FlashCausalLMBatch"]:
if len(updated_requests) == 0:
self,
model: "FlashCausalLM",
kept_requests: List[generate_pb2.KeptRequest],
terminated_request_ids: List[int],
) -> Tuple[Optional["FlashCausalLMBatch"], List[generate_pb2.GeneratedText]]:
if len(kept_requests) == 0:
raise ValueError("Batch must have at least one request")
terminated_generations = []
for request_id in terminated_request_ids:
idx = self.requests_idx_mapping[request_id]
all_input_ids = self.all_input_ids[idx]
stopping_criteria = self.stopping_criterias[idx]
do_sample = self.next_token_chooser.do_sample[idx]
seed = self.next_token_chooser.seeds[idx]
# Decode generated tokens
output_text, _, _ = model.decode_token(
all_input_ids,
prefix_offset=len(all_input_ids) - stopping_criteria.current_tokens - 1,
read_offset=len(all_input_ids) - stopping_criteria.current_tokens,
skip_special_tokens=True,
)
generated_text = GeneratedText(
output_text,
stopping_criteria.current_tokens,
generate_pb2.FINISH_REASON_TERMINATED,
seed if do_sample else None,
)
terminated_generations.append(generated_text)
device = self.input_ids.device
# New values after filtering
@ -429,7 +455,7 @@ class FlashCausalLMBatch(Batch):
num_blocks = 0
max_blocks = 0
for i, request in enumerate(updated_requests):
for i, request in enumerate(kept_requests):
request_id = request.id
idx = self.requests_idx_mapping[request_id]
@ -491,7 +517,7 @@ class FlashCausalLMBatch(Batch):
# Move to GPU
block_tables_tensor = block_tables_tensor.to(device)
return type(self)(
filtered_batch = type(self)(
batch_id=self.batch_id,
requests=requests,
requests_idx_mapping=requests_idx_mapping,
@ -520,6 +546,7 @@ class FlashCausalLMBatch(Batch):
max_blocks=max_blocks,
speculative_ids=speculative_ids,
)
return filtered_batch, terminated_generations
@classmethod
@tracer.start_as_current_span("concatenate")

View File

@ -215,7 +215,7 @@ class IdeficsCausalLMBatch(Batch):
@tracer.start_as_current_span("filter")
def filter(
self, updated_requests: List[generate_pb2.UpdatedRequest]
self, updated_requests: List[generate_pb2.KeptRequest]
) -> Optional["IdeficsCausalLMBatch"]:
request_ids = [r.id for r in updated_requests]

View File

@ -196,7 +196,7 @@ class MambaBatch(Batch):
)
def filter(
self, updated_requests: List[generate_pb2.UpdatedRequest]
self, updated_requests: List[generate_pb2.KeptRequest]
) -> Optional["MambaBatch"]:
request_ids = [r.id for r in updated_requests]

View File

@ -167,7 +167,7 @@ class Seq2SeqLMBatch(Batch):
@tracer.start_as_current_span("filter")
def filter(
self, updated_requests: List[generate_pb2.UpdatedRequest]
self, updated_requests: List[generate_pb2.KeptRequest]
) -> Optional["Seq2SeqLMBatch"]:
request_ids = [r.id for r in updated_requests]

View File

@ -3,7 +3,7 @@ import torch
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import List, Optional
from typing import List, Optional, Tuple
from transformers import PreTrainedTokenizerBase
@ -28,7 +28,12 @@ class Batch(ABC):
raise NotImplementedError
@abstractmethod
def filter(self, updated_requests: List[generate_pb2.UpdatedRequest]) -> "Batch":
def filter(
self,
model,
kept_requests: List[generate_pb2.KeptRequest],
terminated_request_ids: List[int],
) -> Tuple["Batch", List[generate_pb2.GeneratedText]]:
raise NotImplementedError
@classmethod
@ -84,7 +89,7 @@ class Generation:
generated_text: Optional[GeneratedText]
# Optional for now, since it's not yet supported for every model.
top_tokens: Optional[List[Tokens]]
current_length: int
cache_length: int
def to_pb(self) -> generate_pb2.Generation:
return generate_pb2.Generation(
@ -101,5 +106,5 @@ class Generation:
if self.top_tokens is not None
else None
),
current_length=self.current_length,
cache_length=self.cache_length,
)

View File

@ -123,7 +123,7 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
@tracer.start_as_current_span("filter")
def filter(
self, updated_requests: List[generate_pb2.UpdatedRequest]
self, updated_requests: List[generate_pb2.KeptRequest]
) -> Optional["VlmCausalLMBatch"]:
batch = super().filter(updated_requests)
batch.pixel_values = None

View File

@ -83,10 +83,14 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
batch = self.cache.pop(request.batch_id)
if batch is None:
raise ValueError(f"Batch ID {request.batch_id} not found in cache.")
filtered_batch = batch.filter(request.updated_requests)
filtered_batch, terminated_generations = batch.filter(
self.model, request.kept_requests, request.terminated_request_ids
)
self.cache.set(filtered_batch)
return generate_pb2.FilterBatchResponse(batch=filtered_batch.to_pb())
return generate_pb2.FilterBatchResponse(
batch=filtered_batch.to_pb(), terminated_generations=terminated_generations
)
async def Warmup(self, request, context):
if self.quantize in {"exl2", "gptq"}: