add terminated_generations
This commit is contained in:
parent
3c596983ba
commit
298bf31e69
|
@ -164,6 +164,7 @@ enum FinishReason {
|
|||
FINISH_REASON_LENGTH = 0;
|
||||
FINISH_REASON_EOS_TOKEN = 1;
|
||||
FINISH_REASON_STOP_SEQUENCE = 2;
|
||||
FINISH_REASON_TERMINATED = 3;
|
||||
}
|
||||
|
||||
message GeneratedText {
|
||||
|
@ -198,11 +199,11 @@ message Generation {
|
|||
optional GeneratedText generated_text = 4;
|
||||
/// Top tokens
|
||||
repeated Tokens top_tokens = 5;
|
||||
/// Current length of the request: prompt tokens + number of generated tokens until this point
|
||||
uint32 current_length = 6;
|
||||
/// Current length of the cache: prompt tokens + number of generated tokens until this point
|
||||
uint32 cache_length = 6;
|
||||
}
|
||||
|
||||
message UpdatedRequest {
|
||||
message KeptRequest {
|
||||
/// Request ID
|
||||
uint64 id = 1;
|
||||
/// Paged attention blocks
|
||||
|
@ -211,16 +212,23 @@ message UpdatedRequest {
|
|||
repeated uint32 slots = 3;
|
||||
}
|
||||
|
||||
/// kept_requests + terminated_request_ids might not cover all requests from the
|
||||
/// cached batch as some requests can be filtered out without requiring to generate text
|
||||
/// for example if the client dropped its connection to the router
|
||||
message FilterBatchRequest {
|
||||
/// Batch ID
|
||||
uint64 batch_id = 1;
|
||||
/// Requests to keep
|
||||
repeated UpdatedRequest updated_requests = 2;
|
||||
repeated KeptRequest kept_requests = 2;
|
||||
/// Requests to terminate and generate text for
|
||||
repeated uint64 terminated_request_ids = 3;
|
||||
}
|
||||
|
||||
message FilterBatchResponse {
|
||||
/// Filtered Batch (cached)
|
||||
CachedBatch batch = 1;
|
||||
/// Terminated generations
|
||||
repeated GeneratedText terminated_generations = 2;
|
||||
}
|
||||
|
||||
|
||||
|
|
|
@ -90,11 +90,13 @@ impl Client {
|
|||
pub async fn filter_batch(
|
||||
&mut self,
|
||||
batch_id: u64,
|
||||
updated_requests: Vec<UpdatedRequest>,
|
||||
kept_requests: Vec<KeptRequest>,
|
||||
terminated_request_ids: Vec<u64>,
|
||||
) -> Result<Option<CachedBatch>> {
|
||||
let request = tonic::Request::new(FilterBatchRequest {
|
||||
batch_id,
|
||||
updated_requests,
|
||||
kept_requests,
|
||||
terminated_request_ids,
|
||||
})
|
||||
.inject_context();
|
||||
let filtered_batch = self.stub.filter_batch(request).await?.into_inner();
|
||||
|
|
|
@ -7,7 +7,7 @@ mod sharded_client;
|
|||
pub use client::Client;
|
||||
pub use pb::generate::v3::{
|
||||
input_chunk::Chunk, Batch, CachedBatch, FinishReason, GeneratedText, Generation, GrammarType,
|
||||
HealthResponse, Image, InfoResponse, Input, InputChunk, NextTokenChooserParameters, Request,
|
||||
StoppingCriteriaParameters, Tokens, UpdatedRequest,
|
||||
HealthResponse, Image, InfoResponse, Input, InputChunk, KeptRequest,
|
||||
NextTokenChooserParameters, Request, StoppingCriteriaParameters, Tokens,
|
||||
};
|
||||
pub use sharded_client::ShardedClient;
|
||||
|
|
|
@ -9,8 +9,8 @@ use tonic::transport::Uri;
|
|||
use tracing::instrument;
|
||||
use v3::client::{DecodeTimings, PrefillTimings};
|
||||
use v3::{
|
||||
Batch, CachedBatch, Client, Generation, GrammarType, HealthResponse,
|
||||
NextTokenChooserParameters, Request, StoppingCriteriaParameters, UpdatedRequest,
|
||||
Batch, CachedBatch, Client, Generation, GrammarType, HealthResponse, KeptRequest,
|
||||
NextTokenChooserParameters, Request, StoppingCriteriaParameters,
|
||||
};
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
|
@ -84,12 +84,19 @@ impl ShardedClient {
|
|||
pub async fn filter_batch(
|
||||
&mut self,
|
||||
batch_id: u64,
|
||||
updated_requests: Vec<UpdatedRequest>,
|
||||
kept_requests: Vec<KeptRequest>,
|
||||
terminated_request_ids: Vec<u64>,
|
||||
) -> Result<Option<CachedBatch>> {
|
||||
let futures: Vec<_> = self
|
||||
.clients
|
||||
.iter_mut()
|
||||
.map(|client| Box::pin(client.filter_batch(batch_id, updated_requests.clone())))
|
||||
.map(|client| {
|
||||
Box::pin(client.filter_batch(
|
||||
batch_id,
|
||||
kept_requests.clone(),
|
||||
terminated_request_ids.clone(),
|
||||
))
|
||||
})
|
||||
.collect();
|
||||
// all shards return the same message
|
||||
join_all(futures).await.pop().unwrap()
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
use std::cmp::{max, min};
|
||||
use std::cmp::min;
|
||||
use thiserror::Error;
|
||||
use tokio::sync::{mpsc, oneshot};
|
||||
|
||||
|
@ -16,8 +16,9 @@ impl BlockAllocation {
|
|||
self.slots.len()
|
||||
}
|
||||
|
||||
pub(crate) async fn extend(&mut self, current_length: u32) -> Result<(), AllocationError> {
|
||||
let remaining_tokens = max(self.prompt_tokens + self.decode_tokens - current_length, 1);
|
||||
pub(crate) async fn extend(&mut self, cache_length: u32) -> Result<(), AllocationError> {
|
||||
let remaining_tokens =
|
||||
(self.prompt_tokens + self.decode_tokens).saturating_sub(cache_length);
|
||||
self.block_allocator
|
||||
.clone()
|
||||
.extend(self, remaining_tokens)
|
||||
|
@ -131,6 +132,7 @@ async fn block_allocator_task(
|
|||
let decode_tokens = min(decode_tokens, block_size);
|
||||
let tokens = prompt_tokens + decode_tokens;
|
||||
|
||||
// FIXME: window size is not working
|
||||
// Apply window size
|
||||
let (required_blocks, repeats) = {
|
||||
let (tokens, repeats) = match window_size {
|
||||
|
|
|
@ -5,7 +5,7 @@ use crate::validation::{
|
|||
ValidGenerateRequest, ValidGrammar, ValidParameters, ValidStoppingParameters,
|
||||
};
|
||||
use nohash_hasher::{BuildNoHashHasher, IntMap};
|
||||
use std::cmp::{max, min};
|
||||
use std::cmp::max;
|
||||
use std::collections::VecDeque;
|
||||
use text_generation_client::v3::{
|
||||
Batch, GrammarType, NextTokenChooserParameters, Request, StoppingCriteriaParameters,
|
||||
|
@ -33,8 +33,8 @@ pub(crate) struct Entry {
|
|||
pub batch_time: Option<Instant>,
|
||||
/// Block Allocation
|
||||
pub block_allocation: Option<BlockAllocation>,
|
||||
/// Current length (in tokens) of the request (prompt tokens + generated_tokens)
|
||||
pub current_length: u32,
|
||||
/// Cache length (in tokens) of the request (prompt tokens + generated_tokens)
|
||||
pub cache_length: u32,
|
||||
}
|
||||
|
||||
/// Request Queue
|
||||
|
@ -164,9 +164,6 @@ struct State {
|
|||
/// Paged Attention block size
|
||||
block_size: u32,
|
||||
|
||||
/// Sliding window
|
||||
window_size: Option<u32>,
|
||||
|
||||
/// Speculation amount
|
||||
speculate: u32,
|
||||
|
||||
|
@ -190,7 +187,6 @@ impl State {
|
|||
next_id: 0,
|
||||
next_batch_id: 0,
|
||||
block_size,
|
||||
window_size,
|
||||
speculate,
|
||||
block_allocator,
|
||||
}
|
||||
|
@ -276,18 +272,7 @@ impl State {
|
|||
}
|
||||
Some(block_allocator) => {
|
||||
prefill_tokens += entry.request.input_length;
|
||||
let max_new_tokens = match self.window_size {
|
||||
None => entry.request.stopping_parameters.max_new_tokens,
|
||||
Some(window_size) => min(
|
||||
window_size.saturating_sub(entry.request.input_length),
|
||||
entry.request.stopping_parameters.max_new_tokens,
|
||||
),
|
||||
};
|
||||
decode_tokens += max_new_tokens;
|
||||
|
||||
if prefill_tokens > prefill_token_budget
|
||||
|| (prefill_tokens + decode_tokens + self.speculate) > token_budget
|
||||
{
|
||||
if prefill_tokens > prefill_token_budget {
|
||||
// Entry is over budget
|
||||
// Add it back to the front
|
||||
tracing::debug!("Over budget: prefill_tokens={prefill_tokens} > {prefill_token_budget} || {prefill_tokens} + {decode_tokens} + {} > {token_budget}", self.speculate);
|
||||
|
@ -296,7 +281,7 @@ impl State {
|
|||
}
|
||||
|
||||
let decode_tokens =
|
||||
entry.request.stopping_parameters.max_new_tokens + self.speculate - 1;
|
||||
entry.request.stopping_parameters.max_new_tokens + self.speculate;
|
||||
match block_allocator
|
||||
.allocate(entry.request.input_length, decode_tokens)
|
||||
.await
|
||||
|
@ -500,7 +485,7 @@ mod tests {
|
|||
queue_time: Instant::now(),
|
||||
batch_time: None,
|
||||
block_allocation: None,
|
||||
current_length: 0,
|
||||
cache_length: 0,
|
||||
};
|
||||
(entry, receiver_tx)
|
||||
}
|
||||
|
|
|
@ -10,7 +10,7 @@ use std::sync::{
|
|||
atomic::{AtomicBool, Ordering},
|
||||
Arc,
|
||||
};
|
||||
use text_generation_client::v3::{Batch, CachedBatch, Generation, ShardedClient, UpdatedRequest};
|
||||
use text_generation_client::v3::{Batch, CachedBatch, Generation, KeptRequest, ShardedClient};
|
||||
use text_generation_client::ClientError;
|
||||
use tokio::sync::mpsc::error::SendError;
|
||||
use tokio::sync::{mpsc, Notify, OwnedSemaphorePermit};
|
||||
|
@ -88,7 +88,7 @@ impl Scheduler for SchedulerV3 {
|
|||
queue_time: Instant::now(),
|
||||
batch_time: None,
|
||||
block_allocation: None,
|
||||
current_length: input_length,
|
||||
cache_length: 0,
|
||||
});
|
||||
|
||||
// Notify the background task that we have a new entry in the queue that needs
|
||||
|
@ -350,7 +350,7 @@ async fn filter_batch(
|
|||
.map(|alloc| (alloc.blocks.clone(), alloc.slots.clone()))
|
||||
.unwrap_or((Vec::new(), Vec::new()));
|
||||
|
||||
UpdatedRequest {
|
||||
KeptRequest {
|
||||
id: *request_id,
|
||||
blocks,
|
||||
slots,
|
||||
|
@ -359,7 +359,10 @@ async fn filter_batch(
|
|||
.collect();
|
||||
|
||||
// We unwrap here as we need to panic since we cannot recover if this method fails
|
||||
client.filter_batch(id, updated_requests).await.unwrap()
|
||||
client
|
||||
.filter_batch(id, updated_requests, Vec::new())
|
||||
.await
|
||||
.unwrap()
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -374,7 +377,7 @@ fn filter_send_generations(generations: Vec<Generation>, entries: &mut IntMap<u6
|
|||
let entry = entries
|
||||
.get_mut(&id)
|
||||
.expect("ID not found in entries. This is a bug.");
|
||||
entry.current_length = generation.current_length;
|
||||
entry.cache_length = generation.cache_length;
|
||||
|
||||
// Create and enter a span to link this function back to the entry
|
||||
let _span = info_span!(parent: entry.temp_span.as_ref().expect("batch_span is None. This is a bug."), "send_generation", generation = ?generation).entered();
|
||||
|
@ -403,7 +406,7 @@ async fn filter_update_allocations(entries: &mut IntMap<u64, Entry>) -> bool {
|
|||
.block_allocation
|
||||
.as_ref()
|
||||
.map(|block_allocation| {
|
||||
if entry.current_length > block_allocation.len() as u32 {
|
||||
if entry.cache_length > block_allocation.len() as u32 {
|
||||
// We need to re-allocate
|
||||
Some(*id)
|
||||
} else {
|
||||
|
@ -424,8 +427,8 @@ async fn filter_update_allocations(entries: &mut IntMap<u64, Entry>) -> bool {
|
|||
entry
|
||||
.block_allocation
|
||||
.as_mut()
|
||||
.unwrap()
|
||||
.extend(entry.current_length)
|
||||
.expect("We checked that the block allocation exists above")
|
||||
.extend(entry.cache_length)
|
||||
.await
|
||||
};
|
||||
|
||||
|
@ -563,6 +566,7 @@ impl From<text_generation_client::v3::GeneratedText> for GeneratedText {
|
|||
let v3_finish_reason =
|
||||
text_generation_client::v3::FinishReason::try_from(value.finish_reason).unwrap();
|
||||
let finish_reason = match v3_finish_reason {
|
||||
text_generation_client::v3::FinishReason::Terminated => FinishReason::OutOfResources,
|
||||
text_generation_client::v3::FinishReason::Length => FinishReason::Length,
|
||||
text_generation_client::v3::FinishReason::EosToken => FinishReason::EndOfSequenceToken,
|
||||
text_generation_client::v3::FinishReason::StopSequence => FinishReason::StopSequence,
|
||||
|
|
|
@ -1085,6 +1085,8 @@ pub(crate) enum FinishReason {
|
|||
EndOfSequenceToken,
|
||||
#[schema(rename = "stop_sequence")]
|
||||
StopSequence,
|
||||
#[schema(rename = "out_of_resources")]
|
||||
OutOfResources,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for FinishReason {
|
||||
|
@ -1093,6 +1095,7 @@ impl std::fmt::Display for FinishReason {
|
|||
FinishReason::Length => write!(f, "length"),
|
||||
FinishReason::EndOfSequenceToken => write!(f, "eos_token"),
|
||||
FinishReason::StopSequence => write!(f, "stop_sequence"),
|
||||
FinishReason::OutOfResources => write!(f, "out_of_resources"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -159,7 +159,7 @@ class CausalLMBatch(Batch):
|
|||
|
||||
@tracer.start_as_current_span("filter")
|
||||
def filter(
|
||||
self, updated_requests: List[generate_pb2.UpdatedRequest]
|
||||
self, updated_requests: List[generate_pb2.KeptRequest]
|
||||
) -> Optional["CausalLMBatch"]:
|
||||
request_ids = [r.id for r in updated_requests]
|
||||
|
||||
|
|
|
@ -398,11 +398,37 @@ class FlashCausalLMBatch(Batch):
|
|||
|
||||
@tracer.start_as_current_span("filter")
|
||||
def filter(
|
||||
self, updated_requests: List[generate_pb2.UpdatedRequest]
|
||||
) -> Optional["FlashCausalLMBatch"]:
|
||||
if len(updated_requests) == 0:
|
||||
self,
|
||||
model: "FlashCausalLM",
|
||||
kept_requests: List[generate_pb2.KeptRequest],
|
||||
terminated_request_ids: List[int],
|
||||
) -> Tuple[Optional["FlashCausalLMBatch"], List[generate_pb2.GeneratedText]]:
|
||||
if len(kept_requests) == 0:
|
||||
raise ValueError("Batch must have at least one request")
|
||||
|
||||
terminated_generations = []
|
||||
for request_id in terminated_request_ids:
|
||||
idx = self.requests_idx_mapping[request_id]
|
||||
all_input_ids = self.all_input_ids[idx]
|
||||
stopping_criteria = self.stopping_criterias[idx]
|
||||
do_sample = self.next_token_chooser.do_sample[idx]
|
||||
seed = self.next_token_chooser.seeds[idx]
|
||||
|
||||
# Decode generated tokens
|
||||
output_text, _, _ = model.decode_token(
|
||||
all_input_ids,
|
||||
prefix_offset=len(all_input_ids) - stopping_criteria.current_tokens - 1,
|
||||
read_offset=len(all_input_ids) - stopping_criteria.current_tokens,
|
||||
skip_special_tokens=True,
|
||||
)
|
||||
generated_text = GeneratedText(
|
||||
output_text,
|
||||
stopping_criteria.current_tokens,
|
||||
generate_pb2.FINISH_REASON_TERMINATED,
|
||||
seed if do_sample else None,
|
||||
)
|
||||
terminated_generations.append(generated_text)
|
||||
|
||||
device = self.input_ids.device
|
||||
|
||||
# New values after filtering
|
||||
|
@ -429,7 +455,7 @@ class FlashCausalLMBatch(Batch):
|
|||
num_blocks = 0
|
||||
max_blocks = 0
|
||||
|
||||
for i, request in enumerate(updated_requests):
|
||||
for i, request in enumerate(kept_requests):
|
||||
request_id = request.id
|
||||
|
||||
idx = self.requests_idx_mapping[request_id]
|
||||
|
@ -491,7 +517,7 @@ class FlashCausalLMBatch(Batch):
|
|||
# Move to GPU
|
||||
block_tables_tensor = block_tables_tensor.to(device)
|
||||
|
||||
return type(self)(
|
||||
filtered_batch = type(self)(
|
||||
batch_id=self.batch_id,
|
||||
requests=requests,
|
||||
requests_idx_mapping=requests_idx_mapping,
|
||||
|
@ -520,6 +546,7 @@ class FlashCausalLMBatch(Batch):
|
|||
max_blocks=max_blocks,
|
||||
speculative_ids=speculative_ids,
|
||||
)
|
||||
return filtered_batch, terminated_generations
|
||||
|
||||
@classmethod
|
||||
@tracer.start_as_current_span("concatenate")
|
||||
|
|
|
@ -215,7 +215,7 @@ class IdeficsCausalLMBatch(Batch):
|
|||
|
||||
@tracer.start_as_current_span("filter")
|
||||
def filter(
|
||||
self, updated_requests: List[generate_pb2.UpdatedRequest]
|
||||
self, updated_requests: List[generate_pb2.KeptRequest]
|
||||
) -> Optional["IdeficsCausalLMBatch"]:
|
||||
request_ids = [r.id for r in updated_requests]
|
||||
|
||||
|
|
|
@ -196,7 +196,7 @@ class MambaBatch(Batch):
|
|||
)
|
||||
|
||||
def filter(
|
||||
self, updated_requests: List[generate_pb2.UpdatedRequest]
|
||||
self, updated_requests: List[generate_pb2.KeptRequest]
|
||||
) -> Optional["MambaBatch"]:
|
||||
request_ids = [r.id for r in updated_requests]
|
||||
|
||||
|
|
|
@ -167,7 +167,7 @@ class Seq2SeqLMBatch(Batch):
|
|||
|
||||
@tracer.start_as_current_span("filter")
|
||||
def filter(
|
||||
self, updated_requests: List[generate_pb2.UpdatedRequest]
|
||||
self, updated_requests: List[generate_pb2.KeptRequest]
|
||||
) -> Optional["Seq2SeqLMBatch"]:
|
||||
request_ids = [r.id for r in updated_requests]
|
||||
|
||||
|
|
|
@ -3,7 +3,7 @@ import torch
|
|||
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
from transformers import PreTrainedTokenizerBase
|
||||
|
||||
|
@ -28,7 +28,12 @@ class Batch(ABC):
|
|||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def filter(self, updated_requests: List[generate_pb2.UpdatedRequest]) -> "Batch":
|
||||
def filter(
|
||||
self,
|
||||
model,
|
||||
kept_requests: List[generate_pb2.KeptRequest],
|
||||
terminated_request_ids: List[int],
|
||||
) -> Tuple["Batch", List[generate_pb2.GeneratedText]]:
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
|
@ -84,7 +89,7 @@ class Generation:
|
|||
generated_text: Optional[GeneratedText]
|
||||
# Optional for now, since it's not yet supported for every model.
|
||||
top_tokens: Optional[List[Tokens]]
|
||||
current_length: int
|
||||
cache_length: int
|
||||
|
||||
def to_pb(self) -> generate_pb2.Generation:
|
||||
return generate_pb2.Generation(
|
||||
|
@ -101,5 +106,5 @@ class Generation:
|
|||
if self.top_tokens is not None
|
||||
else None
|
||||
),
|
||||
current_length=self.current_length,
|
||||
cache_length=self.cache_length,
|
||||
)
|
||||
|
|
|
@ -123,7 +123,7 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
|
|||
|
||||
@tracer.start_as_current_span("filter")
|
||||
def filter(
|
||||
self, updated_requests: List[generate_pb2.UpdatedRequest]
|
||||
self, updated_requests: List[generate_pb2.KeptRequest]
|
||||
) -> Optional["VlmCausalLMBatch"]:
|
||||
batch = super().filter(updated_requests)
|
||||
batch.pixel_values = None
|
||||
|
|
|
@ -83,10 +83,14 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
|||
batch = self.cache.pop(request.batch_id)
|
||||
if batch is None:
|
||||
raise ValueError(f"Batch ID {request.batch_id} not found in cache.")
|
||||
filtered_batch = batch.filter(request.updated_requests)
|
||||
filtered_batch, terminated_generations = batch.filter(
|
||||
self.model, request.kept_requests, request.terminated_request_ids
|
||||
)
|
||||
self.cache.set(filtered_batch)
|
||||
|
||||
return generate_pb2.FilterBatchResponse(batch=filtered_batch.to_pb())
|
||||
return generate_pb2.FilterBatchResponse(
|
||||
batch=filtered_batch.to_pb(), terminated_generations=terminated_generations
|
||||
)
|
||||
|
||||
async def Warmup(self, request, context):
|
||||
if self.quantize in {"exl2", "gptq"}:
|
||||
|
|
Loading…
Reference in New Issue