add terminated_generations
This commit is contained in:
parent
3c596983ba
commit
298bf31e69
|
@ -164,6 +164,7 @@ enum FinishReason {
|
||||||
FINISH_REASON_LENGTH = 0;
|
FINISH_REASON_LENGTH = 0;
|
||||||
FINISH_REASON_EOS_TOKEN = 1;
|
FINISH_REASON_EOS_TOKEN = 1;
|
||||||
FINISH_REASON_STOP_SEQUENCE = 2;
|
FINISH_REASON_STOP_SEQUENCE = 2;
|
||||||
|
FINISH_REASON_TERMINATED = 3;
|
||||||
}
|
}
|
||||||
|
|
||||||
message GeneratedText {
|
message GeneratedText {
|
||||||
|
@ -198,11 +199,11 @@ message Generation {
|
||||||
optional GeneratedText generated_text = 4;
|
optional GeneratedText generated_text = 4;
|
||||||
/// Top tokens
|
/// Top tokens
|
||||||
repeated Tokens top_tokens = 5;
|
repeated Tokens top_tokens = 5;
|
||||||
/// Current length of the request: prompt tokens + number of generated tokens until this point
|
/// Current length of the cache: prompt tokens + number of generated tokens until this point
|
||||||
uint32 current_length = 6;
|
uint32 cache_length = 6;
|
||||||
}
|
}
|
||||||
|
|
||||||
message UpdatedRequest {
|
message KeptRequest {
|
||||||
/// Request ID
|
/// Request ID
|
||||||
uint64 id = 1;
|
uint64 id = 1;
|
||||||
/// Paged attention blocks
|
/// Paged attention blocks
|
||||||
|
@ -211,16 +212,23 @@ message UpdatedRequest {
|
||||||
repeated uint32 slots = 3;
|
repeated uint32 slots = 3;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// kept_requests + terminated_request_ids might not cover all requests from the
|
||||||
|
/// cached batch as some requests can be filtered out without requiring to generate text
|
||||||
|
/// for example if the client dropped its connection to the router
|
||||||
message FilterBatchRequest {
|
message FilterBatchRequest {
|
||||||
/// Batch ID
|
/// Batch ID
|
||||||
uint64 batch_id = 1;
|
uint64 batch_id = 1;
|
||||||
/// Requests to keep
|
/// Requests to keep
|
||||||
repeated UpdatedRequest updated_requests = 2;
|
repeated KeptRequest kept_requests = 2;
|
||||||
|
/// Requests to terminate and generate text for
|
||||||
|
repeated uint64 terminated_request_ids = 3;
|
||||||
}
|
}
|
||||||
|
|
||||||
message FilterBatchResponse {
|
message FilterBatchResponse {
|
||||||
/// Filtered Batch (cached)
|
/// Filtered Batch (cached)
|
||||||
CachedBatch batch = 1;
|
CachedBatch batch = 1;
|
||||||
|
/// Terminated generations
|
||||||
|
repeated GeneratedText terminated_generations = 2;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -90,11 +90,13 @@ impl Client {
|
||||||
pub async fn filter_batch(
|
pub async fn filter_batch(
|
||||||
&mut self,
|
&mut self,
|
||||||
batch_id: u64,
|
batch_id: u64,
|
||||||
updated_requests: Vec<UpdatedRequest>,
|
kept_requests: Vec<KeptRequest>,
|
||||||
|
terminated_request_ids: Vec<u64>,
|
||||||
) -> Result<Option<CachedBatch>> {
|
) -> Result<Option<CachedBatch>> {
|
||||||
let request = tonic::Request::new(FilterBatchRequest {
|
let request = tonic::Request::new(FilterBatchRequest {
|
||||||
batch_id,
|
batch_id,
|
||||||
updated_requests,
|
kept_requests,
|
||||||
|
terminated_request_ids,
|
||||||
})
|
})
|
||||||
.inject_context();
|
.inject_context();
|
||||||
let filtered_batch = self.stub.filter_batch(request).await?.into_inner();
|
let filtered_batch = self.stub.filter_batch(request).await?.into_inner();
|
||||||
|
|
|
@ -7,7 +7,7 @@ mod sharded_client;
|
||||||
pub use client::Client;
|
pub use client::Client;
|
||||||
pub use pb::generate::v3::{
|
pub use pb::generate::v3::{
|
||||||
input_chunk::Chunk, Batch, CachedBatch, FinishReason, GeneratedText, Generation, GrammarType,
|
input_chunk::Chunk, Batch, CachedBatch, FinishReason, GeneratedText, Generation, GrammarType,
|
||||||
HealthResponse, Image, InfoResponse, Input, InputChunk, NextTokenChooserParameters, Request,
|
HealthResponse, Image, InfoResponse, Input, InputChunk, KeptRequest,
|
||||||
StoppingCriteriaParameters, Tokens, UpdatedRequest,
|
NextTokenChooserParameters, Request, StoppingCriteriaParameters, Tokens,
|
||||||
};
|
};
|
||||||
pub use sharded_client::ShardedClient;
|
pub use sharded_client::ShardedClient;
|
||||||
|
|
|
@ -9,8 +9,8 @@ use tonic::transport::Uri;
|
||||||
use tracing::instrument;
|
use tracing::instrument;
|
||||||
use v3::client::{DecodeTimings, PrefillTimings};
|
use v3::client::{DecodeTimings, PrefillTimings};
|
||||||
use v3::{
|
use v3::{
|
||||||
Batch, CachedBatch, Client, Generation, GrammarType, HealthResponse,
|
Batch, CachedBatch, Client, Generation, GrammarType, HealthResponse, KeptRequest,
|
||||||
NextTokenChooserParameters, Request, StoppingCriteriaParameters, UpdatedRequest,
|
NextTokenChooserParameters, Request, StoppingCriteriaParameters,
|
||||||
};
|
};
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
|
@ -84,12 +84,19 @@ impl ShardedClient {
|
||||||
pub async fn filter_batch(
|
pub async fn filter_batch(
|
||||||
&mut self,
|
&mut self,
|
||||||
batch_id: u64,
|
batch_id: u64,
|
||||||
updated_requests: Vec<UpdatedRequest>,
|
kept_requests: Vec<KeptRequest>,
|
||||||
|
terminated_request_ids: Vec<u64>,
|
||||||
) -> Result<Option<CachedBatch>> {
|
) -> Result<Option<CachedBatch>> {
|
||||||
let futures: Vec<_> = self
|
let futures: Vec<_> = self
|
||||||
.clients
|
.clients
|
||||||
.iter_mut()
|
.iter_mut()
|
||||||
.map(|client| Box::pin(client.filter_batch(batch_id, updated_requests.clone())))
|
.map(|client| {
|
||||||
|
Box::pin(client.filter_batch(
|
||||||
|
batch_id,
|
||||||
|
kept_requests.clone(),
|
||||||
|
terminated_request_ids.clone(),
|
||||||
|
))
|
||||||
|
})
|
||||||
.collect();
|
.collect();
|
||||||
// all shards return the same message
|
// all shards return the same message
|
||||||
join_all(futures).await.pop().unwrap()
|
join_all(futures).await.pop().unwrap()
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
use std::cmp::{max, min};
|
use std::cmp::min;
|
||||||
use thiserror::Error;
|
use thiserror::Error;
|
||||||
use tokio::sync::{mpsc, oneshot};
|
use tokio::sync::{mpsc, oneshot};
|
||||||
|
|
||||||
|
@ -16,8 +16,9 @@ impl BlockAllocation {
|
||||||
self.slots.len()
|
self.slots.len()
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) async fn extend(&mut self, current_length: u32) -> Result<(), AllocationError> {
|
pub(crate) async fn extend(&mut self, cache_length: u32) -> Result<(), AllocationError> {
|
||||||
let remaining_tokens = max(self.prompt_tokens + self.decode_tokens - current_length, 1);
|
let remaining_tokens =
|
||||||
|
(self.prompt_tokens + self.decode_tokens).saturating_sub(cache_length);
|
||||||
self.block_allocator
|
self.block_allocator
|
||||||
.clone()
|
.clone()
|
||||||
.extend(self, remaining_tokens)
|
.extend(self, remaining_tokens)
|
||||||
|
@ -131,6 +132,7 @@ async fn block_allocator_task(
|
||||||
let decode_tokens = min(decode_tokens, block_size);
|
let decode_tokens = min(decode_tokens, block_size);
|
||||||
let tokens = prompt_tokens + decode_tokens;
|
let tokens = prompt_tokens + decode_tokens;
|
||||||
|
|
||||||
|
// FIXME: window size is not working
|
||||||
// Apply window size
|
// Apply window size
|
||||||
let (required_blocks, repeats) = {
|
let (required_blocks, repeats) = {
|
||||||
let (tokens, repeats) = match window_size {
|
let (tokens, repeats) = match window_size {
|
||||||
|
|
|
@ -5,7 +5,7 @@ use crate::validation::{
|
||||||
ValidGenerateRequest, ValidGrammar, ValidParameters, ValidStoppingParameters,
|
ValidGenerateRequest, ValidGrammar, ValidParameters, ValidStoppingParameters,
|
||||||
};
|
};
|
||||||
use nohash_hasher::{BuildNoHashHasher, IntMap};
|
use nohash_hasher::{BuildNoHashHasher, IntMap};
|
||||||
use std::cmp::{max, min};
|
use std::cmp::max;
|
||||||
use std::collections::VecDeque;
|
use std::collections::VecDeque;
|
||||||
use text_generation_client::v3::{
|
use text_generation_client::v3::{
|
||||||
Batch, GrammarType, NextTokenChooserParameters, Request, StoppingCriteriaParameters,
|
Batch, GrammarType, NextTokenChooserParameters, Request, StoppingCriteriaParameters,
|
||||||
|
@ -33,8 +33,8 @@ pub(crate) struct Entry {
|
||||||
pub batch_time: Option<Instant>,
|
pub batch_time: Option<Instant>,
|
||||||
/// Block Allocation
|
/// Block Allocation
|
||||||
pub block_allocation: Option<BlockAllocation>,
|
pub block_allocation: Option<BlockAllocation>,
|
||||||
/// Current length (in tokens) of the request (prompt tokens + generated_tokens)
|
/// Cache length (in tokens) of the request (prompt tokens + generated_tokens)
|
||||||
pub current_length: u32,
|
pub cache_length: u32,
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Request Queue
|
/// Request Queue
|
||||||
|
@ -164,9 +164,6 @@ struct State {
|
||||||
/// Paged Attention block size
|
/// Paged Attention block size
|
||||||
block_size: u32,
|
block_size: u32,
|
||||||
|
|
||||||
/// Sliding window
|
|
||||||
window_size: Option<u32>,
|
|
||||||
|
|
||||||
/// Speculation amount
|
/// Speculation amount
|
||||||
speculate: u32,
|
speculate: u32,
|
||||||
|
|
||||||
|
@ -190,7 +187,6 @@ impl State {
|
||||||
next_id: 0,
|
next_id: 0,
|
||||||
next_batch_id: 0,
|
next_batch_id: 0,
|
||||||
block_size,
|
block_size,
|
||||||
window_size,
|
|
||||||
speculate,
|
speculate,
|
||||||
block_allocator,
|
block_allocator,
|
||||||
}
|
}
|
||||||
|
@ -276,18 +272,7 @@ impl State {
|
||||||
}
|
}
|
||||||
Some(block_allocator) => {
|
Some(block_allocator) => {
|
||||||
prefill_tokens += entry.request.input_length;
|
prefill_tokens += entry.request.input_length;
|
||||||
let max_new_tokens = match self.window_size {
|
if prefill_tokens > prefill_token_budget {
|
||||||
None => entry.request.stopping_parameters.max_new_tokens,
|
|
||||||
Some(window_size) => min(
|
|
||||||
window_size.saturating_sub(entry.request.input_length),
|
|
||||||
entry.request.stopping_parameters.max_new_tokens,
|
|
||||||
),
|
|
||||||
};
|
|
||||||
decode_tokens += max_new_tokens;
|
|
||||||
|
|
||||||
if prefill_tokens > prefill_token_budget
|
|
||||||
|| (prefill_tokens + decode_tokens + self.speculate) > token_budget
|
|
||||||
{
|
|
||||||
// Entry is over budget
|
// Entry is over budget
|
||||||
// Add it back to the front
|
// Add it back to the front
|
||||||
tracing::debug!("Over budget: prefill_tokens={prefill_tokens} > {prefill_token_budget} || {prefill_tokens} + {decode_tokens} + {} > {token_budget}", self.speculate);
|
tracing::debug!("Over budget: prefill_tokens={prefill_tokens} > {prefill_token_budget} || {prefill_tokens} + {decode_tokens} + {} > {token_budget}", self.speculate);
|
||||||
|
@ -296,7 +281,7 @@ impl State {
|
||||||
}
|
}
|
||||||
|
|
||||||
let decode_tokens =
|
let decode_tokens =
|
||||||
entry.request.stopping_parameters.max_new_tokens + self.speculate - 1;
|
entry.request.stopping_parameters.max_new_tokens + self.speculate;
|
||||||
match block_allocator
|
match block_allocator
|
||||||
.allocate(entry.request.input_length, decode_tokens)
|
.allocate(entry.request.input_length, decode_tokens)
|
||||||
.await
|
.await
|
||||||
|
@ -500,7 +485,7 @@ mod tests {
|
||||||
queue_time: Instant::now(),
|
queue_time: Instant::now(),
|
||||||
batch_time: None,
|
batch_time: None,
|
||||||
block_allocation: None,
|
block_allocation: None,
|
||||||
current_length: 0,
|
cache_length: 0,
|
||||||
};
|
};
|
||||||
(entry, receiver_tx)
|
(entry, receiver_tx)
|
||||||
}
|
}
|
||||||
|
|
|
@ -10,7 +10,7 @@ use std::sync::{
|
||||||
atomic::{AtomicBool, Ordering},
|
atomic::{AtomicBool, Ordering},
|
||||||
Arc,
|
Arc,
|
||||||
};
|
};
|
||||||
use text_generation_client::v3::{Batch, CachedBatch, Generation, ShardedClient, UpdatedRequest};
|
use text_generation_client::v3::{Batch, CachedBatch, Generation, KeptRequest, ShardedClient};
|
||||||
use text_generation_client::ClientError;
|
use text_generation_client::ClientError;
|
||||||
use tokio::sync::mpsc::error::SendError;
|
use tokio::sync::mpsc::error::SendError;
|
||||||
use tokio::sync::{mpsc, Notify, OwnedSemaphorePermit};
|
use tokio::sync::{mpsc, Notify, OwnedSemaphorePermit};
|
||||||
|
@ -88,7 +88,7 @@ impl Scheduler for SchedulerV3 {
|
||||||
queue_time: Instant::now(),
|
queue_time: Instant::now(),
|
||||||
batch_time: None,
|
batch_time: None,
|
||||||
block_allocation: None,
|
block_allocation: None,
|
||||||
current_length: input_length,
|
cache_length: 0,
|
||||||
});
|
});
|
||||||
|
|
||||||
// Notify the background task that we have a new entry in the queue that needs
|
// Notify the background task that we have a new entry in the queue that needs
|
||||||
|
@ -350,7 +350,7 @@ async fn filter_batch(
|
||||||
.map(|alloc| (alloc.blocks.clone(), alloc.slots.clone()))
|
.map(|alloc| (alloc.blocks.clone(), alloc.slots.clone()))
|
||||||
.unwrap_or((Vec::new(), Vec::new()));
|
.unwrap_or((Vec::new(), Vec::new()));
|
||||||
|
|
||||||
UpdatedRequest {
|
KeptRequest {
|
||||||
id: *request_id,
|
id: *request_id,
|
||||||
blocks,
|
blocks,
|
||||||
slots,
|
slots,
|
||||||
|
@ -359,7 +359,10 @@ async fn filter_batch(
|
||||||
.collect();
|
.collect();
|
||||||
|
|
||||||
// We unwrap here as we need to panic since we cannot recover if this method fails
|
// We unwrap here as we need to panic since we cannot recover if this method fails
|
||||||
client.filter_batch(id, updated_requests).await.unwrap()
|
client
|
||||||
|
.filter_batch(id, updated_requests, Vec::new())
|
||||||
|
.await
|
||||||
|
.unwrap()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -374,7 +377,7 @@ fn filter_send_generations(generations: Vec<Generation>, entries: &mut IntMap<u6
|
||||||
let entry = entries
|
let entry = entries
|
||||||
.get_mut(&id)
|
.get_mut(&id)
|
||||||
.expect("ID not found in entries. This is a bug.");
|
.expect("ID not found in entries. This is a bug.");
|
||||||
entry.current_length = generation.current_length;
|
entry.cache_length = generation.cache_length;
|
||||||
|
|
||||||
// Create and enter a span to link this function back to the entry
|
// Create and enter a span to link this function back to the entry
|
||||||
let _span = info_span!(parent: entry.temp_span.as_ref().expect("batch_span is None. This is a bug."), "send_generation", generation = ?generation).entered();
|
let _span = info_span!(parent: entry.temp_span.as_ref().expect("batch_span is None. This is a bug."), "send_generation", generation = ?generation).entered();
|
||||||
|
@ -403,7 +406,7 @@ async fn filter_update_allocations(entries: &mut IntMap<u64, Entry>) -> bool {
|
||||||
.block_allocation
|
.block_allocation
|
||||||
.as_ref()
|
.as_ref()
|
||||||
.map(|block_allocation| {
|
.map(|block_allocation| {
|
||||||
if entry.current_length > block_allocation.len() as u32 {
|
if entry.cache_length > block_allocation.len() as u32 {
|
||||||
// We need to re-allocate
|
// We need to re-allocate
|
||||||
Some(*id)
|
Some(*id)
|
||||||
} else {
|
} else {
|
||||||
|
@ -424,8 +427,8 @@ async fn filter_update_allocations(entries: &mut IntMap<u64, Entry>) -> bool {
|
||||||
entry
|
entry
|
||||||
.block_allocation
|
.block_allocation
|
||||||
.as_mut()
|
.as_mut()
|
||||||
.unwrap()
|
.expect("We checked that the block allocation exists above")
|
||||||
.extend(entry.current_length)
|
.extend(entry.cache_length)
|
||||||
.await
|
.await
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -563,6 +566,7 @@ impl From<text_generation_client::v3::GeneratedText> for GeneratedText {
|
||||||
let v3_finish_reason =
|
let v3_finish_reason =
|
||||||
text_generation_client::v3::FinishReason::try_from(value.finish_reason).unwrap();
|
text_generation_client::v3::FinishReason::try_from(value.finish_reason).unwrap();
|
||||||
let finish_reason = match v3_finish_reason {
|
let finish_reason = match v3_finish_reason {
|
||||||
|
text_generation_client::v3::FinishReason::Terminated => FinishReason::OutOfResources,
|
||||||
text_generation_client::v3::FinishReason::Length => FinishReason::Length,
|
text_generation_client::v3::FinishReason::Length => FinishReason::Length,
|
||||||
text_generation_client::v3::FinishReason::EosToken => FinishReason::EndOfSequenceToken,
|
text_generation_client::v3::FinishReason::EosToken => FinishReason::EndOfSequenceToken,
|
||||||
text_generation_client::v3::FinishReason::StopSequence => FinishReason::StopSequence,
|
text_generation_client::v3::FinishReason::StopSequence => FinishReason::StopSequence,
|
||||||
|
|
|
@ -1085,6 +1085,8 @@ pub(crate) enum FinishReason {
|
||||||
EndOfSequenceToken,
|
EndOfSequenceToken,
|
||||||
#[schema(rename = "stop_sequence")]
|
#[schema(rename = "stop_sequence")]
|
||||||
StopSequence,
|
StopSequence,
|
||||||
|
#[schema(rename = "out_of_resources")]
|
||||||
|
OutOfResources,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl std::fmt::Display for FinishReason {
|
impl std::fmt::Display for FinishReason {
|
||||||
|
@ -1093,6 +1095,7 @@ impl std::fmt::Display for FinishReason {
|
||||||
FinishReason::Length => write!(f, "length"),
|
FinishReason::Length => write!(f, "length"),
|
||||||
FinishReason::EndOfSequenceToken => write!(f, "eos_token"),
|
FinishReason::EndOfSequenceToken => write!(f, "eos_token"),
|
||||||
FinishReason::StopSequence => write!(f, "stop_sequence"),
|
FinishReason::StopSequence => write!(f, "stop_sequence"),
|
||||||
|
FinishReason::OutOfResources => write!(f, "out_of_resources"),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -159,7 +159,7 @@ class CausalLMBatch(Batch):
|
||||||
|
|
||||||
@tracer.start_as_current_span("filter")
|
@tracer.start_as_current_span("filter")
|
||||||
def filter(
|
def filter(
|
||||||
self, updated_requests: List[generate_pb2.UpdatedRequest]
|
self, updated_requests: List[generate_pb2.KeptRequest]
|
||||||
) -> Optional["CausalLMBatch"]:
|
) -> Optional["CausalLMBatch"]:
|
||||||
request_ids = [r.id for r in updated_requests]
|
request_ids = [r.id for r in updated_requests]
|
||||||
|
|
||||||
|
|
|
@ -398,11 +398,37 @@ class FlashCausalLMBatch(Batch):
|
||||||
|
|
||||||
@tracer.start_as_current_span("filter")
|
@tracer.start_as_current_span("filter")
|
||||||
def filter(
|
def filter(
|
||||||
self, updated_requests: List[generate_pb2.UpdatedRequest]
|
self,
|
||||||
) -> Optional["FlashCausalLMBatch"]:
|
model: "FlashCausalLM",
|
||||||
if len(updated_requests) == 0:
|
kept_requests: List[generate_pb2.KeptRequest],
|
||||||
|
terminated_request_ids: List[int],
|
||||||
|
) -> Tuple[Optional["FlashCausalLMBatch"], List[generate_pb2.GeneratedText]]:
|
||||||
|
if len(kept_requests) == 0:
|
||||||
raise ValueError("Batch must have at least one request")
|
raise ValueError("Batch must have at least one request")
|
||||||
|
|
||||||
|
terminated_generations = []
|
||||||
|
for request_id in terminated_request_ids:
|
||||||
|
idx = self.requests_idx_mapping[request_id]
|
||||||
|
all_input_ids = self.all_input_ids[idx]
|
||||||
|
stopping_criteria = self.stopping_criterias[idx]
|
||||||
|
do_sample = self.next_token_chooser.do_sample[idx]
|
||||||
|
seed = self.next_token_chooser.seeds[idx]
|
||||||
|
|
||||||
|
# Decode generated tokens
|
||||||
|
output_text, _, _ = model.decode_token(
|
||||||
|
all_input_ids,
|
||||||
|
prefix_offset=len(all_input_ids) - stopping_criteria.current_tokens - 1,
|
||||||
|
read_offset=len(all_input_ids) - stopping_criteria.current_tokens,
|
||||||
|
skip_special_tokens=True,
|
||||||
|
)
|
||||||
|
generated_text = GeneratedText(
|
||||||
|
output_text,
|
||||||
|
stopping_criteria.current_tokens,
|
||||||
|
generate_pb2.FINISH_REASON_TERMINATED,
|
||||||
|
seed if do_sample else None,
|
||||||
|
)
|
||||||
|
terminated_generations.append(generated_text)
|
||||||
|
|
||||||
device = self.input_ids.device
|
device = self.input_ids.device
|
||||||
|
|
||||||
# New values after filtering
|
# New values after filtering
|
||||||
|
@ -429,7 +455,7 @@ class FlashCausalLMBatch(Batch):
|
||||||
num_blocks = 0
|
num_blocks = 0
|
||||||
max_blocks = 0
|
max_blocks = 0
|
||||||
|
|
||||||
for i, request in enumerate(updated_requests):
|
for i, request in enumerate(kept_requests):
|
||||||
request_id = request.id
|
request_id = request.id
|
||||||
|
|
||||||
idx = self.requests_idx_mapping[request_id]
|
idx = self.requests_idx_mapping[request_id]
|
||||||
|
@ -491,7 +517,7 @@ class FlashCausalLMBatch(Batch):
|
||||||
# Move to GPU
|
# Move to GPU
|
||||||
block_tables_tensor = block_tables_tensor.to(device)
|
block_tables_tensor = block_tables_tensor.to(device)
|
||||||
|
|
||||||
return type(self)(
|
filtered_batch = type(self)(
|
||||||
batch_id=self.batch_id,
|
batch_id=self.batch_id,
|
||||||
requests=requests,
|
requests=requests,
|
||||||
requests_idx_mapping=requests_idx_mapping,
|
requests_idx_mapping=requests_idx_mapping,
|
||||||
|
@ -520,6 +546,7 @@ class FlashCausalLMBatch(Batch):
|
||||||
max_blocks=max_blocks,
|
max_blocks=max_blocks,
|
||||||
speculative_ids=speculative_ids,
|
speculative_ids=speculative_ids,
|
||||||
)
|
)
|
||||||
|
return filtered_batch, terminated_generations
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@tracer.start_as_current_span("concatenate")
|
@tracer.start_as_current_span("concatenate")
|
||||||
|
|
|
@ -215,7 +215,7 @@ class IdeficsCausalLMBatch(Batch):
|
||||||
|
|
||||||
@tracer.start_as_current_span("filter")
|
@tracer.start_as_current_span("filter")
|
||||||
def filter(
|
def filter(
|
||||||
self, updated_requests: List[generate_pb2.UpdatedRequest]
|
self, updated_requests: List[generate_pb2.KeptRequest]
|
||||||
) -> Optional["IdeficsCausalLMBatch"]:
|
) -> Optional["IdeficsCausalLMBatch"]:
|
||||||
request_ids = [r.id for r in updated_requests]
|
request_ids = [r.id for r in updated_requests]
|
||||||
|
|
||||||
|
|
|
@ -196,7 +196,7 @@ class MambaBatch(Batch):
|
||||||
)
|
)
|
||||||
|
|
||||||
def filter(
|
def filter(
|
||||||
self, updated_requests: List[generate_pb2.UpdatedRequest]
|
self, updated_requests: List[generate_pb2.KeptRequest]
|
||||||
) -> Optional["MambaBatch"]:
|
) -> Optional["MambaBatch"]:
|
||||||
request_ids = [r.id for r in updated_requests]
|
request_ids = [r.id for r in updated_requests]
|
||||||
|
|
||||||
|
|
|
@ -167,7 +167,7 @@ class Seq2SeqLMBatch(Batch):
|
||||||
|
|
||||||
@tracer.start_as_current_span("filter")
|
@tracer.start_as_current_span("filter")
|
||||||
def filter(
|
def filter(
|
||||||
self, updated_requests: List[generate_pb2.UpdatedRequest]
|
self, updated_requests: List[generate_pb2.KeptRequest]
|
||||||
) -> Optional["Seq2SeqLMBatch"]:
|
) -> Optional["Seq2SeqLMBatch"]:
|
||||||
request_ids = [r.id for r in updated_requests]
|
request_ids = [r.id for r in updated_requests]
|
||||||
|
|
||||||
|
|
|
@ -3,7 +3,7 @@ import torch
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import List, Optional
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
from transformers import PreTrainedTokenizerBase
|
from transformers import PreTrainedTokenizerBase
|
||||||
|
|
||||||
|
@ -28,7 +28,12 @@ class Batch(ABC):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def filter(self, updated_requests: List[generate_pb2.UpdatedRequest]) -> "Batch":
|
def filter(
|
||||||
|
self,
|
||||||
|
model,
|
||||||
|
kept_requests: List[generate_pb2.KeptRequest],
|
||||||
|
terminated_request_ids: List[int],
|
||||||
|
) -> Tuple["Batch", List[generate_pb2.GeneratedText]]:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
@ -84,7 +89,7 @@ class Generation:
|
||||||
generated_text: Optional[GeneratedText]
|
generated_text: Optional[GeneratedText]
|
||||||
# Optional for now, since it's not yet supported for every model.
|
# Optional for now, since it's not yet supported for every model.
|
||||||
top_tokens: Optional[List[Tokens]]
|
top_tokens: Optional[List[Tokens]]
|
||||||
current_length: int
|
cache_length: int
|
||||||
|
|
||||||
def to_pb(self) -> generate_pb2.Generation:
|
def to_pb(self) -> generate_pb2.Generation:
|
||||||
return generate_pb2.Generation(
|
return generate_pb2.Generation(
|
||||||
|
@ -101,5 +106,5 @@ class Generation:
|
||||||
if self.top_tokens is not None
|
if self.top_tokens is not None
|
||||||
else None
|
else None
|
||||||
),
|
),
|
||||||
current_length=self.current_length,
|
cache_length=self.cache_length,
|
||||||
)
|
)
|
||||||
|
|
|
@ -123,7 +123,7 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
|
||||||
|
|
||||||
@tracer.start_as_current_span("filter")
|
@tracer.start_as_current_span("filter")
|
||||||
def filter(
|
def filter(
|
||||||
self, updated_requests: List[generate_pb2.UpdatedRequest]
|
self, updated_requests: List[generate_pb2.KeptRequest]
|
||||||
) -> Optional["VlmCausalLMBatch"]:
|
) -> Optional["VlmCausalLMBatch"]:
|
||||||
batch = super().filter(updated_requests)
|
batch = super().filter(updated_requests)
|
||||||
batch.pixel_values = None
|
batch.pixel_values = None
|
||||||
|
|
|
@ -83,10 +83,14 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
||||||
batch = self.cache.pop(request.batch_id)
|
batch = self.cache.pop(request.batch_id)
|
||||||
if batch is None:
|
if batch is None:
|
||||||
raise ValueError(f"Batch ID {request.batch_id} not found in cache.")
|
raise ValueError(f"Batch ID {request.batch_id} not found in cache.")
|
||||||
filtered_batch = batch.filter(request.updated_requests)
|
filtered_batch, terminated_generations = batch.filter(
|
||||||
|
self.model, request.kept_requests, request.terminated_request_ids
|
||||||
|
)
|
||||||
self.cache.set(filtered_batch)
|
self.cache.set(filtered_batch)
|
||||||
|
|
||||||
return generate_pb2.FilterBatchResponse(batch=filtered_batch.to_pb())
|
return generate_pb2.FilterBatchResponse(
|
||||||
|
batch=filtered_batch.to_pb(), terminated_generations=terminated_generations
|
||||||
|
)
|
||||||
|
|
||||||
async def Warmup(self, request, context):
|
async def Warmup(self, request, context):
|
||||||
if self.quantize in {"exl2", "gptq"}:
|
if self.quantize in {"exl2", "gptq"}:
|
||||||
|
|
Loading…
Reference in New Issue